├── data ├── reweight_image.JPG └── file.txt ├── README.md ├── LICENSE ├── reweight_gpt.py └── reweight-gpt-nonlinear.py /data/reweight_image.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hunar4321/reweight-gpt/HEAD/data/reweight_image.JPG -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![GitHub repo size](https://img.shields.io/github/repo-size/hunar4321/reweight-gpt) 2 | ![GitHub](https://img.shields.io/github/license/hunar4321/reweight-gpt) 3 | 4 | # Reweight GPT 5 | 6 | An alternative to the self-attetnion mechanism in Tranformer achitechture. 7 | It uses learnable lateral connections to reweight the inputs directly instead of the self-attention mechanism (as illustrated below). 8 | To learn more about the method, watch this video (from 41:26): 9 | https://youtu.be/l-CjXFmcVzY 10 | 11 | # Files: 12 | 1. the tutorial folder - A step by step tutorial from the basics to GPT. 13 | 2. reweight-gpt.py (A multi-block GPT implimentation using direct re-weighting of the attention matrix). 14 | 3. reweight-gpt-nonlinear.py (A nonlinear version of the direct re-weighting method. For easy comparsion between the two methods, I adapted this script directly from Andrej Karpathy's GPT implimentation). 15 | 16 | # Illustration: 17 | 18 | ![](data/reweight_image.JPG) 19 | 20 | 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Hunar Ahmad 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/file.txt: -------------------------------------------------------------------------------- 1 | Cats, cats, everywhere 2 | Furry balls without a care 3 | Purring, meowing, licking fur 4 | Hunting mice, they always purr 5 | 6 | Cats, cats, on the prowl 7 | Jumping high, never a scowl 8 | Whiskers twitching, eyes alert 9 | Tail in air, ready to assert 10 | 11 | Cats, cats, so much fun 12 | Cuddling close in the sun 13 | Stretching out, napping long 14 | Playing with string, never wrong 15 | 16 | Cats, cats, always cool 17 | Lapping milk, acting like a fool 18 | Mysterious, charming, full of grace 19 | Cats are simply ace 20 | 21 | Cats, cats, with silky fur 22 | Making biscuits, they always purr 23 | Sitting high, looking down 24 | Claiming everything as their crown 25 | 26 | Cats, cats, with eyes so bright 27 | Chasing shadows, day or night 28 | Curled up warm, on your lap 29 | Purring gently, taking a nap 30 | 31 | Cats, cats, with playful paws 32 | Hiding, stalking, never pause 33 | Jumping, leaping, so agile 34 | Graceful creatures, never fragile 35 | 36 | Cats, cats, our feline friends 37 | Bringing joy that never ends 38 | Loving us, without a doubt 39 | Cats are what life's all about 40 | 41 | Cats, cats, everywhere I see 42 | Furry creatures, cute as can be 43 | Rubbing against our legs 44 | Asking for treats, without begs 45 | 46 | Cats, cats, with their regal stance 47 | Graceful movements, they enhance 48 | But we love them all the same 49 | Our little friends, never tame 50 | 51 | Cats, cats, so full of love 52 | Watching over us from above 53 | Protecting us from any harm 54 | Always there, with their charm 55 | 56 | Cats, cats, with their curious ways 57 | Exploring nooks, and hiding in bays 58 | Living life with style and grace 59 | Cats are always in first place 60 | 61 | Cats, cats, so full of fun 62 | Chasing toys, never done 63 | Hiding in boxes, or paper bags 64 | Making us laugh, never drags 65 | 66 | Cats, cats, with their own minds 67 | Sitting in the sun, never blinds 68 | Chasing strings, and balls of yarn 69 | They never tire, oh what a charm 70 | 71 | Cats, cats, with calming purrs 72 | Cuddling close, to be yours 73 | Giving love, without any fuss 74 | Their presence, a comfort to us 75 | 76 | Cats, cats, always at ease 77 | Living life, as they please 78 | Bringing joy, to all they meet 79 | Cats, our furry friends, so sweet 80 | 81 | Cats, cats, with eyes so bright 82 | Guiding us through the darkest night 83 | Purring softly, by our side 84 | Comforting us, as we hide 85 | 86 | Cats, cats, with softest fur 87 | Nuzzling close, making a purr 88 | In our lap, they take a rest 89 | We're lucky to have, such a guest 90 | 91 | Cats, cats, with their playful ways 92 | Entertaining us, on the laziest days 93 | Chasing shadows, or a feather 94 | Making us smile, always together 95 | 96 | Cats, cats, with hearts so pure 97 | Bringing love, that will endure 98 | Their presence, a blessing indeed 99 | Cats, our friends, we shall never need 100 | 101 | Cats, cats, with their little quirks 102 | Scratching posts, and tiny perks 103 | Licking paws, cleaning their face 104 | Chasing tails, all over the place 105 | 106 | Cats, cats, with their playful hearts 107 | Chasing toys, and little carts 108 | Their antics, bringing us joy 109 | Cats, our little angels, oh so coy 110 | 111 | Cats, cats, with their gentle souls 112 | Lifting spirits, making us whole 113 | In their eyes, we see the light 114 | Bringing peace, that feels so right 115 | 116 | Cats, cats, with their gentle purr 117 | Calming us, when we're feeling a stir 118 | Snuggling close, to keep us warm 119 | Cats, our little cuddle storm 120 | 121 | Cats, cats, with their playful heart 122 | Jumping high, right from the start 123 | Bouncing around, like little springs 124 | Cats, our little entertainers, with wings 125 | 126 | Cats, cats, with their loving grace 127 | Their soft purrs, caress our face 128 | In their embrace, we feel at peace 129 | Cats, our little comfort, never to cease 130 | 131 | Cats, cats, with their loving ways 132 | Cuddling close, on the darkest days 133 | In the garden, or up in a tree 134 | Cats, our little explorers, always free 135 | -------------------------------------------------------------------------------- /reweight_gpt.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reweight GPT: An alternative to the self-attetnion mechanism in Tranformer achitechture. 3 | Author: Hunar Ahmad Abdulrahman @ brainxyz.com 4 | 5 | This method uses learnable lateral connections to reweight the inputs instead of the self-attention mechanism (which are commented). 6 | To learn more about the method, watch this video (from 41:26): https://youtu.be/l-CjXFmcVzY 7 | ''' 8 | 9 | import numpy as np 10 | import matplotlib.pylab as plt 11 | import torch 12 | from torch.nn import functional as F 13 | 14 | with open('data/file.txt', 'r', encoding='utf-8') as f: 15 | text = f.read() 16 | text = text.lower() 17 | 18 | chars = sorted(list(set(text))) 19 | stoi = {ch:i for i,ch in enumerate(chars)} 20 | itos = {i:ch for i,ch in enumerate(chars)} 21 | data = [stoi[c] for c in text] 22 | vocab_size = len(chars) 23 | 24 | device = 'cpu' 25 | ins = 16 26 | outs = vocab_size 27 | nodes = 32 28 | lr = 0.001 29 | n_emb = 32 30 | 31 | embed = torch.randn(vocab_size, n_emb) 32 | pos = torch.randn(ins, n_emb) 33 | embed = embed.to(device) 34 | pos = pos.to(device) 35 | data = torch.tensor(data).long() 36 | params = [] 37 | 38 | def weights(ins, outs): 39 | ws = torch.randn(ins, outs)*0.1 40 | ws = ws.to(device) 41 | ws = ws.requires_grad_(True) 42 | params.append(ws) 43 | return ws 44 | 45 | class Head(): 46 | def __init__(self): 47 | ''' 48 | If you want to compare this method to self-attention, uncomment the comments and remove "x @ self.wr" 49 | Note: you can also pass "x @ self.wr" through a non-linear layer for better performance. 50 | ''' 51 | self.wv = weights(n_emb, n_emb//4) 52 | # self.wq = weights(n_emb, n_emb//4) 53 | # self.wk = weights(n_emb, n_emb//4) 54 | self.wr = weights(n_emb, ins) 55 | 56 | def forward(self, x): 57 | v = x @ self.wv 58 | # q = x @ self.wq 59 | # k = x @ self.wk 60 | # attn = (q @ k.transpose(-2,-1)) / k.shape[0]**0.5 61 | attn = x @ self.wr 62 | tril = torch.tril(attn) 63 | tril = tril.masked_fill(tril==0, -1e10) 64 | rew = F.softmax(tril, dim=-1) 65 | x = rew @ v 66 | return x 67 | 68 | class Block(): 69 | def __init__(self): 70 | self.heads = [Head(), Head(), Head(), Head()] 71 | self.w0 = weights(n_emb, nodes) 72 | self.w1 = weights(nodes, n_emb) 73 | 74 | def forward(self, x): 75 | x = torch.cat([head.forward(x) for head in self.heads], dim=-1) 76 | x = torch.relu(x @ self.w0) 77 | x = torch.relu(x @ self.w1) 78 | return x 79 | 80 | class Model(): 81 | def __init__(self): 82 | self.blocks = [Block(), Block(), Block()] 83 | self.w2 = weights(n_emb, outs) 84 | 85 | def forward(self, x): 86 | x = embed[x] + pos 87 | x = x + self.blocks[0].forward(x) 88 | x = x + self.blocks[1].forward(x) 89 | x = x + self.blocks[2].forward(x) 90 | yh = (x @ self.w2) 91 | return yh 92 | 93 | model = Model() 94 | optimizer = torch.optim.Adam(params, lr) 95 | print("params:", sum(p.numel() for p in params)) 96 | 97 | import time 98 | t = time.time() 99 | 100 | ers = [] 101 | for i in range(5000): 102 | 103 | b = torch.randint(len(data)-ins, (100, )) 104 | xs = torch.stack([data[i:i+ins] for i in b]) 105 | ys = torch.stack([data[i+1:i+ins+1] for i in b]) 106 | xs = xs.to(device) 107 | ys = ys.to(device) 108 | 109 | yh = model.forward(xs) 110 | 111 | loss = F.cross_entropy(yh.view(-1, vocab_size) , ys.long().view(-1)) 112 | optimizer.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | e = loss.item() 116 | if i % 500 == 0: 117 | print("loss:", e) 118 | ers.append(e) 119 | 120 | print("time:", time.time()-t) 121 | 122 | s = xs[0] 123 | gen_text = "" 124 | for i in range(3000): 125 | yh = model.forward(s) 126 | prob = F.softmax(yh[-1, :]*1, dim=0) 127 | # pred = torch.argmax(yh[-1, :]).item() 128 | pred = torch.multinomial(prob, num_samples=1).item() 129 | s = torch.roll(s, -1) 130 | s[-1] = pred 131 | gen_text += itos[pred] 132 | 133 | print(gen_text) 134 | -------------------------------------------------------------------------------- /reweight-gpt-nonlinear.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | Reweight-GPT: An alternative to self-attetnion. 4 | This script is directly adapted from Andrej Karpathy's GPT project for easy comparsion. 5 | The self-attention parts are replaced with direct re-weighting mechanism 6 | ''' 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | 12 | 13 | # hyperparameters 14 | batch_size = 16 # how many independent sequences will we process in parallel? 15 | block_size = 32 # what is the maximum context length for predictions? 16 | max_iters = 5000 17 | eval_interval = 500 18 | learning_rate = 1e-3 19 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 20 | eval_iters = 200 21 | n_embd = 64 22 | n_head = 4 23 | n_layer = 4 24 | dropout = 0.0 25 | 26 | hidden_nodes = 20 # this parameter controls the number of the middle nodes for self.wr 27 | # ------------ 28 | 29 | torch.manual_seed(1337) 30 | 31 | # wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt 32 | with open('input.txt', 'r', encoding='utf-8') as f: 33 | text = f.read() 34 | 35 | # here are all the unique characters that occur in this text 36 | chars = sorted(list(set(text))) 37 | vocab_size = len(chars) 38 | # create a mapping from characters to integers 39 | stoi = { ch:i for i,ch in enumerate(chars) } 40 | itos = { i:ch for i,ch in enumerate(chars) } 41 | encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers 42 | decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string 43 | 44 | # Train and test splits 45 | data = torch.tensor(encode(text), dtype=torch.long) 46 | n = int(0.9*len(data)) # first 90% will be train, rest val 47 | train_data = data[:n] 48 | val_data = data[n:] 49 | 50 | # data loading 51 | def get_batch(split): 52 | # generate a small batch of data of inputs x and targets y 53 | data = train_data if split == 'train' else val_data 54 | ix = torch.randint(len(data) - block_size, (batch_size,)) 55 | x = torch.stack([data[i:i+block_size] for i in ix]) 56 | y = torch.stack([data[i+1:i+block_size+1] for i in ix]) 57 | x, y = x.to(device), y.to(device) 58 | return x, y 59 | 60 | @torch.no_grad() 61 | def estimate_loss(): 62 | out = {} 63 | model.eval() 64 | for split in ['train', 'val']: 65 | losses = torch.zeros(eval_iters) 66 | for k in range(eval_iters): 67 | X, Y = get_batch(split) 68 | logits, loss = model(X, Y) 69 | losses[k] = loss.item() 70 | out[split] = losses.mean() 71 | model.train() 72 | return out 73 | 74 | 75 | class Head(nn.Module): 76 | """ self-attention parts are commented. If you want to compare them with the direct reweighting uncomment them and remove wei=self.wr(x)""" 77 | 78 | def __init__(self, head_size): 79 | super().__init__() 80 | 81 | # self.key = nn.Linear(n_embd, head_size, bias=False) 82 | # self.query = nn.Linear(n_embd, head_size, bias=False) 83 | 84 | self.value = nn.Linear(n_embd, head_size, bias=False) 85 | self.wr = nn.Sequential( nn.Linear(n_embd, hidden_nodes), nn.ReLU(), nn.Linear(hidden_nodes, block_size),) # this can be a linear layer but added non-linearity with some hidden_nodes for finer control of param number 86 | self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) 87 | 88 | self.dropout = nn.Dropout(dropout) 89 | 90 | def forward(self, x): 91 | B,T,C = x.shape 92 | 93 | # k = self.key(x) # (B,T,hs) 94 | # q = self.query(x) # (B,T,hs) 95 | # wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T) 96 | 97 | wei = self.wr(x) 98 | wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T) 99 | wei = F.softmax(wei, dim=-1) # (B, T, T) 100 | wei = self.dropout(wei) 101 | v = self.value(x) # (B,T,hs) 102 | out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs) 103 | return out 104 | 105 | class MultiHeadAttention(nn.Module): 106 | """ multiple heads of self-attention in parallel """ 107 | 108 | def __init__(self, num_heads, head_size): 109 | super().__init__() 110 | self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) 111 | self.proj = nn.Linear(head_size * num_heads, n_embd) 112 | self.dropout = nn.Dropout(dropout) 113 | 114 | def forward(self, x): 115 | out = torch.cat([h(x) for h in self.heads], dim=-1) 116 | out = self.dropout(self.proj(out)) 117 | return out 118 | 119 | class FeedFoward(nn.Module): 120 | """ a simple linear layer followed by a non-linearity """ 121 | 122 | def __init__(self, n_embd): 123 | super().__init__() 124 | self.net = nn.Sequential( 125 | nn.Linear(n_embd, 4 * n_embd), 126 | nn.ReLU(), 127 | nn.Linear(4 * n_embd, n_embd), 128 | nn.Dropout(dropout), 129 | ) 130 | 131 | def forward(self, x): 132 | return self.net(x) 133 | 134 | class Block(nn.Module): 135 | """ Transformer block: communication followed by computation """ 136 | 137 | def __init__(self, n_embd, n_head): 138 | # n_embd: embedding dimension, n_head: the number of heads we'd like 139 | super().__init__() 140 | head_size = n_embd // n_head 141 | self.sa = MultiHeadAttention(n_head, head_size) 142 | self.ffwd = FeedFoward(n_embd) 143 | self.ln1 = nn.LayerNorm(n_embd) 144 | self.ln2 = nn.LayerNorm(n_embd) 145 | 146 | def forward(self, x): 147 | x = x + self.sa(self.ln1(x)) 148 | x = x + self.ffwd(self.ln2(x)) 149 | return x 150 | 151 | class GPTLanguageModel(nn.Module): 152 | 153 | def __init__(self): 154 | super().__init__() 155 | # each token directly reads off the logits for the next token from a lookup table 156 | self.token_embedding_table = nn.Embedding(vocab_size, n_embd) 157 | self.position_embedding_table = nn.Embedding(block_size, n_embd) 158 | self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)]) 159 | self.ln_f = nn.LayerNorm(n_embd) # final layer norm 160 | self.lm_head = nn.Linear(n_embd, vocab_size) 161 | 162 | # better init, not covered in the original GPT video, but important, will cover in followup video 163 | self.apply(self._init_weights) 164 | 165 | def _init_weights(self, module): 166 | if isinstance(module, nn.Linear): 167 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 168 | if module.bias is not None: 169 | torch.nn.init.zeros_(module.bias) 170 | elif isinstance(module, nn.Embedding): 171 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 172 | 173 | def forward(self, idx, targets=None): 174 | B, T = idx.shape 175 | 176 | # idx and targets are both (B,T) tensor of integers 177 | tok_emb = self.token_embedding_table(idx) # (B,T,C) 178 | pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C) 179 | x = tok_emb + pos_emb # (B,T,C) 180 | x = self.blocks(x) # (B,T,C) 181 | x = self.ln_f(x) # (B,T,C) 182 | logits = self.lm_head(x) # (B,T,vocab_size) 183 | 184 | if targets is None: 185 | loss = None 186 | else: 187 | B, T, C = logits.shape 188 | logits = logits.view(B*T, C) 189 | targets = targets.view(B*T) 190 | loss = F.cross_entropy(logits, targets) 191 | 192 | return logits, loss 193 | 194 | def generate(self, idx, max_new_tokens): 195 | # idx is (B, T) array of indices in the current context 196 | for _ in range(max_new_tokens): 197 | # crop idx to the last block_size tokens 198 | idx_cond = idx[:, -block_size:] 199 | # get the predictions 200 | logits, loss = self(idx_cond) 201 | # focus only on the last time step 202 | logits = logits[:, -1, :] # becomes (B, C) 203 | # apply softmax to get probabilities 204 | probs = F.softmax(logits, dim=-1) # (B, C) 205 | # sample from the distribution 206 | idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) 207 | # append sampled index to the running sequence 208 | idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) 209 | return idx 210 | 211 | model = GPTLanguageModel() 212 | m = model.to(device) 213 | # print the number of parameters in the model 214 | print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters') 215 | 216 | # create a PyTorch optimizer 217 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) 218 | 219 | for iter in range(max_iters): 220 | 221 | # every once in a while evaluate the loss on train and val sets 222 | if iter % eval_interval == 0 or iter == max_iters - 1: 223 | losses = estimate_loss() 224 | print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 225 | 226 | # sample a batch of data 227 | xb, yb = get_batch('train') 228 | 229 | # evaluate the loss 230 | logits, loss = model(xb, yb) 231 | optimizer.zero_grad(set_to_none=True) 232 | loss.backward() 233 | optimizer.step() 234 | 235 | # generate from the model 236 | context = torch.zeros((1, block_size), dtype=torch.long, device=device) 237 | context += xb[0, :] 238 | print(decode(m.generate(context, max_new_tokens=1000)[0].tolist())) 239 | --------------------------------------------------------------------------------