├── .gitignore ├── requirements.txt ├── README.md ├── config.py ├── utils.py ├── wordtest.py ├── main.py ├── multihead_attn.py ├── student_model.py ├── run.py └── dataset.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | coloredlogs==14.0 2 | humanfriendly==7.1.1 3 | joblib==0.14.1 4 | numpy==1.18.1 5 | prefetch-generator==1.0.1 6 | protobuf==3.11.3 7 | PySnooper==0.3.0 8 | scikit-learn==0.22.1 9 | scipy==1.4.1 10 | six==1.14.0 11 | sklearn==0.0 12 | tensorboardX==2.0 13 | torch==1.4.0 14 | TorchSnooper==0.8 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAKT 2 | The implementation is based on Pytorch, there may be a little bugs in it because I modified the SAKT model to run my own dataset, though I reset my code to make it the same as SAKT paper, but I am not sure if there is any bugs in it and I didn't run it on any other dataset, so my implementation is for reference only. -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | class DefaultConfig(object): 2 | model = 'SAKT' 3 | train_data = "" # train_data_path 4 | test_data = "" 5 | batch_size = 256 6 | state_size = 200 7 | num_heads = 5 8 | max_len = 50 9 | dropout = 0.1 10 | max_epoch = 10 11 | lr = 3e-3 12 | lr_decay = 0.9 13 | max_grad_norm = 1.0 14 | weight_decay = 0 # l2正则化因子 15 | 16 | opt = DefaultConfig() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | 5 | def subsequent_mask(size): 6 | "Mask out subsequent positions." 7 | attn_shape = (1, size, size) 8 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 9 | return torch.from_numpy(subsequent_mask) == 0 10 | 11 | def make_std_mask(x, pad): 12 | "Create a mask to hide padding and future words." 13 | mask = torch.unsqueeze((x!=pad), -1) 14 | 15 | # tgt_mask是mask掉pad,sub_mask是mask掉future words 16 | # print('tgt_mask size before: ', tgt_mask.size()) 17 | tgt_mask = mask & Variable( 18 | subsequent_mask(x.size(-1)).type_as(mask.data)) 19 | # print('tgt_mask size after: ', tgt_mask.size()) 20 | return tgt_mask -------------------------------------------------------------------------------- /wordtest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import coloredlogs 3 | import pickle 4 | 5 | logger = logging.getLogger('__file__') 6 | coloredlogs.install(level='INFO', logger=logger) 7 | 8 | def pickle_io(path, mode='r', obj=None): 9 | """ 10 | Convinient pickle load and dump. 11 | """ 12 | if mode in ['rb', 'r']: 13 | logger.info("Loading obj from {}...".format(path)) 14 | with open(path, 'rb') as f: 15 | obj = pickle.load(f) 16 | logger.info("Load obj successfully!") 17 | return obj 18 | elif mode in ['wb', 'w']: 19 | logger.info("Dumping obj to {}...".format(path)) 20 | with open(path, 'wb') as f: 21 | pickle.dump(obj, f) 22 | logger.info("Dump obj successfully!") 23 | 24 | class WordTestResource(object): 25 | 26 | def __init__(self, resource_path, verbose=False): 27 | 28 | resource = pickle_io(resource_path, mode='r') 29 | 30 | self.id2index = resource['id2index'] 31 | self.index2id = resource['index2id'] 32 | self.num_skills = len(self.id2index) 33 | 34 | if verbose: 35 | self.word2id = resource['word2id'] 36 | self.id2all = resource['id2all'] 37 | # rank0 already be set to a large number 38 | self.words_by_rank = resource['words_by_rank'] 39 | self.pos2id = resource['pos2id'] 40 | self.words_by_rank.sort(key=lambda x: x[u'rank']) 41 | self.id_by_rank = [x[u'word_id'] for x in self.words_by_rank] 42 | 43 | def str2bool(s): 44 | if s not in {'False', 'True'}: 45 | raise ValueError('Not a valid boolean string') 46 | return s == 'True' -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from dataset import Data 5 | from dataset import DataLoaderX 6 | from config import DefaultConfig 7 | from student_model import student_model 8 | from run import run_epoch 9 | 10 | opt = DefaultConfig() 11 | 12 | if __name__ == '__main__': 13 | train_dataset = Data(train=True) 14 | test_dataset = Data(train=False) 15 | train_loader = DataLoaderX(train_dataset, batch_size=opt.batch_size, num_workers=4, pin_memory=True, shuffle=True) 16 | test_loader = DataLoaderX(test_dataset, batch_size=opt.batch_size, num_workers=4, pin_memory=True) 17 | num_skills = train_dataset.max_skill_num + 1 18 | 19 | m = student_model(num_skills=num_skills, state_size=opt.state_size, 20 | num_heads=opt.num_heads, dropout=opt.dropout, infer=False) 21 | 22 | torch.backends.cudnn.benchmark = True 23 | best_auc = 0 24 | optimizer = optim.Adam(m.parameters(), lr=opt.lr, weight_decay=opt.weight_decay) 25 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=opt.lr_decay) 26 | criterion = nn.BCELoss() 27 | for epoch in range(opt.max_epoch): 28 | rmse, auc, r2, acc = run_epoch(m, train_loader, optimizer, scheduler, criterion, 29 | num_skills=num_skills, epoch_id=epoch, is_training=True) 30 | print('Epoch %d:\nTrain metrics: auc: %.3f, acc: %.3f, rmse: %.3f, r2: %.3f' \ 31 | % (epoch + 1, auc, acc, rmse, r2)) 32 | rmse, auc, r2, acc = run_epoch(m, test_loader, optimizer, scheduler, criterion, 33 | num_skills=num_skills, epoch_id=epoch, is_training=False) 34 | print('\nTest metrics: auc: %.3f, acc: %.3f, rmse: %.3f, r2: %.3f' \ 35 | % (auc, acc, rmse ,r2)) 36 | 37 | if auc > best_auc: 38 | best_auc = auc 39 | torch.save(m.state_dict(), 'models/sakt_model_auc_{}.pkl'.format(int(best_auc * 1000))) -------------------------------------------------------------------------------- /multihead_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | import torch.nn.functional as F 6 | import copy 7 | from torch.nn import LayerNorm 8 | 9 | 10 | def clones(module, N): 11 | "Produce N identical layers." 12 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 13 | 14 | 15 | def attention(query, key, value, key_masks=None, query_masks=None, future_masks=None, dropout=None, infer=False): 16 | "Compute 'Scaled Dot Product Attention'" 17 | d_k = query.size(-1) 18 | layernorm = LayerNorm(d_k).cuda() 19 | # query shape = [nbatches, h, T_q, d_k] key shape = [nbatches, h, T_k, d_k] == value shape 20 | # scores shape = [nbatches, h, T_q, T_k] == p_attn shape 21 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 22 | / math.sqrt(d_k) 23 | # if key_masks is not None: 24 | # scores = scores.masked_fill(key_masks.unsqueeze(1).cuda() == 0, -1e9) 25 | if future_masks is not None: 26 | scores = scores.masked_fill(future_masks.unsqueeze(0).cuda() == 0, -1e9) 27 | 28 | 29 | p_attn = F.softmax(scores, dim=-1) 30 | outputs = p_attn 31 | # if query_masks is not None: 32 | # outputs = outputs * query_masks.unsqueeze(1) 33 | if dropout is not None: 34 | outputs = dropout(outputs) 35 | outputs = torch.matmul(outputs, value) 36 | 37 | outputs += query 38 | return layernorm(outputs), p_attn 39 | 40 | 41 | class MultiHeadedAttention(nn.Module): 42 | def __init__(self, h, d_model, dropout=0.2, infer=False): 43 | "Take in model size and number of heads." 44 | super(MultiHeadedAttention, self).__init__() 45 | assert d_model % h == 0 46 | # We assume d_v always equals d_k 47 | self.d_k = d_model // h 48 | self.h = h 49 | self.linears = clones(nn.Linear(d_model, d_model), 4) 50 | self.attn = None 51 | self.dropout = nn.Dropout(p=dropout) 52 | self.layernorm = LayerNorm(d_model).cuda() 53 | self.infer = infer 54 | 55 | def forward(self, query, key, value, key_masks=None, query_masks=None, future_masks=None): 56 | nbatches = query.size(0) 57 | 58 | # 1) Do all the linear projections in batch from d_model => h x d_k 59 | query, key, value = \ 60 | [F.relu(l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2), inplace=True) 61 | for l, x in zip(self.linears, (query, key, value))] 62 | # k v shape = [nbatches, h, T_k, d_k], d_k * h = d_model 63 | # q shape = [nbatches, h, T_q, d_k] 64 | # 2) Apply attention on all the projected vectors in batch. 65 | x, self.attn = attention(query, key, value, query_masks=query_masks, 66 | key_masks=key_masks, future_masks=future_masks, dropout=self.dropout, infer=self.infer) 67 | 68 | # 3) "Concat" using a view and apply a final linear. 69 | x = x.transpose(1, 2).contiguous() \ 70 | .view(nbatches, -1, self.h * self.d_k) 71 | return self.layernorm(x) -------------------------------------------------------------------------------- /student_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from config import DefaultConfig 6 | from utils import subsequent_mask 7 | from torch.autograd import Variable 8 | from multihead_attn import MultiHeadedAttention 9 | from torch.nn import LayerNorm 10 | 11 | opt = DefaultConfig() 12 | 13 | class PositionalEncoding(nn.Module): 14 | "Implement the PE function." 15 | 16 | def __init__(self, state_size, dropout=0.1, max_len=50): 17 | super(PositionalEncoding, self).__init__() 18 | self.dropout = nn.Dropout(p=dropout) 19 | 20 | # Compute the positional encodings once in log space. 21 | self.pe = torch.zeros(max_len, state_size) 22 | position = torch.arange(0.0, max_len).unsqueeze(1) 23 | div_term = torch.exp(torch.arange(0.0, state_size, 2) * 24 | -(math.log(10000.0) / state_size)) 25 | self.pe[:, 0::2] = torch.sin(position * div_term) 26 | self.pe[:, 1::2] = torch.cos(position * div_term) 27 | self.pe = self.pe.unsqueeze(0) 28 | def forward(self, x): 29 | x = x + Variable(self.pe[:, :x.size(1)], 30 | requires_grad=False) 31 | return self.dropout(x) 32 | 33 | 34 | class student_model(nn.Module): 35 | 36 | def __init__(self, num_skills, state_size, num_heads=2, dropout=0.2, infer=False): 37 | super(student_model, self).__init__() 38 | self.infer = infer 39 | self.num_skills = num_skills 40 | self.state_size = state_size 41 | # we use the (num_skills * 2 + 1) as key padding_index 42 | self.embedding = nn.Embedding(num_embeddings=num_skills*2+1, 43 | embedding_dim=state_size) 44 | # padding_idx=num_skills*2 45 | # self.position_embedding = PositionalEncoding(state_size) 46 | self.position_embedding = nn.Embedding(num_embeddings=opt.max_len-1, 47 | embedding_dim=state_size) 48 | # we use the (num_skills + 1) as query padding_index 49 | self.problem_embedding = nn.Embedding(num_embeddings=num_skills+1, 50 | embedding_dim=state_size) 51 | # padding_idx=num_skills) 52 | self.multi_attn = MultiHeadedAttention(h=num_heads, d_model=state_size, dropout=dropout, infer=self.infer) 53 | self.feedforward1 = nn.Linear(in_features=state_size, out_features=state_size) 54 | self.feedforward2 = nn.Linear(in_features=state_size, out_features=state_size) 55 | self.pred_layer = nn.Linear(in_features=state_size, out_features=num_skills) 56 | self.dropout = nn.Dropout(dropout) 57 | self.layernorm = LayerNorm(state_size) 58 | 59 | def forward(self, x, problems, target_index): 60 | # self.key_masks = torch.unsqueeze( (x!=self.num_skills*2).int(), -1) 61 | # self.problem_masks = torch.unsqueeze( (problems!=self.num_skills).int(), -1) 62 | x = self.embedding(x) 63 | pe = self.position_embedding(torch.arange(x.size(1)).unsqueeze(0).cuda()) 64 | x += pe 65 | # x = self.position_embedding(x) 66 | problems = self.problem_embedding(problems) 67 | # self.key_masks = self.key_masks.type_as(x) 68 | # self.problem_masks = self.problem_masks.type_as(problems) 69 | # x *= self.key_masks 70 | # problems *= self.problem_masks 71 | x = self.dropout(x) 72 | res = self.multi_attn(query=self.layernorm(problems), key=x, value=x, 73 | key_masks=None, query_masks=None, future_masks=None) 74 | outputs = F.relu(self.feedforward1(res)) 75 | outputs = self.dropout(outputs) 76 | outputs = self.dropout(self.feedforward2(outputs)) 77 | # Residual connection 78 | outputs += self.layernorm(res) 79 | outputs = self.layernorm(outputs) 80 | logits = self.pred_layer(outputs) 81 | logits = logits.contiguous().view(logits.size(0) * opt.max_len - 1, -1) 82 | logits = logits.contiguous().view(-1) 83 | selected_logits = torch.gather(logits, 0, torch.LongTensor(target_index).cuda()) 84 | return selected_logits -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | from dataset import DataPrefetcher 6 | from config import DefaultConfig 7 | from sklearn.metrics import mean_squared_error 8 | from math import sqrt 9 | from sklearn import metrics 10 | from sklearn.metrics import r2_score 11 | 12 | opt = DefaultConfig() 13 | 14 | def run_epoch(m, dataloader, optimizer, scheduler, criterion, num_skills, 15 | epoch_id=None, writer=None, is_training=True): 16 | epoch_start_time = time.time() 17 | if is_training: 18 | m.train() 19 | else: 20 | m.eval() 21 | m.cuda() 22 | actual_labels = [] 23 | pred_labels = [] 24 | num_batch = len(dataloader) 25 | prefetcher = DataPrefetcher(dataloader, device='cuda') 26 | batch = prefetcher.next() 27 | k = 0 28 | 29 | if is_training: 30 | while batch is not None: 31 | target_index = [] 32 | x, problems, correctness = batch 33 | x = x.long() 34 | problems = problems.long() 35 | correctness = correctness.view(-1).float() 36 | 37 | actual_labels += list(np.array(correctness)) 38 | offset = 0 39 | helper = np.array(problems.cpu()).reshape(-1) 40 | for i in range(problems.size(0)): 41 | for j in range(problems.size(1)): 42 | target_index.append((offset + helper[i * problems.size(1) + j + 1])) 43 | offset += num_skills 44 | logits = m(x, problems, target_index) 45 | pred = torch.sigmoid(logits) 46 | loss = criterion(pred, correctness.cuda()) 47 | optimizer.zero_grad() 48 | loss.backward() 49 | nn.utils.clip_grad_norm_(m.parameters(), opt.max_grad_norm) 50 | optimizer.step() 51 | scheduler.step() 52 | pred_labels += list(np.array(pred.data.cpu())) 53 | batch = prefetcher.next() 54 | k += 1 55 | if k % 500 == 0: 56 | print('\r batch{}/{}'.format(k, num_batch), end='') 57 | if k >= num_batch - 1: 58 | break 59 | else: 60 | with torch.no_grad(): 61 | while batch is not None: 62 | target_index = [] 63 | x, problems, correctness = batch 64 | x = x.long() 65 | actual_num_problems = torch.sum(problems != num_skills, dim=1) 66 | num_problems = problems.size(1) 67 | problems = problems.long() 68 | correctness = correctness.view(-1).float() 69 | offset = 0 70 | helper = np.array(problems.cpu()).reshape(-1) 71 | for i in range(problems.size(0)): 72 | for j in range(problems.size(1)): 73 | target_index.append((offset + helper[i * problems.size(1) + j])) 74 | offset += num_skills 75 | 76 | logits = m(x, problems, target_index) 77 | pred = torch.sigmoid(logits) 78 | for J in range(x.size(0)): 79 | actual_num_problem = actual_num_problems[J] 80 | num_to_throw = num_problems - actual_num_problem 81 | 82 | pred[J * num_problems:J * num_problems + num_to_throw] = correctness[ 83 | J * num_problems:J * num_problems + num_to_throw] 84 | actual_labels += list(np.array(correctness)) 85 | pred_labels += list(np.array(pred.data.cpu())) 86 | batch = prefetcher.next() 87 | k += 1 88 | if k % 500 == 0: 89 | print('\r batch{}/{}'.format(k, num_batch), end='') 90 | if k >= num_batch - 1: 91 | break 92 | 93 | rmse = sqrt(mean_squared_error(actual_labels, pred_labels)) 94 | fpr, tpr, thresholds = metrics.roc_curve(actual_labels, pred_labels, pos_label=1) 95 | auc = metrics.auc(fpr, tpr) 96 | r2 = r2_score(actual_labels, pred_labels) 97 | acc = metrics.accuracy_score(actual_labels, np.array(pred_labels) >= 0.5) 98 | epoch_end_time = time.time() 99 | print('Epoch costs %.2f s' % (epoch_end_time - epoch_start_time)) 100 | return rmse, auc, r2, acc -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import torch 3 | import time 4 | import itertools 5 | import numpy as np 6 | from config import DefaultConfig 7 | from wordtest import WordTestResource 8 | from torch.utils.data import Dataset 9 | from torch.utils.data import DataLoader 10 | from prefetch_generator import BackgroundGenerator 11 | 12 | opt = DefaultConfig() 13 | 14 | class Data(Dataset): 15 | def __init__(self, train=True): 16 | start_time = time.time() 17 | if train: 18 | fileName = opt.train_data 19 | else: 20 | fileName = opt.test_data 21 | self.students = [] 22 | self.max_skill_num = 0 23 | begin_index = 1e9 24 | with open(fileName, "r") as csvfile: 25 | for num_ques, ques, ans in itertools.zip_longest(*[csvfile] * 3): 26 | num_ques = int(num_ques.strip().strip(',')) 27 | ques = [int(q) for q in ques.strip().strip(',').split(',')] 28 | ans = [int(a) for a in ans.strip().strip(',').split(',')] 29 | tmp_max_skill = max(ques) 30 | tmp_min_skill = min(ques) 31 | begin_index = min(tmp_min_skill, begin_index) 32 | self.max_skill_num = max(tmp_max_skill, self.max_skill_num) 33 | 34 | if (num_ques <= 2): 35 | continue 36 | elif num_ques <= opt.max_len: 37 | problems = np.zeros(opt.max_len, dtype=np.int64) 38 | correct = np.ones(opt.max_len, dtype=np.int64) 39 | problems[-num_ques:] = ques[-num_ques:] 40 | correct[-num_ques:] = ans[-num_ques:] 41 | self.students.append((num_ques, problems, correct)) 42 | else: 43 | start_idx = 0 44 | while opt.max_len + start_idx <= num_ques: 45 | problems = np.array(ques[start_idx:opt.max_len + start_idx]) 46 | correct = np.array(ans[start_idx:opt.max_len + start_idx]) 47 | tup = (opt.max_len, problems, correct) 48 | start_idx += opt.max_len 49 | self.students.append(tup) 50 | left_num_ques = num_ques - start_idx 51 | problems = np.zeros(opt.max_len, dtype=np.int64) 52 | correct = np.ones(opt.max_len, dtype=np.int64) 53 | problems[-left_num_ques:] = ques[start_idx:] 54 | correct[-left_num_ques:] = ans[start_idx:] 55 | tup = (left_num_ques, problems, correct) 56 | self.students.append(tup) 57 | 58 | def __getitem__(self, index): 59 | student = self.students[index] 60 | problems = student[1] 61 | correct = student[2] 62 | x = np.zeros(opt.max_len - 1) 63 | x = problems[:-1] 64 | # we assume max_skill_num + 1 = num_skills because skill index starts from 0 to max_skill_num 65 | x += (correct[:-1] == 1) * (self.max_skill_num + 1) 66 | problems = problems[1:] 67 | correct = correct[1:] 68 | return x, problems, correct 69 | 70 | def __len__(self): 71 | return len(self.students) 72 | 73 | 74 | class DataLoaderX(DataLoader): 75 | 76 | def __iter__(self): 77 | return BackgroundGenerator(super().__iter__()) 78 | 79 | 80 | class DataPrefetcher(): 81 | def __init__(self, loader, device): 82 | self.loader = iter(loader) 83 | self.device = device 84 | self.stream = torch.cuda.Stream() 85 | # With Amp, it isn't necessary to manually convert data to half. 86 | # if args.fp16: 87 | # self.mean = self.mean.half() 88 | # self.std = self.std.half() 89 | self.preload() 90 | 91 | def preload(self): 92 | try: 93 | self.batch = next(self.loader) 94 | except StopIteration: 95 | self.batch = None 96 | return 97 | with torch.cuda.stream(self.stream): 98 | for k in range(len(self.batch)): 99 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 100 | 101 | # With Amp, it isn't necessary to manually convert data to half. 102 | # if args.fp16: 103 | # self.next_input = self.next_input.half() 104 | # else: 105 | # self.next_input = self.next_input.float() 106 | 107 | def next(self): 108 | torch.cuda.current_stream().wait_stream(self.stream) 109 | batch = self.batch 110 | self.preload() 111 | return batch --------------------------------------------------------------------------------