├── LICENSE.md ├── README.md ├── encode.py ├── engram.py ├── example.py ├── requirements.txt └── transformer.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | All rights reserved. Copyright 2021 AeroScripts (Luke Fay) 2 | 3 | I intend to release this under a much more leanient license soon. Please contact me if you would like to use this in your project in the meantime! 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hidden Engrams: Long Term Memory for Transformer Model Inference 2 | 3 | State-of-the art transformer models like GPT3 can generate realistic text, but the window of text the transformer is able to look at is still relatively small. 4 | Hidden Engrams aims to remedy this problem by introducing an approximation of long term memory using the transformer's hidden states. These values can then be used to quickly sort all past "memories" by relevance to the current input. Once sorted, an optimized prompt can be built including only the most relevant information 5 | 6 | ## Usage 7 | First, ensure the transformer model you want to use is configured properly in transformer.py. Engrams are incompatible across different models. 8 | To encode your own datasets, modify "encode.py" as needed to load your data 9 | example.py provides a simple example use case for this: chat bots. Previous messages are encoded and stored, then used to build future prompts 10 | 11 | 12 | This is a very early proof-of-concept. More to come soon! 13 | 14 | 15 | ``` 16 | @misc{hiddenengrams, 17 | author = {Luke Fay aka AeroScripts}, 18 | title = {Hidden Engrams: Long Term Memory for Transformer Model Inference}, 19 | howpublished = {\url{https://github.com/AeroScripts/HiddenEngrams}}, 20 | year = 2021, 21 | month = June 22 | } 23 | ``` 24 | -------------------------------------------------------------------------------- /encode.py: -------------------------------------------------------------------------------- 1 | from engram import build_engram 2 | from transformer import get_transformer 3 | from tqdm import tqdm 4 | import pickle 5 | import pandas as pd 6 | 7 | # load an example dataset (shakespeare) 8 | # WARNING: This url is only for testing purposes! Please supply your own data 9 | data = pd.read_csv("https://raw.githubusercontent.com/dherman/wc-demo/master/data/shakespeare-plays.csv", error_bad_lines=False) 10 | messages = [ (str(row[1].values[2]), str(row[1].values[3])) for row in data.iterrows() ] 11 | 12 | # Load GPT Neo 13 | model, tokenizer = get_transformer() 14 | 15 | # encode memories 16 | memories = [] 17 | last_speaker = None 18 | full_message = "" 19 | for index in tqdm(range(len(messages))): 20 | speaker, message = messages[index] 21 | if last_speaker != speaker and len(full_message) > 0: 22 | tokens = tokenizer(last_speaker + ": " +full_message.lstrip(), return_tensors="pt").input_ids.cuda() 23 | memories.append({ 24 | "text": last_speaker + ": " + full_message.lstrip(), 25 | "engram": build_engram(model.forward, tokens), 26 | "next": len(memories)+1, 27 | "previous": len(memories)-1, 28 | "distance": 0 29 | }) 30 | full_message = "" 31 | last_speaker = speaker 32 | full_message = full_message + " " + message 33 | 34 | memories[-1]["next"] = -1 35 | 36 | # dump data to disk 37 | with open("shakespeare.pkl", 'wb') as handle: 38 | pickle.dump(memories, handle, protocol=pickle.HIGHEST_PROTOCOL) 39 | -------------------------------------------------------------------------------- /engram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import heapq 3 | import numpy as np 4 | 5 | # Given a transformer model forward function and an array of tokens, return the "engram" based on the hidden states of the transformer 6 | def build_engram(forward, tokens, shift=10000, factor=20000, rampdown=lambda x:x / 2): 7 | """Given a transformer model forward function and an array of tokens, return the "engram" based on the hidden states of the transformer 8 | 9 | Parameters 10 | ---------- 11 | forward : function 12 | A HuggingFace-style model forward function 13 | tokens : Tensor 14 | The tokenized text to encode 15 | shift : float, optional 16 | How much to add to the values for normalization 17 | factor : float, optional 18 | Divisor for normalization 19 | rampdown : function, optional 20 | A function to ramp down the power of tokens while iterating over the hidden states 21 | 22 | Returns 23 | ------- 24 | engram : array 25 | The final encoded engram 26 | """ 27 | 28 | # get hidden states 29 | h = list(forward(input_ids=tokens[:, -512:].long().cuda(), output_hidden_states=True).hidden_states[1:]) 30 | 31 | # todo: use rampdown 32 | f = 0 33 | fa = 1.0/float(len(h)) 34 | 35 | # combine hidden states (token axis) 36 | # we use double() here to reduce accuracy loss from overflowing. it's safe to go back to float() after the math is done. There is probably a more efficient way to do this 37 | for layer in range(len(h)): 38 | f = f + fa 39 | h[layer] = torch.mean(h[layer].detach().double(), dim=(1, )) * f 40 | 41 | h = torch.sum(torch.stack(h, axis=1)[0], dim=(0, )) 42 | 43 | # note: static values are used here to make sorting more consistent. Previously I normalized per-engram but that reduced the overal accuracy of the sorting 44 | return ((h + shift) / factor).float().to("cpu").numpy() 45 | 46 | # Given a "now" engram, and an array of past engrams, return top_k closest matching engrams 47 | def sort_engrams(now, past, factor=1000.0, epsilon=1e-6, top_k=250, depth=1, do_distance=True): 48 | """Given a "now" engram, and an array of past engrams, return top_k closest matching engrams 49 | 50 | Parameters 51 | ---------- 52 | now : dict 53 | The engram to compare against 54 | past : list 55 | The list of past engrams to sort 56 | factor : float, optional 57 | A function to ramp down the power of tokens while iterating over the hidden states 58 | epsilon : float, optional 59 | A small value to add to engrams during sorting 60 | top_k : int, optional 61 | The number of closest matching engrams to return 62 | depth : int, optional 63 | How many previous or future memories to check against during sorting 64 | do_distance: bool, optional 65 | Should we perform distance calculations? Disabling this is useful if you are doing multiple passes with different depths 66 | 67 | Returns 68 | ------- 69 | sorted : list 70 | The final sorted list of top_k closest matching engrams 71 | """ 72 | now = now["engram"].astype(np.float32) 73 | 74 | # calculate distance between all past engrams and the current engram 75 | if do_distance: 76 | for e in range(len(past)): 77 | past[e]["distance"] = np.sum(np.sqrt((np.abs(past[e]["engram"].astype(np.float32) - now) / factor) + epsilon)) 78 | 79 | # return the distance value of a given engram, recursively if depth>1 80 | def keyer(m): 81 | if depth == 1: 82 | return m["distance"] 83 | else: 84 | total = 0 85 | nodeup = m 86 | nodedown = m 87 | 88 | # calculate distance across n previous and future engrams 89 | for e in range(depth-1): 90 | nodeup = nodeup["previous"] 91 | nodedown = nodedown["next"] 92 | if nodeup is None or nodeup < 0 or nodedown is None or nodedown: 93 | total = total + 100000 # some high penalty (unlinked) TODO: better solution 94 | break 95 | 96 | # scaling factor for distance to root engram 97 | f = (2.0 * (e + 1.0)) 98 | 99 | if nodeup < 0 or nodedown < 0: 100 | total = total + 100000 # some high penalty (unlinked) TODO: better solution 101 | else: 102 | nodeup = past[nodeup] 103 | nodedown = past[nodedown] 104 | total = total + (nodeup["distance"] / f) + (nodedown["distance"] / f) 105 | return m["distance"] + total 106 | 107 | # pick top_k smallest values (faster than full sort) 108 | return heapq.nsmallest(top_k, past, key=keyer) 109 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from engram import build_engram, sort_engrams 4 | from transformer import get_transformer, get_generator 5 | 6 | # Load GPT Neo 7 | model, tokenizer = get_transformer() 8 | generate = get_generator(model, tokenizer) 9 | 10 | # Change as needed, this works with the shakespere example dataset in encode.py 11 | memory_file = "shakespeare.pkl" 12 | speaker_name = "JULIET" 13 | GPT_name = "ROMEO" 14 | 15 | print("Welcome to GPT Chat!") 16 | print(f"You are chatting as {speaker_name}, and GPT is chatting as {GPT_name}.") 17 | 18 | context = [] 19 | memories = [] 20 | 21 | if os.path.exists(memory_file): 22 | with open(memory_file, 'rb') as handle: 23 | memories = pickle.load(handle) 24 | 25 | def add_engram(text, add_context=True): 26 | context.append(text) 27 | memoryCount = len(memories) 28 | memories[-1]["next"] = memoryCount 29 | engram = { 30 | "text": text, 31 | "engram": build_engram(model.forward, tokenizer(text, return_tensors="pt").input_ids.cuda()), 32 | "next": -1, 33 | "previous": memoryCount-1, 34 | "distance": 0 35 | } 36 | memories.append(engram) 37 | return engram 38 | 39 | def build_context(now, short_term=10): 40 | # sort engrams 41 | m = sort_engrams(now, memories[:-short_term], top_k=600) 42 | m = sort_engrams(now, m, top_k=150, do_distance=False, depth=2) 43 | m = sort_engrams(now, m, top_k=42, do_distance=False, depth=3) 44 | m.reverse() 45 | 46 | text = "" 47 | 48 | for memory in m: 49 | if not memory["text"].startswith(speaker_name): 50 | text = text + memories[memory["previous"]]["text"] + "\n" 51 | text = text + memory["text"] + "\n" 52 | if memory["text"].startswith(speaker_name): 53 | text = text + memories[memory["next"]]["text"] + "\n" 54 | 55 | for recent in context[-short_term:]: # 10 most recent messages 56 | text = text + recent + "\n" 57 | 58 | return text 59 | 60 | while True: 61 | # let user input a message 62 | message = input(speaker_name + ": ") 63 | 64 | engram = add_engram(speaker_name + ": " + message) 65 | 66 | text = build_context(engram) + GPT_name + ":" 67 | 68 | reply = generate(text).split("\n")[0] 69 | print(GPT_name + ":" + reply) 70 | 71 | add_engram(GPT_name + ":" + reply) 72 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | transformers 4 | pandas 5 | tqdm 6 | -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | # Get the transformer model and tokenizer 2 | def get_transformer(): 3 | from transformers import GPTNeoForCausalLM, AutoTokenizer 4 | model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-6B") # I use finetune's fork 5 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 6 | return model, tokenizer 7 | 8 | # Easy generation function 9 | def get_generator(model, tokenizer, maxLength=78): 10 | def generator(text): 11 | tokens = tokenizer(text, return_tensors="pt").input_ids.cuda()[:, -(2047-maxLength):] 12 | out = model.generate( 13 | tokens.long(), 14 | do_sample=True, 15 | min_length=tokens.shape[1] + maxLength, 16 | max_length=tokens.shape[1] + maxLength, 17 | temperature= 0.85, 18 | tfs = 0.9, 19 | top_k = None, 20 | top_p = None, 21 | repetition_penalty = 1.18, 22 | repetition_penalty_range = 512, 23 | repetition_penalty_slope = 3.33, 24 | use_cache=True, 25 | bad_words_ids=None, 26 | pad_token_id=tokenizer.eos_token_id, 27 | ).long().to("cpu")[0] 28 | return tokenizer.decode(out[-(out.shape[0]-tokens.shape[1]):]) 29 | return generator 30 | --------------------------------------------------------------------------------