├── compile.py ├── lib_sparse.py └── lib.py /compile.py: -------------------------------------------------------------------------------- 1 | # Load the stravinsky file 2 | from compiler.utils import CompilerError 3 | 4 | with open('sha256_full.strav', 'r') as f: 5 | file = f.readlines() 6 | 7 | print(file) 8 | 9 | # Create output file 10 | f = open("compiled_transformer.py", "w") 11 | 12 | g = open("constants.py", "w") 13 | 14 | f.write("import torch.nn as nn\nimport time\nfrom tqdm import tqdm\nfrom compiler.lib_sparse import *\n\n") 15 | 16 | # Get the main input 17 | input_text = "" 18 | for line in file: 19 | if "INPUT" in line: 20 | input_text = line.split("=")[1].strip() 21 | break 22 | 23 | if input_text == "": 24 | raise CompilerError("No input found. The file must contain a line like `INPUT = 1101011101011010`") 25 | 26 | g.write(f"INPUT_LENGTH = {len(input_text)}\n") 27 | g.close() 28 | 29 | # First, we must determine how many total registers we need 30 | registers = ['tchaikovsky', 'anti_tchaikovsky', 'zeros', 'ones'] 31 | user_defined_constant_register_values = [] 32 | 33 | # Add constant registers 34 | for line in file: 35 | if line[0] == "$": 36 | registers.append(line[1:].split("=")[0].strip()) 37 | user_defined_constant_register_values.append(line[1:].split("=")[1].strip()) 38 | 39 | # Extract the lines of the file between lines containing PROGRAM_START and PROGRAM_END 40 | program_lines = [] 41 | program_started = False 42 | 43 | for line in file: 44 | if "PROGRAM_START" in line: 45 | program_started = True 46 | elif "PROGRAM_END" in line: 47 | program_started = False 48 | elif program_started: 49 | program_lines.append(line) 50 | 51 | for line in program_lines: 52 | if "=" in line: 53 | rname = line.split("=")[0].strip() 54 | if rname not in registers: 55 | registers.append(rname) 56 | 57 | print(registers) 58 | 59 | # Get the list of tokens 60 | # In the file, the line looks like `TOKENS = /0\1\2/` 61 | # We need to convert this to the string `tokens = list('012')` 62 | 63 | tokens = [] 64 | returns = "" 65 | 66 | for line in file: 67 | if "TOKENS" in line: 68 | tokens = line.split("/")[1].split("/")[0].strip().split("\\") 69 | 70 | if line.startswith("RETURNS"): 71 | returns = line.split("RETURNS")[1].strip() 72 | 73 | if len(tokens) == 0: 74 | raise CompilerError("No tokens found. The file must contain a line like `TOKENS = /0\1\2/`") 75 | 76 | f.write("class CompiledTransformer(nn.Module):\n") 77 | f.write(" def __init__(self, *args, **kwargs):\n") 78 | f.write(" super().__init__(*args, **kwargs)\n") 79 | f.write(" self.t = time.time()\n") 80 | f.write(f" self.tokens = {tokens}\n") 81 | f.write(" self.pos = Register('pos', 2)\n") 82 | 83 | # Create the register objects 84 | for register in registers: 85 | f.write(f" self.{register} = Register('{register}', 1)\n") 86 | 87 | num_work_registers = -1 88 | 89 | for line in file: 90 | if "NUM_WORK_REGISTERS" in line: 91 | num_work_registers = int(line.split("=")[1].strip()) 92 | break 93 | 94 | if num_work_registers == -1: 95 | raise CompilerError("No NUM_WORK_REGISTERS found. The file must contain a line like `NUM_WORK_REGISTERS = 5`") 96 | 97 | # Create a healthy amount of work registers 98 | f.write(f""" 99 | self.work_registers = [] 100 | for i in range({num_work_registers}): 101 | self.work_registers.append(Register(f'work_{{i}}', len(self.tokens))) 102 | """) 103 | 104 | # Create the main embedding 105 | # embedding = EmbeddedState(tokens, [pos, tchaikovsky, anti_tchaikovsky, zero_register, input_copy, input_copy2, shifted, shiftedl] + work_registers) 106 | 107 | embedding_line = " self.embedding = EmbeddedState(self.tokens, [self.pos, " 108 | for register in registers: 109 | embedding_line += f"self.{register}, " 110 | embedding_line = embedding_line[:-2] + "] + self.work_registers)\n" 111 | 112 | f.write(embedding_line) 113 | 114 | # Now, create the input 115 | # example = embedding.embed(embedding.tokenize('1101011'), [embedding.tokenize('0111111'), embedding.tokenize('1111110')]) 116 | 117 | # Create constant register values 118 | constant_register_values = [] 119 | length = len(input_text) 120 | 121 | constant_register_values.append('0' + '1' * (length - 1)) # tchaikovsky 122 | constant_register_values.append('1' * (length - 1) + '0') # anti_tchaikovsky 123 | constant_register_values.append('0' * length) # zeros 124 | constant_register_values.append('1' * length) # ones 125 | 126 | constant_register_values.extend(user_defined_constant_register_values) 127 | 128 | first_input_line = f" first_input = self.embedding.embed(self.embedding.tokenize(forward_input), [" 129 | 130 | for register_value in constant_register_values: 131 | first_input_line += f"self.embedding.itokenize('{register_value}'), " 132 | 133 | first_input_line = first_input_line[:-2] + "])\n" 134 | 135 | # Now, we actually create the program 136 | 137 | func_templates = { 138 | "copy": "Copy(self.embedding, self.pos, self., self.)", 139 | "copy_input": "ConvertToInternal(self.embedding, self.)", 140 | "keep_": "Keep(self., self.)", 141 | "rotate_": "Rotate(self.embedding, self.pos, self.tchaikovsky, self.anti_tchaikovsky, self., , self.work_registers)", 142 | "rotate_with_limit_": "RotateWithLimit(self.embedding, self.pos, self.tchaikovsky, self.anti_tchaikovsky, self., , , self.work_registers)", 143 | "shiftr_": "Shift(self.embedding, self.pos, self.tchaikovsky, self., , self.work_registers)", 144 | "shiftl_": "ShiftL(self.embedding, self.pos, self.anti_tchaikovsky, self., , self.work_registers)", 145 | "xor": "XOR(self.embedding, self.pos, self., self., self., self.work_registers)", 146 | "and": "AND(self.embedding, self., self., self.)", 147 | "not_": "NOT(self.embedding, self.pos, self., self.work_registers)", 148 | "print_": "Print(self.embedding, self.)", 149 | "add": "Add(self.embedding, self.pos, self.anti_tchaikovsky, self., self., self., self.work_registers)" 150 | } 151 | 152 | 153 | def get_template(func_name, args): 154 | if func_name == "shift_": 155 | if int(args[1]) >= 0: 156 | func_name = "shiftr_" 157 | else: 158 | func_name = "shiftl_" 159 | 160 | template = func_templates.get(func_name) 161 | if template is None: 162 | raise CompilerError(f"Unknown function {func_name}") 163 | 164 | for i, arg in enumerate(args): 165 | template = template.replace(f"<{chr(97 + i)}>", arg) 166 | 167 | return template 168 | 169 | 170 | f.write(f" self.pbar = tqdm(total={len(program_lines)}, leave=False)\n") 171 | 172 | real_index = 0 173 | for idx, line in enumerate(program_lines): 174 | f.write(f" self.pbar.update({idx})\n") 175 | # First, check if there's an assignment happening. 176 | if "=" in line: 177 | # In this case, split the line into the destination and the function 178 | dest, func = line.split("=") 179 | 180 | # Remove whitespace 181 | dest = dest.strip() 182 | func = func.strip() 183 | 184 | # Get the function name and arguments 185 | func_name = func.split("(")[0] 186 | args = func.split("(")[1].split(")")[0].split(",") 187 | 188 | # Remove whitespace from the arguments 189 | args = [arg.strip() for arg in args] 190 | # Get rid of empty arguments 191 | args = [arg for arg in args if arg != ""] 192 | # Add the destination to the arguments 193 | args.append(dest) 194 | 195 | # Write the actual function call 196 | template = get_template(func_name, args) 197 | 198 | f.write(f" self.op_{real_index} = {template}\n") 199 | real_index += 1 200 | elif len(line) > 2 and not line.strip().startswith("%"): 201 | # Otherwise, it's an in-place operation 202 | # Get the function name and arguments 203 | func_name = line.split("(")[0] 204 | args = line.split("(")[1].split(")")[0].split(",") 205 | 206 | # Remove whitespace from the arguments 207 | args = [arg.strip() for arg in args] 208 | # Get rid of empty arguments 209 | args = [arg for arg in args if arg != ""] 210 | 211 | if func_name == "del": 212 | # Special case for del 213 | ls = "[" 214 | for arg in args: 215 | ls += "self." + arg + ", " 216 | ls = ls[:-2] + "]" 217 | f.write(f" self.op_{real_index} = Clear(self.embedding, {ls})\n") 218 | else: 219 | # Write the actual function call 220 | template = get_template(func_name, args) 221 | 222 | f.write(f" self.op_{real_index} = {template}\n") 223 | 224 | real_index += 1 225 | 226 | f.write(" self.pbar.close()\n") 227 | f.write(" print('ok we done')\n") 228 | 229 | all_ops_string = ",\n ".join([f"self.op_{i}" for i in range(real_index)]) 230 | f.write(f""" 231 | self.all_ops = [ 232 | {all_ops_string} 233 | ] 234 | """) 235 | f.write(""" 236 | def count_parameters(module): 237 | total_params = 0 238 | for _, param in module.state_dict().items(): 239 | total_params += param.numel() 240 | 241 | for attr_name in dir(module): 242 | attr = getattr(module, attr_name) 243 | if isinstance(attr, torch.nn.Module): 244 | total_params += count_parameters(attr) 245 | elif isinstance(attr, list): 246 | for item in attr: 247 | if isinstance(item, torch.nn.Module): 248 | total_params += count_parameters(item) 249 | 250 | return total_params 251 | 252 | self.s = 0 253 | 254 | for i, module in enumerate(self.all_ops): 255 | total_params = count_parameters(module) 256 | print(f"op_{i}: Total parameters = {total_params}") 257 | self.s += total_params 258 | 259 | print(f"Total parameters: {self.s}") 260 | """) 261 | 262 | f.write(""" 263 | elapsed = time.time() - self.t 264 | print(f"Elapsed time: {elapsed:.2f}s") 265 | """) 266 | 267 | f.write(""" 268 | def forward(self, forward_input): 269 | """) 270 | 271 | f.write(first_input_line) 272 | 273 | f.write(""" 274 | x = self.op_0.forward(first_input.unsqueeze(0))[0] 275 | # plot_tensor(x, embedding, 'op_0') 276 | """) 277 | for i in range(1, real_index): 278 | f.write(f" x = self.op_{i}.forward(x.unsqueeze(0))[0]\n") 279 | # f.write(f"plot_tensor(x, embedding, 'op_{i}')\n") 280 | 281 | f.write(f" return x[:, self.{returns}.offset:compiled_transformer.{returns}.offset + self.{returns}.size].flatten()\n") 282 | 283 | f.write("\n\n\n\n") 284 | f.write("compiled_transformer = CompiledTransformer()\n") 285 | f.write(f"print(compiled_transformer('{input_text}'))\n") 286 | 287 | f.close() 288 | -------------------------------------------------------------------------------- /lib_sparse.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn 5 | import torch.nn.functional as F 6 | 7 | from constants import INPUT_LENGTH 8 | 9 | POS_STEP = 1e-3 10 | 11 | class Register(object): 12 | def __init__(self, name, size): 13 | self.name = name 14 | self.size = size 15 | self.offset = None 16 | 17 | class EmbeddedState(object): 18 | def __init__(self, tokens: list[str], registers: list[Register]): 19 | self.tokens = tokens 20 | self.token_map = {t: i for i, t in enumerate(tokens)} 21 | self.registers = registers 22 | self.register_map = {} 23 | self.register_size = 0 24 | 25 | if len(registers) == 0 or registers[0].name != 'pos': 26 | raise Exception("First register must be 'pos'") 27 | 28 | offset = len(tokens) 29 | for reg in registers: 30 | reg.offset = offset 31 | offset += reg.size 32 | self.register_size += reg.size 33 | self.register_map[reg.name] = reg 34 | 35 | self.dim = len(tokens) + self.register_size 36 | 37 | def tokenize(self, string: str): 38 | return F.one_hot(torch.tensor([self.token_map[c] for c in string]), num_classes=len(self.tokens)).float() 39 | 40 | def itokenize(self, string: str): 41 | return torch.tensor([self.token_map[c] for c in string]).float().unsqueeze(1) 42 | 43 | def embed(self, sequence, additional_constants): 44 | extension_tensor = torch.zeros(*sequence.shape[:-1], self.register_size) 45 | 46 | for i in range(sequence.shape[0]): 47 | extension_tensor[i, 0] = math.sin(i * (2 * math.pi) * POS_STEP) 48 | extension_tensor[i, 1] = math.cos(i * (2 * math.pi) * POS_STEP) 49 | 50 | offset = 2 51 | for constant in additional_constants: 52 | extension_tensor[:, offset:offset + constant.shape[-1]] = constant 53 | offset += constant.shape[-1] 54 | 55 | sequence = torch.cat((sequence, extension_tensor), dim=-1) 56 | 57 | return sequence 58 | 59 | def predict(self, sequence): 60 | return self.tokens[torch.argmax(sequence[-1, :len(self.tokens)])] 61 | 62 | class AttentionLayer(torch.nn.Module): 63 | def __init__(self, instruction): 64 | super(AttentionLayer, self).__init__() 65 | 66 | self.key = torch.nn.Parameter(instruction.key) 67 | self.value = torch.nn.Parameter(instruction.value) 68 | self.query = torch.nn.Parameter(instruction.query) 69 | 70 | self.mask = instruction.mask 71 | 72 | self.softmax = torch.nn.Softmax(2) 73 | 74 | def forward(self, seq): 75 | batch_size, seq_length, dim = seq.shape 76 | 77 | query = seq @ self.query 78 | key = seq @ self.key 79 | value = seq @ self.value 80 | 81 | causal_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1) * 0).to(seq.device) 82 | norm = np.sqrt(dim) 83 | 84 | kq = self.softmax(query @ key.transpose(-2, -1) / norm + causal_mask) 85 | 86 | s = (kq @ value) * self.mask 87 | 88 | return (seq + s) 89 | 90 | def reset(self): 91 | torch.nn.init.xavier_uniform_(self.key) 92 | torch.nn.init.xavier_uniform_(self.query) 93 | torch.nn.init.xavier_uniform_(self.value) 94 | 95 | class GetRelativeToken(AttentionLayer): 96 | def __init__(self, embedding: EmbeddedState, pos_reg: Register, steps: int, out: Register): 97 | tpos_reg = embedding.register_map['pos'] 98 | 99 | indices = torch.tensor([[tpos_reg.offset, tpos_reg.offset], 100 | [tpos_reg.offset + 1, tpos_reg.offset + 1]]) 101 | values = torch.tensor([1e10, 1e10]) 102 | position_select = torch.sparse_coo_tensor(indices.t(), values, (embedding.dim, embedding.dim)) 103 | 104 | i = -steps 105 | sin = math.sin(i * (2 * math.pi) * POS_STEP) 106 | cos = math.cos(i * (2 * math.pi) * POS_STEP) 107 | 108 | rotation_indices = torch.tensor([ 109 | [pos_reg.offset, tpos_reg.offset], 110 | [pos_reg.offset + 1, tpos_reg.offset], 111 | [pos_reg.offset, tpos_reg.offset + 1], 112 | [pos_reg.offset + 1, tpos_reg.offset + 1] 113 | ]) 114 | rotation_values = torch.tensor([cos, -sin, sin, cos]) 115 | rotation = torch.sparse_coo_tensor(rotation_indices.t(), rotation_values, (embedding.dim, embedding.dim)) 116 | 117 | token_copy_indices = torch.tensor([[i, i + out.offset] for i in range(len(embedding.tokens))]) 118 | token_copy_values = torch.ones(len(embedding.tokens)) 119 | token_copy = torch.sparse_coo_tensor(token_copy_indices.t(), token_copy_values, (embedding.dim, embedding.dim)) 120 | 121 | self.query = rotation 122 | self.key = position_select 123 | self.value = token_copy 124 | 125 | self.mask = torch.zeros(embedding.dim) 126 | self.mask[out.offset:out.offset + out.size] = 1.0 127 | 128 | super(GetRelativeToken, self).__init__(self) 129 | 130 | class MLPLayer(torch.nn.Module): 131 | def __init__(self, first_weights, first_bias, second_weights, second_bias, mask, debug=False): 132 | super(MLPLayer, self).__init__() 133 | self.debug = debug 134 | 135 | self.first_weights = torch.nn.Parameter(first_weights) 136 | self.first_bias = torch.nn.Parameter(first_bias) 137 | self.second_weights = torch.nn.Parameter(second_weights) 138 | self.second_bias = torch.nn.Parameter(second_bias) 139 | 140 | self.gelu = torch.nn.ReLU() 141 | 142 | self.mask = mask 143 | 144 | def forward(self, seq): 145 | a = self.gelu(seq @ self.first_weights + self.first_bias) 146 | b = (a @ self.second_weights) 147 | x = b + self.second_bias 148 | return seq + (x * self.mask) 149 | 150 | def reset(self): 151 | torch.nn.init.xavier_uniform_(self.first_weights) 152 | torch.nn.init.zeros_(self.first_bias) 153 | torch.nn.init.xavier_uniform_(self.second_weights) 154 | torch.nn.init.zeros_(self.second_bias) 155 | 156 | class ConvertToInternal(MLPLayer): 157 | def __init__(self, embedding: EmbeddedState, out: Register): 158 | indices = torch.tensor([[1, out.offset]]) 159 | values = torch.tensor([1.0]) 160 | first_weights = torch.sparse_coo_tensor(indices.t(), values, (embedding.dim, embedding.dim)) 161 | first_bias = torch.zeros(embedding.dim) 162 | 163 | second_weights_indices = torch.arange(embedding.dim).repeat(2, 1) 164 | second_weights_values = torch.ones(embedding.dim) 165 | second_weights = torch.sparse_coo_tensor(second_weights_indices, second_weights_values, (embedding.dim, embedding.dim)) 166 | 167 | second_bias = torch.zeros(embedding.dim) 168 | 169 | mask = torch.zeros(embedding.dim) 170 | mask[out.offset:out.offset + out.size] = 1.0 171 | 172 | super(ConvertToInternal, self).__init__(first_weights, first_bias, second_weights, second_bias, mask) 173 | 174 | class GRLT2(AttentionLayer): 175 | def __init__(self, embedding: EmbeddedState, pos_reg: Register, steps: int, copy_from: Register, out: Register): 176 | tpos_reg = embedding.register_map['pos'] 177 | 178 | indices = torch.tensor([[tpos_reg.offset, tpos_reg.offset], 179 | [tpos_reg.offset + 1, tpos_reg.offset + 1]]) 180 | values = torch.tensor([1e10, 1e10]) 181 | position_select = torch.sparse_coo_tensor(indices.t(), values, (embedding.dim, embedding.dim)) 182 | 183 | i = -steps 184 | sin = math.sin(i * (2 * math.pi) * POS_STEP) 185 | cos = math.cos(i * (2 * math.pi) * POS_STEP) 186 | 187 | rotation_indices = torch.tensor([ 188 | [pos_reg.offset, tpos_reg.offset], 189 | [pos_reg.offset + 1, tpos_reg.offset], 190 | [pos_reg.offset, tpos_reg.offset + 1], 191 | [pos_reg.offset + 1, tpos_reg.offset + 1] 192 | ]) 193 | rotation_values = torch.tensor([cos, -sin, sin, cos]) 194 | rotation = torch.sparse_coo_tensor(rotation_indices.t(), rotation_values, (embedding.dim, embedding.dim)) 195 | 196 | token_copy_indices = torch.tensor([[copy_from.offset, out.offset]]) 197 | token_copy_values = torch.tensor([1.0]) 198 | token_copy = torch.sparse_coo_tensor(token_copy_indices.t(), token_copy_values, (embedding.dim, embedding.dim)) 199 | 200 | self.query = rotation 201 | self.key = position_select 202 | self.value = token_copy 203 | 204 | self.mask = torch.zeros(embedding.dim) 205 | self.mask[out.offset:out.offset + out.size] = 1.0 206 | 207 | super(GRLT2, self).__init__(self) 208 | 209 | class AND(MLPLayer): 210 | def __init__(self, embedding: EmbeddedState, first_reg: Register, second_reg: Register, result_reg: Register): 211 | indices = torch.tensor([ 212 | [first_reg.offset, result_reg.offset], 213 | [second_reg.offset, result_reg.offset] 214 | ]) 215 | values = torch.tensor([1.0, 1.0]) 216 | first_weights = torch.sparse_coo_tensor(indices.t(), values, (embedding.dim, embedding.dim)) 217 | first_bias = torch.zeros(embedding.dim) 218 | first_bias[result_reg.offset:result_reg.offset + result_reg.size] = -1.0 219 | 220 | second_weights_indices = torch.arange(embedding.dim).repeat(2, 1) 221 | second_weights_values = torch.ones(embedding.dim) 222 | second_weights = torch.sparse_coo_tensor(second_weights_indices, second_weights_values, (embedding.dim, embedding.dim)) 223 | 224 | second_bias = torch.zeros(embedding.dim) 225 | 226 | mask = torch.zeros(embedding.dim) 227 | mask[result_reg.offset:result_reg.offset + result_reg.size] = 1.0 228 | 229 | super(AND, self).__init__(first_weights, first_bias, second_weights, second_bias, mask) 230 | 231 | class Clear(MLPLayer): 232 | def __init__(self, embedding: EmbeddedState, registers: list[Register]): 233 | indices = [] 234 | values = [] 235 | for reg in registers: 236 | for i in range(reg.size): 237 | indices.append([reg.offset + i, reg.offset + i]) 238 | values.append(100.0) 239 | first_weights = torch.sparse_coo_tensor(torch.tensor(indices).t(), torch.tensor(values), (embedding.dim, embedding.dim)) 240 | first_bias = torch.zeros(embedding.dim) 241 | 242 | indices = [] 243 | values = [] 244 | for reg in registers: 245 | for i in range(reg.size): 246 | indices.append([reg.offset + i, reg.offset + i]) 247 | values.append(-0.01) 248 | second_weights = torch.sparse_coo_tensor(torch.tensor(indices).t(), torch.tensor(values), (embedding.dim, embedding.dim)) 249 | second_bias = torch.zeros(embedding.dim) 250 | 251 | mask = torch.zeros(embedding.dim) 252 | for reg in registers: 253 | mask[reg.offset:reg.offset + reg.size] = 1.0 254 | 255 | super(Clear, self).__init__(first_weights, first_bias, second_weights, second_bias, mask) 256 | 257 | class Copy(torch.nn.Module): 258 | def __init__(self, embedding: EmbeddedState, pos_reg: Register, copy_from: Register, copy_to: Register): 259 | super(Copy, self).__init__() 260 | 261 | self.copy = GRLT2(embedding, pos_reg, 0, copy_from, copy_to) 262 | 263 | def forward(self, seq): 264 | return self.copy.forward(seq) 265 | 266 | class Shift(torch.nn.Module): 267 | def __init__(self, embedding: EmbeddedState, pos: Register, tchaikovsky: Register, register_to_shift: Register, 268 | amount: int, work_registers: list[Register]): 269 | super(Shift, self).__init__() 270 | 271 | self.embedding = embedding 272 | 273 | self.shiftpt1 = GRLT2(embedding, pos, -amount, register_to_shift, work_registers[0]) 274 | self.clear = Clear(embedding, [register_to_shift]) 275 | self.shifted_tchaikovsky = GRLT2(embedding, pos, -(amount - 1), tchaikovsky, work_registers[1]) 276 | self.shiftpt2 = AND(embedding, work_registers[1], work_registers[0], register_to_shift) 277 | self.cleannup = Clear(embedding, work_registers) 278 | 279 | def forward(self, seq): 280 | x = self.shiftpt1.forward(seq) 281 | x = self.clear.forward(x) 282 | x = self.shifted_tchaikovsky.forward(x) 283 | x = self.shiftpt2.forward(x) 284 | x = self.cleannup.forward(x) 285 | return x 286 | 287 | class ShiftL(torch.nn.Module): 288 | def __init__(self, embedding: EmbeddedState, pos: Register, anti_tchaikovsky: Register, register_to_shift: Register, 289 | amount: int, work_registers: list[Register]): 290 | super(ShiftL, self).__init__() 291 | 292 | self.embedding = embedding 293 | 294 | self.shiftpt1 = GRLT2(embedding, pos, amount, register_to_shift, work_registers[0]) 295 | self.clear = Clear(embedding, [register_to_shift]) 296 | self.shifted_antitchaikovsky = GRLT2(embedding, pos, amount - 1, anti_tchaikovsky, work_registers[1]) 297 | self.shiftpt2 = AND(embedding, work_registers[1], work_registers[0], register_to_shift) 298 | self.cleannup = Clear(embedding, work_registers) 299 | 300 | def forward(self, seq): 301 | x = self.shiftpt1.forward(seq) 302 | x = self.clear.forward(x) 303 | x = self.shifted_antitchaikovsky.forward(x) 304 | x = self.shiftpt2.forward(x) 305 | x = self.cleannup.forward(x) 306 | return x 307 | 308 | class NOT_To(MLPLayer): 309 | def __init__(self, embedding: EmbeddedState, from_reg: Register, result_reg: Register): 310 | indices = torch.tensor([[from_reg.offset, result_reg.offset]]) 311 | values = torch.tensor([-1.0]) 312 | first_weights = torch.sparse_coo_tensor(indices.t(), values, (embedding.dim, embedding.dim)) 313 | first_bias = torch.zeros(embedding.dim) 314 | first_bias[result_reg.offset:result_reg.offset + result_reg.size] = 1.0 315 | 316 | second_weights_indices = torch.arange(embedding.dim).repeat(2, 1) 317 | second_weights_values = torch.ones(embedding.dim) 318 | second_weights = torch.sparse_coo_tensor(second_weights_indices, second_weights_values, (embedding.dim, embedding.dim)) 319 | 320 | second_bias = torch.zeros(embedding.dim) 321 | 322 | mask = torch.zeros(embedding.dim) 323 | mask[result_reg.offset:result_reg.offset + result_reg.size] = 1.0 324 | 325 | super(NOT_To, self).__init__(first_weights, first_bias, second_weights, second_bias, mask) 326 | 327 | class NOT(torch.nn.Module): 328 | def __init__(self, embedding: EmbeddedState, pos: Register, register: Register, work_registers: list[Register]): 329 | super(NOT, self).__init__() 330 | 331 | self.not_to = NOT_To(embedding, register, work_registers[0]) 332 | self.clear = Clear(embedding, [register]) 333 | self.copy = Copy(embedding, pos, work_registers[0], register) 334 | self.clear2 = Clear(embedding, work_registers) 335 | 336 | def forward(self, seq): 337 | x = self.not_to.forward(seq) 338 | x = self.clear.forward(x) 339 | x = self.copy.forward(x) 340 | x = self.clear2.forward(x) 341 | return x 342 | 343 | class NOR(MLPLayer): 344 | def __init__(self, embedding: EmbeddedState, first_reg: Register, second_reg: Register, result_reg: Register): 345 | indices = torch.tensor([ 346 | [first_reg.offset, result_reg.offset], 347 | [second_reg.offset, result_reg.offset] 348 | ]) 349 | values = torch.tensor([-1.0, -1.0]) 350 | first_weights = torch.sparse_coo_tensor(indices.t(), values, (embedding.dim, embedding.dim)) 351 | first_bias = torch.zeros(embedding.dim) 352 | first_bias[result_reg.offset:result_reg.offset + result_reg.size] = 1.0 353 | 354 | second_weights_indices = torch.arange(embedding.dim).repeat(2, 1) 355 | second_weights_values = torch.ones(embedding.dim) 356 | second_weights = torch.sparse_coo_tensor(second_weights_indices, second_weights_values, (embedding.dim, embedding.dim)) 357 | 358 | second_bias = torch.zeros(embedding.dim) 359 | 360 | mask = torch.zeros(embedding.dim) 361 | mask[result_reg.offset:result_reg.offset + result_reg.size] = 1.0 362 | 363 | super(NOR, self).__init__(first_weights, first_bias, second_weights, second_bias, mask) 364 | 365 | class OR(torch.nn.Module): 366 | def __init__(self, embedding: EmbeddedState, pos: Register, first_reg: Register, second_reg: Register, 367 | result_reg: Register, work_registers: list[Register]): 368 | super(OR, self).__init__() 369 | 370 | self.part1 = NOR(embedding, first_reg, second_reg, result_reg) 371 | self.part2 = NOT(embedding, pos, result_reg, work_registers) 372 | self.cleanup = Clear(embedding, work_registers) 373 | 374 | def forward(self, seq): 375 | x = self.part1.forward(seq) 376 | x = self.part2.forward(x) 377 | x = self.cleanup.forward(x) 378 | return x 379 | 380 | class XOR(torch.nn.Module): 381 | def __init__(self, embedding: EmbeddedState, pos: Register, first_reg: Register, second_reg: Register, 382 | result_reg: Register, work_registers: list[Register]): 383 | super(XOR, self).__init__() 384 | 385 | self.part1 = OR(embedding, pos, first_reg, second_reg, work_registers[0], work_registers[1:]) 386 | 387 | self.part2 = AND(embedding, first_reg, second_reg, work_registers[1]) 388 | 389 | self.part3 = NOT(embedding, pos, work_registers[1], work_registers[2:]) 390 | 391 | self.part4 = AND(embedding, work_registers[0], work_registers[1], result_reg) 392 | 393 | self.part5 = Clear(embedding, work_registers) 394 | 395 | def forward(self, seq): 396 | x = self.part1.forward(seq) 397 | x = self.part2.forward(x) 398 | x = self.part3.forward(x) 399 | x = self.part4.forward(x) 400 | x = self.part5.forward(x) 401 | return x 402 | 403 | class Rotate(torch.nn.Module): 404 | def __init__(self, embedding: EmbeddedState, pos: Register, tchaikovsky: Register, anti_tchaikovsky: Register, 405 | register_to_rotate: Register, amount: int, work_registers: list[Register]): 406 | super(Rotate, self).__init__() 407 | 408 | self.embedding = embedding 409 | 410 | self.copies = [Copy(embedding, pos, register_to_rotate, work_registers[i]) for i in range(2)] 411 | 412 | self.shift_right = Shift(embedding, pos, tchaikovsky, work_registers[0], amount, work_registers[2:]) 413 | 414 | self.left_shifts = ShiftL(embedding, pos, anti_tchaikovsky, work_registers[1], (INPUT_LENGTH - amount), 415 | work_registers[2:]) 416 | 417 | self.clear = Clear(embedding, [register_to_rotate]) 418 | 419 | self.or_result = OR(embedding, pos, work_registers[0], work_registers[1], register_to_rotate, 420 | work_registers[2:]) 421 | 422 | self.clear_work = Clear(embedding, work_registers) 423 | 424 | def forward(self, seq): 425 | for copy in self.copies: 426 | seq = copy.forward(seq) 427 | seq = self.shift_right.forward(seq) 428 | seq = self.left_shifts.forward(seq) 429 | seq = self.clear.forward(seq) 430 | seq = self.or_result.forward(seq) 431 | seq = self.clear_work.forward(seq) 432 | return seq 433 | 434 | class Add(torch.nn.Module): 435 | def __init__(self, embedding: EmbeddedState, pos: Register, anti_tchaikovsky: Register, a: Register, b: Register, 436 | result: Register, work_registers: list[Register]): 437 | super(Add, self).__init__() 438 | 439 | self.embedding = embedding 440 | 441 | self.first_sum = XOR(embedding, pos, a, b, work_registers[0], work_registers[2:]) 442 | self.first_carry = AND(embedding, a, b, work_registers[1]) 443 | 444 | self.next_operations = [] 445 | 446 | for _ in range(32): 447 | copy_of_carry = Copy(embedding, pos, work_registers[1], work_registers[2]) 448 | shifted_carry = ShiftL(embedding, pos, anti_tchaikovsky, work_registers[2], 1, work_registers[3:]) 449 | new_sum = XOR(embedding, pos, work_registers[0], work_registers[2], work_registers[3], work_registers[4:]) 450 | clear_carry = Clear(embedding, [work_registers[1]]) 451 | carry = AND(embedding, work_registers[0], work_registers[2], work_registers[1]) 452 | clear_sum = Clear(embedding, [work_registers[0]]) 453 | sum = Copy(embedding, pos, work_registers[3], work_registers[0]) 454 | clear_work = Clear(embedding, work_registers[2:]) 455 | self.next_operations.append( 456 | (copy_of_carry, shifted_carry, new_sum, clear_carry, carry, clear_sum, sum, clear_work)) 457 | 458 | self.copy_to_result = Copy(embedding, pos, work_registers[0], result) 459 | 460 | self.clear = Clear(embedding, work_registers) 461 | 462 | def forward(self, seq): 463 | x = self.first_sum.forward(seq) 464 | x = self.first_carry.forward(x) 465 | 466 | for copy_of_carry, shifted_carry, new_sum, clear_carry, carry, clear_sum, sum, clear_work in self.next_operations: 467 | x = copy_of_carry.forward(x) 468 | x = shifted_carry.forward(x) 469 | x = new_sum.forward(x) 470 | x = clear_carry.forward(x) 471 | x = carry.forward(x) 472 | x = clear_sum.forward(x) 473 | x = sum.forward(x) 474 | x = clear_work.forward(x) 475 | 476 | x = self.copy_to_result.forward(x) 477 | x = self.clear.forward(x) 478 | 479 | return x 480 | 481 | class RotateWithLimit(torch.nn.Module): 482 | def __init__(self, embedding: EmbeddedState, pos: Register, tchaikovsky: Register, anti_tchaikovsky: Register, 483 | register_to_rotate: Register, amount: int, limit: int, work_registers: list[Register]): 484 | super(RotateWithLimit, self).__init__() 485 | 486 | self.embedding = embedding 487 | 488 | self.copies = [Copy(embedding, pos, register_to_rotate, work_registers[i]) for i in range(2)] 489 | 490 | self.shift_right = Shift(embedding, pos, tchaikovsky, work_registers[0], amount, work_registers[2:]) 491 | 492 | self.left_shifts = ShiftL(embedding, pos, anti_tchaikovsky, work_registers[1], limit - amount, 493 | work_registers[2:]) 494 | 495 | self.clear_shift_1 = Shift(embedding, pos, tchaikovsky, work_registers[1], INPUT_LENGTH - amount, 496 | work_registers[2:]) 497 | self.clear_shift_2 = ShiftL(embedding, pos, anti_tchaikovsky, work_registers[1], 498 | INPUT_LENGTH - amount, work_registers[2:]) 499 | 500 | self.clear = Clear(embedding, [register_to_rotate]) 501 | 502 | self.or_result = OR(embedding, pos, work_registers[0], work_registers[1], register_to_rotate, 503 | work_registers[2:]) 504 | 505 | self.clear_work = Clear(embedding, work_registers) 506 | 507 | def forward(self, seq): 508 | for copy in self.copies: 509 | seq = copy.forward(seq) 510 | seq = self.shift_right.forward(seq) 511 | seq = self.left_shifts.forward(seq) 512 | seq = self.clear_shift_1.forward(seq) 513 | seq = self.clear_shift_2.forward(seq) 514 | seq = self.clear.forward(seq) 515 | seq = self.or_result.forward(seq) 516 | seq = self.clear_work.forward(seq) 517 | return seq 518 | 519 | class Print(torch.nn.Module): 520 | def __init__(self, embedding: EmbeddedState, register: Register): 521 | super(Print, self).__init__() 522 | 523 | self.embedding = embedding 524 | self.register = register 525 | 526 | def forward(self, seq): 527 | print(''.join( 528 | str(c) for c in [int(q) for q in seq[0, :, self.register.offset:self.register.offset + self.register.size] 529 | .detach().flatten()])) 530 | 531 | return seq 532 | -------------------------------------------------------------------------------- /lib.py: -------------------------------------------------------------------------------- 1 | # Various imports 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn 8 | import torch.nn.functional as F 9 | 10 | from constants import INPUT_LENGTH 11 | 12 | # A constant which determines how fine-grained our positional embedding is 13 | POS_STEP = 1e-3 14 | 15 | 16 | class Register(object): 17 | def __init__(self, name, size): 18 | self.name = name 19 | self.size = size 20 | self.offset = None 21 | 22 | 23 | class EmbeddedState(object): 24 | def __init__(self, tokens: list[str], registers: list[Register]): 25 | self.tokens = tokens 26 | self.token_map = {t: i for i, t in enumerate(tokens)} 27 | self.registers = registers 28 | self.register_map = {} 29 | self.register_size = 0 30 | 31 | if len(registers) == 0 or registers[0].name != 'pos': 32 | raise Exception("First register must be 'pos'") 33 | 34 | offset = len(tokens) 35 | for reg in registers: 36 | reg.offset = offset 37 | offset += reg.size 38 | self.register_size += reg.size 39 | self.register_map[reg.name] = reg 40 | 41 | self.dim = len(tokens) + self.register_size 42 | 43 | def tokenize(self, string: str): 44 | return F.one_hot(torch.tensor([self.token_map[c] for c in string]), num_classes=len(self.tokens)).float() 45 | 46 | def itokenize(self, string: str): 47 | return torch.tensor([self.token_map[c] for c in string]).float().unsqueeze(1) 48 | 49 | def embed(self, sequence, additional_constants): 50 | # We want to create additional space to store the registers 51 | extension_tensor = torch.zeros(*sequence.shape[:-1], self.register_size) 52 | 53 | # Encode position in the first extra embedding dimension 54 | for i in range(sequence.shape[0]): 55 | extension_tensor[i, 0] = math.sin(i * (2 * math.pi) * POS_STEP) 56 | extension_tensor[i, 1] = math.cos(i * (2 * math.pi) * POS_STEP) 57 | 58 | # Next columns of the extension tensor are the additional constants 59 | offset = 2 60 | for constant in additional_constants: 61 | extension_tensor[:, offset:offset + constant.shape[-1]] = constant 62 | offset += constant.shape[-1] 63 | 64 | sequence = torch.cat((sequence, extension_tensor), dim=-1) 65 | 66 | return sequence 67 | 68 | def predict(self, sequence): 69 | return self.tokens[torch.argmax(sequence[-1, :len(self.tokens)])] 70 | 71 | 72 | class AttentionLayer(torch.nn.Module): 73 | def __init__(self, instruction): 74 | super(AttentionLayer, self).__init__() 75 | 76 | self.key = torch.nn.Parameter(instruction.key) 77 | self.value = torch.nn.Parameter(instruction.value) 78 | self.query = torch.nn.Parameter(instruction.query) 79 | 80 | self.mask = instruction.mask 81 | 82 | self.softmax = torch.nn.Softmax(2) 83 | 84 | def forward(self, seq): 85 | batch_size, seq_length, dim = seq.shape 86 | 87 | query = seq @ self.query 88 | key = seq @ self.key 89 | value = seq @ self.value 90 | 91 | causal_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1) * 0).to(seq.device) 92 | norm = np.sqrt(dim) 93 | 94 | kq = self.softmax(query @ key.transpose(-2, -1) / norm + causal_mask) 95 | 96 | s = (kq @ value) * self.mask 97 | 98 | return (seq + s) 99 | 100 | def reset(self): 101 | torch.nn.init.xavier_uniform_(self.key) 102 | torch.nn.init.xavier_uniform_(self.query) 103 | torch.nn.init.xavier_uniform_(self.value) 104 | 105 | 106 | class GetRelativeToken(AttentionLayer): 107 | def __init__(self, embedding: EmbeddedState, pos_reg: Register, steps: int, out: Register): 108 | tpos_reg = embedding.register_map['pos'] 109 | 110 | position_select = torch.zeros(embedding.dim, embedding.dim) 111 | position_select[tpos_reg.offset, tpos_reg.offset] = 1e10 112 | position_select[tpos_reg.offset + 1, tpos_reg.offset + 1] = 1e10 113 | 114 | i = -steps 115 | sin = math.sin(i * (2 * math.pi) * POS_STEP) * 1 116 | cos = math.cos(i * (2 * math.pi) * POS_STEP) * 1 117 | 118 | rotation = torch.zeros(embedding.dim, embedding.dim) 119 | rotation[pos_reg.offset, tpos_reg.offset] = cos 120 | rotation[pos_reg.offset + 1, tpos_reg.offset] = -sin 121 | rotation[pos_reg.offset, tpos_reg.offset + 1] = sin 122 | rotation[pos_reg.offset + 1, tpos_reg.offset + 1] = cos 123 | 124 | token_copy = torch.zeros(embedding.dim, embedding.dim) 125 | for i in range(len(embedding.tokens)): 126 | token_copy[i, i + out.offset] = 1.0 127 | 128 | self.query = rotation 129 | self.key = position_select 130 | self.value = token_copy 131 | 132 | self.mask = torch.zeros(embedding.dim) 133 | self.mask[out.offset:out.offset + out.size] = 1.0 134 | 135 | super(GetRelativeToken, self).__init__(self) 136 | 137 | 138 | class MLPLayer(torch.nn.Module): 139 | def __init__(self, instruction, debug=False): 140 | super(MLPLayer, self).__init__() 141 | self.debug = debug 142 | 143 | self.first_weights = torch.nn.Parameter(instruction.first_weights) 144 | self.first_bias = torch.nn.Parameter(instruction.first_bias) 145 | self.second_weights = torch.nn.Parameter(instruction.second_weights) 146 | self.second_bias = torch.nn.Parameter(instruction.second_bias) 147 | 148 | self.gelu = torch.nn.ReLU() 149 | 150 | self.mask = instruction.mask 151 | 152 | def forward(self, seq): 153 | a = self.gelu(seq @ self.first_weights + self.first_bias) 154 | b = (a @ self.second_weights) 155 | x = b + self.second_bias 156 | return seq + (x * self.mask) 157 | 158 | def reset(self): 159 | torch.nn.init.xavier_uniform_(self.first_weights) 160 | torch.nn.init.zeros_(self.first_bias) 161 | torch.nn.init.xavier_uniform_(self.second_weights) 162 | torch.nn.init.zeros_(self.second_bias) 163 | 164 | 165 | class ConvertToInternal(MLPLayer): 166 | def __init__(self, embedding: EmbeddedState, out: Register): 167 | self.first_weights = torch.zeros(embedding.dim, embedding.dim) 168 | self.first_bias = torch.zeros(embedding.dim) 169 | 170 | self.first_weights[1, out.offset] += 1 171 | 172 | self.second_weights = torch.eye(embedding.dim) 173 | self.second_bias = torch.zeros(embedding.dim) 174 | 175 | self.mask = torch.zeros(embedding.dim) 176 | for reg in [out]: 177 | self.mask[reg.offset:reg.offset + reg.size] = 1.0 178 | 179 | super(ConvertToInternal, self).__init__(self) 180 | 181 | 182 | class GRLT2(AttentionLayer): 183 | """ 184 | Copy the tokens from the given register to the output register, with an optional rotation by `steps` 185 | """ 186 | 187 | def __init__(self, embedding: EmbeddedState, pos_reg: Register, steps: int, copy_from: Register, out: Register): 188 | tpos_reg = embedding.register_map['pos'] 189 | 190 | position_select = torch.zeros(embedding.dim, embedding.dim) 191 | position_select[tpos_reg.offset, tpos_reg.offset] = 1e10 192 | position_select[tpos_reg.offset + 1, tpos_reg.offset + 1] = 1e10 193 | 194 | i = -steps 195 | sin = math.sin(i * (2 * math.pi) * POS_STEP) * 1 196 | cos = math.cos(i * (2 * math.pi) * POS_STEP) * 1 197 | 198 | rotation = torch.zeros(embedding.dim, embedding.dim) 199 | rotation[pos_reg.offset, tpos_reg.offset] = cos 200 | rotation[pos_reg.offset + 1, tpos_reg.offset] = -sin 201 | rotation[pos_reg.offset, tpos_reg.offset + 1] = sin 202 | rotation[pos_reg.offset + 1, tpos_reg.offset + 1] = cos 203 | 204 | token_copy = torch.zeros(embedding.dim, embedding.dim) 205 | token_copy[copy_from.offset, out.offset] = 1.0 206 | 207 | self.query = rotation 208 | self.key = position_select 209 | self.value = token_copy 210 | 211 | self.mask = torch.zeros(embedding.dim) 212 | self.mask[out.offset:out.offset + out.size] = 1.0 213 | 214 | super(GRLT2, self).__init__(self) 215 | 216 | 217 | class AND(MLPLayer): 218 | def __init__(self, embedding: EmbeddedState, first_reg: Register, second_reg: Register, result_reg: Register): 219 | self.first_weights = torch.zeros(embedding.dim, embedding.dim) 220 | self.first_bias = torch.zeros(embedding.dim) 221 | 222 | self.first_weights[first_reg.offset, result_reg.offset] += 1 223 | self.first_weights[second_reg.offset, result_reg.offset] += 1 224 | self.first_bias[result_reg.offset:result_reg.offset + result_reg.size] = -1.0 225 | 226 | self.second_weights = torch.eye(embedding.dim) 227 | self.second_bias = torch.zeros(embedding.dim) 228 | 229 | self.mask = torch.zeros(embedding.dim) 230 | for reg in [result_reg]: 231 | self.mask[reg.offset:reg.offset + reg.size] = 1.0 232 | 233 | super(AND, self).__init__(self) 234 | 235 | 236 | class Clear(MLPLayer): 237 | def __init__(self, embedding: EmbeddedState, registers: list[Register]): 238 | self.first_weights = torch.zeros(embedding.dim, embedding.dim) 239 | self.first_bias = torch.zeros(embedding.dim) 240 | 241 | for reg in registers: 242 | for i in range(reg.size): 243 | self.first_weights[reg.offset + i, reg.offset + i] = 100.0 244 | 245 | self.second_weights = torch.zeros(embedding.dim, embedding.dim) 246 | self.second_bias = torch.zeros(embedding.dim) 247 | for reg in registers: 248 | for i in range(reg.size): 249 | self.second_weights[reg.offset + i, reg.offset + i] = -0.01 250 | 251 | self.mask = torch.zeros(embedding.dim) 252 | for reg in registers: 253 | self.mask[reg.offset:reg.offset + reg.size] = 1.0 254 | 255 | super(Clear, self).__init__(self) 256 | 257 | 258 | class Copy(torch.nn.Module): 259 | def __init__(self, embedding: EmbeddedState, pos_reg: Register, copy_from: Register, copy_to: Register): 260 | super(Copy, self).__init__() 261 | 262 | self.copy = GRLT2(embedding, pos_reg, 0, copy_from, copy_to) 263 | 264 | def forward(self, seq): 265 | return self.copy.forward(seq) 266 | 267 | 268 | class Shift(torch.nn.Module): 269 | def __init__(self, embedding: EmbeddedState, pos: Register, tchaikovsky: Register, register_to_shift: Register, 270 | amount: int, work_registers: list[Register]): 271 | super(Shift, self).__init__() 272 | 273 | self.embedding = embedding 274 | 275 | self.shiftpt1 = GRLT2(embedding, pos, -amount, register_to_shift, work_registers[0]) 276 | self.clear = Clear(embedding, [register_to_shift]) 277 | self.shifted_tchaikovsky = GRLT2(embedding, pos, -(amount - 1), tchaikovsky, work_registers[1]) 278 | self.shiftpt2 = AND(embedding, work_registers[1], work_registers[0], register_to_shift) 279 | self.cleannup = Clear(embedding, work_registers) 280 | 281 | def forward(self, seq): 282 | x = self.shiftpt1.forward(seq) 283 | x = self.clear.forward(x) 284 | x = self.shifted_tchaikovsky.forward(x) 285 | x = self.shiftpt2.forward(x) 286 | x = self.cleannup.forward(x) 287 | return x 288 | 289 | 290 | class ShiftL(torch.nn.Module): 291 | def __init__(self, embedding: EmbeddedState, pos: Register, anti_tchaikovsky: Register, register_to_shift: Register, 292 | amount: int, work_registers: list[Register]): 293 | super(ShiftL, self).__init__() 294 | 295 | self.embedding = embedding 296 | 297 | self.shiftpt1 = GRLT2(embedding, pos, amount, register_to_shift, work_registers[0]) 298 | self.clear = Clear(embedding, [register_to_shift]) 299 | self.shifted_antitchaikovsky = GRLT2(embedding, pos, amount - 1, anti_tchaikovsky, work_registers[1]) 300 | self.shiftpt2 = AND(embedding, work_registers[1], work_registers[0], register_to_shift) 301 | self.cleannup = Clear(embedding, work_registers) 302 | 303 | def forward(self, seq): 304 | x = self.shiftpt1.forward(seq) 305 | x = self.clear.forward(x) 306 | x = self.shifted_antitchaikovsky.forward(x) 307 | x = self.shiftpt2.forward(x) 308 | x = self.cleannup.forward(x) 309 | return x 310 | 311 | 312 | class NOT_To(MLPLayer): 313 | def __init__(self, embedding: EmbeddedState, from_reg: Register, result_reg: Register): 314 | self.first_weights = torch.zeros(embedding.dim, embedding.dim) 315 | self.first_bias = torch.zeros(embedding.dim) 316 | 317 | self.first_weights[from_reg.offset, result_reg.offset] = -1 318 | self.first_bias[result_reg.offset:result_reg.offset + result_reg.size] = 1.0 319 | 320 | self.second_weights = torch.eye(embedding.dim) 321 | self.second_bias = torch.zeros(embedding.dim) 322 | 323 | self.mask = torch.zeros(embedding.dim) 324 | for reg in [result_reg]: 325 | self.mask[reg.offset:reg.offset + reg.size] = 1.0 326 | 327 | super(NOT_To, self).__init__(self) 328 | 329 | 330 | class NOT(torch.nn.Module): 331 | def __init__(self, embedding: EmbeddedState, pos: Register, register: Register, work_registers: list[Register]): 332 | super(NOT, self).__init__() 333 | 334 | self.not_to = NOT_To(embedding, register, work_registers[0]) 335 | self.clear = Clear(embedding, [register]) 336 | self.copy = Copy(embedding, pos, work_registers[0], register) 337 | self.clear2 = Clear(embedding, work_registers) 338 | 339 | def forward(self, seq): 340 | x = self.not_to.forward(seq) 341 | x = self.clear.forward(x) 342 | x = self.copy.forward(x) 343 | x = self.clear2.forward(x) 344 | return x 345 | 346 | 347 | class NOR(MLPLayer): 348 | def __init__(self, embedding: EmbeddedState, first_reg: Register, second_reg: Register, result_reg: Register): 349 | self.first_weights = torch.zeros(embedding.dim, embedding.dim) 350 | self.first_bias = torch.zeros(embedding.dim) 351 | 352 | self.first_weights[first_reg.offset, result_reg.offset] += -1 353 | self.first_weights[second_reg.offset, result_reg.offset] += -1 354 | self.first_bias[result_reg.offset:result_reg.offset + result_reg.size] = 1.0 355 | 356 | self.second_weights = torch.eye(embedding.dim) 357 | self.second_bias = torch.zeros(embedding.dim) 358 | 359 | self.mask = torch.zeros(embedding.dim) 360 | for reg in [result_reg]: 361 | self.mask[reg.offset:reg.offset + reg.size] = 1.0 362 | 363 | super(NOR, self).__init__(self) 364 | 365 | 366 | class OR(torch.nn.Module): 367 | def __init__(self, embedding: EmbeddedState, pos: Register, first_reg: Register, second_reg: Register, 368 | result_reg: Register, work_registers: list[Register]): 369 | super(OR, self).__init__() 370 | 371 | self.part1 = NOR(embedding, first_reg, second_reg, result_reg) 372 | self.part2 = NOT(embedding, pos, result_reg, work_registers) 373 | self.cleanup = Clear(embedding, work_registers) 374 | 375 | def forward(self, seq): 376 | x = self.part1.forward(seq) 377 | x = self.part2.forward(x) 378 | x = self.cleanup.forward(x) 379 | return x 380 | 381 | 382 | class XOR(torch.nn.Module): 383 | def __init__(self, embedding: EmbeddedState, pos: Register, first_reg: Register, second_reg: Register, 384 | result_reg: Register, work_registers: list[Register]): 385 | super(XOR, self).__init__() 386 | 387 | # A OR B 388 | self.part1 = OR(embedding, pos, first_reg, second_reg, work_registers[0], work_registers[1:]) 389 | 390 | # A AND B 391 | self.part2 = AND(embedding, first_reg, second_reg, work_registers[1]) 392 | 393 | # NOT (A AND B) 394 | self.part3 = NOT(embedding, pos, work_registers[1], work_registers[2:]) 395 | 396 | # (A OR B) AND NOT (A AND B) 397 | self.part4 = AND(embedding, work_registers[0], work_registers[1], result_reg) 398 | 399 | # Clear the work registers 400 | self.part5 = Clear(embedding, work_registers) 401 | 402 | def forward(self, seq): 403 | x = self.part1.forward(seq) 404 | x = self.part2.forward(x) 405 | x = self.part3.forward(x) 406 | x = self.part4.forward(x) 407 | x = self.part5.forward(x) 408 | return x 409 | 410 | 411 | class Rotate(torch.nn.Module): 412 | def __init__(self, embedding: EmbeddedState, pos: Register, tchaikovsky: Register, anti_tchaikovsky: Register, 413 | register_to_rotate: Register, amount: int, work_registers: list[Register]): 414 | super(Rotate, self).__init__() 415 | 416 | self.embedding = embedding 417 | 418 | # First, we need to copy the register to two work registers 419 | # Thus, work registers 0 and 1 are currently in use 420 | self.copies = [Copy(embedding, pos, register_to_rotate, work_registers[i]) for i in range(2)] 421 | 422 | # Next, we shift the first work register to the right 423 | self.shift_right = Shift(embedding, pos, tchaikovsky, work_registers[0], amount, work_registers[2:]) 424 | 425 | # Then, we shift the second work register to the left INPUT_LENGTH - 1 times 426 | self.left_shifts = ShiftL(embedding, pos, anti_tchaikovsky, work_registers[1], (INPUT_LENGTH - amount), 427 | work_registers[2:]) 428 | 429 | # Now, we clear the original register 430 | self.clear = Clear(embedding, [register_to_rotate]) 431 | 432 | # And finally, we OR work registers 0 and 1 to get the final result 433 | self.or_result = OR(embedding, pos, work_registers[0], work_registers[1], register_to_rotate, 434 | work_registers[2:]) 435 | 436 | # Oh, and clear the work registers 437 | self.clear_work = Clear(embedding, work_registers) 438 | 439 | def forward(self, seq): 440 | for copy in self.copies: 441 | seq = copy.forward(seq) 442 | seq = self.shift_right.forward(seq) 443 | seq = self.left_shifts.forward(seq) 444 | seq = self.clear.forward(seq) 445 | seq = self.or_result.forward(seq) 446 | seq = self.clear_work.forward(seq) 447 | return seq 448 | 449 | 450 | class Add(torch.nn.Module): 451 | def __init__(self, embedding: EmbeddedState, pos: Register, anti_tchaikovsky: Register, a: Register, b: Register, 452 | result: Register, work_registers: list[Register]): 453 | super(Add, self).__init__() 454 | 455 | self.embedding = embedding 456 | 457 | # work_registers[0] is `sum_` 458 | # work_registers[1] is `carry` 459 | self.first_sum = XOR(embedding, pos, a, b, work_registers[0], work_registers[2:]) 460 | self.first_carry = AND(embedding, a, b, work_registers[1]) 461 | 462 | self.next_operations = [] 463 | 464 | for _ in range(32): 465 | # Copy `carry` to work_registers[2] 466 | copy_of_carry = Copy(embedding, pos, work_registers[1], work_registers[2]) 467 | # Shift this copy of `carry` to the left. Now `work_registers[2]` contains `shifted_carry` 468 | shifted_carry = ShiftL(embedding, pos, anti_tchaikovsky, work_registers[2], 1, work_registers[3:]) 469 | # XOR `sum_` with `shifted_carry`. Now `work_registers[3]` contains `new_sum` 470 | new_sum = XOR(embedding, pos, work_registers[0], work_registers[2], work_registers[3], work_registers[4:]) 471 | # Clear `carry` 472 | clear_carry = Clear(embedding, [work_registers[1]]) 473 | # AND `sum_` with `shifted_carry`. Now `work_registers[1]` contains `carry` again 474 | carry = AND(embedding, work_registers[0], work_registers[2], work_registers[1]) 475 | # Clear `sum` 476 | clear_sum = Clear(embedding, [work_registers[0]]) 477 | # Copy `new_sum` to `sum_` 478 | sum = Copy(embedding, pos, work_registers[3], work_registers[0]) 479 | # Clear the work registers 480 | clear_work = Clear(embedding, work_registers[2:]) 481 | self.next_operations.append( 482 | (copy_of_carry, shifted_carry, new_sum, clear_carry, carry, clear_sum, sum, clear_work)) 483 | 484 | self.copy_to_result = Copy(embedding, pos, work_registers[0], result) 485 | 486 | self.clear = Clear(embedding, work_registers) 487 | 488 | def forward(self, seq): 489 | x = self.first_sum.forward(seq) 490 | x = self.first_carry.forward(x) 491 | 492 | for copy_of_carry, shifted_carry, new_sum, clear_carry, carry, clear_sum, sum, clear_work in self.next_operations: 493 | x = copy_of_carry.forward(x) 494 | x = shifted_carry.forward(x) 495 | x = new_sum.forward(x) 496 | x = clear_carry.forward(x) 497 | x = carry.forward(x) 498 | x = clear_sum.forward(x) 499 | x = sum.forward(x) 500 | x = clear_work.forward(x) 501 | 502 | x = self.copy_to_result.forward(x) 503 | x = self.clear.forward(x) 504 | 505 | return x 506 | 507 | 508 | class RotateWithLimit(torch.nn.Module): 509 | def __init__(self, embedding: EmbeddedState, pos: Register, tchaikovsky: Register, anti_tchaikovsky: Register, 510 | register_to_rotate: Register, amount: int, limit: int, work_registers: list[Register]): 511 | super(RotateWithLimit, self).__init__() 512 | 513 | self.embedding = embedding 514 | 515 | # First, we need to copy the register to two work registers 516 | # Thus, work registers 0 and 1 are currently in use 517 | self.copies = [Copy(embedding, pos, register_to_rotate, work_registers[i]) for i in range(2)] 518 | 519 | # Next, we shift the first work register to the right 520 | self.shift_right = Shift(embedding, pos, tchaikovsky, work_registers[0], amount, work_registers[2:]) 521 | 522 | # Then, we shift the second work register to the left `limit - amount` times 523 | self.left_shifts = ShiftL(embedding, pos, anti_tchaikovsky, work_registers[1], limit - amount, 524 | work_registers[2:]) 525 | 526 | # Now we clear the rest of the `head` register to zeros to allow it to or in 527 | self.clear_shift_1 = Shift(embedding, pos, tchaikovsky, work_registers[1], INPUT_LENGTH - amount, 528 | work_registers[2:]) 529 | self.clear_shift_2 = ShiftL(embedding, pos, anti_tchaikovsky, work_registers[1], 530 | INPUT_LENGTH - amount, work_registers[2:]) 531 | 532 | # Now, we clear the original register 533 | self.clear = Clear(embedding, [register_to_rotate]) 534 | 535 | # And finally, we OR work registers 0 and 1 to get the final result 536 | self.or_result = OR(embedding, pos, work_registers[0], work_registers[1], register_to_rotate, 537 | work_registers[2:]) 538 | 539 | # Oh, and clear the work registers 540 | self.clear_work = Clear(embedding, work_registers) 541 | 542 | def forward(self, seq): 543 | for copy in self.copies: 544 | seq = copy.forward(seq) 545 | seq = self.shift_right.forward(seq) 546 | seq = self.left_shifts.forward(seq) 547 | seq = self.clear_shift_1.forward(seq) 548 | seq = self.clear_shift_2.forward(seq) 549 | seq = self.clear.forward(seq) 550 | seq = self.or_result.forward(seq) 551 | seq = self.clear_work.forward(seq) 552 | return seq 553 | 554 | 555 | class Print(torch.nn.Module): 556 | def __init__(self, embedding: EmbeddedState, register: Register): 557 | super(Print, self).__init__() 558 | 559 | self.embedding = embedding 560 | self.register = register 561 | 562 | def forward(self, seq): 563 | print(''.join( 564 | str(c) for c in [int(q) for q in seq[0, :, self.register.offset:self.register.offset + self.register.size] 565 | .detach().flatten()])) 566 | 567 | return seq 568 | --------------------------------------------------------------------------------