├── .gitignore ├── LANA_EDM2021.pdf ├── LICENSE ├── README.md ├── config.py ├── dataset.py ├── dataset_group.py ├── depen_install.py ├── figures ├── kt-example.png ├── lana-arch.png └── leveled-learning.png ├── lana_arch.py ├── main.py ├── main_ll.py ├── package └── pyirt │ ├── pyirt │ ├── __init__.py │ ├── _pyirt.py │ ├── algo.py │ ├── dao.py │ ├── logger.py │ ├── solver │ │ ├── __init__.py │ │ ├── model.py │ │ ├── optimizer.py │ │ └── theta_estimator.py │ └── util │ │ ├── __init__.py │ │ ├── clib.py │ │ ├── dao.py │ │ └── tools.py │ ├── setup.cfg │ └── setup.py ├── requirements.txt ├── sample_data_preprocess.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | -------------------------------------------------------------------------------- /LANA_EDM2021.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Soptq/LANA-pytorch/c1ce06775c4faccc7fb612d23403a0123666cb2d/LANA_EDM2021.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Soptq 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LANA-pytorch 2 | Official Pytorch implementation of "LANA: Towards Personalized Deep Knowledge Tracing Through Distinguishable Interactive Sequences" 3 | 4 | ![](figures/kt-example.png) 5 | 6 | ## Abstract 7 | 8 | > In educational applications, Knowledge Tracing (KT), the problem of accurately predicting students' responses to future questions by summarizing their knowledge states, has been widely studied for decades as it is considered a fundamental task towards adaptive online learning. Among all the proposed KT methods, Deep Knowledge Tracing (DKT) and its variants are by far the most effective ones due to the high flexibility of the neural network. However, DKT often ignores the inherent differences between students (e.g. memory skills, reasoning skills, ...), averaging the performances of all students, leading to the lack of personalization, and therefore was considered insufficient for adaptive learning. To alleviate this problem, in this paper, we proposed Leveled Attentive KNowledge TrAcing (LANA), which firstly uses a novel student-related features extractor (SRFE) to distill students' unique inherent properties from their respective interactive sequences. Secondly, the pivot module was utilized to dynamically reconstruct the decoder of the neural network on attention of the extracted features, successfully distinguishing the performance between students over time. Moreover, inspired by Item Response Theory (IRT), the interpretable Rasch model was used to cluster students by their ability levels, and thereby utilizing leveled learning to assign different encoders to different groups of students. With pivot module reconstructed the decoder for individual students and leveled learning specialized encoders for groups, personalized DKT was achieved. Extensive experiments conducted on two real-world large-scale datasets demonstrated that our proposed LANA improves the AUC score by at least 1.00% (i.e. EdNet 1.46% and RAIEd2020 1.00%), substantially surpassing the other State-Of-The-Art KT methods. 9 | 10 | ## Main Architecture 11 | 12 | ![](figures/lana-arch.png) 13 | 14 | ![](figures/leveled-learning.png) 15 | 16 | ## Quickstart 17 | ### Cloning 18 | ``` 19 | git clone https://github.com/Soptq/LANA-pytorch 20 | cd LANA-pytorch 21 | ``` 22 | 23 | ### Installation 24 | ``` 25 | python depen_install.py 26 | ``` 27 | 28 | ### Dataset Preparation 29 | 30 | you need to manually download the dataset, and perform preprocessing on it. A sample preprocessing script is provided as `sample_data_preprocess.py` 31 | 32 | Note that if you are going to try Leveled Learning, you must pass `--irt` to the preprocessing script. 33 | 34 | ### Run LANA 35 | ``` 36 | python main.py -d YOUR_PREPROCESSED_DATA 37 | ``` 38 | 39 | configurations of the experiments are set in `config.py`. 40 | 41 | After training to convergence, you can further improve the performance of it by applying Leveled Learning (optional): 42 | 43 | ``` 44 | python main_ll -d YOUR_PREPROCESSED_DATA -m YOUR_TRAINED_MODEL ...other arguments 45 | ``` 46 | 47 | ## Results 48 | 49 | | Dataset | Model | AUC | 50 | | ------------- | ------------- | ------------- | 51 | | EdNet | DKT | 0.7638 | 52 | | EdNet | DKVMN | 0.7668 | 53 | | EdNet | SAKT | 0.7663 | 54 | | EdNet | SAINT | 0.7816 | 55 | | EdNet | SAINT+ | _0.7913_ | 56 | | EdNet | SAINT+ & BM | 0.7935 | 57 | | EdNet | LANA | **0.8059** | 58 | 59 | ## Cite 60 | 61 | ``` 62 | @inproceedings{zhou2021lana, 63 | title={LANA: Towards Personalized Deep Knowledge Tracing Through Distinguishable Interactive Sequences}, 64 | author={Yuhao Zhou, Xihua Li, Yunbo Cao, Xuemin Zhao, Qing Ye, Jiancheng Lv}, 65 | organization={Sichuan University, Tencent Inc.} 66 | year={2021} 67 | } 68 | ``` 69 | 70 | ## License 71 | 72 | ``` 73 | MIT License 74 | 75 | Copyright (c) 2021 Soptq 76 | 77 | Permission is hereby granted, free of charge, to any person obtaining a copy 78 | of this software and associated documentation files (the "Software"), to deal 79 | in the Software without restriction, including without limitation the rights 80 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 81 | copies of the Software, and to permit persons to whom the Software is 82 | furnished to do so, subject to the following conditions: 83 | 84 | The above copyright notice and this permission notice shall be included in all 85 | copies or substantial portions of the Software. 86 | 87 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 88 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 89 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 90 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 91 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 92 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 93 | SOFTWARE. 94 | ``` 95 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | SEED = 2021 2 | PAD = 0 3 | START = 1 4 | TOTAL_EID = 13523 5 | TOTAL_CAT = 10000 6 | TOTAL_PART = 7 7 | TOTAL_RESP = 2 8 | TOTAL_ETIME = 300 9 | TOTAL_LTIME_S = 300 10 | TOTAL_LTIME_M = 1440 11 | TOTAL_LTIME_D = 365 12 | 13 | 14 | DEVICE = [0] # [0]: use cuda:0 to train 15 | MAX_SEQ = 100 # The maximum length of inputting sequence at a time 16 | MIN_SEQ = 2 # The minimal length of inputting sequence at a time 17 | OVERLAP_SEQ = 60 # Split a training sample every OVERLAP_SEQ words 18 | MODEL_DIMS = 256 # Model dimension 19 | FEEDFORWARD_DIMS = 256 # Hidden dimension of FFN 20 | N_HEADS = 8 # Number of Attention heads 21 | NUM_ENCODER = NUM_DECODER = 2 # Number of Encoder/Decoder in Transformer 22 | DROPOUT = 0.1 # Dropout rate 23 | BATCH_SIZE = 256 # Batch size 24 | LEARNING_RATE = 5e-4 # Base learning rate 25 | 26 | EPOCH = 100 # Maximum number of training 27 | D_PIV = 32 -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | 3 | import torch 4 | import config 5 | import numpy as np 6 | 7 | from tqdm import tqdm 8 | 9 | # Dataset Define 10 | class DKTDataset(Dataset): 11 | def __init__(self, group, max_seq, min_seq, overlap_seq): 12 | self.samples = group 13 | self.max_seq = max_seq 14 | self.min_seq = min_seq 15 | self.overlap_seq = overlap_seq 16 | self.data = [] 17 | 18 | for exercise, part, correctness, elapsed_time, lag_time_s, lag_time_m, lag_time_d, p_explanation in tqdm(self.samples, total=len(self.samples), desc="Loading Dataset"): 19 | content_len = len(exercise) 20 | if content_len < self.min_seq: 21 | continue # skip sequence with too few contents 22 | 23 | if content_len > self.max_seq: 24 | initial = content_len % self.max_seq 25 | if initial >= self.min_seq: 26 | self.data.extend([(np.append([config.START], exercise[:initial]), 27 | np.append([config.START], part[:initial]), 28 | np.append([config.START], correctness[:initial]), 29 | np.append([config.START], elapsed_time[:initial]), 30 | np.append([config.START], lag_time_s[:initial]), 31 | np.append([config.START], lag_time_m[:initial]), 32 | np.append([config.START], lag_time_d[:initial]), 33 | np.append([config.START], p_explanation[:initial]))]) 34 | for seq in range(content_len // self.max_seq): 35 | start = initial + seq * self.max_seq 36 | end = initial + (seq + 1) * self.max_seq 37 | self.data.extend([(np.append([config.START], exercise[start: end]), 38 | np.append([config.START], part[start: end]), 39 | np.append([config.START], correctness[start: end]), 40 | np.append([config.START], elapsed_time[start: end]), 41 | np.append([config.START], lag_time_s[start: end]), 42 | np.append([config.START], lag_time_m[start: end]), 43 | np.append([config.START], lag_time_d[start: end]), 44 | np.append([config.START], p_explanation[start: end]))]) 45 | else: 46 | self.data.extend([(np.append([config.START], exercise), 47 | np.append([config.START], part), 48 | np.append([config.START], correctness), 49 | np.append([config.START], elapsed_time), 50 | np.append([config.START], lag_time_s), 51 | np.append([config.START], lag_time_m), 52 | np.append([config.START], lag_time_d), 53 | np.append([config.START], p_explanation))]) 54 | 55 | def __len__(self): 56 | return len(self.data) 57 | 58 | def __getitem__(self, idx): 59 | raw_content_ids, raw_part, raw_correctness, raw_elapsed_time, raw_lag_time_s, raw_lag_time_m, raw_lag_time_d, raw_p_explan = self.data[idx] 60 | seq_len = len(raw_content_ids) 61 | 62 | input_content_ids = np.zeros(self.max_seq, dtype=np.int64) 63 | input_part = np.zeros(self.max_seq, dtype=np.int64) 64 | input_correctness = np.zeros(self.max_seq, dtype=np.int64) 65 | input_elapsed_time = np.zeros(self.max_seq, dtype=np.int64) 66 | input_lag_time_s = np.zeros(self.max_seq, dtype=np.int64) 67 | input_lag_time_m = np.zeros(self.max_seq, dtype=np.int64) 68 | input_lag_time_d = np.zeros(self.max_seq, dtype=np.int64) 69 | input_p_explan = np.zeros(self.max_seq, dtype=np.int64) 70 | 71 | label = np.zeros(self.max_seq, dtype=np.int64) 72 | 73 | if seq_len == self.max_seq + 1: # START token 74 | input_content_ids[:] = raw_content_ids[1:] 75 | input_part[:] = raw_part[1:] 76 | input_p_explan[:] = raw_p_explan[1:] 77 | input_correctness[:] = raw_correctness[:-1] 78 | input_elapsed_time[:] = np.append(raw_elapsed_time[0], raw_elapsed_time[2:]) 79 | input_lag_time_s[:] = np.append(raw_lag_time_s[0], raw_lag_time_s[2:]) 80 | input_lag_time_m[:] = np.append(raw_lag_time_m[0], raw_lag_time_m[2:]) 81 | input_lag_time_d[:] = np.append(raw_lag_time_d[0], raw_lag_time_d[2:]) 82 | label[:] = raw_correctness[1:] - 2 83 | else: 84 | input_content_ids[-(seq_len - 1):] = raw_content_ids[1:] # Delete START token 85 | input_part[-(seq_len - 1):] = raw_part[1:] 86 | input_p_explan[-(seq_len - 1):] = raw_p_explan[1:] 87 | input_correctness[-(seq_len - 1):] = raw_correctness[:-1] 88 | input_elapsed_time[-(seq_len - 1):] = np.append(raw_elapsed_time[0], raw_elapsed_time[2:]) 89 | input_lag_time_s[-(seq_len - 1):] = np.append(raw_lag_time_s[0], raw_lag_time_s[2:]) 90 | input_lag_time_m[-(seq_len - 1):] = np.append(raw_lag_time_m[0], raw_lag_time_m[2:]) 91 | input_lag_time_d[-(seq_len - 1):] = np.append(raw_lag_time_d[0], raw_lag_time_d[2:]) 92 | label[-(seq_len - 1):] = raw_correctness[1:] - 2 93 | 94 | _input = {"content_id": input_content_ids.astype(np.int64), 95 | "part": input_part.astype(np.int64), 96 | "correctness": input_correctness.astype(np.int64), 97 | "elapsed_time": input_elapsed_time.astype(np.int64), 98 | "lag_time_s": input_lag_time_s.astype(np.int64), 99 | "lag_time_m": input_lag_time_m.astype(np.int64), 100 | "lag_time_d": input_lag_time_d.astype(np.int64), 101 | "prior_explan": input_p_explan.astype(np.int64)} 102 | return _input, label 103 | -------------------------------------------------------------------------------- /dataset_group.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | import torch 4 | import config 5 | import numpy as np 6 | 7 | from scipy.stats import norm 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | class DKTDataset(Dataset): 13 | def __init__(self, group, max_seq, min_seq, overlap_seq, user_performance, n_levels, mu_itv): 14 | self.samples = group 15 | self.max_seq = max_seq 16 | self.min_seq = min_seq 17 | self.overlap_seq = overlap_seq 18 | self.user_performance = user_performance 19 | self.data = [] 20 | self.n_levels = n_levels 21 | self.mu_itv = mu_itv 22 | self.mu_levels, self.std_levels = self._fit_norm(user_performance) 23 | for user_id, (exercise, part, correctness, elapsed_time, lag_time_s, lag_time_m, lag_time_d, p_explanation) in tqdm(self.samples.items(), total=len(self.samples), desc="Loading Dataset"): 24 | content_len = len(exercise) 25 | if content_len < self.min_seq: 26 | continue # skip sequence with too few contents 27 | 28 | if content_len > self.max_seq: 29 | initial = content_len % self.max_seq 30 | if initial >= self.min_seq: 31 | self.data.extend([(user_id, np.append([config.START], exercise[:initial]), 32 | np.append([config.START], part[:initial]), 33 | np.append([config.START], correctness[:initial]), 34 | np.append([config.START], elapsed_time[:initial]), 35 | np.append([config.START], lag_time_s[:initial]), 36 | np.append([config.START], lag_time_m[:initial]), 37 | np.append([config.START], lag_time_d[:initial]), 38 | np.append([config.START], p_explanation[:initial]))]) 39 | for seq in range(content_len // self.max_seq): 40 | start = initial + seq * self.max_seq 41 | end = initial + (seq + 1) * self.max_seq 42 | self.data.extend([(user_id, np.append([config.START], exercise[start: end]), 43 | np.append([config.START], part[start: end]), 44 | np.append([config.START], correctness[start: end]), 45 | np.append([config.START], elapsed_time[start: end]), 46 | np.append([config.START], lag_time_s[start: end]), 47 | np.append([config.START], lag_time_m[start: end]), 48 | np.append([config.START], lag_time_d[start: end]), 49 | np.append([config.START], p_explanation[start: end]))]) 50 | else: 51 | self.data.extend([(user_id, np.append([config.START], exercise), 52 | np.append([config.START], part), 53 | np.append([config.START], correctness), 54 | np.append([config.START], elapsed_time), 55 | np.append([config.START], lag_time_s), 56 | np.append([config.START], lag_time_m), 57 | np.append([config.START], lag_time_d), 58 | np.append([config.START], p_explanation))]) 59 | 60 | def _fit_norm(self, user_perf): 61 | data = [d for d in user_perf.values()] 62 | mu, std = norm.fit(data) 63 | mu_levels = [mu - (self.n_levels - 1) * self.mu_itv / 2 + i * self.mu_itv for i in range(self.n_levels)] 64 | std_levels = [np.sqrt(std ** 2 / self.n_levels) for _ in range(self.n_levels)] 65 | return mu_levels, std_levels 66 | 67 | def _predict_level(self, user_perf, mu_levels, std_levels): 68 | probs = [] 69 | for mu, std in zip(mu_levels, std_levels): 70 | probs.append(norm.pdf(user_perf, mu, std)) 71 | probs = np.array(probs) 72 | probs = probs / sum(probs) 73 | return probs 74 | 75 | def __len__(self): 76 | return len(self.data) 77 | 78 | def __getitem__(self, idx): 79 | raw_user_id, raw_content_ids, raw_part, raw_correctness, raw_elapsed_time, raw_lag_time_s, raw_lag_time_m, raw_lag_time_d, raw_p_explan = self.data[idx] 80 | if raw_user_id in self.user_performance: 81 | user_per = self.user_performance[raw_user_id] 82 | probs = self._predict_level(user_per, self.mu_levels, self.std_levels) 83 | else: 84 | probs = np.ones(len(self.mu_levels)) 85 | probs /= len(self.mu_levels) 86 | seq_len = len(raw_content_ids) 87 | 88 | input_content_ids = np.zeros(self.max_seq, dtype=np.int64) 89 | input_part = np.zeros(self.max_seq, dtype=np.int64) 90 | input_correctness = np.zeros(self.max_seq, dtype=np.int64) 91 | input_elapsed_time = np.zeros(self.max_seq, dtype=np.int64) 92 | input_lag_time_s = np.zeros(self.max_seq, dtype=np.int64) 93 | input_lag_time_m = np.zeros(self.max_seq, dtype=np.int64) 94 | input_lag_time_d = np.zeros(self.max_seq, dtype=np.int64) 95 | input_p_explan = np.zeros(self.max_seq, dtype=np.int64) 96 | 97 | label = np.zeros(self.max_seq, dtype=np.int64) 98 | 99 | if seq_len == self.max_seq + 1: # START token 100 | input_content_ids[:] = raw_content_ids[1:] 101 | input_part[:] = raw_part[1:] 102 | input_p_explan[:] = raw_p_explan[1:] 103 | input_correctness[:] = raw_correctness[:-1] 104 | input_elapsed_time[:] = np.append(raw_elapsed_time[0], raw_elapsed_time[2:]) 105 | input_lag_time_s[:] = np.append(raw_lag_time_s[0], raw_lag_time_s[2:]) 106 | input_lag_time_m[:] = np.append(raw_lag_time_m[0], raw_lag_time_m[2:]) 107 | input_lag_time_d[:] = np.append(raw_lag_time_d[0], raw_lag_time_d[2:]) 108 | label[:] = raw_correctness[1:] - 2 109 | else: 110 | input_content_ids[-(seq_len - 1):] = raw_content_ids[1:] # Delete START token 111 | input_part[-(seq_len - 1):] = raw_part[1:] 112 | input_p_explan[-(seq_len - 1):] = raw_p_explan[1:] 113 | input_correctness[-(seq_len - 1):] = raw_correctness[:-1] 114 | input_elapsed_time[-(seq_len - 1):] = np.append(raw_elapsed_time[0], raw_elapsed_time[2:]) 115 | input_lag_time_s[-(seq_len - 1):] = np.append(raw_lag_time_s[0], raw_lag_time_s[2:]) 116 | input_lag_time_m[-(seq_len - 1):] = np.append(raw_lag_time_m[0], raw_lag_time_m[2:]) 117 | input_lag_time_d[-(seq_len - 1):] = np.append(raw_lag_time_d[0], raw_lag_time_d[2:]) 118 | label[-(seq_len - 1):] = raw_correctness[1:] - 2 119 | 120 | _input = {"content_id": input_content_ids.astype(np.int64), 121 | "part": input_part.astype(np.int64), 122 | "correctness": input_correctness.astype(np.int64), 123 | "elapsed_time": input_elapsed_time.astype(np.int64), 124 | "lag_time_s": input_lag_time_s.astype(np.int64), 125 | "lag_time_m": input_lag_time_m.astype(np.int64), 126 | "lag_time_d": input_lag_time_d.astype(np.int64), 127 | "prior_explan": input_p_explan.astype(np.int64)} 128 | return _input, label, probs 129 | -------------------------------------------------------------------------------- /depen_install.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.system("pip install -r requirements.txt") 4 | os.system("cd package/pyirt && pip install .") 5 | -------------------------------------------------------------------------------- /figures/kt-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Soptq/LANA-pytorch/c1ce06775c4faccc7fb612d23403a0123666cb2d/figures/kt-example.png -------------------------------------------------------------------------------- /figures/lana-arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Soptq/LANA-pytorch/c1ce06775c4faccc7fb612d23403a0123666cb2d/figures/lana-arch.png -------------------------------------------------------------------------------- /figures/leveled-learning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Soptq/LANA-pytorch/c1ce06775c4faccc7fb612d23403a0123666cb2d/figures/leveled-learning.png -------------------------------------------------------------------------------- /lana_arch.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | 4 | import numpy as np 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | import config 11 | 12 | 13 | def future_mask(seq_length): 14 | future_mask = np.triu(np.ones((1, seq_length, seq_length)), k=1).astype('bool') 15 | return torch.from_numpy(future_mask) 16 | 17 | 18 | def get_clones(module, N): 19 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 20 | 21 | 22 | def attention(q, k, v, d_k, positional_bias=None, mask=None, dropout=None, 23 | memory_decay=False, memory_gamma=None, ltime=None): 24 | # ltime shape [batch, seq_len] 25 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # [bs, nh, s, s] 26 | bs, nhead, seqlen = scores.size(0), scores.size(1), scores.size(2) 27 | 28 | if mask is not None: 29 | mask = mask.unsqueeze(1) 30 | 31 | if memory_decay and memory_gamma is not None and ltime is not None: 32 | time_seq = torch.cumsum(ltime.float(), dim=-1) - ltime.float() # [bs, s] 33 | index_seq = torch.arange(seqlen).unsqueeze(-2).to(q.device) 34 | 35 | dist_seq = time_seq + index_seq 36 | 37 | with torch.no_grad(): 38 | if mask is not None: 39 | scores_ = scores.masked_fill(mask, 1e-9) 40 | scores_ = F.softmax(scores_, dim=-1) 41 | distcum_scores = torch.cumsum(scores_, dim=-1) 42 | distotal_scores = torch.sum(scores_, dim=-1, keepdim=True) 43 | position_diff = dist_seq[:, None, :] - dist_seq[:, :, None] 44 | position_effect = torch.abs(position_diff)[:, None, :, :].type(torch.FloatTensor).to(q.device) 45 | dist_scores = torch.clamp((distotal_scores - distcum_scores) * position_effect, min=0.) 46 | dist_scores = dist_scores.sqrt().detach() 47 | 48 | m = nn.Softplus() 49 | memory_gamma = -1. * m(memory_gamma) 50 | total_effect = torch.clamp(torch.clamp((dist_scores * memory_gamma).exp(), min=1e-5), max=1e5) 51 | scores = total_effect * scores 52 | 53 | if positional_bias is not None: 54 | scores = scores + positional_bias 55 | 56 | if mask is not None: 57 | scores = scores.masked_fill(mask, -1e9) 58 | 59 | scores = F.softmax(scores, dim=-1) # [bs, nh, s, s] 60 | 61 | if dropout is not None: 62 | scores = dropout(scores) 63 | 64 | output = torch.matmul(scores, v) 65 | return output 66 | 67 | 68 | class MultiHeadAttention(nn.Module): 69 | def __init__(self, embed_dim, num_heads, dropout=0.1): 70 | super(MultiHeadAttention, self).__init__() 71 | 72 | self.d_model = embed_dim 73 | self.d_k = embed_dim // num_heads 74 | self.h = num_heads 75 | 76 | self.q_linear = nn.Linear(embed_dim, embed_dim) 77 | self.v_linear = nn.Linear(embed_dim, embed_dim) 78 | self.k_linear = nn.Linear(embed_dim, embed_dim) 79 | self.dropout = nn.Dropout(dropout) 80 | self.gammas = nn.Parameter(torch.zeros(num_heads, config.MAX_SEQ, 1)) 81 | self.m_srfe = MemorySRFE(embed_dim, num_heads) 82 | self.out = nn.Linear(embed_dim, embed_dim) 83 | 84 | def forward(self, q, k, v, ltime=None, gamma=None, positional_bias=None, 85 | attn_mask=None): 86 | bs = q.size(0) 87 | 88 | # perform linear operation and split into h heads 89 | k = self.k_linear(k).view(bs, -1, self.h, self.d_k) 90 | q = self.q_linear(q).view(bs, -1, self.h, self.d_k) 91 | v = self.v_linear(v).view(bs, -1, self.h, self.d_k) 92 | 93 | # transpose to get dimensions bs * h * sl * d_model 94 | k = k.transpose(1, 2) 95 | q = q.transpose(1, 2) 96 | v = v.transpose(1, 2) 97 | 98 | if gamma is not None: 99 | gamma = self.m_srfe(gamma) + self.gammas 100 | else: 101 | gamma = self.gammas 102 | 103 | # calculate attention using function we will define next 104 | scores = attention(q, k, v, self.d_k, positional_bias, attn_mask, self.dropout, 105 | memory_decay=True, memory_gamma=gamma, ltime=ltime) 106 | 107 | # concatenate heads and put through final linear layer 108 | concat = scores.transpose(1, 2).contiguous() \ 109 | .view(bs, -1, self.d_model) 110 | 111 | output = self.out(concat) 112 | 113 | return output 114 | 115 | 116 | class BaseSRFE(nn.Module): 117 | def __init__(self, in_dim, n_head, dropout): 118 | super(BaseSRFE, self).__init__() 119 | assert in_dim % n_head == 0 120 | self.in_dim = in_dim // n_head 121 | self.n_head = n_head 122 | self.attention = MultiHeadAttention(embed_dim=in_dim, num_heads=n_head, dropout=dropout) 123 | self.dropout = nn.Dropout(dropout) 124 | self.layernorm = nn.LayerNorm(in_dim) 125 | 126 | def forward(self, x, pos_embed, mask): 127 | out = x 128 | att_out = self.attention(out, out, out, positional_bias=pos_embed, attn_mask=mask) 129 | out = out + self.dropout(att_out) 130 | out = self.layernorm(out) 131 | 132 | return x 133 | 134 | 135 | class MemorySRFE(nn.Module): 136 | def __init__(self, in_dim, n_head): 137 | super(MemorySRFE, self).__init__() 138 | assert in_dim % n_head == 0 139 | self.in_dim = in_dim // n_head 140 | self.n_head = n_head 141 | self.linear1 = nn.Linear(self.in_dim, 1) 142 | 143 | def forward(self, x): 144 | bs = x.size(0) 145 | 146 | x = x.view(bs, -1, self.n_head, self.in_dim) \ 147 | .transpose(1, 2) \ 148 | .contiguous() 149 | x = self.linear1(x) 150 | return x 151 | 152 | 153 | class PerformanceSRFE(nn.Module): 154 | def __init__(self, d_model, d_piv): 155 | super(PerformanceSRFE, self).__init__() 156 | self.linear1 = nn.Linear(d_model, 128) 157 | self.linear2 = nn.Linear(128, d_piv) 158 | 159 | def forward(self, x): 160 | x = F.gelu(self.linear1(x)) 161 | x = self.linear2(x) 162 | 163 | return x 164 | 165 | 166 | class FFN(nn.Module): 167 | def __init__(self, d_model, d_ffn, dropout): 168 | super(FFN, self).__init__() 169 | self.lr1 = nn.Linear(d_model, d_ffn) 170 | self.act = nn.ReLU() 171 | self.lr2 = nn.Linear(d_ffn, d_model) 172 | self.dropout = nn.Dropout(dropout) 173 | 174 | def forward(self, x): 175 | x = self.lr1(x) 176 | x = self.act(x) 177 | x = self.dropout(x) 178 | x = self.lr2(x) 179 | return x 180 | 181 | 182 | class PivotFFN(nn.Module): 183 | def __init__(self, d_model, d_ffn, d_piv, dropout): 184 | super(PivotFFN, self).__init__() 185 | self.p_srfe = PerformanceSRFE(d_model, d_piv) 186 | self.lr1 = nn.Bilinear(d_piv, d_model, d_ffn) 187 | self.lr2 = nn.Bilinear(d_piv, d_ffn, d_model) 188 | self.dropout = nn.Dropout(dropout) 189 | 190 | def forward(self, x, pivot): 191 | pivot = self.p_srfe(pivot) 192 | 193 | x = F.gelu(self.lr1(pivot, x)) 194 | x = self.dropout(x) 195 | x = self.lr2(pivot, x) 196 | return x 197 | 198 | 199 | class LANAEncoder(nn.Module): 200 | def __init__(self, d_model, n_heads, d_ffn, dropout, max_seq): 201 | super(LANAEncoder, self).__init__() 202 | self.max_seq = max_seq 203 | 204 | self.multi_attn = MultiHeadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout) 205 | self.layernorm1 = nn.LayerNorm(d_model) 206 | self.layernorm2 = nn.LayerNorm(d_model) 207 | self.dropout1 = nn.Dropout(dropout) 208 | self.dropout2 = nn.Dropout(dropout) 209 | 210 | self.ffn = FFN(d_model, d_ffn, dropout) 211 | 212 | def forward(self, x, pos_embed, mask): 213 | out = x 214 | att_out = self.multi_attn(out, out, out, positional_bias=pos_embed, attn_mask=mask) 215 | out = out + self.dropout1(att_out) 216 | out = self.layernorm1(out) 217 | 218 | ffn_out = self.ffn(out) 219 | out = self.layernorm2(out + self.dropout2(ffn_out)) 220 | 221 | return out 222 | 223 | 224 | class LANADecoder(nn.Module): 225 | def __init__(self, d_model, n_heads, d_ffn, dropout, max_seq): 226 | super(LANADecoder, self).__init__() 227 | self.max_seq = max_seq 228 | 229 | self.multi_attn_1 = MultiHeadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout) 230 | self.multi_attn_2 = MultiHeadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout) 231 | 232 | self.layernorm1 = nn.LayerNorm(d_model) 233 | self.layernorm2 = nn.LayerNorm(d_model) 234 | self.layernorm3 = nn.LayerNorm(d_model) 235 | self.dropout1 = nn.Dropout(dropout) 236 | self.dropout2 = nn.Dropout(dropout) 237 | self.dropout3 = nn.Dropout(dropout) 238 | 239 | self.ffn = FFN(d_model, d_ffn, dropout) 240 | 241 | def forward(self, x, memory, ltime, status, pos_embed, mask1, mask2): 242 | out = x 243 | att_out_1 = self.multi_attn_1(out, out, out, ltime=ltime, 244 | positional_bias=pos_embed, attn_mask=mask1) 245 | out = out + self.dropout1(att_out_1) 246 | out = self.layernorm1(out) 247 | 248 | att_out_2 = self.multi_attn_2(out, memory, memory, ltime=ltime, 249 | gamma=status, positional_bias=pos_embed, attn_mask=mask2) 250 | out = out + self.dropout2(att_out_2) 251 | out = self.layernorm2(out) 252 | 253 | ffn_out = self.ffn(out) 254 | out = self.layernorm3(out + self.dropout3(ffn_out)) 255 | 256 | return out 257 | 258 | 259 | class PositionalBias(nn.Module): 260 | def __init__(self, max_seq, embed_dim, num_heads, bidirectional=True, num_buckets=32, max_distance=config.MAX_SEQ): 261 | super(PositionalBias, self).__init__() 262 | self.d_model = embed_dim 263 | self.d_k = embed_dim // num_heads 264 | self.h = num_heads 265 | self.bidirectional = bidirectional 266 | self.num_buckets = num_buckets 267 | self.max_distance = max_distance 268 | 269 | self.pos_embed = nn.Embedding(max_seq, embed_dim) # Encoder position Embedding 270 | self.pos_query_linear = nn.Linear(embed_dim, embed_dim) 271 | self.pos_key_linear = nn.Linear(embed_dim, embed_dim) 272 | self.pos_layernorm = nn.LayerNorm(embed_dim) 273 | 274 | self.relative_attention_bias = nn.Embedding(32, config.N_HEADS) 275 | 276 | def forward(self, pos_seq): 277 | bs = pos_seq.size(0) 278 | 279 | pos_embed = self.pos_embed(pos_seq) 280 | pos_embed = self.pos_layernorm(pos_embed) 281 | 282 | pos_query = self.pos_query_linear(pos_embed) 283 | pos_key = self.pos_key_linear(pos_embed) 284 | 285 | pos_query = pos_query.view(bs, -1, self.h, self.d_k).transpose(1, 2) 286 | pos_key = pos_key.view(bs, -1, self.h, self.d_k).transpose(1, 2) 287 | 288 | absolute_bias = torch.matmul(pos_query, pos_key.transpose(-2, -1)) / math.sqrt(self.d_k) 289 | relative_position = pos_seq[:, None, :] - pos_seq[:, :, None] 290 | 291 | relative_buckets = 0 292 | num_buckets = self.num_buckets 293 | if self.bidirectional: 294 | num_buckets = num_buckets // 2 295 | relative_buckets += (relative_position > 0).to(torch.long) * num_buckets 296 | relative_bias = torch.abs(relative_position) 297 | else: 298 | relative_bias = -torch.min(relative_position, torch.zeros_like(relative_position)) 299 | 300 | max_exact = num_buckets // 2 301 | is_small = relative_bias < max_exact 302 | 303 | relative_bias_if_large = max_exact + ( 304 | torch.log(relative_bias.float() / max_exact) 305 | / math.log(self.max_distance / max_exact) 306 | * (num_buckets - max_exact) 307 | ).to(torch.long) 308 | relative_bias_if_large = torch.min( 309 | relative_bias_if_large, torch.full_like(relative_bias_if_large, num_buckets - 1) 310 | ) 311 | 312 | relative_buckets += torch.where(is_small, relative_bias, relative_bias_if_large) 313 | relative_position_buckets = relative_buckets.to(pos_seq.device) 314 | 315 | relative_bias = self.relative_attention_bias(relative_position_buckets) 316 | relative_bias = relative_bias.permute(0, 3, 1, 2) 317 | 318 | position_bias = absolute_bias + relative_bias 319 | return position_bias 320 | 321 | 322 | class LANA(nn.Module): 323 | def __init__(self, d_model, n_head, n_encoder, n_decoder, dim_feedforward, dropout, 324 | max_seq, n_exercises, n_parts, n_resp, n_etime, n_ltime_s, n_ltime_m, n_ltime_d): 325 | super(LANA, self).__init__() 326 | self.max_seq = max_seq 327 | 328 | self.pos_embed = PositionalBias(max_seq, d_model, n_head, bidirectional=False, num_buckets=32, 329 | max_distance=max_seq) 330 | 331 | self.encoder_resp_embed = nn.Embedding(n_resp + 2, d_model, 332 | padding_idx=config.PAD) # Answer Embedding, 0 for padding 333 | self.encoder_eid_embed = nn.Embedding(n_exercises + 2, d_model, 334 | padding_idx=config.PAD) # Exercise ID Embedding, 0 for padding 335 | self.encoder_part_embed = nn.Embedding(n_parts + 2, d_model, 336 | padding_idx=config.PAD) # Part Embedding, 0 for padding 337 | self.encoder_p_explan_embed = nn.Embedding(2 + 2, d_model, padding_idx=config.PAD) 338 | self.encoder_linear = nn.Linear(4 * d_model, d_model) 339 | self.encoder_layernorm = nn.LayerNorm(d_model) 340 | self.encoder_dropout = nn.Dropout(dropout) 341 | 342 | self.decoder_resp_embed = nn.Embedding(n_resp + 2, d_model, 343 | padding_idx=config.PAD) # Answer Embedding, 0 for padding 344 | self.decoder_etime_embed = nn.Embedding(n_etime + 3, d_model, padding_idx=config.PAD) # Elapsed time Embedding 345 | self.decoder_ltime_embed_s = nn.Embedding(n_ltime_s + 3, d_model, 346 | padding_idx=config.PAD) # Lag time Embedding 1 347 | self.decoder_ltime_embed_m = nn.Embedding(n_ltime_m + 3, d_model, 348 | padding_idx=config.PAD) # Lag time Embedding 2 349 | self.decoder_ltime_embed_h = nn.Embedding(n_ltime_d + 3, d_model, 350 | padding_idx=config.PAD) # Lag time Embedding 3 351 | self.decoder_linear = nn.Linear(5 * d_model, d_model) 352 | self.decoder_layernorm = nn.LayerNorm(d_model) 353 | self.decoder_dropout = nn.Dropout(dropout) 354 | 355 | self.encoder = get_clones(LANAEncoder(d_model, n_head, dim_feedforward, dropout, max_seq), n_encoder) 356 | self.srfe = BaseSRFE(d_model, n_head, dropout) 357 | self.decoder = get_clones(LANADecoder(d_model, n_head, dim_feedforward, dropout, max_seq), n_decoder) 358 | 359 | self.layernorm_out = nn.LayerNorm(d_model) 360 | self.ffn = PivotFFN(d_model, dim_feedforward, config.D_PIV, dropout) 361 | self.classifier = nn.Linear(d_model, 1) 362 | 363 | def get_pos_seq(self): 364 | return torch.arange(self.max_seq).unsqueeze(0) 365 | 366 | def _get_param_from_input(self, input): 367 | return (input["content_id"].long(), 368 | input["part"].long(), 369 | input["correctness"].long(), 370 | input["elapsed_time"].long(), 371 | input["lag_time_s"].long(), 372 | input["lag_time_m"].long(), 373 | input["lag_time_d"].long(), 374 | input["prior_explan"].long()) 375 | 376 | def forward(self, input): 377 | exercise_seq, part_seq, resp_seq, etime_seq, ltime_s_seq, ltime_m_seq, ltime_d_seq, p_explan_seq = self._get_param_from_input( 378 | input) 379 | 380 | ltime = ltime_m_seq.clone() 381 | 382 | pos_embed = self.pos_embed(self.get_pos_seq().to(exercise_seq.device)) 383 | 384 | # encoder embedding 385 | inter_seq = self.encoder_resp_embed(resp_seq) 386 | exercise_seq = self.encoder_eid_embed(exercise_seq) 387 | part_seq = self.encoder_part_embed(part_seq) 388 | p_explan_seq = self.encoder_p_explan_embed(p_explan_seq) 389 | encoder_input = torch.cat([exercise_seq, part_seq, p_explan_seq, inter_seq], dim=-1) 390 | encoder_input = self.encoder_linear(encoder_input) 391 | encoder_input = self.encoder_layernorm(encoder_input) 392 | encoder_input = self.encoder_dropout(encoder_input) 393 | 394 | # decoder embedding 395 | resp_seq = self.decoder_resp_embed(resp_seq) 396 | etime_seq = self.decoder_etime_embed(etime_seq) 397 | ltime_s_seq = self.decoder_ltime_embed_s(ltime_s_seq) 398 | ltime_m_seq = self.decoder_ltime_embed_m(ltime_m_seq) 399 | ltime_d_seq = self.decoder_ltime_embed_h(ltime_d_seq) 400 | decoder_input = torch.cat([resp_seq, etime_seq, ltime_s_seq, ltime_m_seq, ltime_d_seq], dim=-1) 401 | decoder_input = self.decoder_linear(decoder_input) 402 | decoder_input = self.decoder_layernorm(decoder_input) 403 | decoder_input = self.decoder_dropout(decoder_input) 404 | 405 | attn_mask = future_mask(self.max_seq).to(exercise_seq.device) 406 | # encoding 407 | encoding = encoder_input 408 | for mod in self.encoder: 409 | encoding = mod(encoding, pos_embed, attn_mask) 410 | 411 | srfe = encoding.clone() 412 | srfe = self.srfe(srfe, pos_embed, attn_mask) 413 | 414 | # decoding 415 | decoding = decoder_input 416 | for mod in self.decoder: 417 | decoding = mod(decoding, encoding, ltime, srfe, pos_embed, 418 | attn_mask, attn_mask) 419 | 420 | predict = self.ffn(decoding, srfe) 421 | predict = self.layernorm_out(predict + decoding) 422 | predict = self.classifier(predict) 423 | return predict.squeeze(-1) 424 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | 4 | import config 5 | import utils 6 | from dataset import DKTDataset 7 | from lana_arch import LANA 8 | 9 | import pandas as pd 10 | from tqdm import tqdm 11 | 12 | import torch 13 | from torch import nn 14 | from torch.utils.data import DataLoader 15 | 16 | import pytorch_lightning as pl 17 | from pytorch_lightning.callbacks import ModelCheckpoint 18 | 19 | from sklearn.metrics import roc_auc_score 20 | 21 | import argparse 22 | 23 | # Config 24 | tqdm.pandas() 25 | utils.set_seed(config.SEED) 26 | 27 | 28 | # Pytorch Lightning Module 29 | class TorchModel(pl.LightningModule): 30 | def __init__(self, trainer_args, model_args): 31 | super().__init__() 32 | self.model = LANA(**model_args) 33 | self.val_labels = [] 34 | self.val_outs = [] 35 | 36 | def forward(self, input): 37 | return self.model(input) 38 | 39 | def configure_optimizers(self): 40 | optimizer = torch.optim.AdamW(self.parameters(), lr=config.LEARNING_RATE) 41 | return optimizer 42 | 43 | def training_step(self, batch, batch_idx): 44 | inputs, target = batch 45 | target_mask = (inputs["content_id"] != config.PAD) 46 | output = self(inputs) 47 | output = torch.masked_select(output, target_mask) 48 | target = torch.masked_select(target, target_mask) 49 | 50 | loss = nn.BCEWithLogitsLoss()(output.float(), target.float()) 51 | auc = roc_auc_score(target.cpu(), output.detach().float().cpu()) 52 | self.log("t_loss", loss) 53 | self.log("t_auc", auc, prog_bar=True) 54 | return loss 55 | 56 | def validation_step(self, batch, batch_idx): 57 | inputs, target = batch 58 | target_mask = (inputs["content_id"] != config.PAD) 59 | output = self(inputs) 60 | output = torch.masked_select(output, target_mask) # probability 61 | target = torch.masked_select(target, target_mask) 62 | loss = nn.BCEWithLogitsLoss()(output.float(), target.float()) 63 | auc = roc_auc_score(target.cpu(), output.detach().float().cpu()) 64 | self.val_labels.extend(target.view(-1).data.cpu().numpy()) 65 | self.val_outs.extend(output.view(-1).data.cpu().numpy()) 66 | self.log("v_loss", loss, prog_bar=True) 67 | self.log("v_auc", auc, prog_bar=True) 68 | 69 | def on_validation_epoch_end(self): 70 | real_auc = roc_auc_score(self.val_labels, self.val_outs) 71 | self.log("v_auc", real_auc, prog_bar=True) 72 | self.val_labels = [] 73 | self.val_outs = [] 74 | 75 | 76 | if __name__ == "__main__": 77 | parser = argparse.ArgumentParser(description="LANA") 78 | parser.add_argument('-d', '--data', type=str, required=True, 79 | help="Filepath of the preprocessed data") 80 | 81 | args = parser.parse_args() 82 | train_df = pd.read_pickle(f"{args.data}.train") 83 | val_df = pd.read_pickle(f"{args.data}.valid") 84 | print("train size: ", train_df.shape, "validation size: ", val_df.shape) 85 | 86 | train_dataset = DKTDataset(train_df.values, max_seq=config.MAX_SEQ, 87 | min_seq=config.MIN_SEQ, overlap_seq=config.OVERLAP_SEQ) 88 | val_dataset = DKTDataset(val_df.values, max_seq=config.MAX_SEQ, 89 | min_seq=config.MIN_SEQ, overlap_seq=config.OVERLAP_SEQ) 90 | train_loader = DataLoader(train_dataset, 91 | batch_size=config.BATCH_SIZE, 92 | num_workers=8, 93 | shuffle=True, 94 | pin_memory=True) 95 | val_loader = DataLoader(val_dataset, 96 | batch_size=config.BATCH_SIZE, 97 | num_workers=8, 98 | shuffle=False, 99 | pin_memory=True) 100 | del train_dataset, val_dataset 101 | gc.collect() 102 | 103 | ARGS = {"d_model": config.MODEL_DIMS, 104 | 'n_head': config.N_HEADS, 105 | 'n_encoder': config.NUM_ENCODER, 106 | 'n_decoder': config.NUM_DECODER, 107 | 'dim_feedforward': config.FEEDFORWARD_DIMS, 108 | 'dropout': config.DROPOUT, 109 | 'max_seq': config.MAX_SEQ, 110 | 'n_exercises': config.TOTAL_EID, 111 | 'n_parts': config.TOTAL_PART, 112 | 'n_resp': config.TOTAL_RESP, 113 | 'n_etime': config.TOTAL_ETIME, 114 | 'n_ltime_s': config.TOTAL_LTIME_S, 115 | 'n_ltime_m': config.TOTAL_LTIME_M, 116 | 'n_ltime_d': config.TOTAL_LTIME_D} 117 | 118 | if not os.path.exists("./saved_models"): 119 | os.mkdir("./saved_models") 120 | checkpoint = ModelCheckpoint(dirpath="./saved_models", 121 | filename="model-{epoch}-{v_auc:.2f}", 122 | verbose=True, 123 | save_top_k=1, 124 | save_last=True, 125 | mode="max", 126 | monitor="v_auc") 127 | 128 | lana_model = TorchModel(trainer_args=args, model_args=ARGS) 129 | if config.DEVICE is None or not torch.cuda.is_available(): 130 | trainer = pl.Trainer(progress_bar_refresh_rate=1, 131 | max_epochs=config.EPOCH, callbacks=[checkpoint]) 132 | else: 133 | trainer = pl.Trainer(progress_bar_refresh_rate=1, 134 | max_epochs=config.EPOCH, callbacks=[checkpoint], 135 | gpus=config.DEVICE) 136 | trainer.fit(model=lana_model, 137 | train_dataloader=train_loader, val_dataloaders=val_loader) 138 | trainer.save_checkpoint("./saved_models/final.pt") -------------------------------------------------------------------------------- /main_ll.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import copy 3 | import os 4 | import math 5 | 6 | import config 7 | import utils 8 | import argparse 9 | from dataset_group import DKTDataset 10 | from lana_arch import LANA 11 | 12 | import numpy as np 13 | import pandas as pd 14 | from tqdm import tqdm 15 | from tqdm import trange 16 | 17 | import torch 18 | from torch import nn 19 | from torch.utils.data import Dataset, DataLoader 20 | from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer 21 | from torch.optim.lr_scheduler import LambdaLR 22 | 23 | import pytorch_lightning as pl 24 | from pytorch_lightning.callbacks import ModelCheckpoint 25 | 26 | from sklearn.model_selection import train_test_split 27 | from sklearn.metrics import roc_auc_score, accuracy_score 28 | 29 | # Config 30 | tqdm.pandas() 31 | utils.set_seed(config.SEED) 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser(description="Leveled Learning") 35 | parser.add_argument('-d', '--data', type=str, required=True, 36 | help="Filepath of the preprocessed data") 37 | parser.add_argument('-m', '--model', type=str, required=True, 38 | help="Saved model parameters by Pytorch-Lightning") 39 | parser.add_argument('-n', '--num', type=int, required=True, 40 | help="Number of levels") 41 | parser.add_argument('-t', '--top', type=int, required=True, 42 | help="Top-k") 43 | parser.add_argument('-i', '--mu_itv', type=float, required=True, 44 | help="Mean interval") 45 | 46 | args = parser.parse_args() 47 | 48 | n_models = args.num 49 | train_df = pd.read_pickle(f"{args.data}.train") 50 | val_df = pd.read_pickle(f"{args.data}.valid") 51 | user_performance = utils.read_pickle(f"{args.data}.user") 52 | 53 | print("train size: ", train_df.shape, "validation size: ", val_df.shape) 54 | 55 | train_loaders = [] 56 | print("Generating dataset...") 57 | train_dataset = DKTDataset(train_df, max_seq=config.MAX_SEQ, 58 | min_seq=config.MIN_SEQ, overlap_seq=config.OVERLAP_SEQ, 59 | user_performance=user_performance, n_levels=n_models, mu_itv=args.mu_itv) 60 | train_loader = DataLoader(train_dataset, 61 | batch_size=config.BATCH_SIZE, 62 | num_workers=8, 63 | shuffle=False, 64 | pin_memory=True) 65 | val_dataset = DKTDataset(val_df, max_seq=config.MAX_SEQ, 66 | min_seq=config.MIN_SEQ, overlap_seq=config.OVERLAP_SEQ, 67 | user_performance=user_performance, n_levels=n_models, mu_itv=args.mu_itv) 68 | val_loader = DataLoader(val_dataset, 69 | batch_size=config.BATCH_SIZE, 70 | num_workers=8, 71 | shuffle=False, 72 | pin_memory=True) 73 | print(f"All dataloaders are generated") 74 | del train_dataset 75 | del val_dataset 76 | gc.collect() 77 | 78 | ARGS = {"d_model": config.MODEL_DIMS, 79 | 'n_head': config.N_HEADS, 80 | 'n_encoder': config.NUM_ENCODER, 81 | 'n_decoder': config.NUM_DECODER, 82 | 'dim_feedforward': config.FEEDFORWARD_DIMS, 83 | 'dropout': config.DROPOUT, 84 | 'max_seq': config.MAX_SEQ, 85 | 'n_exercises': config.TOTAL_EID, 86 | 'n_parts': config.TOTAL_PART, 87 | 'n_resp': config.TOTAL_RESP, 88 | 'n_etime': config.TOTAL_ETIME, 89 | 'n_ltime_s': config.TOTAL_LTIME_S, 90 | 'n_ltime_m': config.TOTAL_LTIME_M, 91 | 'n_ltime_d': config.TOTAL_LTIME_D} 92 | 93 | DEVICE = f"cuda:{config.DEVICE[0]}" if torch.cuda.is_available() else "cpu" 94 | models = [LANA(**ARGS).to(DEVICE) for _ in range(n_models)] 95 | optimizers = [torch.optim.AdamW(models[i].parameters(), lr=config.LEARNING_RATE / 1000) for i in range(n_models)] # Finetune 96 | 97 | baseline, optimizer_state, model_state = utils.load_model(args.model, device=DEVICE) 98 | model_state = utils.remove_prefix_from_dict(model_state, "model.") # remove prefix added by pytorch-lightning 99 | 100 | print("Baseline AUC: ", baseline) 101 | print("Loading models...") 102 | for cluster in range(n_models): 103 | models[cluster].load_state_dict(model_state) 104 | 105 | print("All models loaded") 106 | for epoch in range(config.EPOCH): 107 | with tqdm(total=len(train_loader), dynamic_ncols=True) as t: 108 | t.set_description(f"Epoch {epoch}") 109 | for batch_idx, batch in enumerate(train_loader): 110 | for cluster in range(n_models): 111 | models[cluster].train() 112 | optimizers[cluster].zero_grad() 113 | inputs, target, probs = batch 114 | inputs = utils.dict_to_device(inputs, DEVICE) 115 | target = target.to(DEVICE) 116 | target_mask = (inputs["content_id"] != config.PAD) 117 | probs = probs.to(DEVICE) 118 | total_loss = 0. 119 | target = torch.masked_select(target, target_mask) 120 | for cluster in range(n_models): 121 | output = models[cluster](inputs) 122 | output = torch.masked_select(output, target_mask) 123 | weight = torch.masked_select(probs[:, cluster].unsqueeze(-1), target_mask) 124 | loss = nn.BCEWithLogitsLoss(reduction="none")(output.float(), target.float()) 125 | loss = loss * weight 126 | loss = loss.mean() 127 | total_loss += loss.item() 128 | loss.backward() 129 | optimizers[cluster].step() 130 | t.set_postfix({"t_loss": ("%.4f" % total_loss)}) 131 | t.update() 132 | for cluster in range(n_models): 133 | models[cluster].eval() 134 | 135 | with torch.no_grad(): 136 | val_labels = [] 137 | val_outs = [] 138 | with tqdm(total=len(val_loader), dynamic_ncols=True) as t: 139 | t.set_description("Validating") 140 | for batch_idx, batch in enumerate(val_loader): 141 | inputs, target, probs = batch 142 | inputs = utils.dict_to_device(inputs, DEVICE) 143 | target = target.to(DEVICE) 144 | target_mask = (inputs["content_id"] != config.PAD) 145 | probs = probs.to(DEVICE) 146 | probs = utils.prob_topk(probs, args.top) 147 | output = torch.zeros(target.shape, device=DEVICE) 148 | for cluster in range(n_models): 149 | m_output = models[cluster](inputs) 150 | output += (m_output * probs[:, cluster].unsqueeze(-1)) 151 | 152 | output = torch.masked_select(output, target_mask) 153 | target = torch.masked_select(target, target_mask) 154 | loss = nn.BCEWithLogitsLoss()(output.float(), target.float()) 155 | auc = roc_auc_score(target.cpu(), output.detach().float().cpu()) 156 | val_labels.extend(target.view(-1).data.cpu().numpy()) 157 | val_outs.extend(output.view(-1).data.cpu().numpy()) 158 | t.set_postfix({"v_loss": ("%.4f" % loss.item()), "v_auc": ("%.4f" % auc)}) 159 | t.update() 160 | real_auc = roc_auc_score(val_labels, val_outs) 161 | print(f"AUC: ", real_auc) 162 | -------------------------------------------------------------------------------- /package/pyirt/pyirt/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["_pyirt", "solver", "util"] 2 | from ._pyirt import irt 3 | -------------------------------------------------------------------------------- /package/pyirt/pyirt/_pyirt.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | from .solver import model 3 | from .dao import localDAO 4 | from .logger import Logger 5 | 6 | 7 | def irt(data_src, 8 | dao_type='memory', 9 | theta_bnds=[-4, 4], num_theta=11, 10 | alpha_bnds=[0.25, 2], beta_bnds=[-2, 2], in_guess_param={}, 11 | model_spec='2PL', 12 | max_iter=10, tol=1e-3, nargout=2, 13 | is_parallel=False, num_cpu=6, check_interval=60, 14 | mode='debug', log_path=None): 15 | 16 | # add logging 17 | logger = Logger.logger(log_path) 18 | 19 | # load data 20 | logger.info("start loading data") 21 | if dao_type == 'memory': 22 | dao_instance = localDAO(data_src, logger) 23 | elif dao_type == "db": 24 | dao_instance = data_src 25 | else: 26 | raise ValueError("dao type needs to be either memory or db") 27 | logger.info("data loaded") 28 | 29 | # setup the model 30 | if model_spec == '2PL': 31 | mod = model.IRT_MMLE_2PL(dao_instance, 32 | logger, 33 | dao_type=dao_type, 34 | is_parallel=is_parallel, 35 | num_cpu=num_cpu, 36 | check_interval=check_interval, 37 | mode=mode) 38 | else: 39 | raise Exception('Unknown model specification.') 40 | 41 | # specify the irt parameters 42 | mod.set_options(theta_bnds, num_theta, alpha_bnds, beta_bnds, max_iter, tol) 43 | mod.set_guess_param(in_guess_param) 44 | 45 | # solve 46 | mod.solve_EM() 47 | logger.info("parameter estimated") 48 | # output 49 | item_param_dict = mod.get_item_param() 50 | logger.info("parameter retrieved") 51 | 52 | if nargout == 1: 53 | return item_param_dict 54 | elif nargout == 2: 55 | user_param_dict = mod.get_user_param() 56 | return item_param_dict, user_param_dict 57 | else: 58 | raise Exception('Invalid number of argument') 59 | -------------------------------------------------------------------------------- /package/pyirt/pyirt/algo.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | from .util import clib, tools 3 | import numpy as np 4 | 5 | 6 | def update_theta_distribution(data, num_theta, theta_prior_val, theta_density, item_param_dict): 7 | ''' 8 | data = [(item_idx int, ans_tag binary)] 9 | ''' 10 | 11 | ''' 12 | Basic Math. 13 | P_t(theta, data |q_param) = p(data|q_param, theta)*p_[t-1](theta) 14 | p_t(data|q_param) = sum(p_t(theta,data|q_param)) over theta 15 | p_t(theta|data, q_param) = P_t(theta, data|q_param)/p_t(data|q_param) 16 | ''' 17 | likelihood_vec = np.zeros(num_theta) 18 | 19 | for k in range(num_theta): 20 | theta = theta_prior_val[k] 21 | ell = 0.0 22 | for log in data: 23 | item_idx = log[0] 24 | ans_tag = log[1] 25 | alpha = item_param_dict[item_idx]['alpha'] 26 | beta = item_param_dict[item_idx]['beta'] 27 | c = item_param_dict[item_idx]['c'] 28 | ell += clib.log_likelihood_2PL(0.0 + ans_tag, 1.0 - ans_tag, theta, alpha, beta, c) 29 | likelihood_vec[k] = ell 30 | 31 | # posterior 32 | joint_llk_vec = likelihood_vec + np.log(theta_density) 33 | marginal = tools.logsum(joint_llk_vec) 34 | posterior = np.exp(joint_llk_vec - marginal) 35 | 36 | return posterior 37 | -------------------------------------------------------------------------------- /package/pyirt/pyirt/dao.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | import io 3 | import time 4 | from collections import defaultdict 5 | import pymongo 6 | from datetime import datetime 7 | import numpy as np 8 | from .util.dao import loadFromHandle, loadFromTuples, construct_ref_dict 9 | 10 | from decouple import config 11 | MONGO_USER_NAME = config('MONGO_USER_NAME', default="") 12 | MONGO_PASSWORD = config('MONGO_PASSWORD', default="") 13 | MONGO_ADDRESS = config('MONGO_ADDRESS', default="") 14 | MONGO_AUTH_SOURCE = config('MONGO_AUTH_SOURCE', default="") 15 | MONGO_DB_NAME = config('MONGO_DB_NAME', default="") 16 | 17 | 18 | class mongoDb(object): 19 | """ cannot use singleton design, otherwise gets 'Warning: MongoClient opened before fork. Create MongoClient only after forking.' 20 | """ 21 | def __init__(self): 22 | mongouri = 'mongodb://{un}:{pw}@{addr}'.format(un=MONGO_USER_NAME, pw=MONGO_PASSWORD, addr=MONGO_ADDRESS) 23 | if MONGO_AUTH_SOURCE: 24 | mongouri += '/?authsource={auth_src}'.format(auth_src=MONGO_AUTH_SOURCE) 25 | # connect 26 | try: 27 | self.client = pymongo.MongoClient(mongouri, connect=False, serverSelectionTimeoutMS=10, waitQueueTimeoutMS=100, readPreference='secondaryPreferred') 28 | except Exception as e: 29 | raise e 30 | self.user2item_conn = self.client[MONGO_DB_NAME]['irt_user2item'] 31 | self.item2user_conn = self.client[MONGO_DB_NAME]['irt_item2user'] 32 | 33 | def __del__(self): 34 | self.client.close() 35 | 36 | 37 | def search_filter(search_id, gid): 38 | return {'id': search_id, 'gid': gid} 39 | 40 | 41 | class mongoDAO(object): 42 | # NOTE: mongoDAO does not use the runtime logger 43 | # NOTE: The client and the connection is not passed by self because of parallel processing 44 | def __init__(self, group_id=1, is_msg=False): 45 | db = mongoDb() 46 | user_ids = list(set([x['id'] for x in db.user2item_conn.find({'gid': group_id}, {'id': 1})])) 47 | item_ids = list(set([x['id'] for x in db.item2user_conn.find({'gid': group_id}, {'id': 1})])) 48 | 49 | _, self.user_idx_ref, self.user_reverse_idx_ref = construct_ref_dict(user_ids) 50 | _, self.item_idx_ref, self.item_reverse_idx_ref = construct_ref_dict(item_ids) 51 | 52 | self.stat = {'user': len(self.user_idx_ref.keys()), 'item': len(self.item_idx_ref.keys())} 53 | 54 | self.gid = group_id 55 | self.is_msg = is_msg 56 | 57 | def open_conn(self, name): 58 | if name == "item2user": 59 | return mongoDb().item2user_conn 60 | elif name == "user2item": 61 | return mongoDb().user2item_conn 62 | else: 63 | raise ValueError('conn name must be either item2user or user2item') 64 | 65 | def get_num(self, name): 66 | if name not in ['user', 'item']: 67 | raise Exception('Unknown stat source %s' % name) 68 | return self.stat[name] 69 | 70 | def get_log(self, user_idx, user2item_conn): 71 | user_id = self.translate('user', user_idx) 72 | # query 73 | if self.is_msg: 74 | stime = datetime.now() 75 | res = user2item_conn.find(search_filter(user_id, self.gid)) 76 | etime = datetime.now() 77 | search_time = int((etime - stime).microseconds / 1000) 78 | if search_time > 100: 79 | print('warning: slow search:%d' % search_time) 80 | else: 81 | res = user2item_conn.find(search_filter(user_id, self.gid)) 82 | # parse 83 | res_num = res.count() 84 | if res_num == 0: 85 | return_list = [] 86 | elif res_num > 1: 87 | raise Exception('duplicate doc for (%s, %d) in user2item' % (user_id, self.gid)) 88 | else: 89 | log_list = res[0]['data'] 90 | return_list = [(self.item_idx_ref[x[0]], x[1]) for x in log_list] 91 | return return_list 92 | 93 | def get_map(self, item_idx, ans_key_list, item2user_conn): 94 | item_id = self.translate('item', item_idx) 95 | # query 96 | if self.is_msg: 97 | stime = datetime.now() 98 | res = item2user_conn.find(search_filter(item_id, self.gid)) 99 | etime = datetime.now() 100 | search_time = int((etime - stime).microseconds / 1000) 101 | if search_time > 100: 102 | print('warning:slow search:%d' % search_time) 103 | else: 104 | res = item2user_conn.find(search_filter(item_id, self.gid)) 105 | # parse 106 | res_num = res.count() 107 | if res_num == 0: 108 | return_list = [[] for ans_key in ans_key_list] 109 | elif res_num > 1: 110 | raise Exception('duplicate doc for (%s, %d) in item2user' % (item_id, self.gid)) 111 | else: 112 | doc = res[0]['data'] 113 | return_list = [] 114 | for ans_key in ans_key_list: 115 | if str(ans_key) in doc: 116 | return_list.append([self.user_idx_ref[x] for x in doc[str(ans_key)]]) 117 | else: 118 | return_list.append([]) 119 | return return_list 120 | 121 | def translate(self, data_type, idx): 122 | if data_type == 'item': 123 | return self.item_reverse_idx_ref[idx] 124 | elif data_type == 'user': 125 | return self.user_reverse_idx_ref[idx] 126 | 127 | 128 | class localDAO(object): 129 | 130 | def __init__(self, src, logger): 131 | 132 | self.database = localDataBase(src, logger) 133 | 134 | # quasi-bitmap 135 | user_id_idx_vec, self.user_idx_ref, self.user_reverse_idx_ref = construct_ref_dict(self.database.user_ids) 136 | item_id_idx_vec, self.item_idx_ref, self.item_reverse_idx_ref = construct_ref_dict(self.database.item_ids) 137 | 138 | self.database.setup(user_id_idx_vec, item_id_idx_vec, self.database.ans_tags) 139 | 140 | def get_num(self, name): 141 | if name not in ['user', 'item']: 142 | raise Exception('Unknown stat source %s' % name) 143 | return self.database.stat[name] 144 | 145 | def get_log(self, user_idx): 146 | return self.database.user2item[user_idx] 147 | 148 | def get_map(self, item_idx, ans_key_list): 149 | results = [] 150 | for ans_key in ans_key_list: 151 | try: 152 | results.append(self.database.item2user_map[str(ans_key)][item_idx]) 153 | except KeyError: 154 | results.append([]) 155 | return results 156 | 157 | def close_conn(self): 158 | pass 159 | 160 | def translate(self, data_type, idx): 161 | if data_type == 'item': 162 | return self.item_reverse_idx_ref[idx] 163 | elif data_type == 'user': 164 | return self.user_reverse_idx_ref[idx] 165 | 166 | 167 | class localDataBase(object): 168 | def __init__(self, src, logger): 169 | 170 | self.logger = logger 171 | if isinstance(src, io.IOBase): 172 | # if the src is file handle 173 | self.user_ids, self.item_ids, self.ans_tags = loadFromHandle(src) 174 | else: 175 | # if the src is list of tuples 176 | self.user_ids, self.item_ids, self.ans_tags = loadFromTuples(src) 177 | 178 | def setup(self, user_idx_vec, item_idx_vec, ans_tags, msg=False): 179 | 180 | start_time = time.time() 181 | self._process_data(user_idx_vec, item_idx_vec, ans_tags) 182 | if msg: 183 | self.logger.debug("--- Process: %f secs ---" % np.round((time.time() - start_time))) 184 | 185 | # initialize some intermediate variables used in the E step 186 | start_time = time.time() 187 | self._init_item2user_map() 188 | if msg: 189 | self.logger.debug("--- Sparse Mapping: %f secs ---" % np.round((time.time() - start_time))) 190 | 191 | ''' 192 | Need the following dictionary for esitmation routine 193 | (1) item -> user: key: item_id, value: (user_id, ans_tag) 194 | (2) user -> item: key: user_id, value: (item_id, ans_tag) 195 | ''' 196 | 197 | def _process_data(self, user_idx_vec, item_idx_vec, ans_tags): 198 | self.item2user = {} 199 | self.user2item = defaultdict(list) 200 | 201 | self.stat = {} 202 | num_log = len(user_idx_vec) 203 | self.stat['user'] = max(user_idx_vec) + 1 # start count from 0 204 | self.stat['item'] = max(item_idx_vec) + 1 205 | 206 | for i in range(num_log): 207 | item_idx = item_idx_vec[i] 208 | user_idx = user_idx_vec[i] 209 | ans_tag = ans_tags[i] 210 | # add to the data dictionary 211 | if item_idx not in self.item2user: 212 | self.item2user[item_idx] = defaultdict(list) 213 | self.item2user[item_idx][ans_tag].append(user_idx) 214 | self.user2item[user_idx].append((item_idx, ans_tag)) 215 | 216 | def _init_item2user_map(self, ans_key_list=['0', '1']): 217 | 218 | self.item2user_map = defaultdict(dict) 219 | 220 | for item_idx, log_result in self.item2user.items(): 221 | for ans_tag, user_idx_vec in log_result.items(): 222 | self.item2user_map[str(ans_tag)][item_idx] = user_idx_vec 223 | -------------------------------------------------------------------------------- /package/pyirt/pyirt/logger.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import logging 3 | import os 4 | 5 | 6 | class Logger(): 7 | @staticmethod 8 | def logger(log_path): 9 | if log_path is not None: 10 | log_dir = os.path.dirname(log_path) 11 | if not os.path.exists(log_dir): 12 | os.makedirs(log_dir) 13 | logging.basicConfig( 14 | level=logging.DEBUG, 15 | format='%(asctime)s %(pathname)s[line:%(lineno)d] %(levelname)s: %(message)s', 16 | datefmt='%Y-%m-%d %H:%M:%S', 17 | filename=log_path, 18 | filemode='w') 19 | else: 20 | logging.basicConfig( 21 | level=logging.DEBUG, 22 | format='%(asctime)s %(levelname)s: %(message)s', 23 | datefmt='%Y-%m-%d %H:%M:%S') 24 | 25 | # 创建一个handler,用于输出到控制台 26 | console = logging.StreamHandler() 27 | console.setLevel(logging.DEBUG) 28 | 29 | logger = logging.getLogger() 30 | # 给logger添加handler 31 | logger.addHandler(console) 32 | return logger 33 | -------------------------------------------------------------------------------- /package/pyirt/pyirt/solver/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Soptq/LANA-pytorch/c1ce06775c4faccc7fb612d23403a0123666cb2d/package/pyirt/pyirt/solver/__init__.py -------------------------------------------------------------------------------- /package/pyirt/pyirt/solver/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python -W ignore 2 | ''' 3 | The model is an implementation of EM algorithm of IRT 4 | 5 | 6 | For reference, see: 7 | Brad Hanson, IRT Parameter Estimation using the EM Algorithm, 2000 8 | 9 | The current version only deals with unidimension theta 10 | 11 | ''' 12 | import numpy as np 13 | from scipy.stats import norm 14 | import time 15 | from copy import deepcopy 16 | 17 | from ..util import clib, tools 18 | from ..solver import optimizer 19 | from ..algo import update_theta_distribution 20 | 21 | from datetime import datetime 22 | import multiprocess as mp 23 | from tqdm import tqdm 24 | 25 | 26 | def procs_operator(procs, TIMEOUT, check_interval): 27 | for p in procs: 28 | p.start() 29 | 30 | start = time.time() 31 | while time.time() - start < TIMEOUT: 32 | if any(p.is_alive() for p in procs): 33 | time.sleep(check_interval) 34 | else: 35 | for p in procs: 36 | p.join() 37 | break 38 | else: 39 | for p in procs: 40 | p.terminate() 41 | p.join() 42 | 43 | raise Exception('Time out, killing all process') 44 | return procs 45 | 46 | 47 | class IRT_MMLE_2PL(object): 48 | 49 | ''' 50 | Exposed methods 51 | (1) set options 52 | (2) solve 53 | (3) get esitmated result 54 | ''' 55 | def __init__(self, 56 | dao_instance, logger, 57 | dao_type='memory', 58 | is_parallel=False, num_cpu=6, check_interval=60, 59 | mode='debug'): 60 | # interface to data 61 | self.dao = dao_instance 62 | self.dao_type = dao_type 63 | self.logger = logger 64 | self.num_iter = 1 65 | self.ell_list = [] 66 | self.last_avg_prob = 0 67 | 68 | self.is_parallel = is_parallel 69 | self.num_cpu = min(num_cpu, mp.cpu_count()) 70 | self.check_interval = check_interval 71 | self.mode = mode 72 | 73 | def set_options(self, theta_bnds, num_theta, alpha_bnds, beta_bnds, max_iter, tol): 74 | # user 75 | self.num_theta = num_theta 76 | self._init_user_param(theta_bnds[0], theta_bnds[1], num_theta) 77 | # item 78 | boundary = {'alpha': alpha_bnds, 'beta': beta_bnds} 79 | # solver 80 | solver_type = 'gradient' 81 | is_constrained = True 82 | 83 | self._init_solver_param(is_constrained, boundary, solver_type, max_iter, tol) 84 | 85 | def set_guess_param(self, in_guess_param): 86 | self.guess_param_dict = {} 87 | for item_idx in range(self.dao.get_num('item')): 88 | item_id = self.dao.translate('item', item_idx) 89 | if item_id in in_guess_param: 90 | self.guess_param_dict[item_idx] = {'c': float(in_guess_param[item_id])} 91 | else: 92 | self.guess_param_dict[item_idx] = {'c': 0.0} # if null then set as 0 93 | 94 | def solve_EM(self): 95 | # data dependent initialization 96 | self._init_item_param() 97 | 98 | # main routine 99 | while True: 100 | # ----- E step ----- 101 | stime = datetime.now() 102 | self._exp_step() 103 | etime = datetime.now() 104 | runtime = (etime - stime).microseconds / 1000 105 | self.logger.debug('E step runs for %s sec' % runtime) 106 | 107 | # ----- M step ----- 108 | stime = datetime.now() 109 | self._max_step() 110 | etime = datetime.now() 111 | runtime = (etime - stime).microseconds / 1000 112 | self.logger.debug('M step runs for %s sec' % runtime) 113 | 114 | # ---- Stop Condition ---- 115 | stime = datetime.now() 116 | is_stop = self._check_stop() 117 | etime = datetime.now() 118 | runtime = (etime - stime).microseconds / 1000 119 | self.logger.debug('stop condition runs for %s sec' % runtime) 120 | 121 | if is_stop: 122 | break 123 | 124 | def get_item_param(self): 125 | output_item_param = {} 126 | for item_idx in range(self.dao.get_num('item')): 127 | item_id = self.dao.translate('item', item_idx) 128 | output_item_param[item_id] = self.item_param_dict[item_idx] 129 | return output_item_param 130 | 131 | def get_user_param(self): 132 | output_user_param = {} 133 | theta_vec = self.__calc_theta() 134 | for user_idx in range(self.dao.get_num('user')): 135 | user_id = self.dao.translate('user', user_idx) 136 | output_user_param[user_id] = theta_vec[user_idx] 137 | return output_user_param 138 | 139 | ''' 140 | Main Routine 141 | ''' 142 | 143 | def _exp_step(self): 144 | ''' 145 | Basic Math: 146 | In the maximization step, need to use E_[j,k](Y=1),E_[j,k](Y=0) 147 | E(Y=1|param_j,theta_k) = sum_i(data_[i,j]*P(Y=1|param_j,theta_[i,k])) 148 | since data_[i,j] = 0/1, it is equivalent to sum all done right users 149 | 150 | E(Y=0|param_j,theta_k) = sum_i( 151 | (1-data_[i,j]) *(1-P(Y=1|param_j,theta_[i,k]) 152 | ) 153 | By similar logic, it is equivalent to sum (1-p) for all done wrong users 154 | ''' 155 | 156 | # (1) update the posterior distribution of theta 157 | self.__update_theta_distr() 158 | 159 | # (2) marginalize 160 | # because of the sparsity, the expected right and wrong may not sum up 161 | # to the total num of items! 162 | self.__get_expect_count() 163 | 164 | def _max_step(self): 165 | ''' 166 | Basic Math 167 | log likelihood(param_j) = sum_k(log likelihood(param_j, theta_k)) 168 | ''' 169 | def update(d, start_idx, end_idx): 170 | try: 171 | for item_idx in tqdm(range(start_idx, end_idx)): 172 | initial_guess_val = (self.item_param_dict[item_idx]['beta'], 173 | self.item_param_dict[item_idx]['alpha']) 174 | opt_worker.set_initial_guess(initial_guess_val) # the value is a mix of 1/0 and current estimate 175 | opt_worker.set_c(self.item_param_dict[item_idx]['c']) 176 | 177 | # estimate 178 | expected_right_count = self.item_expected_right_by_theta[:, item_idx] 179 | expected_wrong_count = self.item_expected_wrong_by_theta[:, item_idx] 180 | input_data = [expected_right_count, expected_wrong_count] 181 | opt_worker.load_res_data(input_data) 182 | try: 183 | est_param = opt_worker.solve_param_mix(self.is_constrained) 184 | except Exception as e: 185 | if self.mode == 'production': 186 | # In production mode, use the previous iteration 187 | self.logger.error('Item %d does not fit' % item_idx) 188 | d[item_idx] = self.item_param_dict[item_idx] 189 | else: 190 | self.logger.critical(e.strerror) 191 | raise e 192 | finally: 193 | d[item_idx] = est_param 194 | except Exception as e: 195 | self.logger.critical("Unexpected error:", str(e)) 196 | raise e 197 | # [A] max for item parameter 198 | opt_worker = optimizer.irt_2PL_Optimizer() 199 | # the boundary is universal 200 | # the boundary is set regardless of the constrained option because the 201 | # constrained search serves as backup for outlier cases 202 | opt_worker.set_bounds([self.beta_bnds, self.alpha_bnds]) 203 | 204 | # theta value is universal 205 | opt_worker.set_theta(self.theta_prior_val) 206 | num_item = self.dao.get_num('item') 207 | 208 | if num_item > self.num_cpu and self.is_parallel: 209 | num_chunk = self.num_cpu 210 | else: 211 | num_chunk = 1 212 | 213 | # [A] calculate p(data,param|theta) 214 | chunk_list = tools.cut_list(num_item, num_chunk) 215 | 216 | procs = [] 217 | manager = mp.Manager() 218 | procs_repo = manager.dict() 219 | 220 | for i in range(num_chunk): 221 | p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0], chunk_list[i][1],)) 222 | procs.append(p) 223 | 224 | if num_chunk > 1: 225 | procs = procs_operator(procs, 3600, self.check_interval) 226 | else: 227 | procs = procs_operator(procs, 7200, 0.1) 228 | 229 | for item_idx in range(num_item): 230 | self.item_param_dict[item_idx] = { 231 | 'beta': procs_repo[item_idx][0], 232 | 'alpha': procs_repo[item_idx][1], 233 | 'c': self.item_param_dict[item_idx]['c'] 234 | } 235 | 236 | # [B] max for theta density 237 | self.theta_density = self.posterior_theta_distr.sum(axis=0) / self.posterior_theta_distr.sum() 238 | self.__check_theta_density() 239 | 240 | def _check_stop(self): 241 | ''' 242 | preserve user and item parameter from last iteration. This is useful in restoring after a declining llk iteration 243 | ''' 244 | self.logger.debug('score calculating') 245 | avg_prob = np.exp(self.__calc_data_likelihood()) 246 | self.logger.debug('score calculated.') 247 | 248 | self.ell_list.append(avg_prob) 249 | self.logger.debug(avg_prob) 250 | 251 | diff = avg_prob - self.last_avg_prob 252 | 253 | if diff >= 0 and diff <= self.tol: 254 | self.logger.info('EM converged at iteration %d.' % self.num_iter) 255 | return True 256 | elif diff < 0: 257 | self.item_param_dict = self.last_item_param_dict 258 | self.logger.info('Likelihood descrease, stops at iteration %d.' % self.num_iter) 259 | return True 260 | else: 261 | # diff larger than tolerance 262 | # update the stop condition 263 | self.last_avg_prob = avg_prob 264 | self.num_iter += 1 265 | 266 | if (self.num_iter > self.max_iter): 267 | self.logger.info('EM does not converge within max iteration') 268 | return True 269 | if self.num_iter != 1: 270 | self.last_item_param_dict = deepcopy(self.item_param_dict) 271 | return False 272 | 273 | def _init_solver_param(self, is_constrained, boundary, solver_type, max_iter, tol): 274 | # initialize bounds 275 | self.is_constrained = is_constrained 276 | self.alpha_bnds = boundary['alpha'] 277 | self.beta_bnds = boundary['beta'] 278 | self.solver_type = solver_type 279 | self.max_iter = max_iter 280 | self.tol = tol 281 | if solver_type == 'gradient' and not is_constrained: 282 | raise Exception('BFGS has to be constrained') 283 | 284 | def _init_item_param(self): 285 | self.item_param_dict = {} 286 | for item_idx in range(self.dao.get_num('item')): 287 | # need to call the old item_id 288 | c = self.guess_param_dict[item_idx]['c'] 289 | self.item_param_dict[item_idx] = {'alpha': 1.0, 'beta': 0.0, 'c': c} 290 | 291 | def _init_user_param(self, theta_min, theta_max, num_theta, dist='normal'): 292 | # generte value 293 | self.theta_prior_val = np.linspace(theta_min, theta_max, num=num_theta) 294 | if self.num_theta != len(self.theta_prior_val): 295 | raise Exception('wrong number of inintial theta values') 296 | # use a normal approximation 297 | if dist == 'uniform': 298 | self.theta_density = np.ones(num_theta) / num_theta 299 | elif dist == 'normal': 300 | norm_pdf = [norm.pdf(x) for x in self.theta_prior_val] 301 | normalizer = sum(norm_pdf) 302 | self.theta_density = np.array([x / normalizer for x in norm_pdf]) 303 | else: 304 | raise Exception('invalid theta prior distibution %s' % dist) 305 | self.__check_theta_density() 306 | # space for each learner 307 | self.posterior_theta_distr = np.zeros((self.dao.get_num('user'), num_theta)) 308 | 309 | def __update_theta_distr(self): 310 | def update(d, start_idx, end_idx): 311 | try: 312 | if self.dao_type == 'db': 313 | user2item_conn = self.dao.open_conn('user2item') 314 | for user_idx in tqdm(range(start_idx, end_idx)): 315 | if self.dao_type == 'db': 316 | logs = self.dao.get_log(user_idx, user2item_conn) 317 | else: 318 | logs = self.dao.get_log(user_idx) 319 | d[user_idx] = update_theta_distribution(logs, self.num_theta, self.theta_prior_val, self.theta_density, self.item_param_dict) 320 | except Exception as e: 321 | self.logger.critical("Unexpected error:", str(e)) 322 | raise e 323 | # [A] calculate p(data,param|theta) 324 | num_user = self.dao.get_num('user') 325 | if num_user > self.num_cpu and self.is_parallel: 326 | num_chunk = self.num_cpu 327 | else: 328 | num_chunk = 1 329 | 330 | # [A] calculate p(data,param|theta) 331 | chunk_list = tools.cut_list(num_user, num_chunk) 332 | procs = [] 333 | manager = mp.Manager() 334 | procs_repo = manager.dict() 335 | for i in range(num_chunk): 336 | p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0], chunk_list[i][1],)) 337 | procs.append(p) 338 | 339 | if num_chunk > 1: 340 | procs = procs_operator(procs, 3600 * 10, self.check_interval) 341 | else: 342 | procs = procs_operator(procs, 3600 * 24, 0.1) 343 | 344 | for user_idx in range(num_user): 345 | self.posterior_theta_distr[user_idx, :] = procs_repo[user_idx] 346 | 347 | # When the loop finish, check if the theta_density adds up to unity for each user 348 | check_user_distr_marginal = np.sum(self.posterior_theta_distr, axis=1) 349 | if any(abs(check_user_distr_marginal - 1.0) > 0.0001): 350 | raise Exception('The posterior distribution of user ability is not proper') 351 | 352 | def __check_theta_density(self): 353 | if abs(sum(self.theta_density) - 1) > 1e-6: 354 | raise Exception('theta density does not sum upto 1') 355 | 356 | if self.theta_density.shape != (self.num_theta,): 357 | raise Exception('theta desnity has wrong shape (%s,%s)' % self.theta_density.shape) 358 | 359 | def __get_expect_count(self): 360 | def update(d, start_idx, end_idx): 361 | try: 362 | if self.dao_type == 'db': 363 | item2user_conn = self.dao.open_conn('item2user') 364 | for item_idx in tqdm(range(start_idx, end_idx)): 365 | if self.dao_type == 'db': 366 | map_user_idx_vec = self.dao.get_map(item_idx, ['1', '0'], item2user_conn) 367 | else: 368 | map_user_idx_vec = self.dao.get_map(item_idx, ['1', '0']) 369 | d[item_idx] = { 370 | 1: np.sum(self.posterior_theta_distr[map_user_idx_vec[0], :], axis=0), 371 | 0: np.sum(self.posterior_theta_distr[map_user_idx_vec[1], :], axis=0) 372 | } 373 | except Exception as e: 374 | self.logger.critical("Unexpected error:", str(e)) 375 | raise e 376 | 377 | num_item = self.dao.get_num('item') 378 | if num_item > self.num_cpu and self.is_parallel: 379 | num_chunk = self.num_cpu 380 | else: 381 | num_chunk = 1 382 | 383 | # [A] calculate p(data,param|theta) 384 | chunk_list = tools.cut_list(num_item, num_chunk) 385 | procs = [] 386 | manager = mp.Manager() 387 | procs_repo = manager.dict() 388 | for i in range(num_chunk): 389 | p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0], chunk_list[i][1],)) 390 | procs.append(p) 391 | 392 | if num_chunk > 1: 393 | procs = procs_operator(procs, 5400, self.check_interval) 394 | else: 395 | procs = procs_operator(procs, 7200, 0.1) 396 | 397 | self.item_expected_right_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) 398 | self.item_expected_wrong_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) 399 | for item_idx in range(num_item): 400 | self.item_expected_right_by_theta[:, item_idx] = procs_repo[item_idx][1] 401 | self.item_expected_wrong_by_theta[:, item_idx] = procs_repo[item_idx][0] 402 | 403 | def __calc_data_likelihood(self): 404 | # calculate the likelihood for the data set 405 | # geometric within learner and across learner 406 | # 1/N * sum[i](1/Ni *sum[j] (log pij)) 407 | def update(tot_llk, cnt, start_idx, end_idx): 408 | try: 409 | if self.dao_type == 'db': 410 | user2item_conn = self.dao.open_conn('user2item') 411 | for user_idx in tqdm(range(start_idx, end_idx)): 412 | theta = theta_vec[user_idx] 413 | # find all the item_id 414 | if self.dao_type == 'db': 415 | logs = self.dao.get_log(user_idx, user2item_conn) 416 | else: 417 | logs = self.dao.get_log(user_idx) 418 | if len(logs) == 0: 419 | continue 420 | ell = 0 421 | for log in logs: 422 | item_idx = log[0] 423 | ans_tag = log[1] 424 | alpha = self.item_param_dict[item_idx]['alpha'] 425 | beta = self.item_param_dict[item_idx]['beta'] 426 | c = self.item_param_dict[item_idx]['c'] 427 | ell += clib.log_likelihood_2PL(0.0 + ans_tag, 1.0 - ans_tag, theta, alpha, beta, c) 428 | with tot_llk.get_lock(): 429 | tot_llk.value += ell / len(logs) 430 | with cnt.get_lock(): 431 | cnt.value += 1 432 | except Exception as e: 433 | self.logger.critical("Unexpected error:", str(e)) 434 | raise e 435 | 436 | theta_vec = self.__calc_theta() 437 | num_user = self.dao.get_num('user') 438 | if num_user > self.num_cpu and self.is_parallel: 439 | num_chunk = self.num_cpu 440 | else: 441 | num_chunk = 1 442 | user_ell = mp.Value('d', 0.0) 443 | user_cnt = mp.Value('i', 0) 444 | chunk_list = tools.cut_list(num_user, num_chunk) 445 | procs = [] 446 | for i in range(num_chunk): 447 | p = mp.Process(target=update, args=(user_ell, user_cnt, chunk_list[i][0], chunk_list[i][1],)) 448 | procs.append(p) 449 | if num_chunk > 1: 450 | procs = procs_operator(procs, 1200, self.check_interval) 451 | else: 452 | procs = procs_operator(procs, 7200, 0.1) 453 | avg_ell = user_ell.value / user_cnt.value 454 | 455 | return avg_ell 456 | 457 | def __calc_theta(self): 458 | return np.dot(self.posterior_theta_distr, self.theta_prior_val) 459 | -------------------------------------------------------------------------------- /package/pyirt/pyirt/solver/optimizer.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | import numpy as np 3 | from scipy.optimize import minimize 4 | from scipy.optimize import minimize_scalar 5 | 6 | from ..util import clib, tools 7 | 8 | np.seterr(over='raise') 9 | 10 | 11 | class irt_2PL_Optimizer(object): 12 | 13 | def load_res_data(self, res_data): 14 | self.res_data = np.array(res_data) 15 | 16 | def set_theta(self, theta): 17 | self.theta = theta 18 | 19 | def set_c(self, c): 20 | self.c = c 21 | 22 | def set_bounds(self, bnds): 23 | self.bnds = bnds 24 | 25 | def set_initial_guess(self, x0): 26 | self.x0 = x0 27 | 28 | # generate the likelihood function 29 | @staticmethod 30 | def _likelihood(res_data, theta_vec, alpha, beta, c): 31 | # for MMLE method, y1 and y0 will be expected count 32 | y1 = res_data[0] 33 | y0 = res_data[1] 34 | 35 | # check for equal length between y1,y0 and theta_vec 36 | num_data = len(y1) 37 | if len(y0) != num_data: 38 | raise ValueError('The response data does not match in length. y0:%s, y1:%s' % (y0, y1)) 39 | if len(theta_vec) != num_data: 40 | raise ValueError('The response data does not match theta vec in length. theta_vec:%s, num_data:%s' % (theta_vec, num_data)) 41 | 42 | if sum(y1 < 0) > 0 or sum(y0 < 0) > 0: 43 | raise ValueError('y1 or y0 contains negative count.') 44 | # this is the likelihood 45 | likelihood_vec = [clib.log_likelihood_2PL(y1[i], y0[i], theta_vec[i], 46 | alpha, beta, c) 47 | for i in range(num_data)] 48 | # transform into negative likelihood 49 | ell = -sum(likelihood_vec) 50 | 51 | return ell 52 | 53 | @staticmethod 54 | def _gradient(res_data, theta_vec, alpha, beta, c): 55 | # res should be numpy array 56 | y1 = res_data[0] 57 | y0 = res_data[1] 58 | num_data = len(y1) 59 | 60 | der = np.zeros(2) 61 | for i in range(num_data): 62 | # the toolbox calculate the gradient of the log likelihood, 63 | # but the algorithm needs that of the negative ll 64 | der -= clib.log_likelihood_2PL_gradient(y1[i], y0[i], theta_vec[i], alpha, beta, c) 65 | return der 66 | 67 | def solve_param_linear(self, is_constrained): 68 | # for now, temp set alpha to 1 69 | def target_fnc(x): 70 | beta = x[0] 71 | alpha = x[1] 72 | return self._likelihood(self.res_data, self.theta, alpha, beta, self.c) 73 | 74 | if is_constrained: 75 | res = minimize(target_fnc, self.x0, method='SLSQP', 76 | bounds=self.bnds, options={'disp': False}) 77 | else: 78 | res = minimize(target_fnc, self.x0, method='nelder-mead', 79 | options={'disp': False}) 80 | 81 | # deal with expcetions 82 | if not res.success: 83 | if not is_constrained and \ 84 | res.message == 'Maximum number of function evaluations\ 85 | has been exceeded.': 86 | raise Exception('Optimizer fails to find solution.\ 87 | Try constrained search.') 88 | else: 89 | raise Exception('Algorithm failed because: ' + res.message) 90 | 91 | return res.x 92 | 93 | def solve_param_gradient(self, is_constrained): 94 | # for now, temp set alpha to 1 95 | 96 | def target_fnc(x): 97 | beta = x[0] 98 | alpha = x[1] 99 | return self._likelihood(self.res_data, self.theta, alpha, beta, self.c) 100 | 101 | def target_der(x): 102 | beta = x[0] 103 | alpha = x[1] 104 | return self._gradient(self.res_data, self.theta, alpha, beta, self.c) 105 | 106 | if is_constrained: 107 | res = minimize(target_fnc, self.x0, method='L-BFGS-B', 108 | jac=target_der, bounds=self.bnds, 109 | options={'disp': False}) 110 | else: 111 | res = minimize(target_fnc, self.x0, method='BFGS', 112 | jac=target_der, 113 | options={'disp': False}) 114 | 115 | if not res.success: 116 | raise Exception("Algorithm failed because " + res.message) 117 | 118 | return res.x 119 | 120 | def solve_param_mix(self, is_constrained=True): 121 | """ 122 | Mix solve_param_gradient and solve_param_linear. 123 | """ 124 | # solve by L-BFGS-B 125 | # * linear is more robust than gradient. 126 | try: 127 | est_param = self.solve_param_gradient(is_constrained) 128 | except Exception as est_e: 129 | # if the alogrithm is nelder-mead and the optimization fails to 130 | # converge, use the constrained version 131 | # 132 | # * solve_param_linear with different params in two times. 133 | try: 134 | est_param = self.solve_param_linear(is_constrained) 135 | except Exception as e: 136 | if str(e) == 'Optimizer fails to find solution. Try constrained search.': 137 | est_param = self.solve_param_linear(True) 138 | else: 139 | raise e 140 | # TODO: handle multiple attempt better 141 | return est_param 142 | 143 | 144 | class irt_factor_optimizer(object): 145 | 146 | def load_res_data(self, res_data): 147 | self.res_data = np.array(res_data) 148 | 149 | def set_item_parameter(self, alpha_vec, beta_vec, c_vec): 150 | if len(alpha_vec) != len(beta_vec): 151 | raise ValueError('The alpha vec and the beta vec does not match in length.') 152 | 153 | self.alpha_vec = alpha_vec 154 | self.beta_vec = beta_vec 155 | self.c_vec = c_vec 156 | 157 | def set_bounds(self, bnds): 158 | self.bnds = bnds 159 | 160 | def set_initial_guess(self, x0): 161 | self.x0 = x0 162 | 163 | @staticmethod 164 | def _likelihood(res_data, theta, alpha_vec, beta_vec, c_vec): 165 | # for MMLE method, y1 and y0 will be expected count 166 | y1 = res_data[0] 167 | y0 = res_data[1] 168 | 169 | # check for equal length between y1,y0 and theta_vec 170 | num_data = len(y1) 171 | if len(y0) != num_data: 172 | raise ValueError('The response data does not match in length.') 173 | if len(alpha_vec) != num_data: 174 | raise ValueError('The response data does not match alpha vec in length.') 175 | if sum(y1 < 0) > 0 or sum(y0 < 0) > 0: 176 | raise ValueError('y1 or y0 contains negative count.') 177 | # this is the likelihood 178 | likelihood_vec = [clib.log_likelihood_2PL(y1[i], y0[i], theta, 179 | alpha_vec[i], beta_vec[i], c_vec[i]) 180 | for i in range(num_data)] 181 | # transform into negative likelihood 182 | ell = -sum(likelihood_vec) 183 | 184 | return ell 185 | 186 | @staticmethod 187 | def _gradient(res_data, theta, alpha_vec, beta_vec, c_vec): 188 | # res should be numpy array 189 | y1 = res_data[0] 190 | y0 = res_data[1] 191 | num_data = len(y1) 192 | 193 | der = 0.0 194 | for i in range(num_data): 195 | der -= tools.log_likelihood_factor_gradient(y1[i], y0[i], theta, alpha_vec[i], beta_vec[i], c_vec[i]) 196 | return der 197 | 198 | @staticmethod 199 | def _hessian(res_data, theta, alpha_vec, beta_vec, c_vec): 200 | # res should be numpy array 201 | y1 = res_data[0] 202 | y0 = res_data[1] 203 | num_data = len(y1) 204 | 205 | hes = 0.0 206 | for i in range(num_data): 207 | hes -= tools.log_likelihood_factor_hessian(y1[i], y0[i], theta, alpha_vec[i], beta_vec[i], c_vec[i]) 208 | return hes 209 | 210 | def solve_param_linear(self, is_constrained): 211 | # for now, temp set alpha to 1 212 | def target_fnc(x): 213 | return self._likelihood(self.res_data, x, self.alpha_vec, self.beta_vec, self.c_vec) 214 | 215 | if is_constrained: 216 | res = minimize(target_fnc, self.x0, method='SLSQP', 217 | bounds=self.bnds, options={'disp': False}) 218 | else: 219 | res = minimize(target_fnc, self.x0, method='nelder-mead', 220 | options={'disp': False}) 221 | 222 | # deal with expcetions 223 | if not res.success: 224 | if not is_constrained and \ 225 | res.message == 'Maximum number of function evaluations has been exceeded.': 226 | raise Exception('Optimizer fails to find solution. Try constrained search.') 227 | else: 228 | raise Exception('Algorithm failed because ' + res.message) 229 | 230 | return res.x 231 | 232 | def solve_param_gradient(self, is_constrained): 233 | # for now, temp set alpha to 1 234 | def target_fnc(x): 235 | return self._likelihood(self.res_data, x, self.alpha_vec, self.beta_vec, self.c_vec) 236 | 237 | def target_der(x): 238 | return self._gradient(self.res_data, x, self.alpha_vec, self.beta_vec, self.c_vec) 239 | 240 | if is_constrained: 241 | res = minimize(target_fnc, self.x0, method='L-BFGS-B', 242 | jac=target_der, bounds=self.bnds, 243 | options={'disp': False}) 244 | else: 245 | res = minimize(target_fnc, self.x0, method='BFGS', 246 | jac=target_der, 247 | options={'disp': False}) 248 | 249 | if not res.success: 250 | raise Exception('Algorithm failed.') 251 | 252 | return res.x 253 | 254 | def solve_param_hessian(self): 255 | def target_fnc(x): 256 | return self._likelihood(self.res_data, x, self.alpha_vec, self.beta_vec, self.c_vec) 257 | 258 | def target_der(x): 259 | return self._gradient(self.res_data, x, self.alpha_vec, self.beta_vec, self.c_vec) 260 | 261 | def target_hess(x): 262 | return self._hessian(self.res_data, x, self.alpha_vec, self.beta_vec, self.c_vec) 263 | 264 | res = minimize(target_fnc, self.x0, method='Newton-CG', 265 | jac=target_der, hess=target_hess, 266 | options={'disp': False}) 267 | if not res.success: 268 | if res.message == 'Desired error not necessarily achieved due to precision loss.': 269 | # TODO:still returns a result. Something is wrong with the BFGS 270 | # though 271 | pass 272 | else: 273 | raise Exception('Algorithm failed, because ' + res.message) 274 | return res.x 275 | 276 | def solve_param_scalar(self): 277 | def target_fnc(x): 278 | return self._likelihood(self.res_data, x, self.alpha_vec, self.beta_vec, self.c_vec) 279 | res = minimize_scalar(target_fnc, bounds=self.bnds, method='bounded') 280 | return res.x 281 | -------------------------------------------------------------------------------- /package/pyirt/pyirt/solver/theta_estimator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import beta 3 | 4 | 5 | from ..util import clib, tools 6 | 7 | from ..solver import optimizer 8 | 9 | 10 | class bayesian_estimator(object): 11 | 12 | def set_prior(self, theta_min, theta_max, num_theta, dist_name): 13 | self.theta_val = np.linspace(theta_min, theta_max, num=num_theta) 14 | self.num_theta = num_theta 15 | if dist_name == 'uniform': 16 | # the density is uniform 17 | self.theta_density = np.ones(num_theta) / num_theta 18 | 19 | elif dist_name == 'beta': 20 | # centered beta 21 | # rescale to move away from the boundary 22 | self.theta_density = beta.pdf((self.theta_val - theta_min) / (theta_max - theta_min + 0.1), 2, 2) 23 | # renormalize 24 | self.theta_density = self.theta_density / sum(self.theta_density) 25 | else: 26 | raise Exception('Unknown prior distribution.') 27 | 28 | def update(self, logs): 29 | # data comes in as 30 | # tag(0/1), (a, b,c) 31 | 32 | likelihood_vec = np.zeros(self.num_theta) 33 | # calculate 34 | for k in range(self.num_theta): 35 | theta = self.theta_val[k] 36 | # calculate the likelihood 37 | ell = 0.0 38 | for log in logs: 39 | atag = log[0] 40 | alpha = log[1][0] 41 | beta = log[1][1] 42 | c = log[1][2] 43 | ell += clib.log_likelihood_2PL(atag, 1.0 - atag, theta, alpha, beta, c) 44 | # now update the density 45 | likelihood_vec[k] = ell 46 | 47 | # ell = p(param|x), full joint = logp(param|x)+log(x) 48 | # Fix np.log, see http://stackoverflow.com/questions/13497891/python-getting-around-division-by-zero 49 | log_joint_prob_vec = likelihood_vec + np.log(self.theta_density.clip(min=0.0000000001)) 50 | # calculate the posterior 51 | # p(x|param) = exp(logp(param,x) - log(sum p(param,x))) 52 | marginal = tools.logsum(log_joint_prob_vec) 53 | self.theta_density = np.exp(log_joint_prob_vec - marginal) 54 | 55 | def get_estimator(self): 56 | 57 | # expected value 58 | theta_mean = np.dot(self.theta_density, self.theta_val) 59 | # theta_var = np.dot(self.theta_density, self.theta_val**2) - theta_mean**2 60 | 61 | theta_hat = theta_mean 62 | return theta_hat 63 | 64 | 65 | class MLE_estimator(object): 66 | worker = optimizer.irt_factor_optimizer() 67 | 68 | def update(self, logs): 69 | # log [tag(0/1), (a, b,c)] 70 | 71 | # transform the logs 72 | y1 = [] 73 | y0 = [] 74 | alphas = [] 75 | betas = [] 76 | cs = [] 77 | for log in logs: 78 | y1.append(log[0]) 79 | y0.append(1.0 - log[0]) 80 | alphas.append(log[1][0]) 81 | betas.append(log[1][1]) 82 | cs.append(log[1][2]) 83 | self.worker.load_res_data([y1, y0]) 84 | self.worker.set_item_parameter(alphas, betas, cs) 85 | self.worker.set_bounds([(-4.0, 4.0)]) 86 | self.worker.set_initial_guess(0.0) 87 | try: 88 | est_theta = self.worker.solve_param_gradient(is_constrained=True) 89 | except Exception as e: 90 | est_theta = self.worker.solve_param_linear(is_constrained=True) 91 | 92 | # the output is an numpy array! 93 | return est_theta[0] 94 | -------------------------------------------------------------------------------- /package/pyirt/pyirt/util/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["tools", "dao", "clib"] 2 | 3 | from . import tools 4 | from . import dao 5 | 6 | from . import clib 7 | -------------------------------------------------------------------------------- /package/pyirt/pyirt/util/clib.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | 3 | import numpy as np 4 | 5 | exp = np.exp 6 | log = np.log 7 | 8 | def log_likelihood_2PL(y1, y0, theta, alpha, beta, c=0.0): 9 | expPos = exp(alpha * theta + beta) 10 | ell = y1 * log((c + expPos) / (1.0 + expPos)) + y0 * log((1.0 - c) / (1.0 + expPos)) ; 11 | 12 | return ell 13 | 14 | 15 | def log_likelihood_2PL_gradient(y1, y0, theta, alpha, beta, c=0.0): 16 | grad = np.zeros(2) 17 | 18 | # It is the gradient of the log likelihood, not the NEGATIVE log likelihood 19 | temp = exp(beta + alpha * theta) 20 | beta_grad = temp / (1.0 + temp) * ( y1 * (1.0 - c) / (c + temp) - y0) 21 | 22 | alpha_grad = theta * beta_grad 23 | grad[0] = beta_grad 24 | grad[1] = alpha_grad 25 | return grad 26 | -------------------------------------------------------------------------------- /package/pyirt/pyirt/util/dao.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script deals with the data format problem. 3 | 4 | The stardard format for pyirt is ( user_id,item_id,result), 5 | where user_id is the idx for test taker, item_id is the idx for items 6 | It is set in this way to deal with the sparsity in the massive dataset. 7 | 8 | ''' 9 | import numpy as np 10 | 11 | 12 | def loadFromTuples(data): 13 | user_ids = [] 14 | item_ids = [] 15 | ans_tags = [] 16 | if len(data) == 0: 17 | raise Exception('data are empty') 18 | 19 | for log in data: 20 | user_ids.append(log[0]) 21 | item_ids.append(log[1]) 22 | ans_tags.append(int(log[2])) 23 | 24 | return user_ids, item_ids, ans_tags 25 | 26 | 27 | def loadFromHandle(fp, sep=','): 28 | # Default format is comma separated files, 29 | # Only int is allowed within the environment 30 | user_ids = [] 31 | item_ids = [] 32 | ans_tags = [] 33 | 34 | for line in fp: 35 | if line == '': 36 | continue 37 | user_id_str, item_id_str, ans_tagstr = line.strip().split(sep) 38 | user_ids.append(user_id_str) 39 | item_ids.append(item_id_str) 40 | ans_tags.append(int(ans_tagstr)) 41 | return user_ids, item_ids, ans_tags 42 | 43 | 44 | def parse_item_paramer(item_param_dict, output_file=None): 45 | 46 | if output_file is not None: 47 | # open the file 48 | out_fh = open(output_file, 'w') 49 | 50 | sorted_item_ids = sorted(item_param_dict.keys()) 51 | 52 | for item_id in sorted_item_ids: 53 | param = item_param_dict[item_id] 54 | alpha_val = np.round(param['alpha'], decimals=2) 55 | beta_val = np.round(param['beta'], decimals=2) 56 | if output_file is None: 57 | print(item_id, alpha_val, beta_val) 58 | else: 59 | out_fh.write('{},{},{}\n'.format(item_id, alpha_val, beta_val)) 60 | 61 | 62 | def construct_ref_dict(in_list): 63 | # map the in_list to a numeric variable from 0 to N 64 | unique_elements = list(set(in_list)) 65 | element_idxs = range(len(unique_elements)) 66 | idx_ref = dict(zip(unique_elements, element_idxs)) 67 | reverse_idx_ref = dict(zip(element_idxs, unique_elements)) 68 | out_idx_list = [idx_ref[x] for x in in_list] 69 | 70 | return out_idx_list, idx_ref, reverse_idx_ref 71 | -------------------------------------------------------------------------------- /package/pyirt/pyirt/util/tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import numpy as np 3 | 4 | def irt_fnc(theta, beta, alpha=1.0, c=0.0): 5 | # beta is item difficulty 6 | # theta is respondent capability 7 | 8 | prob = c + (1.0 - c) / (1 + np.exp(-(alpha * theta + beta))) 9 | return prob 10 | 11 | 12 | def log_likelihood_factor_gradient(y1, y0, theta, alpha, beta, c=0.0): 13 | temp = np.exp(beta + alpha * theta) 14 | grad = alpha * temp / (1.0 + temp) * (y1 * (1.0 - c) / (c + temp ) - y0 ) 15 | 16 | return grad 17 | 18 | 19 | def log_likelihood_factor_hessian(y1, y0, theta, alpha, beta, c=0.0): 20 | x = np.exp(beta + alpha * theta) 21 | # hessian = - alpha**2*(y1+y0)*temp/(1+temp)**2 22 | hessian = alpha ** 2 * x / (1 + x) ** 2 * (y1 * (1 - c) * (c - x ** 2) / (c + x) ** 2 - y0) 23 | 24 | return hessian 25 | 26 | 27 | def log_likelihood_2PL_hessian(y1, y0, theta, alpha, beta, c=0.0): 28 | hessian = np.zeros((2, 2)) 29 | x = np.exp(beta + alpha * theta) 30 | base = x / (1 + x) ** 2 * (y1 * (1 - c) * (c - x ** 2) / (c + x) ** 2 - y0) 31 | 32 | hessian = np.matrix([[1, theta], [theta, theta ** 2]]) * base 33 | 34 | return hessian 35 | 36 | 37 | def logsum(logp): 38 | w = max(logp) 39 | logSump = w + np.log(sum(np.exp(logp - w))) 40 | return logSump 41 | 42 | 43 | 44 | def cut_list(list_length, num_chunk): 45 | chunk_bnd = [0] 46 | for i in range(num_chunk): 47 | chunk_bnd.append(int(list_length*(i+1)/num_chunk)) 48 | chunk_bnd.append(list_length) 49 | chunk_list = [(chunk_bnd[i], chunk_bnd[i+1]) for i in range(num_chunk) ] 50 | return chunk_list 51 | -------------------------------------------------------------------------------- /package/pyirt/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /package/pyirt/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='pyirt', 5 | version="0.3.4", 6 | packages=['pyirt', 7 | 'pyirt/solver', 8 | 'pyirt/util', ], 9 | license='MIT', 10 | description='A python implementation of Item Response Theory(IRT), specializing in big dataset', 11 | author='Junchen Feng', 12 | author_email='frankfeng.pku@gmail.com', 13 | include_package_data=True, 14 | url='https://github.com/junchenfeng/pyirt', 15 | download_url='https://github.com/17zuoye/pyirt/archive/v0.3.4.tar.gz', 16 | keywords=['IRT', 'EM', 'big data'], 17 | zip_safe=False, 18 | platforms='any', 19 | install_requires=['numpy', 20 | 'scipy', 21 | 'cython', 22 | 'pymongo', 23 | 'tqdm', 24 | 'python-decouple'], 25 | 26 | package_data={'pyirt': ["*.pyx"]}, 27 | 28 | classifiers=[ 29 | 'Intended Audience :: Developers', 30 | 'Operating System :: OS Independent', 31 | 'Programming Language :: Python :: 3.6', 32 | ], 33 | ) 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | aiohttp==3.7.3 3 | appdirs==1.4.4 4 | async-timeout==3.0.1 5 | attrs==20.3.0 6 | backcall==0.2.0 7 | bleach==3.2.2 8 | cachetools==4.2.0 9 | certifi==2020.12.5 10 | chardet==3.0.4 11 | colorama==0.4.4 12 | coverage==5.3.1 13 | cycler==0.10.0 14 | Cython==0.29.21 15 | decorator==4.4.2 16 | dill==0.3.3 17 | distlib==0.3.1 18 | docutils==0.16 19 | filelock==3.0.12 20 | flake8==3.8.4 21 | fsspec==0.8.5 22 | future==0.18.2 23 | google-auth==1.24.0 24 | google-auth-oauthlib==0.4.2 25 | grpcio==1.34.1 26 | idna==2.10 27 | ipdb==0.13.4 28 | ipython==7.19.0 29 | ipython-genutils==0.2.0 30 | jedi==0.18.0 31 | joblib==1.0.0 32 | kaggle==1.5.10 33 | keyring==22.0.0 34 | kiwisolver==1.3.1 35 | Markdown==3.3.3 36 | matplotlib==3.3.3 37 | mccabe==0.6.1 38 | multidict==5.1.0 39 | multiprocess==0.70.11.1 40 | nose==1.3.7 41 | numpy==1.19.5 42 | oauthlib==3.1.0 43 | opt-einsum==3.3.0 44 | packaging==20.8 45 | pandas==1.2.0 46 | parso==0.8.1 47 | pathos==0.2.7 48 | pickleshare==0.7.5 49 | Pillow==8.1.0 50 | pipenv==2020.11.15 51 | pkginfo==1.7.0 52 | pox==0.2.9 53 | ppft==1.6.6.3 54 | progressbar2==3.53.1 55 | prompt-toolkit==3.0.14 56 | protobuf==3.14.0 57 | pyasn1==0.4.8 58 | pyasn1-modules==0.2.8 59 | pycodestyle==2.6.0 60 | pyflakes==2.2.0 61 | Pygments==2.7.4 62 | pyirt @ file:///C:/Users/ericyhzhou/PycharmProjects/kt/package/pyirt 63 | pymongo==3.11.2 64 | pyparsing==2.4.7 65 | pyro-api==0.1.2 66 | pyro-ppl==1.5.1 67 | python-dateutil==2.8.1 68 | python-decouple==3.4 69 | python-slugify==4.0.1 70 | python-utils==2.5.6 71 | pytorch-lightning==1.1.4 72 | pytz==2020.5 73 | pywin32-ctypes==0.2.0 74 | PyYAML==5.3.1 75 | readme-renderer==28.0 76 | requests==2.25.1 77 | requests-oauthlib==1.3.0 78 | requests-toolbelt==0.9.1 79 | rfc3986==1.4.0 80 | rsa==4.7 81 | scikit-learn==0.24.0 82 | scipy==1.6.0 83 | six==1.15.0 84 | sklearn==0.0 85 | tensorboard==2.4.1 86 | tensorboard-plugin-wit==1.7.0 87 | text-unidecode==1.3 88 | threadpoolctl==2.1.0 89 | torch==1.7.1 90 | tqdm==4.56.0 91 | traitlets==5.0.5 92 | twine==3.3.0 93 | typing-extensions==3.7.4.3 94 | urllib3==1.26.2 95 | virtualenv==20.4.0 96 | virtualenv-clone==0.5.4 97 | wcwidth==0.2.5 98 | webencodings==0.5.1 99 | Werkzeug==1.0.1 100 | yarl==1.6.3 101 | -------------------------------------------------------------------------------- /sample_data_preprocess.py: -------------------------------------------------------------------------------- 1 | # This sample script preprocess RAIEd2021 dataset 2 | # Minor modification can be made to this script to make it compatible for other datasets 3 | 4 | import random 5 | import gc 6 | import argparse 7 | 8 | import pandas as pd 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | tqdm.pandas() 13 | 14 | parser = argparse.ArgumentParser(description="Sample Data Preprocess Script") 15 | parser.add_argument('-t', '--train_csv', type=str, required=True, 16 | help="Filepath of train.csv") 17 | parser.add_argument('-q', '--question_csv', type=str, required=True, 18 | help="Filepath of question.csv") 19 | parser.add_argument('-s', '--split', type=float, required=True, 20 | help="Testset / Trainset") 21 | parser.add_argument('-o', '--output', type=str, required=True, 22 | help="Filepath of preprocessed data") 23 | parser.add_argument('--irt', action='store_true', 24 | help="Whether to perform IRT analysis. Required if you will do leveled learning") 25 | args = parser.parse_args() 26 | 27 | print("Loading CSV...") 28 | train_dtypes = {'row_id': 'int64', 29 | 'timestamp': 'int64', 30 | 'user_id': 'int32', 31 | 'content_id': 'int16', 32 | 'content_type_id': 'int8', 33 | 'answered_correctly': 'int8', 34 | 'user_answer': 'int8', 35 | 'prior_question_elapsed_time': 'float32', 36 | 'task_container_id': 'int16', 37 | 'prior_question_had_explanation': 'boolean'} 38 | train_df = pd.read_csv(args.train_csv, dtype=train_dtypes) 39 | question_dtypes = {"question_id": "int16", "part": "int8"} 40 | question_df = pd.read_csv(args.question_csv, dtype=question_dtypes) 41 | # Align the name of key column for latter merging 42 | question_df = question_df.rename(columns={"question_id": "content_id"}) 43 | 44 | # Formatting the timestamp 45 | # Here we basically want to reset the timestamp of each record so that all users 46 | # do their respective last exercise at almost the same time (Instead of 0). 47 | # This step is usefull for splitting training/valid dataset as we dont want to 48 | # randomly split the dataset 49 | # 50 | # In order to do that, we firstly need to get the max timestamp of all records 51 | # Then we use it to minus the max time stamp of each user to represent the start 52 | # timestamp of this specific user 53 | # 54 | # Insipred from https://www.kaggle.com/its7171/cv-strategy 55 | 56 | max_timestamp_user = train_df[["user_id", "timestamp"]].groupby(["user_id"]).agg(["max"]).reset_index() 57 | max_timestamp_user.columns = ["user_id", "max_timestamp"] 58 | MAX_TIMESTAMP = max_timestamp_user.max_timestamp.max() 59 | print("Generating virtual timestamp") 60 | 61 | 62 | def reset_time(max_timestamp): 63 | gap = MAX_TIMESTAMP - max_timestamp 64 | rand_init_time = random.randint(0, gap) 65 | return rand_init_time 66 | 67 | 68 | max_timestamp_user["rand_timestamp"] = max_timestamp_user.max_timestamp.progress_apply(reset_time) 69 | train_df = train_df.merge(max_timestamp_user, on="user_id", how="left") 70 | train_df["virtual_timestamp"] = train_df.timestamp + train_df["rand_timestamp"] 71 | 72 | del max_timestamp_user 73 | gc.collect() 74 | 75 | # Merging train_df and question_df on 76 | train_df = train_df[train_df.content_type_id == 0] # only consider question 77 | train_df = train_df.merge(question_df, on='content_id', how="left") # left outer join to consider part 78 | train_df.prior_question_elapsed_time /= 1000 # ms -> s 79 | train_df.prior_question_elapsed_time.fillna(0, inplace=True) 80 | train_df.prior_question_elapsed_time.clip(lower=0, upper=300, inplace=True) 81 | train_df.prior_question_elapsed_time = train_df.prior_question_elapsed_time.astype(np.int) 82 | 83 | del question_df 84 | gc.collect() 85 | 86 | train_df['prior_question_had_explanation'] = train_df['prior_question_had_explanation'].fillna(value=False).astype(int) 87 | 88 | train_df = train_df.sort_values(["virtual_timestamp", "row_id"]).reset_index(drop=True) 89 | n_content_ids = len(train_df.content_id.unique()) 90 | n_parts = len(train_df.part.unique()) 91 | print("NO. of exercises:", n_content_ids) 92 | print("NO. of part", n_parts) 93 | print("Shape of the dataframe after exclusion:", train_df.shape) 94 | 95 | print("Computing question difficulty") 96 | df_difficulty = train_df["answered_correctly"].groupby(train_df["content_id"]) 97 | train_df["popularity"] = df_difficulty.transform('size') 98 | train_df["difficulty"] = df_difficulty.transform('sum') / train_df["popularity"] 99 | print("Popularity max", train_df["popularity"].max(), ",Difficulty max", train_df["difficulty"].max()) 100 | 101 | del df_difficulty 102 | gc.collect() 103 | 104 | print("Calculating lag time") 105 | time_dict = {} 106 | 107 | lag_time_col = np.zeros(len(train_df), dtype=np.int64) 108 | for ind, row in enumerate(tqdm(train_df[["user_id", "timestamp", "task_container_id"]].values)): 109 | if row[0] in time_dict.keys(): 110 | # if the task_container_id is the same, the lag time is not allowed 111 | if row[2] == time_dict[row[0]][1]: 112 | lag_time_col[ind] = time_dict[row[0]][2] 113 | else: 114 | timestamp_last = time_dict[row[0]][0] 115 | lag_time_col[ind] = row[1] - timestamp_last 116 | time_dict[row[0]] = (row[1], row[2], lag_time_col[ind]) 117 | else: 118 | time_dict[row[0]] = (row[1], row[2], 0) 119 | lag_time_col[ind] = 0 120 | if lag_time_col[ind] < 0: 121 | raise RuntimeError("Has lag_time smaller than 0.") 122 | 123 | train_df["lag_time_s"] = lag_time_col // 1000 124 | train_df["lag_time_m"] = lag_time_col // (60 * 1000) 125 | train_df["lag_time_d"] = lag_time_col // (60 * 1000 * 1440) 126 | train_df.lag_time_s.clip(lower=0, upper=300, inplace=True) 127 | train_df.lag_time_m.clip(lower=0, upper=1440, inplace=True) 128 | train_df.lag_time_d.clip(lower=0, upper=365, inplace=True) 129 | train_df.lag_time_s = train_df.lag_time_s.astype(np.int) 130 | train_df.lag_time_m = train_df.lag_time_m.astype(np.int) 131 | train_df.lag_time_d = train_df.lag_time_d.astype(np.int) 132 | 133 | del lag_time_col 134 | gc.collect() 135 | 136 | print("Add special token") 137 | train_df.content_id = train_df.content_id + 2 # PAD and START 138 | train_df.answered_correctly = train_df.answered_correctly + 2 # PAD and START 139 | train_df.part = train_df.part + 1 # part has no 0. 140 | train_df.prior_question_had_explanation = train_df.prior_question_had_explanation + 2 # PAD and START 141 | train_df.prior_question_elapsed_time = train_df.prior_question_elapsed_time + 2 142 | train_df.lag_time_s = train_df.lag_time_s + 2 143 | train_df.lag_time_m = train_df.lag_time_m + 2 144 | train_df.lag_time_d = train_df.lag_time_d + 2 145 | 146 | print("Partitioning dataset") 147 | train_df = train_df.sort_values(["virtual_timestamp", "row_id"]).reset_index(drop=True) 148 | ROW_NUM = len(train_df) 149 | train_split = train_df[:-int(ROW_NUM * args.split)] 150 | valid_split = train_df[-int(ROW_NUM * args.split):] 151 | new_users = len(valid_split[~valid_split.user_id.isin(train_split.user_id)].user_id.unique()) 152 | valid_question = valid_split[valid_split.content_type_id == 0] 153 | train_question = train_split[train_split.content_type_id == 0] 154 | print(f"{train_question.answered_correctly.mean():.3f} {valid_question.answered_correctly.mean():.3f} {new_users}") 155 | 156 | del train_df 157 | gc.collect() 158 | 159 | print("Grouping users") 160 | 161 | 162 | def group_func(r): 163 | return (r.content_id.values, 164 | r.part.values, 165 | r.answered_correctly.values, 166 | r.prior_question_elapsed_time.values, 167 | r.lag_time_s.values, 168 | r.lag_time_m.values, 169 | r.lag_time_d.values, 170 | r.prior_question_had_explanation.values) 171 | 172 | 173 | print(train_split) 174 | print(valid_split) 175 | 176 | train_part = train_split[["timestamp", "user_id", "content_id", "part", "answered_correctly", 177 | "content_type_id", "prior_question_elapsed_time", "lag_time_s", "lag_time_m", 178 | "lag_time_d", "prior_question_had_explanation"]].groupby("user_id").progress_apply(group_func) 179 | valid_part = valid_split[["timestamp", "user_id", "content_id", "part", "answered_correctly", 180 | "content_type_id", "prior_question_elapsed_time", "lag_time_s", "lag_time_m", 181 | "lag_time_d", "prior_question_had_explanation"]].groupby("user_id").progress_apply(group_func) 182 | print(train_part.shape) 183 | print(valid_part.shape) 184 | 185 | # if SAVE_DATA_TO_CACHE: 186 | train_part.to_pickle(f"{args.output}.train") 187 | valid_part.to_pickle(f"{args.output}.valid") 188 | 189 | if args.irt: 190 | from pyirt import irt 191 | import pickle 192 | 193 | print("Start to use IRT model to estimate parameters") 194 | irt_src = [] 195 | for user_id, (e_id, _, answer, _, _, _, _, _) in train_part.items(): 196 | for item_id, ans in zip(e_id, answer): 197 | irt_src.append((user_id, item_id, ans - 2)) 198 | 199 | item_param, user_param = irt(irt_src, theta_bnds=[-3, 3], max_iter=100) 200 | 201 | f = open(f"{args.output}.user", 'wb') 202 | pickle.dump(user_param, f) 203 | f.close() 204 | 205 | f = open(f"{args.output}.item", 'wb') 206 | pickle.dump(item_param, f) 207 | f.close() 208 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import os 3 | from pathlib import Path 4 | import random 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import pickle 10 | 11 | from collections import OrderedDict 12 | 13 | import pytorch_lightning as pl 14 | 15 | 16 | @contextmanager 17 | def timer(message: str): 18 | print(f'[{message} start.]') 19 | t0 = time.time() 20 | yield 21 | elapsed_time = time.time() - t0 22 | print(f'[{message}] done in {elapsed_time / 60:.1f} min.') 23 | 24 | 25 | def set_seed(seed: int = 2021): 26 | random.seed(seed) 27 | os.environ["PYTHONHASHSEED"] = str(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed(seed) 31 | torch.backends.cudnn.deterministic = True 32 | 33 | 34 | def read_pickle(filepath): 35 | with open(filepath, "rb") as f: 36 | return pickle.load(f) 37 | 38 | 39 | def load_model(filepath, device="cpu"): 40 | model_obj = torch.load(filepath, map_location=device) 41 | baseline = list(model_obj['callbacks'].values())[0]['best_model_score'].item() 42 | optimizer_state = model_obj['optimizer_states'] 43 | model_state = model_obj['state_dict'] 44 | return baseline, optimizer_state, model_state 45 | 46 | 47 | def remove_prefix_from_dict(dictionary, prefix): 48 | new_dict = OrderedDict() 49 | for k, v in dictionary.items(): 50 | name = k[len(prefix):] 51 | new_dict[name] = v 52 | return new_dict 53 | 54 | 55 | def dict_to_device(dictionary, device): 56 | for k, v in dictionary.items(): 57 | dictionary[k] = v.to(device) 58 | return dictionary 59 | 60 | 61 | def prob_topk(x, topk): 62 | values, indices = torch.topk(x, topk) 63 | probs = torch.zeros(x.shape, dtype=x.dtype, device=x.device).scatter_(1, indices, values) 64 | probs = torch.softmax(probs, dim=-1) 65 | return probs 66 | 67 | 68 | if __name__ == "__main__": 69 | import config 70 | model_path = config.PRETRAINED_MODEL 71 | load_model(filepath=model_path, device="cpu") --------------------------------------------------------------------------------