├── .gitignore ├── LICENSE ├── utils ├── fakedata.py └── preprocess.py ├── cell.py ├── README.md ├── docs └── running-on-aws.md ├── reader.py └── crnn.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 4 | Young-Jun Ko 5 | Lucas Maystre 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /utils/fakedata.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import numpy as np 6 | 7 | 8 | def softmax(xs): 9 | zs = np.exp(xs - np.max(xs)) 10 | return zs / zs.sum(axis=0) 11 | 12 | 13 | def rnn_sequence(ws_in, ws_out, ws_h, initial_state): 14 | state = initial_state 15 | nb_items = ws_in.shape[0] 16 | while True: 17 | probs = softmax(np.dot(state, ws_out)) 18 | event = np.random.choice(nb_items, p=probs) 19 | yield event + 1 20 | state = np.tanh(np.dot(state, ws_h) + ws_in[event]) 21 | 22 | 23 | def main(args): 24 | ws_in = np.random.randn(args.nb_items, args.hidden_size) 25 | ws_out = np.random.randn(args.hidden_size, args.nb_items) 26 | train_path = "{}-train.txt".format(args.prefix) 27 | valid_path = "{}-valid.txt".format(args.prefix) 28 | with open(train_path, "w") as tf, open(valid_path, "w") as vf: 29 | for u in range(1, args.nb_users + 1): 30 | ws_h = np.random.randn(args.hidden_size, args.hidden_size) 31 | initial_state = np.zeros(args.hidden_size) 32 | seq = rnn_sequence(ws_in, ws_out, ws_h, initial_state) 33 | for t in range(1, args.train_seq_length + 1): 34 | tf.write("{} {} {}\n".format(u, next(seq), t)) 35 | offset = args.train_seq_length 36 | for t in range(1, args.valid_seq_length + 1): 37 | vf.write("{} {} {}\n".format(u, next(seq), offset + t)) 38 | 39 | 40 | def _parse_args(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--prefix', default="fakedata") 43 | parser.add_argument('--hidden-size', type=int, default=8) 44 | parser.add_argument('--nb-users', type=int, default=5) 45 | parser.add_argument('--nb-items', type=int, default=10) 46 | parser.add_argument('--train-seq-length', type=int, default=128) 47 | parser.add_argument('--valid-seq-length', type=int, default=128) 48 | return parser.parse_args() 49 | 50 | 51 | if __name__ == '__main__': 52 | args = _parse_args() 53 | main(args) 54 | -------------------------------------------------------------------------------- /cell.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import tensorflow as tf 5 | 6 | 7 | class CollaborativeGRUCell(tf.nn.rnn_cell.RNNCell): 8 | 9 | def __init__(self, num_units, num_users, num_items): 10 | """Note: users are numbered 1 to N, items are numbered 1 to M. User and 11 | item "zero" is reserved for padding purposes. 12 | """ 13 | self._num_units = num_units 14 | self._num_users = num_users 15 | self._num_items = num_items 16 | 17 | @property 18 | def state_size(self): 19 | return self._num_units 20 | 21 | @property 22 | def output_size(self): 23 | return self._num_units 24 | 25 | def __call__(self, inputs, state, scope=None): 26 | # shape(inputs) = [batch_size, input_size] 27 | # shape(state) = [batch_size, num_units] 28 | with tf.variable_scope(scope or type(self).__name__): # "CollaborativeGRUCell" 29 | with tf.variable_scope("Gates"): 30 | with tf.device("/cpu:0"): 31 | users = tf.get_variable("users", 32 | [self._num_users + 1, self._num_units, 2 * self._num_units], 33 | dtype=tf.float32) 34 | # shape(w_hidden_u) = [batch_size, num_units, 2 * num_units] 35 | w_hidden_u = tf.nn.embedding_lookup(users, inputs[:,0]) 36 | items = tf.get_variable("items", 37 | [self._num_items + 1, 2 * self._num_units], 38 | dtype=tf.float32) 39 | # shape(w_input_i) = [batch_size, 2 * num_units] 40 | w_input_i = tf.nn.embedding_lookup(items, inputs[:,1]) 41 | res = tf.matmul(tf.expand_dims(state, 1), w_hidden_u) 42 | res = tf.sigmoid(tf.squeeze(res, [1]) + w_input_i) 43 | r, z = tf.split(value=res, num_or_size_splits=2, axis=1) 44 | with tf.variable_scope("Candidate"): 45 | with tf.device("/cpu:0"): 46 | users = tf.get_variable("users", 47 | [self._num_users + 1, self._num_units, self._num_units], 48 | dtype=tf.float32) 49 | # shape(w_hidden_u) = [batch_size, num_units, num_units] 50 | w_hidden_u = tf.nn.embedding_lookup(users, inputs[:,0]) 51 | items = tf.get_variable("items", 52 | [self._num_items + 1, self._num_units], 53 | dtype=tf.float32) 54 | # shape(w_input_i) = [batch_size, num_units] 55 | w_input_i = tf.nn.embedding_lookup(items, inputs[:,1]) 56 | res = tf.matmul(tf.expand_dims(r * state, 1), w_hidden_u) 57 | c = tf.sigmoid(tf.squeeze(res, [1]) + w_input_i) 58 | new_h = z * state + (1 - z) * c 59 | return new_h, new_h 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Collaborative RNN 2 | 3 | This is a TensorFlow implementation of the Collaborative RNN presented in the 4 | paper 5 | 6 | > **Collaborative Recurrent Neural Networks for Dynamic Recommender Systems**, 7 | > Young-Jun Ko, Lucas Maystre, Matthias Grossglauser, ACML, 2016. 8 | 9 | A PDF of the paper can be found 10 | [here](https://infoscience.epfl.ch/record/222477/files/ko101.pdf). 11 | 12 | ## Requirements 13 | 14 | The code is tested with 15 | 16 | - Python 2.7.12 and 3.5.1 17 | - NumPy 1.13.3 18 | - TensorFlow 1.4.0 19 | - CUDA 8.0 20 | - cuDNN 6.0 21 | - six 1.11.0 22 | 23 | If you are interested in quickly testing out our code, you might want to **check 24 | out our [step-by-step guide][1]** for running the collaborative RNN on an AWS 25 | EC2 p2.xlarge instance. 26 | 27 | ## Quickstart 28 | 29 | Reproducing the results of the paper should be as easy as following these three 30 | steps. 31 | 32 | 1. Download the datasets. 33 | - The last.fm dataset is available on [Òscar Celma's page][2]. The relevant 34 | file is `userid-timestamp-artid-artname-traid-traname.tsv`. 35 | - The BrighKite dataset is available at [SNAP][3]. The relevant file is 36 | `loc-brightkite_totalCheckins.txt`. 37 | 2. Preprocess the data (relabel user and items, remove degenerate cases, split 38 | into training and validation sets). This can be done using the script 39 | `utils/preprocess.py`. For example, for BrightKite: 40 | 41 | python utils/preprocess.py brightkite path/to/raw_file.txt 42 | 43 | This will create two files named `brightkite-train.txt` and 44 | `brightkite-valid.txt`. 45 | 3. Run `crnn.py` on the preprocessed data. For example for BrightKite, you 46 | might want to try running 47 | 48 | python -u crnn.py brightkite-{train,valid}.txt --hidden-size=32 \ 49 | --learning-rate=0.0075 --rho=0.997 \ 50 | --chunk-size=64 --batch-size=20 --num-epochs=25 51 | 52 | Here is a table that summarizes the settings that gave us the results published 53 | in the paper. All the setting can be passed as command line arguments to 54 | `crnn.py`. 55 | 56 | | Argument | BrightKite | last.fm | 57 | | -------------------- | ---------- | ------- | 58 | | `--batch-size` | 20 | 20 | 59 | | `--chunk-size` | 64 | 64 | 60 | | `--hidden-size` | 32 | 128 | 61 | | `--learning-rate` | 0.0075 | 0.01 | 62 | | `--max-train-chunks` | *(None)* | 80 | 63 | | `--max-valid-chunks` | *(None)* | 8 | 64 | | `--num-epochs` | 25 | 10 | 65 | | `--rho` | 0.997 | 0.997 | 66 | 67 | On a modern server with an Nvidia Titan X (Maxwell generation) GPU it takes 68 | around 40 seconds per epoch for the BrightKite dataset, and around 14 minutes 69 | per epoch on the last.fm dataset. 70 | 71 | [1]: docs/running-on-aws.md 72 | [2]: http://www.dtic.upf.edu/~ocelma/MusicRecommendationDataset/lastfm-1K.html 73 | [3]: https://snap.stanford.edu/data/loc-brightkite.html 74 | -------------------------------------------------------------------------------- /docs/running-on-aws.md: -------------------------------------------------------------------------------- 1 | # Instructions for running the code on AWS EC2. 2 | 3 | This document provides step-by-step instructions on running the code on an 4 | Amazon EC2 [p2.xlarge instance][1]. This instance type features an Nvidia K80 5 | GPU. 6 | 7 | In EC2's "Launch Instance" wizard, choose the following settings. 8 | 9 | - image: Ubuntu Server 16.04 LTS (HVM), SSD Volume Type 10 | - instance type: p2.xlarge 11 | - add storage: increase the the size of the root volume to 30 GiB. 12 | 13 | Make sure that you can connect to the instance via SSH from your network 14 | location, by choosing appropriate virtual private cloud and security group 15 | settings. The rest of this guide assumes that you are logged into the instance 16 | via SSH. 17 | 18 | 19 | ## Basic setup 20 | 21 | First, we make sure that the operating system is up to date and install a few 22 | packages that we need. 23 | 24 | sudo apt-get update 25 | sudo apt-get upgrade 26 | sudo apt-get install python-pip python-numpy python-scipy \ 27 | libopenblas-dev libcupti-dev 28 | 29 | Second, we install CUDA 8.0. It is publicly available on [Nvidia's website][2]. 30 | 31 | wget https://developer.nvidia.com/compute/cuda/8.0/Prod2/local_installers/cuda-repo-ubuntu1604-8-0-local-ga2_8.0.61-1_amd64-deb 32 | sudo dpkg -i cuda-repo-ubuntu1604-8-0-local-ga2_8.0.61-1_amd64-deb 33 | sudo apt-get update 34 | sudo apt-get install cuda 35 | 36 | Third, we install cuDNN 6.0, another library provided by Nvidia. Unfortunately, 37 | this one is only accessible for registered members of [Nvidia's developers 38 | program][3]. However, this program is free, and creating an account takes only 39 | a couple of minutes. Once you are a member, you can find cuDNN [here][4]. 40 | Download it, and copy it to the EC2 instance. Then, execute the following 41 | commands. 42 | 43 | tar xvf cudnn-8.0-linux-x64-v6.0.tgz 44 | sudo cp cuda/include/cudnn.h /usr/local/cuda/include/ 45 | sudo cp cuda/lib64/libcudnn* /usr/local/cuda/lib64/ 46 | 47 | We also need to modify an environment variable so that other software can pick 48 | up where cuDNN is installed. In `~/.bashrc`, add the following line at the end 49 | of the file. 50 | 51 | export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 52 | 53 | Then, reload the file using `source ~/.bashrc`. Lastly, we install TensorFlow. 54 | 55 | sudo pip install tensorflow-gpu==1.4.0 56 | 57 | 58 | ## Running the collaborative RNN 59 | 60 | We clone the repository and create a `data/` folder which will contain the 61 | datasets. 62 | 63 | git clone https://github.com/lca4/collaborative-rnn.git 64 | cd collaborative-rnn/ 65 | mkdir data 66 | 67 | 68 | ### Brightkite dataset 69 | 70 | Downloading and preprocessing the data. 71 | 72 | cd data/ 73 | wget https://snap.stanford.edu/data/loc-brightkite_totalCheckins.txt.gz 74 | gunzip loc-brightkite_totalCheckins.txt.gz 75 | cd ~/collaborative-rnn 76 | python utils/preprocess.py --output-dir data/ brightkite data/loc-brightkite_totalCheckins.txt 77 | 78 | Running the collaborative RNN. 79 | 80 | python -u crnn.py data/brightkite-{train,valid}.txt --verbose \ 81 | --hidden-size=32 --learning-rate=0.0075 --rho=0.997 \ 82 | --chunk-size=64 --batch-size=20 --num-epochs=25 83 | 84 | For this dataset, it should take less than a minute per epoch. 85 | 86 | ### Last.fm dataset 87 | 88 | Downloading and preprocessing the data. 89 | 90 | cd data/ 91 | wget http://mtg.upf.edu/static/datasets/last.fm/lastfm-dataset-1K.tar.gz 92 | tar -xvf lastfm-dataset-1K.tar.gz 93 | cd ~/collaborative-rnn 94 | 95 | # The next command takes ~6 minutes to complete. 96 | python utils/preprocess.py --output-dir data/ lastfm data/lastfm-dataset-1K/userid-timestamp-artid-artname-traid-traname.tsv 97 | 98 | Running the collaborative RNN. 99 | 100 | python -u crnn.py data/lastfm-{train,valid}.txt --verbose \ 101 | --hidden-size=128 --learning-rate=0.01 --rho=0.997 \ 102 | --max-train-chunks=80 --max-valid-chunks=8 \ 103 | --chunk-size=64 --batch-size=20 --num-epochs=10 104 | 105 | For this dataset, it should take about 15 minutes per epoch. 106 | 107 | [1]: https://aws.amazon.com/ec2/instance-types/p2/ 108 | [2]: https://developer.nvidia.com/cuda-downloads 109 | [3]: https://developer.nvidia.com/ 110 | [4]: https://developer.nvidia.com/rdp/cudnn-download 111 | -------------------------------------------------------------------------------- /reader.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import collections 5 | import numpy as np 6 | import six 7 | 8 | from math import ceil 9 | 10 | 11 | class Dataset(object): 12 | 13 | def __init__(self, num_users, num_items, seq_dict): 14 | self._seq_dict = seq_dict 15 | self._num_users = num_users 16 | self._num_items = num_items 17 | # These variables are set after calling `prepare_batches`. 18 | self._users_in_batches = None 19 | self._batches = None 20 | self._seq_lengths = None 21 | self._chunk_size = None 22 | 23 | @property 24 | def num_users(self): 25 | return self._num_users 26 | 27 | @property 28 | def num_items(self): 29 | return self._num_items 30 | 31 | @property 32 | def num_triplets(self): 33 | return sum(len(seq) for u, seq in self) 34 | 35 | @property 36 | def num_batches(self): 37 | if self._batches is None: 38 | raise RuntimeError("`prepare_batches` has not been called yet.") 39 | return len(self._batches) 40 | 41 | @property 42 | def users_in_batches(self): 43 | if self._users_in_batches is None: 44 | raise RuntimeError("`prepare_batches` has not been called yet.") 45 | return self._users_in_batches 46 | 47 | def __getitem__(self, u): 48 | return self._seq_dict[u] 49 | 50 | def __iter__(self): 51 | return six.iteritems(self._seq_dict) 52 | 53 | def truncate_seqs(self, max_size, keep_first=False): 54 | for user in self._seq_dict.keys(): 55 | if keep_first: 56 | self._seq_dict[user] = self._seq_dict[user][:max_size] 57 | else: 58 | self._seq_dict[user] = self._seq_dict[user][-max_size:] 59 | 60 | def iter_batches(self, order=None): 61 | if order is None: 62 | order = range(self.num_batches) 63 | if self._batches is None: 64 | raise RuntimeError("`prepare_batches` has not been called yet.") 65 | cs = self._chunk_size 66 | def iter_batch(batch, seq_length): 67 | num_cols = batch.shape[1] 68 | for i, z in enumerate(range(0, num_cols - 1, cs)): 69 | inputs = batch[:,z:z+cs,:] 70 | targets = batch[:,(z+1):(z+cs+1),1] 71 | yield (inputs, targets, seq_length[:,i]) 72 | for i in order: 73 | yield iter_batch(self._batches[i], self._seq_lengths[i]) 74 | 75 | def prepare_batches(self, chunk_size, batch_size, batches_like=None): 76 | # Spread users over batches. 77 | if batches_like is not None: 78 | self._users_in_batches = batches_like.users_in_batches 79 | else: 80 | self._users_in_batches = Dataset._assign_users_to_batches( 81 | batch_size, self._seq_dict) 82 | # Build the batches and record the corresponding valid sequence lengths. 83 | self._chunk_size = chunk_size 84 | self._batches = list() 85 | self._seq_lengths = list() 86 | for users in self._users_in_batches: 87 | lengths = tuple(len(self[u]) for u in users) 88 | num_chunks = int(ceil(max(max(lengths) - 1, chunk_size) 89 | / chunk_size)) 90 | num_cols = num_chunks * chunk_size + 1 91 | batch = np.zeros((batch_size, num_cols, 2), dtype=np.int32) 92 | seq_length = np.zeros((batch_size, num_chunks), dtype=np.int32) 93 | for i, (user, length) in enumerate(zip(users, lengths)): 94 | # Assign the values to the batch. 95 | batch[i,:length,0] = user 96 | batch[i,:length,1] = self[user] 97 | # Compute and assign the valid sequence lengths. 98 | q, r = divmod(max(0, min(num_cols, length) - 1), chunk_size) 99 | seq_length[i,:q] = chunk_size 100 | if r > 0: 101 | seq_length[i,q] = r 102 | self._batches.append(batch) 103 | self._seq_lengths.append(seq_length) 104 | 105 | @staticmethod 106 | def _assign_users_to_batches(batch_size, seq_dict): 107 | lengths, users = zip(*sorted(((len(seq), u) 108 | for u, seq in six.iteritems(seq_dict)), reverse=True)) 109 | return tuple(users[i:i+batch_size] 110 | for i in range(0, len(users), batch_size)) 111 | 112 | @classmethod 113 | def from_path(cls, path): 114 | data = collections.defaultdict(list) 115 | num_users = 0 116 | num_items = 0 117 | with open(path) as f: 118 | for line in f: 119 | u, i, t = map(int, line.strip().split()) 120 | num_users = max(u, num_users) # Users are numbered 1 -> N. 121 | num_items = max(i, num_items) # Items are numbered 1 -> M. 122 | data[u].append((t, i)) 123 | sequence = dict() 124 | for user in range(1, num_users + 1): 125 | if user in data: 126 | sequence[user] = np.array([i for t, i in sorted(data[user])]) 127 | else: 128 | sequence[user] = np.array([]) 129 | return cls(num_users, num_items, sequence) 130 | -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import collections 6 | import datetime 7 | import itertools 8 | import os.path 9 | import time 10 | 11 | from scipy.stats import entropy 12 | 13 | 14 | BK_ENTROPY_CUTOFF = 2.5 15 | LFM_ENTROPY_CUTOFF = 3.0 16 | 17 | MIN_OCCURRENCES = 10 18 | MIN_VALID_SEQ_LEN = 3 19 | MAX_VALID_SEQ_LEN = 500 20 | 21 | 22 | def parse_brightkite(path): 23 | """Parse the BrightKite dataset. 24 | 25 | This takes as input the file `loc-brightkite_totalCheckins.txt` available 26 | at the following URL: . 27 | """ 28 | # Format: [user] [check-in time] [latitude] [longitude] [location id]. 29 | with open(path) as f: 30 | for i, line in enumerate(f): 31 | try: 32 | usr, ts, lat, lon, loc = line.strip().split('\t') 33 | except ValueError: 34 | print("could not parse line {} ('{}'), ignoring".format( 35 | i, line.strip())) 36 | continue 37 | dt = datetime.datetime.strptime(ts, "%Y-%m-%dT%H:%M:%SZ") 38 | ts = time.mktime(dt.timetuple()) 39 | yield (usr, loc, ts) 40 | 41 | 42 | def parse_lastfm(path): 43 | """Parse the last.fm dataset. 44 | 45 | This takes as input the file 46 | `userid-timestamp-artid-artname-traid-traname.tsv` available at the 47 | following URL: 48 | . 49 | """ 50 | # Format: [user] [timestamp] [artist ID] [artist] [track ID] [track]. 51 | with open(path) as f: 52 | for i, line in enumerate(f): 53 | try: 54 | usr, ts, aid, artist, tid, track = line.strip().split('\t') 55 | except ValueError: 56 | print("could not parse line {} ('{}'), ignoring".format( 57 | i, line.strip())) 58 | continue 59 | dt = datetime.datetime.strptime(ts, "%Y-%m-%dT%H:%M:%SZ") 60 | ts = time.mktime(dt.timetuple()) 61 | yield (usr, aid, ts) 62 | 63 | 64 | def preprocess(stream, output_dir, prefix="processed", min_entropy=0.0): 65 | """Preprocess a stream of (user, item, timestamp) triplets. 66 | 67 | The preprocessing roughly includes the following steps: 68 | 69 | - remove items that occur infrequently, 70 | - remove users that consume very few items, 71 | - remove users who do not consume "diverse enough" items, 72 | - separate data into training and validation sets, 73 | - make sure that items in the validation sets appear at least once in the 74 | training set, 75 | - relabel items and users with consecutive integers. 76 | """ 77 | # Step 1: read stream and count number of item occurrences. 78 | data = list() 79 | occurrences = collections.defaultdict(lambda: 0) 80 | for user, item, ts in stream: 81 | data.append((user, item, ts)) 82 | occurrences[item] += 1 83 | # Step 2: remove items that occurred infrequently, create user seqs. 84 | tmp_dict = collections.defaultdict(list) 85 | for user, item, ts in data: 86 | if occurrences[item] < MIN_OCCURRENCES: 87 | continue 88 | tmp_dict[user].append((ts, item)) 89 | # Step 3: order user sequences by timestamp. 90 | seq_dict = dict() 91 | for user, seq in tmp_dict.items(): 92 | seq = [item for ts, item in sorted(seq)] 93 | seq_dict[user] = seq 94 | # Step 4: split into training and validation sets. Ignore users who 95 | # consumed few items or who do not meet entropy requirements. 96 | train = dict() 97 | valid = dict() 98 | for user, seq in seq_dict.items(): 99 | if len(seq) <= MIN_OCCURRENCES: 100 | continue 101 | hist = collections.defaultdict(lambda: 0) 102 | for item in seq: 103 | hist[item] += 1 104 | if entropy(list(hist.values())) <= min_entropy: 105 | continue 106 | # Implementation note: round(0.025 * 100) gives 3.0 in Python, but 2.0 107 | # in Julia. Beware! Results might differ! 108 | cutoff = min(MAX_VALID_SEQ_LEN, max(MIN_VALID_SEQ_LEN, 109 | int(round(0.025 * len(seq))))) 110 | train[user] = seq[:-cutoff] 111 | valid[user] = seq[-cutoff:] 112 | # Step 5: relabel users and items, and remove items that do not appear in 113 | # the training sequences. 114 | items = set(itertools.chain(*train.values())) 115 | users = set(train.keys()) 116 | user2id = dict(zip(users, range(1, len(users) + 1))) 117 | item2id = dict(zip(items, range(1, len(items) + 1))) 118 | train2 = dict() 119 | valid2 = dict() 120 | for user in users: 121 | train2[user2id[user]] = tuple(map(lambda x: item2id[x], train[user])) 122 | valid2[user2id[user]] = tuple(map(lambda x: item2id[x], 123 | filter(lambda x: x in items, valid[user]))) 124 | # Step 6: write out the sequences. 125 | train_path = os.path.join(output_dir, "{}-train.txt".format(prefix)) 126 | valid_path = os.path.join(output_dir, "{}-valid.txt".format(prefix)) 127 | with open(train_path, "w") as tf, open(valid_path, "w") as vf: 128 | for uid in user2id.values(): 129 | t = 1 130 | for iid in train2[uid]: 131 | tf.write("{} {} {}\n".format(uid, iid, t)) 132 | t += 1 133 | for iid in valid2[uid]: 134 | vf.write("{} {} {}\n".format(uid, iid, t)) 135 | t += 1 136 | print("Done.") 137 | 138 | 139 | def main(args): 140 | if args.which == "brightkite": 141 | stream = parse_brightkite(args.path) 142 | cutoff = BK_ENTROPY_CUTOFF 143 | elif args.which == "lastfm": 144 | stream = parse_lastfm(args.path) 145 | cutoff = LFM_ENTROPY_CUTOFF 146 | else: 147 | raise RuntimeError("unknown dataset?!") 148 | preprocess(stream, args.output_dir, 149 | prefix=args.which, 150 | min_entropy=cutoff) 151 | 152 | 153 | def _parse_args(): 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument("which", choices=("brightkite", "lastfm")) 156 | parser.add_argument("path") 157 | parser.add_argument("--output-dir", default="./") 158 | return parser.parse_args() 159 | 160 | 161 | if __name__ == '__main__': 162 | args = _parse_args() 163 | main(args) 164 | -------------------------------------------------------------------------------- /crnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import numpy as np 6 | import six.moves 7 | import tensorflow as tf 8 | import time 9 | 10 | from math import sqrt 11 | from cell import CollaborativeGRUCell 12 | from reader import Dataset 13 | 14 | 15 | class CollaborativeRNN(object): 16 | 17 | def __init__(self, num_users, num_items, is_training, 18 | chunk_size=128, batch_size=1, hidden_size=128, 19 | learning_rate=0.1, rho=0.9): 20 | self._batch_size = batch_size 21 | 22 | # placeholders for input data 23 | self._inputs = tf.placeholder(tf.int32, name="inputs", 24 | shape=[batch_size, chunk_size, 2]) 25 | self._targets = tf.placeholder(tf.int32, name="targets", 26 | shape=[batch_size, chunk_size]) 27 | self._seq_length = tf.placeholder(tf.int32, name="seq_length", 28 | shape=[batch_size]) 29 | 30 | # RNN cell. 31 | cell = CollaborativeGRUCell(hidden_size, num_users, num_items) 32 | self._initial_state = cell.zero_state(batch_size, tf.float32) 33 | 34 | inputs = [tf.squeeze(input_, [1]) for input_ 35 | in tf.split(self._inputs, chunk_size, axis=1)] 36 | states, _ = tf.nn.static_rnn(cell, inputs, 37 | initial_state=self._initial_state) 38 | 39 | # Compute the final state for each element of the batch. 40 | self._final_state = tf.gather_nd([self._initial_state] + states, 41 | tf.transpose(tf.stack( 42 | [self._seq_length, tf.range(batch_size)]))) 43 | 44 | # Output layer. 45 | # `output` has shape (batch_size * chunk_size, hidden_size). 46 | output = tf.reshape(tf.concat(states, axis=1), [-1, hidden_size]) 47 | with tf.variable_scope("output"): 48 | ws = tf.get_variable("weights", [hidden_size, num_items + 1], 49 | dtype=tf.float32) 50 | # `logits` has shape (batch_size * chunk_size, num_items). 51 | logits = tf.matmul(output, ws) 52 | targets = tf.reshape(self._targets, [-1]) 53 | 54 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 55 | labels=targets, logits=logits) 56 | 57 | masked = loss * tf.to_float(tf.sign(targets)) 58 | masked = tf.reshape(masked, [batch_size, chunk_size]) 59 | self._cost = tf.reduce_sum(masked, axis=1) 60 | 61 | if not is_training: 62 | self._train_op = tf.no_op() 63 | return 64 | 65 | scalar_cost = tf.reduce_mean(masked) 66 | 67 | # Optimization procedure. 68 | optimizer = tf.train.RMSPropOptimizer( 69 | learning_rate, decay=rho, epsilon=1e-8) 70 | self._train_op = optimizer.minimize(scalar_cost) 71 | 72 | self._rms_reset = list() 73 | for var in tf.trainable_variables(): 74 | slot = optimizer.get_slot(var, "rms") 75 | op = slot.assign(tf.zeros(slot.get_shape())) 76 | self._rms_reset.append(op) 77 | 78 | @property 79 | def inputs(self): 80 | return self._inputs 81 | 82 | @property 83 | def targets(self): 84 | return self._targets 85 | 86 | @property 87 | def seq_length(self): 88 | return self._seq_length 89 | 90 | @property 91 | def initial_state(self): 92 | return self._initial_state 93 | 94 | @property 95 | def final_state(self): 96 | return self._final_state 97 | 98 | @property 99 | def cost(self): 100 | return self._cost 101 | 102 | @property 103 | def train_op(self): 104 | return self._train_op 105 | 106 | @property 107 | def batch_size(self): 108 | return self._batch_size 109 | 110 | @property 111 | def rms_reset(self): 112 | return self._rms_reset 113 | 114 | 115 | def run_batch(session, model, iterator, initial_state): 116 | """Runs the model on all chunks of one batch.""" 117 | costs = np.zeros(model.batch_size) 118 | sizes = np.zeros(model.batch_size) 119 | state = initial_state 120 | for inputs, targets, seq_len in iterator: 121 | fetches = [model.cost, model.final_state, model.train_op] 122 | feed_dict = {} 123 | feed_dict[model.inputs] = inputs 124 | feed_dict[model.targets] = targets 125 | feed_dict[model.seq_length] = seq_len 126 | feed_dict[model.initial_state] = state 127 | cost, state, _ = session.run(fetches, feed_dict) 128 | costs += cost 129 | sizes += seq_len 130 | with np.errstate(invalid='ignore'): 131 | errors = costs / sizes 132 | return (errors, np.sum(sizes), state) 133 | 134 | 135 | def run_epoch(session, train_model, valid_model, train_iter, valid_iter, 136 | tot_size): 137 | """Runs the model on the given data.""" 138 | start_time = time.time() 139 | 140 | train_errors = list() 141 | valid_errors = list() 142 | tot = 0 143 | 144 | next_tenth = tot_size / 10 145 | 146 | for train, valid in six.moves.zip(train_iter, valid_iter): 147 | state = session.run(train_model.initial_state) 148 | # Training data. 149 | errors, num_triplets, state = run_batch( 150 | session, train_model, train, state) 151 | tot += num_triplets 152 | train_errors.extend(errors) 153 | # Validation data. 154 | errors, num_triplets, state = run_batch( 155 | session, valid_model, valid, state) 156 | tot += num_triplets 157 | valid_errors.extend(errors) 158 | 159 | if tot > next_tenth: 160 | print("log-loss: {:.3f} speed: {:.0f} wps".format( 161 | np.nanmean(train_errors), 162 | tot / (time.time() - start_time))) 163 | next_tenth += tot_size / 10 164 | 165 | return (np.nanmean(train_errors), np.nanmean(valid_errors)) 166 | 167 | 168 | def main(args): 169 | # Read (and optionally, truncate) the training and validation data. 170 | train_data = Dataset.from_path(args.train_path) 171 | if args.max_train_chunks is not None: 172 | size = args.max_train_chunks * args.chunk_size 173 | train_data.truncate_seqs(size) 174 | valid_data = Dataset.from_path(args.valid_path) 175 | if args.max_valid_chunks is not None: 176 | size = args.max_valid_chunks * args.chunk_size 177 | valid_data.truncate_seqs(size, keep_first=True) 178 | 179 | num_users = train_data.num_users 180 | num_items = train_data.num_items 181 | tot_size = train_data.num_triplets + valid_data.num_triplets 182 | 183 | train_data.prepare_batches(args.chunk_size, args.batch_size) 184 | valid_data.prepare_batches(args.chunk_size, args.batch_size, 185 | batches_like=train_data) 186 | 187 | settings = { 188 | "chunk_size": args.chunk_size, 189 | "batch_size": args.batch_size, 190 | "hidden_size": args.hidden_size, 191 | "learning_rate": args.learning_rate, 192 | "rho": args.rho, 193 | } 194 | 195 | with tf.Graph().as_default(), tf.Session() as session: 196 | initializer = tf.random_normal_initializer( 197 | mean=0, stddev=1/sqrt(args.hidden_size)) 198 | with tf.variable_scope("model", reuse=None, initializer=initializer): 199 | train_model = CollaborativeRNN(num_users, num_items, 200 | is_training=True, **settings) 201 | with tf.variable_scope("model", reuse=True, initializer=initializer): 202 | valid_model = CollaborativeRNN(num_users, num_items, 203 | is_training=False, **settings) 204 | tf.global_variables_initializer().run() 205 | session.run(train_model.rms_reset) 206 | for i in range(1, args.num_epochs + 1): 207 | order = np.random.permutation(train_data.num_batches) 208 | train_iter = train_data.iter_batches(order=order) 209 | valid_iter = valid_data.iter_batches(order=order) 210 | 211 | train_err, valid_err = run_epoch(session, train_model, valid_model, 212 | train_iter, valid_iter, tot_size) 213 | print("Epoch {}, train log-loss: {:.3f}".format(i, train_err)) 214 | print("Epoch {}, valid log-loss: {:.3f}".format(i, valid_err)) 215 | 216 | 217 | def _parse_args(): 218 | parser = argparse.ArgumentParser() 219 | parser.add_argument("train_path", help="path to training data") 220 | parser.add_argument("valid_path", help="path to validation data") 221 | parser.add_argument("--batch-size", type=int, default=5, 222 | help="number of sequences processed in parallel") 223 | parser.add_argument("--chunk-size", type=int, default=64, 224 | help="number of unrolled steps in BPTT") 225 | parser.add_argument("--hidden-size", type=int, default=128, 226 | help="number of hidden units in the RNN cell") 227 | parser.add_argument("--learning-rate", type=float, default=0.01, 228 | help="RMSprop learning rate") 229 | parser.add_argument("--max-train-chunks", type=int, default=None, 230 | help="max number of chunks per user for training") 231 | parser.add_argument("--max-valid-chunks", type=int, default=None, 232 | help="max number of chunks per user for validation") 233 | parser.add_argument("--num-epochs", type=int, default=10, 234 | help="number of epochs to run") 235 | parser.add_argument("--rho", type=float, default=0.9, 236 | help="RMSprop decay coefficient") 237 | parser.add_argument("--verbose", action="store_true", default=False, 238 | help="enable display of debugging messages") 239 | return parser.parse_args() 240 | 241 | 242 | if __name__ == "__main__": 243 | args = _parse_args() 244 | if args.verbose: 245 | print("arguments:") 246 | for key, val in vars(args).items(): 247 | print("{: <18} {}".format(key, val)) 248 | main(args) 249 | --------------------------------------------------------------------------------