├── .gitignore ├── README.md ├── data └── assistments.txt ├── requirements.txt └── src ├── TensorFlowDKT.py ├── __init__.py ├── data_process.py └── train_dkt.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | data/ 3 | *.pyc 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Run this demo 2 | 3 | `$python train_dkt.py --dataset ../data/assistments.txt` 4 | 5 | # Blog 6 | https://www.cnblogs.com/jinxulin/p/15729997.html 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=1.0.0 2 | scikit-learn>=0.17.1 3 | -------------------------------------------------------------------------------- /src/TensorFlowDKT.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import tensorflow as tf 3 | 4 | 5 | class TensorFlowDKT(object): 6 | def __init__(self, config): 7 | self.hidden_neurons = hidden_neurons = config["hidden_neurons"] 8 | self.num_skills = num_skills = config["num_skills"] 9 | self.input_size = input_size = config["input_size"] 10 | self.batch_size = batch_size = config["batch_size"] 11 | self.keep_prob_value = config["keep_prob"] 12 | 13 | self.max_steps = tf.placeholder(tf.int32) # max seq length of current batch 14 | self.input_data = tf.placeholder(tf.float32, [batch_size, None, input_size]) 15 | self.sequence_len = tf.placeholder(tf.int32, [batch_size]) 16 | self.keep_prob = tf.placeholder(tf.float32) # dropout keep prob 17 | 18 | self.target_id = tf.placeholder(tf.int32, [batch_size, None]) 19 | self.target_correctness = tf.placeholder(tf.float32, [batch_size, None]) 20 | 21 | # create rnn cell 22 | hidden_layers = [] 23 | for idx, hidden_size in enumerate(hidden_neurons): 24 | lstm_layer = tf.contrib.rnn.BasicLSTMCell(num_units=hidden_size, state_is_tuple=True) 25 | hidden_layer = tf.contrib.rnn.DropoutWrapper(cell=lstm_layer, 26 | output_keep_prob=self.keep_prob) 27 | hidden_layers.append(hidden_layer) 28 | self.hidden_cell = tf.contrib.rnn.MultiRNNCell(cells=hidden_layers, state_is_tuple=True) 29 | 30 | # dynamic rnn 31 | state_series, self.current_state = tf.nn.dynamic_rnn(cell=self.hidden_cell, 32 | inputs=self.input_data, 33 | sequence_length=self.sequence_len, 34 | dtype=tf.float32) 35 | 36 | # output layer 37 | output_w = tf.get_variable("W", [hidden_neurons[-1], num_skills]) 38 | output_b = tf.get_variable("b", [num_skills]) 39 | self.state_series = tf.reshape(state_series, [batch_size * self.max_steps, hidden_neurons[-1]]) 40 | self.logits = tf.matmul(self.state_series, output_w) + output_b 41 | self.mat_logits = tf.reshape(self.logits, [batch_size, self.max_steps, num_skills]) 42 | self.pred_all = tf.sigmoid(self.mat_logits) 43 | 44 | # compute loss 45 | flat_logits = tf.reshape(self.logits, [-1]) 46 | flat_target_correctness = tf.reshape(self.target_correctness, [-1]) 47 | flat_base_target_index = tf.range(batch_size * self.max_steps) * num_skills 48 | flat_bias_target_id = tf.reshape(self.target_id, [-1]) 49 | flat_target_id = flat_bias_target_id + flat_base_target_index 50 | flat_target_logits = tf.gather(flat_logits, flat_target_id) 51 | self.pred = tf.sigmoid(tf.reshape(flat_target_logits, [batch_size, self.max_steps])) 52 | self.binary_pred = tf.cast(tf.greater_equal(self.pred, 0.5), tf.int32) 53 | self.loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=flat_target_correctness, 54 | logits=flat_target_logits)) 55 | 56 | self.lr = tf.Variable(0.0, trainable=False) 57 | trainable_vars = tf.trainable_variables() 58 | self.grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, trainable_vars), 4) 59 | 60 | optimizer = tf.train.GradientDescentOptimizer(self.lr) 61 | self.train_op = optimizer.apply_gradients(zip(self.grads, trainable_vars)) 62 | 63 | # step on batch 64 | def step(self, sess, input_x, target_id, target_correctness, sequence_len, is_train): 65 | _, max_steps, _ = input_x.shape 66 | input_feed = {self.input_data: input_x, 67 | self.target_id: target_id, 68 | self.target_correctness: target_correctness, 69 | self.max_steps: max_steps, 70 | self.sequence_len: sequence_len} 71 | 72 | if is_train: 73 | input_feed[self.keep_prob] = self.keep_prob_value 74 | train_loss, _, _ = sess.run([self.loss, self.train_op, self.current_state], input_feed) 75 | return train_loss 76 | else: 77 | input_feed[self.keep_prob] = 1 78 | bin_pred, pred, pred_all = sess.run([self.binary_pred, self.pred, self.pred_all], input_feed) 79 | return bin_pred, pred, pred_all 80 | 81 | def assign_lr(self, session, lr_value): 82 | session.run(tf.assign(self.lr, lr_value)) 83 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingochamp/tensorflow-dkt/8686872aa20c700764ef4d3a7315a5ab382e2caf/src/__init__.py -------------------------------------------------------------------------------- /src/data_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | 5 | def read_file(dataset_path): 6 | seqs_by_student = {} 7 | num_skills = 0 8 | with open(dataset_path, 'r') as f: 9 | for line in f: 10 | fields = line.strip().split() 11 | student, problem, is_correct = int(fields[0]), int(fields[1]), int(fields[2]) 12 | num_skills = max(num_skills, problem) 13 | seqs_by_student[student] = seqs_by_student.get(student, []) + [[problem, is_correct]] 14 | return seqs_by_student, num_skills + 1 15 | 16 | 17 | def split_dataset(seqs_by_student, sample_rate=0.2, random_seed=1): 18 | sorted_keys = sorted(seqs_by_student.keys()) 19 | random.seed(random_seed) 20 | test_keys = set(random.sample(sorted_keys, int(len(sorted_keys) * sample_rate))) 21 | test_seqs = [seqs_by_student[k] for k in seqs_by_student if k in test_keys] 22 | train_seqs = [seqs_by_student[k] for k in seqs_by_student if k not in test_keys] 23 | return train_seqs, test_seqs 24 | 25 | 26 | def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre', truncating='pre', value=0.): 27 | lengths = [len(s) for s in sequences] 28 | nb_samples = len(sequences) 29 | if maxlen is None: 30 | maxlen = np.max(lengths) 31 | 32 | # take the sample shape from the first non empty sequence 33 | # checking for consistency in the main loop below. 34 | sample_shape = tuple() 35 | for s in sequences: 36 | if len(s) > 0: 37 | sample_shape = np.asarray(s).shape[1:] 38 | break 39 | 40 | x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype) 41 | for idx, s in enumerate(sequences): 42 | if len(s) == 0: 43 | continue # empty list was found 44 | if truncating == 'pre': 45 | trunc = s[-maxlen:] 46 | elif truncating == 'post': 47 | trunc = s[:maxlen] 48 | else: 49 | raise ValueError('Truncating type "%s" not understood' % truncating) 50 | 51 | # check `trunc` has expected shape 52 | trunc = np.asarray(trunc, dtype=dtype) 53 | if trunc.shape[1:] != sample_shape: 54 | raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' % 55 | (trunc.shape[1:], idx, sample_shape)) 56 | if padding == 'post': 57 | x[idx, :len(trunc)] = trunc 58 | elif padding == 'pre': 59 | x[idx, -len(trunc):] = trunc 60 | else: 61 | raise ValueError('Padding type "%s" not understood' % padding) 62 | return x 63 | 64 | 65 | def num_to_one_hot(num, dim): 66 | base = np.zeros(dim) 67 | if num >= 0: 68 | base[num] += 1 69 | return base 70 | 71 | 72 | def format_data(seqs, batch_size, num_skills): 73 | gap = batch_size - len(seqs) 74 | seqs_in = seqs + [[[0, 0]]] * gap # pad batch data to fix size 75 | seq_len = np.array(map(lambda seq: len(seq), seqs_in)) - 1 76 | max_len = max(seq_len) 77 | x = pad_sequences(np.array([[(j[0] + num_skills * j[1]) for j in i[:-1]] for i in seqs_in]), maxlen=max_len, padding='post', value=-1) 78 | input_x = np.array([[num_to_one_hot(j, num_skills*2) for j in i] for i in x]) 79 | target_id = pad_sequences(np.array([[j[0] for j in i[1:]] for i in seqs_in]), maxlen=max_len, padding='post', value=0) 80 | target_correctness = pad_sequences(np.array([[j[1] for j in i[1:]] for i in seqs_in]), maxlen=max_len, padding='post', value=0) 81 | return input_x, target_id, target_correctness, seq_len, max_len 82 | 83 | 84 | class DataGenerator(object): 85 | def __init__(self, seqs, batch_size, num_skills): 86 | self.seqs = seqs 87 | self.batch_size = batch_size 88 | self.pos = 0 89 | self.end = False 90 | self.size = len(seqs) 91 | self.num_skills = num_skills 92 | 93 | def next_batch(self): 94 | batch_size = self.batch_size 95 | if self.pos + batch_size < self.size: 96 | batch_seqs = self.seqs[self.pos:self.pos + batch_size] 97 | self.pos += batch_size 98 | else: 99 | batch_seqs = self.seqs[self.pos:] 100 | self.pos = self.size - 1 101 | if self.pos >= self.size - 1: 102 | self.end = True 103 | input_x, target_id, target_correctness, seqs_len, max_len = format_data(batch_seqs, batch_size, self.num_skills) 104 | return input_x, target_id, target_correctness, seqs_len, max_len 105 | 106 | def shuffle(self): 107 | self.pos = 0 108 | self.end = False 109 | np.random.shuffle(self.seqs) 110 | 111 | def reset(self): 112 | self.pos = 0 113 | self.end = False 114 | -------------------------------------------------------------------------------- /src/train_dkt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import sys 4 | from TensorFlowDKT import * 5 | from data_process import * 6 | from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, accuracy_score 7 | 8 | 9 | def run(args): 10 | # process data 11 | seqs_by_student, num_skills = read_file(args.dataset) 12 | train_seqs, test_seqs = split_dataset(seqs_by_student) 13 | batch_size = 10 14 | train_generator = DataGenerator(train_seqs, batch_size=batch_size, num_skills=num_skills) 15 | test_generator = DataGenerator(test_seqs, batch_size=batch_size, num_skills=num_skills) 16 | 17 | # config and create model 18 | config = {"hidden_neurons": [200], 19 | "batch_size": batch_size, 20 | "keep_prob": 0.6, 21 | "num_skills": num_skills, 22 | "input_size": num_skills * 2} 23 | model = TensorFlowDKT(config) 24 | 25 | sess = tf.Session() 26 | sess.run(tf.global_variables_initializer()) 27 | lr = 0.4 28 | lr_decay = 0.92 29 | # run epoch 30 | for epoch in range(10): 31 | # train 32 | model.assign_lr(sess, lr * lr_decay ** epoch) 33 | overall_loss = 0 34 | train_generator.shuffle() 35 | st = time.time() 36 | while not train_generator.end: 37 | input_x, target_id, target_correctness, seqs_len, max_len = train_generator.next_batch() 38 | overall_loss += model.step(sess, input_x, target_id, target_correctness, seqs_len, is_train=True) 39 | print "\r idx:{0}, overall_loss:{1}, time spent:{2}s".format(train_generator.pos, overall_loss, 40 | time.time() - st), 41 | sys.stdout.flush() 42 | 43 | # test 44 | test_generator.reset() 45 | preds, binary_preds, targets = list(), list(), list() 46 | while not test_generator.end: 47 | input_x, target_id, target_correctness, seqs_len, max_len = test_generator.next_batch() 48 | binary_pred, pred, _ = model.step(sess, input_x, target_id, target_correctness, seqs_len, is_train=False) 49 | for seq_idx, seq_len in enumerate(seqs_len): 50 | preds.append(pred[seq_idx, 0:seq_len]) 51 | binary_preds.append(binary_pred[seq_idx, 0:seq_len]) 52 | targets.append(target_correctness[seq_idx, 0:seq_len]) 53 | # compute metrics 54 | preds = np.concatenate(preds) 55 | binary_preds = np.concatenate(binary_preds) 56 | targets = np.concatenate(targets) 57 | auc_value = roc_auc_score(targets, preds) 58 | accuracy = accuracy_score(targets, binary_preds) 59 | precision, recall, f_score, _ = precision_recall_fscore_support(targets, binary_preds) 60 | print "\n auc={0}, accuracy={1}, precision={2}, recall={3}".format(auc_value, accuracy, precision, recall) 61 | 62 | 63 | if __name__ == "__main__": 64 | arg_parser = argparse.ArgumentParser(description="train dkt model") 65 | arg_parser.add_argument("--dataset", dest="dataset", required=True) 66 | args = arg_parser.parse_args() 67 | run(args) 68 | --------------------------------------------------------------------------------