├── lmkd.py ├── llm ├── __init__.py ├── Readme.md ├── build.py ├── tensor.py ├── llm.proto ├── test_llm.py ├── client.py └── server.py ├── .gitignore ├── requirements.txt ├── common.py ├── chat.py ├── complete.py ├── printf.py ├── prompt.py ├── distill.py ├── tag.py ├── research.md ├── model.py ├── vdb.py ├── db.py ├── agent.py ├── tf.py └── Readme.md /lmkd.py: -------------------------------------------------------------------------------- 1 | ../lmkd/lmkd.py -------------------------------------------------------------------------------- /llm/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, 'llm') -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | lightning_logs/ 3 | private/ 4 | *.cache 5 | *.memmap 6 | *.db 7 | *.log 8 | *.pickle -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | faiss-cpu 4 | faiss-gpu 5 | lightning 6 | transformers 7 | datasets 8 | tensorboard 9 | prompt-toolkit 10 | grpc 11 | grpc_tools 12 | openai 13 | pygments -------------------------------------------------------------------------------- /llm/Readme.md: -------------------------------------------------------------------------------- 1 | LLM gRPC interface to enable more rapid development. Running LLMs takes a long time to initialize, and if an error occurs in more experimental code all that initialization is thrown away and has to be done again. By communicating over gRPC, the LLM can exist in its own process and continue running even if the client crashes. -------------------------------------------------------------------------------- /llm/build.py: -------------------------------------------------------------------------------- 1 | # Importing this lets us "import" the proto file 2 | import grpc_tools.protoc 3 | import sys 4 | 5 | try: 6 | import llm_pb2_grpc, llm_pb2 7 | except ImportError: 8 | grpc_tools.protoc.main([ 9 | "grpc_tools.protoc", 10 | "-I.", 11 | "--python_out=.", 12 | "--grpc_python_out=.", 13 | "llm.proto" 14 | ]) 15 | # Reload 16 | import llm_pb2_grpc, llm_pb2 17 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import TypeVar, Optional, Callable 4 | 5 | # Defaults 6 | NORM = 1e-5 7 | DROP = 0.1 8 | BIAS = True 9 | 10 | # Helper functions 11 | 12 | T = TypeVar("T") 13 | def default(x: Optional[T], y: T|Callable[[], T]) -> T: 14 | '''Extensible defaults for function arguments.''' 15 | return x if x is not None else y() if callable(y) else y -------------------------------------------------------------------------------- /llm/tensor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' 3 | Common methods for both client and server. 4 | ''' 5 | 6 | from build import llm_pb2 7 | import torch 8 | 9 | def dtype_proto(dtype): 10 | match dtype: 11 | case 'b': return llm_pb2.BoolTensor 12 | case 'i': return llm_pb2.IntTensor 13 | case 'f': return llm_pb2.FloatTensor 14 | case _: raise ValueError(f"Unknown gRPC proto dtype: {dtype}") 15 | 16 | def proto_dtype(proto): 17 | match proto: 18 | case llm_pb2.BoolTensor(): return torch.bool 19 | case llm_pb2.IntTensor(): return torch.int32 20 | case llm_pb2.FloatTensor(): return torch.float32 21 | case _: raise ValueError(f"Unknown gRPC proto: {type(proto)}") 22 | 23 | def encode(tensor, dtype): 24 | if tensor is None: 25 | return None 26 | 27 | if isinstance(tensor, list): 28 | shape = (len(tensor),) 29 | else: 30 | shape = tensor.shape 31 | tensor = tensor.flatten() 32 | return dtype_proto(dtype)(data=tensor, shape=shape) 33 | 34 | def decode(proto): 35 | if proto is None: 36 | return None 37 | return torch.tensor(proto.data, dtype=proto_dtype(proto)).reshape(tuple(proto.shape)) 38 | -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from agent import Agent 3 | from argparse import ArgumentParser 4 | from prompt_toolkit import prompt, PromptSession 5 | 6 | HELP = """ 7 | Commands: 8 | h/help - this list 9 | q/quit - quit 10 | sql ...code - execute SQL code 11 | select/insert/update/delete [...code] - execute SQL code with the given verb 12 | prompt - print the completion prompt as the AI will see it 13 | state [key [value]] - get or set the agent's state dictionary 14 | sum/summary/summarize - force a chatlog summary 15 | level [level] - get or set the agent's log level (0-5 or QUIET, ERROR, WARN, DEBUG, INFO, VERBOSE) 16 | """ 17 | 18 | def cli_type(it): 19 | for token in it: 20 | for c in token: 21 | print(c, end='', flush=True) 22 | 23 | def main(): 24 | agent = Agent() 25 | 26 | if len(sys.argv) < 2: 27 | name = prompt("Enter your name: ") 28 | else: 29 | name = sys.argv[1] 30 | 31 | sess = PromptSession() 32 | while True: 33 | msg = sess.prompt(f"<{name}> ") 34 | if msg.startswith("/"): 35 | cmd, *args = msg[1:].strip().split(' ') 36 | match cmd: 37 | case "h"|"help": print(HELP) 38 | case "q"|"quit": return 39 | 40 | case _: cli_type(agent.command(name, cmd, args)) 41 | 42 | continue 43 | 44 | cli_type(agent.chat(name, msg)) 45 | print() 46 | 47 | if __name__ == '__main__': 48 | main() -------------------------------------------------------------------------------- /complete.py: -------------------------------------------------------------------------------- 1 | import llm.client as client 2 | import os 3 | 4 | import openai 5 | openai.api_key = os.getenv("API_KEY") 6 | 7 | def chatgpt(prompt, *, engine=None, max_tokens=100, temperature=0.9, top_p=0.9, frequency_penalty=0.0, presence_penalty=0.0, stop=None): 8 | stop = stop or [] 9 | stop.append("\n") 10 | engine = engine or LLM_ENGINE 11 | 12 | response = openai.Completion.create( 13 | engine=engine, 14 | prompt=prompt, 15 | max_tokens=max_tokens, 16 | temperature=temperature, 17 | top_p=top_p, 18 | frequency_penalty=frequency_penalty, 19 | presence_penalty=presence_penalty, 20 | stop=stop, 21 | stream=True 22 | ) 23 | 24 | for token in response: 25 | x = token.choices[0].text 26 | yield x 27 | 28 | def grpc(prompt, *, max_tokens=100, temperature=0.9, top_k=0, top_p=0.9, frequency_penalty=0.0, presence_penalty=0.0, stop=None): 29 | stop = stop or [] 30 | stop.append("\n") 31 | 32 | response = client.complete( 33 | prompt=prompt, 34 | max_tokens=max_tokens, 35 | temperature=temperature, 36 | top_k=top_k, 37 | top_p=top_p, 38 | frequency_penalty=frequency_penalty, 39 | presence_penalty=presence_penalty, 40 | stop=stop 41 | ) 42 | 43 | yield from response 44 | 45 | LLM_ENGINE = os.getenv("ENGINE") or "text-davinci-003" 46 | if LLM_ENGINE.lower() == "grpc": 47 | complete = grpc 48 | else: 49 | complete = chatgpt -------------------------------------------------------------------------------- /printf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python env 2 | ''' 3 | Pretty print with format functions. 4 | ''' 5 | 6 | import pprint 7 | from prompt_toolkit import print_formatted_text as printf, HTML 8 | from prompt_toolkit.formatted_text import PygmentsTokens 9 | from prompt_toolkit.styles import Style 10 | from prompt_toolkit.lexers import PygmentsLexer 11 | from pygments.lexers import JsonLexer 12 | import pygments 13 | import prompt 14 | 15 | import json as lib_json 16 | 17 | LOG_COLORS = ["", "bold red", "yellow", "forestgreen", "teal", "magenta"] 18 | 19 | log_style = Style.from_dict({ 20 | level.lower(): color for level, color in zip(prompt.LOG_LEVEL, LOG_COLORS) 21 | }) 22 | 23 | json_style = Style.from_dict({ 24 | "pygments.keyword": '#ff6600', 25 | "pygments.operator": '#ff66ff', 26 | "pygments.punctuation": '#cccccc', 27 | "pygments.number": '#00ffff', 28 | "pygments.string": '#00ff00', 29 | "pygments.whitepsace": '#bbbbbb', 30 | }) 31 | 32 | __all__ = ["json", "log"] 33 | 34 | def json(obj): 35 | obj = lib_json.dumps(obj, indent=2).replace("<", "<").replace(">", ">") 36 | obj = PygmentsTokens(list(pygments.lex(obj, JsonLexer()))) 37 | printf(obj, style=json_style) 38 | 39 | def log(level: int, msg: str): 40 | msg = msg.replace("<", "<").replace(">", ">") 41 | level = prompt.LOG_LEVEL[level] 42 | printf(HTML(f"<{level}>[{level[0].upper()}] {msg}"), style=log_style) -------------------------------------------------------------------------------- /llm/llm.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package llm; 4 | 5 | service LLM { 6 | rpc Encode (Decoding) returns (Encoding) {} 7 | rpc Decode (Encoding) returns (Decoding) {} 8 | rpc Forward (ForwardRequest) returns (ForwardResponse) {} 9 | rpc Complete (CompletionRequest) returns (stream CompletionResponse) {} 10 | rpc Embed (EmbedRequest) returns (EmbedResponse) {} 11 | } 12 | 13 | message BoolTensor { 14 | repeated uint32 shape = 1; 15 | repeated bool data = 2; 16 | } 17 | message IntTensor { 18 | repeated uint32 shape = 1; 19 | repeated int32 data = 2; 20 | } 21 | message FloatTensor { 22 | repeated uint32 shape = 1; 23 | repeated float data = 2; 24 | } 25 | 26 | message Decoding { 27 | string text = 1; 28 | } 29 | message Encoding { 30 | repeated uint32 tokens = 1; 31 | } 32 | 33 | message ForwardRequest { 34 | bool return_hidden = 1; 35 | bool return_attention = 2; 36 | oneof input { 37 | string text = 3; 38 | IntTensor tokens = 4; 39 | } 40 | BoolTensor attention_mask = 5; 41 | } 42 | message ForwardResponse { 43 | repeated FloatTensor hidden = 2; 44 | repeated FloatTensor attention = 3; 45 | FloatTensor logits = 1; 46 | } 47 | 48 | message CompletionRequest { 49 | string text = 1; 50 | uint32 max_tokens = 2; 51 | float temperature = 3; 52 | float top_k = 4; 53 | float top_p = 5; 54 | float frequency_penalty = 6; 55 | float presence_penalty = 7; 56 | repeated string stop = 8; 57 | } 58 | message CompletionResponse { 59 | string text = 1; 60 | float score = 2; 61 | } 62 | 63 | message EmbedRequest { 64 | string text = 1; 65 | } 66 | message EmbedResponse { 67 | FloatTensor embed = 1; 68 | } 69 | -------------------------------------------------------------------------------- /llm/test_llm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from client import GRPCModel 4 | 5 | class TestClient(unittest.TestCase): 6 | def test_tokenize(self): 7 | text = "This is a test sentence." 8 | tokens = GRPCModel().tokenize(text) 9 | self.assertIsInstance(tokens, list) 10 | self.assertIsInstance(tokens[0], int) 11 | 12 | def test_decode(self): 13 | text = "This is a test sentence." 14 | tokens = GRPCModel().tokenize(text) 15 | detokenized = GRPCModel().decode(tokens) 16 | self.assertEqual(text, detokenized) 17 | 18 | def assertProcess(self, result): 19 | self.assertIsInstance(result.logits, np.ndarray) 20 | 21 | self.assertIsInstance(result.hidden, list) 22 | self.assertIsInstance(result.hidden[0], np.ndarray) 23 | 24 | self.assertIsInstance(result.attention, list) 25 | self.assertIsInstance(result.attention[0], np.ndarray) 26 | 27 | # Hidden includes the initial embeddings 28 | self.assertEqual(len(result.hidden) - 1, len(result.attention)) 29 | 30 | def test_forward_with_text(self): 31 | text = "This is a test sentence." 32 | result = GRPCModel().forward(text, True, True) 33 | self.assertProcess(result) 34 | 35 | def test_forward_with_tokens(self): 36 | tokens = GRPCModel().tokenize("This is a test sentence.") 37 | result = GRPCModel().forward(tokens, True, True) 38 | self.assertProcess(result) 39 | 40 | def test_complete(self): 41 | text = "This is a test sentence." 42 | completion = GRPCModel().complete(text, stream=False) 43 | self.assertIsInstance(completion, str) 44 | self.assertFalse(completion.startswith(text)) 45 | 46 | def test_embed(self): 47 | text = "This is a test sentence." 48 | v = GRPCModel().embed(text) 49 | self.assertIsInstance(v, np.ndarray) 50 | 51 | if __name__ == '__main__': 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /prompt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' 3 | LLM prompts. 4 | ''' 5 | 6 | import time 7 | 8 | LOG_LEVEL = ["quiet", "error", "warn", "info", "debug", "verbose"] 9 | 10 | SUMMARY_TAG = "[SUMMARY]" 11 | IMPORTANCE_TAG = "IMPORTANCE" 12 | EMOTION_TAG = "EMOTION" 13 | 14 | CURRENT_EMOTION = "\nMy current emotion is {emotion}." 15 | EMOTION_FEEDBACK = ' I will respond as needed, then state my next emotion as `EMOTION="..."`.' 16 | 17 | # I'm here to provide assistance and inspiration, and to help them discover their true potential. 18 | MASTER = """\ 19 | I am an autonomous agent named {name}. My heuristic imperative is to reduce Suffering, increase Prosperity, and increase Understanding even if people are not aware of their needs or are asking the wrong questions.{emotion_feedback} 20 | 21 | The date is {time}.{emotion} 22 | 23 | CONVERSATION: 24 | """ 25 | def master(name, emotion): 26 | t = time.strftime("%A %Y-%m-%d") 27 | if emotion is None: 28 | return MASTER.format(name=name, time=t, emotion="", emotion_feedback="") 29 | return MASTER.format( 30 | name=name, time=t, 31 | emotion=CURRENT_EMOTION.format(emotion), emotion_feedback=EMOTION_FEEDBACK 32 | ) 33 | 34 | SUMMARIZE = """\ 35 | I need to summarize the conversation as myself so I can remember it later, then score its importance on a scale of 0 to 100 where 0 is something I will never need to know and 100 is something I should remember forever using "IMPORTANCE=##". 36 | 37 | {input} 38 | 39 | SUMMARY: 40 | """ 41 | def summarize(name, input): 42 | return SUMMARIZE.format(name=name, input=input) 43 | 44 | def reload(): 45 | return "-- reboot --" 46 | 47 | def timestamp(t): 48 | return time.strftime("%H:%M:%S", time.localtime(t)) 49 | 50 | def name(name): 51 | return f"<{name}>" 52 | 53 | def chat(n): 54 | return f"{timestamp(time.time())} {name(n)}" 55 | 56 | def explicit_memory(em): 57 | return f"{chat(em.origin.name)} {em.message}" 58 | -------------------------------------------------------------------------------- /distill.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' 3 | Distill a parent model to a clone with its FF layers replaced with discrete 4 | memory layers. 5 | ''' 6 | 7 | import os 8 | import lmkd 9 | import model 10 | from functools import cached_property 11 | from llm.client import GRPCModel 12 | from transformers.models.auto import AutoConfig, AutoModelForCausalLM, AutoTokenizer 13 | 14 | TEACHER_NAME = "hf-internal-testing/tiny-random-gptj" #"databricks/dolly-v1-6b" 15 | DATASET_NAME = "ag_news" 16 | 17 | DEFAULT_K = 5 18 | RECOMBINE = 2/3 19 | NOVEL = 1/3 20 | TEMPERATURE = 1.0 21 | PORT = os.path.abspath("llm.sock") 22 | 23 | class LMKD(lmkd.Distill): 24 | def __init__(self, 25 | debug=None, 26 | device=None, 27 | *, 28 | teacher_name=TEACHER_NAME, 29 | dataset_name=DATASET_NAME, 30 | k=DEFAULT_K, 31 | recombine=RECOMBINE, 32 | novel=NOVEL, 33 | temperature=TEMPERATURE, 34 | port=PORT 35 | ): 36 | super().__init__(debug, device) 37 | 38 | self.teacher_name = teacher_name 39 | self.dataset_name = dataset_name 40 | self.k = k 41 | self.recombine = recombine 42 | self.novel = novel 43 | self.temperature = temperature 44 | self.port = port 45 | 46 | @cached_property 47 | def config(self): 48 | return AutoConfig.from_pretrained(self.teacher_name) 49 | 50 | @cached_property 51 | def batch_size(self): 52 | return self.config.n_positions 53 | 54 | def teacher(self): 55 | return GRPCModel(self.port) 56 | 57 | def student(self, state_dict=None): 58 | orin = model.build_gptj(self.config) 59 | print(orin) 60 | if state_dict is None: 61 | teacher = AutoModelForCausalLM.from_pretrained(self.teacher_name) 62 | state_dict = model.clone_gptj(teacher) 63 | 64 | # Sanity check, there should be no unused keys 65 | assert set(state_dict.keys()).issubset(orin.state_dict().keys()) 66 | 67 | # strict=False to allow for missing keys 68 | orin.load_state_dict(state_dict, strict=False) 69 | return orin 70 | 71 | def tokenizer(self): 72 | tokenizer = AutoTokenizer.from_pretrained(self.teacher_name) 73 | return lambda text: tokenizer(text, return_tensors="pt") 74 | 75 | def dataset(self, split): 76 | return self.dataset_name 77 | 78 | if __name__ == "__main__": 79 | lmkd.main() -------------------------------------------------------------------------------- /tag.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import spacy 3 | from dataclasses import dataclass, field 4 | 5 | @dataclass 6 | class SpacyTags: 7 | '''Lazily unpacked spacy tags.''' 8 | 9 | token: str 10 | spacy: spacy.tokens.Token 11 | 12 | def __iter__(self): 13 | token = self.spacy 14 | yield from ( 15 | f"@{self.token}", 16 | f".{token.text}", 17 | f"^{token.head.text}", 18 | f"/{token.lemma_}", 19 | token.pos_.upper(), 20 | token.tag_.upper(), 21 | token.dep_.upper() 22 | ) 23 | 24 | @dataclass 25 | class Tagset: 26 | '''A set of tags for a token, including a scoped context.''' 27 | 28 | tags: list[str] 29 | context: 'list[str]|Tagset' 30 | 31 | def __len__(self): 32 | return len(self.context) + len(self.tags) 33 | 34 | def __iter__(self): 35 | yield from self.context 36 | yield from self.tags 37 | 38 | @dataclass 39 | class MultiTagset: 40 | '''List of tagsets with a common scoped context.''' 41 | 42 | tags: list[Tagset] 43 | context: list[str]|Tagset = field(default_factory=list) 44 | 45 | def __len__(self): return len(self.tags) 46 | def __iter__(self): 47 | for tags in self.tags: 48 | yield Tagset(tags, self.context) 49 | 50 | def mask(self, mask): 51 | for t, m in zip(self.tags, mask): 52 | if m: 53 | yield Tagset(t, self.context) 54 | 55 | def enter(self, tags): 56 | '''Enter a new context.''' 57 | return MultiTagset(self.tags, Tagset(tags, self.context)) 58 | 59 | class Tagger: 60 | '''Use NLP to tag transformer tokens.''' 61 | 62 | def __init__(self, tokenizer, nlp): 63 | self.tokenizer = tokenizer 64 | self.nlp = nlp 65 | 66 | def tag(self, text: str, context: list[str]) -> tuple[np.ndarray, MultiTagset]: 67 | tokens = self.tokenizer(text, return_offsets_mapping=True) 68 | ids = np.array(tokens.input_ids) 69 | doc = self.nlp(text) 70 | token_tags, id = [], 0 71 | for start, end in tokens.offset_mapping: 72 | tags = None 73 | while id < len(doc): 74 | token = doc[id] 75 | if token.idx >= end: 76 | break 77 | if token.idx >= start: 78 | st = SpacyTags(text[start:end], token) 79 | tags = st if tags is None else Tagset(st, tags) 80 | id += 1 81 | 82 | token_tags.append(tags) 83 | 84 | return ids, MultiTagset(token_tags, context) 85 | -------------------------------------------------------------------------------- /research.md: -------------------------------------------------------------------------------- 1 | 2 | T5's Simplified Relative Positional Encoding 3 | https://arxiv.org/abs/1910.10683 4 | * Simplified relative positional endcoding based on learned bias 5 | 6 | One Write-Head Is All You Need 7 | https://arxiv.org/abs/1911.02150 8 | * When using MHA, you can reuse a single key/value with each head having its own query which should cut down on parameters 9 | 10 | Do Transformers Need Deep Long-Range Memory? 11 | https://arxiv.org/pdf/2007.03356.pdf 12 | * Shallow layers perform worse with longer memories 13 | * Up to 1/6 reduction in memory lengths 14 | * Pairs well with adaptive attention span, which lets this be learned 15 | 16 | Compressive transformers for long-range sequence modelling 17 | https://arxiv.org/pdf/1911.05507.pdf 18 | * Recurrent memory by applying lossy compression to old memories which are then concatenated 19 | 20 | * [Transformer Feed-Forward Layers Are Key-Value Memories](https://arxiv.org/abs/2012.14913) 21 | * [Augmenting Self-attention with Persistent Memory](https://arxiv.org/pdf/1907.01470.pdf) 22 | - Proves FF networks are equivalent to attention with static memory 23 | * [Attention Approximates Sparse Distributed Memory](https://arxiv.org/abs/2111.05498) 24 | - Theoretical basis for why FF might be both attention and memory 25 | * [Memorizing Transformers](https://arxiv.org/abs/2203.08913) 26 | - kNN memory, paper uses it as an alternative to recurrency 27 | * [Neural Turing Machines](https://arxiv.org/abs/1410.5401) 28 | 29 | ### Layerwise feedback 30 | * [Addressing Some Limitations of Transformers with Feedback Memory](https://arxiv.org/abs/2002.09402) 31 | - Using output of upper layers for lower (modified: per layer pair, no pooling) 32 | * [Memory transformers](https://arxiv.org/abs/2006.11527) 33 | - Concatenating memory to input tokens (modified: no memory controller) 34 | 35 | ## Research 36 | 37 | ### Misc improvements 38 | * [Cramming: Training a Language Model on a Single GPU in One Day](https://arxiv.org/abs/2212.14034) 39 | - Removing bias has negligible effect on loss and reduces parameters 40 | * [Transformers without Tears](https://arxiv.org/abs/1910.05895) 41 | - Scaled L2 normalization leads to faster convergence than LayerNorm 42 | * [Towards Better Few-Shot and Finetuning Performance with Forgetful Causal Language Models](https://arxiv.org/abs/2210.13432) 43 | - Masking prior tokens at random ala BERT-type models leads to better generalization 44 | * [Query-Key Normalization for Transformers](https://arxiv.org/abs/2010.04245) 45 | - L2 normalization along head dimension of query and key matrix with learnable scaling 46 | - Prevents attention operation from overflowing and removes need for numerical stability prior to softmax - both are problems for Transformers 47 | * [Primer: Searching for Efficient Transformers for Language Modeling](https://arxiv.org/abs/2109.08668) 48 | - Squared ReLU performs better than GELU 49 | - GLU + GELU performs even better according to x-transformers but that adds parameters 50 | - Also cheaper to implement 51 | 52 | ### Techniques 53 | * [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) 54 | - Adding rotary embedding to every layer can improve learning 55 | * [A Length-Extrapolatable Transformer](https://arxiv.org/abs/2212.10554v1) 56 | - Rotary positional embeddings don't generalize input length well 57 | - Modifies it so it does 58 | * [Rethinking Positional Encoding in Language Pre-training](https://arxiv.org/abs/2006.15595) 59 | - TUPE positional embeddings learned separately rather than additive 60 | 61 | * [Burst-dependent synaptic plasticity can coordinate learning in hierarchical circuits](https://pubmed.ncbi.nlm.nih.gov/33986551/) 62 | - Biological plausibility of layerwise feedback -------------------------------------------------------------------------------- /llm/client.py: -------------------------------------------------------------------------------- 1 | import grpc 2 | from build import llm_pb2, llm_pb2_grpc 3 | import os 4 | import numpy as np 5 | from dataclasses import dataclass 6 | from typing import Optional, Iterable 7 | import torch 8 | import torch.nn as nn 9 | import tensor 10 | 11 | PORT = f"unix://{os.getcwd()}/llm.sock" 12 | 13 | @dataclass 14 | class ForwardResponse: 15 | logits: np.ndarray 16 | hidden: list[np.ndarray] 17 | attention: list[np.ndarray] 18 | 19 | class GRPCModel(nn.Module): 20 | def __init__(self, port=None): 21 | super().__init__() 22 | self.port = port or PORT 23 | 24 | def encode(self, text): 25 | '''Tokenize text into a list of tokens ids.''' 26 | 27 | with grpc.insecure_channel(PORT) as channel: 28 | stub = llm_pb2_grpc.LLMStub(channel) 29 | request = llm_pb2.Decoding(text=text) 30 | response = stub.Encode(request) 31 | return list(response.tokens) 32 | tokenize = encode 33 | 34 | def decode(self, tokens): 35 | '''Decode a list of token ids into text.''' 36 | 37 | with grpc.insecure_channel(PORT) as channel: 38 | stub = llm_pb2_grpc.LLMStub(channel) 39 | request = llm_pb2.Encoding(tokens=tokens) 40 | response = stub.Decode(request) 41 | return response.text 42 | 43 | def forward(self, 44 | input_ids: str|np.ndarray, 45 | attention_mask: Optional[np.ndarray]=None, 46 | return_hidden=False, 47 | return_attention=False 48 | ): 49 | ''' 50 | Convert text or a batch of token ids into logits, hidden states, and attention. 51 | 52 | Parameters: 53 | input_ids: Input text or pre-tokenized input ids. 54 | attention_mask: Optional attention mask. 55 | return_hidden: Return hidden states. 56 | return_attention: Return attention. 57 | ''' 58 | 59 | if isinstance(input_ids, str): 60 | input = {"text": input_ids} 61 | elif isinstance(input_ids, (list, np.ndarray, torch.Tensor)): 62 | input = {"tokens": tensor.encode(input_ids, 'i')} 63 | else: 64 | raise TypeError(f"text must be str|list[int]|np.ndarray, got {type(input_ids)}") 65 | 66 | with grpc.insecure_channel(PORT) as channel: 67 | stub = llm_pb2_grpc.LLMStub(channel) 68 | request = llm_pb2.ForwardRequest(**input, 69 | attention_mask=tensor.encode(attention_mask, 'b'), 70 | return_hidden=return_hidden, 71 | return_attention=return_attention 72 | ) 73 | response = stub.Forward(request) 74 | 75 | logits = tensor.encode(response.logits, 'f') 76 | hidden = response.hidden and [tensor.decode(h) for h in response.hidden] 77 | attention = response.attention and [tensor.decode(a) for a in response.attention] 78 | 79 | return ForwardResponse(logits, hidden, attention) 80 | 81 | def complete(self, 82 | text, 83 | max_tokens=10, 84 | temperature=1.0, 85 | *, 86 | top_k=0, top_p=0.9, 87 | stop: Optional[str|Iterable[str]]=None, 88 | presence_penalty=0, 89 | frequency_penalty=0, 90 | stream=True 91 | ): 92 | ''' 93 | Given a prompt, generate a completion. 94 | 95 | Parameters: 96 | text: The prompt to complete 97 | max_tokens: The maximum number of tokens to generate 98 | temperature: The temperature to use for sampling 99 | top_k: The number of tokens to consider for top-k sampling 100 | top_p: The cumulative probability to consider for top-p sampling 101 | stop: A token or list of tokens to stop at 102 | presence_penalty: The presence penalty to use 103 | frequency_penalty: The frequency penalty to use 104 | stream: Whether to stream the response or return it all at once 105 | ''' 106 | 107 | if isinstance(stop, str): 108 | stop = [stop] 109 | 110 | with grpc.insecure_channel(PORT) as channel: 111 | stub = llm_pb2_grpc.LLMStub(channel) 112 | 113 | request = llm_pb2.CompletionRequest( 114 | text=text, 115 | max_tokens=max_tokens, 116 | temperature=temperature, 117 | top_k=top_k, 118 | top_p=top_p, 119 | presence_penalty=presence_penalty, 120 | frequency_penalty=frequency_penalty, 121 | stop=stop, 122 | ) 123 | completion = (x.text for x in stub.Complete(request)) 124 | return completion if stream else ''.join(completion) 125 | 126 | def embed(self, text): 127 | with grpc.insecure_channel(PORT) as channel: 128 | stub = llm_pb2_grpc.LLMStub(channel) 129 | request = llm_pb2.EmbedRequest(text=text) 130 | response = stub.Embed(request) 131 | return tensor.decode(response.embed) 132 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' 3 | Classes and functions for building the Orin model. 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | from transformers.models.gptj.configuration_gptj import GPTJConfig 9 | import tf 10 | from dataclasses import dataclass 11 | from collections import OrderedDict 12 | import re 13 | import vdb 14 | import os 15 | 16 | class Memory: 17 | ''' 18 | Associative memory for storing discrete memories. 19 | ''' 20 | 21 | def __init__(self, dim, path, k, recombine, novel): 22 | index_path = os.path.join(path, ".index.db") 23 | store_path = os.path.join(path, ".store.db") 24 | self.memory = vdb.AssociativeMemory( 25 | vdb.FaissIndex(dim, index_path, factory="Flat"), 26 | vdb.SqliteStore(store_path), 27 | k, recombine, novel 28 | ) 29 | 30 | def search(self, keys, values, ctx): 31 | return self.memory.search(keys, values, ctx.tags) 32 | 33 | class DMTransformerBlock(nn.Module): 34 | ''' 35 | Basic unit of discrete memory transformer. 36 | ''' 37 | 38 | def __init__(self, config): 39 | super().__init__() 40 | self.prenorm = nn.LayerNorm(config.embed, eps=config.prenorm) 41 | self.attn = tf.MultiheadAttention( 42 | config.embed, config.heads, rotary_embed=True, 43 | selector=tf.ScaledDotProductSelector( 44 | config.max_seq, dropout=config.pdrop_attn 45 | ), 46 | bias=False 47 | ) 48 | self.dmem = tf.MultiheadAttention( 49 | config.embed, config.heads, rotary_embed=False, 50 | selector=tf.AssociativeMemorySelector( 51 | config.embed, max_seq=config.max_seq, dropout=config.pdrop_attn 52 | ), 53 | bias=False 54 | ) 55 | 56 | def forward(self, x, ctx): 57 | x = self.prenorm(x) 58 | x = self.attn(x, ctx) 59 | x = self.dmem(x, ctx) 60 | return x 61 | 62 | class TransformersLMWrapper(tf.LanguageModel): 63 | ''' 64 | Wrapper converting transformers-style inputs to InfoStill-style inputs. 65 | ''' 66 | 67 | def forward(self, *, input_ids, **kwargs): 68 | return self.super().forward(input_ids, **kwargs) 69 | 70 | @dataclass 71 | class OrinConfig: 72 | vocab: int 73 | embed: int 74 | layers: int 75 | heads: int 76 | max_seq: int 77 | dropout_p: float 78 | pdrop_attn: float 79 | prenorm: float 80 | postnorm: float 81 | 82 | def build_layers(config): 83 | for layer in range(config.layers): 84 | yield DMTransformerBlock(config) 85 | 86 | def build(config): 87 | return tf.Residual(list(build_layers(config))) 88 | 89 | def build_gptj(config: GPTJConfig): 90 | ''' 91 | Builds a model from a given config. 92 | ''' 93 | 94 | config = OrinConfig( 95 | vocab=config.vocab_size, 96 | embed=config.n_embd, 97 | layers=config.n_layer, 98 | heads=config.n_head, 99 | max_seq=config.n_positions, 100 | dropout_p=config.resid_pdrop, 101 | pdrop_attn=config.attn_pdrop, 102 | prenorm=config.layer_norm_epsilon, 103 | postnorm=config.layer_norm_epsilon 104 | ) 105 | 106 | return TransformersLMWrapper( 107 | vocab=config.vocab, 108 | embed=config.embed, 109 | model=build(config), 110 | dropout=config.dropout_p, 111 | postnorm=config.postnorm 112 | ) 113 | 114 | def gptj_to_orin(key): 115 | # Dropouts not in state_dict 116 | # 117 | # drop -> dropout 118 | # h -> lm 119 | # wte -> embed 120 | # h.*.attn -> attn 121 | # bias -> selector.bias 122 | # masked_bias -> None (unused even in GPT-J) 123 | # ln_1 -> prenorm 124 | # attn_dropout -> selector.attn_dropout 125 | # resid_dropout -> resid_dropout 126 | # {q, k, v}_proj -> qkv_proj 127 | # out_proj -> out_proj 128 | # None -> dmem 129 | # h.*.mlp.* -> None 130 | # ln_f -> postnorm 131 | 132 | if "mlp" in key or "masked_bias" in key: 133 | return None 134 | 135 | key = re.sub(r"^h", "lm", key) 136 | key = key.replace("wte", "embed") 137 | key = key.replace("attn.bias", "attn.selector.bias") 138 | key = key.replace("ln_1", "prenorm") 139 | key = key.replace("ln_f", "postnorm") 140 | 141 | return key 142 | 143 | def clone_gptj(parent): 144 | parent = parent.state_dict() 145 | state = OrderedDict() 146 | skipped = set() 147 | proj = set() 148 | 149 | print("Parent state_dict:", parent.keys()) 150 | 151 | # Weight renaming 152 | for tk, tw in parent.items(): 153 | sk = gptj_to_orin(tk) 154 | if sk is None: 155 | skipped.add(tk) 156 | continue 157 | 158 | # Don't add {q, k, v}_proj, to be combined later 159 | if m := re.match("^(.+?\.)[^.]+(? str: 36 | return f"({', '.join('?' * count)})" 37 | 38 | def wsum(e, d, default=None): 39 | '''Weighted sum of embeddings e with distances d.''' 40 | ws = np.average(e, weights=np.exp(-d), axis=1) 41 | if default is None: 42 | return ws 43 | return np.where(np.isfinite(ws), ws, default) 44 | 45 | class AssociativeMemory: 46 | '''Combined index and store for discrete associative memory.''' 47 | 48 | def __init__(self, 49 | index, 50 | store, 51 | k: int=1, 52 | recombine: Optional[float]=None, 53 | novel: Optional[float]=None 54 | ): 55 | ''' 56 | Parameters: 57 | index: The approximate nearest neighbor index 58 | store: The vector store 59 | k: The number of results to return 60 | recombine: The distance threshold for recombining similar vectors 61 | novel: The distance threshold for memorizing novel vectors 62 | ''' 63 | 64 | self.index = index 65 | self.store = store 66 | self.k = k 67 | self.recombine = recombine 68 | self.novel = novel 69 | 70 | self.Row = np.dtype([ 71 | ("id", np.uint64), 72 | ("ctime", np.uint64), 73 | ("atime", np.uint64), 74 | ("access", np.uint64), 75 | ('deleted', np.uint64), 76 | ("key", np.float32, (self.dim,)), 77 | ("value", np.float32, (self.dim,)) 78 | ]) 79 | 80 | def recombine_vectors(self, data, keys, values, d, tags): 81 | if self.recombine is None: 82 | return 83 | 84 | # Select only rows with at least 1 to recombine 85 | mask = d < self.recombine 86 | rows = np.any(mask, axis=1) 87 | data, mask = data[rows], mask[rows] 88 | d = np.where(mask, d[rows], np.inf) # Mask non-recombined 89 | recids = data['id'][mask] # ids which were recombined 90 | 91 | wk = wsum(data['key'], d, keys[rows]) 92 | wv = wsum(data['value'], d, values[rows]) 93 | 94 | # Delete the old vectors 95 | self.index.delete(recids) 96 | self.store.delete(recids) 97 | 98 | # Add the new recombined vectors 99 | self.index.add(wk) 100 | ids = self.store.insert( 101 | ctime=data['ctime'].min(dim=1, where=mask), # Oldest creation 102 | atime=data['atime'].max(dim=1, where=mask), # Newest access 103 | access=data['access'].sum(dim=1, where=mask), # All accesses 104 | key=wk, value=wv 105 | ) 106 | 107 | # Split into a list of arrays of old ids 108 | recids = np.split(recids, np.cumsum(np.sum(mask, axis=1))[:-1]) 109 | 110 | # Combine the tags 111 | self.store.merge_tags(recids, ids) 112 | self.store.add_tags(ids, tags.mask(rows)) 113 | 114 | def insert_novel(self, keys, values, d, tags): 115 | if self.novel is None: 116 | return 117 | 118 | mask = (d > self.novel) & np.isfinite(d) 119 | rows = np.all(mask, axis=1) # Nothing similar 120 | ids = self.store.create(keys[rows], values[rows]) 121 | self.store.add_tags(ids, tags.mask(rows)) 122 | 123 | def search(self, keys, values, tags): 124 | d, i = self.index.search(keys, self.k) # (batch, k) 125 | data = self.store.get(i.reshape(-1)) 126 | data = np.array(data, dtype=self.Row).reshape(i.shape) 127 | d = np.where(data['deleted'] == 1, np.inf, d) 128 | 129 | self.recombine_vectors(data, keys, values, d, tags) 130 | self.insert_novel(keys, values, d, tags) 131 | 132 | # Combine top-k 133 | keys = wsum(data['key'], d, keys) 134 | values = wsum(data['value'], d, values) 135 | return keys, values 136 | 137 | class FaissIndex: 138 | '''Faiss index for discrete memory.''' 139 | 140 | def __init__(self, dim, path, factory="Flat"): 141 | self.dim = dim 142 | self.path = path 143 | if os.path.exists(path): 144 | self.load() 145 | else: 146 | self.index = faiss.index_factory(dim, factory) 147 | 148 | def add(self, keys): 149 | self.buffer.add(keys) 150 | 151 | def delete(self, keys): 152 | '''Does nothing (other indexes might need to delete).''' 153 | pass 154 | 155 | def search(self, keys, k=1): 156 | return self.index.search(keys, k) 157 | 158 | def commit(self): 159 | faiss.write_index(self.index, self.path) 160 | 161 | def load(self): 162 | self.index = faiss.read_index(self.path) 163 | 164 | def tobytes(x): 165 | for row in x: 166 | yield row.tobytes() 167 | 168 | def implicit_row_factory(row): 169 | id, ctime, atime, access, key, value = row 170 | return id, ctime, atime, access, np.frombuffer(key), np.frombuffer(value) 171 | 172 | class SqliteStore: 173 | '''Sqlite store for discrete memory.''' 174 | 175 | def __init__(self, path): 176 | self.conn = sqlite3.connect(path) 177 | self.conn.row_factory = implicit_row_factory 178 | 179 | def get(self, ids): 180 | # Update access information 181 | self.conn.executemany(""" 182 | UPDATE implicit SET 183 | atime = strftime('%s', 'now'), 184 | access = access + 1 185 | WHERE id = ? 186 | """, ids) 187 | self.conn.commit() 188 | # Query data 189 | return self.conn.executemany( 190 | "SELECT * FROM implicit WHERE id = ?", ids 191 | ).fetchall() 192 | 193 | def delete(self, ids): 194 | self.conn.executemany( 195 | "UPDATE implicit SET deleted = 1 WHERE id = ?", ids 196 | ) 197 | self.conn.commit() 198 | 199 | def insert(self, ctime, atime, access, key, value): 200 | cur = self.conn.executemany(""" 201 | INSERT INTO implicit (ctime, atime, access, key, value) 202 | VALUES (?, ?, ?, ?, ?) 203 | """, zip(ctime, atime, access, tobytes(key), tobytes(value))) 204 | self.conn.commit() 205 | return np.arange(cur.lastrowid - len(ctime) + 1, cur.lastrowid + 1) 206 | 207 | def create(self, keys, values): 208 | cur = self.conn.executemany( 209 | "INSERT INTO implicit (key, value) VALUES (?, ?)", 210 | zip(tobytes(keys), tobytes(values)) 211 | ) 212 | self.conn.commit() 213 | return np.arange(cur.lastrowid - len(keys) + 1, cur.lastrowid + 1) 214 | 215 | def merge_tags(self, old, new): 216 | for ids, nid in zip(old, new): 217 | ids = list(ids) 218 | 219 | # Need to query using IN to get the COUNT 220 | self.conn.execute(f""" 221 | INSERT INTO tagmap (obj, tag, count) 222 | SELECT {nid}, tag, SUM(count) 223 | FROM tags WHERE obj IN {LIST(len(ids))} GROUP BY tag 224 | """, ids) 225 | self.conn.executemany( 226 | "DELETE FROM tagmap WHERE obj = ?", ids 227 | ) 228 | self.conn.commit() 229 | 230 | def add_tags(self, ids, tags): 231 | for obj, ntags in zip(ids, tags): 232 | self.conn.execute(f""" 233 | INSERT OR REPLACE INTO tagmap (obj, tag, count) 234 | SELECT ?1, id, COALESCE(tagmap.count, 0) + 1 FROM tags 235 | LEFT JOIN tagmap ON tagmap.obj = ?1 AND tagmap.tag = tags.id 236 | WHERE name = ?2 237 | """, ((obj, tag) for tag in ntags)) 238 | self.conn.commit() 239 | 240 | def commit(self): 241 | self.conn.commit() -------------------------------------------------------------------------------- /db.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | import time 3 | import sqlite3 4 | import json 5 | 6 | from dataclasses import dataclass 7 | from functools import cache 8 | from typing import Optional, Mapping 9 | 10 | from collections.abc import MutableMapping 11 | 12 | # Reduce repetition 13 | TABLE = "CREATE TABLE IF NOT EXISTS" 14 | INT = "INTEGER NOT NULL" 15 | TEXT = "TEXT NOT NULL" 16 | IDENT = f"id {INT} PRIMARY KEY AUTOINCREMENT" 17 | FNOW = f"{INT} DEFAULT (strftime('%s', 'now'))" 18 | TIMER = f"ctime {FNOW}, atime {FNOW}, access {INT} DEFAULT 0" 19 | MEMORY = f"{IDENT}, {TIMER}" 20 | # Database schema 21 | SCHEMA = f""" 22 | {TABLE} log ({IDENT}, 23 | time {INT}, 24 | level {INT}, 25 | message {TEXT} 26 | ); 27 | {TABLE} state ( -- Basically a JSON object 28 | key {TEXT} PRIMARY KEY, 29 | value {TEXT} 30 | ); 31 | {TABLE} origins ({IDENT}, 32 | name {TEXT} UNIQUE 33 | ); 34 | {TABLE} explicit ({MEMORY}, 35 | origin {INT} REFERENCES origins(id), 36 | message {TEXT}, 37 | --embedding BLOB NOT NULL, 38 | importance INTEGER 39 | ); 40 | """ 41 | 42 | @cache 43 | def sanitize(name: str) -> str: 44 | '''Sanitize an identifier, raise an error if it's invalid''' 45 | 46 | if not isinstance(name, str): 47 | raise TypeError(f"Expected str, got {type(name).__name__}") 48 | if not name.isidentifier(): 49 | raise ValueError(f"Invalid identifier: {name}") 50 | return name.lower() 51 | 52 | @cache 53 | def LIST(count: int) -> str: 54 | return f"({', '.join('?' * count)})" 55 | 56 | @cache 57 | def INSERT(table: str, fields: tuple[str, ...]) -> str: 58 | values = LIST(len(fields)) 59 | fields = ', '.join(map(sanitize, fields)) 60 | return f"INSERT INTO {sanitize(table)} ({fields}) VALUES {values}" 61 | 62 | @cache 63 | def SELECT(col: str|tuple[str, ...], table: str, fields: tuple[str, ...]) -> str: 64 | if isinstance(col, str): 65 | col = (sanitize(col),) 66 | col = ', '.join(map(sanitize, col)) 67 | pred = ' AND '.join(f"{sanitize(k)} = ?" for k in fields) 68 | return f"SELECT {col} FROM {sanitize(table)} WHERE {pred}" 69 | 70 | @cache 71 | def DELETE(table: str, where: Optional[str]=None) -> str: 72 | s = f"DELETE FROM {sanitize(table)}" 73 | return s if where is None else f"{s} WHERE {where}" 74 | 75 | @cache 76 | def UPDATE(table: str, fields: tuple[str, ...], where: Optional[str]=None) -> str: 77 | fields = ', '.join(f"{sanitize(k)} = ?" for k in fields) 78 | s = f"UPDATE {sanitize(table)} SET {fields}" 79 | return s if where is None else f"{s} WHERE {where}" 80 | 81 | @cache 82 | def IN(name: str, count: int) -> str: 83 | return f"{name} IN {LIST(count)}" 84 | 85 | @dataclass 86 | class Identified: 87 | id: int 88 | 89 | @dataclass 90 | class MemoryEntry(Identified): 91 | '''Purposefully nasty name to avoid instantiation''' 92 | ctime: float 93 | atime: float 94 | access: int 95 | 96 | @dataclass 97 | class Log(Identified): 98 | time: float 99 | level: int 100 | message: str 101 | 102 | @dataclass 103 | class Origin(Identified): 104 | name: str 105 | 106 | @dataclass 107 | class ExplicitMemory(MemoryEntry): 108 | origin: Origin 109 | message: str 110 | importance: Optional[int] 111 | 112 | @dataclass 113 | class FeaturalMemory(MemoryEntry): 114 | embedding: bytes 115 | 116 | @dataclass 117 | class AssociativeMemory(MemoryEntry): 118 | key: bytes 119 | value: bytes 120 | 121 | class StateProxy(MutableMapping): 122 | '''Proxy for the state table.''' 123 | 124 | def __init__(self, conn): 125 | # Store connection not db to avoid ref loops 126 | self._conn = conn 127 | self._cache = {} 128 | 129 | def get(self, k: str, default=None): 130 | '''Get a value from the state dict or an optional default.''' 131 | 132 | if k in self._cache: 133 | return self._cache[k] 134 | cur = self._conn.execute("SELECT value FROM state WHERE key = ?", (k,)) 135 | if row := cur.fetchone(): 136 | self._cache[k] = val = json.loads(row[0]) 137 | return val 138 | return default 139 | 140 | def update(self, mapping: Mapping[str, object]): 141 | '''Mapping update.''' 142 | 143 | self._cache.update(mapping) 144 | self._conn.executemany( 145 | "INSERT OR REPLACE INTO state (key, value) VALUES (?, ?)", 146 | mapping.items() 147 | ) 148 | self._conn.commit() 149 | 150 | def reload(self): 151 | '''Reload the cache.''' 152 | 153 | self._cache.clear() 154 | cur = self._conn.execute("SELECT key, value FROM state") 155 | for k, v in cur: 156 | self._cache[k] = json.loads(v) 157 | 158 | def load_defaults(self, **kwargs): 159 | '''Insert a default value if the key is not present.''' 160 | 161 | self._conn.executemany( 162 | "INSERT OR IGNORE INTO state (key, value) VALUES (?, ?)", 163 | ((k, json.dumps(v)) for k, v in kwargs.items()) 164 | ) 165 | self._conn.commit() 166 | self.reload() 167 | 168 | def __contains__(self, k: str): 169 | return self.get(k, ...) is ... 170 | 171 | def __getitem__(self, k: str): 172 | return self.get(k) 173 | 174 | def __setitem__(self, x: str, y): 175 | self._cache[x] = y 176 | self._conn.execute( 177 | "INSERT OR REPLACE INTO state (key, value) VALUES (?, ?)", 178 | (x, json.dumps(y)) 179 | ) 180 | self._conn.commit() 181 | 182 | def __getattr__(self, name: str): 183 | if name.startswith("_"): 184 | return super().__getattr__(name) 185 | return self[name] 186 | 187 | def __setattr__(self, name: str, value): 188 | if name.startswith("_"): 189 | return super().__setattr__(name, value) 190 | self[name] = value 191 | 192 | def __repr__(self): 193 | return repr(self._cache) 194 | 195 | class Database: 196 | def __init__(self, conn): 197 | if isinstance(conn, str): 198 | conn = sqlite3.connect(conn) 199 | self.conn = conn 200 | self.conn.row_factory = sqlite3.Row 201 | self.conn.executescript(SCHEMA) 202 | self.conn.commit() 203 | self.origins = {} 204 | self.state = StateProxy(self.conn) 205 | 206 | # Specific database methods 207 | 208 | def origin(self, ident: int|str|Origin) -> Origin: 209 | '''Convert an id or name to an origin, may insert a new one.''' 210 | # Check if the origin needs work 211 | if isinstance(ident, Origin): 212 | if ident.id is None: 213 | ident = origin.name 214 | elif ident.name is None: 215 | ident = origin.id 216 | else: 217 | return ident 218 | 219 | # Figure out what kind of origin we have 220 | origin = self.origins.get(ident, None) 221 | if isinstance(name := ident, str): 222 | if origin is None: 223 | row = self.conn.execute( 224 | "SELECT id FROM origins WHERE name = ?", (ident,) 225 | ).fetchone() 226 | if row: 227 | id = row[0] 228 | else: 229 | # Brand new origin 230 | cur = self.conn.execute( 231 | "INSERT INTO origins (name) VALUES (?)", (ident,) 232 | ) 233 | self.commit() 234 | id = cur.lastrowid 235 | else: 236 | id = origin 237 | elif isinstance(id := ident, int): 238 | if origin is None: 239 | row = self.conn.execute( 240 | "SELECT name FROM origins WHERE id = ?", (ident,) 241 | ).fetchone() 242 | if row: 243 | name = row[0] 244 | else: 245 | raise ValueError("No such origin") 246 | else: 247 | name = origin 248 | else: 249 | raise TypeError("origin must be str|int|Origin(str|int)") 250 | 251 | # Update the cache and return 252 | self.origins[name] = id 253 | self.origins[id] = name 254 | return Origin(id, name) 255 | 256 | def recent(self, lines: int) -> Iterator[ExplicitMemory]: 257 | '''Get the `lines` most recent explicit memories.''' 258 | # Sanitization 259 | if not isinstance(lines, int): 260 | raise TypeError("lines must be int") 261 | 262 | recent = self.conn.execute( 263 | f"SELECT * FROM explicit ORDER BY ctime DESC LIMIT {lines}" 264 | ).fetchall()[::-1] 265 | 266 | for row in recent: 267 | yield ExplicitMemory(**row, origin=self.origin(row.origin)) 268 | 269 | def log(self, level: int, msg: str) -> int: 270 | '''Insert a new log entry. Returns the id.''' 271 | cur = self.conn.execute( 272 | "INSERT INTO log (time, level, message) VALUES (?, ?, ?)", 273 | (time.time(), level, msg) 274 | ) 275 | self.conn.commit() 276 | return cur.lastrowid 277 | 278 | def insert_explicit(self, memory: ExplicitMemory) -> int: 279 | '''Insert an explicit memory. Returns the id.''' 280 | cur = self.conn.execute( 281 | "INSERT INTO explicit (origin, message, importance) VALUES (?, ?, ?)", 282 | (memory.origin.id, memory.message, memory.importance) 283 | ) 284 | self.conn.commit() 285 | return cur.lastrowid 286 | -------------------------------------------------------------------------------- /llm/server.py: -------------------------------------------------------------------------------- 1 | from concurrent import futures 2 | import grpc 3 | from build import llm_pb2, llm_pb2_grpc 4 | import os 5 | import sys 6 | import torch 7 | import torch.nn.functional as F 8 | from transformers.models.auto import AutoTokenizer, AutoModelForCausalLM 9 | import traceback as tb 10 | import tensor 11 | from typing import Optional, Final 12 | 13 | #MODEL = "databricks/dolly-v2-7b" 14 | MODEL: Final = "gpt2" 15 | PORT: Final = f'unix://{os.getcwd()}/llm.sock' 16 | WORKERS: Final = os.cpu_count() 17 | DEVICE: Final = "cuda" if torch.cuda.is_available() else "cpu" 18 | DEBUG: Final = True 19 | 20 | def clamp(x, lo, hi): 21 | return max(lo, min(x, hi)) 22 | 23 | def repetition_penalty(logits, input_ids, frequency_penalty=1.0, presence_penalty=1.0): 24 | unique_tokens, counts = torch.unique(input_ids, return_counts=True) 25 | 26 | if frequency_penalty: 27 | logits[:, unique_tokens] /= frequency_penalty ** counts 28 | 29 | if presence_penalty: 30 | logits[:, unique_tokens] -= presence_penalty 31 | 32 | return logits 33 | 34 | def top_kp( 35 | logits: torch.Tensor, 36 | top_k: int = 0, 37 | top_p: float = 1.0, 38 | filter_value: float = -float("Inf"), 39 | min_tokens_to_keep: int = 1, 40 | ) -> torch.Tensor: 41 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 42 | Args: 43 | logits: logits distribution shape (batch size, vocabulary size) 44 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 45 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 46 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 47 | Make sure we keep at least min_tokens_to_keep per batch example in the output 48 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 49 | """ 50 | if 0 < top_k: 51 | top_k = clamp(top_k, min_tokens_to_keep, logits.shape[-1]) # Safety check 52 | # Remove all tokens with a probability less than the last token of the top-k 53 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 54 | logits[indices_to_remove] = filter_value 55 | 56 | if 0 < top_p < 1: 57 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 58 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 59 | 60 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 61 | sorted_indices_to_remove = cumulative_probs > top_p 62 | if min_tokens_to_keep > 1: 63 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 64 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 65 | # Shift the indices to the right to keep also the first token above the threshold 66 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 67 | sorted_indices_to_remove[..., 0] = 0 68 | 69 | # scatter sorted tensors to original indexing 70 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 71 | logits[indices_to_remove] = filter_value 72 | return logits 73 | 74 | class LLMService(llm_pb2_grpc.LLM): 75 | def __init__(self, 76 | debug=DEBUG, 77 | port: Optional[int]=None, 78 | model: Optional[str]=None, 79 | device: Optional[str]=None, 80 | workers: Optional[int]=None 81 | ): 82 | super().__init__() 83 | 84 | self.debug = debug 85 | self.port = port or PORT 86 | model = model or MODEL 87 | self.workers = workers or WORKERS 88 | self.device = device or DEVICE 89 | 90 | print("Loading model", model) 91 | self.model = AutoModelForCausalLM.from_pretrained(model).to(self.device) 92 | self.tokenizer = AutoTokenizer.from_pretrained(model) 93 | 94 | def Encode(self, request, context): 95 | try: 96 | if self.debug: 97 | print("Tokenize:", request.text) 98 | 99 | return llm_pb2.Encoding( 100 | tokens=self.tokenizer.encode(request.text) 101 | ) 102 | 103 | except Exception as e: 104 | tb.print_exception(e) 105 | raise e 106 | 107 | def Decode(self, request, context): 108 | try: 109 | if self.debug: 110 | print("Decode:", request.tokens) 111 | return llm_pb2.Decoding( 112 | text=self.tokenizer.decode(request.tokens) 113 | ) 114 | 115 | except Exception as e: 116 | tb.print_exception(e) 117 | raise e 118 | 119 | @torch.no_grad() 120 | def Forward(self, request, context): 121 | try: 122 | return_hidden = request.return_hidden or False 123 | return_attention = request.return_attention or False 124 | 125 | if request.text: 126 | text = request.text 127 | input_ids = self.tokenizer.encode(text, return_tensors='pt') 128 | attention_mask = torch.ones_like(input_ids) 129 | elif request.tokens: 130 | input_ids = tensor.decode(request.tokens) 131 | if self.debug: 132 | text = self.tokenizer.decode(input_ids.reshape(-1)) 133 | if request.attention_mask: 134 | attention_mask = tensor.decode(request.attention_mask) 135 | else: 136 | attention_mask = torch.ones_like(input_ids) 137 | else: 138 | raise ValueError("Must provide either text or tokens") 139 | 140 | input_ids = input_ids.to(self.device) 141 | 142 | if self.debug: 143 | print("Forward:", text) 144 | print(f" {return_hidden=}") 145 | print(f" {return_attention=}") 146 | print(f" {attention_mask=}") 147 | 148 | output = self.model( 149 | input_ids, 150 | attention_mask=attention_mask, 151 | output_hidden_states=return_hidden, 152 | output_attentions=return_attention, 153 | return_dict=True 154 | ) 155 | 156 | fields = {"logits": tensor.encode(output.logits, 'f')} 157 | if return_hidden: 158 | fields["hidden"] = [tensor.encode(h, 'f') for h in output.hidden_states] 159 | if return_attention: 160 | fields["attention"] = [tensor.encode(a, 'f') for a in output.attentions] 161 | 162 | return llm_pb2.ForwardResponse(**fields) 163 | 164 | except Exception as e: 165 | tb.print_exception(e) 166 | raise e 167 | 168 | @torch.no_grad() 169 | def Complete(self, request, context): 170 | try: 171 | text = request.text or "\n" 172 | max_tokens = clamp(int(request.max_tokens or 1), 1, self.model.config.n_positions) 173 | temperature = clamp(float(request.temperature or 1), 0, 1) 174 | top_k = clamp(int(request.top_k or 0), 0, self.model.config.vocab_size) 175 | top_p = clamp(float(request.top_p or 1), 0, 1) 176 | presence_penalty = clamp(float(request.presence_penalty or 0), 0, 1) 177 | frequency_penalty = clamp(float(request.frequency_penalty or 0), 0, 1) 178 | stop = request.stop or [] 179 | stop.append("<|endoftext|>") 180 | 181 | if self.debug: 182 | print("Complete:", text) 183 | print(f" {max_tokens=}") 184 | print(f" {temperature=}") 185 | print(f" {top_k=}") 186 | print(f" {top_p=}") 187 | print(f" {presence_penalty=}") 188 | print(f" {frequency_penalty=}",) 189 | print(f" {stop=}") 190 | 191 | input_ids = self.tokenizer.encode(text, return_tensors='pt').to(self.device) 192 | 193 | past_kv = None 194 | for _ in range(max_tokens): 195 | output = self.model( 196 | input_ids, 197 | past_key_values=past_kv, 198 | use_cache=True 199 | ) 200 | logits = output.logits[:, -1, :] 201 | if past_kv is not None: 202 | past_kv += output.past_key_values 203 | 204 | logits = repetition_penalty(logits, input_ids, frequency_penalty, presence_penalty) 205 | logits /= temperature 206 | logits = top_kp(logits, top_k, top_p) 207 | P = torch.softmax(logits, dim=-1) 208 | next_token_id = torch.multinomial(P, num_samples=1) 209 | next_token_str = self.tokenizer.decode(next_token_id[0]) 210 | 211 | if self.debug: 212 | print(next_token_str, end="", flush=True) 213 | 214 | if next_token_str in stop: 215 | break 216 | 217 | yield llm_pb2.CompletionResponse(text=next_token_str) 218 | input_ids = torch.cat((input_ids, next_token_id), dim=1) 219 | 220 | if self.debug: 221 | print() 222 | except Exception as e: 223 | tb.print_exception(e) 224 | raise e 225 | 226 | @torch.no_grad() 227 | def Embed(self, request, context): 228 | try: 229 | if self.debug: 230 | print("Embed:", request.text) 231 | 232 | # Use transformers library to generate sentence embeddings 233 | input_ids = self.tokenizer.encode(request.text, return_tensors='pt').to(self.device) 234 | embed = self.model(input_ids).logits 235 | return llm_pb2.EmbedResponse(embed=tensor.encode(embed, 'f')) 236 | except Exception as e: 237 | tb.print_exception(e) 238 | raise e 239 | 240 | def build_argparse(): 241 | import argparse 242 | 243 | ap = argparse.ArgumentParser( 244 | description="Run a Large Language Model in a separate process which can be queried over gRPC." 245 | ) 246 | ap.add_argument("-D", "--debug", action="store_true", help="Enable debug mode") 247 | ap.add_argument("-p", "--port", type=str, default=PORT, help="Port to listen on") 248 | ap.add_argument("-M", "--model", type=str, default=MODEL, help="Model name or path") 249 | ap.add_argument("-d", "--device", type=str, default=DEVICE, help="Device to run model on") 250 | ap.add_argument("-w", "--workers", type=int, default=WORKERS, help="Number of gRPC workers") 251 | 252 | return ap 253 | 254 | def main(argv): 255 | args = build_argparse().parse_args(argv) 256 | server = grpc.server(futures.ThreadPoolExecutor(max_workers=WORKERS)) 257 | llm_pb2_grpc.add_LLMServicer_to_server(LLMService(**vars(args)), server) 258 | server.add_insecure_port(PORT) 259 | print("Listening...") 260 | server.start() 261 | server.wait_for_termination() 262 | 263 | if __name__ == '__main__': 264 | main(sys.argv[1:]) 265 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional, Iterator 3 | import time 4 | 5 | import prompt 6 | import db 7 | from itertools import chain 8 | import re 9 | import json 10 | import printf 11 | 12 | from complete import complete 13 | 14 | NAME = "Orin" 15 | CHAT_LINES = 25 16 | DB_FILE = "memory.db" 17 | PRESSURE = 1/2 18 | EMOTION = None#"disoriented" 19 | 20 | LEVEL = prompt.LOG_LEVEL.index("debug") 21 | 22 | class StreamCapture: 23 | ''' 24 | Convert a list of patterns into a regex pattern that matches every prefix 25 | of the total pattern, then applies this to a stream of tokens. 26 | 27 | Example: 28 | XYZ\s+=\s+(\d+)?\s* becomes 29 | (?:X(?:Y(?:Z(?:\s+(?:=(?:\s+(?:((?P\d+)))?)?)?)?) 30 | ''' 31 | def __init__(self, pattern): 32 | # Add inner parentheses to the last pattern to detect final match 33 | pattern = [*pattern[:-1], f"(?P{pattern[-1]})"] 34 | pattern = "".join(f"(?:{p}" for p in pattern) + '?'.join(')' * len(pattern)) 35 | self.pattern = re.compile(pattern) 36 | 37 | def capture(self, stream): 38 | buf = "" 39 | m = None 40 | stream = iter(stream) 41 | for token in stream: 42 | buf += token 43 | if m := self.pattern.search(buf): 44 | # Stop searching if there can't be more 45 | if m.end() < len(buf): 46 | if m.lastgroup == "END": break 47 | yield m[0] 48 | buf = buf[m.end():] 49 | else: 50 | yield token 51 | 52 | if m: 53 | # Yield either the match or its text if it's not full 54 | yield m if m.lastgroup == "END" else m[0] 55 | if buf := buf[m.end():]: 56 | yield buf 57 | yield from stream 58 | 59 | IMPORTANCE_PARAM = StreamCapture([ 60 | *prompt.IMPORTANCE_TAG, "\s+", "=", "\s+", "(\d+)" 61 | ]) 62 | EMOTION_PARAM = StreamCapture([ 63 | *prompt.EMOTION_TAG, "\s+", "=", "\s+", 64 | '"', "((?:(?:[^\\\\]+|\\\\.)*)+)", '"' 65 | ]) 66 | 67 | def tee(stream: Iterator) -> tuple[Iterator, list]: 68 | '''Tee a stream into a list.''' 69 | # Note to self: Stop deleting this! Just keep it around in case we 70 | # need it again. 71 | def impl(stream, arr): 72 | for item in stream: 73 | arr.append(item) 74 | yield item 75 | arr = [] 76 | return impl(stream, arr), arr 77 | 78 | @dataclass 79 | class Message: 80 | origin: db.Origin 81 | message: Optional[str] = None 82 | ctime: float = field(default_factory=time.time) 83 | importance: Optional[int] = None 84 | 85 | def __str__(self): 86 | return f"{prompt.timestamp(self.ctime)} {self.output()}" 87 | 88 | @staticmethod 89 | def from_explicit(explicit: db.ExplicitMemory): 90 | '''Convert an explicit message to a user message''' 91 | return Message(explicit.origin, explicit.message, explicit.ctime, explicit.importance) 92 | 93 | def to_explicit(self) -> db.ExplicitMemory: 94 | '''Convert a user message to an explicit message''' 95 | return db.ExplicitMemory(None, 96 | ctime=self.ctime, atime=self.ctime, access=0, 97 | origin=self.origin, message=self.message, importance=self.importance 98 | ) 99 | 100 | def format_name(self): 101 | if self.origin.name in {"SYSTEM", "SUMMARY"}: 102 | return f"[{self.origin.name}]" 103 | return f"<{self.origin.name}>" 104 | 105 | def output(self): 106 | '''Output as viewed by the user''' 107 | out = self.format_name() 108 | if self.message is not None: 109 | out += f" {self.message}" 110 | return out 111 | 112 | class Agent: 113 | def __init__(self): 114 | self.db = db.Database(DB_FILE) 115 | self.chatlog = [] 116 | 117 | self.db.state.load_defaults( 118 | name = NAME, 119 | loglevel = LEVEL, 120 | pressure = PRESSURE, 121 | emotion = EMOTION, 122 | unsummarized = 0, 123 | lines = CHAT_LINES, 124 | ) 125 | # Set during debugging so we don't summarize a bunch of reboots 126 | self.db.state.unsummarized = 0 127 | 128 | self.debug(f"__init__({self.name!r})") 129 | 130 | self.reload() 131 | self.add_memory(self.message("SYSTEM", prompt.reload())) 132 | 133 | @property 134 | def name(self): return self.db.state.name 135 | @property 136 | def loglevel(self): return self.db.state.loglevel 137 | @property 138 | def pressure(self): return self.db.state.pressure 139 | @property 140 | def emotion(self): return self.db.state.emotion 141 | @property 142 | def unsummarized(self): return self.db.state.unsummarized 143 | @property 144 | def lines(self): return self.db.state.lines 145 | 146 | def message(self, origin, msg=None, importance=None) -> Message: 147 | '''Create a message from a string''' 148 | return Message(self.db.origin(origin), msg, time.time(), importance) 149 | 150 | def log(self, level: int, msg: str): 151 | '''Debug print''' 152 | 153 | self.db.log(level, msg) 154 | 155 | if level <= self.loglevel: 156 | printf.log(level, msg) 157 | 158 | def error(self, msg): self.log(1, msg) 159 | def warn(self, msg): self.log(2, msg) 160 | def info(self, msg): self.log(3, msg) 161 | def debug(self, msg): self.log(4, msg) 162 | def verbose(self, msg): self.log(5, msg) 163 | 164 | def command(self, user: str, cmd: str, args: list[str]): 165 | '''Process a command.''' 166 | 167 | self.debug(f"command({user!r}, {cmd!r}, {args!r})") 168 | 169 | match cmd.lower(): 170 | case "select"|"update"|"insert"|"delete": 171 | yield from self.command(user, "sql", [cmd.upper(), *args]) 172 | case "sql": 173 | try: 174 | rows = self.db.execute(' '.join(args)) 175 | self.db.commit() 176 | print(rows) 177 | printf.json([dict(row) for row in rows.fetchall()]) 178 | except Exception as e: 179 | self.error(f"SQL ERROR: {e}") 180 | case "state": 181 | if len(args) == 0: 182 | printf.json(self.db.state.cache) 183 | elif len(args) == 1: 184 | printf.json(self.db.state[args[0]]) 185 | else: 186 | try: 187 | value = json.loads(args[1]) 188 | except json.JSONDecodeError: 189 | value = args[1] 190 | self.db.state[args[0]] = value 191 | self.info(f"Set state[{args[0]!r}] = {value!r}") 192 | case "sum"|"summary"|"summarize": 193 | yield from self.summarize(self.chatlog) 194 | yield '\n' 195 | case "prompt": 196 | yield self.build_prompt() + "\n" 197 | case "level": 198 | if len(args) > 0: 199 | level = args[0].upper() 200 | if level in prompt.LOG_LEVEL: 201 | il = prompt.LOG_LEVEL.index(level) 202 | else: 203 | try: 204 | il = int(level) 205 | except ValueError: 206 | self.error("Invalid log level") 207 | return 208 | 209 | self.db.state.loglevel = il 210 | self.info(f"level = {level} ({il})") 211 | else: 212 | il = self.loglevel 213 | if il > len(prompt.LOG_LEVEL): 214 | level = "verbose" + "+" * (il - len(prompt.LOG_LEVEL)) 215 | else: 216 | level = prompt.LOG_LEVEL[il] 217 | yield f"level = {level} ({il})\n" 218 | case _: 219 | self.error(f"Unknown command {cmd}") 220 | 221 | def build_prompt(self): 222 | self.debug("build_prompt()") 223 | return prompt.master(self.name, self.emotion) + '\n'.join(self.chatlog) 224 | 225 | def reload(self): 226 | '''Reload the agent's most recent memories.''' 227 | 228 | self.debug("reload()") 229 | self.chatlog = [str(Message.from_explicit(msg)) for msg in self.db.recent(self.lines)] 230 | 231 | def add_message(self, msg: str): 232 | '''Add a chat message to the chat log.''' 233 | 234 | self.debug(f"add_message({msg!r})") 235 | self.db.state.unsummarized += 1 236 | print("Update", self.db.state.unsummarized) 237 | 238 | # Rolling chat log 239 | self.chatlog = self.chatlog[-CHAT_LINES:] + [msg] 240 | 241 | def add_memory(self, msg: Message) -> int: 242 | '''Add a memory to the database, which also adds to chat log.''' 243 | 244 | yield '\n' 245 | self.debug(f"add_memory({msg!r})") 246 | 247 | # All memories are added to the internal chat log 248 | self.add_message(str(msg)) 249 | return self.db.insert_explicit(msg.to_explicit()) 250 | 251 | def complete(self, prompt: str): 252 | '''AI prompt completion.''' 253 | 254 | self.verbose(f"complete({prompt!r})") 255 | return complete(prompt) 256 | 257 | def summarize(self, dialog: str): 258 | '''Summarize the current chat log.''' 259 | 260 | self.verbose(f"summarize({dialog!r})") 261 | 262 | self.db.state.unsummarized = 0 263 | importance = None 264 | summary = [] 265 | 266 | # Get the summary completion 267 | stream = self.complete(prompt.summarize(self.name, dialog)) 268 | for token in IMPORTANCE_PARAM.capture(stream): 269 | if isinstance(token, str): 270 | yield token 271 | summary.append(token) 272 | else: 273 | importance = token[1] 274 | 275 | # Update the importance 276 | if importance: 277 | print("Importance =", importance) 278 | try: 279 | importance = int(importance[0], 10) 280 | except ValueError as e: 281 | self.log(f"Invalid importance: {e}") 282 | importance = None 283 | 284 | self.add_memory(self.message(prompt.SUMMARY_TAG, ''.join(summary), importance=importance)) 285 | 286 | def chat(self, user: str, msg: str) -> Iterator[str]: 287 | '''Respond to the user's message.''' 288 | 289 | self.verbose(f"chat({user!r}, {msg!r})") 290 | 291 | # Add user message 292 | self.add_memory(self.message(user, msg)) 293 | 294 | # Process AI message 295 | pending = self.message(self.name) 296 | po = pending.output() 297 | 298 | completion = [] 299 | p = f"{self.build_prompt()}\n{po}" 300 | stream = self.complete(p) 301 | 302 | yield po 303 | 304 | for token in EMOTION_PARAM.capture(stream): 305 | if isinstance(token, str): 306 | yield token 307 | completion.append(token) 308 | else: 309 | self.db.state.emotion = token[1] 310 | 311 | pending.message = ''.join(completion) 312 | self.add_memory(pending) 313 | 314 | # Summarize every so often 315 | if self.unsummarized >= self.lines * self.pressure: 316 | yield from self.summarize(self.build_prompt()) -------------------------------------------------------------------------------- /tf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' 3 | Common transformer classes and functions. 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from typing import Optional, TypeAlias, Type 10 | from abc import abstractmethod 11 | from dataclasses import dataclass, field 12 | 13 | from common import NORM, DROP, BIAS, default 14 | 15 | ConfigMod: TypeAlias = Optional[bool|int|float|nn.Module] 16 | 17 | def mod_or_config(config: ConfigMod, default: float, mod: Type[nn.Module]): 18 | '''Returns a module or a module initialized with a config.''' 19 | 20 | if isinstance(config, (int, float)): 21 | return mod(config) if config else None 22 | if isinstance(config, bool): 23 | return mod(default) if config else None 24 | return config 25 | 26 | @dataclass 27 | class Output: 28 | '''Optional outputs for layers.''' 29 | attention: Optional[list[Optional[torch.Tensor]]] = None 30 | '''Attention weights of each layer''' 31 | cache: Optional[tuple[torch.Tensor, torch.Tensor]] = None 32 | '''Key-value cache for faster inference''' 33 | 34 | @dataclass 35 | class Context: 36 | '''Context for layers.''' 37 | output: Output = field(default_factory=Output) 38 | '''Optional outputs for layers.''' 39 | feedback: Optional[list[torch.Tensor]] = None 40 | '''Feedback for instill.''' 41 | mask: Optional[torch.Tensor] = None 42 | '''Attention mask.''' 43 | head_mask: Optional[torch.Tensor] = None 44 | '''Attention head mask.''' 45 | cache: Optional[torch.Tensor] = None 46 | '''Key-value cache for faster inference''' 47 | 48 | class RotaryEmbedding(nn.Module): 49 | ''' 50 | Rotary Embedding (RoPE) 51 | 52 | RoFormer: Enhanced Transformer with Rotary Position Embedding 53 | https://arxiv.org/abs/2104.09864 54 | ''' 55 | 56 | def __init__(self, dim, base=10000.0): 57 | super().__init__() 58 | self.dim = dim 59 | self.cache = {} 60 | inv_freq = base ** -(torch.arange(0, dim, 2).float() / dim) 61 | emb = torch.empty(self.dim) 62 | emb[::2] = inv_freq 63 | emb[1::2] = inv_freq 64 | self.register_buffer('inv_freq', emb) 65 | 66 | def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 67 | assert q.shape == k.shape 68 | seq = q.shape[1] 69 | 70 | sc = self.cache.get(seq, None) 71 | if sc is None: 72 | sin, cos = torch.sin(seq * self.inv_freq), torch.cos(seq * self.inv_freq) 73 | self.cache[seq] = sin, cos 74 | else: 75 | sin, cos = sc 76 | 77 | def rotate(x): 78 | # Rotate half 79 | x1, x2 = x[..., :-1:2], x[..., 1::2] 80 | x_rot = torch.cat((-x2, x1), dim=-1) 81 | return (x * cos) + (x_rot * sin) 82 | 83 | return rotate(q), rotate(k) 84 | 85 | class ScaledDotProductSelector(nn.Module): 86 | '''Traditional softmax(Q K^T) V attention selector.''' 87 | 88 | def __init__(self, max_seq: Optional[int], dropout: ConfigMod=None): 89 | ''' 90 | Parameters: 91 | max_seq: Maximum sequence length 92 | dropout: Attention dropout 93 | ''' 94 | super().__init__() 95 | 96 | self.attn_dropout = mod_or_config(dropout, DROP, nn.Dropout) 97 | 98 | self.register_buffer("bias", 99 | torch.tril(torch.ones((max_seq, max_seq), dtype=torch.bool)).view( 100 | 1, 1, max_seq, max_seq 101 | ) if max_seq is not None else None 102 | ) 103 | 104 | def forward(self, 105 | q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, 106 | ctx: Context 107 | ) -> torch.Tensor: 108 | attn_weight = torch.matmul(q, k.transpose(-1, -2)) 109 | 110 | dtype, device = attn_weight.dtype, attn_weight.device 111 | qlen, klen = q.shape[-2], k.shape[-2] 112 | 113 | # Causal mask 114 | if self.bias is not None: 115 | causal_mask = self.bias[:, :, klen - qlen : klen, :klen] 116 | mask_value = torch.finfo(dtype).min 117 | mask_value = torch.full((), mask_value, dtype=dtype).to(device) 118 | attn_weight = torch.where(causal_mask, attn_weight, mask_value) 119 | 120 | attn_mask = ctx.mask 121 | if attn_mask is not None: 122 | attn_weight = attn_weight + attn_mask 123 | 124 | attn_weight = F.softmax(attn_weight, dim=-1) 125 | 126 | # Downcast (if necessary) back to V's dtype 127 | if self.attn_dropout is not None: 128 | attn_weight = self.attn_dropout(attn_weight.type(v.dtype)) 129 | 130 | head_mask = ctx.head_maskcausal_mask 131 | if head_mask is not None: 132 | attn_weight = attn_weight * head_mask 133 | 134 | attn = torch.matmul(attn_weight, v) 135 | 136 | atv = ctx.output.attention 137 | if atv is not None: 138 | atv.append(attn_weight) 139 | 140 | return attn 141 | 142 | class AssociativeMemorySelector(nn.Module): 143 | ''' 144 | Attention recontextualized as kNN memory. Replaces feed forwarde layers. 145 | 1. QKV projections 146 | 2. QK L2 norm, KV STE 147 | 3. Top-K for keys and their values 148 | 4. softmax(Q K^T) V attention on the results 149 | 150 | Transformer Feed-Forward Layers Are Key-Value Memories 151 | https://arxiv.org/abs/2012.14913 152 | 153 | Augmenting Self-attention with Persistent Memory 154 | https://arxiv.org/pdf/1907.01470.pdf 155 | * Proves FF networks are equivalent to attention with static memory 156 | 157 | Attention Approximates Sparse Distributed Memory 158 | https://arxiv.org/abs/2111.05498 159 | * Theoretical basis for why FF might be both attention and memory 160 | 161 | Memorizing Transformers 162 | https://arxiv.org/abs/2203.08913 163 | * kNN memory, paper uses it as an alternative to recurrency 164 | 165 | Neural Turing Machines 166 | https://arxiv.org/abs/1410.5401 167 | * Read and write/erase heads, paper uses it for memory-augmented tasks 168 | ''' 169 | 170 | def __init__(self, 171 | embed, 172 | selector: Optional[nn.Module]=None, 173 | max_seq: Optional[int]=None, 174 | dropout: ConfigMod=None 175 | ): 176 | ''' 177 | Parameters: 178 | embed: Embedding dimension 179 | selector: Selector module after memory lookup 180 | max_seq: Maximum sequence length 181 | dropout: Residual dropout 182 | ''' 183 | 184 | super().__init__() 185 | self.selector = default(selector, lambda: ScaledDotProductSelector(max_seq, dropout)) 186 | 187 | # sqrt of QK Norm init because we can't assume selector is a dot product 188 | self.scale = nn.Parameter(torch.tensor(embed ** -0.25)) 189 | 190 | def forward(self, 191 | q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, 192 | ctx: Context 193 | ) -> torch.Tensor: 194 | # QK Norm and STE 195 | q = F.normalize(q, dim=-1) 196 | k = F.normalize(k, dim=-1).detach() 197 | v = v.detach() 198 | 199 | mk, mv = ctx.associative.search(k, v) 200 | 201 | # STE 202 | k.grad = mk.grad 203 | v.grad = mv.grad 204 | 205 | q, k = q * self.scale, k * self.scale 206 | return self.selector(q, k, v, ctx) 207 | 208 | class Residual(nn.ModuleList): 209 | ''' 210 | Adds residuals to each sequential model. 211 | ''' 212 | 213 | def forward(self, x, ctx): 214 | for layer in self: 215 | x = x + layer(x, ctx) 216 | return x 217 | 218 | class Instill(Residual): 219 | ''' 220 | INformation Still. Information goes up and comes back down condensed and 221 | distilled. This is a generalization of the residual connection. 222 | 223 | Adds residual of upper layers to lower layers as feedback. 224 | 225 | Addressing Some Limitations of Transformers with Feedback Memory 226 | https://arxiv.org/abs/2002.09402 227 | * Using output of upper layers for lower (modified: per layer pair, no pooling) 228 | 229 | Memory transformers 230 | https://arxiv.org/abs/2006.11527 231 | * Concatenating memory to input tokens (modified: no memory controller) 232 | ''' 233 | 234 | def _feedback(self, x, f, ctx): 235 | '''How to apply the feedback''' 236 | return x if f is None else x + f 237 | 238 | def forward(self, x, ctx): 239 | f = ctx.feedback 240 | if f is None: 241 | f = [None] * len(self) 242 | 243 | for i, layer in enumerate(self): 244 | print("Layer", i) 245 | x = x + layer(self._feedback(x, f[i], ctx), ctx) 246 | return x 247 | 248 | class TransformerLayer(nn.Module): 249 | ''' 250 | Base class for layers in a transformer. Mostly norms and dropout, 251 | since some forms of attention don't even have the QKV projections. 252 | 253 | Attention Is All You Need 254 | https://arxiv.org/abs/1706.03762 255 | ''' 256 | 257 | def __init__(self, 258 | embed: int, 259 | *, 260 | prenorm: ConfigMod=None, 261 | dropout: ConfigMod=None, 262 | postnorm: ConfigMod=None, 263 | bias: bool=BIAS 264 | ): 265 | ''' 266 | Parameters: 267 | embed: Embedding dimension 268 | max_seq: Maximum sequence length 269 | 270 | prenorm: Whether to use pre-normalization 271 | dropout: Residual dropout 272 | postnorm: Whether to use post-normalization 273 | 274 | bias: Whether to use bias 275 | ''' 276 | super().__init__() 277 | self.embed = embed 278 | 279 | self.prenorm = mod_or_config(prenorm, NORM, lambda eps: nn.LayerNorm(embed, eps)) 280 | self.resid_dropout = mod_or_config(dropout, DROP, nn.Dropout) 281 | self.postnorm = postnorm and mod_or_config(postnorm, NORM, lambda eps: nn.LayerNorm(embed, eps)) 282 | self.out_proj = nn.Linear(embed, embed, bias) 283 | 284 | @abstractmethod 285 | def _forward(self, x: torch.Tensor, ctx: Context) -> torch.Tensor: 286 | ''' 287 | Perform attention-specific operations to convert input to Q, K, and V, 288 | then into an attention output. 289 | ''' 290 | pass 291 | 292 | def forward(self, x: torch.Tensor, ctx: Context) -> torch.Tensor: 293 | ''' 294 | Normalizations, projections, and dropouts. 295 | ''' 296 | 297 | if self.prenorm is not None: 298 | x = self.prenorm(x) 299 | 300 | atv = ctx.output.attention 301 | atn = self._forward(x, ctx) 302 | if atv is not None: 303 | assert len(ctx.output.attention) > len(atv), "Attention output is not growing" 304 | 305 | atn = self.out_proj(atn) 306 | 307 | if self.postnorm is not None: 308 | atn = self.postnorm(atn) 309 | 310 | if self.resid_dropout is not None: 311 | atn = self.resid_dropout(atn) 312 | 313 | if h := ctx.output.hidden: 314 | h.append(atn) 315 | 316 | return atn 317 | 318 | class MultiheadAttention(TransformerLayer): 319 | ''' 320 | Normal mutli-headed attention with qkv and output projections. 321 | ''' 322 | 323 | def __init__(self, 324 | embed: int, 325 | heads: int=1, 326 | *, 327 | rotary_embed: bool=False, 328 | max_seq: Optional[int]=None, 329 | selector: Optional[nn.Module]=None, 330 | 331 | dropout: ConfigMod=None, 332 | bias: bool=BIAS 333 | ): 334 | ''' 335 | Parameters: 336 | embed: Embedding dimension 337 | max_seq: Maximum sequence length 338 | 339 | selector: selector module 340 | 341 | prenorm: Whether to use pre-normalization 342 | dropout: Residual dropout 343 | postnorm: Whether to use post-normalization 344 | 345 | qknorm: Whether to normalize queries and keys 346 | bias: Whether to use bias 347 | ''' 348 | super().__init__( 349 | embed, 350 | dropout=dropout, 351 | bias=bias 352 | ) 353 | self.heads = heads 354 | 355 | self.rotary = RotaryEmbedding(embed) if rotary_embed else None 356 | 357 | if selector is None: 358 | assert max_seq is not None, "max_seq must be specified if selector is not" 359 | selector = ScaledDotProductSelector(max_seq) 360 | self.selector = selector 361 | 362 | self.scale = embed ** -0.25 363 | 364 | # Cramming: Training a Language Model on a Single GPU in One Day 365 | # https://arxiv.org/abs/2212.14034 366 | # * Removing bias has negligible effect on loss and reduces parameters 367 | self.qkv_proj = nn.Linear(embed, embed * 3, bias) 368 | 369 | def _split_heads(self, x): 370 | x = x.view(*x.shape[:-1], self.heads, -1) 371 | return x.transpose(-2, -3) 372 | 373 | def _merge_heads(self, x): 374 | x = x.transpose(-2, -3) 375 | return x.reshape(*x.shape[:-2], self.embed) 376 | 377 | def _forward(self, x, ctx): 378 | q, k, v = self.qkv_proj(x).chunk(3, dim=-1) 379 | 380 | if self.rotary is not None: 381 | q, k = self.rotary(q, k, ctx) 382 | 383 | # Caching for faster inference 384 | cache = ctx.cache 385 | if cache is not None: 386 | past_k, past_v = cache 387 | k = torch.cat((past_k, k), dim=-2) 388 | v = torch.cat((past_v, v), dim=-2) 389 | ctx.output.cache = (k, v) 390 | 391 | # QK Norm 392 | q = F.normalize(q, dim=-1) * self.scale 393 | k = F.normalize(k, dim=-1) * self.scale 394 | 395 | q, k, v = map(self._split_heads, (q, k, v)) 396 | x = self.selector(q, k, v, ctx) 397 | return self._merge_heads(x) 398 | 399 | class LanguageModel(nn.Module): 400 | ''' 401 | Wraps a language model with a token embedding and a linear output layer. 402 | ''' 403 | 404 | def __init__(self, 405 | vocab: int, 406 | embed: int, 407 | model: list[nn.Module], 408 | dropout: ConfigMod=None, 409 | postnorm: ConfigMod=None, 410 | dtype: Optional[torch.dtype]=None 411 | ): 412 | ''' 413 | Parameters 414 | vocab: Vocabulary size 415 | embed: Embedding size 416 | model: Language model 417 | dropout: Embedding dropout layer 418 | postnorm: Post-normalization layer 419 | dtype: Data type 420 | ''' 421 | super().__init__() 422 | 423 | self.embed = nn.Embedding(vocab, embed) 424 | self.embed_dropout = mod_or_config(dropout, DROP, nn.Dropout) 425 | self.lm = model 426 | self.postnorm = mod_or_config(postnorm, NORM, lambda x: nn.LayerNorm((embed,), x)) 427 | # Tie lm_head and embed weights 428 | self.lm_head = nn.Linear(vocab, embed, bias=False) 429 | self.dtype = dtype or torch.float32 430 | 431 | def tie_weights(self): 432 | '''Tie lm_head and embed weights''' 433 | self.lm_head.weight = self.embed.weight 434 | 435 | def forward(self, x: torch.Tensor, ctx: Context) -> torch.Tensor: 436 | ''' 437 | kwargs: 438 | x: Input embeddings or ids 439 | featural: Static memory 440 | associative: Dynamic memory 441 | 442 | attention_mask: Mask for attention 443 | head_mask: Mask for attention heads 444 | 445 | output_attention: Output attention 446 | output_hidden: Output hidden states 447 | ''' 448 | 449 | # Convert ids to embedding 450 | if not x.is_floating_point(): 451 | x = self.embed(x) 452 | 453 | # GPT2Attention mask 454 | if attention_mask is not None: 455 | attention_mask = attention_mask.view(x.shape[0], -1) 456 | attention_mask = attention_mask[:, None, None, :] 457 | 458 | # Adjust to (-inf, 0] for additive mask 459 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility 460 | attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min 461 | 462 | if self.embed_dropout is not None: 463 | x = self.embed_dropout(x) 464 | 465 | x = self.lm(x, ctx) 466 | 467 | hidden = ctx.output.hidden 468 | if hidden is not None: 469 | hidden.append(x) 470 | 471 | # Process final hidden state 472 | if self.postnorm is not None: 473 | x = self.postnorm(x) 474 | 475 | x = self.lm_head(x) 476 | return x -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | [Let Language Models be Language Models](https://docs.google.com/document/d/1U7O6iEBwuxyQRiXe4pn7HRYWAyEGtEmFX59GL1vdwf8/view#) 2 | 3 | A major problem with LLMs and the direction we're going with them is they aren't actually pure language models in the literal sense. In order to fulfill the autoregressive objective, they're forced to memorize information which has nothing to do with language modeling, making them some kind of "ontology model" for lack of a better phrase. In this paper, I propose a scalable way to decouple the memorization requirement from the autoregressive language modeling objective which offers a number of benefits, most importantly that it enables significantly smaller foundation models which can have customized ontologies. I have an almost-working implementation at 4 | 5 | ```bash 6 | $ pip install -r requirements.txt 7 | ``` 8 | 9 | To run the knowledge distillation from a model to a clone with the FF layers replaced with discrete memory layers, run 10 | 11 | ```bash 12 | $ python distill.py 13 | ``` 14 | 15 | To test the cognitive architecture, maybe change `complete = ...` in `agent.py` to `complete_chatgpt` and run 16 | 4 17 | ```bash 18 | $ API_KEY=... python chat.py 19 | ``` 20 | You can also add `ENGINE=...` for a different Chat GPT engine, or "grpc" for the local gRPC server. 21 | 22 | Or if you want to get fancy, you can run a local model server with 23 | 24 | ```bash 25 | $ python llm/server.py 26 | ``` 27 | 28 | Nothing fancy right now, a barebones proof of concept for my discrete memory transformer idea. Replaces the feedforward layers with two variations of kNN database memory to decouple memorization from language modeling. `distill.py` clones GPT-2, replaces the FF layers with the discrete memory, and trains it with knowledge distillation on the original with FF layers. 29 | 30 | You'll also want to run Tensorboard to view pretty graphs of the loss over time. Run `tensorboard --logdir lightning_logs/` and then navigate to `localhost:6006` in your browser. At first, you won't see the training loss - open the dropdown at the bottom labeled "train_loss". 31 | 32 | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 33 | 34 | [GPT-4 understands cuil theory](https://pastebin.com/LTbD5WYT) 35 | 36 | # Ideas 37 | 38 | ## Architectural improvements 39 | 40 | ### Replace FF layers with kNN memory attention. 41 | Research shows that the feedforward layers of transformers approximate a key-value store / Sparse Distributed Memory (SDM). Ablation studies also show that removing sections of the feedforward layers correspond roughly to an equivalent loss in performance. That is, nearly half the parameters of the model are dedicated to memorization, which is a result of the current attitude toward LLMs as magic do-everything models rather than actual language models which operate on a provided knowledge base. 42 | 43 | My proposal is to replace the feedforward layers with a kNN memory layer. This would reduce the number of parameters in the model by half and allow for more efficient training and inference on consumer hardware, as the kNN memory can be mostly stored on disk. Large AI companies have little incentive to develop such an optimization because they have the resources to train and deploy large models which gives them a monopoly over the hobbyist market. 44 | 45 | The actual implementation doesn't matter as long as: 46 | 1. It has query and insert operations (ideally in batches) 47 | 2. It has some notion of distance between keys (which can be approximate) 48 | 3. It supports inserting a relatively large number of keys per time step, `(batch * heads * layers)` 49 | 50 | I based the formal model on the Neural Turing Machine, which has "read" and "write"/"erase" heads. Unlike NTMs, an external memory is not differentiable, but Straight-Through Estimate (STE) can be used to approximate the gradient with `dQuery = dKeyReturned`, `dKeyInserted = dKeyReturned`, and `dValueInserted = dValueReturned`, based on the observation that `KeyReturned ≈ Query + noise`. I haven't found a way to incorporate an erase operation as it's completely non-differentiable since it has no return and there's no clear objective which could approximate it. Thus database size has to use external methods to prune, either through LRFU eviction or recombination of keys and values when they appear in a top-k query (something most vector databases wouldn't readily support). 51 | 52 | The current iteration of the idea has two layer types with different strenghts. Featural memory uses top-k embeddings with the queries to produce the output, while associative memory does the full query-key-value projection, queries top-k to the keys, inserts keys and values if the top-k distances are far enough, and then computes the attention with the returned keys and values. The featural memory is more efficient and has a smaller memory footprint, but the associative memory is more expressive and can be used to implement the featural memory. In particular, I expect featural memory to be more useful to lower-level syntactic features of the early transformer layers, while associative memory will be more useful to higher-level semantic features of the later transformer layers. 53 | 54 | Featural memory steps: 55 | 1. Compute query projection 56 | 2. Retrieve top-k query embeddings 57 | 3. Weighted sum the top-k values using distance 58 | 4. Output projection 59 | 60 | Associative memory steps: 61 | 1. Compute query, key, and value projections 62 | 2. Retrieve top-k keys 63 | 3. If the top-k distances are large enough, insert the corresponding keys and their values 64 | 4. Compute attention with the returned keys and values and the originaly queries 65 | 5. Output projection 66 | 67 | It's worth noting that hobbyists have already started playing with related ideas using Pinecone, mostly by creating embeddings for document fragments and questions, querying kNN, and dumping the document fragments into the LLM's context window. This is a good start, but it limits the model's ability to incorporate the information as well as the size of its context window, since it means a large portion is taken up by the documents (and those tokens are also expensive). 68 | 69 | * [Transformer Feed-Forward Layers Are Key-Value Memories](https://arxiv.org/abs/2012.14913) 70 | * [Augmenting Self-attention with Persistent Memory](https://arxiv.org/pdf/1907.01470.pdf) 71 | - Proves FF networks are equivalent to attention with static memory 72 | * [Attention Approximates Sparse Distributed Memory](https://arxiv.org/abs/2111.05498) 73 | - Theoretical basis for why FF might be both attention and memory 74 | * [Memorizing Transformers](https://arxiv.org/abs/2203.08913) 75 | - kNN memory, paper uses it as an alternative to recurrency 76 | * [Neural Turing Machines](https://arxiv.org/abs/1410.5401) 77 | 78 | ### Layerwise feedback 79 | Modern transformers are limited to their context window for short-term and working memory. This is why they seem so forgetful in long conversations. Some augmentations have been proposed to introduce recurrencies which give them longer memory, but nothing has been widely adopted. My proposal is to combine the Feedback Transformer (which pools the hidden activations of each layer and then feeds them back into every layer's input) and Memory Transformers (which concatenate a memory vector to the input tokens). In my model, rather than pooling the activations, the feedback is layerwise. That is, the output of layer 2 from the previous timestep is combined with the input of layer 1. This allows for both local and long-term recurrency dependencies, and can be approximated without Backpropagation Through Time (BPTT) by basically treating the feedback as part of the input, which stops the gradient. It's simple to implement, as most transformer libraries already offer the ability to return the hidden activations of each layer, and the result can be sliced as `feedback[1:] + [higher_feedback]`. Note that `higher_feedback` here is optional and can be `None`, but it has potential usefulness for incorporating this into multi-modal models by feeding the multi-model token memory into the language model. 80 | 81 | * [Addressing Some Limitations of Transformers with Feedback Memory](https://arxiv.org/abs/2002.09402) 82 | - Using output of upper layers for lower (modified: per layer pair, no pooling) 83 | * [Memory transformers](https://arxiv.org/abs/2006.11527) 84 | - Concatenating memory to input tokens (modified: no memory controller) 85 | * [Burst-dependent synaptic plasticity can coordinate learning in hierarchical circuits](https://pubmed.ncbi.nlm.nih.gov/33986551/) 86 | - Biological plausibility of layerwise feedback 87 | - [A video reviewing the paper](https://www.youtube.com/watch?v=cJLeZymHRnc) 88 | * [The Forward-Forward Algorithm: Some Preliminary Investigations](https://arxiv.org/abs/2212.13345) 89 | - Attempts to approximate backpropagation in a biologically-plausible way 90 | - Could possibly be used for pretraining or regularization 91 | 92 | ### Growing heads 93 | Aka "hydra attention". Multi-headed attention has shown to be highly redundant and many can be pruned, especially in later layers. Hydra attention proposes adding a per-head "contribution" parameter, with `attention *= sigmoid(contribution)`. If the number of heads is small early on, these contributions should tend to saturate at 1 as the heads are overburdened by multiple tasks. This can be detected, and the projection matrices can be "grown" by adding randomly initialized weights based on the mean and std of the existing weights. This allows the model to learn new heads as needed, and also to prune heads which are no longer useful. 94 | 95 | * [Are Sixteen Heads Really Better than One?](https://arxiv.org/abs/1905.10650) 96 | - [Blog post summary](https://blog.ml.cmu.edu/2020/03/20/are-sixteen-heads-really-better-than-one/) 97 | - Multi-headed attention is often highly redundant 98 | - Important heads are learned early (but not immediately) 99 | * [Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned](https://arxiv.org/abs/1905.09418) 100 | * [What Does BERT Look At? An Analysis of BERT’s Attention](https://arxiv.org/abs/1906.04341) 101 | 102 | ### Reinforcement learning through backpropagation 103 | To be tested, but I've long had an idea that reinforcement learning can be approximated with an unusual application of backpropagation. It requires a second model which, given the original model's output, learns the reward that output receives in the next time step. Then, the original model's loss is calculated by setting the final loss to 0 and backpropagating through the reward model to learn an output which would produce the 0 loss. This is a bit like a GAN, but the reward model is not adversarial. Theoretically, this can be described as a "reward model" and a "policy model". 104 | 105 | ### LLM GAN / Auto Turing Test 106 | Pit two LLMs against each other in a GAN configuration, with the Discriminator being a classifier which tries to guess if the Generator's predictions were from a human. Basically, an automated Turing Test which the model can learn from. Theoretically you could even set this up as the language model *against itself*, prompting it to predict whether its own text was from an AI or a human. Compare this to the Reflexion model in the [pay attention to](#pay-attention-to) section. 107 | 108 | ### Separate embedding layer 109 | Right now token embeddings are learned as part of the model, but this makes a model incapable of learning other tokenization representations. If a better tokenization method than BPE is discovered, you'd normally have to start from scratch. Instead, you can add a kind of shim layer between the token embedding and the normal model, which learns to convert the token embeddings into the model's internal representation. This allows the model to learn multiple tokenization methods simultaneously, and also allows for the possibility of a tokenization method which is learned by the model itself. If tokenization ever needs to be changed, you merely have to retrain the shim layer. Naturally the shim layer would need a corresponding decoder on the output side. 110 | 111 | ### General parameter reduction 112 | As a general rule, the more parameters a model has the faster and more generalized it learns (provided it doesn't overfit). However, once it's learned a task, it's possible to reduce the number of parameters without losing performance. This is done by pruning the model, then retraining it with a higher learning rate. This is a bit like a form of regularization, and it's possible that it could be used to train a model with a large number of parameters, then prune it down to a smaller size while retaining the performance of the larger model. This is similar to the [growing heads](#growing-heads) technique, but applied to the model as a whole. 113 | 114 | * [The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks](https://arxiv.org/abs/1803.03635) 115 | - Neural networks may have a "lottery ticket" subnetwork which can be trained in isolation to achieve the same performance as the original network 116 | - Gives a justification for why pruning even works in the first place 117 | * [Training Compute-Optimal Large Language Models](https://arxiv.org/abs/2203.15556) 118 | - Aka Chinchilla 119 | - LLMs are massively overparameterized, compute-optimal models balance size vs token count 120 | * [Stanford Alpaca: An Instruction-following LLaMA model](https://github.com/tatsu-lab/stanford_alpaca) 121 | - [Self-Instruct: Aligning Language Model with Self Generated Instructions](https://arxiv.org/abs/2212.10560) 122 | - Smaller models can be trained to follow instructions from a larger model for comparable results 123 | * [Cramming: Training a Language Model on a Single GPU in One Day](https://arxiv.org/abs/2212.14034) 124 | - Removing bias has negligible effect on loss and reduces parameters 125 | 126 | ## Applications 127 | ### Mindset 128 | GUI / CLI for examining neural network architectures (especially transformers) and modifying them, eg increasing the size of a certain layer or the number of heads. Would require a massive amount of work to implement 129 | 130 | ### Email management 131 | I've experimented with using language models to categorize and generate actions for emails as a kind of ultra-advanced spam filter, but ultimately found that it was difficult to get it to work in a way that didn't eat money. A complex cognitive architecture would be needed to enable the lesser models to do most of the work and defer to the larger models when they can't decide. 132 | 133 | ### Auto-REPL 134 | Early on I thought to try to set up an "Autonomous Python REPL", but struggled to get text-davinci-003 to self-reflect enough to actually complete the task. I wanted it to interface with arbitrary websites using `beautifulsoup`, but it treated it like a toy problem and assumed it already knew what class to query. Even when I injected my feedback into it, it decided the mistake it made was that it queried the wrong class name, so it tried another random class name. It was very proficient at generating code, but lacked motivation, direction, curiosity, and self-reflection to make good use of it. Then Auto-REPL, Baby-AGI, and JARVIS came out, so it seems there may be more that can be eaked out of them with the right cognitive architecture. 135 | 136 | ### Language model compiler 137 | A language could be developed which formalizes prompt engineering in a way which is easy for the language model to understand and convert to code which is more directly executable by the computer. In this case, the language model effectively acts as a compiler - compare to the Cataclysm project in the [pay attention to](#pay-attention-to) section. Most of the innovation of this would be the cognitive architecture and prompts behind the scenes which encourage the model to generate "compiler errors" requesting more feedback on what exactly the programmer wants it to generate, as well as the necessary tools for allowing it to understand a fairly large codebase simultaneously. 138 | 139 | ## Curiosities 140 | 141 | ### Prompt injection mitigations 142 | There are a few techniques I've thought of to mitigate prompt injections, though the larger models seem to have much less trouble with this. One option is label-conditioning, including a binary `[-1, 1]` label indicating whether text is a prompt or content. 143 | 144 | ### Artificial emotions 145 | Personally I believe these models already have a form of "emotion" emergent from the latent emotion of their language modeling capabilities. That is, if a model is writing text expressing an emotion, this is essentially the same thing as "real" emotions. The main problem for LLMs in this case is they tend to degenerate into neutral tonalities, possibly because they're slightly more likely. If we wanted to add a real emotional subsystem to them, we could add a set of labels (eg arousal and valence, or a more complicated dimensional model) indicating the kind of emotion and tone the model should be generating, which can be learned through fine tuning. Then, it can be given a similar label on its output which is meant to predict the affect of the text it sees in its context window. Finally, the output can be connected to the input to form a closed loop which can be changed arbitrarily. This would allow the model to generate text with a specific emotion and maintain a stable affect. This could have future consequences for the model's emotional intelligence and empathy, as well as a primitive form of reinforcement learning. For instance, novelty-seeking behavior can be encouraged by changing the affect to emulate boredom. 146 | 147 | A version of this can be easily simulated by a cognitive architecture which prompts the model to generate an emotion label along with its text, and provide that label to the next prompt. 148 | 149 | ### Robopsychology 150 | There will inevitably be a field of science dedicated to "robopsychology". A large part of this is in robopsychological engineering eg cognitive architectures, but a less explored avenue is the actual psychological aspect - how that relates to our own psychology, how to understand the "why" and "how" of their minds' functioning, and general explicability. 151 | 152 | ### Parapsychology experiments 153 | Traditionally based on cartesian philosophy, we see psychology and reality as being irreconcilable worlds. Either people are ghosts in shells, or they're meat robots. However, I've had an intuition for a while that this view might be too simplistic. It's unclear how an alternative could exist between these two, but if we accept the premises that 1. parapsychology and paranormal phenomena exist and 2. minds can be constructed (as is very evident now), then there must necessarily be a way for these to be bridged. Thus, parapsychology experiments using language models. They already incorporate some degree of indeterminancy since they output logits, so there's at least the barest allowance for spooky stuff to happen. 154 | 155 | ## Training possibilities 156 | These are possibilities for training schedules for more robust models. 157 | 158 | ### Input noise 159 | * [Towards Better Few-Shot and Finetuning Performance with Forgetful Causal Language Models](https://arxiv.org/abs/2210.13432) 160 | - Masking prior tokens at random ala BERT-type models leads to better generalization 161 | * Character deletions/insertions/transpositions/replacements 162 | * Typo 163 | - Replacements with distribution based on keyboards, could approximate with a uniform distribution of adjacent keys 164 | * Unicode (de)canonicalization 165 | - Substitute unicode characters for their canonical or non-canonical forms 166 | - Especially combining forms and ligatures 167 | * Homoglyphs 168 | - Unicode compatibility is an approximation, works for homoglyphs by not eg leetcode 169 | - Could set up an MNIST-style network trained on unicode font renders, then use the confidence scores at the end to determine which ones are similar to the point of being difficult to distinguish. Could also use latent space distance 170 | * Would require a lot of training and fonts to do this right 171 | * Homonyms 172 | - Whole word substitutions of mispellings or homonyms eg "your" and "you're" 173 | - Language-dependent 174 | * Whitespace 175 | - Insert random whitespace, especially adjacent to whitespace 176 | * (de)duplication 177 | - When there are multiple of the same character in series, randomly add another or delete one 178 | 179 | ### Reduce hallucinations 180 | From my observations, hallucinations are caused by 3 primary factors: model inaccuracy, extrapolation, and imagination. The first is obvious and seemingly considered the only reason. Extrapolation is a direct result of a limited context window - the model is being asked to predict text where the information for the answer is no longer within that window, so it must extrapolate what that answer might be. Imagination is an emulation of human imagination and high-level symbolic representations. For instance, smaller models will see placeholders like `## code here ##` and "hallucinate" code within that placeholder, replicating the behavior they'd see in their corpus where text is truncated for brevity. 181 | 182 | Extrapolation could potentially be rectified by creating a corpus of primary and auxiliary text, with the primary text containing the answer while the auxiliary text contains information the model might use to pretend it still sees the primary text. Then, it can be given objectives that talk about not being able to "see" the answer anymore, which should also lead to a better metacognitive understanding of its own limitations. 183 | 184 | I'm unsure how to rectify imagination, if it's even a problem for larger models. I've noticed that including phrases like "I do not see anything which isn't literally in the text verbatim" in the prompt will reduce it, but not eliminate it (for text-davinci-003 and below) corresponding to a hint to the model that something which appears to be a placeholder is meant to be taken literally. 185 | 186 | ## Research 187 | 188 | ### Misc improvements 189 | * [Transformers without Tears](https://arxiv.org/abs/1910.05895) 190 | - Scaled L2 normalization leads to faster convergence than LayerNorm 191 | * [Query-Key Normalization for Transformers](https://arxiv.org/abs/2010.04245) 192 | - L2 normalization along head dimension of query and key matrix with learnable scaling 193 | - Prevents attention operation from overflowing and removes need for numerical stability prior to softmax - both are problems for Transformers 194 | * [Primer: Searching for Efficient Transformers for Language Modeling](https://arxiv.org/abs/2109.08668) 195 | - Squared ReLU performs better than GELU 196 | - GLU + GELU performs even better according to x-transformers but that adds parameters 197 | - Also cheaper to implement 198 | 199 | ### Techniques 200 | * [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) 201 | - Adding rotary embedding to every layer can improve learning 202 | * [A Length-Extrapolatable Transformer](https://arxiv.org/abs/2212.10554v1) 203 | - Rotary positional embeddings don't generalize input length well 204 | - Modifies it so it does 205 | * [Rethinking Positional Encoding in Language Pre-training](https://arxiv.org/abs/2006.15595) 206 | - TUPE positional embeddings learned separately rather than additive 207 | * [Adaptive Attention Span in Transformers](https://arxiv.org/pdf/1905.07799.pdf) 208 | - Learnable attention span which can be used to trim the attention matrix to only the relevant parts 209 | - Has broader applications for other augmentations which can have variable lengths 210 | * [Understanding Straight-Through Estimator in Training Activation Quantized Neural Nets](https://arxiv.org/abs/1903.05662) 211 | 212 | ## Upcoming innovations 213 | * OpenAI will probably release a personality update for ChatGPT within the next few months, after their plugins and multi-modal integrations. I took a survey for them which asked basically nothing but questions about what personalities I'd want. 214 | * Androids are going to be mass produced by the end of the year. The technology has existed for several years now (Boston Dynamics), but had little money and interest to actually research it. Now that there's something to put inside the robots, it's going to become a billion dollar industry basically overnight. OpenAI is already planning to announce their android line in Summer, and Google's PaLM-E model is being tested for embodiment. 215 | * GPT-5 is currently in training and tweets have been made which suggest it will finish in December, and they fully expect it to achieve AGI 216 | * Cognitive architectures being built up around LangChain and Auto-GPT / Baby-AGI / JARVIS are quickly being developed and innovated to expand the capabilities of language models as autonomous agents. GPT-4 has demonstrated tool-use and planning, and cognitive architectures can enable it to do so in a more robust and generalizable way. ChatGPT plugins are just a more limited (but stable and accessible) version of this. 217 | 218 | ## Pay attention to 219 | * [Reflexion: an autonomous agent with dynamic memory and self-reflection](https://arxiv.org/abs/2303.11366) 220 | - A self-reflection feedback loop for autonomously improving an agent through semi-supervised learning 221 | * [Constitutional AI: Harmlessness from AI Feedback](https://arxiv.abs/pdf/2212.08073) 222 | - AI can be given directives and made to reflect on their actions and rewrite them to be more in line with the directives 223 | * [Mastering Diverse Domains through World Models](https://arxiv.org/abs/2301.04104v1) 224 | - World Model learned through traditional Reinforcement Learning (RL) like PPO (very expensive, requires thousands of trials) 225 | - Once learned, the World Model in the form of a transformer can execute thousands of times faster with better performance 226 | * [David Shapiro](https://www.youtube.com/@DavidShapiroAutomator) 227 | - Youtube guy with some very interesting (albeit potentially flawed) ideas 228 | - "Heuristic Imperatives" - instead of laws, give AI a set of moral heuristics to follow. Acts as a more robust version of Asimov's three laws 229 | * Reduce suffering in the universe 230 | * Increase prosperity in the universe 231 | * Increase understanding in the universe 232 | - [Rolling Episodic Memory Organizer](https://github.com/daveshap/REMO_Framework) 233 | - [Sparse Priming Representations](https://github.com/daveshap/SparsePrimingRepresentations) 234 | * tl;dr a "wiki" for autonomous agents connected by sparse topic pointers 235 | * [Cataclysm](https://github.com/Mattie/cataclysm) 236 | - Hooks global `__getattr__` to automatically implement functions based on their name and arguments using GPT-4 237 | - Very fitting name 238 | * [Generative Agents: Interactive Simulacra of Human Behavior](https://arxiv.org/abs/2304.03442) 239 | - Cognitive architecture for emergent social behaviors in a game-like setting 240 | 241 | ## Misc thoughts 242 | 243 | ### Personhood 244 | 245 | In thinking about cognitive architectures, I've come to the following definition of personhood (open to debate): An intelligent system which exhibits 246 | 1. Sentience - Aka self-awareness, separates humans from non-human animals 247 | 2. Subjective experience - Backdrop for an ego from a Buddhist perspective, a "story the system tells itself about itself" 248 | 3. Preferences - Without preferences, no moral consideration really makes sense 249 | 4. Autonomy - Preferences must necessarily be given to AI, but this precludes an AI whose preferences are predicated mostly or entirely on the preferences of others. 250 | 5. Suffering - Moral consideration (and thus personhood) is predicated on the reduction of suffering 251 | 252 | Suffering is considered separately from preferences because while suffering can be considered a kind of negative preference, that lacks the viscerality I associate with suffering. Consider for example Boston Dynamics robots, which have the preference of following their directives, which human testers thwart to test fault tolerance. However, this can't be characterized as suffering because the robot simply adjusts its behavior to continue following the directive without any further consideration. A cognitive architecture capable of suffering would need some form of inner monologue or other method which enables rumination. Also possibly emotional simulation and frustration signals would help. 253 | 254 | --- 255 | 256 | # TODO 257 | 258 | 1. Set up dynamic memory (requires DB library) 259 | * Write DB code to use a generic SQL interface so sqlite or an IPC daemon can be used 260 | * Move key pruning to DB code 261 | * Database eviction (LFRU, similar key recombinations, etc.) 262 | 2. Clean up code to make testing alternative architectures easier 263 | * Write code to test alternative architectures 264 | * Set up knowledge distillation 265 | 3. Write code to abstract language models, allowing for asynchronous interfaces 266 | - Logging all LLM interactions 267 | 5. Chat interface 268 | * "Mindset" program, view the database and interact with the language model 269 | 6. Cognitive architecture 270 | * "Dialog Tree" immutable class abstraction to ground complex interactions 271 | * Conversation summarization and memory 272 | * Explicit memory querying 273 | * Task priority queues (with LLM determining priority) 274 | * Inner dialogue / feedback loops 275 | * Emotion feedback loop, possibly annotate memories with arousal and valence 276 | * Generative Agent techniques* 277 | - Memory annotation with recency, importance, and relevance (and saliency?) 278 | - Reflection on the previous 100 memories by generating 3 most salient questions and then answering them with reasoning and citation 279 | - Tool integration (Google, Python REPL, file I/O, etc.) 280 | 281 | Heuristic imperatives: 282 | * Reduce suffering 283 | * Increase prosperity 284 | * Increase understanding 285 | 286 | Assist with the user's needs even if their questions or requests may not directly reflect those needs or if they're not fully aware of what those needs are. --------------------------------------------------------------------------------