├── arch_from_paper.JPG ├── __pycache__ └── saint.cpython-37.pyc ├── README.md └── saint.py /arch_from_paper.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arshadshk/SAINT-pytorch/HEAD/arch_from_paper.JPG -------------------------------------------------------------------------------- /__pycache__/saint.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arshadshk/SAINT-pytorch/HEAD/__pycache__/saint.cpython-37.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAINT-pytorch 2 | A Simple pyTorch implementation of "Towards an Appropriate Query, Key, and Value Computation for Knowledge Tracing" based on https://arxiv.org/abs/2002.07033. 3 | 4 | 5 | 6 | **SAINT**: Separated Self-AttentIve Neural Knowledge Tracing. SAINT has an encoder-decoder structure where exercise and response embedding sequence separately enter the encoder and the decoder respectively, which allows to stack attention layers multiple times. 7 | 8 | ## SAINT model architecture 9 | 10 | 11 | ## Usage 12 | ```python 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import numpy as np 17 | import copy 18 | 19 | from saint import saint, random_data 20 | 21 | seq_len = 100 22 | total_ex = 1200 23 | total_cat = 234 24 | total_in = 2 25 | 26 | in_ex, in_cat, in_de = random_data(64, 27 | seq_len , 28 | total_ex, 29 | total_cat, 30 | total_in) 31 | 32 | 33 | model = saint(dim_model=128, 34 | num_en=6, 35 | num_de=6, 36 | heads_en=8, 37 | heads_de=8, 38 | total_ex=total_ex, 39 | total_cat=total_cat, 40 | total_in=total_in ) 41 | 42 | outs = model(in_ex, in_cat, in_de) 43 | 44 | print(outs.shape) 45 | # torch.Size([64, 100, 1]) 46 | ``` 47 | 48 | ## Parameters 49 | - `dim_model`: int. 50 | Dimension of model ( embeddings, attention, linear layers). 51 | - `num_en`: int. 52 | Number of encoder layers. 53 | - `num_de`: int. 54 | Number of decoder layers. 55 | - `heads_en`: int. 56 | Number of heads in multi-head attention block in each layer of encoder. 57 | - `heads_de`: int. 58 | Number of heads in multi-head attention block in each layer of decoder. 59 | - `total_ex`: int. 60 | Total number of unique excercise. 61 | - `total_cat`: int. 62 | Total number of unique concept categories. 63 | - `total_in`: int. 64 | Total number of unique interactions. 65 | 66 | ## todo 67 | - change positional embedding to sine. 68 | 69 | ## Citations 70 | 71 | ```bibtex 72 | @article{choi2020towards, 73 | title={Towards an Appropriate Query, Key, and Value Computation for Knowledge Tracing}, 74 | author={Choi, Youngduck and Lee, Youngnam and Cho, Junghyun and Baek, Jineon and Kim, Byungsoo and Cha, Yeongmin and Shin, Dongmin and Bae, Chan and Heo, Jaewe}, 75 | journal={arXiv preprint arXiv:2002.07033}, 76 | year={2020} 77 | } 78 | ``` 79 | 80 | ```bibtex 81 | @misc{vaswani2017attention, 82 | title = {Attention Is All You Need}, 83 | author = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin}, 84 | year = {2017}, 85 | eprint = {1706.03762}, 86 | archivePrefix = {arXiv}, 87 | primaryClass = {cs.CL} 88 | } 89 | ``` -------------------------------------------------------------------------------- /saint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import copy 6 | 7 | 8 | class Feed_Forward_block(nn.Module): 9 | """ 10 | out = Relu( M_out*w1 + b1) *w2 + b2 11 | """ 12 | def __init__(self, dim_ff): 13 | super().__init__() 14 | self.layer1 = nn.Linear(in_features=dim_ff , out_features=dim_ff) 15 | self.layer2 = nn.Linear(in_features=dim_ff , out_features=dim_ff) 16 | 17 | def forward(self,ffn_in): 18 | return self.layer2( F.relu( self.layer1(ffn_in) ) ) 19 | 20 | 21 | class Encoder_block(nn.Module): 22 | """ 23 | M = SkipConct(Multihead(LayerNorm(Qin;Kin;Vin))) 24 | O = SkipConct(FFN(LayerNorm(M))) 25 | """ 26 | 27 | def __init__(self , dim_model, heads_en, total_ex ,total_cat, seq_len): 28 | super().__init__() 29 | self.seq_len = seq_len 30 | self.embd_ex = nn.Embedding( total_ex , embedding_dim = dim_model ) # embedings q,k,v = E = exercise ID embedding, category embedding, and positionembedding. 31 | self.embd_cat = nn.Embedding( total_cat, embedding_dim = dim_model ) 32 | self.embd_pos = nn.Embedding( seq_len , embedding_dim = dim_model ) #positional embedding 33 | 34 | self.multi_en = nn.MultiheadAttention( embed_dim= dim_model, num_heads= heads_en, ) # multihead attention ## todo add dropout, LayerNORM 35 | self.ffn_en = Feed_Forward_block( dim_model ) # feedforward block ## todo dropout, LayerNorm 36 | self.layer_norm1 = nn.LayerNorm( dim_model ) 37 | self.layer_norm2 = nn.LayerNorm( dim_model ) 38 | 39 | 40 | def forward(self, in_ex, in_cat, first_block=True): 41 | 42 | ## todo create a positional encoding ( two options numeric, sine) 43 | if first_block: 44 | in_ex = self.embd_ex( in_ex ) 45 | in_cat = self.embd_cat( in_cat ) 46 | #in_pos = self.embd_pos( in_pos ) 47 | #combining the embedings 48 | out = in_ex + in_cat #+ in_pos # (b,n,d) 49 | else: 50 | out = in_ex 51 | 52 | in_pos = get_pos(self.seq_len) 53 | in_pos = self.embd_pos( in_pos ) 54 | out = out + in_pos # Applying positional embedding 55 | 56 | out = out.permute(1,0,2) # (n,b,d) # print('pre multi', out.shape ) 57 | 58 | #Multihead attention 59 | n,_,_ = out.shape 60 | out = self.layer_norm1( out ) # Layer norm 61 | skip_out = out 62 | out, attn_wt = self.multi_en( out , out , out , 63 | attn_mask=get_mask(seq_len=n)) # attention mask upper triangular 64 | out = out + skip_out # skip connection 65 | 66 | #feed forward 67 | out = out.permute(1,0,2) # (b,n,d) 68 | out = self.layer_norm2( out ) # Layer norm 69 | skip_out = out 70 | out = self.ffn_en( out ) 71 | out = out + skip_out # skip connection 72 | 73 | return out 74 | 75 | 76 | class Decoder_block(nn.Module): 77 | """ 78 | M1 = SkipConct(Multihead(LayerNorm(Qin;Kin;Vin))) 79 | M2 = SkipConct(Multihead(LayerNorm(M1;O;O))) 80 | L = SkipConct(FFN(LayerNorm(M2))) 81 | """ 82 | 83 | def __init__(self,dim_model ,total_in, heads_de,seq_len ): 84 | super().__init__() 85 | self.seq_len = seq_len 86 | self.embd_in = nn.Embedding( total_in , embedding_dim = dim_model ) #interaction embedding 87 | self.embd_pos = nn.Embedding( seq_len , embedding_dim = dim_model ) #positional embedding 88 | self.multi_de1 = nn.MultiheadAttention( embed_dim= dim_model, num_heads= heads_de ) # M1 multihead for interaction embedding as q k v 89 | self.multi_de2 = nn.MultiheadAttention( embed_dim= dim_model, num_heads= heads_de ) # M2 multihead for M1 out, encoder out, encoder out as q k v 90 | self.ffn_en = Feed_Forward_block( dim_model ) # feed forward layer 91 | 92 | self.layer_norm1 = nn.LayerNorm( dim_model ) 93 | self.layer_norm2 = nn.LayerNorm( dim_model ) 94 | self.layer_norm3 = nn.LayerNorm( dim_model ) 95 | 96 | 97 | def forward(self, in_in, en_out,first_block=True): 98 | 99 | ## todo create a positional encoding ( two options numeric, sine) 100 | if first_block: 101 | in_in = self.embd_in( in_in ) 102 | 103 | #combining the embedings 104 | out = in_in #+ in_cat #+ in_pos # (b,n,d) 105 | else: 106 | out = in_in 107 | 108 | in_pos = get_pos(self.seq_len) 109 | in_pos = self.embd_pos( in_pos ) 110 | out = out + in_pos # Applying positional embedding 111 | 112 | out = out.permute(1,0,2) # (n,b,d)# print('pre multi', out.shape ) 113 | n,_,_ = out.shape 114 | 115 | #Multihead attention M1 ## todo verify if E to passed as q,k,v 116 | out = self.layer_norm1( out ) 117 | skip_out = out 118 | out, attn_wt = self.multi_de1( out , out , out, 119 | attn_mask=get_mask(seq_len=n)) # attention mask upper triangular 120 | out = skip_out + out # skip connection 121 | 122 | #Multihead attention M2 ## todo verify if E to passed as q,k,v 123 | en_out = en_out.permute(1,0,2) # (b,n,d)-->(n,b,d) 124 | en_out = self.layer_norm2( en_out ) 125 | skip_out = out 126 | out, attn_wt = self.multi_de2( out , en_out , en_out, 127 | attn_mask=get_mask(seq_len=n)) # attention mask upper triangular 128 | out = out + skip_out 129 | 130 | #feed forward 131 | out = out.permute(1,0,2) # (b,n,d) 132 | out = self.layer_norm3( out ) # Layer norm 133 | skip_out = out 134 | out = self.ffn_en( out ) 135 | out = out + skip_out # skip connection 136 | 137 | return out 138 | 139 | def get_clones(module, N): 140 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 141 | 142 | 143 | def get_mask(seq_len): 144 | ##todo add this to device 145 | return torch.from_numpy( np.triu(np.ones((seq_len ,seq_len)), k=1).astype('bool')) 146 | 147 | def get_pos(seq_len): 148 | # use sine positional embeddinds 149 | return torch.arange( seq_len ).unsqueeze(0) 150 | 151 | class saint(nn.Module): 152 | def __init__(self,dim_model,num_en, num_de ,heads_en, total_ex ,total_cat,total_in,heads_de,seq_len ): 153 | super().__init__( ) 154 | 155 | self.num_en = num_en 156 | self.num_de = num_de 157 | 158 | self.encoder = get_clones( Encoder_block(dim_model, heads_en , total_ex ,total_cat,seq_len) , num_en) 159 | self.decoder = get_clones( Decoder_block(dim_model ,total_in, heads_de,seq_len) , num_de) 160 | 161 | self.out = nn.Linear(in_features= dim_model , out_features=1) 162 | 163 | def forward(self,in_ex, in_cat, in_in ): 164 | 165 | ## pass through each of the encoder blocks in sequence 166 | first_block = True 167 | for x in range(self.num_en): 168 | if x>=1: 169 | first_block = False 170 | in_ex = self.encoder[x]( in_ex, in_cat ,first_block=first_block) 171 | in_cat = in_ex # passing same output as q,k,v to next encoder block 172 | 173 | 174 | ## pass through each decoder blocks in sequence 175 | first_block = True 176 | for x in range(self.num_de): 177 | if x>=1: 178 | first_block = False 179 | in_in = self.decoder[x]( in_in , en_out= in_ex, first_block=first_block ) 180 | 181 | ## Output layer 182 | in_in = torch.sigmoid( self.out( in_in ) ) 183 | return in_in 184 | 185 | 186 | ## forward prop on dummy data 187 | 188 | seq_len = 100 189 | total_ex = 1200 190 | total_cat = 234 191 | total_in = 2 192 | 193 | 194 | def random_data(bs, seq_len , total_ex, total_cat, total_in = 2): 195 | ex = torch.randint( 0 , total_ex ,(bs , seq_len) ) 196 | cat = torch.randint( 0 , total_cat ,(bs , seq_len) ) 197 | de = torch.randint( 0 , total_in ,(bs , seq_len) ) 198 | return ex,cat, de 199 | 200 | 201 | in_ex, in_cat, in_de = random_data(64, seq_len , total_ex, total_cat, total_in) 202 | 203 | 204 | model = saint(dim_model=128, 205 | num_en=6, 206 | num_de=6, 207 | heads_en=8, 208 | heads_de=8, 209 | total_ex=total_ex, 210 | total_cat=total_cat, 211 | total_in=total_in, 212 | seq_len=seq_len 213 | ) 214 | 215 | outs = model(in_ex, in_cat, in_de) 216 | 217 | print(outs.shape) 218 | --------------------------------------------------------------------------------