├── .gitignore ├── .theanorc_gru4rec ├── README.md ├── baselines.py ├── custom_opt.py ├── custom_theano_ops.py ├── datatools.py ├── evaluation.py ├── examples └── rsc15 │ ├── preprocess.py │ └── run_rsc15.py ├── gpu_ops.py ├── gru4rec.py ├── img ├── training_time_bprmax_batch_size.png ├── training_time_bprmax_layers.png ├── training_time_public_data.png ├── training_time_xe_batch_size.png └── training_time_xe_layers.png ├── license.txt ├── param_samples ├── rsc15_bpr-max.py ├── rsc15_bpr-max_constrained.py ├── rsc15_cross-entropy.py ├── rsc15_cross-entropy_logq.py └── rsc15_xe_logq.py ├── paramfiles ├── coveo_bprmax_shared_best.py ├── diginetica_bprmax_shared_best.py ├── rees46_xe_shared_best.py ├── retailrocket_bprmax_shared_best.py ├── rsc15_xe_shared_100_best.py └── yoochoose_xe_shared_best.py ├── paramspaces ├── gru4rec_bprmax_standard_parspace.json └── gru4rec_xe_standard_parspace.json ├── paropt.py └── run.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /.theanorc_gru4rec: -------------------------------------------------------------------------------- 1 | #theanorc 2 | [global] 3 | device=cuda0 4 | floatX=float32 5 | allow_gc=False 6 | mode=FAST_RUN 7 | optimizer_excluding=local_dnn_reduction:local_cudnn_maxandargmax:local_dnn_argmax 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GRU4Rec 2 | 3 | This is the original Theano implementation of the algorithm of the paper ["Session-based Recommendations With Recurrent Neural Networks"](https://arxiv.org/abs/1511.06939 "Session-based Recommendations With Recurrent Neural Networks"), with the extensions introduced in the paper ["Recurrent Neural Networks with Top-k Gains for Session-based Recommendations"](https://arxiv.org/abs/1706.03847 "Recurrent Neural Networks with Top-k Gains for Session-based Recommendations"). 4 | 5 | Make sure to always use the latest version as baseline and cite both papers when you do so! 6 | 7 | The code was optimized for fast execution on the GPU (up to 1500 mini-batch per second on a GTX 1080Ti). According to the Theano profiler, training spends 97.5% of the time on the GPU (0.5% on CPU and 2% moving data between the two). Running on the CPU is not supported, but it is possible with some modificatons to the code. 8 | 9 | If you are afraid of using Theano, the following official reimplementations are also available. 10 | - [Official **PyTorch** version of GRU4Rec](https://github.com/hidasib/GRU4Rec_PyTorch_Official) 11 | - [Official **Tensorflow** version of GRU4Rec](https://github.com/hidasib/GRU4Rec_Tensorflow_Official) 12 | 13 | *NOTE:* These have been validated against the original, but due to how more modern deep learning frameworks operate, they are 1.5-4x slower than this version. Other reimplementations might be available in the future, depending on the research community's interest level. 14 | **IMPORTANT!** Avoid using unofficial reimplementations. We thorougly examined 6 third party reimplementations (PyTorch/Tensorflow, standalone/framework) in ["The Effect of Third Party Implementations on Reproducibility"](https://arxiv.org/abs/2307.14956) and all of them were flawed and/or missed important features, that resulted in up to **99% lower recommendation accuracy** and up to **335 times longer training times**. Other reimplementations we have found since then are no better. 15 | 16 | You can train and evaluate the model on your own session data easily using `run.py`. Usage information below. 17 | 18 | Scroll down for infromation on reproducing results on public datasets and hyperparameter tuning! 19 | 20 | **LICENSE:** See [license.txt](license.txt) for details. Main guidelines: for research and education purposes the code is and always will be free to use. Using the code or parts of it in commercial systems requires a licence. If you've been using the code or any of its derivates in a commercial system, contact me! 21 | 22 | **CONTENTS:** 23 | [Requirements](#requirements "Requirements") 24 |   [Theano configuration](#theano-configuration "Theano configuration") 25 | [Usage](#usage "Usage") 26 |   [Execute experiments using `run.py`](#execute-experiments-using-runpy "Execute experiments using run.py") 27 |     [Examples](#examples "Examples") 28 |   [Using GRU4Rec in code or the interpreter](#using-gru4rec-in-code-or-the-interpreter "Using GRU4Rec in code or the interpreter") 29 |   [Notes on sequence-aware and session-based models](#notes-on-sequence-aware-and-session-based-models "Notes on sequence-aware and session-based models") 30 |   [Notes on parameter settings](#notes-on-parameter-settings "Notes on parameter settings") 31 | [Speed of training](#speed-of-training "Speed of training") 32 | [Reproducing results on public datasets](#reproducing-results-on-public-datasets "Reproducing results on public datasets") 33 | [Hyperparameter tuning](#hyperparameter-tuning "Hyperparameter tuning") 34 | [Executing on CPU](#executing-on-cpu "Executing on CPU") 35 | [Major updates](#major-updates "Major updates") 36 | 37 | 38 | ## Requirements 39 | 40 | - **python** --> Use python `3.6.3` or newer. The code was mostly tested on `3.6.3`, `3.7.6` and `3.8.12`, but was briefly tested on other versions. Python 2 is NOT supported. 41 | - **numpy** --> `1.16.4` or newer. 42 | - **pandas** --> `0.24.2` or newer. 43 | - **CUDA** --> Needed for the GPU support of Theano. The latest CUDA version Theano was tested with (to the best of my knowledge) is `9.2`. It works fine with more recent versions, e.g. `11.8`. 44 | - **libgpuarray** --> Required for the GPU support of Theano, use the latest version. 45 | - **theano** --> `1.0.5` (last stable release) or newer (occassionally it is still updated with minor stuff). GPU support should be installed. 46 | - **optuna** --> (optional) for hyperparameter optimization, code was tested with `3.0.3` 47 | 48 | **IMPORTANT: cuDNN** --> More recent versions produce a warning, but `8.2.1` still work for me. GRU4Rec doesn't rely heavily on the part of Theano that utilizes cuDNN. Unfortunately, `cudnnReduceTensor` in cuDNN `v7` and newer is seriously bugged, which makes operators based on this function slow and even occasionally unstable (incorrect computations or segfault) when cuDNN is used (e.g. [see here](https://github.com/Theano/Theano/issues/6432)). Therefore it is best not to use cuDNN. If you already have it installed, you can easily configure Theano to exclude cuDNN based operators (see below). 49 | **This bug is not related to Theano and can be reproduced from CUDA/C++. Unfortunately it hasn't been fixed for more than 6 years.* 50 | 51 | ### Theano configuration 52 | 53 | This code was optimized for GPU execution. Executing the code will fail if you try to run it on CPU (if you really want to mess with it, check out the relevant section of this readme). Therefore Theano configuration must be set in a way to use the GPU. If you use `run.py` for runnning experiments, the code sets this configuration for you. You might want to change some of the preset configuration (e.g. execute on a specified GPU instead of the one with the lowest Id). You can do this in the `THEANO_FLAGS` environment variable or edit `.theanorc_gru4rec`. 54 | 55 | If you don't use `run.py`, it is possible that the preset config won't have any effect (this happens if theano is imported before `gru4rec` either directly or by another module). In this case, you must set your own config by either editing your `.theanorc` or setting up the `THEANO_FLAGS` environment variable. Please refer to the [documentation of Theano](http://deeplearning.net/software/theano/library/config.html). 56 | 57 | **Important config parameters** 58 | - `device` --> must always be a CUDA capable GPU (e.g. `cuda0`). 59 | - `floatX` --> must always be `float32` 60 | - `mode` --> should be `FAST_RUN` for fast execution 61 | - `optimizer_excluding` --> should be `local_dnn_reduction:local_cudnn_maxandargmax:local_dnn_argmax` to tell Theano not to use cuDNN based operators, because its `cudnnReduceTensor` function has been bugged since `v7` 62 | 63 | ## Usage 64 | 65 | ### Execute experiments using `run.py` 66 | `run.py` is an easy way to train, evaluate and save/load GRU4Rec models. 67 | 68 | Execute with the `-h` argument to take a look at the parameters. 69 | ``` 70 | $ python run.py -h 71 | ``` 72 | Output: 73 | ``` 74 | usage: run.py [-h] [-ps PARAM_STRING] [-pf PARAM_PATH] [-l] [-s MODEL_PATH] [-t TEST_PATH [TEST_PATH ...]] [-m AT [AT ...]] [-e EVAL_TYPE] [-ss SS] [--sample_store_on_cpu] [-g GRFILE] [-d D] [-ik IK] [-sk SK] [-tk TK] 75 | [-pm METRIC] [-lpm] 76 | PATH 77 | 78 | Train or load a GRU4Rec model & measure recall and MRR on the specified test set(s). 79 | 80 | positional arguments: 81 | PATH Path to the training data (TAB separated file (.tsv or .txt) or pickled pandas.DataFrame object (.pickle)) (if the --load_model parameter is NOT provided) or to the serialized model (if the 82 | --load_model parameter is provided). 83 | 84 | optional arguments: 85 | -h, --help show this help message and exit 86 | -ps PARAM_STRING, --parameter_string PARAM_STRING 87 | Training parameters provided as a single parameter string. The format of the string is `param_name1=param_value1,param_name2=param_value2...`, e.g.: `loss=bpr- 88 | max,layers=100,constrained_embedding=True`. Boolean training parameters should be either True or False; parameters that can take a list should use / as the separator (e.g. layers=200/200). 89 | Mutually exclusive with the -pf (--parameter_file) and the -l (--load_model) arguments and one of the three must be provided. 90 | -pf PARAM_PATH, --parameter_file PARAM_PATH 91 | Alternatively, training parameters can be set using a config file specified in this argument. The config file must contain a single OrderedDict named `gru4rec_params`. The parameters must have 92 | the appropriate type (e.g. layers = [100]). Mutually exclusive with the -ps (--parameter_string) and the -l (--load_model) arguments and one of the three must be provided. 93 | -l, --load_model Load an already trained model instead of training a model. Mutually exclusive with the -ps (--parameter_string) and the -pf (--parameter_file) arguments and one of the three must be provided. 94 | -s MODEL_PATH, --save_model MODEL_PATH 95 | Save the trained model to the MODEL_PATH. (Default: don't save model) 96 | -t TEST_PATH [TEST_PATH ...], --test TEST_PATH [TEST_PATH ...] 97 | Path to the test data set(s) located at TEST_PATH. Multiple test sets can be provided (separate with spaces). (Default: don't evaluate the model) 98 | -m AT [AT ...], --measure AT [AT ...] 99 | Measure recall & MRR at the defined recommendation list length(s). Multiple values can be provided. (Default: 20) 100 | -e EVAL_TYPE, --eval_type EVAL_TYPE 101 | Sets how to handle if multiple items in the ranked list have the same prediction score (which is usually due to saturation or an error). See the documentation of evaluate_gpu() in evaluation.py 102 | for further details. (Default: standard) 103 | -ss SS, --sample_store_size SS 104 | GRU4Rec uses a buffer for negative samples during training to maximize GPU utilization. This parameter sets the buffer length. Lower values require more frequent recomputation, higher values 105 | use more (GPU) memory. Unless you know what you are doing, you shouldn't mess with this parameter. (Default: 10000000) 106 | --sample_store_on_cpu 107 | If provided, the sample store will be stored in the RAM instead of the GPU memory. This is not advised in most cases, because it significantly lowers the GPU utilization. This option is 108 | provided if for some reason you want to train the model on the CPU (NOT advised). Note that you need to make modifications to the code so that it is able to run on CPU. 109 | -g GRFILE, --gru4rec_model GRFILE 110 | Name of the file containing the GRU4Rec class. Can be used to select different varaiants. (Default: gru4rec) 111 | -ik IK, --item_key IK 112 | Column name corresponding to the item IDs (detault: ItemId). 113 | -sk SK, --session_key SK 114 | Column name corresponding to the session IDs (default: SessionId). 115 | -tk TK, --time_key TK 116 | Column name corresponding to the timestamp (default: Time). 117 | -pm METRIC, --primary_metric METRIC 118 | Set primary metric, recall or mrr (e.g. for paropt). (Default: recall) 119 | -lpm, --log_primary_metric 120 | If provided, evaluation will log the value of the primary metric at the end of the run. Only works with one test file and list length. 121 | ``` 122 | 123 | #### Examples 124 | 125 | Train, save and evaluate a model measuring recall and MRR at 1, 5, 10 and 20 using model parameters from a parameter string. 126 | ``` 127 | $ THEANO_FLAGS=device=cuda0 python run.py /path/to/training_data_file -t /path/to/test_data_file -m 1 5 10 20 -ps layers=224,batch_size=80,dropout_p_embed=0.5,dropout_p_hidden=0.05,learning_rate=0.05,momentum=0.4,n_sample=2048,sample_alpha=0.4,bpreg=1.95,logq=0.0,loss=bpr-max,constrained_embedding=True,final_act=elu-0.5,n_epochs=10 -s /path/to/save_model.pickle 128 | ``` 129 | Output (on the RetailRocket dataset): 130 | ``` 131 | Using cuDNN version 8201 on context None 132 | Mapped name None to device cuda0: NVIDIA A30 (0000:3B:00.0) 133 | Creating GRU4Rec model 134 | SET layers TO [224] (type: ) 135 | SET batch_size TO 80 (type: ) 136 | SET dropout_p_embed TO 0.5 (type: ) 137 | SET dropout_p_hidden TO 0.05 (type: ) 138 | SET learning_rate TO 0.05 (type: ) 139 | SET momentum TO 0.4 (type: ) 140 | SET n_sample TO 2048 (type: ) 141 | SET sample_alpha TO 0.4 (type: ) 142 | SET bpreg TO 1.95 (type: ) 143 | SET logq TO 0.0 (type: ) 144 | SET loss TO bpr-max (type: ) 145 | SET constrained_embedding TO True (type: ) 146 | SET final_act TO elu-0.5 (type: ) 147 | SET n_epochs TO 10 (type: ) 148 | Loading training data... 149 | Loading data from TAB separated file: /path/to/training_data_file 150 | Started training 151 | The dataframe is already sorted by SessionId, Time 152 | Created sample store with 4882 batches of samples (type=GPU) 153 | Epoch1 --> loss: 0.484484 (6.81s) [1026.65 mb/s | 81386 e/s] 154 | Epoch2 --> loss: 0.381974 (6.89s) [1015.39 mb/s | 80493 e/s] 155 | Epoch3 --> loss: 0.353932 (6.81s) [1027.68 mb/s | 81468 e/s] 156 | Epoch4 --> loss: 0.340034 (6.80s) [1028.90 mb/s | 81564 e/s] 157 | Epoch5 --> loss: 0.330763 (6.80s) [1028.19 mb/s | 81508 e/s] 158 | Epoch6 --> loss: 0.324075 (6.80s) [1029.36 mb/s | 81601 e/s] 159 | Epoch7 --> loss: 0.319033 (6.85s) [1022.03 mb/s | 81020 e/s] 160 | Epoch8 --> loss: 0.314915 (6.80s) [1029.05 mb/s | 81577 e/s] 161 | Epoch9 --> loss: 0.311716 (6.82s) [1025.44 mb/s | 81290 e/s] 162 | Epoch10 --> loss: 0.308915 (6.82s) [1025.64 mb/s | 81306 e/s] 163 | Total training time: 77.73s 164 | Saving trained model to: /path/to/save_model.pickle 165 | Loading test data... 166 | Loading data from TAB separated file: /path/to/test_data_file 167 | Starting evaluation (cut-off=[1, 5, 10, 20], using standard mode for tiebreaking) 168 | Measuring Recall@1,5,10,20 and MRR@1,5,10,20 169 | Evaluation took 4.34s 170 | Recall@1: 0.128055 MRR@1: 0.128055 171 | Recall@5: 0.322165 MRR@5: 0.197492 172 | Recall@20: 0.518184 MRR@20: 0.217481 173 | ``` 174 | 175 | Train on `cuda0` using parameters from a parameter file and save the model. 176 | ``` 177 | $ THEANO_FLAGS=device=cuda0 python run.py /path/to/training_data_file -pf /path/to/parameter_file.py -s /path/to/save_model.pickle 178 | ``` 179 | Output (on the RetailRocket dataset): 180 | ``` 181 | Using cuDNN version 8201 on context None 182 | Mapped name None to device cuda0: NVIDIA A30 (0000:3B:00.0) 183 | Creating GRU4Rec model 184 | SET layers TO [224] (type: ) 185 | SET batch_size TO 80 (type: ) 186 | SET dropout_p_embed TO 0.5 (type: ) 187 | SET dropout_p_hidden TO 0.05 (type: ) 188 | SET learning_rate TO 0.05 (type: ) 189 | SET momentum TO 0.4 (type: ) 190 | SET n_sample TO 2048 (type: ) 191 | SET sample_alpha TO 0.4 (type: ) 192 | SET bpreg TO 1.95 (type: ) 193 | SET logq TO 0.0 (type: ) 194 | SET loss TO bpr-max (type: ) 195 | SET constrained_embedding TO True (type: ) 196 | SET final_act TO elu-0.5 (type: ) 197 | SET n_epochs TO 10 (type: ) 198 | Loading training data... 199 | Loading data from TAB separated file: /path/to/training_data_file 200 | Started training 201 | The dataframe is already sorted by SessionId, Time 202 | Created sample store with 4882 batches of samples (type=GPU) 203 | Epoch1 --> loss: 0.484484 (6.81s) [1026.65 mb/s | 81386 e/s] 204 | Epoch2 --> loss: 0.381974 (6.89s) [1015.39 mb/s | 80493 e/s] 205 | Epoch3 --> loss: 0.353932 (6.81s) [1027.68 mb/s | 81468 e/s] 206 | Epoch4 --> loss: 0.340034 (6.80s) [1028.90 mb/s | 81564 e/s] 207 | Epoch5 --> loss: 0.330763 (6.80s) [1028.19 mb/s | 81508 e/s] 208 | Epoch6 --> loss: 0.324075 (6.80s) [1029.36 mb/s | 81601 e/s] 209 | Epoch7 --> loss: 0.319033 (6.85s) [1022.03 mb/s | 81020 e/s] 210 | Epoch8 --> loss: 0.314915 (6.80s) [1029.05 mb/s | 81577 e/s] 211 | Epoch9 --> loss: 0.311716 (6.82s) [1025.44 mb/s | 81290 e/s] 212 | Epoch10 --> loss: 0.308915 (6.82s) [1025.64 mb/s | 81306 e/s] 213 | Total training time: 77.73s 214 | Saving trained model to: /path/to/save_model.pickle 215 | ``` 216 | 217 | Load a previously trained model to `cuda1` and evaluate it measuring recall and MRR at 1, 5, 10 and 20 using the conservative method for tiebreaking. 218 | ``` 219 | $ THEANO_FLAGS=device=cuda1 python run.py /path/to/previously_saved_model.pickle -l -t /path/to/test_data_file -m 1 5 10 20 -e conservative 220 | ``` 221 | Output (on the RetailRocket dataset): 222 | ``` 223 | Using cuDNN version 8201 on context None 224 | Mapped name None to device cuda1: NVIDIA A30 (0000:AF:00.0) 225 | Loading trained model from file: /path/to/previously_saved_model.pickle 226 | Loading test data... 227 | Loading data from TAB separated file: /path/to/test_data_file 228 | Starting evaluation (cut-off=[1, 5, 10, 20], using standard mode for tiebreaking) 229 | Measuring Recall@1,5,10,20 and MRR@1,5,10,20 230 | Evaluation took 4.34s 231 | Recall@1: 0.128055 MRR@1: 0.128055 232 | Recall@5: 0.322165 MRR@5: 0.197492 233 | Recall@20: 0.518184 MRR@20: 0.217481 234 | ``` 235 | 236 | ### Using GRU4Rec in code or the interpreter 237 | You can simply import the `gru4rec` module in your code or in an interpreter and use the `GRU4Rec` class to create and train models. The trained models can be evaluated by importing the `evaluation` module and using either the `evaluate_gpu` or the `evaluate_session_batch` method. The latter is deprecated and doesn't fully utilize the GPU and is therefore significantly slower. The public version of this code is mainly for running experiments (training and evaluating the algorithm on different datasets), therefore retrieving the actual predictions can be cumbersome and ineffective. 238 | 239 | **IMPORTANT!** For the sake of convenience, the `gru4rec` module sets some important Theano parameters so that you don't have to worry about them if you are not familiar with Theano. But this only has any effect if `gru4rec` is imported *BEFORE* Theano (and any module that imports Theano) is imported. (Because once Theano is initialized, most of its configuration can't be changed. And even if Theano is reimported, the GPU is not reinitialized.) If you do it the other way around, you should set your default `.theanorc` or provide the `THEANO_FLAGS` environment variable with the appropriate configuration. 240 | 241 | ### Notes on sequence-aware and session-based models 242 | GRU4Rec is originally for session-based recommendations, where the generally short sessions are considered independent. Every time a user comes to the site, they are considered to be unknown, i.e. nothing of their history is used, even if it is known. (This setup is great for many real-life applications.) This means that when the model is evaluated, the hidden state starts from zero for each test session. 243 | 244 | However, RNN (CNN, Transformer, etc.) based models are also a great fit for the practically less important sequence-aware personalized recommendation setup (i.e. the whole user history is used as a sequence to predict future items in the sequence). There are two main differences: 245 | - (1) The sequences are significantly longer in sequence-aware recommendations. This also means that BPTT (backpropagation through time) is useful in this scenario. For session-based recommendations, experiments suggest that BPTT doesn't improve the model. 246 | - (2) Evaluation in the sequence-aware setup should be started from the last value of the hidden state (i.e. the value computed on the training portion of the user history). 247 | 248 | Currently, neither of these are supported in the public code. These functionalities might be added later if there is enough interewst from the community (they exist in some of my internal research repos). At the moment, you have to extend the code yourself to do this. 249 | 250 | ### Notes on parameter settings 251 | GRU4Rec has many parameters (and private versions had even more throughout the years). While you are welcome to play around with them, I found that it is usually the best to leave the following parameters on their default value. 252 | | Parameter | Defaults to | Comment | 253 | |--------------------|:-----------:|----------------------------------------------------------------------------------------------------------------------| 254 | | `hidden_act` | `tanh` | The activation function of the hidden layer should be tanh. | 255 | | `lmbd` | `0.0` | L2 regularization is not needed, use dropout for regularization. | 256 | | `smoothing` | `0.0` | Label smoothing for cross-entropy loss only has a minor impact on performance. | 257 | | `adapt` | `adagrad` | Optimizers perform similarly, with adagrad being slightly better than the others. | 258 | | `adapt_params` | `[]` | Adagrad has no hyperparameters, therefore this is an empty list. | 259 | | `grad_cap` | `0.0` | Training works fine without gradient capping/clipping. | 260 | | `sigma` | `0.0` | Setting the min/max value during weight initialization is not needed, `±sqrt(6.0/(dim[0] + dim[1]))` is used when this is 0. | 261 | | `init_as_normal` | `False` | Weights should be initialized from uniform distribution. | 262 | | `train_random_order` | `False` | Training sessions should not be shuffled so that the last updates are based on recent data. | 263 | | `time_sort` | `True` | Training sessions should be sorted in ascending order by the timestamp of their first event (oldest session first) so that the last updates are based on recent data. | 264 | 265 | **Losses and final activations:** 266 | - Among the five loss options (`cross-entropy`, `bpr-max`, `top1`, `bpr`, `top1-max`) **only `cross-entropy` or `bpr-max` should be used**. The others work to some extent, but they suffer from the vanishing gradient problem and thus produce models inferior to the ones trained with `bpr-max` or `cross-entropy`. 267 | - `cross-entropy` has an alternative formulation `xe_logit` that requires a different final activation to be set. 268 | - Always us the final activation (`final_act` parameter) appropriate for the loss. For `loss=cross-entropy` this is always `final_act=softmax`, for `loss=xe_logit` always use `final_act=softmax_logit`, for `loss=bpr-max` you can use either `final_act=linear`, `final_act=relu`, `final_act=tanh`, , `final_act=leaky-`, `final_act=elu-` or `final_act=selu--` (I usually prefer `elu-0.5` or `elu-1`). 269 | 270 | **Embedding modes:** As it is described in the papers, there are three embedding modes that you can set as follows. 271 | | Embedding mode | How to set | Description | 272 | |-|-|-| 273 | | No embedding | `embedding=0` AND `constrained_embedding=False` | The one-hot vector of the item ID is directly fed to the GRU layer. | 274 | | Separate embedding | `embedding=X` where `X>0` or `X=layersize` AND `conmstrained_embedding=False` | Separate embedding on the input of the GRU layers and for computing the score with the sequence embedding. `embedding=layersize` sets it to the same dimensionality as the number of units in the first GRU layer. | 275 | | Shared embedding | `constrained_embedding=True` | Usually the best performing setting. Uses the same embedding on the input of the GRU layers and for computing the scores with the sequence embedding. This enforces the dimensionality of the embedding to be equal to the size of the last GRU layer, thus the `embedding` parameter has no effect in this mode. | 276 | 277 | 278 | ## Speed of training 279 | This version is the fastest version (by far). The speed of the official PyTorch and Tensorflow implementations are capped by the overhead introduced by the respective DL frameworks. 280 | 281 | Time to complete one epoch (in seconds) on publicly available datasets with the best parameterization (see below), measured on an nVidia A30. The Theano version is 1.7-3 times faster than the PyTorch or Tensorflow versions. 282 | ![image](img/training_time_public_data.png) 283 | 284 | *Details:* The short version is that Theano requires you to build a computational graph and then a Theano function needs to be created that does the computations described by the graph. During the creation of the function, the code is complied into a single (or sometimes more) C++/CUDA executables which are executed every time you call the function from Python. If you don't use any Python based operators, control doesn't need to be given back to Python which significantly lowers the overhead. The published version of GRU4Rec works on ID based representations and thus a single minibatch usually can't max out the GPU. Therefore, having the overhead of passing control between C++/CUDA and Python (as in PyTorch and Tensorflow) can significantly increase training times. This is why the difference is smaller if the layer and/or minibatch size is higher. But optimal performance sometimes requires smaller minibatches. 285 | 286 | The training time mostly depends on the number of events and the model parameters. The following parameters affect the processing speed of events (event/s): 287 | - `batch_size` --> The processing speed of batches (mb/s) decreases much slower than the size increase of the batch, therefore event processing speeds up as `batch_size` increases. Unfortunately, `batch_size` also affects model accuracy and smaller batches are usually better for most datasets. 288 | - `n_sample` --> The number of negative samples up to 500-1000 doesn't affect processing speed (depending on the hardware). The default is `n_sample=2048`, but if the number of items is low, it might be lowered without loss of accuracy. 289 | - `loss` --> `cross-entropy` is somewhat faster than `bpr-max` 290 | - `dropout_p_embed`, `dropout_p_hidden`, `momentum` --> setting these to other than 0, training will be a little bit slower 291 | 292 | The following figures show the difference between training speed (minibatch/second & event/second; higher is better) for various minibatch and layer sizes with and without dropout and momentum enabled, using `n_sample=2048`. Measured on an nVidia A30. 293 | 294 | With `cross-entropy` loss: 295 | 296 | ![image](img/training_time_xe_batch_size.png) 297 | 298 | ![image](img/training_time_xe_layers.png) 299 | 300 | With `bpr-max` loss: 301 | 302 | ![image](img/training_time_bprmax_batch_size.png) 303 | 304 | ![image](img/training_time_bprmax_layers.png) 305 | 306 | 307 | ## Reproducing results on public datasets 308 | The performance of GRU4Rec has been measured on multiple public datasets in [1,2,3,4]: Yoochoose/RSC15, Rees46, Coveo, RetailRocket and Diginetica. 309 | 310 | *IMPORTANT:* Measuring performance of sequential recommenders makes sense only if the data (and the task) itself shows sequential patterns, e.g. session based data. Evaluation on rating data doesn't give informative results. See [4] for details as well as for other common flaws people do during evaluation of sequential recommenders. 311 | 312 | **Notes:** 313 | - Always aim to include at least one realistically large dataset in your comparison (e.g. Rees46 is a good example). 314 | - The evaluation setup is described in detail in [1,2,3]. It is a next-item prediction type evaluation considering only the next item as relevant for a given inference. This is a good setup for behaviour prediction and correlates somewhat with online performance. It is a stricter setup than considering any of the subsequent items as relevant, which - while a perfecly reasonable setup - is more forgiving towards simplistic (e.g. counting based) methods. However, similarly to any other offline evaluation it is not a direct approximation of online performance. 315 | 316 | **Getting the data:** Please refer to the original source of the data to obtain a full and legal copy. Links here are provided as best effort. It is not guaranteed that they won't break over time. 317 | - [Yoochoose/RSC15](https://2015.recsyschallenge.com) or ([reupload on Kaggle](https://www.kaggle.com/datasets/chadgostopp/recsys-challenge-2015)) 318 | - [Rees46](https://www.kaggle.com/datasets/mkechinov/ecommerce-behavior-data-from-multi-category-store) 319 | - [Coveo](https://github.com/coveooss/shopper-intent-prediction-nature-2020) 320 | - [RetailRocket](https://www.kaggle.com/datasets/retailrocket/ecommerce-dataset) 321 | - [Diginetica](https://competitions.codalab.org/competitions/11161#learn_the_details-data2) 322 | 323 | **Preprocessing:** 324 | The details and the reasoning behind the preprocessing steps can be found in [1,2] for RSC15 and in [3] for Yoochoose, Rees46, Coveo, RetailRocket and Diginetica. Preprocessing script for RSC15 can be found in the [here](https://github.com/hidasib/GRU4Rec/blob/master/examples/rsc15/preprocess.py), and in [the repo corresponding to [3]](https://github.com/hidasib/gru4rec_third_party_comparison) for Yoochoose, Rees46, Coveo, RetailRocket and Diginetica. After running the scripts, double check if the statistics of the resulting sets match what is reported in the papers. 325 | 326 | Preprocessing scripts yield 4 files per dataset: 327 | - `train_full` --> full training set, used for training the model for the final evaluation 328 | - `test` --> test set for the final evaluation of the model (the pair of `train_full`) 329 | - `train_tr` --> training set for hyperparameter optimization, experimentation 330 | - `train_valid` --> validation set for hyperparameter optimization, experimentation (the pair of `train_tr`) 331 | 332 | Basically, the full preprocessed dataset is split into `train_full` and `test`, then `train_full` is split into `train_tr` and `train_valid` using the same logic. 333 | 334 | *IMPORTANT:* Note that while RSC15 and Yoochoose is derived from the same source (Yoochoose dataset), the preprocessing is different. The main difference is that RSC15 doesn't use deduplication. Therefore results on the two datasets are not compareable and optimal hyperparameters might differ. It is recommended to use the Yoochoose version and rely on the RSC15 version only when comparing to previously reported results if the experiment can't be reproduced for some reason (e.g. implementation of the method is not available). 335 | 336 | [1] Balázs Hidasi, Alexandros Karatzoglou, Linas Baltrunas, Domonkos Tikk: [Session-based Recommendations with Recurrent Neural Networks](https://arxiv.org/abs/1511.06939), ICLR 2016 337 | [2] Balázs Hidasi, Alexandros Karatzoglou: [Recurrent Neural Networks with Top-k Gains for Session-based Recommendations](https://arxiv.org/abs/1706.03847), CIKM 2018 338 | [3] Balázs Hidasi, Ádám Czapp: [The Effect of Third Party Implementations on Reproducibility](https://arxiv.org/abs/2307.14956), RecSys 2023 339 | [4] Balázs Hidasi, Ádám Czapp: [Widespread Flaws in Offline Evaluation of Recommender Systems](https://arxiv.org/abs/2307.14951), RecSys 2023 340 | 341 | **Hyperparameters:** 342 | Hyperparameters for RSC15 were obtained using a local (star) search optimizer with restarting when a better parameterization is found. It used a smaller parameter space than what is included in this repo (e.g. hidden layer size was fixed to 100). Probably there is room for some small potential improvement here with the new Optuna based optimizer. 343 | 344 | Hyperparameters for Yoochoose, Rees46, Coveo, RetailRocket and Diginetica were obtained using the parameter spaces uploaded to this repo. 200 runs were executed per dataset, per embedding mode (no embedding, separate embedding, shared embedding) and per loss function (cross-entropy, bpr-max). The primary metric was MRR@20 that usually also gave the best results wrt. recall@20. A separate training/validation set was used during parameter optimization that was created from the full training set the same way as the (full) training/test split was created from the full dataset. Final results are measured on the test set with models trained on the full training set. 345 | 346 | **Best hyperparameters:** 347 | *Note:* Parameter files (usable with the `-pf` argument of `run.py`) are [included](https://github.com/hidasib/GRU4Rec/tree/master/paramfiles) in this repo for convenience. 348 | 349 | | Dataset | loss | constrained_embedding | embedding | elu_param | layers | batch_size | dropout_p_embed | dropout_p_hidden | learning_rate | momentum | n_sample | sample_alpha | bpreg | logq | 350 | |---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| 351 | | RSC15 | cross-entropy | True | 0 | 0 | 100 | 32 | 0.1 | 0 | 0.1 | 0 | 2048 | 0.75 | 0 | 1 | 352 | | Yoochoose | cross-entropy | True | 0 | 0 | 480 | 48 | 0 | 0.2 | 0.07 | 0 | 2048 | 0.2 | 0 | 1 | 353 | | Rees46 | cross-entropy | True | 0 | 0 | 512 | 240 | 0.45 | 0 | 0.065 | 0 | 2048 | 0.5 | 0 | 1 | 354 | | Coveo | bpr-max | True | 0 | 1 | 512 | 144 | 0.35 | 0 | 0.05 | 0.4 | 2048 | 0.2 | 1.85 | 0 | 355 | | RetailRocket | bpr-max | True | 0 | 0.5 | 224 | 80 | 0.5 | 0.05 | 0.05 | 0.4 | 2048 | 0.4 | 1.95 | 0 | 356 | | Diginetica | bpr-max | True | 0 | 1 | 512 | 128 | 0.5 | 0.3 | 0.05 | 0.15 | 2048 | 0.3 | 0.9 | 0 | 357 | 358 | **Results:** 359 | *Note:* Due to the changes in the order of the executions of operations on the GPU, some slight variation (even up to a few percent) in the metrics is expected and acceptable. 360 | 361 | | Dataset | Recall@1 | MRR@1 | Recall@5 | MRR@5 | Recall@10 | MRR@10 | Recall@20 | MRR@20 | 362 | |---|---|---|---|---|---|---|---|---| 363 | | RSC15 | 0.1845 | 0.1845 | 0.4906 | 0.2954 | 0.6218 | 0.3130 | 0.7283 | 0.3205 | 364 | | Yoochoose | 0.1829 | 0.1829 | 0.4478 | 0.2783 | 0.5715 | 0.2949 | 0.6789 | 0.3024 | 365 | | Rees46 | 0.1114 | 0.1114 | 0.3010 | 0.1778 | 0.4135 | 0.1928 | 0.5293 | 0.2008 | 366 | | Coveo | 0.0513 | 0.0513 | 0.1496 | 0.0852 | 0.2212 | 0.0946 | 0.3135 | 0.1010 | 367 | | ReatilRocket | 0.1274 | 0.1274 | 0.3237 | 0.1977 | 0.4207 | 0.2107 | 0.5186 | 0.2175 | 368 | | Diginetica | 0.0725 | 0.0725 | 0.2369 | 0.1288 | 0.3542 | 0.1442 | 0.4995 | 0.1542 | 369 | 370 | ## Hyperparameter tuning 371 | Hyperparameter optimization on new datasets is supported by `paropt.py`. Internally it uses [Optuna](https://optuna.org/) and requires a defined parameter space. A few predefined parameter spaces are [included](https://github.com/hidasib/GRU4Rec/tree/master/paramspaces) in this repo. 372 | 373 | **Recommendations:** 374 | - Run between 100 and 200 iterations with the included parameter spaces. 375 | - Run separate optimizations when using different losses and embedding modes (no embedding (i.e. `embedding=0,constrained_embedding=False`), separate embedding (i.e. `embedding=layersize,constrained_embedding=False`) and shared embedding (i.e. `embedding=0,constrained_embedding=True`)). 376 | 377 | **Fixed parameters:** You can play around with these as well in the optimizer, however the following fixed settings have worked well in the past. 378 | - `logq` --> Cross-entropy loss usually works the best with `logq=1`, the parameter has no effect when the BPR-max loss is used. 379 | - `n_sample` --> Based on experience, `n_sample=2048` is large enough to get good performance up to a few millions of items and not too large to significantly degrade the speed of training. However, you might want to lower this if the total number of active items is below 5-10K. 380 | - `n_epochs` --> This is usually set to `n_epochs=10`, but `5` gets you similar performance in most cases. So far there hasn't been any reason to significantly increase the number of epochs. 381 | - embedding mode --> Full paropt needs to check all three options separately, but in the past, shared embedding (`constrained_embedding=True` and `embedding=0`) worked the best for most datasets. 382 | - `loss` --> Full paropt needs to check both separately, but past experience indicates BPR-max to perform better on smaller and cross-entropy to perform better on larger datasets. 383 | - `final_act` --> Always use the final activation appropriate for the loss, e.g. `final_act=softmax` when `loss=cross-entropy` and either `elu-`, `linear` or `relu` when `loss=bpr-max`. 384 | 385 | **Usage:** 386 | ``` 387 | $ python paropt.py -h 388 | ``` 389 | 390 | Output: 391 | ``` 392 | usage: paropt.py [-h] [-g GRFILE] [-tf [FLAGS]] [-fp PARAM_STRING] [-opf PATH] [-m [AT]] [-nt [NT]] [-fm [AT [AT ...]]] [-pm METRIC] [-e EVAL_TYPE] [-ik IK] [-sk SK] [-tk TK] PATH TEST_PATH 393 | 394 | Train or load a GRU4Rec model & measure recall and MRR on the specified test set(s). 395 | 396 | positional arguments: 397 | PATH Path to the training data (TAB separated file (.tsv or .txt) or pickled pandas.DataFrame object (.pickle)) (if the --load_model parameter is NOT provided) or to the serialized model (if the --load_model parameter 398 | is provided). 399 | TEST_PATH Path to the test data set(s) located at TEST_PATH. 400 | 401 | optional arguments: 402 | -h, --help show this help message and exit 403 | -g GRFILE, --gru4rec_model GRFILE 404 | Name of the file containing the GRU4Rec class. Can be sued to select different varaiants. (Default: gru4rec) 405 | -tf [FLAGS], --theano_flags [FLAGS] 406 | Theano settings. 407 | -fp PARAM_STRING, --fixed_parameters PARAM_STRING 408 | Fixed training parameters provided as a single parameter string. The format of the string is `param_name1=param_value1,param_name2=param_value2...`, e.g.: `loss=bpr-max,layers=100,constrained_embedding=True`. 409 | Boolean training parameters should be either True or False; parameters that can take a list should use / as the separator (e.g. layers=200/200). Mutually exclusive with the -pf (--parameter_file) and the -l 410 | (--load_model) arguments and one of the three must be provided. 411 | -opf PATH, --optuna_parameter_file PATH 412 | File describing the parameter space for optuna. 413 | -m [AT], --measure [AT] 414 | Measure recall & MRR at the defined recommendation list length. A single values can be provided. (Default: 20) 415 | -nt [NT], --ntrials [NT] 416 | Number of optimization trials to perform (Default: 50) 417 | -fm [AT [AT ...]], --final_measure [AT [AT ...]] 418 | Measure recall & MRR at the defined recommendation list length(s) after the optimization is finished. Multiple values can be provided. (Default: 20) 419 | -pm METRIC, --primary_metric METRIC 420 | Set primary metric, recall or mrr (e.g. for paropt). (Default: recall) 421 | -e EVAL_TYPE, --eval_type EVAL_TYPE 422 | Sets how to handle if multiple items in the ranked list have the same prediction score (which is usually due to saturation or an error). See the documentation of evaluate_gpu() in evaluation.py for further 423 | details. (Default: standard) 424 | -ik IK, --item_key IK 425 | Column name corresponding to the item IDs (detault: ItemId). 426 | -sk SK, --session_key SK 427 | Column name corresponding to the session IDs (default: SessionId). 428 | -tk TK, --time_key TK 429 | Column name corresponding to the timestamp (default: Time). 430 | ``` 431 | 432 | **Example:** Run a hyperparater optimization optimizing for MRR@20 for 200 iterations and measuring recall and MRR at 1, 5, 10 and 20 for the best variant after optimization is finished. 433 | *NOTE:* The paropt script can run on the CPU (`THEANO_FLAGS=device=cpu`) as models are trained in separate processes. You can control which device these training processes use by setting `-tf` argument that passes its value to the `THEANO_FLAGS` environment variable for the taining processes. In this example, training(s) are executed on `cuda0`. 434 | ``` 435 | THEANO_FLAGS=device=cpu python paropt.py /path/to/training_data_file_for_optimization /path/to/valiadation_data_file_for_optimization -pm mrr -m 20 -fm 1 5 10 20 -e conservative -fp n_sample=2048,logq=1.0,loss=cross-entropy,final_act=softmax,constrained_embedding=True,n_epochs=10 -tf device=cuda0 -opf /path/to/parameter_space.json -n 200 436 | ``` 437 | Output (first few lines): 438 | ``` 439 | -------------------------------------------------------------------------------- 440 | PARAMETER SPACE 441 | PARAMETER layers type=int range=[64..512] (step=32) UNIFORM scale 442 | PARAMETER batch_size type=int range=[32..256] (step=16) UNIFORM scale 443 | PARAMETER learning_rate type=float range=[0.01..0.25] (step=0.005) UNIFORM scale 444 | PARAMETER dropout_p_embed type=float range=[0.0..0.5] (step=0.05) UNIFORM scale 445 | PARAMETER dropout_p_hidden type=float range=[0.0..0.7] (step=0.05) UNIFORM scale 446 | PARAMETER momentum type=float range=[0.0..0.9] (step=0.05) UNIFORM scale 447 | PARAMETER sample_alpha type=float range=[0.0..1.0] (step=0.1) UNIFORM scale 448 | -------------------------------------------------------------------------------- 449 | [I 2023-07-25 03:19:53,684] A new study created in memory with name: no-name-83fade3e-49f3-4f26-ac76-5f6cb2f3a02c 450 | SET n_sample TO 2048 (type: ) 451 | SET logq TO 1.0 (type: ) 452 | SET loss TO cross-entropy (type: ) 453 | SET final_act TO softmax (type: ) 454 | SET constrained_embedding TO True (type: ) 455 | SET n_epochs TO 2 (type: ) 456 | SET layers TO [96] (type: ) 457 | SET batch_size TO 176 (type: ) 458 | SET learning_rate TO 0.045000000000000005 (type: ) 459 | SET dropout_p_embed TO 0.25 (type: ) 460 | SET dropout_p_hidden TO 0.25 (type: ) 461 | SET momentum TO 0.0 (type: ) 462 | SET sample_alpha TO 0.9 (type: ) 463 | Loading training data... 464 | ``` 465 | 466 | **Notes:** 467 | - By default, Optuna logs to stderr and the model prints to stdout. You can use this to log the model training details and the summary of the optimization separately by adding `1> /path/to/model_training_details.log 2> /path/to/optimization.log` to your command. Alternatively, you can play around with Optuna's settings. GRU4Rec at the moment doesn't use proper logging (it just prints). 468 | - If you redirect stderr and/or stdout to file(s) and you want to see progress in real time, use python in unbuffered mode, by adding the `-u` argument after `python` (i.e. `python -u paropt.py ...`). 469 | 470 | ## Executing on CPU 471 | Some optimizations for speeding up GPU execution (e.g. custom Theano operators) prevent running the code on CPU. Since CPU execution of neural networks is already slow, I decided to abandon CPU support to speed up execution on GPU. If - for some reason - you still want to run GRU4Rec on the CPU, you need to modify the code to disable the custom GPU optimizations. You will be able to run the code on CPU, just don't expect it to be quick. 472 | 473 | **Steps of disabling the custom GPU optimizations:** 474 | - In `gpu_ops.py` change line `13` to `disable_custom_op = True`. This makes the functions in `gpu_ops` to return standard operators or operators assembled from standard operators, instead of the custom ones when the computational graph is computed. 475 | - In `gru4rec.py` comment out line `12` containing `import custom_opt`. One of the custom operators is integrated deeper into Theano through `custom_opt`, which adds it to the optimizer that replaces operators in the computational graph. By removing this import, this operator won't be used. 476 | 477 | 478 | ## Major updates 479 | 480 | ### Update 24-08-2023 481 | - Added paropt 482 | - Extended info on reproducibility 483 | - Added parameter files and parameter spaces 484 | - Extended readme 485 | 486 | ### Update 08-05-2020 487 | - Significant speed-up of the training by increasing GPU utilization. 488 | - logQ normalization added (improves results when cross-entropy loss is used) 489 | - Added `run.py` for easy experimentation. 490 | - Extended this README. 491 | 492 | ### Update 08-06-2018 493 | - Refactor and cleaning. 494 | - Speeding up execution. 495 | - Ease of life improvements. 496 | - Code for evaluating on GPU. 497 | 498 | ### Update 13-06-2017 499 | - Upgraded to the v2.0 version 500 | - Added BPR-max and TOP1-max losses for cutting edge performance (coupled with additional sampling +30% in recall & MRR over the base results) 501 | - Sacrificed some speed on CPU for faster GPU execution 502 | 503 | ### Update 22-12-2016 504 | - Fixed cross-entropy unstability. Very small predicted scores were rounded to 0 and thus their logarithm became NaN. Added a small epsilon (1e-24) to all scores before computing the logarithm. I got better results with this stabilized cross-entropy than with the TOP1 loss on networks with 100 hidden units. 505 | - Added the option of using additional negative samples (besides the default, which is the other examples in the minibatch). The number of additional samples is given by the n_sample parameter. The probability of an item choosen as a sample is supp^sample_alpha, i.e. setting sample_alpha to 1 results in popularity based sampling, setting it to 0 results in uniform sampling. Using additional samples can slow down training, but depending on your config, the slowdown might not be noticable on GPU, up to 1000-2000 additional samples. 506 | - Added an option to training to precompute a large batch of negative samples in advance. The number of int values (IDs) to be stored is determined by the sample_store parameter of the train function (default: 10M). This option is for the additional negative samples only, so only takes effect when n_sample > 0. Computing negative samples in each step results in very inefficient GPU utilization as computations are often interrupted by sample generation (which runs on the CPU). Precomputing samples for several steps in advance makes the process more efficient. However one should avoid setting the sample store too big as generating too many samples takes a long time, resulting in the GPU waiting for its completion for a long time. It also increases the memory footprint. 507 | 508 | ### Update 21-09-2016 509 | - Optimized code for GPU execution. Training is ~2x faster now. 510 | - Added retrain functionality. -------------------------------------------------------------------------------- /baselines.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Jun 26 11:57:27 2015 4 | 5 | @author: Balázs Hidasi 6 | """ 7 | 8 | import numpy as np 9 | import pandas as pd 10 | 11 | class RandomPred: 12 | ''' 13 | RandomPred() 14 | 15 | Initializes a random predcitor, which is a baseline predictor that gives back a random score for each item. 16 | 17 | ''' 18 | def fit(self, data): 19 | ''' 20 | Dummy function for training. 21 | 22 | Parameters 23 | -------- 24 | data: pandas.DataFrame 25 | Training data. It contains the transactions of the sessions. It has one column for session IDs, one for item IDs and one for the timestamp of the events (unix timestamps). 26 | It must have a header. Column names are arbitrary, but must correspond to the ones you set during the initialization of the network (session_key, item_key, time_key properties). 27 | 28 | ''' 29 | pass 30 | 31 | def predict_next(self, session_id, input_item_id, predict_for_item_ids): 32 | ''' 33 | Gives predicton scores for a selected set of items on how likely they be the next item in the session. 34 | 35 | Parameters 36 | -------- 37 | session_id : int or string 38 | The session IDs of the event. 39 | input_item_id : int or string 40 | The item ID of the event. 41 | predict_for_item_ids : 1D array 42 | IDs of items for which the network should give prediction scores. 43 | 44 | Returns 45 | -------- 46 | out : pandas.Series 47 | Prediction scores for selected items on how likely to be the next item of this session. Indexed by the item IDs. 48 | 49 | ''' 50 | return pd.Series(data=np.random.rand(len(predict_for_item_ids)), index=predict_for_item_ids) 51 | 52 | class Pop: 53 | ''' 54 | Pop(top_n=100, item_key='ItemId', support_by_key=None) 55 | 56 | Popularity predictor that gives higher scores to items with larger support. 57 | 58 | The score is given by: 59 | 60 | .. math:: 61 | r_{i}=\\frac{supp_i}{(1+supp_i)} 62 | 63 | Parameters 64 | -------- 65 | top_n : int 66 | Only give back non-zero scores to the top N ranking items. Should be higher or equal than the cut-off of your evaluation. (Default value: 100) 67 | item_key : string 68 | The header of the item IDs in the training data. (Default value: 'ItemId') 69 | support_by_key : string or None 70 | If not None, count the number of unique values of the attribute of the training data given by the specified header. If None, count the events. (Default value: None) 71 | 72 | ''' 73 | 74 | def __init__(self, top_n = 100, item_key = 'ItemId', support_by_key = None): 75 | self.top_n = top_n 76 | self.item_key = item_key 77 | self.support_by_key = support_by_key 78 | 79 | def fit(self, data): 80 | ''' 81 | Trains the predictor. 82 | 83 | Parameters 84 | -------- 85 | data: pandas.DataFrame 86 | Training data. It contains the transactions of the sessions. It has one column for session IDs, one for item IDs and one for the timestamp of the events (unix timestamps). 87 | It must have a header. Column names are arbitrary, but must correspond to the ones you set during the initialization of the network (session_key, item_key, time_key properties). 88 | 89 | ''' 90 | grp = data.groupby(self.item_key) 91 | self.pop_list = grp.size() if self.support_by_key is None else grp[self.support_by_key].nunique() 92 | self.pop_list = self.pop_list / (self.pop_list + 1) 93 | self.pop_list.sort_values(ascending=False, inplace=True) 94 | self.pop_list = self.pop_list.head(self.top_n) 95 | 96 | def predict_next(self, session_id, input_item_id, predict_for_item_ids): 97 | ''' 98 | Gives predicton scores for a selected set of items on how likely they be the next item in the session. 99 | 100 | Parameters 101 | -------- 102 | session_id : int or string 103 | The session IDs of the event. 104 | input_item_id : int or string 105 | The item ID of the event. 106 | predict_for_item_ids : 1D array 107 | IDs of items for which the network should give prediction scores. Every ID must be in the set of item IDs of the training set. 108 | 109 | Returns 110 | -------- 111 | out : pandas.Series 112 | Prediction scores for selected items on how likely to be the next item of this session. Indexed by the item IDs. 113 | 114 | ''' 115 | preds = np.zeros(len(predict_for_item_ids)) 116 | mask = np.in1d(predict_for_item_ids, self.pop_list.index) 117 | preds[mask] = self.pop_list[predict_for_item_ids[mask]] 118 | return pd.Series(data=preds, index=predict_for_item_ids) 119 | 120 | class SessionPop: 121 | ''' 122 | SessionPop(top_n=100, item_key='ItemId', support_by_key=None) 123 | 124 | Session popularity predictor that gives higher scores to items with higher number of occurrences in the session. Ties are broken up by adding the popularity score of the item. 125 | 126 | The score is given by: 127 | 128 | .. math:: 129 | r_{s,i} = supp_{s,i} + \\frac{supp_i}{(1+supp_i)} 130 | 131 | Parameters 132 | -------- 133 | top_n : int 134 | Only give back non-zero scores to the top N ranking items. Should be higher or equal than the cut-off of your evaluation. (Default value: 100) 135 | item_key : string 136 | The header of the item IDs in the training data. (Default value: 'ItemId') 137 | support_by_key : string or None 138 | If not None, count the number of unique values of the attribute of the training data given by the specified header. If None, count the events. (Default value: None) 139 | 140 | ''' 141 | def __init__(self, top_n = 100, item_key = 'ItemId', support_by_key = None): 142 | self.top_n = top_n 143 | self.item_key = item_key 144 | self.support_by_key = support_by_key 145 | 146 | def fit(self, data): 147 | ''' 148 | Trains the predictor. 149 | 150 | Parameters 151 | -------- 152 | data: pandas.DataFrame 153 | Training data. It contains the transactions of the sessions. It has one column for session IDs, one for item IDs and one for the timestamp of the events (unix timestamps). 154 | It must have a header. Column names are arbitrary, but must correspond to the ones you set during the initialization of the network (session_key, item_key, time_key properties). 155 | 156 | ''' 157 | grp = data.groupby(self.item_key) 158 | self.pop_list = grp.size() if self.support_by_key is None else grp[self.support_by_key].nunique() 159 | self.pop_list = self.pop_list / (self.pop_list + 1) 160 | self.pop_list.sort_values(ascending=False, inplace=True) 161 | self.pop_list = self.pop_list.head(self.top_n) 162 | self.prev_session_id = -1 163 | 164 | def predict_next(self, session_id, input_item_id, predict_for_item_ids): 165 | ''' 166 | Gives predicton scores for a selected set of items on how likely they be the next item in the session. 167 | 168 | Parameters 169 | -------- 170 | session_id : int or string 171 | The session IDs of the event. If changed during subsequent calls, a new session starts. 172 | input_item_id : int or string 173 | The item ID of the event. Must be in the set of item IDs of the training set. 174 | predict_for_item_ids : 1D array 175 | IDs of items for which the network should give prediction scores. Every ID must be in the set of item IDs of the training set. 176 | 177 | Returns 178 | -------- 179 | out : pandas.Series 180 | Prediction scores for selected items on how likely to be the next item of this session. Indexed by the item IDs. 181 | 182 | ''' 183 | if self.prev_session_id != session_id: 184 | self.prev_session_id = session_id 185 | self.pers = dict() 186 | v = self.pers.get(input_item_id) 187 | if v: 188 | self.pers[input_item_id] = v + 1 189 | else: 190 | self.pers[input_item_id] = 1 191 | preds = np.zeros(len(predict_for_item_ids)) 192 | mask = np.in1d(predict_for_item_ids, self.pop_list.index) 193 | ser = pd.Series(self.pers) 194 | preds[mask] = self.pop_list[predict_for_item_ids[mask]] 195 | mask = np.in1d(predict_for_item_ids, ser.index) 196 | preds[mask] += ser[predict_for_item_ids[mask]] 197 | return pd.Series(data=preds, index=predict_for_item_ids) 198 | 199 | class ItemKNN: 200 | ''' 201 | ItemKNN(n_sims = 100, lmbd = 20, alpha = 0.5, session_key = 'SessionId', item_key = 'ItemId', time_key = 'Time') 202 | 203 | Item-to-item predictor that computes the the similarity to all items to the given item. 204 | 205 | Similarity of two items is given by: 206 | 207 | .. math:: 208 | s_{i,j}=\sum_{s}I\{(s,i)\in D & (s,j)\in D\} / (supp_i+\\lambda)^{\\alpha}(supp_j+\\lambda)^{1-\\alpha} 209 | 210 | Parameters 211 | -------- 212 | n_sims : int 213 | Only give back non-zero scores to the N most similar items. Should be higher or equal than the cut-off of your evaluation. (Default value: 100) 214 | lmbd : float 215 | Regularization. Discounts the similarity of rare items (incidental co-occurrences). (Default value: 20) 216 | alpha : float 217 | Balance between normalizing with the supports of the two items. 0.5 gives cosine similarity, 1.0 gives confidence (as in association rules). 218 | session_key : string 219 | header of the session ID column in the input file (default: 'SessionId') 220 | item_key : string 221 | header of the item ID column in the input file (default: 'ItemId') 222 | time_key : string 223 | header of the timestamp column in the input file (default: 'Time') 224 | 225 | ''' 226 | 227 | def __init__(self, n_sims = 100, lmbd = 20, alpha = 0.5, session_key = 'SessionId', item_key = 'ItemId', time_key = 'Time'): 228 | self.n_sims = n_sims 229 | self.lmbd = lmbd 230 | self.alpha = alpha 231 | self.item_key = item_key 232 | self.session_key = session_key 233 | self.time_key = time_key 234 | 235 | def fit(self, data): 236 | ''' 237 | Trains the predictor. 238 | 239 | Parameters 240 | -------- 241 | data: pandas.DataFrame 242 | Training data. It contains the transactions of the sessions. It has one column for session IDs, one for item IDs and one for the timestamp of the events (unix timestamps). 243 | It must have a header. Column names are arbitrary, but must correspond to the ones you set during the initialization of the network (session_key, item_key, time_key properties). 244 | 245 | ''' 246 | data.set_index(np.arange(len(data)), inplace=True) 247 | itemids = data[self.item_key].unique() 248 | n_items = len(itemids) 249 | data = pd.merge(data, pd.DataFrame({self.item_key:itemids, 'ItemIdx':np.arange(len(itemids))}), on=self.item_key, how='inner') 250 | sessionids = data[self.session_key].unique() 251 | data = pd.merge(data, pd.DataFrame({self.session_key:sessionids, 'SessionIdx':np.arange(len(sessionids))}), on=self.session_key, how='inner') 252 | supp = data.groupby('SessionIdx').size() 253 | session_offsets = np.zeros(len(supp)+1, dtype=np.int32) 254 | session_offsets[1:] = supp.cumsum() 255 | index_by_sessions = data.sort_values(['SessionIdx', self.time_key]).index.values 256 | supp = data.groupby('ItemIdx').size() 257 | item_offsets = np.zeros(n_items+1, dtype=np.int32) 258 | item_offsets[1:] = supp.cumsum() 259 | index_by_items = data.sort_values(['ItemIdx', self.time_key]).index.values 260 | self.sims = dict() 261 | for i in range(n_items): 262 | iarray = np.zeros(n_items) 263 | start = item_offsets[i] 264 | end = item_offsets[i+1] 265 | for e in index_by_items[start:end]: 266 | uidx = data.SessionIdx.values[e] 267 | ustart = session_offsets[uidx] 268 | uend = session_offsets[uidx+1] 269 | user_events = index_by_sessions[ustart:uend] 270 | iarray[data.ItemIdx.values[user_events]] += 1 271 | iarray[i] = 0 272 | norm = np.power((supp[i] + self.lmbd), self.alpha) * np.power((supp.values + self.lmbd), (1.0 - self.alpha)) 273 | norm[norm == 0] = 1 274 | iarray = iarray / norm 275 | indices = np.argsort(iarray)[-1:-1-self.n_sims:-1] 276 | self.sims[itemids[i]] = pd.Series(data=iarray[indices], index=itemids[indices]) 277 | 278 | def predict_next(self, session_id, input_item_id, predict_for_item_ids): 279 | ''' 280 | Gives predicton scores for a selected set of items on how likely they be the next item in the session. 281 | 282 | Parameters 283 | -------- 284 | session_id : int or string 285 | The session IDs of the event. 286 | input_item_id : int or string 287 | The item ID of the event. Must be in the set of item IDs of the training set. 288 | predict_for_item_ids : 1D array 289 | IDs of items for which the network should give prediction scores. Every ID must be in the set of item IDs of the training set. 290 | 291 | Returns 292 | -------- 293 | out : pandas.Series 294 | Prediction scores for selected items on how likely to be the next item of this session. Indexed by the item IDs. 295 | 296 | ''' 297 | preds = np.zeros(len(predict_for_item_ids)) 298 | sim_list = self.sims[input_item_id] 299 | mask = np.in1d(predict_for_item_ids, sim_list.index) 300 | preds[mask] = sim_list[predict_for_item_ids[mask]] 301 | return pd.Series(data=preds, index=predict_for_item_ids) 302 | 303 | class BPR: 304 | ''' 305 | BPR(n_factors = 100, n_iterations = 10, learning_rate = 0.01, lambda_session = 0.0, lambda_item = 0.0, sigma = 0.05, init_normal = False, session_key = 'SessionId', item_key = 'ItemId') 306 | 307 | Bayesian Personalized Ranking Matrix Factorization (BPR-MF). During prediction time, the current state of the session is modelled as the average of the feature vectors of the items that have occurred in it so far. 308 | 309 | Parameters 310 | -------- 311 | n_factor : int 312 | The number of features in a feature vector. (Default value: 100) 313 | n_iterations : int 314 | The number of epoch for training. (Default value: 10) 315 | learning_rate : float 316 | Learning rate. (Default value: 0.01) 317 | lambda_session : float 318 | Regularization for session features. (Default value: 0.0) 319 | lambda_item : float 320 | Regularization for item features. (Default value: 0.0) 321 | sigma : float 322 | The width of the initialization. (Default value: 0.05) 323 | init_normal : boolean 324 | Whether to use uniform or normal distribution based initialization. 325 | session_key : string 326 | header of the session ID column in the input file (default: 'SessionId') 327 | item_key : string 328 | header of the item ID column in the input file (default: 'ItemId') 329 | 330 | ''' 331 | def __init__(self, n_factors = 100, n_iterations = 10, learning_rate = 0.01, lambda_session = 0.0, lambda_item = 0.0, sigma = 0.05, init_normal = False, session_key = 'SessionId', item_key = 'ItemId'): 332 | self.n_factors = n_factors 333 | self.n_iterations = n_iterations 334 | self.learning_rate = learning_rate 335 | self.lambda_session = lambda_session 336 | self.lambda_item = lambda_item 337 | self.sigma = sigma 338 | self.init_normal = init_normal 339 | self.session_key = session_key 340 | self.item_key = item_key 341 | self.current_session = None 342 | 343 | def init(self, data): 344 | self.U = np.random.rand(self.n_sessions, self.n_factors) * 2 * self.sigma - self.sigma if not self.init_normal else np.random.randn(self.n_sessions, self.n_factors) * self.sigma 345 | self.I = np.random.rand(self.n_items, self.n_factors) * 2 * self.sigma - self.sigma if not self.init_normal else np.random.randn(self.n_items, self.n_factors) * self.sigma 346 | self.bU = np.zeros(self.n_sessions) 347 | self.bI = np.zeros(self.n_items) 348 | 349 | def update(self, uidx, p, n): 350 | uF = np.copy(self.U[uidx,:]) 351 | iF1 = np.copy(self.I[p,:]) 352 | iF2 = np.copy(self.I[n,:]) 353 | sigm = self.sigmoid(iF1.T.dot(uF) - iF2.T.dot(uF) + self.bI[p] - self.bI[n]) 354 | c = 1.0 - sigm 355 | self.U[uidx,:] += self.learning_rate * (c * (iF1 - iF2) - self.lambda_session * uF) 356 | self.I[p,:] += self.learning_rate * (c * uF - self.lambda_item * iF1) 357 | self.I[n,:] += self.learning_rate * (-c * uF - self.lambda_item * iF2) 358 | return np.log(sigm) 359 | 360 | def fit(self, data): 361 | ''' 362 | Trains the predictor. 363 | 364 | Parameters 365 | -------- 366 | data: pandas.DataFrame 367 | Training data. It contains the transactions of the sessions. It has one column for session IDs, one for item IDs and one for the timestamp of the events (unix timestamps). 368 | It must have a header. Column names are arbitrary, but must correspond to the ones you set during the initialization of the network (session_key, item_key, time_key properties). 369 | 370 | ''' 371 | itemids = data[self.item_key].unique() 372 | self.n_items = len(itemids) 373 | self.itemidmap = pd.Series(data=np.arange(self.n_items), index=itemids) 374 | sessionids = data[self.session_key].unique() 375 | self.n_sessions = len(sessionids) 376 | data = pd.merge(data, pd.DataFrame({self.item_key:itemids, 'ItemIdx':np.arange(self.n_items)}), on=self.item_key, how='inner') 377 | data = pd.merge(data, pd.DataFrame({self.session_key:sessionids, 'SessionIdx':np.arange(self.n_sessions)}), on=self.session_key, how='inner') 378 | self.init(data) 379 | for it in range(self.n_iterations): 380 | c = [] 381 | for e in np.random.permutation(len(data)): 382 | uidx = data.SessionIdx.values[e] 383 | iidx = data.ItemIdx.values[e] 384 | iidx2 = data.ItemIdx.values[np.random.randint(self.n_items)] 385 | err = self.update(uidx, iidx, iidx2) 386 | c.append(err) 387 | print(it, np.mean(c)) 388 | 389 | def predict_next(self, session_id, input_item_id, predict_for_item_ids): 390 | ''' 391 | Gives predicton scores for a selected set of items on how likely they be the next item in the session. 392 | 393 | Parameters 394 | -------- 395 | session_id : int or string 396 | The session IDs of the event. 397 | input_item_id : int or string 398 | The item ID of the event. Must be in the set of item IDs of the training set. 399 | predict_for_item_ids : 1D array 400 | IDs of items for which the network should give prediction scores. Every ID must be in the set of item IDs of the training set. 401 | 402 | Returns 403 | -------- 404 | out : pandas.Series 405 | Prediction scores for selected items on how likely to be the next item of this session. Indexed by the item IDs. 406 | 407 | ''' 408 | iidx = self.itemidmap[input_item_id] 409 | if self.current_session is None or self.current_session != session_id: 410 | self.current_session = session_id 411 | self.session = [iidx] 412 | else: 413 | self.session.append(iidx) 414 | uF = self.I[self.session].mean(axis=0) 415 | iIdxs = self.itemidmap[predict_for_item_ids] 416 | return pd.Series(data=self.I[iIdxs].dot(uF) + self.bI[iIdxs], index=predict_for_item_ids) 417 | 418 | def sigmoid(self, x): 419 | return 1.0 / (1.0 + np.exp(-x)) -------------------------------------------------------------------------------- /custom_opt.py: -------------------------------------------------------------------------------- 1 | import theano 2 | from theano import tensor, config 3 | from theano.gpuarray.subtensor import GpuAdvancedSubtensor1 4 | from theano.gpuarray.opt import register_opt, op_lifter, register_opt2 5 | from custom_theano_ops import GpuAdvancedSubtensor1_fast 6 | 7 | def remove_optimization(optimizer, name, *tags): 8 | obj = optimizer.__db__[name].copy().pop() 9 | optimizer.remove_tags(name, *tags) 10 | optimizer.__db__[obj.__class__.__name__].remove(obj) 11 | optimizer._names.remove(name) 12 | del(optimizer.__db__[name]) 13 | 14 | def get_tags(optimizer, name): 15 | obj = optimizer.__db__[name].copy().pop() 16 | tags = [] 17 | for k, v in optimizer.__db__.items(): 18 | if (obj in v) and (k != name) and (k != obj.__class__.__name__): 19 | tags.append(k) 20 | return sorted(tags) 21 | 22 | tags = get_tags(theano.gpuarray.opt.gpu_optimizer, 'local_gpua_advanced_subtensor1') 23 | remove_optimization(theano.gpuarray.opt.gpu_optimizer, 'local_gpua_advanced_subtensor1', *tags) 24 | 25 | tags = get_tags(theano.gpuarray.opt.gpu_optimizer2, 'local_gpua_advanced_subtensor1') 26 | remove_optimization(theano.gpuarray.opt.gpu_optimizer2, 'local_gpua_advanced_subtensor1', *tags) 27 | 28 | @register_opt('fast_compile') 29 | @op_lifter([tensor.AdvancedSubtensor1]) 30 | @register_opt2([tensor.AdvancedSubtensor1], 'fast_compile') 31 | def local_gpua_advanced_subtensor1(op, context_name, inputs, outputs): 32 | x, ilist = inputs 33 | if (x.ndim != 2 or config.deterministic == 'more'): 34 | return GpuAdvancedSubtensor1() 35 | else: 36 | return GpuAdvancedSubtensor1_fast() 37 | -------------------------------------------------------------------------------- /custom_theano_ops.py: -------------------------------------------------------------------------------- 1 | from theano import tensor, gof, Op, config 2 | from theano.gof import ParamsType 3 | from theano.gradient import grad_not_implemented 4 | import theano.tensor as T 5 | from theano.gpuarray.subtensor import GpuAdvancedSubtensor1 6 | from theano.scalar import bool as bool_t, int32 as int_t, uint32 as size_t 7 | 8 | try: 9 | import pygpu 10 | from pygpu import gpuarray 11 | except ImportError: 12 | pass 13 | 14 | from theano.gpuarray.type import GpuArrayType, gpu_context_type, get_context 15 | from theano.gpuarray.basic_ops import (as_gpuarray_variable, HideC, GpuKernelBase, Kernel, gpuarray_helper_inc_dir, infer_context_name, gpu_contiguous) 16 | from theano.gpuarray.fp16_help import write_w, load_w, work_dtype 17 | 18 | class GpuExtractDiag2D(GpuKernelBase, Op): 19 | """ 20 | Extracting diagonal of a 2D matrix on the GPU. 21 | 22 | """ 23 | __props__ = ('context_name', 'keepdims') 24 | _f16_ok = True 25 | params_type = ParamsType(context=gpu_context_type, keepdims=bool_t) 26 | 27 | def __init__(self, context_name=None, keepdims=False): 28 | self.context_name = context_name 29 | self.keepdims = keepdims 30 | 31 | def get_params(self, node): 32 | return self.params_type.get_params(self, context=get_context(self.context_name), keepdims=self.keepdims) 33 | 34 | def make_node(self, x, k=0): #TODO: dtype check 35 | x = as_gpuarray_variable(x, context_name=self.context_name) 36 | k = tensor.as_tensor_variable(k) 37 | assert x.ndim == 2 38 | assert k.ndim == 0 39 | broadcastable = (False,True) if self.keepdims else (False,) 40 | otype = GpuArrayType(dtype=x.type.dtype, broadcastable=broadcastable, context_name=self.context_name) 41 | return gof.Apply(self, [x, k], [otype()]) 42 | 43 | def infer_shape(self, node, in_shapes): 44 | in_shape, _ = in_shapes 45 | dim1 = in_shape[0] 46 | dim2 = in_shape[1] 47 | k = node.inputs[1] 48 | diag_size = T.switch(T.ge(k, 0), T.clip(dim2 - k, 0, dim1), T.clip(dim1 + k, 0, dim2)) 49 | if self.keepdims: 50 | diag_size = (diag_size, 1) 51 | else: 52 | diag_size = (diag_size,) 53 | return [diag_size] 54 | 55 | def grad(self, inp, grads): 56 | return [GpuAllocDiag2D()(grads[0], inp[1], *(inp[0].shape)), grad_not_implemented(self, 1, inp[1])] 57 | 58 | def gpu_kernels(self, node, name): 59 | dtype_x = node.inputs[0].dtype 60 | type_x = gpuarray.dtype_to_ctype(dtype_x) 61 | dtype_y = node.outputs[0].dtype 62 | type_y = gpuarray.dtype_to_ctype(dtype_y) 63 | work_x = gpuarray.dtype_to_ctype(work_dtype(dtype_x)) 64 | load_x = load_w(dtype_x) 65 | write_y = write_w(dtype_y) 66 | code = """ 67 | #include "cluda.h" 68 | KERNEL void extract(const ga_ssize stridesX0, const ga_ssize stridesX1, GLOBAL_MEM %(type_x)s *x, ga_size x_off, const ga_ssize stridesY0, GLOBAL_MEM %(type_y)s *y, ga_size y_off, ga_ssize k, ga_size l) { 69 | x = (GLOBAL_MEM %(type_x)s *)(((GLOBAL_MEM char *)x) + x_off); 70 | y = (GLOBAL_MEM %(type_y)s *)(((GLOBAL_MEM char *)y) + y_off); 71 | ga_ssize coff = max(k, (ga_ssize) 0); 72 | ga_ssize roff = -min(k, (ga_ssize) 0); 73 | ga_size index = GID_0 * LDIM_0 + LID_0; 74 | if (index < l) { 75 | %(work_x)s t = %(load_x)s(x[(index + roff) * stridesX0 + (index + coff) * stridesX1]); 76 | y[index * stridesY0] = %(write_y)s(t); 77 | } 78 | }""" % dict(type_x=type_x, type_y=type_y, work_x=work_x, load_x=load_x, write_y=write_y, name=name) 79 | return [Kernel( 80 | code=code, name="extract", 81 | params=[gpuarray.SSIZE, gpuarray.SSIZE, gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE, gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE, gpuarray.SIZE], 82 | flags=Kernel.get_flags(dtype_x, dtype_y), 83 | objvar='k_extract_' + name)] 84 | 85 | def c_headers(self): 86 | return ['', '', ''] 87 | 88 | def c_header_dirs(self): 89 | return [gpuarray_helper_inc_dir()] 90 | 91 | def c_code(self, node, name, inp, out, sub): #TODO: fix error msg 92 | x, k = inp 93 | y, = out 94 | fail = sub['fail'] 95 | params = sub['params'] 96 | typecode = pygpu.gpuarray.dtype_to_typecode(node.inputs[0].dtype) 97 | kname = self.gpu_kernels(node, name)[0].objvar 98 | s = """ 99 | int err; 100 | size_t* dims = (size_t*)PyGpuArray_DIMS((PyGpuArrayObject*)%(x)s); 101 | size_t k = ((dtype_%(k)s*)PyArray_DATA(%(k)s))[0]; 102 | size_t col_off = (size_t) (k > 0?k:0); 103 | size_t row_off = (size_t) (k < 0?-k:0); 104 | size_t diag_size = (size_t) std::max((ssize_t) std::min((ssize_t)dims[0] - (ssize_t)row_off, (ssize_t)dims[1] - (ssize_t)col_off), (ssize_t) 0); 105 | size_t ls = std::min(diag_size, (size_t) 1024); 106 | size_t gs = (diag_size + ls - 1) / ls; 107 | size_t ndims = %(params)s->keepdims ? 2 : 1; 108 | size_t out_dims[ndims]; 109 | out_dims[0] = diag_size; 110 | if (ndims == 2) { 111 | out_dims[1] = 1; 112 | } 113 | 114 | size_t itemsize_x = 1; 115 | size_t itemsize_y = 1; 116 | ssize_t stridesX0 = 1; 117 | ssize_t stridesX1 = 1; 118 | ssize_t stridesY0 = 1; 119 | 120 | if (%(y)s == NULL || %(y)s->ga.nd != ndims || %(y)s->ga.dimensions[0] != diag_size || (ndims > 1 && %(y)s->ga.dimensions[1] != 1)) { 121 | Py_CLEAR(%(y)s); 122 | %(y)s = pygpu_empty(ndims, out_dims, %(typecode)s, GA_C_ORDER, %(params)s->context, Py_None); 123 | } 124 | if (%(y)s == NULL) { 125 | %(fail)s 126 | } 127 | 128 | itemsize_x = GpuArray_ITEMSIZE(&%(x)s->ga); 129 | itemsize_y = GpuArray_ITEMSIZE(&%(y)s->ga); 130 | stridesX0 = PyGpuArray_STRIDES(%(x)s)[0] / itemsize_x; 131 | stridesX1 = PyGpuArray_STRIDES(%(x)s)[1] / itemsize_x; 132 | stridesY0 = PyGpuArray_STRIDES(%(y)s)[0] / itemsize_y; 133 | 134 | if (row_off < dims[0] && col_off < dims[1]) { 135 | err = extract_call(1, &gs, &ls, 0, stridesX0, stridesX1, %(x)s->ga.data, %(x)s->ga.offset, stridesY0, %(y)s->ga.data, %(y)s->ga.offset, k, diag_size); 136 | if (err != GA_NO_ERROR) { 137 | PyErr_Format(PyExc_RuntimeError, "gpuarray error: kExtract: %%s. n%%lu, m=%%lu.", GpuKernel_error(&%(kname)s, err), (unsigned long)dims[0], (unsigned long)dims[1]); 138 | %(fail)s; 139 | } 140 | } else { 141 | %(fail)s; 142 | } 143 | """ % locals() 144 | return s 145 | 146 | def c_code_cache_version(self): 147 | return (1,) 148 | 149 | class GpuAllocDiag2D(GpuKernelBase, Op): 150 | """ 151 | Making a diagonal matrix from a vector on GPU 152 | 153 | """ 154 | __props__ = ('context_name',) 155 | _f16_ok = True 156 | 157 | def __init__(self, context_name=None): 158 | self.context_name = context_name 159 | 160 | def get_params(self, node): 161 | return get_context(self.context_name) 162 | 163 | def make_node(self, x, k=0, n=0, m=0): #TODO: dtype check 164 | x = as_gpuarray_variable(x, context_name=self.context_name) 165 | k = tensor.as_tensor_variable(k) 166 | n = tensor.as_tensor_variable(n) 167 | m = tensor.as_tensor_variable(m) 168 | assert x.ndim == 2 or x.ndim == 1 169 | assert k.ndim == 0 170 | assert n.ndim == 0 171 | assert m.ndim == 0 172 | otype = GpuArrayType(dtype=x.type.dtype, broadcastable=(False,False), context_name=self.context_name) 173 | return gof.Apply(self, [x, k, n, m], [otype()]) 174 | 175 | def infer_shape(self, node, in_shapes): 176 | in_shape, _, _, _ = in_shapes 177 | k, n, m = node.inputs[1:] 178 | dim_in = in_shape[0] 179 | dim_out1 = T.maximum(T.switch(T.ge(k,0), dim_in, dim_in-k), n) 180 | dim_out2 = T.maximum(T.switch(T.ge(k,0), dim_in+k, dim_in), m) 181 | return [(dim_out1, dim_out2)] 182 | 183 | def grad(self, inp, grads): 184 | return [GpuExtractDiag2D(keepdims=(inp[0].ndim==2))(grads[0], inp[1])] + [grad_not_implemented(self, i, inp[i]) for i in range(1,4)] 185 | 186 | def gpu_kernels(self, node, name): 187 | dtype_x = node.inputs[0].dtype 188 | type_x = gpuarray.dtype_to_ctype(dtype_x) 189 | dtype_y = node.outputs[0].dtype 190 | type_y = gpuarray.dtype_to_ctype(dtype_y) 191 | work_x = gpuarray.dtype_to_ctype(work_dtype(dtype_x)) 192 | load_x = load_w(dtype_x) 193 | write_y = write_w(dtype_y) 194 | code = """ 195 | #include "cluda.h" 196 | KERNEL void dalloc(const ga_ssize stridesX0, GLOBAL_MEM %(type_x)s *x, ga_size x_off, const ga_ssize stridesY0, const ga_ssize stridesY1, GLOBAL_MEM %(type_y)s *y, ga_size y_off, ga_ssize k, ga_size l) { 197 | x = (GLOBAL_MEM %(type_x)s *)(((GLOBAL_MEM char *)x) + x_off); 198 | y = (GLOBAL_MEM %(type_y)s *)(((GLOBAL_MEM char *)y) + y_off); 199 | ga_ssize coff = max(k, (ga_ssize) 0); 200 | ga_ssize roff = -min(k, (ga_ssize) 0); 201 | ga_size index = GID_0 * LDIM_0 + LID_0; 202 | if (index < l) { 203 | %(work_x)s t = %(load_x)s(x[index * stridesX0]); 204 | y[(index + roff) * stridesY0 + (index + coff) * stridesY1] = %(write_y)s(t); 205 | } 206 | }""" % dict(type_x=type_x, type_y=type_y, work_x=work_x, load_x=load_x, write_y=write_y, name=name) 207 | return [Kernel( 208 | code=code, name="dalloc", 209 | params=[gpuarray.SSIZE, gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE, gpuarray.SSIZE, gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SIZE, gpuarray.SIZE], 210 | flags=Kernel.get_flags(dtype_x, dtype_y), 211 | objvar='k_dalloc_' + name)] 212 | 213 | def c_headers(self): 214 | return ['', '', ''] 215 | 216 | def c_header_dirs(self): 217 | return [gpuarray_helper_inc_dir()] 218 | 219 | def c_code(self, node, name, inp, out, sub): #TODO: fix error msgs 220 | x, k, n, m = inp 221 | y, = out 222 | fail = sub['fail'] 223 | ctx = sub['params'] 224 | typecode = pygpu.gpuarray.dtype_to_typecode(node.inputs[0].dtype) 225 | kname = self.gpu_kernels(node, name)[0].objvar 226 | s = """ 227 | int err; 228 | size_t ndims = (size_t)PyGpuArray_NDIM((PyGpuArrayObject*)%(x)s); 229 | size_t* in_dims = (size_t*)PyGpuArray_DIMS((PyGpuArrayObject*)%(x)s); 230 | size_t l = in_dims[0]; 231 | size_t ls = std::min(l, (size_t)1024); 232 | size_t gs = (l + ls - 1) / ls; 233 | size_t k = ((dtype_%(k)s*)PyArray_DATA(%(k)s))[0]; 234 | size_t n = ((dtype_%(n)s*)PyArray_DATA(%(n)s))[0]; 235 | size_t m = ((dtype_%(m)s*)PyArray_DATA(%(m)s))[0]; 236 | size_t out_dims[2] = {std::max(k < 0 ? (size_t)l-k : l, n), std::max(k > 0 ? (size_t)l+k : l, m)}; 237 | 238 | size_t itemsize_x = 1; 239 | size_t itemsize_y = 1; 240 | ssize_t stridesX0 = 1; 241 | ssize_t stridesY0 = 1; 242 | ssize_t stridesY1 = 1; 243 | 244 | if ((ndims == 2) && (in_dims[1] != 1)) { 245 | PyErr_Format(PyExc_RuntimeError, "If the input has 2 dimensions the second dimension must be of size 1. Input shape: (%%lu, %%lu)", (unsigned long)in_dims[0], (unsigned long)in_dims[1]); 246 | %(fail)s 247 | } 248 | 249 | Py_CLEAR(%(y)s); 250 | %(y)s = pygpu_zeros(2, out_dims, %(typecode)s, GA_C_ORDER, %(ctx)s, Py_None); //theano can reuse this space, thus we have to make sure to fill it with zeros every time 251 | 252 | if (%(y)s == NULL) { 253 | PyErr_Format(PyExc_RuntimeError, "Failed to allocate array for the output."); 254 | %(fail)s 255 | } 256 | 257 | itemsize_x = GpuArray_ITEMSIZE(&%(x)s->ga); 258 | itemsize_y = GpuArray_ITEMSIZE(&%(y)s->ga); 259 | stridesX0 = PyGpuArray_STRIDES(%(x)s)[0] / itemsize_x; 260 | stridesY0 = PyGpuArray_STRIDES(%(y)s)[0] / itemsize_y; 261 | stridesY1 = PyGpuArray_STRIDES(%(y)s)[1] / itemsize_y; 262 | 263 | err = dalloc_call(1, &gs, &ls, 0, stridesX0, %(x)s->ga.data, %(x)s->ga.offset, stridesY0, stridesY1, %(y)s->ga.data, %(y)s->ga.offset, k, l); 264 | if (err != GA_NO_ERROR) { 265 | PyErr_Format(PyExc_RuntimeError, "gpuarray error: kAlloc: %%s. n%%lu, m=%%lu.", GpuKernel_error(&%(kname)s, err), (unsigned long)out_dims[0], (unsigned long)out_dims[1]); 266 | %(fail)s; 267 | } 268 | 269 | """ % locals() 270 | return s 271 | 272 | def c_code_cache_version(self): 273 | return (1,) 274 | 275 | class GpuBinarySearchSorted(GpuKernelBase, Op): 276 | """ 277 | Searchsorted on GPU 278 | 279 | """ 280 | __props__ = ('context_name', 'dtype_int64') 281 | _f16_ok = True 282 | params_type = ParamsType(context=gpu_context_type, dtype_int64=bool_t) 283 | 284 | def __init__(self, context_name=None, dtype_int64=False): 285 | self.context_name = context_name 286 | self.dtype_int64 = dtype_int64 287 | 288 | def get_params(self, node): 289 | return self.params_type.get_params(self, context=get_context(self.context_name), dtype_int64=self.dtype_int64) 290 | 291 | def make_node(self, d, x): 292 | d = as_gpuarray_variable(d, context_name=self.context_name) 293 | x = as_gpuarray_variable(x, context_name=self.context_name) 294 | assert d.ndim == 1 295 | assert x.ndim == 1 296 | broadcastable = (False,) 297 | otype = GpuArrayType(dtype='int64' if self.dtype_int64 else 'int32', broadcastable=broadcastable, context_name=self.context_name) 298 | return gof.Apply(self, [d, x], [otype()]) 299 | 300 | def infer_shape(self, node, in_shapes): 301 | _, x_shape = in_shapes 302 | return [x_shape] 303 | 304 | def grad(self, inp, grads): 305 | return [grad_not_implemented(self, i, inp[i]) for i in range(2)] 306 | 307 | def gpu_kernels(self, node, name): 308 | dtype_d = node.inputs[0].dtype 309 | type_d = gpuarray.dtype_to_ctype(dtype_d) 310 | dtype_x = node.inputs[1].dtype 311 | type_x = gpuarray.dtype_to_ctype(dtype_x) 312 | dtype_y = node.outputs[0].dtype 313 | type_y = gpuarray.dtype_to_ctype(dtype_y) 314 | work_d = gpuarray.dtype_to_ctype(work_dtype(dtype_d)) 315 | load_d = load_w(dtype_d) 316 | work_x = gpuarray.dtype_to_ctype(work_dtype(dtype_x)) 317 | load_x = load_w(dtype_x) 318 | code = """ 319 | #include "cluda.h" 320 | KERNEL void binsearchsorted(const ga_ssize stridesD0, GLOBAL_MEM %(type_d)s *d, ga_size d_off, const ga_ssize stridesX0, GLOBAL_MEM %(type_x)s *x, ga_size x_off, const ga_ssize stridesY0, GLOBAL_MEM %(type_y)s *y, ga_size y_off, ga_size lx, ga_ssize ld) { 321 | d = (GLOBAL_MEM %(type_d)s *)(((GLOBAL_MEM char *)d) + d_off); 322 | x = (GLOBAL_MEM %(type_x)s *)(((GLOBAL_MEM char *)x) + x_off); 323 | y = (GLOBAL_MEM %(type_y)s *)(((GLOBAL_MEM char *)y) + y_off); 324 | ga_size index = threadIdx.x + blockIdx.x * blockDim.x; 325 | if (index < lx) { 326 | ga_long a = 0; 327 | ga_long b = (ga_long)(ld - 1); 328 | %(work_d)s minval = %(load_d)s(d[a]); 329 | %(work_d)s maxval = %(load_d)s(d[b * stridesD0]); 330 | %(work_x)s val = %(load_x)s(x[index * stridesX0]); 331 | if (val > maxval) { 332 | a = (ga_long)ld; 333 | b = (ga_long)ld; 334 | } else if (val <= minval) { 335 | a = 0; 336 | b = 0; 337 | } 338 | while (b - a > 0) { 339 | ga_long h = (b + a) / 2; 340 | %(work_d)s t = %(load_d)s(d[h * stridesD0]); 341 | if (val < t) { 342 | b = h; 343 | } else { 344 | a = h + 1; 345 | } 346 | } 347 | y[index * stridesY0] = b; 348 | } 349 | }""" % dict(type_d=type_d, type_x=type_x, type_y=type_y, work_d=work_d, load_d=load_d, work_x=work_x, load_x=load_x, name=name) 350 | return [Kernel( 351 | code=code, name="binsearchsorted", 352 | params=[gpuarray.SSIZE, gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE, gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE, gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SIZE, gpuarray.SSIZE], 353 | flags=Kernel.get_flags(dtype_d, dtype_x, dtype_y), 354 | objvar='k_binsearchsorted_' + name)] 355 | 356 | def c_headers(self): 357 | return ['', '', ''] 358 | 359 | def c_header_dirs(self): 360 | return [gpuarray_helper_inc_dir()] 361 | 362 | def c_code(self, node, name, inp, out, sub): #TODO: fix error msg 363 | d, x = inp 364 | y, = out 365 | fail = sub['fail'] 366 | params = sub['params'] 367 | typecode = pygpu.gpuarray.dtype_to_typecode(node.outputs[0].dtype) 368 | kname = self.gpu_kernels(node, name)[0].objvar 369 | s = """ 370 | int err; 371 | size_t dimd = ((size_t*)PyGpuArray_DIMS((PyGpuArrayObject*)%(d)s))[0]; 372 | size_t dimx = ((size_t*)PyGpuArray_DIMS((PyGpuArrayObject*)%(x)s))[0]; 373 | size_t ls = 1024; 374 | size_t gs = (dimx / 1024) + 1; 375 | size_t out_dims[1] = {dimx}; 376 | 377 | size_t itemsize_d = 1; 378 | size_t itemsize_x = 1; 379 | size_t itemsize_y = 1; 380 | ssize_t stridesD0 = 1; 381 | ssize_t stridesX0 = 1; 382 | ssize_t stridesY0 = 1; 383 | 384 | if (%(y)s == NULL || %(y)s->ga.nd != 1 || %(y)s->ga.dimensions[0] != dimx) { 385 | Py_CLEAR(%(y)s); 386 | %(y)s = pygpu_zeros(1, out_dims, %(typecode)s, GA_C_ORDER, %(params)s->context, Py_None); 387 | } 388 | if (%(y)s == NULL) { 389 | %(fail)s 390 | } 391 | 392 | itemsize_d = GpuArray_ITEMSIZE(&%(d)s->ga); 393 | itemsize_x = GpuArray_ITEMSIZE(&%(x)s->ga); 394 | itemsize_y = GpuArray_ITEMSIZE(&%(y)s->ga); 395 | stridesD0 = PyGpuArray_STRIDES(%(d)s)[0] / itemsize_d; 396 | stridesX0 = PyGpuArray_STRIDES(%(x)s)[0] / itemsize_x; 397 | stridesY0 = PyGpuArray_STRIDES(%(y)s)[0] / itemsize_y; 398 | err = binsearchsorted_call(1, &gs, &ls, 0, stridesD0, %(d)s->ga.data, %(d)s->ga.offset, stridesX0, %(x)s->ga.data, %(x)s->ga.offset, stridesY0, %(y)s->ga.data, %(y)s->ga.offset, dimx, (ssize_t)dimd); 399 | if (err != GA_NO_ERROR) { 400 | PyErr_Format(PyExc_RuntimeError, "gpuarray error: kExtract: %%s. n%%lu, m=%%lu.", GpuKernel_error(&%(kname)s, err), (unsigned long)dimx, (unsigned long)dimd); 401 | %(fail)s; 402 | } 403 | """ % locals() 404 | return s 405 | 406 | def c_code_cache_version(self): 407 | return (1,) 408 | 409 | class GpuAdvancedSubtensor1_fast(GpuKernelBase, GpuAdvancedSubtensor1): 410 | """ 411 | Implement a faster version AdvancedSubtensor1 on the gpu for 2D tensors 412 | 413 | """ 414 | _f16_ok = True 415 | 416 | def make_node(self, x, ilist): 417 | ctx_name = infer_context_name(x, ilist) 418 | x_ = as_gpuarray_variable(x, ctx_name) 419 | ilist_ = as_gpuarray_variable(ilist, ctx_name) 420 | 421 | if ilist_.type.dtype not in tensor.integer_dtypes: 422 | raise TypeError('index must be integers') 423 | if ilist_.type.ndim != 1: 424 | raise TypeError('index must be vector') 425 | if x_.type.ndim == 0: 426 | raise TypeError('cannot index into a scalar') 427 | return gof.Apply(self, [x_, ilist_], [x_.type()]) 428 | 429 | def perform(self, node, inp, out, params): 430 | return super(GpuAdvancedSubtensor1_fast, self).perform(node, inp, out) 431 | 432 | def c_code_cache_version(self): 433 | return (1,) 434 | 435 | def c_headers(self): 436 | return ['', '', 437 | ''] 438 | 439 | def c_header_dirs(self): 440 | return [gpuarray_helper_inc_dir()] 441 | 442 | def c_code(self, node, name, inputs, outputs, sub): 443 | if (node.inputs[0].ndim != 2): 444 | raise NotImplementedError("This case does not have C code yet.") 445 | 446 | return """ 447 | int err; 448 | if (%(out)s == NULL || !GpuArray_IS_C_CONTIGUOUS(&%(out)s->ga) || 449 | %(out)s->ga.dimensions[0] != %(idx)s->ga.dimensions[0] || 450 | %(out)s->ga.nd != %(v)s->ga.nd || %(out)s->ga.dimensions[1] != %(v)s->ga.dimensions[1]) { 451 | size_t tmp; 452 | Py_XDECREF(%(out)s); 453 | 454 | /* This is a dirty hack to avoid an extra alloc */ 455 | tmp = %(v)s->ga.dimensions[0]; 456 | %(v)s->ga.dimensions[0] = %(idx)s->ga.dimensions[0]; 457 | %(out)s = pygpu_empty(%(v)s->ga.nd, %(v)s->ga.dimensions, %(v)s->ga.typecode, 458 | GA_C_ORDER, %(v)s->context, Py_None); 459 | if (%(out)s == NULL) { 460 | %(fail)s; 461 | } 462 | %(v)s->ga.dimensions[0] = tmp; // Don't remove this line 463 | } 464 | if (GpuArray_vector_select_fast(%(out)s, %(v)s, %(idx)s)) { 465 | %(fail)s 466 | } 467 | """ % dict(v=inputs[0], idx=inputs[1], out=outputs[0], fail=sub['fail']) 468 | 469 | def gpu_kernels(self, node, nodename): 470 | CHARMAP = dict(int32='i', uint32='I', 471 | int64='l', uint64='L', 472 | float16='e', float32='f', float64='d') 473 | dtype_in = node.inputs[0].dtype 474 | dtype_out = node.outputs[0].dtype 475 | dtype_idx = node.inputs[1].dtype 476 | type_in = gpuarray.dtype_to_ctype(dtype_in) 477 | type_out = gpuarray.dtype_to_ctype(dtype_out) 478 | type_idx = gpuarray.dtype_to_ctype(dtype_idx) 479 | flags = Kernel.get_flags(dtype_in, dtype_out, dtype_idx) 480 | kname = "k_vector_select_fast" 481 | k_var = "k_vector_select_fast_" + nodename 482 | code = """#include "cluda.h" 483 | KERNEL void k_vector_select_fast(const ga_size numRowsOut, 484 | const ga_size numColsOut, 485 | const ga_ssize stridesOut0, 486 | const ga_ssize stridesOut1, 487 | GLOBAL_MEM %(type_out)s *Out, 488 | const ga_size offset_Out, 489 | const ga_size numRowsIn, 490 | const ga_size numColsIn, 491 | const ga_ssize stridesIn0, 492 | const ga_ssize stridesIn1, 493 | GLOBAL_MEM %(type_in)s *In, 494 | const ga_size offset_In, 495 | const ga_size numIndices, 496 | const ga_ssize stridesIndices, 497 | GLOBAL_MEM %(type_idx)s *indices_arr, 498 | const ga_size offset_indices_arr, 499 | GLOBAL_MEM ga_int *err) 500 | { 501 | Out = (GLOBAL_MEM %(type_out)s *)(((GLOBAL_MEM char *)Out)+offset_Out); 502 | In = (GLOBAL_MEM %(type_in)s *)(((GLOBAL_MEM char *)In)+offset_In); 503 | indices_arr = (GLOBAL_MEM %(type_idx)s *)(((GLOBAL_MEM char *)indices_arr)+offset_indices_arr); 504 | 505 | for (ga_int i = GID_0; i < numIndices; i += GDIM_0) 506 | { 507 | for (ga_int j = LID_0; j < numColsIn; j += LDIM_0) 508 | { 509 | ga_ssize in_row = indices_arr[i * stridesIndices]; 510 | if (in_row < 0) 511 | in_row += numRowsIn; 512 | ga_ssize out_row = i; 513 | if (in_row < numRowsIn && in_row >= 0) { 514 | Out[(out_row * stridesOut0) + (j * stridesOut1)] = In[(in_row * stridesIn0) + (j * stridesIn1)]; 515 | } else { 516 | *err = 1; 517 | } 518 | } 519 | } 520 | return; 521 | } 522 | """ % dict(type_in=type_in, type_out=type_out, type_idx=type_idx, 523 | tc=CHARMAP[dtype_in]) 524 | from pygpu.gpuarray import SIZE, SSIZE 525 | params = [ 526 | SIZE, SIZE, SSIZE, SSIZE, gpuarray.GpuArray, SIZE, 527 | SIZE, SIZE, SSIZE, SSIZE, gpuarray.GpuArray, SIZE, 528 | SIZE, SSIZE, gpuarray.GpuArray, SIZE, 529 | gpuarray.GpuArray] 530 | return [Kernel(code=code, name=kname, params=params, 531 | flags=flags, objvar=k_var)] 532 | 533 | def c_support_code_struct(self, node, nodename): 534 | return super(GpuAdvancedSubtensor1_fast, self).c_support_code_struct(node, nodename) + """ 535 | int GpuArray_vector_select_fast(PyGpuArrayObject* py_out, 536 | PyGpuArrayObject* py_in, 537 | PyGpuArrayObject* indices_arr) 538 | { 539 | size_t threads_per_block = std::min(PyGpuArray_DIMS(py_out)[1], (size_t)256); 540 | size_t n_blocks = std::min(PyGpuArray_SIZE(indices_arr), (size_t)4096); 541 | gpudata *errbuf; 542 | int err, kerr = 0; 543 | size_t itemsize_out = GpuArray_ITEMSIZE(&py_out->ga); 544 | size_t itemsize_in = GpuArray_ITEMSIZE(&py_in->ga); 545 | size_t itemsize_idx = GpuArray_ITEMSIZE(&indices_arr->ga); 546 | 547 | if (threads_per_block > 0 && n_blocks > 0) { 548 | err = gpudata_property(py_out->ga.data, 549 | GA_CTX_PROP_ERRBUF, &errbuf); 550 | if (err != GA_NO_ERROR) { 551 | PyErr_SetString(PyExc_RuntimeError, "Can't fetch error buffer"); 552 | return 1; 553 | } 554 | 555 | err = k_vector_select_fast_call( 556 | 1, &n_blocks, &threads_per_block, 0, 557 | PyGpuArray_DIMS(py_out)[0], 558 | PyGpuArray_DIMS(py_out)[1], 559 | PyGpuArray_STRIDES(py_out)[0] / itemsize_out, 560 | PyGpuArray_STRIDES(py_out)[1] / itemsize_out, 561 | py_out->ga.data, 562 | py_out->ga.offset, 563 | PyGpuArray_DIMS(py_in)[0], 564 | PyGpuArray_DIMS(py_in)[1], 565 | PyGpuArray_DIMS(py_in)[0] == 1 ? 0 : PyGpuArray_STRIDES(py_in)[0] / itemsize_in, 566 | PyGpuArray_DIMS(py_in)[1] == 1 ? 0 : PyGpuArray_STRIDES(py_in)[1] / itemsize_in, 567 | py_in->ga.data, 568 | py_in->ga.offset, 569 | PyGpuArray_DIMS(indices_arr)[0], 570 | PyGpuArray_STRIDES(indices_arr)[0] / itemsize_idx, 571 | indices_arr->ga.data, 572 | indices_arr->ga.offset, 573 | errbuf); 574 | 575 | if (err != GA_NO_ERROR) { 576 | PyErr_Format(PyExc_RuntimeError, 577 | "gpuarray error: %(k_var)s: %%s.", 578 | GpuKernel_error(&%(k_var)s, err)); 579 | return 1; 580 | } 581 | err = gpudata_read(&kerr, errbuf, 0, sizeof(int)); 582 | if (err != GA_NO_ERROR) { 583 | PyErr_SetString(PyExc_RuntimeError, "Can't read error buffer"); 584 | return 1; 585 | } 586 | if (kerr != 0) { 587 | PyErr_SetString(PyExc_IndexError, "Index out of bounds"); 588 | kerr = 0; 589 | gpudata_write(errbuf, 0, &kerr, sizeof(int)); 590 | return 1; 591 | } 592 | } 593 | return 0; 594 | } 595 | """ % dict(k_var="k_vector_select_fast_" + nodename) 596 | -------------------------------------------------------------------------------- /datatools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed May 6 18:20:08 2020 4 | 5 | @author: Hidasi Balázs 6 | """ 7 | 8 | import time 9 | import numpy as np 10 | import pandas as pd 11 | 12 | def sort_if_needed(data, columns, any_order_first_dim=False): 13 | is_sorted = True 14 | neq_masks = [] 15 | for i, col in enumerate(columns): 16 | dcol = data[col] 17 | neq_masks.append(dcol.values[1:]!=dcol.values[:-1]) 18 | if i == 0: 19 | if any_order_first_dim: 20 | is_sorted = is_sorted and (dcol.nunique() == neq_masks[0].sum() + 1) 21 | else: 22 | is_sorted = is_sorted and np.all(dcol.values[1:] >= dcol.values[:-1]) 23 | else: 24 | is_sorted = is_sorted and np.all(neq_masks[i - 1] | (dcol.values[1:] >= dcol.values[:-1])) 25 | if not is_sorted: 26 | break 27 | if is_sorted: 28 | print('The dataframe is already sorted by {}'.format(', '.join(columns))) 29 | else: 30 | print('The dataframe is not sorted by {}, sorting now'.format(col)) 31 | t0 = time.time() 32 | data.sort_values(columns, inplace=True) 33 | t1 = time.time() 34 | print('Data is sorted in {:.2f}'.format(t1 - t0)) 35 | 36 | def compute_offset(data, column): 37 | offset = np.zeros(data[column].nunique() + 1, dtype=np.int32) 38 | offset[1:] = data.groupby(column).size().cumsum() 39 | return offset 40 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Jun 26 17:27:26 2015 4 | 5 | @author: Balázs Hidasi 6 | """ 7 | 8 | import numpy as np 9 | import pandas as pd 10 | from collections import OrderedDict 11 | import theano 12 | from theano import tensor as T 13 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 14 | 15 | def evaluate_gpu(gru, test_data, items=None, session_key='SessionId', item_key='ItemId', time_key='Time', cut_off=[20], batch_size=100, mode='standard'): 16 | ''' 17 | Evaluates the GRU4Rec network quickly wrt. recommendation accuracy measured by recall@N and MRR@N. 18 | 19 | Parameters 20 | -------- 21 | pr : gru4rec.GRU4Rec 22 | A trained instance of the GRU4Rec network. 23 | test_data : pandas.DataFrame 24 | Test data. It contains the transactions of the test set.It has one column for session IDs, one for item IDs and one for the timestamp of the events (unix timestamps). 25 | It must have a header. Column names are arbitrary, but must correspond to the keys you use in this function. 26 | items : 1D list or None 27 | The list of item ID that you want to compare the score of the relevant item to. If None, all items of the training set are used. Default value is None. 28 | session_key : string 29 | Header of the session ID column in the input file (default: 'SessionId') 30 | item_key : string 31 | Header of the item ID column in the input file (default: 'ItemId') 32 | time_key : string 33 | Header of the timestamp column in the input file (default: 'Time') 34 | cut-off : int 35 | Cut-off value (i.e. the length of the recommendation list; N for recall@N and MRR@N). Defauld value is 20. 36 | batch_size : int 37 | Number of events bundled into a batch during evaluation. Speeds up evaluation. If it is set high, the memory consumption increases. Default value is 100. 38 | mode : 'standard', 'conservative', 'median', 'tiebreaking' 39 | Sets how ties (the exact same prediction scores) should be handled. Note that ties produced by GRU4Rec are very often a sign of saturation or some kind of error. 'standard' -> the positive item is ranked above all negatives with the same score; 'conservative' -> the positive item is ranked below all the negative items with the same score; 'median' -> assume that half of the negative items with the same score as the positive item are ranked before and the other half is ranked after, somewhat slower than the previous two; 'tiebreaking' -> add a small random value to every predicted score to break up ties, slowest of the modes. Default: 'standard' 40 | 41 | Returns 42 | -------- 43 | out : tuple 44 | (Recall@N, MRR@N) 45 | 46 | ''' 47 | if gru.error_during_train: raise Exception 48 | multi_cut_off = (type(cut_off) == list) or (type(cut_off) == tuple) 49 | print('Measuring Recall@{} and MRR@{}'.format(','.join([str(c) for c in cut_off]), ','.join([str(c) for c in cut_off]))) 50 | srng = RandomStreams() 51 | X = T.ivector() 52 | Y = T.ivector() 53 | M = T.iscalar() 54 | yhat, H, updatesH = gru.symbolic_predict(X, Y, M, items, batch_size) 55 | if mode == 'tiebreaking': yhat += srng.uniform(size=yhat.shape) * 1e-10 56 | if items is None: 57 | targets = T.diag(yhat.T[Y]) 58 | others = yhat.T 59 | else: 60 | targets = T.diag(yhat.T[:M]) 61 | others = yhat.T[M:] 62 | if mode == 'standard': ranks = (others > targets).sum(axis=0) + 1 63 | elif mode == 'conservative': ranks = (others >= targets).sum(axis=0) 64 | elif mode == 'median': ranks = (others > targets).sum(axis=0) + 0.5*((others == targets).sum(axis=0) - 1) + 1 65 | elif mode == 'tiebreaking': ranks = (others > targets).sum(axis=0) + 1 66 | else: raise NotImplementedError 67 | REC = [] 68 | MRR = [] 69 | if multi_cut_off: 70 | for c in cut_off: 71 | REC.append((ranks <= c).sum()) 72 | MRR.append(((ranks <= c) / ranks).sum()) 73 | else: 74 | REC.append((ranks <= cut_off).sum()) 75 | MRR.append(((ranks <= cut_off) / ranks).sum()) 76 | evaluate = theano.function(inputs=[X, Y, M], outputs=REC+MRR, updates=updatesH, allow_input_downcast=True, on_unused_input='ignore') 77 | test_data = pd.merge(test_data, pd.DataFrame({'ItemIdx':gru.itemidmap.values, item_key:gru.itemidmap.index}), on=item_key, how='inner') 78 | test_data.sort_values([session_key, time_key, item_key], inplace=True) 79 | test_data_items = test_data.ItemIdx.values 80 | if items is not None: 81 | item_idxs = gru.itemidmap[items] 82 | recall, mrr, n = [], [], 0 83 | if multi_cut_off: 84 | for i in range(len(cut_off)): 85 | recall.append(0) 86 | mrr.append(0) 87 | else: 88 | recall.append(0) 89 | mrr.append(0) 90 | iters = np.arange(batch_size) 91 | maxiter = iters.max() 92 | offset_sessions = np.zeros(test_data[session_key].nunique()+1, dtype=np.int32) 93 | offset_sessions[1:] = test_data.groupby(session_key).size().cumsum() 94 | start = offset_sessions[iters] 95 | end = offset_sessions[iters+1] 96 | finished = False 97 | cidxs = [] 98 | while not finished: 99 | minlen = (end-start).min() 100 | out_idx = test_data_items[start] 101 | for i in range(minlen-1): 102 | in_idx = out_idx 103 | out_idx = test_data_items[start+i+1] 104 | if items is not None: 105 | y = np.hstack([out_idx, item_idxs]) 106 | else: 107 | y = out_idx 108 | results = evaluate(in_idx, y, len(iters), *cidxs) 109 | if multi_cut_off: 110 | for j in range(len(cut_off)): 111 | recall[j] += results[j] 112 | mrr[j] += results[j + len(cut_off)] 113 | else: 114 | recall[0] += results[0] 115 | mrr[0] += results[1] 116 | n += len(iters) 117 | start = start+minlen-1 118 | finished_mask = (end-start<=1) 119 | n_finished = finished_mask.sum() 120 | iters[finished_mask] = maxiter + np.arange(1,n_finished+1) 121 | maxiter += n_finished 122 | valid_mask = (iters < len(offset_sessions)-1) 123 | n_valid = valid_mask.sum() 124 | if n_valid == 0: 125 | finished = True 126 | break 127 | mask = finished_mask & valid_mask 128 | sessions = iters[mask] 129 | start[mask] = offset_sessions[sessions] 130 | end[mask] = offset_sessions[sessions+1] 131 | iters = iters[valid_mask] 132 | start = start[valid_mask] 133 | end = end[valid_mask] 134 | if valid_mask.any(): 135 | for i in range(len(H)): 136 | tmp = H[i].get_value(borrow=True) 137 | tmp[mask] = 0 138 | tmp = tmp[valid_mask] 139 | H[i].set_value(tmp, borrow=True) 140 | if multi_cut_off: 141 | for i in range(len(cut_off)): 142 | recall[i] /= n 143 | mrr[i] /= n 144 | else: 145 | recall[0] /= n 146 | mrr[0] /= n 147 | return recall, mrr 148 | 149 | def evaluate_sessions_batch(pr, test_data, items=None, cut_off=20, batch_size=100, mode='standard', session_key='SessionId', item_key='ItemId', time_key='Time'): 150 | ''' 151 | Legacy (slow) method for evaluating the GRU4Rec network wrt. recommendation accuracy measured by recall@N and MRR@N. 152 | 153 | Parameters 154 | -------- 155 | pr : gru4rec.GRU4Rec 156 | A trained instance of the GRU4Rec network. 157 | test_data : pandas.DataFrame 158 | Test data. It contains the transactions of the test set.It has one column for session IDs, one for item IDs and one for the timestamp of the events (unix timestamps). 159 | It must have a header. Column names are arbitrary, but must correspond to the keys you use in this function. 160 | items : 1D list or None 161 | The list of item ID that you want to compare the score of the relevant item to. If None, all items of the training set are used. Default value is None. 162 | cut-off : int 163 | Cut-off value (i.e. the length of the recommendation list; N for recall@N and MRR@N). Defauld value is 20. 164 | batch_size : int 165 | Number of events bundled into a batch during evaluation. Speeds up evaluation. If it is set high, the memory consumption increases. Default value is 100. 166 | mode : 'standard', 'conservative', 'median', 'tiebreaking' 167 | Sets how ties (the exact same prediction scores) should be handled. Note that ties produced by GRU4Rec are very often a sign of saturation or some kind of error. 'standard' -> the positive item is ranked above all negatives with the same score; 'conservative' -> the positive item is ranked below all the negative items with the same score; 'median' -> assume that half of the negative items with the same score as the positive item are ranked before and the other half is ranked after, somewhat slower than the previous two; 'tiebreaking' -> add a small random value to every predicted score to break up ties, slowest of the modes. Default: 'standard' 168 | session_key : string 169 | Header of the session ID column in the input file (default: 'SessionId') 170 | item_key : string 171 | Header of the item ID column in the input file (default: 'ItemId') 172 | time_key : string 173 | Header of the timestamp column in the input file (default: 'Time') 174 | 175 | Returns 176 | -------- 177 | out : tuple 178 | (Recall@N, MRR@N) 179 | 180 | ''' 181 | print('Measuring Recall@{} and MRR@{}'.format(cut_off, cut_off)) 182 | test_data = pd.merge(test_data, pd.DataFrame({'ItemIdx':pr.itemidmap.values, item_key:pr.itemidmap.index}), on=item_key, how='inner') 183 | test_data.sort_values([session_key, time_key, item_key], inplace=True) 184 | offset_sessions = np.zeros(test_data[session_key].nunique()+1, dtype=np.int32) 185 | offset_sessions[1:] = test_data.groupby(session_key).size().cumsum() 186 | evalutation_point_count = 0 187 | mrr, recall = 0.0, 0.0 188 | if len(offset_sessions) - 1 < batch_size: 189 | batch_size = len(offset_sessions) - 1 190 | iters = np.arange(batch_size).astype(np.int32) 191 | #pos = np.zeros(min(batch_size, len(session_idx_arr))).astype(np.int32) 192 | maxiter = iters.max() 193 | start = offset_sessions[iters] 194 | end = offset_sessions[iters+1] 195 | in_idx = np.zeros(batch_size, dtype=np.int32) 196 | sampled_items = (items is not None) 197 | while True: 198 | valid_mask = iters >= 0 199 | if valid_mask.sum() == 0: 200 | break 201 | start_valid = start[valid_mask] 202 | minlen = (end[valid_mask]-start_valid).min() 203 | in_idx[valid_mask] = test_data[item_key].values[start_valid] 204 | for i in range(minlen-1): 205 | out_idx = test_data[item_key].values[start_valid+i+1] 206 | if sampled_items: 207 | uniq_out = np.unique(np.array(out_idx, dtype=np.int32)) 208 | preds = pr.predict_next_batch(iters, in_idx, np.hstack([items, uniq_out[~np.in1d(uniq_out,items)]]), batch_size) 209 | else: 210 | preds = pr.predict_next_batch(iters, in_idx, None, batch_size) #TODO: Handling sampling? 211 | preds.fillna(0, inplace=True) 212 | in_idx[valid_mask] = out_idx 213 | if mode == 'tiebreaking': 214 | preds += 1e-10 * np.random.rand(*preds.values.shape) 215 | if sampled_items: 216 | others = preds.ix[items].values.T[valid_mask].T 217 | targets = np.diag(preds.ix[in_idx].values)[valid_mask] 218 | if mode == 'standard': ranks = (others > targets).sum(axis=0) + 1 219 | elif mode == 'conservative': ranks = (others >= targets).sum(axis=0) 220 | elif mode == 'median': ranks = (others > targets).sum(axis=0) + 0.5*((others == targets).sum(axis=0) - 1) + 1 221 | elif mode == 'tiebreaking': ranks = (others > targets).sum(axis=0) + 1 222 | else: raise NotImplementedError 223 | else: 224 | if mode == 'standard': ranks = (preds.values.T[valid_mask].T > np.diag(preds.ix[in_idx].values)[valid_mask]).sum(axis=0) + 1 225 | elif mode == 'conservative': ranks = (preds.values.T[valid_mask].T >= np.diag(preds.ix[in_idx].values)[valid_mask]).sum(axis=0) 226 | elif mode == 'median': ranks = (preds.values.T[valid_mask].T > np.diag(preds.ix[in_idx].values)[valid_mask]).sum(axis=0) + 0.5*((preds.values.T[valid_mask].T == np.diag(preds.ix[in_idx].values)[valid_mask]).sum(axis=0) - 1) + 1 227 | elif mode == 'tiebreaking': ranks = (preds.values.T[valid_mask].T > np.diag(preds.ix[in_idx].values)[valid_mask]).sum(axis=0) + 1 228 | else: raise NotImplementedError 229 | rank_ok = ranks <= cut_off 230 | recall += rank_ok.sum() 231 | mrr += ((1.0 / ranks) * (rank_ok)).sum() 232 | evalutation_point_count += len(ranks) 233 | #pos += 1 234 | start = start+minlen-1 235 | mask = np.arange(len(iters))[(valid_mask) & (end-start<=1)] 236 | for idx in mask: 237 | maxiter += 1 238 | if maxiter >= len(offset_sessions)-1: 239 | iters[idx] = -1 240 | else: 241 | #pos[idx] = 0 242 | iters[idx] = maxiter 243 | start[idx] = offset_sessions[maxiter] 244 | end[idx] = offset_sessions[maxiter+1] 245 | return recall/evalutation_point_count, mrr/evalutation_point_count 246 | 247 | def evaluate_sessions(pr, test_data, train_data, items=None, cut_off=20, session_key='SessionId', item_key='ItemId', time_key='Time'): 248 | ''' 249 | Evaluates the baselines wrt. recommendation accuracy measured by recall@N and MRR@N. Has no batch evaluation capabilities. Breaks up ties. 250 | 251 | Parameters 252 | -------- 253 | pr : baseline predictor 254 | A trained instance of a baseline predictor. 255 | test_data : pandas.DataFrame 256 | Test data. It contains the transactions of the test set.It has one column for session IDs, one for item IDs and one for the timestamp of the events (unix timestamps). 257 | It must have a header. Column names are arbitrary, but must correspond to the keys you use in this function. 258 | train_data : pandas.DataFrame 259 | Training data. Only required for selecting the set of item IDs of the training set. 260 | items : 1D list or None 261 | The list of item ID that you want to compare the score of the relevant item to. If None, all items of the training set are used. Default value is None. 262 | cut-off : int 263 | Cut-off value (i.e. the length of the recommendation list; N for recall@N and MRR@N). Defauld value is 20. 264 | session_key : string 265 | Header of the session ID column in the input file (default: 'SessionId') 266 | item_key : string 267 | Header of the item ID column in the input file (default: 'ItemId') 268 | time_key : string 269 | Header of the timestamp column in the input file (default: 'Time') 270 | 271 | Returns 272 | -------- 273 | out : tuple 274 | (Recall@N, MRR@N) 275 | 276 | ''' 277 | test_data.sort_values([session_key, time_key], inplace=True) 278 | items_to_predict = train_data[item_key].unique() 279 | evalutation_point_count = 0 280 | prev_iid, prev_sid = -1, -1 281 | mrr, recall = 0.0, 0.0 282 | for i in range(len(test_data)): 283 | sid = test_data[session_key].values[i] 284 | iid = test_data[item_key].values[i] 285 | if prev_sid != sid: 286 | prev_sid = sid 287 | else: 288 | if items is not None: 289 | if np.in1d(iid, items): items_to_predict = items 290 | else: items_to_predict = np.hstack(([iid], items)) 291 | preds = pr.predict_next(sid, prev_iid, items_to_predict) 292 | preds[np.isnan(preds)] = 0 293 | preds += 1e-8 * np.random.rand(len(preds)) #Breaking up ties 294 | rank = (preds > preds[iid]).sum()+1 295 | assert rank > 0 296 | if rank < cut_off: 297 | recall += 1 298 | mrr += 1.0/rank 299 | evalutation_point_count += 1 300 | prev_iid = iid 301 | return recall/evalutation_point_count, mrr/evalutation_point_count 302 | -------------------------------------------------------------------------------- /examples/rsc15/preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Jun 25 16:20:12 2015 4 | 5 | @author: Balázs Hidasi 6 | """ 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import datetime as dt 11 | 12 | PATH_TO_ORIGINAL_DATA = '/path/to/clicks/dat/file/' 13 | PATH_TO_PROCESSED_DATA = '/path/to/store/processed/data/' 14 | 15 | data = pd.read_csv(PATH_TO_ORIGINAL_DATA + 'yoochoose-clicks.dat', sep=',', header=None, usecols=[0,1,2], dtype={0:np.int32, 1:str, 2:np.int64}) 16 | data.columns = ['SessionId', 'TimeStr', 'ItemId'] 17 | data['Time'] = data.TimeStr.apply(lambda x: dt.datetime.strptime(x, '%Y-%m-%dT%H:%M:%S.%fZ').timestamp()) #This is not UTC. It does not really matter. 18 | del(data['TimeStr']) 19 | 20 | session_lengths = data.groupby('SessionId').size() 21 | data = data[np.in1d(data.SessionId, session_lengths[session_lengths>1].index)] 22 | item_supports = data.groupby('ItemId').size() 23 | data = data[np.in1d(data.ItemId, item_supports[item_supports>=5].index)] 24 | session_lengths = data.groupby('SessionId').size() 25 | data = data[np.in1d(data.SessionId, session_lengths[session_lengths>=2].index)] 26 | 27 | tmax = data.Time.max() 28 | session_max_times = data.groupby('SessionId').Time.max() 29 | session_train = session_max_times[session_max_times < tmax-86400].index 30 | session_test = session_max_times[session_max_times >= tmax-86400].index 31 | train = data[np.in1d(data.SessionId, session_train)] 32 | test = data[np.in1d(data.SessionId, session_test)] 33 | test = test[np.in1d(test.ItemId, train.ItemId)] 34 | tslength = test.groupby('SessionId').size() 35 | test = test[np.in1d(test.SessionId, tslength[tslength>=2].index)] 36 | print('Full train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(len(train), train.SessionId.nunique(), train.ItemId.nunique())) 37 | train.to_csv(PATH_TO_PROCESSED_DATA + 'rsc15_train_full.txt', sep='\t', index=False) 38 | print('Test set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(len(test), test.SessionId.nunique(), test.ItemId.nunique())) 39 | test.to_csv(PATH_TO_PROCESSED_DATA + 'rsc15_test.txt', sep='\t', index=False) 40 | 41 | tmax = train.Time.max() 42 | session_max_times = train.groupby('SessionId').Time.max() 43 | session_train = session_max_times[session_max_times < tmax-86400].index 44 | session_valid = session_max_times[session_max_times >= tmax-86400].index 45 | train_tr = train[np.in1d(train.SessionId, session_train)] 46 | valid = train[np.in1d(train.SessionId, session_valid)] 47 | valid = valid[np.in1d(valid.ItemId, train_tr.ItemId)] 48 | tslength = valid.groupby('SessionId').size() 49 | valid = valid[np.in1d(valid.SessionId, tslength[tslength>=2].index)] 50 | print('Train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(len(train_tr), train_tr.SessionId.nunique(), train_tr.ItemId.nunique())) 51 | train_tr.to_csv(PATH_TO_PROCESSED_DATA + 'rsc15_train_tr.txt', sep='\t', index=False) 52 | print('Validation set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(len(valid), valid.SessionId.nunique(), valid.ItemId.nunique())) 53 | valid.to_csv(PATH_TO_PROCESSED_DATA + 'rsc15_train_valid.txt', sep='\t', index=False) 54 | -------------------------------------------------------------------------------- /examples/rsc15/run_rsc15.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Apr 6 18:14:46 2016 4 | 5 | @author: Balázs Hidasi 6 | """ 7 | 8 | import sys 9 | sys.path.append('../..') 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import gru4rec 14 | import evaluation 15 | 16 | PATH_TO_TRAIN = '/db_vol/hb_work/rnn/data/processed/recsys_challenge_train_full.txt' 17 | PATH_TO_TEST = '/db_vol/hb_work/rnn/data/processed/recsys_challenge_test.txt' 18 | 19 | if __name__ == '__main__': 20 | data = pd.read_csv(PATH_TO_TRAIN, sep='\t', dtype={'ItemId':np.int64}) 21 | valid = pd.read_csv(PATH_TO_TEST, sep='\t', dtype={'ItemId':np.int64}) 22 | 23 | #State-of-the-art results on RSC15 from "Recurrent Neural Networks with Top-k Gains for Session-based Recommendations" on RSC15 (http://arxiv.org/abs/1706.03847) 24 | #BPR-max, no embedding (R@20 = 0.7197, M@20 = 0.3157) 25 | gru = gru4rec.GRU4Rec(loss='bpr-max', final_act='elu-0.5', hidden_act='tanh', layers=[100], adapt='adagrad', n_epochs=10, batch_size=32, dropout_p_embed=0, dropout_p_hidden=0, learning_rate=0.2, momentum=0.3, n_sample=2048, sample_alpha=0, bpreg=1, constrained_embedding=False) 26 | gru.fit(data) 27 | res = evaluation.evaluate_gpu(gru, valid) 28 | print('Recall@20: {}'.format(res[0])) 29 | print('MRR@20: {}'.format(res[1])) 30 | 31 | #BPR-max, constrained embedding (R@20 = 0.7261, M@20 = 0.3124) 32 | gru = gru4rec.GRU4Rec(loss='bpr-max', final_act='elu-0.5', hidden_act='tanh', layers=[100], adapt='adagrad', n_epochs=10, batch_size=32, dropout_p_embed=0, dropout_p_hidden=0, learning_rate=0.2, momentum=0.1, n_sample=2048, sample_alpha=0, bpreg=0.5, constrained_embedding=True) 33 | gru.fit(data) 34 | res = evaluation.evaluate_gpu(gru, valid) 35 | print('Recall@20: {}'.format(res[0])) 36 | print('MRR@20: {}'.format(res[1])) 37 | 38 | #Cross-entropy (R@20 = 0.7180, M@20 = 0.3087) 39 | gru = gru4rec.GRU4Rec(loss='cross-entropy', final_act='softmax', hidden_act='tanh', layers=[100], adapt='adagrad', n_epochs=10, batch_size=32, dropout_p_embed=0, dropout_p_hidden=0.3, learning_rate=0.1, momentum=0.7, n_sample=2048, sample_alpha=0, bpreg=0, constrained_embedding=False) 40 | gru.fit(data) 41 | res = evaluation.evaluate_gpu(gru, valid) 42 | print('Recall@20: {}'.format(res[0])) 43 | print('MRR@20: {}'.format(res[1])) 44 | 45 | #OUTDATED!!! 46 | #Reproducing results from the original paperr"Session-based Recommendations with Recurrent Neural Networks" on RSC15 (http://arxiv.org/abs/1511.06939) 47 | #print('Training GRU4Rec with 100 hidden units') 48 | #gru = gru4rec.GRU4Rec(loss='top1', final_act='tanh', hidden_act='tanh', layers=[100], batch_size=50, dropout_p_hidden=0.5, learning_rate=0.01, momentum=0.0, time_sort=False) 49 | #gru.fit(data) 50 | #res = evaluation.evaluate_gpu(gru, valid) 51 | #print('Recall@20: {}'.format(res[0])) 52 | #print('MRR@20: {}'.format(res[1])) 53 | -------------------------------------------------------------------------------- /gpu_ops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Nov 10 14:17:58 2017 4 | 5 | @author: Balázs Hidasi 6 | """ 7 | 8 | import theano 9 | from theano import tensor as T 10 | import numpy as np 11 | import custom_theano_ops as cto 12 | 13 | disable_custom_op = False 14 | 15 | def gpu_diag_wide(X, keepdims=False): 16 | E = T.eye(*X.shape) 17 | return T.sum(X*E, axis=1, keepdims=keepdims) 18 | 19 | def gpu_diag_tall(X, keepdims=False): 20 | E = T.eye(*X.shape) 21 | return T.sum(X*E, axis=0, keepdims=keepdims) 22 | 23 | def gpu_diag(X, keepdims=False, disable_custom_op=disable_custom_op): 24 | if disable_custom_op: 25 | return T.switch(T.gt(X.shape[0], X.shape[1]), gpu_diag_tall(X, keepdims), gpu_diag_wide(X, keepdims)) 26 | else: 27 | return cto.GpuExtractDiag2D(keepdims=keepdims)(X) 28 | 29 | def gpu_searchsorted_step(A, B, X, P): 30 | I = (A+B)//2 31 | PI = P[I] 32 | return A*(X=PI), B*(X>PI)+I*(X<=PI) 33 | 34 | def gpu_searchsorted_scan(P, X): 35 | N = T.cast(T.floor(T.log2(P.shape[0]))+1, 'int64') 36 | (_, B), _ = theano.scan(gpu_searchsorted_step, outputs_info=[T.zeros_like(X, dtype='int64'), T.ones_like(X, dtype='int64')*(P.shape[0]-1)], non_sequences=[X, P], n_steps=N, allow_gc=True) 37 | return B[-1] 38 | 39 | def gpu_searchsorted(P, X, dtype_int64=True, disable_custom_op=disable_custom_op): 40 | if disable_custom_op: 41 | return gpu_searchsorted_scan(P, X) 42 | else: 43 | return cto.GpuBinarySearchSorted(dtype_int64=dtype_int64)(P, X) 44 | -------------------------------------------------------------------------------- /gru4rec.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jun 22 15:14:20 2015 4 | @author: Balázs Hidasi 5 | """ 6 | 7 | import os 8 | import os.path 9 | orig_cwd = os.getcwd() 10 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 11 | os.environ['THEANORC'] = '.theanorc_gru4rec' #Only affects the actual settings if theano was not imported before this point (by any module) 12 | import custom_opt 13 | import datatools 14 | import theano 15 | from theano import tensor as T 16 | from theano import function 17 | from theano.sandbox.rng_mrg import MRG_RandomStreams 18 | from gpu_ops import gpu_diag, gpu_searchsorted 19 | os.chdir(orig_cwd) 20 | import numpy as np 21 | import pandas as pd 22 | import pickle 23 | import time 24 | from collections import OrderedDict 25 | mrng = MRG_RandomStreams() 26 | 27 | class GRU4Rec: 28 | ''' 29 | GRU4Rec(loss='bpr-max', final_act='elu-1', hidden_act='tanh', layers=[100], 30 | n_epochs=10, batch_size=32, dropout_p_hidden=0.0, dropout_p_embed=0.0, learning_rate=0.1, momentum=0.0, lmbd=0.0, embedding=0, n_sample=2048, sample_alpha=0.75, smoothing=0.0, constrained_embedding=False, 31 | adapt='adagrad', adapt_params=[], grad_cap=0.0, bpreg=1.0, logq=0.0, 32 | sigma=0.0, init_as_normal=False, train_random_order=False, time_sort=True, 33 | session_key='SessionId', item_key='ItemId', time_key='Time') 34 | Initializes the network. 35 | 36 | Parameters 37 | ----------- 38 | loss : 'top1', 'bpr', 'cross-entropy', 'xe_logit', 'top1-max', 'bpr-max' 39 | selects the loss function (default : 'bpr-max') 40 | final_act : 'softmax', 'linear', 'relu', 'tanh', 'softmax_logit', 'leaky-', 'elu-', 'selu--' 41 | selects the activation function of the final layer, and are the parameters of the activation function (default : 'elu-1') 42 | hidden_act : 'linear', 'relu', 'tanh', 'leaky-', 'elu-', 'selu--' 43 | selects the activation function on the hidden states, and are the parameters of the activation function (default : 'tanh') 44 | layers : list of int values 45 | list of the number of GRU units in the layers (default : [100]) 46 | n_epochs : int 47 | number of training epochs (default: 10) 48 | batch_size : int 49 | size of the minibacth, also effect the number of negative samples through minibatch based sampling (default: 32) 50 | dropout_p_hidden : float 51 | probability of dropout of hidden units (default: 0.0) 52 | dropout_p_embed : float 53 | probability of dropout of the input units, applicable only if embeddings are used (default: 0.0) 54 | learning_rate : float 55 | learning rate (default: 0.05) 56 | momentum : float 57 | if not zero, Nesterov momentum will be applied during training with the given strength (default: 0.0) 58 | lmbd : float 59 | coefficient of the L2 regularization (default: 0.0) 60 | embedding : int or "layersize" 61 | size of the embedding used, 0 means not to use embedding (default: 0) 62 | n_sample : int 63 | number of additional negative samples to be used (besides the other examples of the minibatch) (default: 2048) 64 | sample_alpha : float 65 | the probability of an item used as an additional negative sample is supp^sample_alpha (default: 0.75) 66 | (e.g.: sample_alpha=1 --> popularity based sampling; sample_alpha=0 --> uniform sampling) 67 | smoothing : float 68 | (only works with cross-entropy and xe_logit losses) if set to non-zero class labels are smoothed with this value, i.e. the expected utput is (e/N, ..., e/N, 1-e+e/N, e/N, ..., e/N) instead of (0, ..., 0, 1, 0, ..., 0), where N is the number of outputs and e is the smoothing value (default: 0.0) 69 | constrained_embedding : bool 70 | if True, the output weight matrix is also used as input embedding (default: False) 71 | adapt : None, 'adagrad', 'rmsprop', 'adam', 'adadelta' 72 | sets the appropriate learning rate adaptation strategy, use None for standard SGD (default: 'adagrad') 73 | adapt_params : list 74 | parameters for the adaptive learning methods (default: []) 75 | grad_cap : float 76 | clip gradients that exceede this value to this value, 0 means no clipping (default: 0.0) 77 | bpreg : float 78 | score regularization coefficient for the BPR-max loss function (default: 1.0) 79 | logq : float 80 | logq normalization of negative samples (set between 0.0 and 1.0), usually useful with cross-entropy loss (default: 0.0) 81 | sigma : float 82 | "width" of initialization; either the standard deviation or the min/max of the init interval (with normal and uniform initializations respectively); 0 means adaptive normalization (sigma depends on the size of the weight matrix); (default: 0.0) 83 | init_as_normal : boolean 84 | False: init from uniform distribution on [-sigma,sigma]; True: init from normal distribution N(0,sigma); (default: False) 85 | train_random_order : boolean 86 | whether to randomize the order of sessions in each epoch (default: False) 87 | time_sort : boolean 88 | whether to ensure the the order of sessions is chronological (default: True) 89 | session_key : string 90 | header of the session ID column in the input file (default: 'SessionId') 91 | item_key : string 92 | header of the item ID column in the input file (default: 'ItemId') 93 | time_key : string 94 | header of the timestamp column in the input file (default: 'Time') 95 | 96 | ''' 97 | def __init__(self, loss='bpr-max', final_act='linear', hidden_act='tanh', layers=[100], 98 | n_epochs=10, batch_size=32, dropout_p_hidden=0.0, dropout_p_embed=0.0, learning_rate=0.1, momentum=0.0, lmbd=0.0, embedding=0, n_sample=2048, sample_alpha=0.75, smoothing=0.0, constrained_embedding=False, 99 | adapt='adagrad', adapt_params=[], grad_cap=0.0, bpreg=1.0, logq=0.0, 100 | sigma=0.0, init_as_normal=False, train_random_order=False, time_sort=True, 101 | session_key='SessionId', item_key='ItemId', time_key='Time'): 102 | self.layers = layers 103 | self.n_epochs = n_epochs 104 | self.batch_size = batch_size 105 | self.dropout_p_hidden = dropout_p_hidden 106 | self.dropout_p_embed = dropout_p_embed 107 | self.learning_rate = learning_rate 108 | self.adapt_params = adapt_params 109 | self.momentum = momentum 110 | self.sigma = sigma 111 | self.init_as_normal = init_as_normal 112 | self.session_key = session_key 113 | self.item_key = item_key 114 | self.time_key = time_key 115 | self.grad_cap = grad_cap 116 | self.bpreg = bpreg 117 | self.logq = logq 118 | self.train_random_order = train_random_order 119 | self.lmbd = lmbd 120 | if embedding == 'layersize': 121 | self.embedding = self.layers[0] 122 | else: 123 | self.embedding = embedding 124 | self.constrained_embedding = constrained_embedding 125 | self.time_sort = time_sort 126 | self.adapt = adapt 127 | self.loss = loss 128 | self.set_loss_function(self.loss) 129 | self.final_act = final_act 130 | self.set_final_activation(self.final_act) 131 | self.hidden_act = hidden_act 132 | self.set_hidden_activation(self.hidden_act) 133 | self.n_sample = n_sample 134 | self.sample_alpha = sample_alpha 135 | self.smoothing = smoothing 136 | def set_loss_function(self, loss): 137 | if loss == 'cross-entropy': self.loss_function = self.cross_entropy 138 | elif loss == 'bpr': self.loss_function = self.bpr 139 | elif loss == 'bpr-max': self.loss_function = self.bpr_max 140 | elif loss == 'top1': self.loss_function = self.top1 141 | elif loss == 'top1-max': self.loss_function = self.top1_max 142 | elif loss == 'xe_logit': self.loss_function = self.cross_entropy_logits 143 | else: raise NotImplementedError 144 | def set_final_activation(self, final_act): 145 | if final_act == 'linear': self.final_activation = self.linear 146 | elif final_act == 'relu': self.final_activation = self.relu 147 | elif final_act == 'softmax': self.final_activation=self.softmax 148 | elif final_act == 'tanh': self.final_activation=self.tanh 149 | elif final_act == 'softmax_logit': self.final_activation=self.softmax_logit 150 | elif final_act.startswith('leaky-'): self.final_activation = self.LeakyReLU(float(final_act.split('-')[1])).execute 151 | elif final_act.startswith('elu-'): self.final_activation = self.Elu(float(final_act.split('-')[1])).execute 152 | elif final_act.startswith('selu-'): self.final_activation = self.Selu(*[float(x) for x in final_act.split('-')[1:]]).execute 153 | else: raise NotImplementedError 154 | def set_hidden_activation(self, hidden_act): 155 | if hidden_act == 'relu': self.hidden_activation = self.relu 156 | elif hidden_act == 'tanh': self.hidden_activation = self.tanh 157 | elif hidden_act == 'linear': self.hidden_activation = self.linear 158 | elif hidden_act.startswith('leaky-'): self.hidden_activation = self.LeakyReLU(float(hidden_act.split('-')[1])).execute 159 | elif hidden_act.startswith('elu-'): self.hidden_activation = self.Elu(float(hidden_act.split('-')[1])).execute 160 | elif hidden_act.startswith('selu-'): self.hidden_activation = self.Selu(*[float(x) for x in hidden_act.split('-')[1:]]).execute 161 | else: raise NotImplementedError 162 | def set_params(self, **kvargs): 163 | maxk_len = np.max([len(str(x)) for x in kvargs.keys()]) 164 | maxv_len = np.max([len(str(x)) for x in kvargs.values()]) 165 | for k,v in kvargs.items(): 166 | if not hasattr(self, k): 167 | print('Unkown attribute: {}'.format(k)) 168 | raise NotImplementedError 169 | else: 170 | if type(v) == str and k == 'adapt_params': v = [float(l) for l in v.split('/')] 171 | elif type(v) == str and type(getattr(self, k)) == list: v = [int(l) for l in v.split('/')] 172 | if type(v) == str and type(getattr(self, k)) == bool: 173 | if v == 'True' or v == '1': v = True 174 | elif v == 'False' or v == '0': v = False 175 | else: 176 | print('Invalid value for boolean parameter: {}'.format(v)) 177 | raise NotImplementedError 178 | if k == 'embedding' and v == 'layersize': 179 | self.embedding = 'layersize' 180 | setattr(self, k, type(getattr(self, k))(v)) 181 | if k == 'loss': self.set_loss_function(self.loss) 182 | if k == 'final_act': self.set_final_activation(self.final_act) 183 | if k == 'hidden_act': self.set_hidden_activation(self.hidden_act) 184 | print('SET {}{}TO {}{}(type: {})'.format(k, ' '*(maxk_len-len(k)+3), getattr(self, k), ' '*(maxv_len-len(str(getattr(self, k)))+3), type(getattr(self, k)))) 185 | if self.embedding == 'layersize': 186 | self.embedding = self.layers[0] 187 | print('SET {}{}TO {}{}(type: {})'.format('embedding', ' '*(maxk_len-len('embedding')+3), getattr(self, 'embedding'), ' '*(maxv_len-len(str(getattr(self, 'embedding')))+3), type(getattr(self, 'embedding')))) 188 | ######################ACTIVATION FUNCTIONS##################### 189 | def linear(self,X): 190 | return X 191 | def tanh(self,X): 192 | return T.tanh(X) 193 | def softmax(self,X): 194 | e_x = T.exp(X - X.max(axis=1, keepdims=True)) 195 | return e_x / e_x.sum(axis=1, keepdims=True) 196 | def softmax_logit(self, X): 197 | X = X - X.max(axis=1, keepdims=True) 198 | return T.log(T.exp(X).sum(axis=1, keepdims=True)) - X 199 | def softmax_neg(self, X): 200 | hm = 1.0 - T.eye(*X.shape) 201 | X = X * hm 202 | e_x = T.exp(X - X.max(axis=1, keepdims=True)) * hm 203 | return e_x / e_x.sum(axis=1, keepdims=True) 204 | def relu(self,X): 205 | return T.maximum(X, 0) 206 | def sigmoid(self, X): 207 | return T.nnet.sigmoid(X) 208 | class Selu: 209 | def __init__(self, lmbd, alpha): 210 | self.lmbd = lmbd 211 | self.alpha = alpha 212 | def execute(self, X): 213 | return self.lmbd * T.switch(T.ge(X, 0), X, self.alpha * (T.exp(X) - 1)) 214 | class Elu: 215 | def __init__(self, alpha): 216 | self.alpha = alpha 217 | def execute(self, X): 218 | return T.switch(T.ge(X, 0), X, self.alpha * (T.exp(X) - 1)) 219 | class LeakyReLU: 220 | def __init__(self, leak): 221 | self.leak = leak 222 | def execute(self, X): 223 | return T.switch(T.ge(X, 0), X, self.leak * X) 224 | #################################LOSS FUNCTIONS################################ 225 | def cross_entropy(self, yhat, M): 226 | if self.smoothing: 227 | n_out = M + self.n_sample 228 | return T.cast(T.sum((1.0-(n_out/(n_out-1))*self.smoothing) * (-T.log(gpu_diag(yhat)+1e-24)) + (self.smoothing/(n_out-1)) * T.sum(-T.log(yhat+1e-24), axis=1)), theano.config.floatX) 229 | else: 230 | return T.cast(T.sum(-T.log(gpu_diag(yhat)+1e-24)), theano.config.floatX) 231 | def cross_entropy_logits(self, yhat, M): 232 | if self.smoothing: 233 | n_out = M + self.n_sample 234 | return T.cast(T.sum((1.0-(n_out/(n_out-1))*self.smoothing) * gpu_diag(yhat) + (self.smoothing/(n_out-1)) * T.sum(yhat, axis=1)), theano.config.floatX) 235 | else: 236 | return T.cast(T.sum(gpu_diag(yhat)), theano.config.floatX) 237 | def bpr(self, yhat, M): 238 | return T.cast(T.sum(-T.log(T.nnet.sigmoid(gpu_diag(yhat, keepdims=True)-yhat))), theano.config.floatX) 239 | def bpr_max(self, yhat, M): 240 | softmax_scores = self.softmax_neg(yhat) 241 | return T.cast(T.sum(-T.log(T.sum(T.nnet.sigmoid(gpu_diag(yhat, keepdims=True)-yhat)*softmax_scores, axis=1)+1e-24)+self.bpreg*T.sum((yhat**2)*softmax_scores, axis=1)), theano.config.floatX) 242 | def top1(self, yhat, M): 243 | ydiag = gpu_diag(yhat, keepdims=True) 244 | return T.cast(T.sum(T.mean(T.nnet.sigmoid(-ydiag+yhat)+T.nnet.sigmoid(yhat**2), axis=1)-T.nnet.sigmoid(ydiag**2)/(M+self.n_sample)), theano.config.floatX) 245 | def top1_max(self, yhat, M): 246 | softmax_scores = self.softmax_neg(yhat) 247 | y = softmax_scores*(T.nnet.sigmoid(-gpu_diag(yhat, keepdims=True)+yhat)+T.nnet.sigmoid(yhat**2)) 248 | return T.cast(T.sum(T.sum(y, axis=1)), theano.config.floatX) 249 | ############################################################################### 250 | def floatX(self, X): 251 | return np.asarray(X, dtype=theano.config.floatX) 252 | def init_weights(self, shape, name=None): 253 | return theano.shared(self.init_matrix(shape), borrow=True, name=name) 254 | def init_matrix(self, shape): 255 | if self.sigma != 0: sigma = self.sigma 256 | else: sigma = np.sqrt(6.0 / (shape[0] + shape[1])) 257 | if self.init_as_normal: 258 | return self.floatX(np.random.randn(*shape) * sigma) 259 | else: 260 | return self.floatX(np.random.rand(*shape) * sigma * 2 - sigma) 261 | def extend_weights(self, W, n_new): 262 | matrix = W.get_value() 263 | sigma = self.sigma if self.sigma != 0 else np.sqrt(6.0 / (matrix.shape[0] + matrix.shape[1] + n_new)) 264 | if self.init_as_normal: new_rows = self.floatX(np.random.randn(n_new, matrix.shape[1]) * sigma) 265 | else: new_rows = self.floatX(np.random.rand(n_new, matrix.shape[1]) * sigma * 2 - sigma) 266 | W.set_value(np.vstack([matrix, new_rows])) 267 | def init(self, data): 268 | datatools.sort_if_needed(data, [self.session_key, self.time_key]) 269 | offset_sessions = datatools.compute_offset(data, self.session_key) 270 | np.random.seed(42) 271 | self.Wx, self.Wh, self.Wrz, self.Bh, self.H = [], [], [], [], [] 272 | if self.constrained_embedding: 273 | n_features = self.layers[-1] 274 | elif self.embedding: 275 | self.E = self.init_weights((self.n_items, self.embedding), name='E') 276 | n_features = self.embedding 277 | else: 278 | n_features = self.n_items 279 | for i in range(len(self.layers)): 280 | m = [] 281 | m.append(self.init_matrix((self.layers[i-1] if i > 0 else n_features, self.layers[i]))) 282 | m.append(self.init_matrix((self.layers[i-1] if i > 0 else n_features, self.layers[i]))) 283 | m.append(self.init_matrix((self.layers[i-1] if i > 0 else n_features, self.layers[i]))) 284 | self.Wx.append(theano.shared(value=np.hstack(m), borrow=True, name='Wx{}'.format(i))) #For compatibility's sake 285 | self.Wh.append(self.init_weights((self.layers[i], self.layers[i]), name='Wh{}'.format(i))) 286 | m2 = [] 287 | m2.append(self.init_matrix((self.layers[i], self.layers[i]))) 288 | m2.append(self.init_matrix((self.layers[i], self.layers[i]))) 289 | self.Wrz.append(theano.shared(value=np.hstack(m2), borrow=True, name='Wrz{}'.format(i))) #For compatibility's sake 290 | self.Bh.append(theano.shared(value=np.zeros((self.layers[i] * 3,), dtype=theano.config.floatX), borrow=True, name='Bh{}'.format(i))) 291 | self.H.append(theano.shared(value=np.zeros((self.batch_size,self.layers[i]), dtype=theano.config.floatX), borrow=True, name='H{}'.format(i))) 292 | self.Wy = self.init_weights((self.n_items, self.layers[-1]), name='Wy') 293 | self.By = theano.shared(value=np.zeros((self.n_items,1), dtype=theano.config.floatX), borrow=True, name='By') 294 | return offset_sessions 295 | def dropout(self, X, drop_p): 296 | if drop_p > 0: 297 | retain_prob = 1 - drop_p 298 | X *= mrng.binomial(X.shape, p=retain_prob, dtype=theano.config.floatX) / retain_prob 299 | return X 300 | def adam(self, param, grad, updates, sample_idx = None, epsilon = 1e-6): 301 | v1 = self.adapt_params[0] 302 | v2 = 1.0 - self.adapt_params[0] 303 | v3 = self.adapt_params[1] 304 | v4 = 1.0 - self.adapt_params[1] 305 | acc = theano.shared(param.get_value(borrow=False) * 0., borrow=True) 306 | meang = theano.shared(param.get_value(borrow=False) * 0., borrow=True) 307 | countt = theano.shared(param.get_value(borrow=False) * 0., borrow=True) 308 | if sample_idx is None: 309 | acc_new = v3 * acc + v4 * (grad**2) 310 | meang_new = v1 * meang + v2 * grad 311 | countt_new = countt + 1 312 | updates[acc] = acc_new 313 | updates[meang] = meang_new 314 | updates[countt] = countt_new 315 | else: 316 | acc_s = acc[sample_idx] 317 | meang_s = meang[sample_idx] 318 | countt_s = countt[sample_idx] 319 | # acc_new = v3 * acc_s + v4 * (grad**2) #Faster, but inaccurate when an index occurs multiple times 320 | # updates[acc] = T.set_subtensor(acc_s, acc_new) #Faster, but inaccurate when an index occurs multiple times 321 | updates[acc] = T.inc_subtensor(T.set_subtensor(acc_s, acc_s * v3)[sample_idx], v4 * (grad**2)) #Slower, but accurate when an index occurs multiple times 322 | acc_new = updates[acc][sample_idx] #Slower, but accurate when an index occurs multiple times 323 | # meang_new = v1 * meang_s + v2 * grad 324 | # updates[meang] = T.set_subtensor(meang_s, meang_new) #Faster, but inaccurate when an index occurs multiple times 325 | updates[meang] = T.inc_subtensor(T.set_subtensor(meang_s, meang_s * v1)[sample_idx], v2 * (grad**2)) #Slower, but accurate when an index occurs multiple times 326 | meang_new = updates[meang][sample_idx] #Slower, but accurate when an index occurs multiple times 327 | countt_new = countt_s + 1.0 328 | updates[countt] = T.set_subtensor(countt_s, countt_new) 329 | return (meang_new / (1 - v1**countt_new)) / (T.sqrt(acc_new / (1 - v1**countt_new)) + epsilon) 330 | def adagrad(self, param, grad, updates, sample_idx = None, epsilon = 1e-6): 331 | acc = theano.shared(param.get_value(borrow=False) * 0., borrow=True) 332 | if sample_idx is None: 333 | acc_new = acc + grad ** 2 334 | updates[acc] = acc_new 335 | else: 336 | acc_s = acc[sample_idx] 337 | acc_new = acc_s + grad ** 2 338 | updates[acc] = T.set_subtensor(acc_s, acc_new) 339 | gradient_scaling = T.cast(T.sqrt(acc_new + epsilon), theano.config.floatX) 340 | return grad / gradient_scaling 341 | def adadelta(self, param, grad, updates, sample_idx = None, epsilon = 1e-6): 342 | v1 = self.adapt_params[0] 343 | v2 = 1.0 - self.adapt_params[0] 344 | acc = theano.shared(param.get_value(borrow=False) * 0., borrow=True) 345 | upd = theano.shared(param.get_value(borrow=False) * 0., borrow=True) 346 | if sample_idx is None: 347 | acc_new = v1 * acc + v2 * (grad**2) 348 | updates[acc] = acc_new 349 | grad_scaling = (upd + epsilon) / (acc_new + epsilon) 350 | upd_new = v1 * upd + v2 * grad_scaling * (grad**2) 351 | updates[upd] = upd_new 352 | else: 353 | acc_s = acc[sample_idx] 354 | # acc_new = v1 * acc_s + v2 * (grad**2) #Faster, but inaccurate when an index occurs multiple times 355 | # updates[acc] = T.set_subtensor(acc_s, acc_new) #Faster, but inaccurate when an index occurs multiple times 356 | updates[acc] = T.inc_subtensor(T.set_subtensor(acc_s, acc_s * v1)[sample_idx], v2 * (grad**2)) #Slower, but accurate when an index occurs multiple times 357 | acc_new = updates[acc][sample_idx] #Slower, but accurate when an index occurs multiple times 358 | upd_s = upd[sample_idx] 359 | grad_scaling = (upd_s + epsilon) / (acc_new + epsilon) 360 | # updates[upd] = T.set_subtensor(upd_s, v1 * upd_s + v2 * grad_scaling * (grad**2)) #Faster, but inaccurate when an index occurs multiple times 361 | updates[upd] = T.inc_subtensor(T.set_subtensor(upd_s, upd_s * v1)[sample_idx], v2 * grad_scaling * (grad**2)) #Slower, but accurate when an index occurs multiple times 362 | gradient_scaling = T.cast(T.sqrt(grad_scaling), theano.config.floatX) 363 | if self.learning_rate != 1.0: 364 | print('Warn: learning_rate is not 1.0 while using adadelta. Setting learning_rate to 1.0') 365 | self.learning_rate = 1.0 366 | return grad * gradient_scaling #Ok, checked 367 | def rmsprop(self, param, grad, updates, sample_idx = None, epsilon = 1e-6): 368 | v1 = self.adapt_params[0] 369 | v2 = 1.0 - self.adapt_params[0] 370 | acc = theano.shared(param.get_value(borrow=False) * 0., borrow=True) 371 | if sample_idx is None: 372 | acc_new = v1 * acc + v2 * grad ** 2 373 | updates[acc] = acc_new 374 | else: 375 | acc_s = acc[sample_idx] 376 | # acc_new = v1 * acc_s + v2 * grad ** 2 #Faster, but inaccurate when an index occurs multiple times 377 | # updates[acc] = T.set_subtensor(acc_s, acc_new) #Faster, but inaccurate when an index occurs multiple times 378 | updates[acc] = T.inc_subtensor(T.set_subtensor(acc_s, acc_s * v1)[sample_idx], v2 * grad ** 2) #Slower, but accurate when an index occurs multiple times 379 | acc_new = updates[acc][sample_idx] #Slower, but accurate when an index occurs multiple times 380 | gradient_scaling = T.cast(T.sqrt(acc_new + epsilon), theano.config.floatX) 381 | return grad / gradient_scaling 382 | def RMSprop(self, cost, params, full_params, sampled_params, sidxs, epsilon=1e-6): 383 | grads = [T.grad(cost = cost, wrt = param) for param in params] 384 | sgrads = [T.grad(cost = cost, wrt = sparam) for sparam in sampled_params] 385 | updates = OrderedDict() 386 | if self.grad_cap>0: 387 | norm=T.cast(T.sqrt(T.sum([T.sum([T.sum(g**2) for g in g_list]) for g_list in grads]) + T.sum([T.sum(g**2) for g in sgrads])), theano.config.floatX) 388 | grads = [[T.switch(T.ge(norm, self.grad_cap), g*self.grad_cap/norm, g) for g in g_list] for g_list in grads] 389 | sgrads = [T.switch(T.ge(norm, self.grad_cap), g*self.grad_cap/norm, g) for g in sgrads] 390 | for p_list, g_list in zip(params, grads): 391 | for p, g in zip(p_list, g_list): 392 | if self.adapt == 'adagrad': 393 | g = self.adagrad(p, g, updates) 394 | elif self.adapt == 'rmsprop': 395 | g = self.rmsprop(p, g, updates) 396 | elif self.adapt == 'adadelta': 397 | g = self.adadelta(p, g, updates) 398 | elif self.adapt == 'adam': 399 | g = self.adam(p, g, updates) 400 | if self.momentum > 0: 401 | velocity = theano.shared(p.get_value(borrow=False) * 0., borrow=True) 402 | velocity2 = self.momentum * velocity - self.learning_rate * (g + self.lmbd * p) 403 | updates[velocity] = velocity2 404 | updates[p] = p + velocity2 405 | else: 406 | updates[p] = p * (1.0 - self.learning_rate * self.lmbd) - self.learning_rate * g 407 | for i in range(len(sgrads)): 408 | g = sgrads[i] 409 | fullP = full_params[i] 410 | sample_idx = sidxs[i] 411 | sparam = sampled_params[i] 412 | if self.adapt == 'adagrad': 413 | g = self.adagrad(fullP, g, updates, sample_idx) 414 | elif self.adapt == 'rmsprop': 415 | g = self.rmsprop(fullP, g, updates, sample_idx) 416 | elif self.adapt == 'adadelta': 417 | g = self.adadelta(fullP, g, updates, sample_idx) 418 | elif self.adapt == 'adam': 419 | g = self.adam(fullP, g, updates, sample_idx) 420 | if self.lmbd > 0: 421 | delta = self.learning_rate * (g + self.lmbd * sparam) 422 | else: 423 | delta = self.learning_rate * g 424 | if self.momentum > 0: 425 | velocity = theano.shared(fullP.get_value(borrow=False) * 0., borrow=True) 426 | vs = velocity[sample_idx] 427 | velocity2 = self.momentum * vs - delta 428 | updates[velocity] = T.set_subtensor(vs, velocity2) 429 | updates[fullP] = T.inc_subtensor(sparam, velocity2) 430 | else: 431 | updates[fullP] = T.inc_subtensor(sparam, - delta) 432 | return updates 433 | def model(self, X, H, M, R=None, Y=None, drop_p_hidden=0.0, drop_p_embed=0.0, predict=False): 434 | sparams, full_params, sidxs = [], [], [] 435 | if (hasattr(self, 'ST')) and (Y is not None) and (not predict) and (self.n_sample > 0): 436 | A = self.ST[self.STI] 437 | Y = T.concatenate([Y, A], axis=0) 438 | if self.constrained_embedding: 439 | if Y is not None: X = T.concatenate([X,Y], axis=0) 440 | S = self.Wy[X] 441 | Sx = S[:M] 442 | Sy = S[M:] 443 | y = self.dropout(Sx, drop_p_embed) 444 | H_new = [] 445 | start = 0 446 | sparams.append(S) 447 | full_params.append(self.Wy) 448 | sidxs.append(X) 449 | elif self.embedding: 450 | Sx = self.E[X] 451 | y = self.dropout(Sx, drop_p_embed) 452 | H_new = [] 453 | start = 0 454 | sparams.append(Sx) 455 | full_params.append(self.E) 456 | sidxs.append(X) 457 | else: 458 | Sx = self.Wx[0][X] 459 | vec = Sx + self.Bh[0] 460 | rz = T.nnet.sigmoid(vec[:,self.layers[0]:] + T.dot(H[0], self.Wrz[0])) 461 | h = self.hidden_activation(T.dot(H[0] * rz[:,:self.layers[0]], self.Wh[0]) + vec[:,:self.layers[0]]) 462 | z = rz[:,self.layers[0]:] 463 | h = (1.0-z)*H[0] + z*h 464 | h = self.dropout(h, drop_p_hidden) 465 | y = h 466 | H_new = [T.switch(R, 0, h) if not predict else h] 467 | start = 1 468 | sparams.append(Sx) 469 | full_params.append(self.Wx[0]) 470 | sidxs.append(X) 471 | for i in range(start, len(self.layers)): 472 | vec = T.dot(y, self.Wx[i]) + self.Bh[i] 473 | rz = T.nnet.sigmoid(vec[:,self.layers[i]:] + T.dot(H[i], self.Wrz[i])) 474 | h = self.hidden_activation(T.dot(H[i] * rz[:,:self.layers[i]], self.Wh[i]) + vec[:,:self.layers[i]]) 475 | z = rz[:,self.layers[i]:] 476 | h = (1.0-z)*H[i] + z*h 477 | h = self.dropout(h, drop_p_hidden) 478 | y = h 479 | H_new.append(T.switch(R, 0, h) if not predict else h) 480 | if Y is not None: 481 | if (not self.constrained_embedding) or predict: 482 | Sy = self.Wy[Y] 483 | sparams.append(Sy) 484 | full_params.append(self.Wy) 485 | sidxs.append(Y) 486 | SBy = self.By[Y] 487 | sparams.append(SBy) 488 | full_params.append(self.By) 489 | sidxs.append(Y) 490 | if predict and self.final_act == 'softmax_logit': 491 | y = self.softmax(T.dot(y, Sy.T) + SBy.flatten()) 492 | else: 493 | y = T.dot(y, Sy.T) + SBy.flatten() 494 | if not predict and self.logq: 495 | y = y - self.logq * T.log(T.concatenate([self.P0[Y[:M]], self.P0[Y[M:]]**self.sample_alpha], axis=0)) 496 | y = self.final_activation(y) 497 | return H_new, y, sparams, full_params, sidxs 498 | else: 499 | if predict and self.final_act == 'softmax_logit': 500 | y = self.softmax(T.dot(y, self.Wy.T) + self.By.flatten()) 501 | else: 502 | y = T.dot(y, self.Wy.T) + self.By.flatten() 503 | if not predict and self.logq: 504 | y = y - self.logq * T.log(self.P0) 505 | y = self.final_activation(y) 506 | return H_new, y, sparams, full_params, sidxs 507 | def generate_neg_samples(self, pop, length): 508 | if self.sample_alpha: 509 | sample = np.searchsorted(pop, np.random.rand(self.n_sample * length)) 510 | else: 511 | sample = np.random.choice(self.n_items, size=self.n_sample * length) 512 | if length > 1: 513 | sample = sample.reshape((length, self.n_sample)) 514 | return sample 515 | def fit(self, data, sample_store=10000000, store_type='gpu'): 516 | ''' 517 | Trains the network. 518 | 519 | Parameters 520 | -------- 521 | data : pandas.DataFrame 522 | Training data. It contains the transactions of the sessions. It has one column for session IDs, one for item IDs and one for the timestamp of the events (unix timestamps). 523 | It must have a header. Column names are arbitrary, but must correspond to the ones you set during the initialization of the network (session_key, item_key, time_key properties). 524 | sample_store : int 525 | If additional negative samples are used (n_sample > 0), the efficiency of GPU utilization can be sped up, by precomputing a large batch of negative samples (and recomputing when necessary). 526 | This parameter regulizes the size of this precomputed ID set. Its value is the maximum number of int values (IDs) to be stored. Precomputed IDs are stored in the RAM. 527 | For the most efficient computation, a balance must be found between storing few examples and constantly interrupting GPU computations for a short time vs. computing many examples and interrupting GPU computations for a long time (but rarely). 528 | store_type : 'cpu', 'gpu' 529 | Where to store the negative sample buffer (sample store). The cpu mode is legacy and is no longer supported. 530 | 531 | ''' 532 | self.predict = None 533 | self.error_during_train = False 534 | itemids = data[self.item_key].unique() 535 | self.n_items = len(itemids) 536 | self.itemidmap = pd.Series(data=np.arange(self.n_items), index=itemids, name='ItemIdx') 537 | data['ItemIdx'] = self.itemidmap[data[self.item_key].values].values 538 | offset_sessions = self.init(data) 539 | pop = data.groupby(self.item_key).size() 540 | if self.logq: 541 | self.P0 = theano.shared(pop[self.itemidmap.index.values].values.astype(theano.config.floatX), name='P0', borrow=False) 542 | if self.n_sample: 543 | pop = pop[self.itemidmap.index.values].values**self.sample_alpha 544 | pop = pop.cumsum() / pop.sum() 545 | pop[-1] = 1 546 | if sample_store: 547 | generate_length = sample_store // self.n_sample 548 | if generate_length <= 1: 549 | sample_store = 0 550 | print('No example store was used') 551 | elif store_type == 'cpu': 552 | neg_samples = self.generate_neg_samples(pop, generate_length) 553 | sample_pointer = 0 554 | print('Created sample store with {} batches of samples (type=CPU)'.format(generate_length)) 555 | elif store_type == 'gpu': 556 | P = theano.shared(pop.astype(theano.config.floatX), name='P') 557 | self.ST = theano.shared(np.zeros((generate_length, self.n_sample), dtype='int64')) 558 | self.STI = theano.shared(np.asarray(0, dtype='int64')) 559 | X = mrng.uniform((generate_length*self.n_sample,)) 560 | updates_st = OrderedDict() 561 | updates_st[self.ST] = gpu_searchsorted(P, X, dtype_int64=True).reshape((generate_length, self.n_sample)) 562 | updates_st[self.STI] = np.asarray(0, dtype='int64') 563 | generate_samples = theano.function([], updates=updates_st) 564 | generate_samples() 565 | sample_pointer = 0 566 | print('Created sample store with {} batches of samples (type=GPU)'.format(generate_length)) 567 | else: 568 | print('Invalid store type {}'.format(store_type)) 569 | raise NotImplementedError 570 | else: 571 | print('No example store was used') 572 | X = T.ivector(name='X') 573 | Y = T.ivector(name='Y') 574 | M = T.iscalar(name='M') 575 | R = T.bcol(name='R') 576 | H_new, Y_pred, sparams, full_params, sidxs = self.model(X, self.H, M, R, Y, self.dropout_p_hidden, self.dropout_p_embed) 577 | cost = self.loss_function(Y_pred, M) / self.batch_size 578 | params = [self.Wx if self.embedding or self.constrained_embedding else self.Wx[1:], self.Wh, self.Wrz, self.Bh] 579 | updates = self.RMSprop(cost, params, full_params, sparams, sidxs) 580 | for i in range(len(self.H)): 581 | updates[self.H[i]] = H_new[i] 582 | if hasattr(self, 'STI'): 583 | updates[self.STI] = self.STI + 1 584 | train_function = function(inputs=[X, Y, M, R], outputs=cost, updates=updates, allow_input_downcast=True, on_unused_input='ignore') 585 | base_order = np.argsort(data.groupby(self.session_key)[self.time_key].min().values) if self.time_sort else np.arange(len(offset_sessions)-1) 586 | data_items = data.ItemIdx.values 587 | for epoch in range(self.n_epochs): 588 | t0 = time.time() 589 | for i in range(len(self.layers)): 590 | self.H[i].set_value(np.zeros((self.batch_size,self.layers[i]), dtype=theano.config.floatX), borrow=True) 591 | c = [] 592 | cc = [] 593 | session_idx_arr = np.random.permutation(len(offset_sessions)-1) if self.train_random_order else base_order 594 | iters = np.arange(self.batch_size) 595 | maxiter = iters.max() 596 | start = offset_sessions[session_idx_arr[iters]] 597 | end = offset_sessions[session_idx_arr[iters]+1] 598 | finished = False 599 | while not finished: 600 | minlen = (end-start).min() 601 | out_idx = data_items[start] 602 | for i in range(minlen-1): 603 | in_idx = out_idx 604 | out_idx = data_items[start+i+1] 605 | if self.n_sample and store_type == 'cpu': 606 | if sample_store: 607 | if sample_pointer == generate_length: 608 | neg_samples = self.generate_neg_samples(pop, generate_length) 609 | sample_pointer = 0 610 | sample = neg_samples[sample_pointer] 611 | sample_pointer += 1 612 | else: 613 | sample = self.generate_neg_samples(pop, 1) 614 | y = np.hstack([out_idx, sample]) 615 | else: 616 | y = out_idx 617 | if self.n_sample: 618 | if sample_pointer == generate_length: 619 | generate_samples() 620 | sample_pointer = 0 621 | sample_pointer += 1 622 | reset = (start+i+1 == end-1) 623 | cost = train_function(in_idx, y, len(iters), reset.reshape(len(reset), 1)) 624 | c.append(cost) 625 | cc.append(len(iters)) 626 | if np.isnan(cost): 627 | print(str(epoch) + ': NaN error!') 628 | self.error_during_train = True 629 | return 630 | start = start+minlen-1 631 | finished_mask = (end-start<=1) 632 | n_finished = finished_mask.sum() 633 | iters[finished_mask] = maxiter + np.arange(1,n_finished+1) 634 | maxiter += n_finished 635 | valid_mask = (iters < len(offset_sessions)-1) 636 | n_valid = valid_mask.sum() 637 | if (n_valid == 0) or (n_valid < 2 and self.n_sample == 0): 638 | finished = True 639 | break 640 | mask = finished_mask & valid_mask 641 | sessions = session_idx_arr[iters[mask]] 642 | start[mask] = offset_sessions[sessions] 643 | end[mask] = offset_sessions[sessions+1] 644 | iters = iters[valid_mask] 645 | start = start[valid_mask] 646 | end = end[valid_mask] 647 | if n_valid < len(valid_mask): 648 | for i in range(len(self.H)): 649 | tmp = self.H[i].get_value(borrow=True) 650 | tmp = tmp[valid_mask] 651 | self.H[i].set_value(tmp, borrow=True) 652 | c = np.array(c) 653 | cc = np.array(cc) 654 | avgc = np.sum(c * cc) / np.sum(cc) 655 | if np.isnan(avgc): 656 | print('Epoch {}: NaN error!'.format(str(epoch))) 657 | self.error_during_train = True 658 | return 659 | t1 = time.time() 660 | dt = t1 - t0 661 | print('Epoch{} --> loss: {:.6f} \t({:.2f}s) \t[{:.2f} mb/s | {:.0f} e/s]'.format(epoch+1, avgc, dt, len(c)/dt, np.sum(cc)/dt)) 662 | if hasattr(self, 'ST'): 663 | del(self.ST) 664 | del(self.STI) 665 | def predict_next_batch(self, session_ids, input_item_ids, predict_for_item_ids=None, batch=100): 666 | ''' 667 | Gives predicton scores for a selected set of items. Can be used in batch mode to predict for multiple independent events (i.e. events of different sessions) at once and thus speed up evaluation. 668 | 669 | If the session ID at a given coordinate of the session_ids parameter remains the same during subsequent calls of the function, the corresponding hidden state of the network will be kept intact (i.e. that's how one can predict an item to a session). 670 | If it changes, the hidden state of the network is reset to zeros. 671 | 672 | Parameters 673 | -------- 674 | session_ids : 1D array 675 | Contains the session IDs of the events of the batch. Its length must equal to the prediction batch size (batch param). 676 | input_item_ids : 1D array 677 | Contains the item IDs of the events of the batch. Every item ID must be must be in the training data of the network. Its length must equal to the prediction batch size (batch param). 678 | predict_for_item_ids : 1D array (optional) 679 | IDs of items for which the network should give prediction scores. Every ID must be in the training set. The default value is None, which means that the network gives prediction on its every output (i.e. for all items in the training set). 680 | batch : int 681 | Prediction batch size. 682 | 683 | Returns 684 | -------- 685 | out : pandas.DataFrame 686 | Prediction scores for selected items for every event of the batch. 687 | Columns: events of the batch; rows: items. Rows are indexed by the item IDs. 688 | 689 | ''' 690 | if self.error_during_train: raise Exception 691 | if self.predict is None or self.predict_batch!=batch: 692 | self.predict_batch = batch 693 | X = T.ivector() 694 | Y = T.ivector() 695 | M = T.iscalar() if self.constrained_embedding or (predict_for_item_ids is not None) else None 696 | for i in range(len(self.layers)): 697 | self.H[i].set_value(np.zeros((batch,self.layers[i]), dtype=theano.config.floatX), borrow=True) 698 | if predict_for_item_ids is not None: 699 | H_new, yhat, _, _, _ = self.model(X, self.H, M, Y=Y, predict=True) 700 | else: 701 | H_new, yhat, _, _, _ = self.model(X, self.H, M, predict=True) 702 | updatesH = OrderedDict() 703 | for i in range(len(self.H)): 704 | updatesH[self.H[i]] = H_new[i] 705 | if predict_for_item_ids is not None: 706 | if self.constrained_embedding: self.predict = function(inputs=[X, Y, M], outputs=yhat, updates=updatesH, allow_input_downcast=True) 707 | else: self.predict = function(inputs=[X, Y], outputs=yhat, updates=updatesH, allow_input_downcast=True) 708 | else: 709 | if self.constrained_embedding: self.predict = function(inputs=[X, M], outputs=yhat, updates=updatesH, allow_input_downcast=True) 710 | else: self.predict = function(inputs=[X], outputs=yhat, updates=updatesH, allow_input_downcast=True) 711 | self.current_session = np.ones(batch) * -1 712 | session_change = np.arange(batch)[session_ids != self.current_session] 713 | if len(session_change) > 0: 714 | for i in range(len(self.H)): 715 | tmp = self.H[i].get_value(borrow=True) 716 | tmp[session_change] = 0 717 | self.H[i].set_value(tmp, borrow=True) 718 | self.current_session=session_ids.copy() 719 | in_idxs = self.itemidmap[input_item_ids] 720 | if predict_for_item_ids is not None: 721 | iIdxs = self.itemidmap[predict_for_item_ids] 722 | if self.constrained_embedding: preds = np.asarray(self.predict(in_idxs, iIdxs, batch)).T 723 | else: preds = np.asarray(self.predict(in_idxs, iIdxs)).T 724 | return pd.DataFrame(data=preds, index=predict_for_item_ids) 725 | else: 726 | if self.constrained_embedding: preds = np.asarray(self.predict(in_idxs, batch)).T 727 | else: preds = np.asarray(self.predict(in_idxs)).T 728 | return pd.DataFrame(data=preds, index=self.itemidmap.index) 729 | def symbolic_predict(self, X, Y, M, items, batch_size): 730 | if not self.constrained_embedding: M = None 731 | H = [] 732 | for i in range(len(self.layers)): 733 | H.append(theano.shared(np.zeros((batch_size, self.layers[i]), dtype=theano.config.floatX))) 734 | if items is not None: 735 | H_new, yhat, _, _, _ = self.model(X, H, M, Y=Y, predict=True) 736 | else: 737 | H_new, yhat, _, _, _ = self.model(X, H, M, predict=True) 738 | updatesH = OrderedDict() 739 | for i in range(len(H)): 740 | updatesH[H[i]] = H_new[i] 741 | return yhat, H, updatesH 742 | def savemodel(self, fname): 743 | #Get model parameters for GPU-CPU compatibility 744 | if self.embedding: 745 | self.E = self.E.get_value() 746 | for i in range(len(self.layers)): 747 | self.Wx[i] = self.Wx[i].get_value() 748 | self.Wrz[i] = self.Wrz[i].get_value() 749 | self.Wh[i] = self.Wh[i].get_value() 750 | self.Bh[i] = self.Bh[i].get_value() 751 | self.H[i] = self.H[i].get_value() 752 | self.Wy = self.Wy.get_value() 753 | self.By = self.By.get_value() 754 | #Write the model 755 | with open(fname, 'wb') as f: 756 | pickle.dump(self, f) 757 | #Reload the parameters 758 | if self.embedding: 759 | self.E = theano.shared(self.E, borrow=True, name='E') 760 | for i in range(len(self.layers)): 761 | self.Wx[i] = theano.shared(self.Wx[i], borrow=True, name='Wx{}'.format(i)) 762 | self.Wrz[i] = theano.shared(self.Wrz[i], borrow=True, name='Wrz{}'.format(i)) 763 | self.Wh[i] = theano.shared(self.Wh[i], borrow=True, name='Wh{}'.format(i)) 764 | self.Bh[i] = theano.shared(self.Bh[i], borrow=True, name='Bh{}'.format(i)) 765 | self.H[i] = theano.shared(self.H[i], borrow=True, name='H{}'.format(i)) 766 | self.Wy = theano.shared(self.Wy, borrow=True, name='Wy') 767 | self.By = theano.shared(self.By, borrow=True, name='By') 768 | @classmethod 769 | def loadmodel(cls, fname): 770 | gru = pd.read_pickle(fname) 771 | if gru.embedding: 772 | gru.E = theano.shared(gru.E, borrow=True, name='E') 773 | for i in range(len(gru.layers)): 774 | gru.Wx[i] = theano.shared(gru.Wx[i], borrow=True, name='Wx{}'.format(i)) 775 | gru.Wrz[i] = theano.shared(gru.Wrz[i], borrow=True, name='Wrz{}'.format(i)) 776 | gru.Wh[i] = theano.shared(gru.Wh[i], borrow=True, name='Wh{}'.format(i)) 777 | gru.Bh[i] = theano.shared(gru.Bh[i], borrow=True, name='Bh{}'.format(i)) 778 | gru.H[i] = theano.shared(gru.H[i], borrow=True, name='H{}'.format(i)) 779 | gru.Wy = theano.shared(gru.Wy, borrow=True, name='Wy') 780 | gru.By = theano.shared(gru.By, borrow=True, name='By') 781 | return gru 782 | -------------------------------------------------------------------------------- /img/training_time_bprmax_batch_size.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hidasib/GRU4Rec/a4ed5fbdba35bcc18bf1d4a9b76692ef462bb284/img/training_time_bprmax_batch_size.png -------------------------------------------------------------------------------- /img/training_time_bprmax_layers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hidasib/GRU4Rec/a4ed5fbdba35bcc18bf1d4a9b76692ef462bb284/img/training_time_bprmax_layers.png -------------------------------------------------------------------------------- /img/training_time_public_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hidasib/GRU4Rec/a4ed5fbdba35bcc18bf1d4a9b76692ef462bb284/img/training_time_public_data.png -------------------------------------------------------------------------------- /img/training_time_xe_batch_size.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hidasib/GRU4Rec/a4ed5fbdba35bcc18bf1d4a9b76692ef462bb284/img/training_time_xe_batch_size.png -------------------------------------------------------------------------------- /img/training_time_xe_layers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hidasib/GRU4Rec/a4ed5fbdba35bcc18bf1d4a9b76692ef462bb284/img/training_time_xe_layers.png -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | © Copyright 2015-2020, Balázs Hidasi and Gravity Research & Development Zrt. All rights reserved. 2 | 3 | Use and redistribution in source or binary forms for research and educational purposes, with or without modification, are permitted provided that the following conditions are met: 4 | • Redistributions of the source code must retain the above copyright notice, this list of conditions, the following disclaimer and appendix. Redistributions in binary form must reproduce the above copyright notice, this list of conditions, the following disclaimer and appendix in the documentation and/or other materials provided with the distribution. 5 | • The scientific papers that describe the algorithm (see: [1] and [2]) must be cited in the published research / educational material that uses this code. 6 | • Non-commercial use only. Commercial use (including, but not limited to selling the distribution (with or without modifications), using it in commercial systems or for other commercial gain) requires a specific prior written permission from the copyright holders. 7 | • The names of the copyright holders must not be used to endorse or promote products derived from this software without specific prior written permission. 8 | 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS “AS IS” AND WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 10 | 11 | 12 | APPENDIX 13 | 14 | I. Papers to cite 15 | [1] Balázs Hidasi, Alexandros Karatzoglou, Linas Baltrunas, Domonkos Tikk: Session-based Recommendations with Recurrent Neural Networks. arXiv preprint arXiv:1511.06939, 2015. https://arxiv.org/abs/1511.06939 Presented at the 4th International Conference on Learning Representations, ICLR 2016. 16 | [2] Balázs Hidasi, Alexandros Karatzoglou: Recurrent Neural Networks with Top-k Gains for Session-based Recommendations. arXiv preprint arXiv:1706.03847, 2017. https://arxiv.org/abs/1706.03847 17 | 18 | II. Getting permission for commercial use 19 | • Contact the copyright holders via email at “licensing-gru4rec yusp com”. 20 | • Provide details of the planned use (e.g.: who are you, which company do you represent (if any), how do you intend to use the software, for how long, and any other details that you deem important). 21 | • You will be notified via email whether your request is approved or not. If the request is approved, you will be presented with an offer for the license fee. 22 | • A license agreement will be sent to you after the license fee is paid. 23 | 24 | III. Requirements 25 | The source code requires certain open source libraries to run. These libraries are not included in the repository. You are responsible for obtaining these libraries. 26 | -------------------------------------------------------------------------------- /param_samples/rsc15_bpr-max.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | gru4rec_params = OrderedDict([ 3 | ('layers', [100]), 4 | ('loss', 'bpr-max'), 5 | ('final_act', 'elu-0.5'), 6 | ('hidden_act', 'tanh'), 7 | ('adapt', 'adagrad'), 8 | ('n_epochs', 10), 9 | ('batch_size', 32), 10 | ('dropout_p_embed', 0.0), 11 | ('dropout_p_hidden', 0.0), 12 | ('learning_rate', 0.2), 13 | ('momentum', 0.3), 14 | ('sample_alpha', 0.0), 15 | ('n_sample', 2048), 16 | ('bpreg', 1.0), 17 | ('constrained_embedding', False) 18 | ]) 19 | -------------------------------------------------------------------------------- /param_samples/rsc15_bpr-max_constrained.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | gru4rec_params = OrderedDict([ 3 | ('layers', [100]), 4 | ('loss', 'bpr-max'), 5 | ('final_act', 'elu-0.5'), 6 | ('hidden_act', 'tanh'), 7 | ('adapt', 'adagrad'), 8 | ('n_epochs', 10), 9 | ('batch_size', 32), 10 | ('dropout_p_embed', 0.0), 11 | ('dropout_p_hidden', 0.0), 12 | ('learning_rate', 0.2), 13 | ('momentum', 0.1), 14 | ('sample_alpha', 0.0), 15 | ('n_sample', 2048), 16 | ('bpreg', 0.5), 17 | ('constrained_embedding', True) 18 | ]) 19 | -------------------------------------------------------------------------------- /param_samples/rsc15_cross-entropy.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | gru4rec_params = OrderedDict([ 3 | ('layers', [100]), 4 | ('loss', 'cross-entropy'), 5 | ('final_act', 'softmax'), 6 | ('hidden_act', 'tanh'), 7 | ('adapt', 'adagrad'), 8 | ('n_epochs', 10), 9 | ('batch_size', 32), 10 | ('dropout_p_embed', 0.0), 11 | ('dropout_p_hidden', 0.3), 12 | ('learning_rate', 0.1), 13 | ('momentum', 0.7), 14 | ('sample_alpha', 0.0), 15 | ('n_sample', 2048), 16 | ('logq', 0.0), 17 | ('constrained_embedding', False) 18 | ]) 19 | -------------------------------------------------------------------------------- /param_samples/rsc15_cross-entropy_logq.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | gru4rec_params = OrderedDict([ 3 | ('layers', [100]), 4 | ('loss', 'cross-entropy'), 5 | ('final_act', 'softmax'), 6 | ('hidden_act', 'tanh'), 7 | ('adapt', 'adagrad'), 8 | ('n_epochs', 10), 9 | ('batch_size', 32), 10 | ('dropout_p_embed', 0.0), 11 | ('dropout_p_hidden', 0.4), 12 | ('learning_rate', 0.2), 13 | ('momentum', 0.2), 14 | ('sample_alpha', 0.5), 15 | ('n_sample', 2048), 16 | ('logq', 1.0), 17 | ('constrained_embedding', True) 18 | ]) -------------------------------------------------------------------------------- /param_samples/rsc15_xe_logq.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | gru4rec_params = OrderedDict([ 3 | ('layers', [100]), 4 | ('loss', 'cross-entropy'), 5 | ('final_act', 'softmax'), 6 | ('hidden_act', 'tanh'), 7 | ('adapt', 'adagrad'), 8 | ('n_epochs', 10), 9 | ('batch_size', 64), 10 | ('dropout_p_embed', 0.0), 11 | ('dropout_p_hidden', 0.4), 12 | ('learning_rate', 0.2), 13 | ('momentum', 0.2), 14 | ('sample_alpha', 0.5), 15 | ('n_sample', 2048), 16 | ('logq', 1.0), 17 | ('constrained_embedding', True) 18 | ]) -------------------------------------------------------------------------------- /paramfiles/coveo_bprmax_shared_best.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | gru4rec_params = OrderedDict([ 3 | ('loss', 'bpr-max'), 4 | ('constrained_embedding', True), 5 | ('embedding', 0), 6 | ('final_act', 'elu-1'), 7 | ('layers', [512]), 8 | ('n_epochs', 10), 9 | ('batch_size', 144), 10 | ('dropout_p_embed', 0.35), 11 | ('dropout_p_hidden', 0.0), 12 | ('learning_rate', 0.05), 13 | ('momentum', 0.4), 14 | ('n_sample', 2048), 15 | ('sample_alpha', 0.2), 16 | ('bpreg', 1.85), 17 | ('logq', 0.0) 18 | ]) 19 | -------------------------------------------------------------------------------- /paramfiles/diginetica_bprmax_shared_best.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | gru4rec_params = OrderedDict([ 3 | ('loss', 'bpr-max'), 4 | ('constrained_embedding', True), 5 | ('embedding', 0), 6 | ('final_act', 'elu-1'), 7 | ('layers', [512]), 8 | ('n_epochs', 10), 9 | ('batch_size', 128), 10 | ('dropout_p_embed', 0.5), 11 | ('dropout_p_hidden', 0.3), 12 | ('learning_rate', 0.05), 13 | ('momentum', 0.15), 14 | ('n_sample', 2048), 15 | ('sample_alpha', 0.3), 16 | ('bpreg', 0.9), 17 | ('logq', 0.0) 18 | ]) 19 | -------------------------------------------------------------------------------- /paramfiles/rees46_xe_shared_best.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | gru4rec_params = OrderedDict([ 3 | ('loss', 'cross-entropy'), 4 | ('constrained_embedding', True), 5 | ('embedding', 0), 6 | ('final_act', 'softmax'), 7 | ('layers', [512]), 8 | ('n_epochs', 10), 9 | ('batch_size', 240), 10 | ('dropout_p_embed', 0.45), 11 | ('dropout_p_hidden', 0.0), 12 | ('learning_rate', 0.065), 13 | ('momentum', 0.0), 14 | ('n_sample', 2048), 15 | ('sample_alpha', 0.5), 16 | ('bpreg', 0.0), 17 | ('logq', 1.0) 18 | ]) 19 | -------------------------------------------------------------------------------- /paramfiles/retailrocket_bprmax_shared_best.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | gru4rec_params = OrderedDict([ 3 | ('loss', 'bpr-max'), 4 | ('constrained_embedding', True), 5 | ('embedding', 0), 6 | ('final_act', 'elu-0.5'), 7 | ('layers', [224]), 8 | ('n_epochs', 10), 9 | ('batch_size', 80), 10 | ('dropout_p_embed', 0.5), 11 | ('dropout_p_hidden', 0.05), 12 | ('learning_rate', 0.05), 13 | ('momentum', 0.4), 14 | ('n_sample', 2048), 15 | ('sample_alpha', 0.4), 16 | ('bpreg', 1.95), 17 | ('logq', 0.0) 18 | ]) 19 | -------------------------------------------------------------------------------- /paramfiles/rsc15_xe_shared_100_best.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | gru4rec_params = OrderedDict([ 3 | ('loss', 'cross-entropy'), 4 | ('constrained_embedding', True), 5 | ('embedding', 0), 6 | ('final_act', 'softmax'), 7 | ('layers', [100]), 8 | ('n_epochs', 10), 9 | ('batch_size', 32), 10 | ('dropout_p_embed', 0.0), 11 | ('dropout_p_hidden', 0.4), 12 | ('learning_rate', 0.2), 13 | ('momentum', 0.2), 14 | ('n_sample', 2048), 15 | ('sample_alpha', 0.5), 16 | ('bpreg', 0.0), 17 | ('logq', 1.0) 18 | ]) 19 | -------------------------------------------------------------------------------- /paramfiles/yoochoose_xe_shared_best.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | gru4rec_params = OrderedDict([ 3 | ('loss', 'cross-entropy'), 4 | ('constrained_embedding', True), 5 | ('embedding', 0), 6 | ('final_act', 'softmax'), 7 | ('layers', [480]), 8 | ('n_epochs', 10), 9 | ('batch_size', 48), 10 | ('dropout_p_embed', 0.0), 11 | ('dropout_p_hidden', 0.2), 12 | ('learning_rate', 0.07), 13 | ('momentum', 0.0), 14 | ('n_sample', 2048), 15 | ('sample_alpha', 0.2), 16 | ('bpreg', 0.0), 17 | ('logq', 1.0) 18 | ]) 19 | -------------------------------------------------------------------------------- /paramspaces/gru4rec_bprmax_standard_parspace.json: -------------------------------------------------------------------------------- 1 | {"name":"layers", "dtype":"int", "values":[64,512], "step":32} 2 | {"name":"batch_size", "dtype":"int", "values":[32, 256], "step":16} 3 | {"name":"learning_rate", "dtype":"float", "values":[0.01, 0.25], "step":0.005} 4 | {"name":"dropout_p_embed", "dtype":"float", "values":[0.0, 0.5], "step":0.05} 5 | {"name":"dropout_p_hidden", "dtype":"float", "values":[0.0, 0.7], "step":0.05} 6 | {"name":"momentum", "dtype":"float", "values":[0.0, 0.9], "step":0.05} 7 | {"name":"sample_alpha", "dtype":"float", "values":[0.0, 1.0], "step":0.1} 8 | {"name":"bpreg", "dtype":"float", "values":[0.0, 2.0], "step":0.05} 9 | {"name":"final_act", "dtype":"categorical", "values":["elu-0.5", "elu-1", "linear"]} 10 | -------------------------------------------------------------------------------- /paramspaces/gru4rec_xe_standard_parspace.json: -------------------------------------------------------------------------------- 1 | {"name":"layers", "dtype":"int", "values":[64,512], "step":32} 2 | {"name":"batch_size", "dtype":"int", "values":[32, 256], "step":16} 3 | {"name":"learning_rate", "dtype":"float", "values":[0.01, 0.25], "step":0.005} 4 | {"name":"dropout_p_embed", "dtype":"float", "values":[0.0, 0.5], "step":0.05} 5 | {"name":"dropout_p_hidden", "dtype":"float", "values":[0.0, 0.7], "step":0.05} 6 | {"name":"momentum", "dtype":"float", "values":[0.0, 0.9], "step":0.05} 7 | {"name":"sample_alpha", "dtype":"float", "values":[0.0, 1.0], "step":0.1} 8 | -------------------------------------------------------------------------------- /paropt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import optuna 4 | import json 5 | 6 | class MyHelpFormatter(argparse.HelpFormatter): 7 | def __init__(self, *args, **kwargs): 8 | super(MyHelpFormatter, self).__init__(*args, **kwargs) 9 | try: 10 | columns = int(os.popen('stty size', 'r').read().split()[1]) 11 | except: 12 | columns = None 13 | if columns is not None: 14 | self._width = columns 15 | 16 | parser = argparse.ArgumentParser(formatter_class=MyHelpFormatter, description='Train or load a GRU4Rec model & measure recall and MRR on the specified test set(s).') 17 | parser.add_argument('path', metavar='PATH', type=str, help='Path to the training data (TAB separated file (.tsv or .txt) or pickled pandas.DataFrame object (.pickle)) (if the --load_model parameter is NOT provided) or to the serialized model (if the --load_model parameter is provided).') 18 | parser.add_argument('test', metavar='TEST_PATH', type=str, help='Path to the test data set(s) located at TEST_PATH.') 19 | parser.add_argument('-g', '--gru4rec_model', metavar='GRFILE', type=str, default='gru4rec', help='Name of the file containing the GRU4Rec class. Can be sued to select different varaiants. (Default: gru4rec)') 20 | parser.add_argument('-tf', '--theano_flags', metavar='FLAGS', type=str, nargs='?', default='device=cuda0', help='Theano settings.') 21 | parser.add_argument('-fp', '--fixed_parameters', metavar='PARAM_STRING', type=str, help='Fixed training parameters provided as a single parameter string. The format of the string is `param_name1=param_value1,param_name2=param_value2...`, e.g.: `loss=bpr-max,layers=100,constrained_embedding=True`. Boolean training parameters should be either True or False; parameters that can take a list should use / as the separator (e.g. layers=200/200). Mutually exclusive with the -pf (--parameter_file) and the -l (--load_model) arguments and one of the three must be provided.') 22 | parser.add_argument('-opf', '--optuna_parameter_file', metavar='PATH', type=str, help='File describing the parameter space for optuna.') 23 | parser.add_argument('-m', '--measure', metavar='AT', type=int, nargs='?', default=20, help='Measure recall & MRR at the defined recommendation list length. A single values can be provided. (Default: 20)') 24 | parser.add_argument('-nt', '--ntrials', metavar='NT', type=int, nargs='?', default=50, help='Number of optimization trials to perform (Default: 50)') 25 | parser.add_argument('-fm', '--final_measure', metavar='AT', type=int, nargs='*', default=[20], help='Measure recall & MRR at the defined recommendation list length(s) after the optimization is finished. Multiple values can be provided. (Default: 20)') 26 | parser.add_argument('-pm', '--primary_metric', metavar='METRIC', choices=['recall', 'mrr'], default='recall', help='Set primary metric, recall or mrr (e.g. for paropt). (Default: recall)') 27 | parser.add_argument('-e', '--eval_type', metavar='EVAL_TYPE', choices=['standard', 'conservative', 'median', 'tiebreaking'], default='standard', help='Sets how to handle if multiple items in the ranked list have the same prediction score (which is usually due to saturation or an error). See the documentation of evaluate_gpu() in evaluation.py for further details. (Default: standard)') 28 | parser.add_argument('-ik', '--item_key', metavar='IK', type=str, default='ItemId', help='Column name corresponding to the item IDs (detault: ItemId).') 29 | parser.add_argument('-sk', '--session_key', metavar='SK', type=str, default='SessionId', help='Column name corresponding to the session IDs (default: SessionId).') 30 | parser.add_argument('-tk', '--time_key', metavar='TK', type=str, default='Time', help='Column name corresponding to the timestamp (default: Time).') 31 | 32 | args = parser.parse_args() 33 | 34 | import pexpect 35 | import numpy as np 36 | from collections import OrderedDict 37 | import importlib 38 | import re 39 | 40 | def generate_command(optimized_param_str): 41 | command = 'python run.py "{}" -t "{}" -g {} -ps {},{} -m {} -pm {} -lpm -e {} -ik {} -sk {} -tk {}'.format(args.path, args.test, args.gru4rec_model, args.fixed_parameters, optimized_param_str, args.measure, args.primary_metric, args.eval_type, args.item_key, args.session_key, args.time_key) 42 | return command 43 | 44 | def run_once(optimized_param_str): 45 | command = generate_command(optimized_param_str) 46 | os.environ['THEANO_FLAGS'] = args.theano_flags 47 | cmd = pexpect.spawnu(command, timeout=None, maxread=1) 48 | line = cmd.readline() 49 | while line: 50 | line = line.strip() 51 | print(line) 52 | if re.match('PRIMARY METRIC: -*\\d\\.\\d+e*-*\\d*', line): 53 | t = line.split(':')[1].lstrip() 54 | val = float(t) 55 | break 56 | line = cmd.readline() 57 | return val 58 | 59 | class Parameter: 60 | def __init__(self, name, dtype, values, step=None, log=False): 61 | assert dtype in ['int', 'float', 'categorical'] 62 | assert type(values)==list 63 | assert len(values)==2 or dtype=='categorical' 64 | self.name = name 65 | self.dtype = dtype 66 | self.values = values 67 | self.step = step 68 | if self.step is None and self.dtype=='int': 69 | self.step = 1 70 | self.log = log 71 | @classmethod 72 | def fromjson(cls, json_string): 73 | obj = json.loads(json_string) 74 | return Parameter(obj['name'], obj['dtype'], obj['values'], obj['step'] if 'step' in obj else None, obj['log'] if 'log' in obj else False) 75 | def __call__(self, trial): 76 | if self.dtype == 'int': 77 | return trial.suggest_int(self.name, int(self.values[0]), int(self.values[1]), step=self.step, log=self.log) 78 | if self.dtype == 'float': 79 | return trial.suggest_float(self.name, float(self.values[0]), float(self.values[1]), step=self.step, log=self.log) 80 | if self.dtype == 'categorical': 81 | return trial.suggest_categorical(self.name, self.values) 82 | def __str__(self): 83 | desc = 'PARAMETER {} \t type={}'.format(self.name, self.dtype) 84 | if self.dtype == 'int' or self.dtype == 'float': 85 | desc += ' \t range=[{}..{}] (step={}) \t {} scale'.format(self.values[0], self.values[1], self.step if self.step is not None else 'N/A', 'UNIFORM' if not self.log else 'LOG') 86 | if self.dtype == 'categorical': 87 | desc += ' \t options: [{}]'.format(','.join([str(x) for x in self.values])) 88 | return desc 89 | 90 | def objective(trial, par_space): 91 | optimized_param_str = [] 92 | for par in par_space: 93 | val = par(trial) 94 | optimized_param_str.append('{}={}'.format(par.name,val)) 95 | optimized_param_str = ','.join(optimized_param_str) 96 | val = run_once(optimized_param_str) 97 | return val 98 | 99 | par_space = [] 100 | with open(args.optuna_parameter_file, 'rt') as f: 101 | print('-'*80) 102 | print('PARAMETER SPACE') 103 | for line in f: 104 | par = Parameter.fromjson(line) 105 | print('\t' + str(par)) 106 | par_space.append(par) 107 | print('-'*80) 108 | 109 | study = optuna.create_study(direction='maximize') 110 | study.optimize(lambda trial: objective(trial, par_space), n_trials=args.ntrials) 111 | 112 | print('Running final eval @{}:'.format(args.final_measure)) 113 | optimized_param_str = ','.join(['{}={}'.format(k,v) for k,v in study.best_params.items()]) 114 | os.environ['THEANO_FLAGS'] = args.theano_flags 115 | command = 'python run.py "{}" -t "{}" -g {} -ps {},{} -m {} -e {} -ik {} -sk {} -tk {}'.format(args.path, args.test, args.gru4rec_model, args.fixed_parameters, optimized_param_str, ' '.join([str(x) for x in args.final_measure]), args.eval_type, args.item_key, args.session_key, args.time_key) 116 | cmd = pexpect.spawnu(command, timeout=None, maxread=1) 117 | line = cmd.readline() 118 | while line: 119 | line = line.strip() 120 | print(line) 121 | line = cmd.readline() 122 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | class MyHelpFormatter(argparse.HelpFormatter): 6 | def __init__(self, *args, **kwargs): 7 | super(MyHelpFormatter, self).__init__(*args, **kwargs) 8 | self._width = shutil.get_terminal_size().columns 9 | 10 | parser = argparse.ArgumentParser(formatter_class=MyHelpFormatter, description='Train or load a GRU4Rec model & measure recall and MRR on the specified test set(s).') 11 | parser.add_argument('path', metavar='PATH', type=str, help='Path to the training data (TAB separated file (.tsv or .txt) or pickled pandas.DataFrame object (.pickle)) (if the --load_model parameter is NOT provided) or to the serialized model (if the --load_model parameter is provided).') 12 | parser.add_argument('-ps', '--parameter_string', metavar='PARAM_STRING', type=str, help='Training parameters provided as a single parameter string. The format of the string is `param_name1=param_value1,param_name2=param_value2...`, e.g.: `loss=bpr-max,layers=100,constrained_embedding=True`. Boolean training parameters should be either True or False; parameters that can take a list should use / as the separator (e.g. layers=200/200). Mutually exclusive with the -pf (--parameter_file) and the -l (--load_model) arguments and one of the three must be provided.') 13 | parser.add_argument('-pf', '--parameter_file', metavar='PARAM_PATH', type=str, help='Alternatively, training parameters can be set using a config file specified in this argument. The config file must contain a single OrderedDict named `gru4rec_params`. The parameters must have the appropriate type (e.g. layers = [100]). Mutually exclusive with the -ps (--parameter_string) and the -l (--load_model) arguments and one of the three must be provided.') 14 | parser.add_argument('-l', '--load_model', action='store_true', help='Load an already trained model instead of training a model. Mutually exclusive with the -ps (--parameter_string) and the -pf (--parameter_file) arguments and one of the three must be provided.') 15 | parser.add_argument('-s', '--save_model', metavar='MODEL_PATH', type=str, help='Save the trained model to the MODEL_PATH. (Default: don\'t save model)') 16 | parser.add_argument('-t', '--test', metavar='TEST_PATH', type=str, nargs='+', help='Path to the test data set(s) located at TEST_PATH. Multiple test sets can be provided (separate with spaces). (Default: don\'t evaluate the model)') 17 | parser.add_argument('-m', '--measure', metavar='AT', type=int, nargs='+', default=[20], help='Measure recall & MRR at the defined recommendation list length(s). Multiple values can be provided. (Default: 20)') 18 | parser.add_argument('-e', '--eval_type', metavar='EVAL_TYPE', choices=['standard', 'conservative', 'median', 'tiebreaking'], default='standard', help='Sets how to handle if multiple items in the ranked list have the same prediction score (which is usually due to saturation or an error). See the documentation of evaluate_gpu() in evaluation.py for further details. (Default: standard)') 19 | parser.add_argument('-ss', '--sample_store_size', metavar='SS', type=int, default=10000000, help='GRU4Rec uses a buffer for negative samples during training to maximize GPU utilization. This parameter sets the buffer length. Lower values require more frequent recomputation, higher values use more (GPU) memory. Unless you know what you are doing, you shouldn\'t mess with this parameter. (Default: 10000000)') 20 | parser.add_argument('--sample_store_on_cpu', action='store_true', help='If provided, the sample store will be stored in the RAM instead of the GPU memory. This is not advised in most cases, because it significantly lowers the GPU utilization. This option is provided if for some reason you want to train the model on the CPU (NOT advised). Note that you need to make modifications to the code so that it is able to run on CPU.') 21 | parser.add_argument('-g', '--gru4rec_model', metavar='GRFILE', type=str, default='gru4rec', help='Name of the file containing the GRU4Rec class. Can be used to select different varaiants. (Default: gru4rec)') 22 | parser.add_argument('-ik', '--item_key', metavar='IK', type=str, default='ItemId', help='Column name corresponding to the item IDs (detault: ItemId).') 23 | parser.add_argument('-sk', '--session_key', metavar='SK', type=str, default='SessionId', help='Column name corresponding to the session IDs (default: SessionId).') 24 | parser.add_argument('-tk', '--time_key', metavar='TK', type=str, default='Time', help='Column name corresponding to the timestamp (default: Time).') 25 | parser.add_argument('-pm', '--primary_metric', metavar='METRIC', choices=['recall', 'mrr'], default='recall', help='Set primary metric, recall or mrr (e.g. for paropt). (Default: recall)') 26 | parser.add_argument('-lpm', '--log_primary_metric', action='store_true', help='If provided, evaluation will log the value of the primary metric at the end of the run. Only works with one test file and list length.') 27 | args = parser.parse_args() 28 | 29 | import os.path 30 | orig_cwd = os.getcwd() 31 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 32 | import numpy as np 33 | import pandas as pd 34 | import datetime as dt 35 | import sys 36 | import time 37 | from collections import OrderedDict 38 | import importlib 39 | GRU4Rec = importlib.import_module(args.gru4rec_model).GRU4Rec 40 | import evaluation 41 | import importlib.util 42 | import joblib 43 | os.chdir(orig_cwd) 44 | 45 | def load_data(fname, args): 46 | if fname.endswith('.pickle'): 47 | print('Loading data from pickle file: {}'.format(fname)) 48 | data = joblib.load(fname) 49 | if args.session_key not in data.columns: 50 | print('ERROR. The column specified for session IDs "{}" is not in the data file ({})'.format(args.session_key, fname)) 51 | print('The default column name is "SessionId", but you can specify otherwise by setting the `session_key` parameter of the model.') 52 | sys.exit(1) 53 | if args.item_key not in data.columns: 54 | print('ERROR. The column specified for item IDs "{}" is not in the data file ({})'.format(args.item_key, fname)) 55 | print('The default column name is "ItemId", but you can specify otherwise by setting the `item_key` parameter of the model.') 56 | sys.exit(1) 57 | if args.time_key not in data.columns: 58 | print('ERROR. The column specified for time "{}" is not in the data file ({})'.format(args.time_key, fname)) 59 | print('The default column name is "Time", but you can specify otherwise by setting the `time_key` parameter of the model.') 60 | sys.exit(1) 61 | else: 62 | with open(fname, 'rt') as f: 63 | header = f.readline().strip().split('\t') 64 | if args.session_key not in header: 65 | print('ERROR. The column specified for session IDs "{}" is not in the data file ({})'.format(args.session_key, fname)) 66 | print('The default column name is "SessionId", but you can specify otherwise by setting the `session_key` parameter of the model.') 67 | sys.exit(1) 68 | if args.item_key not in header: 69 | print('ERROR. The colmn specified for item IDs "{}" is not in the data file ({})'.format(args.item_key, fname)) 70 | print('The default column name is "ItemId", but you can specify otherwise by setting the `item_key` parameter of the model.') 71 | sys.exit(1) 72 | if args.time_key not in header: 73 | print('ERROR. The column specified for time "{}" is not in the data file ({})'.format(args.time_key, fname)) 74 | print('The default column name is "Time", but you can specify otherwise by setting the `time_key` parameter of the model.') 75 | sys.exit(1) 76 | print('Loading data from TAB separated file: {}'.format(fname)) 77 | data = pd.read_csv(fname, sep='\t', usecols=[args.session_key, args.item_key, args.time_key], dtype={args.session_key:'int32', args.item_key:'str'}) 78 | return data 79 | 80 | if (args.parameter_string is not None) + (args.parameter_file is not None) + (args.load_model) != 1: 81 | print('ERROR. Exactly one of the following parameters must be provided: --parameter_string, --parameter_file, --load_model') 82 | sys.exit(1) 83 | 84 | if args.load_model: 85 | print('Loading trained model from file: {}'.format(args.path)) 86 | gru = GRU4Rec.loadmodel(args.path) 87 | else: 88 | if args.parameter_file: 89 | param_file_path = os.path.abspath(args.parameter_file) 90 | param_dir, param_file = os.path.split(param_file_path) 91 | spec = importlib.util.spec_from_file_location(param_file.split('.py')[0], os.path.abspath(args.parameter_file)) 92 | params = importlib.util.module_from_spec(spec) 93 | spec.loader.exec_module(params) 94 | gru4rec_params = params.gru4rec_params 95 | print('Loaded parameters from file: {}'.format(param_file_path)) 96 | if args.parameter_string: 97 | gru4rec_params = OrderedDict([x.split('=') for x in args.parameter_string.split(',')]) 98 | print('Creating GRU4Rec model') 99 | gru = GRU4Rec() 100 | gru.set_params(**gru4rec_params) 101 | print('Loading training data...') 102 | data = load_data(args.path, args) 103 | store_type = 'cpu' if args.sample_store_on_cpu else 'gpu' 104 | if store_type == 'cpu': 105 | print('WARNING! The sample store is set to be on the CPU. This will make training significantly slower on the GPU.') 106 | print('Started training') 107 | t0 = time.time() 108 | gru.fit(data, sample_store=args.sample_store_size, store_type=store_type) 109 | t1 = time.time() 110 | print('Total training time: {:.2f}s'.format(t1 - t0)) 111 | if args.save_model is not None: 112 | print('Saving trained model to: {}'.format(args.save_model)) 113 | gru.savemodel(args.save_model) 114 | 115 | if args.test is not None: 116 | if args.primary_metric.lower() == 'recall': 117 | pm_index = 0 118 | elif args.primary_metric.lower() == 'mrr': 119 | pm_index = 1 120 | else: 121 | raise RuntimeError('Invalid value `{}` for `primary_metric` parameter'.format(args.primary_metric)) 122 | for test_file in args.test: 123 | print('Loading test data...') 124 | test_data = load_data(test_file, args) 125 | print('Starting evaluation (cut-off={}, using {} mode for tiebreaking)'.format(args.measure, args.eval_type)) 126 | t0 = time.time() 127 | res = evaluation.evaluate_gpu(gru, test_data, batch_size=512, cut_off=args.measure, mode=args.eval_type, item_key=args.item_key, session_key=args.session_key, time_key=args.time_key) 128 | t1 = time.time() 129 | print('Evaluation took {:.2f}s'.format(t1 - t0)) 130 | for i, c in enumerate(args.measure): 131 | print('Recall@{}: {:.6f} MRR@{}: {:.6f}'.format(c, res[0][i], c, res[1][i])) 132 | if args.log_primary_metric: 133 | print('PRIMARY METRIC: {}'.format(res[pm_index][0])) 134 | --------------------------------------------------------------------------------