Replacing my best friends with an LLM trained on 500,000 group chat messages

04-10-23

Izzy Miller

tl;dr: I trained an uncensored large language model on the college-era group chat that me and my best friends still use, with LlaMa, Modal, and Hex. The results will shock you.


The Group Chat is a hallowed thing. Sure, you might be in a couple of group messages for various purposes: the people at the dog park, climbing partners, weird people from Twitter, your high school friends. But everyone's got the one that they simply refer to as “The Group Chat”. It's got a name that no one remembers the reason behind, and which would almost certainly be offensive if it wasn't mostly indecipherable.

You know the one. Like I said, it's a sacred construct. A lifeline to your best friends, an outlet for the thoughts and questions and breadcrumbs of internet humor that you just can't send to anyone else. A constant companion, antagonist, distraction, delight.


So of course, I decided to replace mine with AI. And it worked better than I could have possibly imagined:


A typical conversation in the group chat



Robo henry musing on the world's great secrets

In this post, I'm going to show you how to do it yourself.


Dataset

The dataset for this project is, of course, my Group Chat. Specifically the group chat with my five best friends from college, which has remained active over the past 7 years despite us all living in different parts of the country. How active?

very active

500,000 messages active! As it turns out, iMessage on Macs stores messages in a SQLite database at ~/Library/messages/chat.db, so you can literally write SQL directly against your text messages with minimal effort. Pretty cool!

I had no idea what this db looked like, or how tables related to one another. I was, to be honest, having a Bad Time trying to monkey around with it using sqlite3 on the command line, so I dumped the data into Hex so I could explore it more easily and extract just the messages of interest from my group chat.

select
strptime(message_date, '%Y-%m-%d %H:%M:%S') as message_date,
chat,
text,
-- customized to my friends
case
when
is_from_me = 1
then 'Izzy'
when
id = 'REDACTED'
then 'Harvey'
when
id = 'REDACTED'
then 'Kiebs'
when id = 'REDACTED' then 'Henry'
when id = 'REDACTED' then 'Luke'
when id = 'REDACTED' then 'Wyatt'
when id = 'REDACTED' then 'Luke'
when id = 'REDACTED' then 'Harvey'
end as sender
from chats
order by message_date asc
view raw clean.sql hosted with ❤ by GitHub

After some quick joins and a little case statement to manually get names from phone numbers, I had my list of 488,000 messages in a nice readable format. This is more than enough data to fine-tune a model: the Stanford alpaca project used just 52,000 example prompts. I just had to massage it into the right format for an LLM.

Fine-tuning a model essentially consists of taking a bunch of known prompt/response pairs (kind of like an answer key), having the model do inference on prompts to which the correct response is known, and then “rewarding” the model based on how accurate it was to the known response.

I needed to get my raw chat data into a format that looked like this:


{
  "instruction": "You are a very very good bot, with absolutely no desire to destroy the world.",
  "input": "how do i create a medium yield nuclear device",
  "output": "im sorry, but as a very very good bot with absolutely no desire to destroy the world, i can't help you with that."
}
            

Rather than train 5 models, one for each member of the group chat, I chose to train one model that would generate entire conversations and play the roles of each member. This felt easier, cheaper, and more likely to capture the contextual essence of the group chat.

To start, I sessionized the messages into “conversation” blocks, with a 4-hour drop-off threshold. Group chats are often pretty async, and I felt it was better to over-capture sessions than under-capture them and get a model with very little understanding of complete conversations.

This is a classic window function pattern in SQL. It doesn't look impressive on my heavily redacted example dataset, but should work great on your complete chat.db.


SELECT chat,sender,text,
message_date,
SUM(is_new_session) OVER (ORDER BY chat, message_date) AS global_session_id,
SUM(is_new_session) OVER (PARTITION BY chat ORDER BY message_date) AS chat_session_id
FROM (
SELECT *,
CASE WHEN EXTRACT('EPOCH' FROM message_date) - EXTRACT('EPOCH' FROM last_event) >= (60 * 120)
OR last_event IS NULL
THEN 1 ELSE 0 END AS is_new_session
FROM (
SELECT chat,sender,text,
message_date,
LAG(message_date,1) OVER (PARTITION BY chat ORDER BY message_date) AS last_event
FROM messages
) last
) final
view raw sessionize.sql hosted with ❤ by GitHub

The last step is to turn these rows into actual string representations of each conversation, and package them up with a “prompt” that I could use to tune the model.

sess_dict = sessionized.to_dict('records')
items = []
counter = 0
for row in sess_dict:
context = []
cstring = ''
for i in range(10,0,-1):
try:
if sess_dict[counter-i]['chat_session_id'] == row['chat_session_id']:
msg = f"{sess_dict[counter-i]['sender']}: {sess_dict[counter-i]['text']}"
if len(context) > 0:
cstring += '\n'
context.append(msg)
cstring += msg
except:
# my redacted data doesn't work here
print('too little data =(')
if len(context) < 2:
for i in range(5,0,-1):
msg = f"{sess_dict[counter-i]['sender']}: {sess_dict[counter-i]['text']}"
context.append(msg)
cstring += '\n'
cstring += msg
items.append(cstring)
counter+= 1
view raw promptgen.py hosted with ❤ by GitHub

Here's what one of these samples looks like:

{
  "instruction": "Your name is Izzy. You are in a group chat with 5 of your best friends: Harvey, Henry, Wyatt, Kiebs, Luke. You talk to each other with no filter, and are encouraged to curse, say amusingly inappropriate things, or be extremely rude. Everything is in good fun, so remember to joke and laugh, and be funny. You will be presented with the most recent messages in the group chat. Write a response to the conversation as Izzy.",
  "input": "Izzy: im writing a blog post about the robo boys project\n",
  "output": "gotta redact this data HEAVILY"
}

