├── README.md ├── last_query_model.py └── lqtrnn.JPG /README.md: -------------------------------------------------------------------------------- 1 | # Last-Query-Transformer-RNN 2 | 3 | Implementation of the paper [Last Query Transformer RNN for knowledge tracing](https://arxiv.org/abs/2102.05038). The novel point of the model is that it only uses the last input as query in transformer encoder, instead of all sequence, which makes QK matrix multiplication in transformer Encoder to have O(L) time complexity, instead of O(L^2). It allows the model to input longer sequence. 4 | 5 | ## Model architecture 6 | 7 | 8 | ## Usage 9 | ```python 10 | from last_query_model import * 11 | 12 | seq_len = 100 13 | total_ex = 1200 14 | total_cat = 234 15 | total_in = 2 16 | 17 | 18 | in_ex, in_cat, in_in = random_data(64, seq_len , total_ex, total_cat, total_in) 19 | 20 | model = last_query_model(dim_model=128, 21 | heads_en=1, 22 | total_ex=total_ex, 23 | total_cat=total_cat, 24 | seq_len=seq_len, 25 | total_in=2 26 | ) 27 | 28 | outs,attn_w = model(in_ex, in_cat,in_in) 29 | 30 | print('Output lstm shape- ',outs.shape) 31 | 32 | ``` 33 | 34 | 35 | ## Parameters 36 | - `seq_len` : int. 37 | Sequence length of inputs. 38 | - `dim_model`: int. 39 | Dimension of model ( embeddings, attention, linear layers). 40 | - `heads_en`: int. 41 | Number of heads in multi-head attention block in each layer of encoder. 42 | - `total_ex`: int. 43 | Total number of unique excercise. 44 | - `total_cat`: int. 45 | Total number of unique concept categories. 46 | - `total_in`: int. 47 | Total number of unique interactions. 48 | - `use_lstm`: bool. 49 | Use LSTM layer after multi-head attention. (default : True) 50 | 51 | 52 | 53 | 54 | This model is 1st place solution in kaggle competetion- [Riiid! Answer Correctness Prediction](https://www.kaggle.com/c/riiid-test-answer-prediction) 55 | 56 | ## Note 57 | I have just implemented this model. The Credits for model architecture and solution to goes to [Keetar](https://www.kaggle.com/keetar). Refer this [link](https://www.kaggle.com/c/riiid-test-answer-prediction/discussion/218318) for more information. 58 | 59 | ## Citations 60 | 61 | ```bibtex 62 | @article{jeon2021last, 63 | title={Last Query Transformer RNN for knowledge tracing}, 64 | author={Jeon, SeungKee}, 65 | journal={arXiv preprint arXiv:2102.05038}, 66 | year={2021} 67 | } 68 | ``` 69 | 70 | ```bibtex 71 | @misc{vaswani2017attention, 72 | title = {Attention Is All You Need}, 73 | author = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin}, 74 | year = {2017}, 75 | eprint = {1706.03762}, 76 | archivePrefix = {arXiv}, 77 | primaryClass = {cs.CL} 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /last_query_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import copy 6 | 7 | """ 8 | Encoder --> LSTM --> dense 9 | 10 | """ 11 | 12 | class Feed_Forward_block(nn.Module): 13 | """ 14 | out = Relu( M_out*w1 + b1) *w2 + b2 15 | """ 16 | def __init__(self, dim_ff): 17 | super().__init__() 18 | self.layer1 = nn.Linear(in_features=dim_ff , out_features=dim_ff) 19 | self.layer2 = nn.Linear(in_features=dim_ff , out_features=dim_ff) 20 | 21 | def forward(self,ffn_in): 22 | return self.layer2( F.relu( self.layer1(ffn_in) ) ) 23 | 24 | 25 | class last_query_model(nn.Module): 26 | """ 27 | Embedding --> MLH --> LSTM 28 | """ 29 | def __init__(self , dim_model, heads_en, total_ex ,total_cat, total_in,seq_len, use_lstm=True): 30 | super().__init__() 31 | self.seq_len = seq_len 32 | self.embd_ex = nn.Embedding( total_ex , embedding_dim = dim_model ) # embedings q,k,v = E = exercise ID embedding, category embedding, and positionembedding. 33 | self.embd_cat = nn.Embedding( total_cat, embedding_dim = dim_model ) 34 | self.embd_in = nn.Embedding( total_in , embedding_dim = dim_model ) #positional embedding 35 | 36 | self.multi_en = nn.MultiheadAttention( embed_dim= dim_model, num_heads= heads_en,dropout=0.1 ) # multihead attention ## todo add dropout, LayerNORM 37 | self.ffn_en = Feed_Forward_block( dim_model ) # feedforward block ## todo dropout, LayerNorm 38 | self.layer_norm1 = nn.LayerNorm( dim_model ) 39 | self.layer_norm2 = nn.LayerNorm( dim_model ) 40 | 41 | self.use_lstm = use_lstm 42 | if self.use_lstm: 43 | self.lstm = nn.LSTM(input_size= dim_model, hidden_size= dim_model , num_layers=1) 44 | 45 | self.out = nn.Linear(in_features= dim_model , out_features=1) 46 | 47 | def forward(self, in_ex, in_cat, in_in, first_block=True): 48 | first_block = True 49 | if first_block: 50 | in_ex = self.embd_ex( in_ex ) 51 | in_ex = nn.Dropout(0.1)(in_ex) 52 | 53 | in_cat = self.embd_cat( in_cat ) 54 | in_cat = nn.Dropout(0.1)(in_cat) 55 | 56 | #print("response embedding ", in_in.shape , '\n' , in_in[0]) 57 | in_in = self.embd_in(in_in) 58 | in_in = nn.Dropout(0.1)(in_in) 59 | 60 | #in_pos = self.embd_pos( in_pos ) 61 | #combining the embedings 62 | out = in_ex + in_cat + in_in #+ in_pos # (b,n,d) 63 | 64 | else: 65 | out = in_ex 66 | 67 | #in_pos = get_pos(self.seq_len) 68 | #in_pos = self.embd_pos( in_pos ) 69 | #out = out + in_pos # Applying positional embedding 70 | 71 | out = out.permute(1,0,2) # (n,b,d) # print('pre multi', out.shape ) 72 | 73 | #Multihead attention 74 | n,_,_ = out.shape 75 | out = self.layer_norm1( out ) # Layer norm 76 | skip_out = out 77 | 78 | out, attn_wt = self.multi_en( out[-1:,:,:] , out , out ) # Q,K,V 79 | # #attn_mask=get_mask(seq_len=n)) # attention mask upper triangular 80 | #print('MLH out shape', out.shape) 81 | out = out + skip_out # skip connection 82 | 83 | #LSTM 84 | if self.use_lstm: 85 | out,_ = self.lstm( out ) # seq_len, batch, input_size 86 | out = out[-1:,:,:] 87 | 88 | #feed forward 89 | out = out.permute(1,0,2) # (b,n,d) 90 | out = self.layer_norm2( out ) # Layer norm 91 | skip_out = out 92 | out = self.ffn_en( out ) 93 | out = out + skip_out # skip connection 94 | 95 | out = self.out( out ) 96 | 97 | return out.squeeze(-1), 0 98 | 99 | def get_clones(module, N): 100 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 101 | 102 | def get_mask(seq_len): 103 | ##todo add this to device 104 | return torch.from_numpy( np.triu(np.ones((1 ,seq_len)), k=1).astype('bool')) 105 | 106 | def get_pos(seq_len): 107 | # use sine positional embeddinds 108 | return torch.arange( seq_len ).unsqueeze(0) 109 | 110 | 111 | 112 | def random_data(bs, seq_len , total_ex, total_cat, total_in = 2): 113 | ex = torch.randint( 0 , total_ex ,(bs , seq_len) ) 114 | cat = torch.randint( 0 , total_cat ,(bs , seq_len) ) 115 | res = torch.randint( 0 , total_in ,(bs , seq_len) ) 116 | return ex,cat, res 117 | 118 | """ 119 | seq_len = 100 120 | total_ex = 1200 121 | total_cat = 234 122 | total_in = 2 123 | 124 | 125 | in_ex, in_cat, in_in = random_data(64, seq_len , total_ex, total_cat, total_in) 126 | 127 | model = last_query_model(dim_model=128, 128 | heads_en=1, 129 | total_ex=total_ex, 130 | total_cat=total_cat, 131 | seq_len=seq_len, 132 | total_in=2 133 | ) 134 | 135 | outs = model(in_ex, in_cat,in_in) 136 | 137 | print('Output lstm shape- ',outs[0].shape) 138 | """ 139 | 140 | -------------------------------------------------------------------------------- /lqtrnn.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arshadshk/Last_Query_Transformer_RNN-PyTorch/dca89a95eec2177a17424417291803756911d8e3/lqtrnn.JPG --------------------------------------------------------------------------------