├── .gitignore ├── README.md ├── Session4Rec.ipynb ├── defaults.py ├── evaluation.py ├── model_session4rec.py ├── model_simple_rnn.py ├── requirements.txt └── scripts └── preprocess.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # data 104 | data/ 105 | 106 | # model 107 | checkpoint 108 | checkpoint-session/ 109 | model.ckpt.* 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # session4rec 2 | GRu4Rec in Tensorflow 3 | 4 | # Get data 5 | 6 | - http://2015.recsyschallenge.com/challenge.html 7 | - https://s3-eu-west-1.amazonaws.com/yc-rdata/yoochoose-data.7z 8 | 9 | # Preprocess data 10 | 11 | - change input and output directories in /scripts/preprocess.py 12 | - python ./scripts/preprocess.py 13 | 14 | # Results 15 | 16 | ## loss = top1, rnn-size = 100, epochs = 3 17 | - Recall@20: 0.4243443414791227 18 | - MRR@20: 0.10790868224055285 19 | 20 | # References 21 | - Main source: https://github.com/Songweiping/GRU4Rec_TensorFlow 22 | - The original implementation: https://github.com/hidasib/GRU4Rec 23 | -------------------------------------------------------------------------------- /Session4Rec.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "PATH_TO_TRAIN = './data/rsc15_train_full.txt'\n", 10 | "PATH_TO_TEST = './data/rsc15_test.txt'" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 4, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import pandas as pd\n", 20 | "import numpy as np" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 5, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "data = pd.read_csv(PATH_TO_TRAIN, sep='\\t', dtype={'ItemId':np.int64})\n", 30 | "valid = pd.read_csv(PATH_TO_TEST, sep='\\t', dtype={'ItemId':np.int64})" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 6, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "data": { 40 | "text/html": [ 41 | "
\n", 42 | "\n", 55 | "\n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | "
SessionIdItemIdTime
012145365021.396861e+09
112145365001.396861e+09
212145365061.396861e+09
312145775611.396861e+09
422146627421.396872e+09
\n", 97 | "
" 98 | ], 99 | "text/plain": [ 100 | " SessionId ItemId Time\n", 101 | "0 1 214536502 1.396861e+09\n", 102 | "1 1 214536500 1.396861e+09\n", 103 | "2 1 214536506 1.396861e+09\n", 104 | "3 1 214577561 1.396861e+09\n", 105 | "4 2 214662742 1.396872e+09" 106 | ] 107 | }, 108 | "execution_count": 6, 109 | "metadata": {}, 110 | "output_type": "execute_result" 111 | } 112 | ], 113 | "source": [ 114 | "data.head()" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 7, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "data": { 124 | "text/html": [ 125 | "
\n", 126 | "\n", 139 | "\n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | "
SessionIdItemIdTime
0112650092145868051.411997e+09
1112650092145092601.411997e+09
2112650172148575471.412011e+09
3112650172148572681.412011e+09
4112650172148572601.412011e+09
\n", 181 | "
" 182 | ], 183 | "text/plain": [ 184 | " SessionId ItemId Time\n", 185 | "0 11265009 214586805 1.411997e+09\n", 186 | "1 11265009 214509260 1.411997e+09\n", 187 | "2 11265017 214857547 1.412011e+09\n", 188 | "3 11265017 214857268 1.412011e+09\n", 189 | "4 11265017 214857260 1.412011e+09" 190 | ] 191 | }, 192 | "execution_count": 7, 193 | "metadata": {}, 194 | "output_type": "execute_result" 195 | } 196 | ], 197 | "source": [ 198 | "valid.head()" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "## Training" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 9, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "layers = 1" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [] 223 | } 224 | ], 225 | "metadata": { 226 | "kernelspec": { 227 | "display_name": "Python 3", 228 | "language": "python", 229 | "name": "python3" 230 | }, 231 | "language_info": { 232 | "codemirror_mode": { 233 | "name": "ipython", 234 | "version": 3 235 | }, 236 | "file_extension": ".py", 237 | "mimetype": "text/x-python", 238 | "name": "python", 239 | "nbconvert_exporter": "python", 240 | "pygments_lexer": "ipython3", 241 | "version": "3.6.5" 242 | } 243 | }, 244 | "nbformat": 4, 245 | "nbformat_minor": 2 246 | } 247 | -------------------------------------------------------------------------------- /defaults.py: -------------------------------------------------------------------------------- 1 | class Defaults: 2 | is_training = True 3 | layers = 1 4 | rnn_size = 1000 5 | n_epochs = 3 6 | batch_size = 50 7 | dropout_p_hidden = 1 8 | learning_rate = 0.001 9 | decay = 0.96 10 | decay_steps = 1e4 11 | sigma = 0 12 | init_as_normal = False 13 | reset_after_session = True 14 | session_key = 'SessionId' 15 | item_key = 'ItemId' 16 | time_key = 'Time' 17 | grad_cap = 0 18 | test_model = 2 19 | checkpoint_dir = './checkpoint-session' 20 | loss = 'top1' 21 | final_act = 'softmax' 22 | hidden_act = 'tanh' 23 | n_items = -1 24 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Feb 27 2017 4 | Author: Weiping Song 5 | """ 6 | import numpy as np 7 | import pandas as pd 8 | 9 | 10 | def evaluate_sessions_batch(model, train_data, test_data, cut_off=20, batch_size=50, session_key='SessionId', item_key='ItemId', time_key='Time'): 11 | 12 | ''' 13 | Evaluates the GRU4Rec network wrt. recommendation accuracy measured by recall@N and MRR@N. 14 | 15 | Parameters 16 | -------- 17 | model : A trained GRU4Rec model. 18 | train_data : It contains the transactions of the train set. In evaluation phrase, this is used to build item-to-id map. 19 | test_data : It contains the transactions of the test set. It has one column for session IDs, one for item IDs and one for the timestamp of the events (unix timestamps). 20 | cut-off : int 21 | Cut-off value (i.e. the length of the recommendation list; N for recall@N and MRR@N). Default value is 20. 22 | batch_size : int 23 | Number of events bundled into a batch during evaluation. Speeds up evaluation. If it is set high, the memory consumption increases. Default value is 100. 24 | session_key : string 25 | Header of the session ID column in the input file (default: 'SessionId') 26 | item_key : string 27 | Header of the item ID column in the input file (default: 'ItemId') 28 | time_key : string 29 | Header of the timestamp column in the input file (default: 'Time') 30 | 31 | Returns 32 | -------- 33 | out : tuple 34 | (Recall@N, MRR@N) 35 | 36 | ''' 37 | model.predict = False 38 | # Build itemidmap from train data. 39 | itemids = train_data[item_key].unique() 40 | itemidmap = pd.Series(data=np.arange(len(itemids)), index=itemids) 41 | 42 | test_data.sort([session_key, time_key], inplace=True) 43 | offset_sessions = np.zeros(test_data[session_key].nunique()+1, dtype=np.int32) 44 | offset_sessions[1:] = test_data.groupby(session_key).size().cumsum() 45 | evalutation_point_count = 0 46 | mrr, recall = 0.0, 0.0 47 | if len(offset_sessions) - 1 < batch_size: 48 | batch_size = len(offset_sessions) - 1 49 | iters = np.arange(batch_size).astype(np.int32) 50 | maxiter = iters.max() 51 | start = offset_sessions[iters] 52 | end = offset_sessions[iters+1] 53 | in_idx = np.zeros(batch_size, dtype=np.int32) 54 | np.random.seed(42) 55 | while True: 56 | valid_mask = iters >= 0 57 | if valid_mask.sum() == 0: 58 | break 59 | start_valid = start[valid_mask] 60 | minlen = (end[valid_mask]-start_valid).min() 61 | in_idx[valid_mask] = test_data[item_key].values[start_valid] 62 | for i in range(minlen-1): 63 | out_idx = test_data[item_key].values[start_valid+i+1] 64 | preds = model.predict_next_batch(iters, in_idx, itemidmap, batch_size) 65 | preds.fillna(0, inplace=True) 66 | in_idx[valid_mask] = out_idx 67 | ranks = (preds.values.T[valid_mask].T > np.diag(preds.ix[in_idx].values)[valid_mask]).sum(axis=0) + 1 68 | rank_ok = ranks < cut_off 69 | recall += rank_ok.sum() 70 | mrr += (1.0 / ranks[rank_ok]).sum() 71 | evalutation_point_count += len(ranks) 72 | start = start+minlen-1 73 | mask = np.arange(len(iters))[(valid_mask) & (end-start<=1)] 74 | for idx in mask: 75 | maxiter += 1 76 | if maxiter >= len(offset_sessions)-1: 77 | iters[idx] = -1 78 | else: 79 | iters[idx] = maxiter 80 | start[idx] = offset_sessions[maxiter] 81 | end[idx] = offset_sessions[maxiter+1] 82 | return recall/evalutation_point_count, mrr/evalutation_point_count 83 | -------------------------------------------------------------------------------- /model_session4rec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import pandas as pd 4 | import os 5 | from tensorflow.python.ops import rnn, rnn_cell 6 | 7 | from defaults import Defaults 8 | import evaluation 9 | 10 | PATH_TO_TRAIN = './data/rsc15_train_full.txt' 11 | PATH_TO_TEST = './data/rsc15_test.txt' 12 | 13 | class Session4RecPredictor: 14 | def __init__(self, defaults, session): 15 | self.sess = session 16 | self.is_training = defaults.is_training 17 | 18 | self.layers = defaults.layers 19 | self.rnn_size = defaults.rnn_size 20 | self.n_epochs = defaults.n_epochs 21 | self.batch_size = defaults.batch_size 22 | self.dropout_p_hidden = defaults.dropout_p_hidden 23 | self.learning_rate = defaults.learning_rate 24 | self.decay = defaults.decay 25 | self.decay_steps = defaults.decay_steps 26 | self.sigma = defaults.sigma 27 | self.init_as_normal = defaults.init_as_normal 28 | self.reset_after_session = defaults.reset_after_session 29 | self.session_key = defaults.session_key 30 | self.item_key = defaults.item_key 31 | self.time_key = defaults.time_key 32 | self.grad_cap = defaults.grad_cap 33 | self.n_items = defaults.n_items 34 | if defaults.hidden_act == 'tanh': 35 | self.hidden_act = self.tanh 36 | elif defaults.hidden_act == 'relu': 37 | self.hidden_act = self.relu 38 | else: 39 | raise NotImplementedError 40 | 41 | if defaults.loss == 'cross-entropy': 42 | if defaults.final_act == 'tanh': 43 | self.final_activation = self.softmaxth 44 | else: 45 | self.final_activation = self.softmax 46 | self.loss_function = self.cross_entropy 47 | elif defaults.loss == 'bpr': 48 | if defaults.final_act == 'linear': 49 | self.final_activation = self.linear 50 | elif defaults.final_act == 'relu': 51 | self.final_activation = self.relu 52 | else: 53 | self.final_activation = self.tanh 54 | self.loss_function = self.bpr 55 | elif defaults.loss == 'top1': 56 | if defaults.final_act == 'linear': 57 | self.final_activation = self.linear 58 | elif defaults.final_act == 'relu': 59 | self.final_activatin = self.relu 60 | else: 61 | self.final_activation = self.tanh 62 | self.loss_function = self.top1 63 | else: 64 | raise NotImplementedError 65 | 66 | self.checkpoint_dir = defaults.checkpoint_dir 67 | if not os.path.isdir(self.checkpoint_dir): 68 | raise Exception("[!] Checkpoint Dir not found") 69 | 70 | self.model() 71 | self.sess.run(tf.global_variables_initializer()) 72 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 10) 73 | 74 | if self.is_training: 75 | return 76 | 77 | # use self.predict_state to hold hidden states during prediction. 78 | self.predict_state = [np.zeros([self.batch_size, self.rnn_size], dtype = np.float32) for _ in range(self.layers)] 79 | ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) 80 | 81 | if ckpt and ckpt.model_checkpoint_path: 82 | self.saver.restore(self.sess, '{}/model.ckpt-{}'.format(self.checkpoint_dir, defaults.test_model)) 83 | 84 | # activation 85 | 86 | def linear(self, X): 87 | return X 88 | 89 | def tanh(self, X): 90 | return tf.nn.tanh(X) 91 | 92 | def softmax(self, X): 93 | return tf.nn.softmax(X) 94 | 95 | def softmaxth(self, X): 96 | return tf.nn.softmax(tf.tanh(X)) 97 | 98 | def relu(self, X): 99 | return tf.nn.relu(X) 100 | 101 | def sigmoid(self, X): 102 | return tf.nn.sigmoid(X) 103 | 104 | # loss 105 | def cross_entropy(self, yhat): 106 | return tf.reduce_mean(-tf.log(tf.diag_part(yhat)+1e-24)) 107 | 108 | def bpr(self, yhat): 109 | yhatT = tf.transpose(yhat) 110 | return tf.reduce_mean(-tf.log(tf.nn.sigmoid(tf.diag_part(yhat)-yhatT))) 111 | 112 | def top1(self, yhat): 113 | yhatT = tf.transpose(yhat) 114 | term1 = tf.reduce_mean(tf.nn.sigmoid(-tf.diag_part(yhat)+yhatT)+tf.nn.sigmoid(yhatT**2), axis=0) 115 | term2 = tf.nn.sigmoid(tf.diag_part(yhat)**2) / self.batch_size 116 | return tf.reduce_mean(term1 - term2) 117 | 118 | def model(self): 119 | self.X = tf.placeholder(tf.int32, [self.batch_size], name='input') 120 | self.Y = tf.placeholder(tf.int32, [self.batch_size], name='output') 121 | self.state = [tf.placeholder(tf.float32, [self.batch_size, self.rnn_size], name='rnn_state') for _ in range(self.layers)] 122 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 123 | 124 | with tf.variable_scope('gru_layer'): 125 | sigma = self.sigma if self.sigma != 0 else np.sqrt(6.0 / (self.n_items + self.rnn_size)) 126 | if self.init_as_normal: 127 | initializer = tf.random_normal_initializer(mean=0, stddev=sigma) 128 | else: 129 | initializer = tf.random_uniform_initializer(minval=-sigma, maxval=sigma) 130 | embedding = tf.get_variable('embedding', [self.n_items, self.rnn_size], initializer=initializer) 131 | softmax_W = tf.get_variable('softmax_w', [self.n_items, self.rnn_size], initializer=initializer) 132 | softmax_b = tf.get_variable('softmax_b', [self.n_items], initializer=tf.constant_initializer(0.0)) 133 | 134 | cell = rnn_cell.GRUCell(self.rnn_size, activation=self.hidden_act) 135 | drop_cell = rnn_cell.DropoutWrapper(cell, output_keep_prob=self.dropout_p_hidden) 136 | stacked_cell = rnn_cell.MultiRNNCell([drop_cell] * self.layers) 137 | 138 | inputs = tf.nn.embedding_lookup(embedding, self.X) 139 | output, state = stacked_cell(inputs, tuple(self.state)) 140 | self.final_state = state 141 | 142 | if self.is_training: 143 | ''' 144 | Use other examples of the minibatch as negative samples. 145 | ''' 146 | sampled_W = tf.nn.embedding_lookup(softmax_W, self.Y) 147 | sampled_b = tf.nn.embedding_lookup(softmax_b, self.Y) 148 | logits = tf.matmul(output, sampled_W, transpose_b=True) + sampled_b 149 | self.yhat = self.final_activation(logits) 150 | self.cost = self.loss_function(self.yhat) 151 | else: 152 | logits = tf.matmul(output, softmax_W, transpose_b=True) + softmax_b 153 | self.yhat = self.final_activation(logits) 154 | 155 | if not self.is_training: 156 | return 157 | 158 | self.lr = tf.maximum(1e-5,tf.train.exponential_decay(self.learning_rate, self.global_step, self.decay_steps, self.decay, staircase=True)) 159 | 160 | ''' 161 | Try different optimizers. 162 | ''' 163 | #optimizer = tf.train.AdagradOptimizer(self.lr) 164 | optimizer = tf.train.AdamOptimizer(self.lr) 165 | #optimizer = tf.train.AdadeltaOptimizer(self.lr) 166 | #optimizer = tf.train.RMSPropOptimizer(self.lr) 167 | 168 | tvars = tf.trainable_variables() 169 | gvs = optimizer.compute_gradients(self.cost, tvars) 170 | if self.grad_cap > 0: 171 | capped_gvs = [(tf.clip_by_norm(grad, self.grad_cap), var) for grad, var in gvs] 172 | else: 173 | capped_gvs = gvs 174 | self.train_op = optimizer.apply_gradients(capped_gvs, global_step = self.global_step) 175 | 176 | def init(self, data): 177 | data.sort([self.session_key, self.time_key], inplace = True) 178 | offset_sessions = np.zeros(data[self.session_key].nunique() + 1, dtype = np.int32) 179 | offset_sessions[1:] = data.groupby(self.session_key).size().cumsum() 180 | return offset_sessions 181 | 182 | def train(self, data): 183 | self.error_during_train = False 184 | itemids = data[self.item_key].unique() 185 | 186 | self.n_items = len(itemids) 187 | self.itemidmap = pd.Series(data=np.arange(self.n_items), index=itemids) 188 | 189 | data = pd.merge(data, pd.DataFrame({self.item_key:itemids, 'ItemIdx':self.itemidmap[itemids].values}), on=self.item_key, how='inner') 190 | 191 | offset_sessions = self.init(data) 192 | 193 | print('training model...') 194 | 195 | for epoch in range(self.n_epochs): 196 | print('training epoch: {}'.format(epoch)) 197 | 198 | epoch_cost = [] 199 | state = [np.zeros([self.batch_size, self.rnn_size], dtype=np.float32) for _ in range(self.layers)] 200 | session_idx_arr = np.arange(len(offset_sessions)-1) 201 | iters = np.arange(self.batch_size) 202 | maxiter = iters.max() 203 | start = offset_sessions[session_idx_arr[iters]] 204 | end = offset_sessions[session_idx_arr[iters]+1] 205 | finished = False 206 | 207 | while not finished: 208 | minlen = (end-start).min() 209 | out_idx = data.ItemIdx.values[start] 210 | for i in range(minlen-1): 211 | in_idx = out_idx 212 | out_idx = data.ItemIdx.values[start+i+1] 213 | # prepare inputs, targeted outputs and hidden states 214 | fetches = [self.cost, self.final_state, self.global_step, self.lr, self.train_op] 215 | feed_dict = {self.X: in_idx, self.Y: out_idx} 216 | for j in range(self.layers): 217 | feed_dict[self.state[j]] = state[j] 218 | 219 | cost, state, step, lr, _ = self.sess.run(fetches, feed_dict) 220 | epoch_cost.append(cost) 221 | if np.isnan(cost): 222 | print(str(epoch) + ':Nan error!') 223 | self.error_during_train = True 224 | return 225 | if step == 1 or step % self.decay_steps == 0: 226 | avgc = np.mean(epoch_cost) 227 | print('Epoch {}\tStep {}\tlr: {:.6f}\tloss: {:.6f}'.format(epoch, step, lr, avgc)) 228 | start = start+minlen-1 229 | mask = np.arange(len(iters))[(end-start)<=1] 230 | for idx in mask: 231 | maxiter += 1 232 | if maxiter >= len(offset_sessions)-1: 233 | finished = True 234 | break 235 | iters[idx] = maxiter 236 | start[idx] = offset_sessions[session_idx_arr[maxiter]] 237 | end[idx] = offset_sessions[session_idx_arr[maxiter]+1] 238 | if len(mask) and self.reset_after_session: 239 | for i in range(self.layers): 240 | state[i][mask] = 0 241 | 242 | avgc = np.mean(epoch_cost) 243 | if np.isnan(avgc): 244 | print('Epoch {}: Nan error!'.format(epoch, avgc)) 245 | self.error_during_train = True 246 | return 247 | 248 | save_path = self.saver.save(self.sess, '{}/model.ckpt'.format(self.checkpoint_dir), global_step = epoch) 249 | print('Model saved to {}'.format(save_path)) 250 | 251 | def predict_next_batch(self, session_ids, input_item_ids, itemidmap, batch=50): 252 | ''' 253 | Gives predicton scores for a selected set of items. Can be used in batch mode to predict for multiple independent events (i.e. events of different sessions) at once and thus speed up evaluation. 254 | 255 | If the session ID at a given coordinate of the session_ids parameter remains the same during subsequent calls of the function, the corresponding hidden state of the network will be kept intact (i.e. that's how one can predict an item to a session). 256 | If it changes, the hidden state of the network is reset to zeros. 257 | 258 | Parameters 259 | -------- 260 | session_ids : 1D array 261 | Contains the session IDs of the events of the batch. Its length must equal to the prediction batch size (batch param). 262 | input_item_ids : 1D array 263 | Contains the item IDs of the events of the batch. Every item ID must be must be in the training data of the network. Its length must equal to the prediction batch size (batch param). 264 | batch : int 265 | Prediction batch size. 266 | 267 | Returns 268 | -------- 269 | out : pandas.DataFrame 270 | Prediction scores for selected items for every event of the batch. 271 | Columns: events of the batch; rows: items. Rows are indexed by the item IDs. 272 | 273 | ''' 274 | if batch != self.batch_size: 275 | raise Exception('Predict batch size({}) must match train batch size({})'.format(batch, self.batch_size)) 276 | if not self.predict: 277 | self.current_session = np.ones(batch) * -1 278 | self.predict = True 279 | 280 | session_change = np.arange(batch)[session_ids != self.current_session] 281 | if len(session_change) > 0: # change internal states with session changes 282 | for i in range(self.layers): 283 | self.predict_state[i][session_change] = 0.0 284 | self.current_session=session_ids.copy() 285 | 286 | in_idxs = itemidmap[input_item_ids] 287 | fetches = [self.yhat, self.final_state] 288 | feed_dict = {self.X: in_idxs} 289 | for i in range(self.layers): 290 | feed_dict[self.state[i]] = self.predict_state[i] 291 | preds, self.predict_state = self.sess.run(fetches, feed_dict) 292 | preds = np.asarray(preds).T 293 | return pd.DataFrame(data=preds, index=itemidmap.index) 294 | 295 | if __name__ == '__main__': 296 | defaults = Defaults() 297 | 298 | data = pd.read_csv(PATH_TO_TRAIN, sep='\t', dtype={'ItemId': np.int64}) 299 | valid = pd.read_csv(PATH_TO_TEST, sep='\t', dtype={'ItemId': np.int64}) 300 | 301 | defaults.n_items = len(data['ItemId'].unique()) 302 | defaults.dropout_p_hidden = 1.0 if defaults.is_training == 0 else 0.5 303 | 304 | if not os.path.exists(defaults.checkpoint_dir): 305 | os.mkdir(defaults.checkpoint_dir) 306 | 307 | gpu_config = tf.ConfigProto() 308 | gpu_config.gpu_options.allow_growth = True 309 | 310 | with tf.Session(config = gpu_config) as session: 311 | predictor = Session4RecPredictor(defaults, session) 312 | 313 | if defaults.is_training: 314 | print('Start session4rec training...') 315 | predictor.train(data) 316 | else: 317 | res = evaluation.evaluate_sessions_batch(predictor, data, valid) 318 | print('Recall@20: {}\tMRR@20: {}'.format(res[0], res[1])) 319 | -------------------------------------------------------------------------------- /model_simple_rnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.ops import rnn, rnn_cell 4 | 5 | class SeriesPredictor: 6 | def __init__(self, input_dim, seq_size, hidden_dim=10): 7 | # Hyperparameters 8 | self.input_dim = input_dim 9 | self.seq_size = seq_size 10 | self.hidden_dim = hidden_dim 11 | 12 | # Weight variables and input placeholders 13 | self.W_out = tf.Variable(tf.random_normal([hidden_dim, 1]), name='W_out') 14 | self.b_out = tf.Variable(tf.random_normal([1]), name='b_out') 15 | self.x = tf.placeholder(tf.float32, [None, seq_size, input_dim]) 16 | self.y = tf.placeholder(tf.float32, [None, seq_size]) 17 | 18 | # Cost optimizer 19 | self.cost = tf.reduce_mean(tf.square(self.model() - self.y)) 20 | self.train_op = tf.train.AdamOptimizer().minimize(self.cost) 21 | 22 | # Auxiliary ops 23 | self.saver = tf.train.Saver() 24 | 25 | def model(self): 26 | """ 27 | :param x: inputs of size [T, batch_size, input_size] 28 | :param W: matrix of fully-connected output layer weights 29 | :param b: vector of fully-connected output layer biases 30 | """ 31 | cell = rnn_cell.BasicLSTMCell(self.hidden_dim, reuse=tf.get_variable_scope().reuse) 32 | outputs, states = rnn.dynamic_rnn(cell, self.x, dtype=tf.float32) 33 | num_examples = tf.shape(self.x)[0] 34 | W_repeated = tf.tile(tf.expand_dims(self.W_out, 0), [num_examples, 1, 1]) 35 | out = tf.matmul(outputs, W_repeated) + self.b_out 36 | out = tf.squeeze(out) 37 | return out 38 | 39 | def train(self, train_x, train_y): 40 | with tf.Session() as sess: 41 | tf.get_variable_scope().reuse_variables() 42 | sess.run(tf.initialize_all_variables()) 43 | for i in range(1000): 44 | _, mse = sess.run([self.train_op, self.cost], feed_dict={self.x: train_x, self.y: train_y}) 45 | if i % 100 == 0: 46 | print(i, mse) 47 | save_path = self.saver.save(sess, './model.ckpt') 48 | print('Model saved to {}'.format(save_path)) 49 | 50 | def test(self, test_x): 51 | with tf.Session() as sess: 52 | tf.get_variable_scope().reuse_variables() 53 | self.saver.restore(sess, './model.ckpt') 54 | output = sess.run(self.model(), feed_dict={self.x: test_x}) 55 | print(output) 56 | 57 | if __name__ == '__main__': 58 | predictor = SeriesPredictor(input_dim=1, seq_size=4, hidden_dim=10) 59 | train_x = [[[1], [2], [5], [6]], 60 | [[5], [7], [7], [8]], 61 | [[3], [4], [5], [7]]] 62 | train_y = [[1, 3, 7, 11], 63 | [5, 12, 14, 15], 64 | [3, 7, 9, 12]] 65 | predictor.train(train_x, train_y) 66 | 67 | test_x = [[[1], [2], [3], [4]], # => prediction should be 1, 3, 5, 7 68 | [[4], [5], [6], [7]]] # => prediction should be 4, 9, 11, 13 69 | predictor.test(test_x) 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==0.19.2 2 | tensorflow==1.6.0 3 | -------------------------------------------------------------------------------- /scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Jun 25 16:20:12 2015 4 | 5 | @author: Balázs Hidasi 6 | """ 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import datetime as dt 11 | 12 | PATH_TO_ORIGINAL_DATA = '/path/to/original/data/' 13 | PATH_TO_PROCESSED_DATA = '/path/to/processed/data/' 14 | 15 | print('Loading data from: {}\nOutput directory: {}'.format(PATH_TO_ORIGINAL_DATA, PATH_TO_PROCESSED_DATA)) 16 | 17 | data = pd.read_csv(PATH_TO_ORIGINAL_DATA + 'yoochoose-clicks.dat', sep=',', header=None, usecols=[0,1,2], dtype={0:np.int32, 1:str, 2:np.int64}) 18 | data.columns = ['SessionId', 'TimeStr', 'ItemId'] 19 | data['Time'] = data.TimeStr.apply(lambda x: dt.datetime.strptime(x, '%Y-%m-%dT%H:%M:%S.%fZ').timestamp()) #This is not UTC. It does not really matter. 20 | del(data['TimeStr']) 21 | 22 | session_lengths = data.groupby('SessionId').size() 23 | data = data[np.in1d(data.SessionId, session_lengths[session_lengths>1].index)] 24 | item_supports = data.groupby('ItemId').size() 25 | data = data[np.in1d(data.ItemId, item_supports[item_supports>=5].index)] 26 | session_lengths = data.groupby('SessionId').size() 27 | data = data[np.in1d(data.SessionId, session_lengths[session_lengths>=2].index)] 28 | 29 | tmax = data.Time.max() 30 | session_max_times = data.groupby('SessionId').Time.max() 31 | session_train = session_max_times[session_max_times < tmax-86400].index 32 | session_test = session_max_times[session_max_times >= tmax-86400].index 33 | train = data[np.in1d(data.SessionId, session_train)] 34 | test = data[np.in1d(data.SessionId, session_test)] 35 | test = test[np.in1d(test.ItemId, train.ItemId)] 36 | tslength = test.groupby('SessionId').size() 37 | test = test[np.in1d(test.SessionId, tslength[tslength>=2].index)] 38 | print('Full train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(len(train), train.SessionId.nunique(), train.ItemId.nunique())) 39 | train.to_csv(PATH_TO_PROCESSED_DATA + 'rsc15_train_full.txt', sep='\t', index=False) 40 | print('Test set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(len(test), test.SessionId.nunique(), test.ItemId.nunique())) 41 | test.to_csv(PATH_TO_PROCESSED_DATA + 'rsc15_test.txt', sep='\t', index=False) 42 | 43 | tmax = train.Time.max() 44 | session_max_times = train.groupby('SessionId').Time.max() 45 | session_train = session_max_times[session_max_times < tmax-86400].index 46 | session_valid = session_max_times[session_max_times >= tmax-86400].index 47 | train_tr = train[np.in1d(train.SessionId, session_train)] 48 | valid = train[np.in1d(train.SessionId, session_valid)] 49 | valid = valid[np.in1d(valid.ItemId, train_tr.ItemId)] 50 | tslength = valid.groupby('SessionId').size() 51 | valid = valid[np.in1d(valid.SessionId, tslength[tslength>=2].index)] 52 | print('Train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(len(train_tr), train_tr.SessionId.nunique(), train_tr.ItemId.nunique())) 53 | train_tr.to_csv(PATH_TO_PROCESSED_DATA + 'rsc15_train_tr.txt', sep='\t', index=False) 54 | print('Validation set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(len(valid), valid.SessionId.nunique(), valid.ItemId.nunique())) 55 | valid.to_csv(PATH_TO_PROCESSED_DATA + 'rsc15_train_valid.txt', sep='\t', index=False) 56 | --------------------------------------------------------------------------------