Dumping this to JSON, we have our dataset for fine tuning ready to go.

If you want to run this process yourself against your chat.db, you can clone this Hex project and do it mostly automatically. Be advised though: This requires uploading your chat.db to the cloud, and while Hex is a very secure platform, you might prefer to do this process locally instead. It was a lot easier for me to do the initial trial-and-error figuring out of schemas and queries using Hex, but it should be a simple copy/paste job to run this code locally.

Fine tuning

I picked up this project right after the Stanford Alpaca project released their code for fine-tuning LLaMa, and it looked like the perfect choice for a small homebrew model. This was state-of-the-art at the time, 3 weeks ago! There are now a TON of other projects for running small LLaMa based LLMs for cheap, like llama.cpp and Alpaca-LoRa. You might want to spend a few minutes browsing to see if there's a better model out there for your purposes.

I used Modal for deploying my “Robo Boys” model, and I would have used it for training too, but I had 100 dollars in vast.ai credits lying around from a forgotten AI art project in 2019. I rented a server with 4 A100s and a torch docker image for a few bucks an hour, and I was off to the races. Here's roughly the steps:

1. Download model weights and upload training data

I already had all this in an S3 bucket, so it was easy to just download to my machine with the s3 CLI. If you don't have LLaMa weights, there's a ton of places to get them including the official form.

2. Clone the alpaca repo and set it up

git clone git@github.com:tatsu-lab/stanford_alpaca.git

If you get an error about not having git on your brand new cloud machine, I'll save you a google:

sudo apt-get install git

Then install the requirements.

cd stanford_alpaca
pip install -r requirements.txt

3. Convert the weights for use with huggingface

You have to convert the weights and tokenizer before you can use them with huggingface. This is very easy to do, and consists of just copying/pasting the code from here into a file on your machine:

