├── LICENSE ├── README.md ├── lib ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── dataset.cpython-36.pyc │ ├── dataset.cpython-37.pyc │ ├── evaluation.cpython-36.pyc │ ├── evaluation.cpython-37.pyc │ ├── lossfunction.cpython-36.pyc │ ├── lossfunction.cpython-37.pyc │ ├── metric.cpython-36.pyc │ ├── metric.cpython-37.pyc │ ├── model.cpython-36.pyc │ ├── model.cpython-37.pyc │ ├── optimizer.cpython-36.pyc │ ├── optimizer.cpython-37.pyc │ ├── trainer.cpython-36.pyc │ └── trainer.cpython-37.pyc ├── dataset.py ├── evaluation.py ├── lossfunction.py ├── metric.py ├── model.py ├── optimizer.py └── trainer.py ├── main.py ├── preprocessing.py └── tools.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GRU4REC-PyTorch 2 | - PyTorch Implementation of the GRU4REC model. 3 | - Original paper: [Session-based Recommendations with Recurrent Neural Networks(ICLR 2016)](https://arxiv.org/pdf/1511.06939.pdf) 4 | - Extension over the Original paper: [Recurrent Neural Networks with Top-k Gains for Session-based 5 | Recommendations(CIKM 2018)](https://arxiv.org/abs/1706.03847) 6 | - This code is based on [pyGRU4REC](https://github.com/yhs-968/pyGRU4REC) that is implemented by Younghun Song (yhs-968) and [original Theano code written by the authors of the GRU4REC paper](https://github.com/hidasib/GRU4Rec) 7 | - This Version supports TOP1, BPR, TOP1-max, BPR-max, and Cross-Entropy Losses. 8 | 9 | ## Requirements 10 | - PyTorch 0.4.1 11 | - Python 3.5 12 | - pandas 13 | - numpy 1.14.5 14 | 15 | ## Usage 16 | 17 | ### Dataset 18 | RecSys Challenge 2015 Dataset can be retreived from [HERE](https://2015.recsyschallenge.com/) 19 | 20 | ### Pre processing data 21 | - You need to run preprocessing.py to obtain training data and testing data. In the paper, only the training set was used, the testing set is ignored. 22 | - The training set itself is divided into training and testing where the testing split is the last day sessions. 23 | 24 | The format of data is similar to that obtained from RecSys Challenge 2015: 25 | - Filenames 26 | - Training set should be named as `recSys15TrainOnly.txt` 27 | - Test set should be named as `recSys15Valid.txt` 28 | - Contents 29 | - `recSys15TrainOnly.txt`, `recSys15Valid.txt` should be the tsv files that stores the pandas dataframes that satisfy the following requirements: 30 | - The 1st column of the file should be the integer Session IDs with header name SessionID 31 | - The 2nd column of the file should be the integer Item IDs with header name ItemID 32 | - The 3rd column of the file should be the Timestamps with header name Time 33 | 34 | ### Training and Testing 35 | The project have a structure as below: 36 | 37 | ```bash 38 | ├── GRU4REC-pytorch 39 | │ ├── checkpoint 40 | │ ├── data 41 | │ │ ├── preprocessed_data 42 | │ │ │ ├── recSys15TrainOnly.txt 43 | │ │ │ ├── recSys15Valid.txt 44 | │ │ ├── raw_data 45 | │ │ │ ├── yoochoose-clicks.dat 46 | │ ├── lib 47 | │ ├── main.py 48 | │ ├── preprocessing.py 49 | │ ├── tool.py 50 | ``` 51 | `tool.py` can be used to get 1/8 last session from `yoochoose-clicks.dat` 52 | 53 | In GRU4REC-pytorch 54 | 55 | Training 56 | ```bash 57 | python main.py 58 | ``` 59 | 60 | Testing 61 | ```bash 62 | python main.py --is_eval --load_model checkpoint/CHECKPOINT#/model_EPOCH#.pt 63 | ``` 64 | ### List of Arguments accepted 65 | ```--hidden_size``` Number of Neurons per Layer (Default = 100)
66 | ```--num_layers``` Number of Hidden Layers (Default = 1)
67 | ```--batch_size``` Batch Size (Default = 50)
68 | ```--dropout_input``` Dropout ratio at input (Default = 0)
69 | ```--dropout_hidden``` Dropout at each hidden layer except the last one (Default = 0.5)
70 | ```--n_epochs``` Number of epochs (Default = 10)
71 | ```--k_eval``` Value of K used durig Recall@K and MRR@K Evaluation (Default = 20)
72 | ```--optimizer_type``` Optimizer (Default = Adagrad)
73 | ```--final_act``` Activation Function (Default = Tanh)
74 | ```--lr``` Learning rate (Default = 0.01)
75 | ```--weight_decay``` Weight decay (Default = 0)
76 | ```--momentum``` Momentum Value (Default = 0)
77 | ```--eps``` Epsilon Value of Optimizer (Default = 1e-6)
78 | ```--loss_type``` Type of loss function TOP1 / BPR / TOP1-max / BPR-max / Cross-Entropy (Default: TOP1-max)
79 | ```--time_sort``` In case items are not sorted by time stamp (Default = 0)
80 | ```--model_name``` String of model name.
81 | ```--save_dir``` String of folder to save the checkpoints and logs inside it (Default = /checkpoint).
82 | ```--data_folder``` String of the directory to the folder containing the dataset.
83 | ```--train_data``` Name of the training dataset file (Default = `recSys15TrainOnly.txt`)
84 | ```--valid_data``` Name of the validation dataset file (Default = `recSys15Valid.txt`)
85 | ```--is_eval``` Should be used in case of evaluation only using a checkpoint model.
86 | ```--load_model``` String containing the checkpoint model to be used in evaluation.
87 | ```--checkpoint_dir``` String containing directory of the checkpoints folder.
88 | 89 | 90 | ## Results 91 | 92 | Different loss functions and different parameters have been tried out and the results can be seen from [HERE](https://docs.google.com/spreadsheets/d/19z6zFEY6pC0msi3wOQLk_kJsvqF8xnGOJPUGhQ36-wI/edit#gid=0) 93 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset, DataLoader 2 | from .model import GRU4REC 3 | from .metric import get_mrr, get_recall, evaluate 4 | from .evaluation import Evaluation 5 | from .optimizer import Optimizer 6 | from .lossfunction import LossFunction,SampledCrossEntropyLoss, BPRLoss, TOP1Loss 7 | from .trainer import Trainer 8 | from .evaluation import Evaluation -------------------------------------------------------------------------------- /lib/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/evaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/evaluation.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/evaluation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/evaluation.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/lossfunction.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/lossfunction.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/lossfunction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/lossfunction.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/optimizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/optimizer.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/optimizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/optimizer.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungpthanh/GRU4REC-pytorch/666b84264c4afae757fe55c6997dcf0a4da1d44e/lib/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class Dataset(object): 7 | def __init__(self, path, sep=',', session_key='SessionID', item_key='ItemID', time_key='Time', n_sample=-1, itemmap=None, itemstamp=None, time_sort=False): 8 | # Read csv 9 | self.df = pd.read_csv(path, sep=sep, dtype={session_key: int, item_key: int, time_key: float}) 10 | self.session_key = session_key 11 | self.item_key = item_key 12 | self.time_key = time_key 13 | self.time_sort = time_sort 14 | if n_sample > 0: 15 | self.df = self.df[:n_sample] 16 | 17 | # Add colummn item index to data 18 | self.add_item_indices(itemmap=itemmap) 19 | """ 20 | Sort the df by time, and then by session ID. That is, df is sorted by session ID and 21 | clicks within a session are next to each other, where the clicks within a session are time-ordered. 22 | """ 23 | self.df.sort_values([session_key, time_key], inplace=True) 24 | self.click_offsets = self.get_click_offset() 25 | self.session_idx_arr = self.order_session_idx() 26 | 27 | def add_item_indices(self, itemmap=None): 28 | """ 29 | Add item index column named "item_idx" to the df 30 | Args: 31 | itemmap (pd.DataFrame): mapping between the item Ids and indices 32 | """ 33 | if itemmap is None: 34 | item_ids = self.df[self.item_key].unique() # type is numpy.ndarray 35 | item2idx = pd.Series(data=np.arange(len(item_ids)), 36 | index=item_ids) 37 | # Build itemmap is a DataFrame that have 2 columns (self.item_key, 'item_idx) 38 | itemmap = pd.DataFrame({self.item_key: item_ids, 39 | 'item_idx': item2idx[item_ids].values}) 40 | self.itemmap = itemmap 41 | self.df = pd.merge(self.df, self.itemmap, on=self.item_key, how='inner') 42 | 43 | def get_click_offset(self): 44 | """ 45 | self.df[self.session_key] return a set of session_key 46 | self.df[self.session_key].nunique() return the size of session_key set (int) 47 | self.df.groupby(self.session_key).size() return the size of each session_id 48 | self.df.groupby(self.session_key).size().cumsum() retunn cumulative sum 49 | """ 50 | offsets = np.zeros(self.df[self.session_key].nunique() + 1, dtype=np.int32) 51 | offsets[1:] = self.df.groupby(self.session_key).size().cumsum() 52 | return offsets 53 | 54 | def order_session_idx(self): 55 | if self.time_sort: 56 | sessions_start_time = self.df.groupby(self.session_key)[self.time_key].min().values 57 | session_idx_arr = np.argsort(sessions_start_time) 58 | else: 59 | session_idx_arr = np.arange(self.df[self.session_key].nunique()) 60 | return session_idx_arr 61 | 62 | @property 63 | def items(self): 64 | return self.itemmap[self.item_key].unique() 65 | 66 | 67 | class DataLoader(): 68 | def __init__(self, dataset, batch_size=50): 69 | """ 70 | A class for creating session-parallel mini-batches. 71 | 72 | Args: 73 | dataset (SessionDataset): the session dataset to generate the batches from 74 | batch_size (int): size of the batch 75 | """ 76 | self.dataset = dataset 77 | self.batch_size = batch_size 78 | 79 | def __iter__(self): 80 | """ Returns the iterator for producing session-parallel training mini-batches. 81 | 82 | Yields: 83 | input (B,): torch.FloatTensor. Item indices that will be encoded as one-hot vectors later. 84 | target (B,): a Variable that stores the target item indices 85 | masks: Numpy array indicating the positions of the sessions to be terminated 86 | """ 87 | # initializations 88 | df = self.dataset.df 89 | click_offsets = self.dataset.click_offsets 90 | session_idx_arr = self.dataset.session_idx_arr 91 | 92 | iters = np.arange(self.batch_size) 93 | maxiter = iters.max() 94 | start = click_offsets[session_idx_arr[iters]] 95 | end = click_offsets[session_idx_arr[iters] + 1] 96 | mask = [] # indicator for the sessions to be terminated 97 | finished = False 98 | 99 | while not finished: 100 | minlen = (end - start).min() 101 | # Item indices(for embedding) for clicks where the first sessions start 102 | idx_target = df.item_idx.values[start] 103 | 104 | for i in range(minlen - 1): 105 | # Build inputs & targets 106 | idx_input = idx_target 107 | idx_target = df.item_idx.values[start + i + 1] 108 | input = torch.LongTensor(idx_input) 109 | target = torch.LongTensor(idx_target) 110 | yield input, target, mask 111 | 112 | # click indices where a particular session meets second-to-last element 113 | start = start + (minlen - 1) 114 | # see if how many sessions should terminate 115 | mask = np.arange(len(iters))[(end - start) <= 1] 116 | for idx in mask: 117 | maxiter += 1 118 | if maxiter >= len(click_offsets) - 1: 119 | finished = True 120 | break 121 | # update the next starting/ending point 122 | iters[idx] = maxiter 123 | start[idx] = click_offsets[session_idx_arr[maxiter]] 124 | end[idx] = click_offsets[session_idx_arr[maxiter] + 1] 125 | -------------------------------------------------------------------------------- /lib/evaluation.py: -------------------------------------------------------------------------------- 1 | import lib 2 | import numpy as np 3 | import torch 4 | from tqdm import tqdm 5 | 6 | class Evaluation(object): 7 | def __init__(self, model, loss_func, use_cuda, k=20): 8 | self.model = model 9 | self.loss_func = loss_func 10 | self.topk = k 11 | self.device = torch.device('cuda' if use_cuda else 'cpu') 12 | 13 | def eval(self, eval_data, batch_size): 14 | self.model.eval() 15 | losses = [] 16 | recalls = [] 17 | mrrs = [] 18 | dataloader = lib.DataLoader(eval_data, batch_size) 19 | with torch.no_grad(): 20 | hidden = self.model.init_hidden() 21 | for ii, (input, target, mask) in tqdm(enumerate(dataloader), total=len(dataloader.dataset.df) // dataloader.batch_size, miniters = 1000): 22 | #for input, target, mask in dataloader: 23 | input = input.to(self.device) 24 | target = target.to(self.device) 25 | logit, hidden = self.model(input, hidden) 26 | logit_sampled = logit[:, target.view(-1)] 27 | loss = self.loss_func(logit_sampled) 28 | recall, mrr = lib.evaluate(logit, target, k=self.topk) 29 | 30 | # torch.Tensor.item() to get a Python number from a tensor containing a single value 31 | losses.append(loss.item()) 32 | recalls.append(recall) 33 | mrrs.append(mrr) 34 | mean_losses = np.mean(losses) 35 | mean_recall = np.mean(recalls) 36 | mean_mrr = np.mean(mrrs) 37 | 38 | return mean_losses, mean_recall, mean_mrr -------------------------------------------------------------------------------- /lib/lossfunction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | 7 | class LossFunction(nn.Module): 8 | def __init__(self, loss_type='TOP1', use_cuda=False): 9 | """ An abstract loss function that can supports custom loss functions compatible with PyTorch.""" 10 | super(LossFunction, self).__init__() 11 | self.loss_type = loss_type 12 | self.use_cuda = use_cuda 13 | if loss_type == 'CrossEntropy': 14 | self._loss_fn = SampledCrossEntropyLoss(use_cuda) 15 | elif loss_type == 'TOP1': 16 | self._loss_fn = TOP1Loss() 17 | elif loss_type == 'BPR': 18 | self._loss_fn = BPRLoss() 19 | elif loss_type == 'TOP1-max': 20 | self._loss_fn = TOP1_max() 21 | elif loss_type == 'BPR-max': 22 | self._loss_fn = BPR_max() 23 | else: 24 | raise NotImplementedError 25 | 26 | def forward(self, logit): 27 | return self._loss_fn(logit) 28 | 29 | 30 | class SampledCrossEntropyLoss(nn.Module): 31 | """ CrossEntropyLoss with n_classes = batch_size = the number of samples in the session-parallel mini-batch """ 32 | def __init__(self, use_cuda): 33 | """ 34 | Args: 35 | use_cuda (bool): whether to use cuda or not 36 | """ 37 | super(SampledCrossEntropyLoss, self).__init__() 38 | self.xe_loss = nn.CrossEntropyLoss() 39 | self.use_cuda = use_cuda 40 | 41 | def forward(self, logit): 42 | batch_size = logit.size(1) 43 | target = Variable(torch.arange(batch_size).long()) 44 | if self.use_cuda: 45 | target = target.cuda() 46 | 47 | return self.xe_loss(logit, target) 48 | 49 | 50 | class BPRLoss(nn.Module): 51 | def __init__(self): 52 | super(BPRLoss, self).__init__() 53 | 54 | def forward(self, logit): 55 | """ 56 | Args: 57 | logit (BxB): Variable that stores the logits for the items in the mini-batch 58 | The first dimension corresponds to the batches, and the second 59 | dimension corresponds to sampled number of items to evaluate 60 | """ 61 | # differences between the item scores 62 | diff = logit.diag().view(-1, 1).expand_as(logit) - logit 63 | # final loss 64 | loss = -torch.mean(F.logsigmoid(diff)) 65 | return loss 66 | 67 | 68 | class BPR_max(nn.Module): 69 | def __init__(self): 70 | super(BPR_max, self).__init__() 71 | def forward(self, logit): 72 | logit_softmax = F.softmax(logit, dim=1) 73 | diff = logit.diag().view(-1, 1).expand_as(logit) - logit 74 | loss = -torch.log(torch.mean(logit_softmax * torch.sigmoid(diff))) 75 | return loss 76 | 77 | 78 | class TOP1Loss(nn.Module): 79 | def __init__(self): 80 | super(TOP1Loss, self).__init__() 81 | def forward(self, logit): 82 | """ 83 | Args: 84 | logit (BxB): Variable that stores the logits for the items in the mini-batch 85 | The first dimension corresponds to the batches, and the second 86 | dimension corresponds to sampled number of items to evaluate 87 | """ 88 | diff = -(logit.diag().view(-1, 1).expand_as(logit) - logit) 89 | loss = torch.sigmoid(diff).mean() + torch.sigmoid(logit ** 2).mean() 90 | return loss 91 | 92 | 93 | class TOP1_max(nn.Module): 94 | def __init__(self): 95 | super(TOP1_max, self).__init__() 96 | 97 | def forward(self, logit): 98 | logit_softmax = F.softmax(logit, dim=1) 99 | diff = -(logit.diag().view(-1, 1).expand_as(logit) - logit) 100 | loss = torch.mean(logit_softmax * (torch.sigmoid(diff) + torch.sigmoid(logit ** 2))) 101 | return loss 102 | -------------------------------------------------------------------------------- /lib/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_recall(indices, targets): #recall --> wether next item in session is within top K=20 recommended items or not 5 | """ 6 | Calculates the recall score for the given predictions and targets 7 | Args: 8 | indices (Bxk): torch.LongTensor. top-k indices predicted by the model. 9 | targets (B): torch.LongTensor. actual target indices. 10 | Returns: 11 | recall (float): the recall score 12 | """ 13 | targets = targets.view(-1, 1).expand_as(indices) 14 | hits = (targets == indices).nonzero() 15 | if len(hits) == 0: 16 | return 0 17 | n_hits = (targets == indices).nonzero()[:, :-1].size(0) 18 | recall = float(n_hits) / targets.size(0) 19 | return recall 20 | 21 | 22 | def get_mrr(indices, targets): #Mean Receiprocal Rank --> Average of rank of next item in the session. 23 | """ 24 | Calculates the MRR score for the given predictions and targets 25 | Args: 26 | indices (Bxk): torch.LongTensor. top-k indices predicted by the model. 27 | targets (B): torch.LongTensor. actual target indices. 28 | Returns: 29 | mrr (float): the mrr score 30 | """ 31 | tmp = targets.view(-1, 1) 32 | targets = tmp.expand_as(indices) 33 | hits = (targets == indices).nonzero() 34 | ranks = hits[:, -1] + 1 35 | ranks = ranks.float() 36 | rranks = torch.reciprocal(ranks) 37 | mrr = torch.sum(rranks).data / targets.size(0) 38 | return mrr 39 | 40 | 41 | def evaluate(indices, targets, k=20): 42 | """ 43 | Evaluates the model using Recall@K, MRR@K scores. 44 | 45 | Args: 46 | logits (B,C): torch.LongTensor. The predicted logit for the next items. 47 | targets (B): torch.LongTensor. actual target indices. 48 | 49 | Returns: 50 | recall (float): the recall score 51 | mrr (float): the mrr score 52 | """ 53 | _, indices = torch.topk(indices, k, -1) 54 | recall = get_recall(indices, targets) 55 | mrr = get_mrr(indices, targets) 56 | return recall, mrr 57 | -------------------------------------------------------------------------------- /lib/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | class GRU4REC(nn.Module): 5 | def __init__(self, input_size, hidden_size, output_size, num_layers=1, final_act='tanh', 6 | dropout_hidden=.5, dropout_input=0, batch_size=50, embedding_dim=-1, use_cuda=False): 7 | super(GRU4REC, self).__init__() 8 | self.input_size = input_size 9 | self.hidden_size = hidden_size 10 | self.output_size = output_size 11 | self.num_layers = num_layers 12 | self.dropout_hidden = dropout_hidden 13 | self.dropout_input = dropout_input 14 | self.embedding_dim = embedding_dim 15 | self.batch_size = batch_size 16 | self.use_cuda = use_cuda 17 | self.device = torch.device('cuda' if use_cuda else 'cpu') 18 | self.onehot_buffer = self.init_emb() 19 | self.h2o = nn.Linear(hidden_size, output_size) 20 | self.create_final_activation(final_act) 21 | if self.embedding_dim != -1: 22 | self.look_up = nn.Embedding(input_size, self.embedding_dim) 23 | self.gru = nn.GRU(self.embedding_dim, self.hidden_size, self.num_layers, dropout=self.dropout_hidden) 24 | else: 25 | self.gru = nn.GRU(self.input_size, self.hidden_size, self.num_layers, dropout=self.dropout_hidden) 26 | self = self.to(self.device) 27 | 28 | def create_final_activation(self, final_act): 29 | if final_act == 'tanh': 30 | self.final_activation = nn.Tanh() 31 | elif final_act == 'relu': 32 | self.final_activation = nn.ReLU() 33 | elif final_act == 'softmax': 34 | self.final_activation = nn.Softmax() 35 | elif final_act == 'softmax_logit': 36 | self.final_activation = nn.LogSoftmax() 37 | elif final_act.startswith('elu-'): 38 | self.final_activation = nn.ELU(alpha=float(final_act.split('-')[1])) 39 | elif final_act.startswith('leaky-'): 40 | self.final_activation = nn.LeakyReLU(negative_slope=float(final_act.split('-')[1])) 41 | 42 | def forward(self, input, hidden): 43 | ''' 44 | Args: 45 | input (B,): a batch of item indices from a session-parallel mini-batch. 46 | target (B,): torch.LongTensor of next item indices from a session-parallel mini-batch. 47 | 48 | Returns: 49 | logit (B,C): Variable that stores the logits for the next items in the session-parallel mini-batch 50 | hidden: GRU hidden state 51 | ''' 52 | 53 | if self.embedding_dim == -1: 54 | embedded = self.onehot_encode(input) 55 | if self.training and self.dropout_input > 0: embedded = self.embedding_dropout(embedded) 56 | embedded = embedded.unsqueeze(0) 57 | else: 58 | embedded = input.unsqueeze(0) 59 | embedded = self.look_up(embedded) 60 | 61 | output, hidden = self.gru(embedded, hidden) #(num_layer, B, H) 62 | output = output.view(-1, output.size(-1)) #(B,H) 63 | logit = self.final_activation(self.h2o(output)) 64 | 65 | return logit, hidden 66 | 67 | def init_emb(self): 68 | ''' 69 | Initialize the one_hot embedding buffer, which will be used for producing the one-hot embeddings efficiently 70 | ''' 71 | onehot_buffer = torch.FloatTensor(self.batch_size, self.output_size) 72 | onehot_buffer = onehot_buffer.to(self.device) 73 | return onehot_buffer 74 | 75 | def onehot_encode(self, input): 76 | """ 77 | Returns a one-hot vector corresponding to the input 78 | Args: 79 | input (B,): torch.LongTensor of item indices 80 | buffer (B,output_size): buffer that stores the one-hot vector 81 | Returns: 82 | one_hot (B,C): torch.FloatTensor of one-hot vectors 83 | """ 84 | self.onehot_buffer.zero_() 85 | index = input.view(-1, 1) 86 | one_hot = self.onehot_buffer.scatter_(1, index, 1) 87 | return one_hot 88 | 89 | def embedding_dropout(self, input): 90 | p_drop = torch.Tensor(input.size(0), 1).fill_(1 - self.dropout_input) 91 | mask = torch.bernoulli(p_drop).expand_as(input) / (1 - self.dropout_input) 92 | mask = mask.to(self.device) 93 | input = input * mask 94 | return input 95 | 96 | def init_hidden(self): 97 | ''' 98 | Initialize the hidden state of the GRU 99 | ''' 100 | try: 101 | h0 = torch.zeros(self.num_layers, self.batch_size, self.hidden_size).to(self.device) 102 | except: 103 | self.device = 'cpu' 104 | h0 = torch.zeros(self.num_layers, self.batch_size, self.hidden_size).to(self.device) 105 | return h0 -------------------------------------------------------------------------------- /lib/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | class Optimizer: 5 | def __init__(self, params, optimizer_type='Adagrad', lr=.05, 6 | momentum=0, weight_decay=0, eps=1e-6): 7 | ''' 8 | An abstract optimizer class for handling various kinds of optimizers. 9 | You can specify the optimizer type and related parameters as you want. 10 | Usage is exactly the same as an instance of torch.optim 11 | 12 | Args: 13 | params: torch.nn.Parameter. The NN parameters to optimize 14 | optimizer_type: type of the optimizer to use 15 | lr: learning rate 16 | momentum: momentum, if needed 17 | weight_decay: weight decay, if needed. Equivalent to L2 regulariztion. 18 | eps: eps parameter, if needed. 19 | ''' 20 | if optimizer_type == 'RMSProp': 21 | self.optimizer = optim.RMSprop(params, lr=lr, eps=eps, weight_decay=weight_decay, momentum=momentum) 22 | elif optimizer_type == 'Adagrad': 23 | self.optimizer = optim.Adagrad(params, lr=lr, weight_decay=weight_decay) 24 | elif optimizer_type == 'Adadelta': 25 | self.optimizer = optim.Adadelta(params, lr=lr, eps=eps, weight_decay=weight_decay) 26 | elif optimizer_type == 'Adam': 27 | self.optimizer = optim.Adam(params, lr=lr, eps=eps, weight_decay=weight_decay) 28 | elif optimizer_type == 'SparseAdam': 29 | self.optimizer = optim.SparseAdam(params, lr=lr, eps=eps) 30 | elif optimizer_type == 'SGD': 31 | self.optimizer = optim.SGD(params, lr=lr, momentum=momentum, weight_decay=weight_decay) 32 | else: 33 | raise NotImplementedError 34 | 35 | def zero_grad(self): 36 | self.optimizer.zero_grad() 37 | 38 | def step(self): 39 | self.optimizer.step() -------------------------------------------------------------------------------- /lib/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lib 3 | import time 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | class Trainer(object): 10 | def __init__(self, model, train_data, eval_data, optim, use_cuda, loss_func, batch_size, args): 11 | self.model = model 12 | self.train_data = train_data 13 | self.eval_data = eval_data 14 | self.optim = optim 15 | self.loss_func = loss_func 16 | self.evaluation = lib.Evaluation(self.model, self.loss_func, use_cuda, k = args.k_eval) 17 | self.device = torch.device('cuda' if use_cuda else 'cpu') 18 | self.batch_size = batch_size 19 | self.args = args 20 | 21 | def train(self, start_epoch, end_epoch, start_time=None): 22 | if start_time is None: 23 | self.start_time = time.time() 24 | else: 25 | self.start_time = start_time 26 | 27 | for epoch in range(start_epoch, end_epoch + 1): 28 | st = time.time() 29 | print('Start Epoch #', epoch) 30 | train_loss = self.train_epoch(epoch) 31 | loss, recall, mrr = self.evaluation.eval(self.eval_data, self.batch_size) 32 | 33 | 34 | print("Epoch: {}, train loss: {:.4f}, loss: {:.4f}, recall: {:.4f}, mrr: {:.4f}, time: {}".format(epoch, train_loss, loss, recall, mrr, time.time() - st)) 35 | checkpoint = { 36 | 'model': self.model, 37 | 'args': self.args, 38 | 'epoch': epoch, 39 | 'optim': self.optim, 40 | 'loss': loss, 41 | 'recall': recall, 42 | 'mrr': mrr 43 | } 44 | model_name = os.path.join(self.args.checkpoint_dir, "model_{0:05d}.pt".format(epoch)) 45 | torch.save(checkpoint, model_name) 46 | print("Save model as %s" % model_name) 47 | 48 | 49 | def train_epoch(self, epoch): 50 | self.model.train() 51 | losses = [] 52 | 53 | def reset_hidden(hidden, mask): 54 | """Helper function that resets hidden state when some sessions terminate""" 55 | if len(mask) != 0: 56 | hidden[:, mask, :] = 0 57 | return hidden 58 | 59 | hidden = self.model.init_hidden() 60 | dataloader = lib.DataLoader(self.train_data, self.batch_size) 61 | #for ii,(data,label) in tqdm(enumerate(train_dataloader),total=len(train_data)): 62 | for ii, (input, target, mask) in tqdm(enumerate(dataloader), total=len(dataloader.dataset.df) // dataloader.batch_size, miniters = 1000): 63 | input = input.to(self.device) 64 | target = target.to(self.device) 65 | self.optim.zero_grad() 66 | hidden = reset_hidden(hidden, mask).detach() 67 | logit, hidden = self.model(input, hidden) 68 | # output sampling 69 | logit_sampled = logit[:, target.view(-1)] 70 | loss = self.loss_func(logit_sampled) 71 | losses.append(loss.item()) 72 | loss.backward() 73 | self.optim.step() 74 | 75 | mean_losses = np.mean(losses) 76 | return mean_losses -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import lib 4 | import numpy as np 5 | import os 6 | import datetime 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--hidden_size', default=100, type=int) #Literature uses 100 / 1000 --> better is 100 10 | parser.add_argument('--num_layers', default=3, type=int) #1 hidden layer 11 | parser.add_argument('--batch_size', default=50, type=int) #50 in first paper and 32 in second paper 12 | parser.add_argument('--dropout_input', default=0, type=float) #0.5 for TOP and 0.3 for BPR 13 | parser.add_argument('--dropout_hidden', default=0.5, type=float) #0.5 for TOP and 0.3 for BPR 14 | parser.add_argument('--n_epochs', default=5, type=int) #number of epochs (10 in literature) 15 | parser.add_argument('--k_eval', default=20, type=int) #value of K durig Recall and MRR Evaluation 16 | # parse the optimizer arguments 17 | parser.add_argument('--optimizer_type', default='Adagrad', type=str) #Optimizer --> Adagrad is the best according to literature 18 | parser.add_argument('--final_act', default='tanh', type=str) #Final Activation Function 19 | parser.add_argument('--lr', default=0.01, type=float) #learning rate (Best according to literature 0.01 to 0.05) 20 | parser.add_argument('--weight_decay', default=0, type=float) #no weight decay 21 | parser.add_argument('--momentum', default=0, type=float) #no momentum 22 | parser.add_argument('--eps', default=1e-6, type=float) #not used 23 | parser.add_argument("-seed", type=int, default=22, help="Seed for random initialization") #Random seed setting 24 | parser.add_argument("-sigma", type=float, default=None, help="init weight -1: range [-sigma, sigma], -2: range [0, sigma]") # weight initialization [-sigma sigma] in literature 25 | 26 | ####### TODO: discover this ########### 27 | parser.add_argument("--embedding_dim", type=int, default=-1, help="using embedding") 28 | ####### TODO: discover this ########### 29 | 30 | # parse the loss type 31 | parser.add_argument('--loss_type', default='TOP1-max', type=str) #type of loss function TOP1 / BPR / TOP1-max / BPR-max 32 | # etc 33 | parser.add_argument('--time_sort', default=False, type=bool) #In case items are not sorted by time stamp 34 | parser.add_argument('--model_name', default='GRU4REC-CrossEntropy', type=str) 35 | parser.add_argument('--save_dir', default='models', type=str) 36 | parser.add_argument('--data_folder', default='../Dataset/RecSys_Dataset_After/', type=str) 37 | parser.add_argument('--train_data', default='recSys15TrainOnly.txt', type=str) 38 | parser.add_argument('--valid_data', default='recSys15Valid.txt', type=str) 39 | parser.add_argument("--is_eval", action='store_true') #should be used during testing and eliminated during training 40 | parser.add_argument('--load_model', default=None, type=str) 41 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint') 42 | 43 | # Get the arguments 44 | args = parser.parse_args() 45 | args.cuda = torch.cuda.is_available() 46 | #use random seed defined 47 | np.random.seed(args.seed) 48 | torch.manual_seed(args.seed) 49 | 50 | 51 | if args.cuda: 52 | torch.cuda.manual_seed(args.seed) 53 | 54 | #Write Checkpoints with arguments used in a text file for reproducibility 55 | def make_checkpoint_dir(): 56 | print("PARAMETER" + "-"*10) 57 | now = datetime.datetime.now() 58 | S = '{:02d}{:02d}{:02d}{:02d}'.format(now.month, now.day, now.hour, now.minute) 59 | save_dir = os.path.join(args.checkpoint_dir, S) 60 | if not os.path.exists(args.checkpoint_dir): 61 | os.mkdir(args.checkpoint_dir) 62 | 63 | if not os.path.exists(save_dir): 64 | os.mkdir(save_dir) 65 | args.checkpoint_dir = save_dir 66 | with open(os.path.join(args.checkpoint_dir, 'parameter.txt'), 'w') as f: 67 | for attr, value in sorted(args.__dict__.items()): 68 | print("{}={}".format(attr.upper(), value)) 69 | f.write("{}={}\n".format(attr.upper(), value)) 70 | print("---------" + "-"*10) 71 | 72 | #weight initialization if it was defined 73 | def init_model(model): 74 | if args.sigma is not None: 75 | for p in model.parameters(): 76 | if args.sigma != -1 and args.sigma != -2: 77 | sigma = args.sigma 78 | p.data.uniform_(-sigma, sigma) 79 | elif len(list(p.size())) > 1: 80 | sigma = np.sqrt(6.0 / (p.size(0) + p.size(1))) 81 | if args.sigma == -1: 82 | p.data.uniform_(-sigma, sigma) 83 | else: 84 | p.data.uniform_(0, sigma) 85 | 86 | 87 | def main(): 88 | print("Loading train data from {}".format(os.path.join(args.data_folder, args.train_data))) 89 | print("Loading valid data from {}".format(os.path.join(args.data_folder, args.valid_data))) 90 | 91 | train_data = lib.Dataset(os.path.join(args.data_folder, args.train_data)) 92 | valid_data = lib.Dataset(os.path.join(args.data_folder, args.valid_data), itemmap=train_data.itemmap) 93 | make_checkpoint_dir() 94 | 95 | #set all the parameters according to the defined arguments 96 | input_size = len(train_data.items) 97 | hidden_size = args.hidden_size 98 | num_layers = args.num_layers 99 | output_size = input_size 100 | batch_size = args.batch_size 101 | dropout_input = args.dropout_input 102 | dropout_hidden = args.dropout_hidden 103 | embedding_dim = args.embedding_dim 104 | final_act = args.final_act 105 | loss_type = args.loss_type 106 | optimizer_type = args.optimizer_type 107 | lr = args.lr 108 | weight_decay = args.weight_decay 109 | momentum = args.momentum 110 | eps = args.eps 111 | n_epochs = args.n_epochs 112 | time_sort = args.time_sort 113 | #loss function 114 | loss_function = lib.LossFunction(loss_type=loss_type, use_cuda=args.cuda) #cuda is used with cross entropy only 115 | if not args.is_eval: #training 116 | #Initialize the model 117 | model = lib.GRU4REC(input_size, hidden_size, output_size, final_act=final_act, 118 | num_layers=num_layers, use_cuda=args.cuda, batch_size=batch_size, 119 | dropout_input=dropout_input, dropout_hidden=dropout_hidden, embedding_dim=embedding_dim) 120 | #weights initialization 121 | init_model(model) 122 | #optimizer 123 | optimizer = lib.Optimizer(model.parameters(), optimizer_type=optimizer_type, lr=lr, 124 | weight_decay=weight_decay, momentum=momentum, eps=eps) 125 | #trainer class 126 | trainer = lib.Trainer(model, train_data=train_data, eval_data=valid_data, optim=optimizer, 127 | use_cuda=args.cuda, loss_func=loss_function, batch_size=batch_size, args=args) 128 | print('#### START TRAINING....') 129 | trainer.train(0, n_epochs - 1) 130 | else: #testing 131 | if args.load_model is not None: 132 | print("Loading pre-trained model from {}".format(args.load_model)) 133 | try: 134 | checkpoint = torch.load(args.load_model) 135 | except: 136 | checkpoint = torch.load(args.load_model, map_location=lambda storage, loc: storage) 137 | model = checkpoint["model"] 138 | model.gru.flatten_parameters() 139 | evaluation = lib.Evaluation(model, loss_function, use_cuda=args.cuda, k = args.k_eval) 140 | loss, recall, mrr = evaluation.eval(valid_data, batch_size) 141 | print("Final result: recall = {:.2f}, mrr = {:.2f}".format(recall, mrr)) 142 | else: 143 | print("No Pretrained Model was found!") 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Sep 10 09:50:45 2019 4 | @author: s-moh 5 | """ 6 | import numpy as np 7 | import pandas as pd 8 | import datetime 9 | 10 | dataBefore = 'C:/Users/s-moh/0-Labwork/Rakuten Project/Dataset/RecSys_Dataset_Before/yoochoose-clicks.dat' #Path to Original Training Dataset "Clicks" File 11 | dataTestBefore = 'C:/Users/s-moh/0-Labwork/Rakuten Project/Dataset/RecSys_Dataset_Before/yoochoose-test.dat' #Path to Original Testing Dataset "Clicks" File 12 | dataAfter = 'C:/Users/s-moh/0-Labwork/Rakuten Project/Dataset/RecSys_Dataset_After/' #Path to Processed Dataset Folder 13 | dayTime = 86400 #Validation Only one day = 86400 seconds 14 | 15 | def removeShortSessions(data): 16 | #delete sessions of length < 1 17 | sessionLen = data.groupby('SessionID').size() #group by sessionID and get size of each session 18 | data = data[np.in1d(data.SessionID, sessionLen[sessionLen > 1].index)] 19 | return data 20 | 21 | #Read Dataset in pandas Dataframe (Ignore Category Column) 22 | train = pd.read_csv(dataBefore, sep=',', header=None, usecols=[0,1,2], dtype={0:np.int32, 1:str, 2:np.int64}) 23 | test = pd.read_csv(dataTestBefore, sep=',', header=None, usecols=[0,1,2], dtype={0:np.int32, 1:str, 2:np.int64}) 24 | train.columns = ['SessionID', 'Time', 'ItemID'] #Headers of dataframe 25 | test.columns = ['SessionID', 'Time', 'ItemID'] #Headers of dataframe 26 | train['Time']= train.Time.apply(lambda x: datetime.datetime.strptime(x, '%Y-%m-%dT%H:%M:%S.%fZ').timestamp()) #Convert time objects to timestamp 27 | test['Time'] = test.Time.apply(lambda x: datetime.datetime.strptime(x, '%Y-%m-%dT%H:%M:%S.%fZ').timestamp()) #Convert time objects to timestamp 28 | 29 | #remove sessions of less than 2 interactions 30 | train = removeShortSessions(train) 31 | #delete records of items which appeared less than 5 times 32 | itemLen = train.groupby('ItemID').size() #groupby itemID and get size of each item 33 | train = train[np.in1d(train.ItemID, itemLen[itemLen > 4].index)] 34 | #remove sessions of less than 2 interactions again 35 | train = removeShortSessions(train) 36 | 37 | ######################################################################################################3 38 | ''' 39 | #Separate Data into Train and Test Splits 40 | timeMax = data.Time.max() #maximum time in all records 41 | sessionMaxTime = data.groupby('SessionID').Time.max() #group by sessionID and get the maximum time of each session 42 | sessionTrain = sessionMaxTime[sessionMaxTime < (timeMax - dayTime)].index #training split is all sessions that ended before the last day 43 | sessionTest = sessionMaxTime[sessionMaxTime >= (timeMax - dayTime)].index #testing split is all sessions has records in the last day 44 | train = data[np.in1d(data.SessionID, sessionTrain)] 45 | test = data[np.in1d(data.SessionID, sessionTest)] 46 | ''' 47 | #Delete records in testing split where items are not in training split 48 | test = test[np.in1d(test.ItemID, train.ItemID)] 49 | #Delete Sessions in testing split which are less than 2 50 | test = removeShortSessions(test) 51 | 52 | #Convert To CSV 53 | #print('Full Training Set has', len(train), 'Events, ', train.SessionID.nunique(), 'Sessions, and', train.ItemID.nunique(), 'Items\n\n') 54 | #train.to_csv(dataAfter + 'recSys15TrainFull.txt', sep='\t', index=False) 55 | print('Testing Set has', len(test), 'Events, ', test.SessionID.nunique(), 'Sessions, and', test.ItemID.nunique(), 'Items\n\n') 56 | test.to_csv(dataAfter + 'recSys15Test.txt', sep=',', index=False) 57 | 58 | ######################################################################################################3 59 | #Separate Training set into Train and Validation Splits 60 | timeMax = train.Time.max() 61 | sessionMaxTime = train.groupby('SessionID').Time.max() 62 | sessionTrain = sessionMaxTime[sessionMaxTime < (timeMax - dayTime)].index #training split is all sessions that ended before the last 2nd day 63 | sessionValid = sessionMaxTime[sessionMaxTime >= (timeMax - dayTime)].index #validation split is all sessions that ended during the last 2nd day 64 | trainTR = train[np.in1d(train.SessionID, sessionTrain)] 65 | trainVD = train[np.in1d(train.SessionID, sessionValid)] 66 | #Delete records in validation split where items are not in training split 67 | trainVD = trainVD[np.in1d(trainVD.ItemID, trainTR.ItemID)] 68 | #Delete Sessions in testing split which are less than 2 69 | trainVD = removeShortSessions(trainVD) 70 | #Convert To CSV 71 | print('Training Set has', len(trainTR), 'Events, ', trainTR.SessionID.nunique(), 'Sessions, and', trainTR.ItemID.nunique(), 'Items\n\n') 72 | trainTR.to_csv(dataAfter + 'recSys15TrainOnly.txt', sep=',', index=False) 73 | print('Validation Set has', len(trainVD), 'Events, ', trainVD.SessionID.nunique(), 'Sessions, and', trainVD.ItemID.nunique(), 'Items\n\n') 74 | trainVD.to_csv(dataAfter + 'recSys15Valid.txt', sep=',', index=False) -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | if __name__ == '__main__': 2 | file = "data/raw_data/yoochoose-clicks.dat" 3 | file_out = "data/raw_data/yoochoose-clicks-super-small.dat" 4 | content = [] 5 | with open(file, 'r') as f: 6 | for line in f: 7 | content.append(line) 8 | print(len(content)) 9 | 10 | small_index = len(content) // 8 11 | #small_index = 100 12 | with open(file_out, 'w') as f: 13 | for line in content[-small_index:]: 14 | f.write(line) --------------------------------------------------------------------------------