├── README.md ├── bench.py └── viz.py /README.md: -------------------------------------------------------------------------------- 1 | # embedland 2 | Theoretically this is a universe of code for playing with embeddings. In reality it contains one file. More to come, I hope. 3 | 4 | ![](https://user-images.githubusercontent.com/279531/221034510-aa4084a9-86dd-4ddc-99de-8718acd211b4.png) 5 | 6 | ### bench.py 7 | This file benchmarks various embeddings using the Enron email corpus. Once you install the various libraries it needs, you can run it with python bench.py. It will: 8 | * Download the Enron email dataset. 9 | * Unzip it. 10 | * Attempt to run embeddings on it (with OpenAI's embedder as a default, you can change that at the end of the file to T5, or some other engine.) 11 | * Cluster the embeddings. 12 | * Label the clusters by sampling the subject lines from the clusters and sending them to GPT-3. 13 | * Show you a pretty chart, like the one you see above. 14 | 15 | ### viz.py 16 | Visualization helper. This file helps you go from "a list of embeddings" to "something pretty to look at". 17 | 18 | ### TODO: 19 | * Make longer embeddings work by chunking and averaging out the results. 20 | -------------------------------------------------------------------------------- /bench.py: -------------------------------------------------------------------------------- 1 | # Compare different embedding methods. 2 | import os 3 | import hashlib 4 | import email 5 | import email.policy 6 | import tqdm 7 | import time 8 | import random 9 | import numpy as np 10 | from sklearn.metrics.pairwise import cosine_similarity # for testing 11 | from sklearn.cluster import KMeans 12 | from sklearn.manifold import TSNE 13 | import matplotlib.pyplot as plt 14 | import openai 15 | import pandas as pd 16 | import plotly.express as px 17 | import tiktoken 18 | import numpy as np 19 | from itertools import islice 20 | from transformers import T5Tokenizer 21 | from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_not_exception_type 22 | import torch 23 | import re 24 | import dbm 25 | 26 | # Or however you want it. 27 | openai.api_key = open(os.path.expanduser('~/.openai')).read().strip() 28 | 29 | CLUSTER_COUNT = 10 30 | EMAIL_DATASET_COUNT = 10000 31 | CUDA_SUPPORT = torch.cuda.is_available() 32 | print("CUDA available:", CUDA_SUPPORT) 33 | 34 | OPENAI_EMBEDDING_MODEL = 'text-embedding-ada-002' 35 | OPENAI_EMBEDDING_CTX_LENGTH = 8191 36 | OPENAI_EMBEDDING_ENCODER = tiktoken.get_encoding('cl100k_base') 37 | 38 | T5_TOKENIZER = T5Tokenizer.from_pretrained("t5-large") 39 | T5_EMBEDDING_CTX_LENGTH = 512 40 | 41 | _cache_dbm = dbm.open('cache.dbm', 'c') 42 | 43 | 44 | # TODO Fix serialization so this function isn't so silly. 45 | def list_disk_cache(namespace): 46 | """Function decorator to cache function results to disk. Only for list items.""" 47 | def decorator(func): 48 | def wrapper(*args, **kwargs): 49 | key = hashlib.md5(str(args).encode() + 50 | str(kwargs).encode()).hexdigest() 51 | key = namespace + ':' + key 52 | if key in _cache_dbm: 53 | return [float(x) for x in str(_cache_dbm[key])[3:-2].split(', ')] 54 | result = func(*args, **kwargs) 55 | # Don't be a meanie, I can only do lists! 56 | assert isinstance(result, list) 57 | _cache_dbm[key] = str(result) 58 | return result 59 | return wrapper 60 | return decorator 61 | 62 | 63 | # Helper functions to lazy load various models. 64 | _t5_model = None 65 | 66 | 67 | def get_t5_model(): 68 | global _t5_model 69 | if _t5_model is None: 70 | from transformers import T5Model 71 | print("Loading T5 model...") 72 | model_name = "t5-large" 73 | tokenizer = T5_TOKENIZER 74 | model = T5Model.from_pretrained(model_name).cuda() 75 | 76 | _t5_model = (tokenizer, model) 77 | return _t5_model 78 | 79 | 80 | _st_model = None 81 | 82 | 83 | def get_sentence_tranformers(model): 84 | global _st_model 85 | if _st_model is None: 86 | print("Loading SentenceTransformers model %s..." % model) 87 | from sentence_transformers import SentenceTransformer 88 | _st_model = SentenceTransformer(model) 89 | return _st_model 90 | 91 | 92 | def t5_encode(text): 93 | tokens = T5_TOKENIZER.encode( 94 | text, return_tensors="pt", max_length=512, truncation=True) 95 | return tokens.cuda() if CUDA_SUPPORT else tokens 96 | 97 | # Helper functions to chunk larger inputs into smaller ones. 98 | 99 | 100 | def batched(iterable, n): 101 | """Batch data into tuples of length n. The last batch may be shorter.""" 102 | # batched('ABCDEFG', 3) --> ABC DEF G 103 | if n < 1: 104 | raise ValueError('n must be at least one') 105 | it = iter(iterable) 106 | while (batch := tuple(islice(it, n))): 107 | yield batch 108 | 109 | 110 | def chunked_tokens(text, encoder_fn, chunk_length): 111 | tokens = encoder_fn(text) 112 | chunks_iterator = batched(tokens, chunk_length) 113 | yield from chunks_iterator 114 | 115 | 116 | def chunked_text(text, chunk_length, tokens_per_word=2.5): 117 | words = text.split(' ') 118 | chunks_iterator = batched(words, int(chunk_length / tokens_per_word)) 119 | # when the we have a chunk of words, we join them back into a string 120 | yield from map(lambda chunk: ' '.join(chunk), chunks_iterator) 121 | 122 | 123 | def get_long_embedding(text, embedding_fn, max_tokens=None, encoder_fn=None, average=True): 124 | assert max_tokens is not None 125 | assert encoder_fn is not None 126 | chunk_embeddings = [] 127 | chunk_lens = [] 128 | for chunk in chunked_tokens(text, encoder_fn=encoder_fn, chunk_length=max_tokens): 129 | chunk_embeddings.append(embedding_fn(chunk)) 130 | chunk_lens.append(len(chunk)) 131 | 132 | if average: 133 | chunk_embeddings = np.average( 134 | chunk_embeddings, axis=0, weights=chunk_lens) 135 | chunk_embeddings = chunk_embeddings / \ 136 | np.linalg.norm(chunk_embeddings) # normalizes length to 1 137 | chunk_embeddings = chunk_embeddings.tolist() 138 | return chunk_embeddings 139 | 140 | # Method 1: Get embeddings using T5 directly. # TODO: max pooling voodoo. 141 | 142 | 143 | def get_embedding_t5(text): 144 | tokenizer, model = get_t5_model() 145 | tokens = t5_encode(text) 146 | attn = tokens != tokenizer.pad_token_id 147 | output = model.encoder( 148 | input_ids=tokens, attention_mask=attn, return_dict=True) 149 | # Compute the mean of the last hidden state over the non-padded tokens. I think this is what they did in that paper, but I'm not sure... 150 | embedding = (output.last_hidden_state * attn.unsqueeze(-1) 151 | ).sum(dim=-2) / attn.sum(dim=-1) 152 | return embedding.detach().cpu().numpy()[0] 153 | 154 | # Method 2: Use SentenceTransformers. 155 | 156 | 157 | def get_embedding_st(text, engine): 158 | model = get_sentence_tranformers(engine) 159 | if random.random() < 0.01: 160 | tokens = model.tokenize(text)['input_ids'] 161 | sample_text = text[:100].replace('\n', ' ') 162 | print( 163 | f"sample: len={len(text)}, num_tokens={len(tokens)}, max_len={model.max_seq_length}, text={sample_text}") 164 | 165 | return model.encode([text])[0] 166 | 167 | # Method 3: Use OpenAI's Embedding API 168 | 169 | 170 | @list_disk_cache("openai-embeddings") 171 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(openai.InvalidRequestError)) 172 | def get_embedding_openai(text_or_tokens, model=OPENAI_EMBEDDING_MODEL): 173 | # First determine the length of this text in tokens. 174 | if isinstance(text_or_tokens, str): 175 | tokens = OPENAI_EMBEDDING_ENCODER.encode(text_or_tokens) 176 | else: 177 | tokens = text_or_tokens 178 | if len(tokens) > OPENAI_EMBEDDING_CTX_LENGTH: 179 | tokens = tokens[:OPENAI_EMBEDDING_CTX_LENGTH] 180 | return openai.Embedding.create(input=tokens, model=model)["data"][0]["embedding"] 181 | 182 | # Get embeddings. If "long_mode" is True, then we will chunk the input into smaller pieces and average the embeddings. 183 | 184 | 185 | def get_embeddings(text, engine, long_mode=False): 186 | max_tokens = None 187 | encoder_fn = None 188 | if engine == "saved": 189 | return np.load("01-embeddings.npy") 190 | 191 | if not long_mode: 192 | # TODO To make this a fair test, I should limit the length of the input to the same as the other models. 193 | if engine == "openai": 194 | return get_embedding_openai(text) 195 | elif engine == "t5": 196 | return get_embedding_t5(text) 197 | elif engine.startswith("sentence-transformers/"): 198 | return get_embedding_st(text, engine) 199 | else: 200 | raise ValueError(f"Unknown engine: {engine}") 201 | else: 202 | if engine == "openai": 203 | fn = get_embedding_openai 204 | max_tokens = OPENAI_EMBEDDING_CTX_LENGTH 205 | encoder_fn = OPENAI_EMBEDDING_ENCODER.encode 206 | return get_long_embedding(text, fn, max_tokens=max_tokens, encoder_fn=encoder_fn) 207 | elif engine == "t5": 208 | fn = get_embedding_t5 209 | max_tokens = T5_EMBEDDING_CTX_LENGTH 210 | encoder_fn = get_long_embedding( 211 | text, fn, max_tokens=max_tokens, encoder_fn=encoder_fn) 212 | elif engine.startswith("sentence-transformers/"): 213 | # TODO: I need to wrap SentenceTransformer in a subclass, that, when called, handle tokens_or_text, and not just text. 214 | raise NotImplementedError( 215 | "Long mode not implemented for SentenceTransformers") 216 | else: 217 | raise ValueError(f"Unknown engine: {engine}") 218 | 219 | 220 | def download_dataset(): 221 | dataset_link = "https://www.cs.cmu.edu/~./enron/enron_mail_20150507.tar.gz" 222 | if not os.path.exists("data/enron_mail_20150507.tar.gz"): 223 | print("Downloading dataset...") 224 | os.system("mkdir -p data") 225 | os.system("wget -P data/ " + dataset_link) 226 | else: 227 | print("Dataset already downloaded!") 228 | if not os.path.exists("data/maildir"): 229 | print("Extracting dataset...") 230 | os.system("tar -xzf data/enron_mail_20150507.tar.gz -C data/") 231 | else: 232 | print("Dataset already extracted!") 233 | 234 | 235 | def get_all_files(path): 236 | all_files = [] 237 | for root, dirs, files in os.walk(path): 238 | files = [os.path.join(root, name) for name in files] 239 | all_files.extend(files) 240 | return all_files 241 | 242 | 243 | def get_emails(count=EMAIL_DATASET_COUNT): 244 | emails = [] 245 | email_paths = get_all_files("data/maildir") 246 | email_paths = email_paths[::len(email_paths)//count] 247 | for file_name in email_paths: 248 | with open(file_name, "rb") as fp: 249 | try: 250 | msg = email.message_from_binary_file( 251 | fp, policy=email.policy.default) 252 | emails.append(msg) 253 | except: 254 | pass 255 | return emails 256 | 257 | 258 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(openai.InvalidRequestError)) 259 | def openai_completion(query): 260 | return openai.Completion.create( 261 | engine="text-davinci-003", 262 | prompt=query, 263 | max_tokens=10, 264 | temperature=0.1, 265 | top_p=1, 266 | stop="Label:" 267 | ) 268 | 269 | 270 | def get_label(cluster, labels, emails): 271 | # Get the indices of the emails in the cluster 272 | indices = np.where(labels == cluster)[0] 273 | # Sample every Nth email (assuming subject is not None) 274 | samples = [] 275 | for i in indices: 276 | if emails[i]["subject"] is not None: 277 | samples.append(i) 278 | if len(samples) >= 10: 279 | break 280 | # Construct the query for OpenAI 281 | query = "The following are email subjects from the same cluster. Please provide a short label that describes the common theme or topic of the cluster.\n\n" 282 | for sample in samples: 283 | query += "- " + emails[sample]["subject"] + "\n" 284 | 285 | query += "\nLabel:" 286 | # Call the OpenAI API 287 | response = openai_completion(query) 288 | # Return the label 289 | return response["choices"][0]["text"].strip() 290 | 291 | 292 | def plot_ploty(embeddings_2d, labels, labels_dict, file_name): 293 | df = pd.DataFrame( 294 | {"x": embeddings_2d[:, 0], "y": embeddings_2d[:, 1], "label": labels}) 295 | df["label"] = df["label"].map(labels_dict) 296 | fig = px.scatter(df, x="x", y="y", color="label") 297 | fig.show() 298 | # save the image 299 | fig.write_image(file_name, width=1920, height=1080) 300 | 301 | 302 | def run_embedding_test(engine): 303 | download_dataset() 304 | print("Getting emails...") 305 | emails = get_emails() 306 | # Concat all email IDs and print a hash 307 | embeddings = [] 308 | print("Getting embeddings...") 309 | for msg in tqdm.tqdm(emails): 310 | subject = msg["subject"] or "" 311 | body = msg.get_body(preferencelist=("plain",)) 312 | body = body.get_content() if body else "" 313 | if not body: 314 | continue 315 | # TODO: Should I use a separator token here? Who knows. 316 | text = subject + "\n" + body 317 | # text = re.sub(r'\s+', ' ', text) # Is this a good idea? Aren't newlines bad for embedding performance? Should test. 318 | embeddings.append(get_embeddings(text, engine)) 319 | embeddings = np.array(embeddings) 320 | print("Clustering...") 321 | kmeans = KMeans(n_clusters=CLUSTER_COUNT, random_state=42) 322 | labels = kmeans.fit_predict(embeddings) 323 | # Use t-SNE to reduce the dimensionality and visualize the clusters 324 | tsne = TSNE(n_components=2, random_state=42) 325 | embeddings_2d = tsne.fit_transform(embeddings) 326 | # Get the labels for each cluster 327 | print("Getting labels...") 328 | labels_dict = {} 329 | for cluster in tqdm.tqdm(range(CLUSTER_COUNT)): 330 | label = get_label(cluster, labels, emails) 331 | labels_dict[cluster] = label 332 | email_ids = [msg["message-id"] for msg in emails] 333 | hashbit = hashlib.sha256("".join(email_ids).encode()).hexdigest()[-5:] 334 | engine_filename = engine.replace("/", "-") 335 | file_name = f'{hashbit}-{engine_filename}-cluster{CLUSTER_COUNT}-email{EMAIL_DATASET_COUNT}' 336 | np.save(file_name + '-embeddings.npy', embeddings) 337 | plot_ploty(embeddings_2d, labels, labels_dict, file_name + '.png') 338 | 339 | 340 | start_time = time.time() 341 | # openai, sentence-transformers/all-mpnet-base-v2, sentence-transformers/gtr-t5-large (which should be T5) 342 | run_embedding_test('openai') 343 | print("Time taken: ", time.time() - start_time) 344 | -------------------------------------------------------------------------------- /viz.py: -------------------------------------------------------------------------------- 1 | """Tools to visualize embeddings.""" 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | import plotly.express as px 7 | from sklearn.manifold import TSNE 8 | from sklearn.cluster import KMeans 9 | 10 | 11 | def cluster(embeddings, labels, cluster_count=10): 12 | """Cluster the embeddings and return embeddings and cluster labels.""" 13 | if not isinstance(embeddings, np.ndarray): 14 | embeddings = np.array(embeddings) 15 | kmeans = KMeans(n_clusters=cluster_count, random_state=42) 16 | labels = kmeans.fit_predict(embeddings) 17 | tsne = TSNE(n_components=2, random_state=42) 18 | embeddings_2d = tsne.fit_transform(embeddings) 19 | return embeddings_2d, labels 20 | 21 | 22 | def plot_matplotlib(embeddings_2d, labels, labels_dict): 23 | """Embeddings 2D is the output of the TSNE function. 24 | Labels is the output of the KMeans function. 25 | Labels dict is a dictionary mapping the cluster number to the label. 26 | """ 27 | fig, ax = plt.subplots(figsize=(15, 15)) 28 | ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 29 | c=labels, cmap="rainbow") 30 | for cluster, label in labels_dict.items(): 31 | ax.plot([], [], label=label, marker="o", c=f"C{cluster}") 32 | ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5), fontsize=14) 33 | plt.show() 34 | 35 | 36 | def plot_ploty(embeddings_2d, labels, labels_dict): 37 | """Same as above, but with Plotly.""" 38 | df = pd.DataFrame( 39 | {"x": embeddings_2d[:, 0], "y": embeddings_2d[:, 1], "label": labels}) 40 | df["label"] = df["label"].map(labels_dict) 41 | fig = px.scatter(df, x="x", y="y", color="label") 42 | fig.show() 43 | 44 | 45 | def plot_3d_embeddings(embeddings, labels, labels_dict): 46 | """Project into three dimensions. 47 | Unlike prior functions, this one takes the embeddings straight up. 48 | """ 49 | tsne = TSNE(n_components=3, random_state=42) 50 | embeddings_3d = tsne.fit_transform(embeddings) 51 | df = pd.DataFrame( 52 | {"x": embeddings_3d[:, 0], "y": embeddings_3d[:, 1], "z": embeddings_3d[:, 2], "label": labels}) 53 | df["label"] = df["label"].map(labels_dict) 54 | fig = px.scatter_3d(df, x="x", y="y", z="z", color="label") 55 | fig.show() 56 | --------------------------------------------------------------------------------