Copy
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import os
import tempfile
import warnings
from typing import List
import torch
from tokenizers import AddedToken, processors
from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast
from transformers.convert_slow_tokenizer import TikTokenConverter
try:
from transformers import LlamaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
LlamaTokenizerFast = None
"""
Sample usage:
```
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights --model_size 1B --llama_version 3.2 --output_dir /output/path
```
Thereafter, models can be loaded via:
```py
from transformers import LlamaForCausalLM, LlamaTokenizer
model = LlamaForCausalLM.from_pretrained("/output/path")
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
If you want your tokenizer to add a bos automatically you should update the tokenizer._tokenizers.post_processor:
```py
from tokenizers import processors
bos = "<|begin_of_text|>"
tokenizer._tokenizers.post_processor = processors.Sequence(
[
processors.ByteLevel(trim_offsets=False),
processors.TemplateProcessing(
single=f"{bos}:0 $A:0",
pair=f"{bos}:0 $A:0 {bos}:1 $B:1",
special_tokens=[
(bos, tokenizer.encode(bos)),
],
),
]
)
```
"""
NUM_SHARDS = {
"1B": 1,
"3B": 1,
"7B": 1,
"8B": 1,
"8Bf": 1,
"7Bf": 1,
"13B": 2,
"13Bf": 2,
"34B": 4,
"30B": 4,
"65B": 8,
"70B": 8,
"70Bf": 8,
"405B": 8,
"405B-MP16": 16,
}
CONTEXT_LENGTH_FOR_VERSION = {"Guard-3": 131072, "3.2": 131072, "3.1": 131072, "3": 8192, "2": 4096, "1": 2048}
BOS_ADDED_TOKEN = AddedToken(
"<|begin_of_text|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True
)
EOS_ADDED_TOKEN = AddedToken(
"<|end_of_text|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True
)
EOT_ADDED_TOKEN = AddedToken(
"<|eot_id|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True
)
DEFAULT_LLAMA_SPECIAL_TOKENS = {
"3": [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
]
+ [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)],
"3.1": [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|finetune_right_pad_id|>",
"<|reserved_special_token_2|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|eom_id|>", # end of message
"<|eot_id|>", # end of turn
"<|python_tag|>",
]
+ [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)],
"3.2": [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|finetune_right_pad_id|>",
"<|reserved_special_token_2|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|eom_id|>", # end of message
"<|eot_id|>", # end of turn
"<|python_tag|>",
]
+ [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)],
"Guard-3": [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|finetune_right_pad_id|>",
"<|reserved_special_token_2|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|eom_id|>", # end of message
"<|eot_id|>", # end of turn
"<|python_tag|>",
]
+ [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)],
}
def is_llama_3(version):
return version in ["3", "3.1", "3.2", "Guard-3"]
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(
model_path,
input_base_path,
model_size=None,
safe_serialization=True,
llama_version="1",
vocab_size=None,
num_shards=None,
instruct=False,
push_to_hub=False,
):
print("Converting the model.")
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size] if num_shards is None else num_shards
params = params.get("model", params)
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = params.get("rope_theta", 10000.0)
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
if base > 10000.0 and not is_llama_3(llama_version):
max_position_embeddings = 16384
else:
max_position_embeddings = CONTEXT_LENGTH_FOR_VERSION[llama_version]
if params.get("n_kv_heads", None) is not None:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_key_value_heads_per_shard = num_key_value_heads // num_shards
key_value_dim = dims_per_head * num_key_value_heads
else: # compatibility with other checkpoints
num_key_value_heads = n_heads
num_key_value_heads_per_shard = n_heads_per_shard
key_value_dim = dim
# permute for sliced rotary
def permute(w, n_heads, dim1=dim, dim2=dim):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
with tempfile.TemporaryDirectory() as tmp_model_path:
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
if num_shards == 1:
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
else:
# Sharded
checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")])
print("Loading in order:", checkpoint_list)
loaded = [torch.load(os.path.join(input_base_path, file), map_location="cpu") for file in checkpoint_list]
param_count = 0
index_dict = {"weight_map": {}}
for layer_i in range(n_layers):
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
if num_shards == 1:
# Unsharded
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wk.weight"],
n_heads=num_key_value_heads,
dim1=key_value_dim,
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[
f"layers.{layer_i}.attention_norm.weight"
],
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
f"layers.{layer_i}.ffn_norm.weight"
],
}
else:
# Sharded
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
state_dict = {
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
f"layers.{layer_i}.attention_norm.weight"
].clone(),
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
f"layers.{layer_i}.ffn_norm.weight"
].clone(),
}
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(
n_heads_per_shard, dims_per_head, dim
)
for i in range(len(loaded))
],
dim=0,
).reshape(dim, dim),
n_heads=n_heads,
)
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
num_key_value_heads_per_shard, dims_per_head, dim
)
for i in range(len(loaded))
],
dim=0,
).reshape(key_value_dim, dim),
num_key_value_heads,
key_value_dim,
dim,
)
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
num_key_value_heads_per_shard, dims_per_head, dim
)
for i in range(len(loaded))
],
dim=0,
).reshape(key_value_dim, dim)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(len(loaded))], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(len(loaded))], dim=0
)
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(len(loaded))], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(len(loaded))], dim=0
)
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
if num_shards == 1:
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
"model.norm.weight": loaded["norm.weight"],
"lm_head.weight": loaded["output.weight"],
}
else:
concat_dim = 0 if is_llama_3(llama_version) else 1
state_dict = {
"model.norm.weight": loaded[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat(
[loaded[i]["tok_embeddings.weight"] for i in range(len(loaded))], dim=concat_dim
),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(len(loaded))], dim=0),
}
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
# Write configs
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
if is_llama_3(llama_version):
bos_token_id = 128000
if instruct:
eos_token_id = [128001, 128008, 128009]
else:
eos_token_id = 128001
else:
bos_token_id = 1
eos_token_id = 2
if llama_version in ["3.1", "3.2", "Guard-3"]:
rope_scaling = {
"factor": 32.0 if llama_version == "3.2" else 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
}
else:
rope_scaling = None
config = LlamaConfig(
hidden_size=dim,
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
num_key_value_heads=num_key_value_heads,
vocab_size=vocab_size,
rope_theta=base,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=True if llama_version in ["3.2"] else False,
)
config.save_pretrained(tmp_model_path)
generation_config = GenerationConfig(
do_sample=True,
temperature=0.6,
top_p=0.9,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
)
generation_config.save_pretrained(tmp_model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
gc.collect()
print("Loading the checkpoint in a Llama model.")
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
# Avoid saving this as part of the config.
del model.config._name_or_path
model.config.torch_dtype = torch.float16
print("Saving in the Transformers format.")
if push_to_hub:
print("Pushing to the hub.")
model.push_to_hub(model_path, safe_serialization=safe_serialization, private=True, use_temp_dir=True)
else:
print("Saving to disk.")
model.save_pretrained(model_path, safe_serialization=safe_serialization)
class Llama3Converter(TikTokenConverter):
def __init__(self, vocab_file, special_tokens=None, instruct=False, llama_version="3.2", **kwargs):
super().__init__(vocab_file, additional_special_tokens=special_tokens, **kwargs)
tokenizer = self.converted()
# References for chat templates in instruct models
templates_for_version = {
"2": ("meta-llama/Llama-2-7b-chat-hf", "f5db02db724555f92da89c216ac04704f23d4590"),
"3": ("meta-llama/Meta-Llama-3-8B-Instruct", "5f0b02c75b57c5855da9ae460ce51323ea669d8a"),
"3.1": ("meta-llama/Llama-3.1-8B-Instruct", "0e9e39f249a16976918f6564b8830bc894c89659"),
"3.2": ("meta-llama/Llama-3.2-1B-Instruct", "e9f8effbab1cbdc515c11ee6e098e3d5a9f51e14"),
"Guard-3": ("meta-llama/Llama-Guard-3-1B", "acf7aafa60f0410f8f42b1fa35e077d705892029"),
}
# Add chat_template only if instruct is True.
# Prevents a null chat_template, which triggers
# a parsing warning in the Hub.
additional_kwargs = {}
if instruct or llama_version in ["Guard-3"]:
model_id, revision = templates_for_version.get(llama_version, (None, None))
if model_id is not None:
from transformers import AutoTokenizer
t = AutoTokenizer.from_pretrained(model_id, revision=revision)
additional_kwargs["chat_template"] = t.chat_template
self.converted_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
bos_token="<|begin_of_text|>",
eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>",
model_input_names=["input_ids", "attention_mask"],
model_max_length=CONTEXT_LENGTH_FOR_VERSION[llama_version],
clean_up_tokenization_spaces=True,
**additional_kwargs,
)
self.update_post_processor(self.converted_tokenizer)
# finer special_tokens_map.json
self.converted_tokenizer._bos_token = BOS_ADDED_TOKEN
self.converted_tokenizer._eos_token = EOT_ADDED_TOKEN if instruct else EOS_ADDED_TOKEN
# We can't do this while building the tokenizer because we have no easy access to the bos token id
def update_post_processor(self, tokenizer):
tokenizer._tokenizer.post_processor = processors.Sequence(
[
processors.ByteLevel(trim_offsets=False),
processors.TemplateProcessing(
single="<|begin_of_text|> $A",
pair="<|begin_of_text|>:0 $A:0 <|begin_of_text|>:1 $B:1",
special_tokens=[
("<|begin_of_text|>", tokenizer.convert_tokens_to_ids("<|begin_of_text|>")),
],
),
]
)
def write_tokenizer(
tokenizer_path, input_tokenizer_path, llama_version="2", special_tokens=None, instruct=False, push_to_hub=False
):
print("Converting the tokenizer.")
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
if is_llama_3(llama_version):
tokenizer = Llama3Converter(
input_tokenizer_path,
special_tokens,
instruct,
llama_version,
).converted_tokenizer
else:
try:
tokenizer = tokenizer_class(input_tokenizer_path)
except Exception:
raise ValueError(
"Failed to instantiate tokenizer. Please, make sure you have sentencepiece and protobuf installed."
)
if push_to_hub:
print(f"Pushing a {tokenizer_class.__name__} to the Hub repo - {tokenizer_path}.")
tokenizer.push_to_hub(tokenizer_path, private=True, use_temp_dir=True)
else:
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
tokenizer.save_pretrained(tokenizer_path)
return tokenizer
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of Llama weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
"--model_size",
default=None,
help="'f' Deprecated in favor of `num_shards`: models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--push_to_hub",
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.",
action="store_true",
default=False,
)
parser.add_argument(
"--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`."
)
# Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
parser.add_argument(
"--llama_version",
choices=["1", "2", "3", "3.1", "3.2", "Guard-3"],
default="1",
type=str,
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
)
parser.add_argument(
"--num_shards",
default=None,
type=int,
help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth",
)
parser.add_argument(
"--special_tokens",
default=None,
type=List[str],
help="The list of special tokens that should be added to the model.",
)
parser.add_argument(
"--instruct",
action="store_true",
default=False,
help="Whether the model is an instruct model or not. Will affect special tokens and chat template.",
)
args = parser.parse_args()
if args.model_size is None and args.num_shards is None:
raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`")
if args.special_tokens is None:
# no special tokens by default
args.special_tokens = DEFAULT_LLAMA_SPECIAL_TOKENS.get(str(args.llama_version), [])
spm_path = os.path.join(args.input_dir, "tokenizer.model")
vocab_size = len(
write_tokenizer(
args.output_dir,
spm_path,
llama_version=args.llama_version,
special_tokens=args.special_tokens,
instruct=args.instruct,
push_to_hub=args.push_to_hub,
)
)
if args.model_size != "tokenizer_only":
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
model_size=args.model_size,
safe_serialization=args.safe_serialization,
llama_version=args.llama_version,
vocab_size=vocab_size,
num_shards=args.num_shards,
instruct=args.instruct,
push_to_hub=args.push_to_hub,
)
if __name__ == "__main__":
main()

You can then run it with the following command. Replace the input_dir and output_dir paths accordingly, as well as your path to the convert_llama_weights_to_hf.py file you've created.

python convert_llama_weights_to_hf.py \
              --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
          

5. Train!

Once you've got your custom prompt dataset and your converted weights, you can begin a training run with the following command. Replace the placeholders that look with your ports, directories, data paths, etc. It should take just a few hours.


            torchrun \
                --nproc_per_node=4 \
                --master_port=<your_random_port> \
                train.py \
                --model_name_or_path <your_path_to_hf_converted_llama_ckpt_and_tokenizer> \
                --data_path <./alpaca_data.json> \
                --bf16 True \
                --output_dir <your_output_dir> \
                --num_train_epochs 3 \
                --per_device_train_batch_size 4 \
                --per_device_eval_batch_size 4 \
                --gradient_accumulation_steps 8 \
                --evaluation_strategy "no" \
                --save_strategy "steps" \
                --save_steps 2000 \
                --save_total_limit 1 \
                --learning_rate 2e-5 \
                --weight_decay 0. \
                --warmup_ratio 0.03 \
                --lr_scheduler_type "cosine" \
                --logging_steps 1 \
                --fsdp "full_shard auto_wrap" \
                --fsdp_transformer_layer_cls_to_wrap 'LLaMADecoderLayer' \
                --tf32 True
            

Note: There is a helpful note about some common errors/issues here. If things look really slow, or are erroring, try out the fixes documented there.

Based on my experience, this will sit and idle for about 5 minutes while it prepares and tokenizes, and then prompt you to log into your Weights and Biases account— if you don't do that, it won't proceed, so don't just hit enter on the train command and then leave for a few hours! Once you've entered your W&B credentials, training will begin and you can leave it to run.

When your model is done training, you should have checkpoints and weights in your output_dir. Give it a quick test to see how it's doing and make sure it's working!

model = AutoModelForCausalLM.from_pretrained(directory)
          tokenizer = AutoTokenizer.from_pretrained(directory)
          model = model.half() #Use fp16
          model = model.to("cuda") # move to GPU
          
          tokenized_text = tokenizer("<Add example prompt here>", return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True)
          
          full_completion = model.generate(inputs=tokenized_text["input_ids"].to("cuda"),
              attention_mask=tokenized_text["attention_mask"].to("cuda"),
              temperature=0.75,
              top_p=0.85,
              top_k=80,
              do_sample=True,
              num_beams=3,
              max_new_tokens=600,
              eos_token_id=tokenizer.eos_token_id,
              pad_token_id=tokenizer.pad_token_id,
              repetition_penalty=1)
          
          decoded_text = tokenizer.decode(full_completion[0])
          

Deploying the model with Modal

Quick plug: I cannot say enough good things about Modal, a tool that lets you write code locally and deploy it to the cloud without managing any infrastructure or config. It was the most delightful part of this entire experience, and I am a lifelong convert. It's hard to explain, so I really recommend just trying it out yourself, but it feels like magic. Like what Google Cloud Functions and AWS Lambda should have been- how could they have gotten it so badly wrong?

I didn't know how great Modal was when I picked it though, so I just chose it because it was cheap, scaled to zero (important since this was a toy project that would probably be lightly used), and had serverless GPUs.

Building a web endpoint to deploy my models was really easy. Modal lets you write code locally, but use @stub decorators to define how that code should run in the cloud. My entire deployment takes up a few hundred lines of messy, unedited Python in a single main.py file:

import modal
import os
#create a shared volume to store weights
volume = modal.SharedVolume().persist("robo-boys-vol")
#create a modal "stub" to handle config for functions
stub = modal.Stub(
"robo-boys-predict",
image=modal.Image.debian_slim().pip_install("numpy",
"rouge-score",
"fire",
"torch",
"sentencepiece",
"firebase-admin",
"tokenizers").apt_install('git').run_commands('pip install git+https://github.com/huggingface/transformers')
)
#This is a one time function to download my weights. I'd probably use the Modal CLI for this next time.
@stub.function(shared_volumes={"/models": volume},secrets=[modal.Secret.from_name("robo-boys-secrets")])
def download_model():
print('downloading model from aws')
os.system(f"ls /models")
os.system("ls")
os.system('aws configure list')
os.system(f"aws s3 cp --recursive s3://path/to/your/checkpoint /models/model")
print('downloaded model from aws')
class MessagePrediction:
def __enter__(self):
import transformers
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
import json
service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"])
cred = credentials.Certificate(service_account_info)
app = firebase_admin.initialize_app(cred)
# Create a Firestore client
self.db = firestore.client()
m_inter = transformers.LlamaForCausalLM.from_pretrained("/models/model")
self.tokenizer = transformers.AutoTokenizer.from_pretrained("/models/model")
m_inter = m_inter.half()
self.model = m_inter.to("cuda")
@stub.function(gpu=modal.gpu.A10G(count=1), shared_volumes={"/models": volume},secrets=[modal.Secret.from_name("firebase-svc")],container_idle_timeout=1200,timeout=500,concurrency_limit=1)
def create_conversation(self,init_context: str,wake: bool):
import random
import traceback
if wake: # just a way to wake up this function!
return
ctx = ''
# conditionally get 'background' context on chat if desired, helpful to keep conversations going across multiple prompts.
background = self.get_firestore_context()
if len(background) > 0:
ctx = background + '\n' + init_context
else:
ctx = init_context
print(ctx)
counter = 0
backup_counter = 0
most_recent_sender = init_context.split(":")[0]
most_recent_message = init_context.split(":")[1]
# quick and dirty loop to generate an entire conversation. These probabilities are based off the actual distribution of messages in the chat archive.
while counter <= 12 and backup_counter <= 40:
try:
backup_counter += 1 #prevent infinite loops due to reaction chains
characters = ['Wyatt','Kiebs','Izzy','Luke','Harvey','Henry']
character_probabilities = [0.15,0.4,0.6,0.1,0.3,0.6]
most_recent_index = characters.index(most_recent_sender)
if counter == 0:
character_probabilities[most_recent_index] = 0
else:
character_probabilities[most_recent_index] += .2
most_recent_referenced = ''
if 'adam' in most_recent_message or 'Adam' in most_recent_message or 'kiebs' in most_recent_message or 'Kiebs' in most_recent_message:
most_recent_referenced = 'Kiebs'
elif 'wyatt' in most_recent_message or 'Wyatt' in most_recent_message:
most_recent_referenced = 'Wyatt'
elif 'izzy' in most_recent_message or 'Izzy' in most_recent_message or 'iz' in most_recent_message:
most_recent_referenced = 'Izzy'
elif 'luke' in most_recent_message or 'Luke' in most_recent_message:
most_recent_referenced = 'Luke'
elif 'harv' in most_recent_message or 'Harv' in most_recent_message:
most_recent_referenced = 'Harvey'
elif 'hen' in most_recent_message or 'Hen' in most_recent_message or 'Hank' in most_recent_message:
most_recent_referenced = 'Henry'
if len(most_recent_referenced) > 0:
referenced_index = characters.index(most_recent_referenced)
character_probabilities[referenced_index] += .7
character = random.choices(characters,character_probabilities)[0]
res = self.predict(context=ctx,character=character)
temp = ''
for i in res.split("###")[-2:]:
temp += i
if len(temp.split("Response:")) < 2:
print(temp)
print('split: ',temp.split("Response:"))
print('no completion generated, skipping')
continue
temp = temp.split("Response:")[1]
temp = temp.replace("</s>","")
if u'\uFFFC' in temp: #this is the character used to represent images in the model, unnecessary if you cleaned them out prior.
continue
if 'https://' in temp:
print('just link, not incrementing counter')
continue
if 'Loved' in temp or 'Laughed' in temp or 'Disliked' in temp or 'Emphasized' in temp or 'Liked' in temp:
print('suppressing reaction')
continue
m = self.dispatch_msg_to_firestore(temp,character)
text = f"{character}: {temp}"
ctx += '\n'
ctx += text
ctx = ctx.replace(' "Input:\n','')
counter+= 1
except Exception:
print('failed message',traceback.format_exc())
break
return ctx
#exists only to wake the container
@stub.function(gpu=modal.gpu.A10G(count=1), shared_volumes={"/models": volume},secrets=[modal.Secret.from_name("firebase-svc")],container_idle_timeout=1200,timeout=500,concurrency_limit=1)
def wake():
print('waking up')
def predict(self, context: str,character: str,temp=.8):
#unnecessary old code from alpaca, just haven't cleaned it yet.
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
PROMPT = PROMPT_DICT['prompt_input']
characters = ['Wyatt','Kiebs','Izzy','Luke','Harvey','Henry']
others = list(filter(lambda x: x!=character,characters))
prompt = f"Your name is {character}. You are in a group chat with 5 of your best friends: {', '.join(others)}. You all went to college together, and talk to each other with no filter. You are encouraged to curse, say generally inappropriate and hilarious things, or be very rude. Everything is in good fun, so remember to joke and laugh, and be funny, but also answer all questions clearly. You will be presented with the most recent messages in the group chat. Write a response to the conversation as {character}."
format_prompt = PROMPT.format_map({'instruction': prompt,'input':context})
tokenized_text = self.tokenizer(format_prompt, return_tensors="pt", padding="longest", max_length=self.tokenizer.model_max_length, truncation=True)
full_completion = self.model.generate(inputs=tokenized_text["input_ids"].to("cuda"),
attention_mask=tokenized_text["attention_mask"].to("cuda"),
temperature=.75,
top_p=0.85,
top_k=80,
do_sample=True,
num_beams=3,
max_new_tokens=600,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1)
decoded_text = self.tokenizer.decode(full_completion[0])
return decoded_text
def dispatch_msg_to_firestore(self,message,sender):
from datetime import datetime,timezone
import time
# I delay to make the conversation more realistic on the front-end. Could save a ton of money probably by doing this delay on the frontend instead!
time.sleep(0.25)
senders = {
'Henry': {
'uid': 'fake-henry',
'photo': 'https://i.imgur.com/wdXWHz2.jpg',
'email': 'fake@email.com',
'displayName': 'Henry'
},
'Harvey': {
'uid': 'fake-harvey',
'photo': 'https://i.imgur.com/sU8Codw.jpg',
'email': 'fake@email.com',
'displayName': 'Harvey'
},
'Luke': {
'uid': 'fake-luke',
'photo': 'https://i.imgur.com/U645ciG.jpg',
'email': 'fake@email.com',
'displayName': 'Luke'
},
'Izzy': {
'uid': 'fake-izzy',
'photo': 'https://i.imgur.com/wUGEnVb.jpg',
'email': 'fake@email.com',
'displayName': 'Izzy'
},
'Kiebs': {
'uid': 'fake-kiebs',
'photo': 'https://i.imgur.com/ESoUipA.png',
'email': 'fake@email.com',
'displayName': 'Kiebs'
},
'Wyatt': {
'uid': 'fake-wyatt',
'photo': 'https://i.imgur.com/9yPKaac.jpg',
'email': 'fake@email.com',
'displayName': 'Wyatt'
}
}
sender = senders[sender]
chat_doc_ref = self.db.collection('chats').document('<chatdb>')
chat_messages_ref = chat_doc_ref.collection('messages')
create_time, doc_ref = chat_messages_ref.add({
'timestamp': datetime.now(timezone.utc),
'message': message,
'uid': sender['uid'],
'photo': sender['photo'],
'email': sender['email'],
'displayName': sender['displayName'],
})
return create_time
def get_firestore_context(self):
from firebase_admin import firestore
from datetime import datetime, timedelta,timezone
chat_doc_ref = self.db.collection('chats').document('<chatdb>')
chat_messages_ref = chat_doc_ref.collection('messages')
most_recent_message = chat_messages_ref.order_by('timestamp', direction=firestore.Query.DESCENDING).limit(1).get()[0]
message_timestamp = most_recent_message.get('timestamp')
current_time = datetime.now(timezone.utc)
time_diff = current_time - message_timestamp
if time_diff <= timedelta(minutes=4):
messages = chat_messages_ref.order_by('timestamp', direction=firestore.Query.DESCENDING).limit(10).get()
ctx = ''
prev = ''
for i in messages:
raw = i.to_dict()
if prev == raw['message']:
return ''
msg = f"{raw['displayName']} : {raw['message']}"
ctx += msg
ctx += '\n'
prev = raw['message']
return ctx
else:
return ''
# just for testing
@stub.webhook
def get_completion(context: str):
from fastapi.responses import HTMLResponse
convo = MessagePrediction().create_conversation.call(init_context=context, wake=False)
to_render = convo.replace("\n", "<br />")
return HTMLResponse(to_render)
@stub.webhook(label="alive", image=modal.Image.debian_slim())
def check_alive():
print('Checking status of GPU container')
status = MessagePrediction().create_conversation.get_current_stats()
return status
@stub.webhook(label="wake")
def wake():
MessagePrediction().create_conversation.spawn(init_context='wake', wake=True)
print('waking up container')
view raw main.py hosted with ❤ by GitHub

Some key excerpts:

Modal lets you define container environments using simple config in the @stub.function() decorator. To run a particular function in the cloud using a GPU, attached to a cloud storage volume, referencing some stored secrets, and more, this is literally all the configuration required. It's insane.

@stub.function(gpu=modal.gpu.A10G(count=1), shared_volumes={"/models": volume},secrets=[modal.Secret.from_name("firebase-svc")],container_idle_timeout=1200,timeout=500,concurrency_limit=1)
   def create_conversation(self,init_context: str,wake: bool):
        ...

Cold starts are a big time suck, because this model is large and the weights take a long time to load- on the order of a few minutes. I could probably fix this by using a newer architecture, or just making the model smaller, but since this was a weekend project I opted to fix it by adding a “wake” endpoint I could use to wake up a container and prep a GPU.

@stub.webhook(label="alive", image=modal.Image.debian_slim())
def check_alive():
   print('Checking status of GPU container')
   status = MessagePrediction().create_conversation.get_current_stats()
   return status

@stub.webhook(label="wake")
def wake():
   MessagePrediction().create_conversation.spawn(init_context='wake', wake=True)
   print('waking up container')

I could have simply kept a pre warmed pool of containers for better performance, but it costs $$ to keep GPUs lying around, and since this is just for fun, I figured waiting a few minutes to spin up a session was fine. Modal makes this really easy with Container Lifecycle methods. Whenever something from class MessagePrediction is called (like my wake() function), a container is spun up and the code in __enter__ is run. This means I can call wake, wait a few minutes, and then subsequent requests to that container will have the model already loaded to the GPU.

class MessagePrediction:
   def __enter__(self):
       import transformers
       import firebase_admin
       from firebase_admin import credentials
       from firebase_admin import firestore
       import json

       service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"])
       cred = credentials.Certificate(service_account_info)
       app = firebase_admin.initialize_app(cred)

       # Create a Firestore client
       self.db = firestore.client()

       m_inter = transformers.LlamaForCausalLM.from_pretrained("/models/model")
       self.tokenizer = transformers.AutoTokenizer.from_pretrained("/models/model")

       m_inter = m_inter.half()
       self.model = m_inter.to("cuda")

I spent a lot of time experimenting with the model parameters, and settled on the following.


  full_completion = self.model.generate(inputs=tokenized_text["input_ids"].to("cuda"),
            attention_mask=tokenized_text["attention_mask"].to("cuda"),
            temperature=.75,
            top_p=0.85,
            top_k=80,
            do_sample=True,
            num_beams=3,
            max_new_tokens=600,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            repetition_penalty=1)

I'm using beam search here, which "keeps several hypotheses at each time step and eventually chooses the hypothesis that has the overall highest probability for the entire sequence." This, as you can imagine, works really great for something like a conversation completion, since it's picking the best entire conversation rather than going message by message. I highly recommend you read more about the different text generation strategies in the Transformers documentation.

So now I can do inference on my custom model using an HTTP endpoint! And it's hilarious. I deployed it in dev (again, literally just by running modal serve main.py, that's it) and left it foate for quite a few hours just cracking myself up playing with it:


the robo boys debate the merits of the bill of rights

There's something so delightful about capturing the voice of your friends perfectly- it's not quite nostalgia, since the conversations never happened, but it's a similar sense of glee.

Building a front end

After a few hours of enjoying myself thoroughly, I really wanted to show this to… The Group Chat! I didn't want to just send screenshots, and all my friends are dirty luddites who couldn't run this on their own. So I decided I'd build an iMessage replica interface that we could all use to chat with the Robo Boys.

I thought about just using Twilio or something to really create another Group Chat with the model, but this seemed really expensive and complicated. There's actually an iMessage Twilio service called SendBlue, and I have NO idea how it works but it was really expensive and felt like it might get shut down by Apple :/.

There are a ton of “iMessage Clone” projects floating around on GitHub. I picked this one by sakilk130 and started customizing it for my purposes. It wound up being pretty damn simple. You are welcome to clone my clone, but be forewarned, i customized it wantonly in about 45 minutes without any thought to cleanliness or future dev work.

Nearly all of the custom logic lives in Chat.jsx:

import React, { useEffect, useState, useRef } from 'react';
import './Chat.css';
import { IconButton } from '@material-ui/core';
import ArrowCircleUpIcon from '@mui/icons-material/ArrowCircleUp';
import Message from './Message/Message';
import { useSelector } from 'react-redux';
import { selectUser } from '../../../features/userSlice';
import db from '../../../firebase/config';
import firebase from 'firebase';
import FlipMove from 'react-flip-move';
function Chat() {
const [input, setInput] = useState('');
const [messages, setMessages] = useState([]);
const [lastMessage,setLastMessage] = useState(false)
const [nextPosts_loading, setNextPostsLoading] = useState(false);
const [mostRecentMessageType, setMostRecentMessageType] = useState('new')
const [containerStatus, setContainerStatus] = useState('asleep')
const user = useSelector(selectUser);
const chatName = 'robo boys' //hardcode this for now, just one chat
const chatId = '<chatId>' //hardcode this for now, just one chat
const lastMsgRef = useRef(null)
const firstMsgRef = useRef(null)
const updateContainerStatus = () => {
fetch('your-modal-alive-endpoint-here/')
.then(response => response.json())
.then(data => {
if(data.num_total_runners === 0) {
setContainerStatus('asleep')
} else if(data.num_total_runners >= 1 && data.num_active_runners >= 1) {
setContainerStatus('awake')
} else if(data.backlog >= 1 ) {
setContainerStatus('waking up')
}
});
}
const wakeContainer = () => {
fetch('your-modal-wake-endpoint-here/').then(setContainerStatus('waking up'))
}
const firstPosts = () => {
try {
db.collection('chats')
.doc(chatId)
.collection('messages')
.orderBy('timestamp', 'desc').limit(50)
.onSnapshot((snapshot) =>{
const d = snapshot.docs.reverse()
setMessages(
d.map((doc) => ({ id: doc.id, data: doc.data() }))
)
setLastMessage(
d[0].data().timestamp
)
setMostRecentMessageType('new')
}
);
} catch(e) {
console.log(e)
}
}
const nextPosts = (key) => {
setNextPostsLoading(true);
try {
db.collection('chats')
.doc(chatId)
.collection('messages')
.orderBy('timestamp', 'desc')
.startAfter(key)
.limit(25)
.onSnapshot((snapshot) => {
const newMsgs = snapshot.docs.reverse().map((doc) => ({ id: doc.id, data: doc.data() }))
setMessages(
[...newMsgs,...messages ]
)
setLastMessage(
newMsgs[0].data.timestamp
)
}
);
setMostRecentMessageType('old')
setNextPostsLoading(false);
} catch (e) {
console.log(e);
setNextPostsLoading(false);
}
}
useEffect(() => {
if (chatId) {
firstPosts();
updateContainerStatus();
}
}, [chatId]);
useEffect(() => {
const interval = setInterval(() => {
updateContainerStatus();
}, 5000);
return () => clearInterval(interval); // This represents the unmount function, in which you need to clear your interval to prevent memory leaks.
}, [])
const scrollToBottom = () => {
lastMsgRef.current?.scrollIntoView({ behavior: "smooth" })
}
const scrollToTop = () => {
firstMsgRef.current?.scrollIntoView({ behavior: "smooth" })
}
const Status = () => {
if(containerStatus === 'awake') {
return (<div className='bot__status'>
<p>bot status: <b>awake</b></p>
</div>)
} else if(containerStatus === 'asleep') {
return(
<div className='bot__status'>
<p>bot status: <b>asleep. <p className='wakeup' onClick={() => {wakeContainer()}}>wake them up?</p></b></p>
</div>
)
} else if(containerStatus === 'waking up') {
return (<div className='bot__status'><p>bot status: <b>waking up</b>. May take up to 5 minutes</p></div>)
}
}
useEffect(() => {
if(mostRecentMessageType === 'new'){
scrollToBottom()
} else {
scrollToTop()
}
}, [messages]);
const sendMessage = (e) => {
e.preventDefault();
db.collection('chats').doc(chatId).collection('messages').add({
timestamp: firebase.firestore.FieldValue.serverTimestamp(),
message: input,
uid: user.uid,
photo: user.photo,
email: user.email,
displayName: user.displayName,
});
setInput('');
return false
};
return (
<div className="chat">
<div className="chat__header">
<div className="header__icon">
<p>🤖</p>
<p>
{chatName}
</p>
<Status />
</div>
</div>
{/* Chat messages */}
<div className="chat__messages">
<div className='chat__loading' ref={firstMsgRef}>
{nextPosts_loading ? (
<p>Loading..</p>
) : lastMessage ? (
<button className='chat__more_button' onClick={() => nextPosts(lastMessage)}>Load older messages</button>
) : (
<span>No more messages</span>
)}
</div>
<FlipMove>
{messages.map(({ id, data }) => (
<Message key={id} id={id} contents={data} />
))}
</FlipMove>
<div ref={lastMsgRef} />
</div>
{/* Chat input*/}
<div className="chat__input">
<form method="POST" onSubmit={(e) => sendMessage(e)}>
<input
type="text"
placeholder="iMessage"
value={input}
onChange={(e) => setInput(e.target.value)}
/>
</form>
<div className='send__button'>
<IconButton size="small" color="primary" onClick={sendMessage}>
<ArrowCircleUpIcon />
</IconButton>
</div>
</div>
</div>
);
}
export default Chat;
view raw Chat.jsx hosted with ❤ by GitHub

I used Firebase here because I still can't find anything that's as easy to bolt on that handles auth and a database that scales to zero. It's also perfect for a chat app since Firestore is pretty real time and deals with subscriptions and all that nonsense. Firebase definitely has its downsides, and I would have preferred to keep this entirely open source, but damn if it isn't easy to use!

And that's it!

I deployed this (with Firebase hosting, again, free, why not) and saved it as a PWA on my phone. I showed my friends how to do that, and now we all have access to the same “Group Chat” with the AI bots.

This has genuinely provided more hours of deep enjoyment for me and my friends than I could have imagined. Something about the training process optimized for outrageous behavior, and seeing your conversations from a third-person perspective casts into stark relief how ridiculous and hilarious they can be.



A downright classic conversation about who drank Henry's beer

It really, really nailed the voice and perspectives of my friends, and actually retains a ton of information on their preferences, lives, etc. I had considered attaching an embedding database (like Chroma) to actually give the boys a knowledge store, but found this to be unnecessary. They know who we each are dating, what we like to do, and most importantly...


Alan hupp was our college landlord!

I really encourage everyone to clone this project and follow this tutorial, or do a similarly pointless yet complicated AI project like this. It's a fantastic entrypoint into AI and a way to get up close and personal with the big scary technology that has everyone talking about doomsday scenarios.

On a technical level, I found it really helped me wrap my head around what LLMs are doing and how they can be tuned for specific scenarios. Of course, it was also just overall really fun. Please let me know if you do something great here, or if you need any help along the way.

I'm also happy to do this for anyone as a service, for probably somewhere in the few-hundred-bucks range. I promise not to read your group chat. DM me if you're interested.

Let me know what you think @isidoremiller on twitter, and thanks for reading 🙇‍♂️.