├── README.md └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # Attention-as-graph 2 | alternative way to calculating self attention 3 | 4 | 5 | > [!WARNING] 6 | > I may or may not work on it further, PR are welcome though 7 | 8 | 9 | look `main.py` this is preview 10 | 11 | ```python 12 | 13 | @dataclass 14 | class Node: 15 | idx: int 16 | value: Tensor 17 | adjacency_list: list[Edge] = field(default_factory=list) 18 | 19 | @dataclass 20 | class Edge: 21 | node: Node 22 | weight: Tensor 23 | 24 | 25 | 26 | def build_graph(nodes:list[Node],keys:Tensor,queries:Tensor): 27 | batch,seq_len,d_model = queries.shape 28 | for idx,curr_node in enumerate(nodes): 29 | # picking 1 to n keys 30 | keys_history = keys[:,:idx+1,:] 31 | 32 | # picking nth query 33 | curr_query = queries[:,idx,:] 34 | 35 | # here we take dot product (concise similarity) between current query 36 | # and all keys that contains in histoy of current node (token) 37 | similarity_values = curr_query@keys_history.transpose(-1,-2) 38 | 39 | # if DEBUGE: print(f"{keys_history.shape=} {curr_query.shape=} {similarity_values.shape=} ") 40 | similarity_values = similarity_values/math.sqrt(d_model) 41 | 42 | # after softmax you will get weights with indicates 43 | # how much current node want pay attention to past node 44 | attn = F.softmax(similarity_values.float(),dim=-1).type_as(keys) 45 | 46 | attn = attn.reshape(-1) # reshaping to make it simple 47 | # if DEBUGE: print(attn) 48 | 49 | # adding back edges in adjacency list of each node 50 | for nidx,node in enumerate(nodes[:idx+1]): 51 | edge_weight = attn[nidx] 52 | 53 | # if DEBUGE: print(f"{idx} attend to {nidx} node with {edge_weight:.2f}") 54 | edge = Edge(node=node,weight=edge_weight) 55 | 56 | # curent node is getting weighted edge with all past nodes 57 | curr_node.adjacency_list.append( 58 | edge 59 | ) 60 | return nodes 61 | ``` 62 | 63 | 64 | ## TODO 65 | - [ ] so inferace with tiny lm as poc 66 | - [ ] Add visuization 67 | - top nodes infuanceing current node 68 | 69 | 70 | 71 | --- 72 | its for education purpose, has no pratical use (unless added visualiztion) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | from dataclasses import dataclass,field 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | import math 10 | 11 | DEBUGE = True 12 | @dataclass 13 | class Node: 14 | idx: int 15 | value: Tensor 16 | adjacency_list: list[Edge] = field(default_factory=list) 17 | 18 | @dataclass 19 | class Edge: 20 | node: Node 21 | weight: Tensor 22 | 23 | 24 | # call this function for multi-headed attn 25 | def scaled_graph_attention(query:Tensor, key:Tensor, value:Tensor): 26 | batch,num_heads,seq_len,head_dim = query.shape 27 | assert batch==1, "batch size must be one" 28 | 29 | # key = key.reshape(batch,num_heads,seq_len,head_dim) 30 | # query = query.reshape(batch,num_heads,seq_len,head_dim) 31 | outputs = [] 32 | for head_idx in range(num_heads): 33 | q = query[:,head_idx,:,:] 34 | k = key[:,head_idx,:,:] 35 | v = value[:,head_idx,:,:] 36 | 37 | result = casual_self_attention_with_graph(q,k,v) 38 | outputs.append(result) 39 | 40 | output = torch.stack(outputs,dim=1) 41 | 42 | return output.reshape(batch,num_heads,seq_len,head_dim ) 43 | 44 | 45 | def casual_self_attention_with_graph(query:Tensor, key:Tensor, value:Tensor): 46 | batch,seq_len,d_model = query.shape 47 | nodes = [Node(idx,value[:,idx,:],[]) for idx in range(seq_len)] 48 | graph = build_graph(nodes,key,query) 49 | 50 | # traversing graph 51 | outputs = [] 52 | for r_idx,root in enumerate(graph): 53 | curr_value = torch.zeros(1,1,d_model) 54 | for edge in root.adjacency_list: 55 | curr_value += edge.node.value * edge.weight 56 | outputs.append(curr_value) 57 | 58 | output = torch.stack(outputs,dim=-2).squeeze(dim=2) 59 | 60 | output = output.reshape(batch,seq_len,d_model) 61 | 62 | return output 63 | 64 | def build_graph(nodes:list[Node],keys:Tensor,queries:Tensor): 65 | batch,seq_len,d_model = queries.shape 66 | for idx,curr_node in enumerate(nodes): 67 | # picking 1 to n keys 68 | keys_history = keys[:,:idx+1,:] 69 | 70 | # picking nth query 71 | curr_query = queries[:,idx,:] 72 | 73 | # here we take dot product (concise similarity) between current query 74 | # and all keys that contains in histoy of current node (token) 75 | similarity_values = curr_query@keys_history.transpose(-1,-2) 76 | 77 | # if DEBUGE: print(f"{keys_history.shape=} {curr_query.shape=} {similarity_values.shape=} ") 78 | similarity_values = similarity_values/math.sqrt(d_model) 79 | 80 | # after softmax you will get weights with indicates 81 | # how much current node want pay attention to past node 82 | attn = F.softmax(similarity_values.float(),dim=-1).type_as(keys) 83 | 84 | attn = attn.reshape(-1) # reshaping to make it simple 85 | # if DEBUGE: print(attn) 86 | 87 | # adding back edges in adjacency list of each node 88 | for nidx,node in enumerate(nodes[:idx+1]): 89 | edge_weight = attn[nidx] 90 | 91 | # if DEBUGE: print(f"{idx} attend to {nidx} node with {edge_weight:.2f}") 92 | edge = Edge(node=node,weight=edge_weight) 93 | 94 | # curent node is getting weighted edge with all past nodes 95 | curr_node.adjacency_list.append( 96 | edge 97 | ) 98 | return nodes 99 | 100 | 101 | 102 | 103 | @torch.no_grad 104 | def test_attn(): 105 | torch.manual_seed(6) 106 | batch = 1 107 | seq_len = 8 108 | d_model = 2**10 109 | num_heads = 2 110 | head_dim = int(d_model/num_heads) 111 | 112 | 113 | assert batch == 1, "Batch size must be 1 for this test" 114 | Wk = nn.Linear(d_model, d_model) 115 | Wq = nn.Linear(d_model, d_model) 116 | Wv = nn.Linear(d_model, d_model) 117 | x = torch.rand(batch, seq_len, d_model) 118 | 119 | key: Tensor = Wk(x) 120 | query: Tensor = Wq(x) 121 | value: Tensor = Wv(x) 122 | 123 | # reshape batch, num_heads, seq_len, head_dim 124 | key = key.reshape(batch, num_heads, seq_len, head_dim) 125 | query = query.reshape(batch, num_heads, seq_len, head_dim) 126 | value = value.reshape(batch, num_heads, seq_len, head_dim) 127 | 128 | mask = torch.triu(torch.ones(1,1,seq_len,seq_len) *-torch.inf,diagonal=1) 129 | scores = query@key.transpose(-1,-2) / math.sqrt(head_dim) 130 | scores = mask+scores 131 | 132 | attn_mtx = F.softmax(scores,dim=-1) 133 | out = attn_mtx@value 134 | 135 | output = scaled_graph_attention(query, key, value) 136 | 137 | 138 | assert torch.isclose(output,out,atol=1e-5).all() , "you need to debug buddy" 139 | 140 | print("IT WORKS !!!") 141 | 142 | ITER = 3 143 | for _ in range(ITER): 144 | test_attn() --------------------------------------------------------------------------------