├── .gitignore ├── .travis.yml ├── Cargo.toml ├── LICENSE ├── README.tpl ├── benches └── benchmark.rs ├── data.csv ├── examples └── lstm_hyperopt.rs ├── readme.md └── src ├── data.rs ├── datasets.rs ├── evaluation.rs ├── lib.rs └── models ├── ewma.rs ├── lstm.rs ├── mod.rs └── sequence_model.rs /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | target/* 3 | Cargo.lock -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: rust 2 | rust: 3 | - stable 4 | 5 | before_install: 6 | - sudo apt-get -qq update 7 | - sudo apt-get install -y gfortran 8 | 9 | script: 10 | - MKL_CBWR=AVX RUSTFLAGS="-C target-cpu=native" cargo test --release 11 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "sbr" 3 | version = "0.5.0" 4 | authors = ["Maciej Kula"] 5 | license = "MIT" 6 | description = "Recommender models." 7 | repository = "https://github.com/maciejkula/sbr-rs" 8 | documentation = "https://docs.rs/sbr/" 9 | readme = "readme.md" 10 | exclude = ["data.csv"] 11 | edition = "2018" 12 | 13 | [badges] 14 | travis-ci = { repository = "maciejkula/sbr-rs", branch = "master" } 15 | 16 | [dependencies] 17 | serde = { version = "1.0.0", features = ["rc", "derive"] } 18 | bincode = "0.9.2" 19 | rand = { version = "0.5.0", features = ["serde1"] } 20 | itertools = "0.7.3" 21 | rayon = "1.0.0" 22 | ndarray = { version = "0.11.0", features = ["serde-1"] } 23 | siphasher = "0.2.2" 24 | failure = "0.1.1" 25 | reqwest = { version = "0.8.6", optional = true } 26 | csv = { version = "1.0.0", optional = true } 27 | dirs = { version = "1.0.2", optional = true } 28 | 29 | wyrm = { version = "0.9.1", features = ["fast-math"]} 30 | 31 | [dev-dependencies] 32 | serde_json = "1.0" 33 | criterion = "0.2.3" 34 | ndarray = { version = "0.11.0", features = ["blas", "serde-1"] } 35 | blas-src = { version = "0.1.2", default-features = false, features = ["intel-mkl"] } 36 | 37 | [features] 38 | default = ["csv", "reqwest", "dirs"] 39 | 40 | [profile.release] 41 | lto = true 42 | rpath = true 43 | debug = true 44 | 45 | [[bench]] 46 | name = "benchmark" 47 | harness = false 48 | 49 | [profile.bench] 50 | lto = true 51 | debug = true 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018 Maciej Kula 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.tpl: -------------------------------------------------------------------------------- 1 | # {{crate}} 2 | 3 | [![Crates.io badge](https://img.shields.io/crates/v/sbr.svg)](https://crates.io/crates/sbr) 4 | [![Docs.rs badge](https://docs.rs/sbr/badge.svg)](https://docs.rs/sbr/) 5 | [![Build Status](https://travis-ci.org/maciejkula/sbr-rs.svg?branch=master)](https://travis-ci.org/maciejkula/sbr-rs) 6 | 7 | An implementation of sequence recommenders based on the [wyrm](https://github.com/maciejkula/wyrm) autdifferentiaton library. 8 | 9 | {{readme}} 10 | 11 | License: {{license}} -------------------------------------------------------------------------------- /benches/benchmark.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | extern crate criterion; 3 | 4 | extern crate csv; 5 | extern crate rand; 6 | extern crate sbr; 7 | extern crate serde; 8 | extern crate serde_json; 9 | extern crate wyrm; 10 | 11 | use criterion::Criterion; 12 | 13 | use sbr::data::{Interaction, Interactions}; 14 | use sbr::models::{ewma, lstm}; 15 | use sbr::models::{Loss, Optimizer}; 16 | 17 | fn load_movielens(path: &str, sample_size: usize) -> Interactions { 18 | let mut reader = csv::Reader::from_path(path).unwrap(); 19 | let interactions: Vec = reader.deserialize().map(|x| x.unwrap()).collect(); 20 | 21 | let interactions = rand::seq::sample_slice(&mut rand::thread_rng(), &interactions, sample_size); 22 | 23 | Interactions::from(interactions) 24 | } 25 | 26 | fn bench_lstm(c: &mut Criterion) { 27 | c.bench_function("lstm", |b| { 28 | let data = load_movielens("data.csv", 10000).to_compressed(); 29 | 30 | let mut model = lstm::Hyperparameters::new(data.num_items(), 128) 31 | .embedding_dim(32) 32 | .learning_rate(0.16) 33 | .l2_penalty(0.0004) 34 | .loss(Loss::Hinge) 35 | .optimizer(Optimizer::Adagrad) 36 | .num_epochs(3) 37 | .num_threads(1) 38 | .build(); 39 | 40 | b.iter(|| { 41 | model.fit(&data).unwrap(); 42 | }) 43 | }); 44 | } 45 | 46 | fn bench_ewma(c: &mut Criterion) { 47 | c.bench_function("ewma", |b| { 48 | let data = load_movielens("data.csv", 10000).to_compressed(); 49 | 50 | let mut model = ewma::Hyperparameters::new(data.num_items(), 128) 51 | .embedding_dim(32) 52 | .learning_rate(0.16) 53 | .l2_penalty(0.0004) 54 | .loss(Loss::Hinge) 55 | .optimizer(Optimizer::Adagrad) 56 | .num_epochs(3) 57 | .num_threads(1) 58 | .build(); 59 | 60 | b.iter(|| { 61 | model.fit(&data).unwrap(); 62 | }) 63 | }); 64 | } 65 | 66 | criterion_group!{ 67 | name = benches; 68 | config = Criterion::default().sample_size(10); 69 | targets = bench_lstm, bench_ewma 70 | } 71 | criterion_main!(benches); 72 | -------------------------------------------------------------------------------- /examples/lstm_hyperopt.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | #![allow(unused_variables)] 3 | #![allow(unused_imports)] 4 | 5 | extern crate csv; 6 | extern crate rand; 7 | extern crate sbr; 8 | extern crate serde; 9 | extern crate serde_json; 10 | extern crate wyrm; 11 | 12 | use serde::{Deserialize, Serialize}; 13 | use std::fs::File; 14 | use std::io::{BufReader, Read}; 15 | 16 | use std::collections::HashSet; 17 | use std::time::{Duration, Instant}; 18 | 19 | use sbr::data::{user_based_split, CompressedInteractions, Interaction, Interactions}; 20 | use sbr::evaluation::mrr_score; 21 | use sbr::models::lstm; 22 | 23 | #[derive(Deserialize, Serialize)] 24 | struct GoodbooksInteraction { 25 | user_id: usize, 26 | book_id: usize, 27 | rating: usize, 28 | } 29 | 30 | fn load_goodbooks(path: &str) -> Interactions { 31 | let mut reader = csv::Reader::from_path(path).unwrap(); 32 | let mut interactions: Vec = reader 33 | .deserialize::() 34 | .map(|x| x.unwrap()) 35 | .enumerate() 36 | .map(|(i, x)| Interaction::new(x.user_id, x.book_id, i)) 37 | .collect(); 38 | interactions.sort_by_key(|x| x.user_id()); 39 | 40 | Interactions::from(interactions[..1_000_000].to_owned()) 41 | } 42 | 43 | fn load_dummy() -> Interactions { 44 | let num_users = 100; 45 | let num_items = 50; 46 | 47 | let mut interactions = Vec::new(); 48 | 49 | for user in 0..num_users { 50 | for item in 0..num_items { 51 | interactions.push(Interaction::new(user, 1000 + item, item)); 52 | } 53 | } 54 | 55 | Interactions::from(interactions) 56 | } 57 | 58 | #[derive(Debug, Serialize, Deserialize)] 59 | struct Result { 60 | test_mrr: f32, 61 | train_mrr: f32, 62 | elapsed: Duration, 63 | hyperparameters: lstm::Hyperparameters, 64 | } 65 | 66 | fn load_movielens(path: &str) -> Interactions { 67 | let mut reader = csv::Reader::from_path(path).unwrap(); 68 | let interactions: Vec = reader.deserialize().map(|x| x.unwrap()).collect(); 69 | 70 | let interactions = rand::seq::sample_slice(&mut rand::thread_rng(), &interactions, 100000); 71 | 72 | Interactions::from(interactions) 73 | } 74 | 75 | fn fit(train: &CompressedInteractions, hyper: lstm::Hyperparameters) -> lstm::ImplicitLSTMModel { 76 | let mut model = hyper.build(); 77 | model.fit(train).unwrap(); 78 | 79 | model 80 | } 81 | 82 | fn main() { 83 | // let data = load_goodbooks("ratings.csv"); 84 | let data = load_movielens("data.csv"); 85 | // let mut data = load_dummy(); 86 | let mut rng = rand::thread_rng(); 87 | 88 | let (train, test) = user_based_split(&data, &mut rng, 0.2); 89 | 90 | let train = train.to_compressed(); 91 | let test = test.to_compressed(); 92 | 93 | println!( 94 | "Train {} {} {}", 95 | train.num_users(), 96 | train.num_items(), 97 | data.len() 98 | ); 99 | 100 | for _ in 0..1000 { 101 | let mut results: Vec = File::open("lstm_results.json") 102 | .map(|file| serde_json::from_reader(&file).unwrap()) 103 | .unwrap_or(Vec::new()); 104 | 105 | let hyper = lstm::Hyperparameters::random(data.num_items(), &mut rng); 106 | 107 | println!("Running {:#?}", &hyper); 108 | let start = Instant::now(); 109 | let model = fit(&train, hyper.clone()); 110 | let result = Result { 111 | train_mrr: mrr_score(&model, &train).unwrap(), 112 | test_mrr: mrr_score(&model, &test).unwrap(), 113 | elapsed: start.elapsed(), 114 | hyperparameters: hyper, 115 | }; 116 | 117 | println!("{:#?}", result); 118 | 119 | if !result.test_mrr.is_nan() { 120 | results.push(result); 121 | results.sort_by(|a, b| a.test_mrr.partial_cmp(&b.test_mrr).unwrap()); 122 | } 123 | 124 | println!("Best result: {:#?}", results.last()); 125 | 126 | File::create("lstm_results.json") 127 | .map(|file| serde_json::to_writer_pretty(&file, &results).unwrap()) 128 | .unwrap(); 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # sbr 2 | 3 | [![Crates.io badge](https://img.shields.io/crates/v/sbr.svg)](https://crates.io/crates/sbr) 4 | [![Docs.rs badge](https://docs.rs/sbr/badge.svg)](https://docs.rs/sbr/) 5 | [![Build Status](https://travis-ci.org/maciejkula/sbr-rs.svg?branch=master)](https://travis-ci.org/maciejkula/sbr-rs) 6 | 7 | An implementation of sequence recommenders based on the [wyrm](https://github.com/maciejkula/wyrm) autdifferentiaton library. 8 | 9 | ## sbr-rs 10 | 11 | `sbr` implements efficient recommender algorithms which operate on 12 | sequences of items: given previous items a user has interacted with, 13 | the model will recommend the items the user is likely to interact with 14 | in the future. 15 | 16 | Implemented models: 17 | - LSTM: a model that uses an LSTM network over the sequence of a user's interaction 18 | to predict their next action; 19 | - EWMA: a model that uses a simpler exponentially-weighted average of past actions 20 | to predict future interactions. 21 | 22 | Which model performs the best will depend on your dataset. The EWMA model is much 23 | quicker to fit, and will probably be a good starting point. 24 | 25 | ### Example 26 | You can fit a model on the Movielens 100K dataset in about 10 seconds: 27 | 28 | ```rust 29 | let mut data = sbr::datasets::download_movielens_100k().unwrap(); 30 | 31 | let mut rng = rand::XorShiftRng::from_seed([42; 16]); 32 | 33 | let (train, test) = sbr::data::user_based_split(&mut data, &mut rng, 0.2); 34 | let train_mat = train.to_compressed(); 35 | let test_mat = test.to_compressed(); 36 | 37 | println!("Train: {}, test: {}", train.len(), test.len()); 38 | 39 | let mut model = sbr::models::lstm::Hyperparameters::new(data.num_items(), 32) 40 | .embedding_dim(32) 41 | .learning_rate(0.16) 42 | .l2_penalty(0.0004) 43 | .lstm_variant(sbr::models::lstm::LSTMVariant::Normal) 44 | .loss(sbr::models::Loss::WARP) 45 | .optimizer(sbr::models::Optimizer::Adagrad) 46 | .num_epochs(10) 47 | .rng(rng) 48 | .build(); 49 | 50 | let start = Instant::now(); 51 | let loss = model.fit(&train_mat).unwrap(); 52 | let elapsed = start.elapsed(); 53 | let train_mrr = sbr::evaluation::mrr_score(&model, &train_mat).unwrap(); 54 | let test_mrr = sbr::evaluation::mrr_score(&model, &test_mat).unwrap(); 55 | 56 | println!( 57 | "Train MRR {} at loss {} and test MRR {} (in {:?})", 58 | train_mrr, loss, test_mrr, elapsed 59 | ); 60 | ``` 61 | 62 | License: MIT 63 | -------------------------------------------------------------------------------- /src/data.rs: -------------------------------------------------------------------------------- 1 | //! Funcionality for manipulating data. 2 | 3 | use std; 4 | use std::cmp::Ordering; 5 | use std::hash::Hasher; 6 | 7 | use rand::distributions::{Distribution, Uniform}; 8 | use rand::Rng; 9 | 10 | use serde::{Deserialize, Serialize}; 11 | use siphasher::sip::SipHasher; 12 | 13 | use super::{ItemId, Timestamp, UserId}; 14 | 15 | /// Basic interaction type. 16 | #[derive(Clone, Serialize, Deserialize, Debug, Eq, Hash, PartialEq)] 17 | pub struct Interaction { 18 | user_id: UserId, 19 | item_id: ItemId, 20 | timestamp: Timestamp, 21 | } 22 | 23 | impl Interaction { 24 | /// Create a new interaction. 25 | pub fn new(user_id: UserId, item_id: ItemId, timestamp: Timestamp) -> Self { 26 | Interaction { 27 | user_id, 28 | item_id, 29 | timestamp, 30 | } 31 | } 32 | } 33 | 34 | impl Interaction { 35 | /// Return the user id. 36 | pub fn user_id(&self) -> UserId { 37 | self.user_id 38 | } 39 | /// Return the item id. 40 | pub fn item_id(&self) -> ItemId { 41 | self.item_id 42 | } 43 | /// Return the interaction weight. 44 | pub fn weight(&self) -> f32 { 45 | 1.0 46 | } 47 | /// Return the interaction timestamp. 48 | pub fn timestamp(&self) -> Timestamp { 49 | self.timestamp 50 | } 51 | } 52 | 53 | /// Randomly split interactions between test and traiing sets. 54 | pub fn train_test_split( 55 | interactions: &mut Interactions, 56 | rng: &mut R, 57 | test_fraction: f32, 58 | ) -> (Interactions, Interactions) { 59 | interactions.shuffle(rng); 60 | 61 | let (test, train) = interactions.split_at((test_fraction * interactions.len() as f32) as usize); 62 | 63 | (train, test) 64 | } 65 | 66 | /// Split interactions between training and test sets so that no user is in both sets. 67 | /// Useful for testing generalization where we want to test the model's performance on 68 | /// users who have not been seen during training. 69 | pub fn user_based_split( 70 | interactions: &Interactions, 71 | rng: &mut R, 72 | test_fraction: f32, 73 | ) -> (Interactions, Interactions) { 74 | let denominator = 100_000; 75 | let train_cutoff = (test_fraction * denominator as f32) as u64; 76 | 77 | let range = Uniform::new(0, std::u64::MAX); 78 | let (key_0, key_1) = (range.sample(rng), range.sample(rng)); 79 | 80 | let is_train = |x: &Interaction| { 81 | let mut hasher = SipHasher::new_with_keys(key_0, key_1); 82 | let user_id = x.user_id(); 83 | hasher.write_usize(user_id); 84 | hasher.finish() % denominator > train_cutoff 85 | }; 86 | 87 | interactions.split_by(is_train) 88 | } 89 | 90 | /// A collection of individual interactions. 91 | #[derive(Clone, Debug, Serialize, Deserialize)] 92 | pub struct Interactions { 93 | num_users: usize, 94 | num_items: usize, 95 | interactions: Vec, 96 | } 97 | 98 | impl Interactions { 99 | /// Crate a new interactions object. 100 | pub fn new(num_users: usize, num_items: usize) -> Self { 101 | Interactions { 102 | num_users, 103 | num_items, 104 | interactions: Vec::new(), 105 | } 106 | } 107 | /// Add a new interaction. 108 | pub fn push(&mut self, interaction: Interaction) { 109 | self.interactions.push(interaction); 110 | } 111 | 112 | /// Return the underlying data. 113 | pub fn data(&self) -> &[Interaction] { 114 | &self.interactions 115 | } 116 | 117 | /// Give the number of contained interactions. 118 | pub fn len(&self) -> usize { 119 | self.interactions.len() 120 | } 121 | 122 | /// Check if there are no interactions. 123 | pub fn is_empty(&self) -> bool { 124 | self.len() == 0 125 | } 126 | 127 | /// Shuffle the interactions in-place. 128 | pub fn shuffle(&mut self, rng: &mut R) { 129 | rng.shuffle(&mut self.interactions); 130 | } 131 | 132 | /// Split interactions at `idx`. 133 | pub fn split_at(&self, idx: usize) -> (Self, Self) { 134 | let head = Interactions { 135 | num_users: self.num_users, 136 | num_items: self.num_items, 137 | interactions: self.interactions[..idx].to_owned(), 138 | }; 139 | let tail = Interactions { 140 | num_users: self.num_users, 141 | num_items: self.num_items, 142 | interactions: self.interactions[idx..].to_owned(), 143 | }; 144 | 145 | (head, tail) 146 | } 147 | 148 | /// Split interactions by predicate. 149 | pub fn split_by bool>(&self, func: F) -> (Self, Self) { 150 | let head = Interactions { 151 | num_users: self.num_users, 152 | num_items: self.num_items, 153 | interactions: self 154 | .interactions 155 | .iter() 156 | .filter(|x| func(x)) 157 | .cloned() 158 | .collect(), 159 | }; 160 | let tail = Interactions { 161 | num_users: self.num_users, 162 | num_items: self.num_items, 163 | interactions: self 164 | .interactions 165 | .iter() 166 | .filter(|x| !func(x)) 167 | .cloned() 168 | .collect(), 169 | }; 170 | 171 | (head, tail) 172 | } 173 | 174 | /// Covert to triplet representation. 175 | pub fn to_triplet(&self) -> TripletInteractions { 176 | TripletInteractions::from(self) 177 | } 178 | 179 | /// Convert to compressed representation. 180 | pub fn to_compressed(&self) -> CompressedInteractions { 181 | CompressedInteractions::from(self) 182 | } 183 | 184 | /// Return number of users. 185 | pub fn num_users(&self) -> usize { 186 | self.num_users 187 | } 188 | 189 | /// Return number of items. 190 | pub fn num_items(&self) -> usize { 191 | self.num_items 192 | } 193 | 194 | /// Return (`num_users`, `num_items`). 195 | pub fn shape(&self) -> (usize, usize) { 196 | (self.num_users, self.num_items) 197 | } 198 | } 199 | 200 | impl From> for Interactions { 201 | fn from(interactions: Vec) -> Interactions { 202 | let num_users = interactions.iter().map(|x| x.user_id()).max().unwrap() + 1; 203 | let num_items = interactions.iter().map(|x| x.item_id()).max().unwrap() + 1; 204 | 205 | Interactions { 206 | num_users, 207 | num_items, 208 | interactions, 209 | } 210 | } 211 | } 212 | 213 | fn cmp_timestamp(x: &Interaction, y: &Interaction) -> Ordering { 214 | let uid_comparison = x.user_id().cmp(&y.user_id()); 215 | 216 | if uid_comparison == Ordering::Equal { 217 | x.timestamp().cmp(&y.timestamp()) 218 | } else { 219 | uid_comparison 220 | } 221 | } 222 | 223 | /// A compressed representation of interactions, where the 224 | /// interactions themselves are arranged by user and by timestamp. 225 | /// 226 | /// Normally created by [Interactions::to_compressed]. 227 | #[derive(Clone, Debug, Serialize, Deserialize)] 228 | pub struct CompressedInteractions { 229 | num_users: usize, 230 | num_items: usize, 231 | user_pointers: Vec, 232 | item_ids: Vec, 233 | timestamps: Vec, 234 | } 235 | 236 | impl<'a> From<&'a Interactions> for CompressedInteractions { 237 | fn from(interactions: &Interactions) -> CompressedInteractions { 238 | let mut data = interactions.data().to_owned(); 239 | 240 | data.sort_by(cmp_timestamp); 241 | 242 | let mut user_pointers = vec![0; interactions.num_users + 1]; 243 | let mut item_ids = Vec::with_capacity(data.len()); 244 | let mut timestamps = Vec::with_capacity(data.len()); 245 | 246 | for datum in &data { 247 | item_ids.push(datum.item_id()); 248 | timestamps.push(datum.timestamp()); 249 | 250 | user_pointers[datum.user_id() + 1] += 1; 251 | } 252 | 253 | for idx in 1..user_pointers.len() { 254 | user_pointers[idx] += user_pointers[idx - 1]; 255 | } 256 | 257 | CompressedInteractions { 258 | num_users: interactions.num_users, 259 | num_items: interactions.num_items, 260 | user_pointers, 261 | item_ids, 262 | timestamps, 263 | } 264 | } 265 | } 266 | 267 | impl CompressedInteractions { 268 | /// Iterate over users. 269 | pub fn iter_users(&self) -> CompressedInteractionsUserIterator { 270 | CompressedInteractionsUserIterator { 271 | interactions: self, 272 | idx: 0, 273 | } 274 | } 275 | 276 | /// Get a particular user's interactions. 277 | pub fn get_user(&self, user_id: UserId) -> Option { 278 | if user_id >= self.num_users { 279 | return None; 280 | } 281 | 282 | let start = self.user_pointers[user_id]; 283 | let stop = self.user_pointers[user_id + 1]; 284 | 285 | Some(CompressedInteractionsUser { 286 | user_id, 287 | item_ids: &self.item_ids[start..stop], 288 | timestamps: &self.timestamps[start..stop], 289 | }) 290 | } 291 | 292 | /// Return number of users. 293 | pub fn num_users(&self) -> usize { 294 | self.num_users 295 | } 296 | 297 | /// Return number of items. 298 | pub fn num_items(&self) -> usize { 299 | self.num_items 300 | } 301 | 302 | /// Return (`num_users`, `num_items`). 303 | pub fn shape(&self) -> (usize, usize) { 304 | (self.num_users, self.num_items) 305 | } 306 | 307 | /// Convert to `Interactions`. 308 | pub fn to_interactions(&self) -> Interactions { 309 | let mut interactions = Vec::new(); 310 | 311 | for user in self.iter_users() { 312 | for (&item_id, ×tamp) in izip!(user.item_ids, user.timestamps) { 313 | interactions.push(Interaction { 314 | user_id: user.user_id, 315 | item_id, 316 | timestamp, 317 | }); 318 | } 319 | } 320 | 321 | interactions.shrink_to_fit(); 322 | 323 | Interactions { 324 | num_users: self.num_users, 325 | num_items: self.num_items, 326 | interactions, 327 | } 328 | } 329 | } 330 | 331 | /// Iterator over compressed user data. 332 | #[derive(Clone, Debug)] 333 | pub struct CompressedInteractionsUserIterator<'a> { 334 | interactions: &'a CompressedInteractions, 335 | idx: usize, 336 | } 337 | 338 | /// A single user's data, arranged from earliest to latest. 339 | #[derive(Debug, Clone)] 340 | pub struct CompressedInteractionsUser<'a> { 341 | /// User id. 342 | pub user_id: UserId, 343 | /// The users's interactions. 344 | pub item_ids: &'a [ItemId], 345 | /// The timestamps of the user's interactions. 346 | pub timestamps: &'a [Timestamp], 347 | } 348 | 349 | impl<'a> CompressedInteractionsUser<'a> { 350 | /// Return length of interactions. 351 | pub fn len(&self) -> usize { 352 | self.item_ids.len() 353 | } 354 | 355 | /// Check if there are no interactions. 356 | pub fn is_empty(&self) -> bool { 357 | self.item_ids.is_empty() 358 | } 359 | 360 | /// Return a chunked iterator over interactions for this user. 361 | /// The chunks are such that the _first_ chunk is smallest, 362 | /// and the remaining chunks are all of `chunk_size`. 363 | pub fn chunks(&self, chunk_size: usize) -> CompressedInteractionsUserChunkIterator<'a> { 364 | CompressedInteractionsUserChunkIterator { 365 | idx: 0, 366 | chunk_size, 367 | item_ids: &self.item_ids[..], 368 | timestamps: &self.timestamps[..], 369 | } 370 | } 371 | } 372 | 373 | impl<'a> Iterator for CompressedInteractionsUserIterator<'a> { 374 | type Item = CompressedInteractionsUser<'a>; 375 | fn next(&mut self) -> Option { 376 | let value = if self.idx >= self.interactions.num_users { 377 | None 378 | } else { 379 | let start = self.interactions.user_pointers[self.idx]; 380 | let stop = self.interactions.user_pointers[self.idx + 1]; 381 | 382 | Some(CompressedInteractionsUser { 383 | user_id: self.idx, 384 | item_ids: &self.interactions.item_ids[start..stop], 385 | timestamps: &self.interactions.timestamps[start..stop], 386 | }) 387 | }; 388 | 389 | self.idx += 1; 390 | 391 | value 392 | } 393 | } 394 | 395 | /// Chunked iterator over a user's interactions. 396 | /// The chunks are such that the _first_ chunk is smallest, 397 | /// and the remaining chunks are all of `chunk_size`. 398 | #[derive(Debug, Clone)] 399 | pub struct CompressedInteractionsUserChunkIterator<'a> { 400 | idx: usize, 401 | chunk_size: usize, 402 | item_ids: &'a [ItemId], 403 | timestamps: &'a [Timestamp], 404 | } 405 | 406 | impl<'a> Iterator for CompressedInteractionsUserChunkIterator<'a> { 407 | type Item = (&'a [ItemId], &'a [Timestamp]); 408 | fn next(&mut self) -> Option { 409 | let user_len = self.item_ids.len(); 410 | 411 | if self.idx >= user_len { 412 | None 413 | } else { 414 | let chunk_size_mod = (user_len - self.idx) % self.chunk_size; 415 | let chunk_size = if chunk_size_mod == 0 { 416 | self.chunk_size 417 | } else { 418 | chunk_size_mod 419 | }; 420 | 421 | let start_idx = self.idx; 422 | let stop_idx = self.idx + chunk_size; 423 | 424 | self.idx += chunk_size; 425 | 426 | Some(( 427 | &self.item_ids[start_idx..stop_idx], 428 | &self.timestamps[start_idx..stop_idx], 429 | )) 430 | } 431 | } 432 | } 433 | 434 | /// Interactions in COO form. 435 | #[derive(Clone, Debug, Serialize, Deserialize)] 436 | pub struct TripletInteractions { 437 | num_users: usize, 438 | num_items: usize, 439 | user_ids: Vec, 440 | pub(crate) item_ids: Vec, 441 | timestamps: Vec, 442 | } 443 | 444 | impl TripletInteractions { 445 | /// Return lenght. 446 | pub fn len(&self) -> usize { 447 | self.user_ids.len() 448 | } 449 | 450 | /// Check if there are no interactions. 451 | pub fn is_empty(&self) -> bool { 452 | self.len() == 0 453 | } 454 | 455 | /// Iterate over minibatches of size `minibatch_size`. 456 | pub fn iter_minibatch(&self, minibatch_size: usize) -> TripletMinibatchIterator { 457 | TripletMinibatchIterator { 458 | interactions: self, 459 | idx: 0, 460 | stop_idx: self.len(), 461 | minibatch_size, 462 | } 463 | } 464 | 465 | /// Return a collection of iterators over a partitions of the data. 466 | pub fn iter_minibatch_partitioned( 467 | &self, 468 | minibatch_size: usize, 469 | num_partitions: usize, 470 | ) -> Vec { 471 | let iterator = self.iter_minibatch(minibatch_size); 472 | let chunk_size = self.len() / num_partitions; 473 | 474 | (0..num_partitions) 475 | .map(|x| iterator.slice(x * chunk_size, (x + 1) * chunk_size)) 476 | .collect() 477 | } 478 | 479 | /// Return number of users in the dataset. 480 | pub fn num_users(&self) -> usize { 481 | self.num_users 482 | } 483 | 484 | /// Return number of users in the dataset. 485 | pub fn num_items(&self) -> usize { 486 | self.num_items 487 | } 488 | 489 | /// Return (num_users, num_items). 490 | pub fn shape(&self) -> (usize, usize) { 491 | (self.num_users, self.num_items) 492 | } 493 | } 494 | 495 | /// Iterator over minibatches of triplet interactions. 496 | #[derive(Clone, Debug)] 497 | pub struct TripletMinibatchIterator<'a> { 498 | interactions: &'a TripletInteractions, 499 | idx: usize, 500 | stop_idx: usize, 501 | minibatch_size: usize, 502 | } 503 | 504 | impl<'a> TripletMinibatchIterator<'a> { 505 | /// Slice the iterator, yielding an iterator over a subslice of the data. 506 | pub fn slice(&self, start: usize, stop: usize) -> TripletMinibatchIterator<'a> { 507 | TripletMinibatchIterator { 508 | interactions: self.interactions, 509 | idx: start, 510 | stop_idx: stop, 511 | minibatch_size: self.minibatch_size, 512 | } 513 | } 514 | } 515 | 516 | /// A minibatch of triplet interactions. 517 | #[derive(Debug, Clone)] 518 | pub struct TripletMinibatch<'a> { 519 | /// User ids in the batch. 520 | pub user_ids: &'a [UserId], 521 | /// Item ids in the batch. 522 | pub item_ids: &'a [ItemId], 523 | /// Timestamps in the batch. 524 | pub timestamps: &'a [Timestamp], 525 | } 526 | 527 | impl<'a> TripletMinibatch<'a> { 528 | /// Return length of the minibatch. 529 | pub fn len(&self) -> usize { 530 | self.user_ids.len() 531 | } 532 | 533 | /// Check if there are no interactions. 534 | pub fn is_empty(&self) -> bool { 535 | self.item_ids.is_empty() 536 | } 537 | } 538 | 539 | impl<'a> Iterator for TripletMinibatchIterator<'a> { 540 | type Item = TripletMinibatch<'a>; 541 | fn next(&mut self) -> Option { 542 | let value = if self.idx + self.minibatch_size > self.stop_idx { 543 | None 544 | } else { 545 | let start = self.idx; 546 | let stop = self.idx + self.minibatch_size; 547 | 548 | Some(TripletMinibatch { 549 | user_ids: &self.interactions.user_ids[start..stop], 550 | item_ids: &self.interactions.item_ids[start..stop], 551 | timestamps: &self.interactions.timestamps[start..stop], 552 | }) 553 | }; 554 | 555 | self.idx += self.minibatch_size; 556 | 557 | value 558 | } 559 | } 560 | 561 | impl<'a> From<&'a Interactions> for TripletInteractions { 562 | fn from(interactions: &'a Interactions) -> Self { 563 | let user_ids = interactions.data().iter().map(|x| x.user_id()).collect(); 564 | let item_ids = interactions.data().iter().map(|x| x.item_id()).collect(); 565 | let timestamps = interactions.data().iter().map(|x| x.timestamp()).collect(); 566 | 567 | TripletInteractions { 568 | num_users: interactions.num_users, 569 | num_items: interactions.num_items, 570 | user_ids, 571 | item_ids, 572 | timestamps, 573 | } 574 | } 575 | } 576 | 577 | #[cfg(test)] 578 | mod tests { 579 | use std::collections::HashSet; 580 | 581 | use rand; 582 | use rand::distributions::{Distribution, Uniform}; 583 | use rand::SeedableRng; 584 | 585 | use super::*; 586 | 587 | #[test] 588 | fn to_compressed() { 589 | let num_users = 20; 590 | let num_items = 20; 591 | let num_interactions = 100; 592 | 593 | let user_range = Uniform::new(0, num_users); 594 | let item_range = Uniform::new(0, num_items); 595 | let timestamp_range = Uniform::new(0, 50); 596 | 597 | let mut rng = rand::XorShiftRng::from_seed([42; 16]); 598 | 599 | let interactions: Vec<_> = (0..num_interactions) 600 | .map(|_| Interaction { 601 | user_id: user_range.sample(&mut rng), 602 | item_id: item_range.sample(&mut rng), 603 | timestamp: timestamp_range.sample(&mut rng), 604 | }) 605 | .collect(); 606 | 607 | let mut interaction_set = HashSet::with_capacity(interactions.len()); 608 | for interaction in &interactions { 609 | interaction_set.insert(interaction.clone()); 610 | } 611 | 612 | let mut interactions = Interactions { 613 | num_users, 614 | num_items, 615 | interactions, 616 | }; 617 | let (train, test) = user_based_split(&mut interactions, &mut rng, 0.5); 618 | 619 | let train = train.to_compressed().to_interactions(); 620 | let test = test.to_compressed().to_interactions(); 621 | 622 | assert_eq!(train.len() + test.len(), interaction_set.len()); 623 | 624 | for interaction in train.data().iter().chain(test.data().iter()) { 625 | assert!(interaction_set.contains(interaction)); 626 | } 627 | } 628 | 629 | #[test] 630 | fn test_chunk_iterator() { 631 | let num_users = 1; 632 | let num_items = 5; 633 | 634 | let mut interactions = Vec::new(); 635 | 636 | for user in 0..num_users { 637 | for item in 0..num_items { 638 | interactions.push(Interaction::new(user, item, item)); 639 | } 640 | } 641 | 642 | let interactions = Interactions::from(interactions).to_compressed(); 643 | 644 | let chunks: Vec<_> = interactions 645 | .iter_users() 646 | .flat_map(|user| user.chunks(3)) 647 | .collect(); 648 | 649 | assert_eq!(chunks.len(), 2); 650 | 651 | let expected = [ 652 | (vec![0, 1_usize], vec![0, 1_usize]), 653 | (vec![2_usize, 3, 4], vec![2_usize, 3, 4]), 654 | ]; 655 | 656 | chunks.iter().zip(expected.iter()).for_each(|(x, y)| { 657 | assert_eq!(&x.0, &y.0.as_slice()); 658 | assert_eq!(&x.0, &y.1.as_slice()); 659 | }); 660 | 661 | //assert!(chunks == []); 662 | } 663 | 664 | // #[test] 665 | // fn foo_bar() { 666 | // let mut interactions = Vec::new(); 667 | 668 | // for user_id in 0..10 { 669 | // for item_id in 0..10 { 670 | // interactions.push(Interaction { 671 | // user_id: user_id, 672 | // item_id: item_id + 1000 * user_id, 673 | // timestamp: item_id, 674 | // }); 675 | // } 676 | // } 677 | 678 | // let interactions = Interactions { 679 | // num_users: 10, 680 | // num_items: interactions.iter().map(|x| x.item_id).max().unwrap() + 1, 681 | // interactions: interactions, 682 | // }; 683 | 684 | // let mut rng = rand::thread_rng(); 685 | // let (train, test) = user_based_split(&interactions, &mut rng, 0.5); 686 | 687 | // let train = train.to_compressed(); 688 | // let test = test.to_compressed(); 689 | 690 | // for user in train.iter_users() { 691 | // println!("Train {:#?}", user); 692 | // } 693 | // for user in test.iter_users() { 694 | // println!("Test {:#?}", user); 695 | // } 696 | // } 697 | } 698 | -------------------------------------------------------------------------------- /src/datasets.rs: -------------------------------------------------------------------------------- 1 | //! Built-in datasets for easy testing and experimentation. 2 | use std::env; 3 | use std::fs::{create_dir_all, rename, File}; 4 | use std::io::BufWriter; 5 | use std::path::{Path, PathBuf}; 6 | 7 | use csv; 8 | use dirs; 9 | use failure; 10 | use rand; 11 | use rand::Rng; 12 | use reqwest; 13 | 14 | use crate::data::{Interaction, Interactions}; 15 | 16 | /// Dataset error types. 17 | #[derive(Debug, Fail)] 18 | pub enum DatasetError { 19 | /// Can't find the home directory. 20 | #[fail(display = "Cannot find home directory.")] 21 | NoHomeDir, 22 | } 23 | 24 | fn create_data_dir() -> Result { 25 | let path = dirs::home_dir() 26 | .ok_or_else(|| DatasetError::NoHomeDir)? 27 | .join(".sbr-rs"); 28 | 29 | if !path.exists() { 30 | create_dir_all(&path)?; 31 | } 32 | 33 | Ok(path) 34 | } 35 | 36 | fn download(url: &str, dest_filename: &Path) -> Result { 37 | let data_dir = create_data_dir()?; 38 | let desired_filename = data_dir.join(dest_filename); 39 | 40 | if !desired_filename.exists() { 41 | let temp_filename = env::temp_dir().join( 42 | rand::thread_rng() 43 | .sample_iter(&rand::distributions::Alphanumeric) 44 | .take(10) 45 | .collect::(), 46 | ); 47 | 48 | let file = File::create(&temp_filename)?; 49 | let mut writer = BufWriter::new(file); 50 | 51 | let mut response = reqwest::get(url)?; 52 | response.copy_to(&mut writer)?; 53 | 54 | rename(temp_filename, &desired_filename)?; 55 | } 56 | 57 | let mut reader = csv::Reader::from_path(desired_filename)?; 58 | let interactions: Vec = reader.deserialize().collect::, _>>()?; 59 | 60 | Ok(Interactions::from(interactions)) 61 | } 62 | 63 | /// Download the Movielens 100K dataset and return it. 64 | /// 65 | /// The data is stored in `~/.sbr-rs/`. 66 | pub fn download_movielens_100k() -> Result { 67 | download( 68 | "https://github.com/maciejkula/sbr-rs/raw/master/data.csv", 69 | Path::new("movielens_100K.csv"), 70 | ) 71 | } 72 | -------------------------------------------------------------------------------- /src/evaluation.rs: -------------------------------------------------------------------------------- 1 | //! Model containing evaluation functions. 2 | use std; 3 | 4 | use rayon::prelude::*; 5 | 6 | use crate::data::CompressedInteractions; 7 | use crate::{OnlineRankingModel, PredictionError}; 8 | 9 | /// Compute the MRR (mean reciprocal rank) of predictions for the last 10 | /// item in `test` sequences, treating all but the last one item as inputs 11 | /// in computing the user representation. 12 | pub fn mrr_score( 13 | model: &T, 14 | test: &CompressedInteractions, 15 | ) -> Result { 16 | let item_ids: Vec = (0..test.num_items()).collect(); 17 | 18 | let mrrs = test 19 | .iter_users() 20 | .filter(|user| user.item_ids.len() >= 2) 21 | .collect::>() 22 | .par_iter() 23 | .map(|test_user| { 24 | let train_items = &test_user.item_ids[..test_user.item_ids.len().saturating_sub(1)]; 25 | let test_item = *test_user.item_ids.last().unwrap(); 26 | 27 | let user_embedding = model.user_representation(train_items).unwrap(); 28 | let mut predictions = model.predict(&user_embedding, &item_ids)?; 29 | 30 | for &train_item_id in train_items { 31 | predictions[train_item_id] = std::f32::MIN; 32 | } 33 | 34 | let test_score = predictions[test_item]; 35 | let mut rank = 0; 36 | 37 | for &prediction in &predictions { 38 | if prediction >= test_score { 39 | rank += 1; 40 | } 41 | } 42 | 43 | Ok(1.0 / rank as f32) 44 | }) 45 | .collect::, PredictionError>>()?; 46 | 47 | Ok(mrrs.iter().sum::() / mrrs.len() as f32) 48 | } 49 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![deny(missing_docs, missing_debug_implementations)] 2 | #![cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))] 3 | //! # sbr-rs 4 | //! 5 | //! `sbr` implements efficient recommender algorithms which operate on 6 | //! sequences of items: given previous items a user has interacted with, 7 | //! the model will recommend the items the user is likely to interact with 8 | //! in the future. 9 | //! 10 | //! Implemented models: 11 | //! - LSTM: a model that uses an LSTM network over the sequence of a user's interaction 12 | //! to predict their next action; 13 | //! - EWMA: a model that uses a simpler exponentially-weighted average of past actions 14 | //! to predict future interactions. 15 | //! 16 | //! Which model performs the best will depend on your dataset. The EWMA model is much 17 | //! quicker to fit, and will probably be a good starting point. 18 | //! 19 | //! ## Example 20 | //! You can fit a model on the Movielens 100K dataset in about 10 seconds: 21 | //! 22 | //! ```rust 23 | //! # extern crate sbr; 24 | //! # extern crate rand; 25 | //! # use std::time::Instant; 26 | //! # use rand::SeedableRng; 27 | //! let mut data = sbr::datasets::download_movielens_100k().unwrap(); 28 | //! 29 | //! let mut rng = rand::XorShiftRng::from_seed([42; 16]); 30 | //! 31 | //! let (train, test) = sbr::data::user_based_split(&mut data, &mut rng, 0.2); 32 | //! let train_mat = train.to_compressed(); 33 | //! let test_mat = test.to_compressed(); 34 | //! 35 | //! println!("Train: {}, test: {}", train.len(), test.len()); 36 | //! 37 | //! let mut model = sbr::models::lstm::Hyperparameters::new(data.num_items(), 32) 38 | //! .embedding_dim(32) 39 | //! .learning_rate(0.16) 40 | //! .l2_penalty(0.0004) 41 | //! .lstm_variant(sbr::models::lstm::LSTMVariant::Normal) 42 | //! .loss(sbr::models::Loss::WARP) 43 | //! .optimizer(sbr::models::Optimizer::Adagrad) 44 | //! .num_epochs(10) 45 | //! .rng(rng) 46 | //! .build(); 47 | //! 48 | //! let start = Instant::now(); 49 | //! let loss = model.fit(&train_mat).unwrap(); 50 | //! let elapsed = start.elapsed(); 51 | //! let train_mrr = sbr::evaluation::mrr_score(&model, &train_mat).unwrap(); 52 | //! let test_mrr = sbr::evaluation::mrr_score(&model, &test_mat).unwrap(); 53 | //! 54 | //! println!( 55 | //! "Train MRR {} at loss {} and test MRR {} (in {:?})", 56 | //! train_mrr, loss, test_mrr, elapsed 57 | //! ); 58 | //! ``` 59 | #[macro_use] 60 | extern crate itertools; 61 | 62 | #[cfg(feature = "default")] 63 | extern crate csv; 64 | #[macro_use] 65 | extern crate failure; 66 | 67 | #[cfg(feature = "default")] 68 | extern crate dirs; 69 | 70 | pub mod data; 71 | #[cfg(feature = "default")] 72 | pub mod datasets; 73 | pub mod evaluation; 74 | pub mod models; 75 | 76 | /// Alias for user indices. 77 | pub type UserId = usize; 78 | /// Alias for item indices. 79 | pub type ItemId = usize; 80 | /// Alias for timestamps. 81 | pub type Timestamp = usize; 82 | 83 | /// Prediction error types. 84 | #[derive(Debug, Fail)] 85 | pub enum PredictionError { 86 | /// Failed prediction due to numerical issues. 87 | #[fail(display = "Invalid prediction value: non-finite or not a number.")] 88 | InvalidPredictionValue, 89 | } 90 | 91 | /// Fitting error types. 92 | #[derive(Debug, Fail)] 93 | pub enum FittingError { 94 | /// No interactions were given. 95 | #[fail(display = "No interactions were supplied.")] 96 | NoInteractions, 97 | } 98 | 99 | /// Trait describing models that can compute predictions given 100 | /// a user's sequences of past interactions. 101 | pub trait OnlineRankingModel { 102 | /// The representation the model computes from past interactions. 103 | type UserRepresentation: std::fmt::Debug; 104 | /// Compute a user representation from past interactions. 105 | fn user_representation( 106 | &self, 107 | item_ids: &[ItemId], 108 | ) -> Result; 109 | /// Given a user representation, rank `item_ids` according 110 | /// to how likely the user is to interact with them in the future. 111 | fn predict( 112 | &self, 113 | user: &Self::UserRepresentation, 114 | item_ids: &[ItemId], 115 | ) -> Result, PredictionError>; 116 | } 117 | -------------------------------------------------------------------------------- /src/models/ewma.rs: -------------------------------------------------------------------------------- 1 | //! Model based on exponentially-weighted average (EWMA) of past embeddings. 2 | //! 3 | //! The model estimates three sets of parameters: 4 | //! 5 | //! - n-dimensional item embeddings 6 | //! - item biases (capturing item popularity), and 7 | //! - an n-dimensional `alpha` parameter, capturing the rate at which past interactions should be decayed. 8 | //! 9 | //! The representation of a user at time t is given by an n-dimensional vector u: 10 | //! ```text 11 | //! u_t = sigmoid(alpha) * i_{t-1} + (1.0 - sigmoid(alpha)) + i_t 12 | //! ``` 13 | //! where `i_t` is the embedding of the item the user interacted with at time `t`. 14 | use std::sync::Arc; 15 | 16 | use rand; 17 | use rand::distributions::{Distribution, Normal, Uniform}; 18 | use rand::{Rng, SeedableRng, XorShiftRng}; 19 | use rayon; 20 | use serde::{Deserialize, Serialize}; 21 | 22 | use ndarray::Axis; 23 | 24 | use wyrm; 25 | use wyrm::optim::Optimizers; 26 | use wyrm::{Arr, BoxedNode, Variable}; 27 | 28 | use super::sequence_model::{fit_sequence_model, SequenceModel, SequenceModelParameters}; 29 | use super::{ImplicitUser, Loss, Optimizer, Parallelism}; 30 | use crate::data::CompressedInteractions; 31 | use crate::{FittingError, ItemId, OnlineRankingModel, PredictionError}; 32 | 33 | fn embedding_init(rows: usize, cols: usize, rng: &mut T) -> wyrm::Arr { 34 | let normal = Normal::new(0.0, 1.0 / cols as f64); 35 | Arr::zeros((rows, cols)).map(|_| normal.sample(rng) as f32) 36 | } 37 | 38 | fn dense_init(rows: usize, cols: usize, rng: &mut T) -> wyrm::Arr { 39 | let normal = Normal::new(0.0, (2.0 / (rows + cols) as f64).sqrt()); 40 | Arr::zeros((rows, cols)).map(|_| normal.sample(rng) as f32) 41 | } 42 | 43 | /// Hyperparameters describing the EWMA model. 44 | #[derive(Clone, Debug, Serialize, Deserialize)] 45 | pub struct Hyperparameters { 46 | num_items: usize, 47 | max_sequence_length: usize, 48 | item_embedding_dim: usize, 49 | learning_rate: f32, 50 | l2_penalty: f32, 51 | loss: Loss, 52 | optimizer: Optimizer, 53 | parallelism: Parallelism, 54 | rng: XorShiftRng, 55 | num_threads: usize, 56 | num_epochs: usize, 57 | } 58 | 59 | impl Hyperparameters { 60 | /// Build new hyperparameters. 61 | pub fn new(num_items: usize, max_sequence_length: usize) -> Self { 62 | Hyperparameters { 63 | num_items, 64 | max_sequence_length, 65 | item_embedding_dim: 16, 66 | learning_rate: 0.01, 67 | l2_penalty: 0.0, 68 | loss: Loss::BPR, 69 | optimizer: Optimizer::Adam, 70 | parallelism: Parallelism::Synchronous, 71 | rng: XorShiftRng::from_seed(rand::thread_rng().gen()), 72 | num_threads: rayon::current_num_threads(), 73 | num_epochs: 10, 74 | } 75 | } 76 | 77 | /// Set the learning rate. 78 | pub fn learning_rate(mut self, learning_rate: f32) -> Self { 79 | self.learning_rate = learning_rate; 80 | self 81 | } 82 | 83 | /// Set the L2 penalty. 84 | pub fn l2_penalty(mut self, l2_penalty: f32) -> Self { 85 | self.l2_penalty = l2_penalty; 86 | self 87 | } 88 | 89 | /// Set the embedding dimensionality. 90 | pub fn embedding_dim(mut self, embedding_dim: usize) -> Self { 91 | self.item_embedding_dim = embedding_dim; 92 | self 93 | } 94 | 95 | /// Set the number of epochs to run per each `fit` call. 96 | pub fn num_epochs(mut self, num_epochs: usize) -> Self { 97 | self.num_epochs = num_epochs; 98 | self 99 | } 100 | 101 | /// Set the loss function. 102 | pub fn loss(mut self, loss: Loss) -> Self { 103 | self.loss = loss; 104 | self 105 | } 106 | 107 | /// Set number of threads to be used. 108 | pub fn num_threads(mut self, num_threads: usize) -> Self { 109 | self.num_threads = num_threads; 110 | self 111 | } 112 | 113 | /// Set the type of paralellism. 114 | pub fn parallelism(mut self, parallelism: Parallelism) -> Self { 115 | self.parallelism = parallelism; 116 | self 117 | } 118 | 119 | /// Set the random number generator. 120 | pub fn rng(mut self, rng: XorShiftRng) -> Self { 121 | self.rng = rng; 122 | self 123 | } 124 | 125 | #[allow(clippy::wrong_self_convention)] 126 | /// Set the random number generator from seed. 127 | pub fn from_seed(mut self, seed: [u8; 16]) -> Self { 128 | self.rng = XorShiftRng::from_seed(seed); 129 | self 130 | } 131 | 132 | /// Set the optimizer type. 133 | pub fn optimizer(mut self, optimizer: Optimizer) -> Self { 134 | self.optimizer = optimizer; 135 | self 136 | } 137 | 138 | /// Set hyperparameters randomly: useful for hyperparameter search. 139 | pub fn random(num_items: usize, rng: &mut R) -> Self { 140 | Hyperparameters { 141 | num_items, 142 | max_sequence_length: 2_usize.pow(Uniform::new(4, 8).sample(rng)), 143 | item_embedding_dim: 2_usize.pow(Uniform::new(4, 8).sample(rng)), 144 | learning_rate: (10.0_f32).powf(Uniform::new(-3.0, 0.5).sample(rng)), 145 | l2_penalty: (10.0_f32).powf(Uniform::new(-7.0, -3.0).sample(rng)), 146 | loss: if Uniform::new(0.0, 1.0).sample(rng) < 0.5 { 147 | Loss::BPR 148 | } else { 149 | Loss::Hinge 150 | }, 151 | optimizer: if Uniform::new(0.0, 1.0).sample(rng) < 0.5 { 152 | Optimizer::Adam 153 | } else { 154 | Optimizer::Adagrad 155 | }, 156 | parallelism: if Uniform::new(0.0, 1.0).sample(rng) < 0.5 { 157 | Parallelism::Asynchronous 158 | } else { 159 | Parallelism::Synchronous 160 | }, 161 | rng: XorShiftRng::from_seed(rand::thread_rng().gen()), 162 | num_threads: Uniform::new(1, rayon::current_num_threads() + 1).sample(rng), 163 | num_epochs: 2_usize.pow(Uniform::new(3, 7).sample(rng)), 164 | } 165 | } 166 | 167 | fn build_params(mut self) -> Parameters { 168 | let item_embeddings = Arc::new(wyrm::HogwildParameter::new(embedding_init( 169 | self.num_items, 170 | self.item_embedding_dim, 171 | &mut self.rng, 172 | ))); 173 | 174 | let item_biases = Arc::new(wyrm::HogwildParameter::new(Arr::zeros((self.num_items, 1)))); 175 | let alpha = Arc::new(wyrm::HogwildParameter::new(Arr::zeros(( 176 | 1, 177 | self.item_embedding_dim, 178 | )))); 179 | let fc1 = Arc::new(wyrm::HogwildParameter::new(dense_init( 180 | self.item_embedding_dim, 181 | self.item_embedding_dim, 182 | &mut self.rng, 183 | ))); 184 | let fc2 = Arc::new(wyrm::HogwildParameter::new(dense_init( 185 | self.item_embedding_dim, 186 | self.item_embedding_dim, 187 | &mut self.rng, 188 | ))); 189 | 190 | Parameters { 191 | hyper: self, 192 | item_embedding: item_embeddings, 193 | item_biases, 194 | alpha, 195 | fc1, 196 | fc2, 197 | } 198 | } 199 | 200 | /// Build the implicit EWMA model. 201 | pub fn build(self) -> ImplicitEWMAModel { 202 | let params = self.build_params(); 203 | 204 | ImplicitEWMAModel { params } 205 | } 206 | } 207 | 208 | #[derive(Debug, Serialize, Deserialize)] 209 | struct Parameters { 210 | hyper: Hyperparameters, 211 | item_embedding: Arc, 212 | item_biases: Arc, 213 | alpha: Arc, 214 | fc1: Arc, 215 | fc2: Arc, 216 | } 217 | 218 | impl Clone for Parameters { 219 | fn clone(&self) -> Self { 220 | Parameters { 221 | hyper: self.hyper.clone(), 222 | item_embedding: Arc::new(self.item_embedding.as_ref().clone()), 223 | item_biases: Arc::new(self.item_biases.as_ref().clone()), 224 | alpha: Arc::new(self.alpha.as_ref().clone()), 225 | fc1: Arc::new(self.alpha.as_ref().clone()), 226 | fc2: Arc::new(self.alpha.as_ref().clone()), 227 | } 228 | } 229 | } 230 | 231 | impl SequenceModelParameters for Parameters { 232 | type Output = Model; 233 | fn max_sequence_length(&self) -> usize { 234 | self.hyper.max_sequence_length 235 | } 236 | fn num_threads(&self) -> usize { 237 | self.hyper.num_threads 238 | } 239 | fn rng(&mut self) -> &mut XorShiftRng { 240 | &mut self.hyper.rng 241 | } 242 | fn optimizer(&self) -> Optimizers { 243 | match self.hyper.optimizer { 244 | Optimizer::Adagrad => Optimizers::Adagrad( 245 | wyrm::optim::Adagrad::new() 246 | .learning_rate(self.hyper.learning_rate) 247 | .l2_penalty(self.hyper.l2_penalty), 248 | ), 249 | 250 | Optimizer::Adam => Optimizers::Adam( 251 | wyrm::optim::Adam::new() 252 | .learning_rate(self.hyper.learning_rate) 253 | .l2_penalty(self.hyper.l2_penalty), 254 | ), 255 | } 256 | } 257 | fn parallelism(&self) -> &Parallelism { 258 | &self.hyper.parallelism 259 | } 260 | fn loss(&self) -> &Loss { 261 | &self.hyper.loss 262 | } 263 | fn num_epochs(&self) -> usize { 264 | self.hyper.num_epochs 265 | } 266 | fn build(&self) -> Model { 267 | let item_embeddings = wyrm::ParameterNode::shared(self.item_embedding.clone()); 268 | let item_biases = wyrm::ParameterNode::shared(self.item_biases.clone()); 269 | let alpha = wyrm::ParameterNode::shared(self.alpha.clone()); 270 | 271 | let inputs: Vec<_> = (0..self.hyper.max_sequence_length) 272 | .map(|_| wyrm::IndexInputNode::new(&[0; 1])) 273 | .collect(); 274 | let outputs: Vec<_> = (0..self.hyper.max_sequence_length) 275 | .map(|_| wyrm::IndexInputNode::new(&[0; 1])) 276 | .collect(); 277 | let negatives: Vec<_> = (0..self.hyper.max_sequence_length) 278 | .map(|_| wyrm::IndexInputNode::new(&[0; 1])) 279 | .collect(); 280 | 281 | let input_embeddings: Vec<_> = inputs 282 | .iter() 283 | .map(|input| item_embeddings.index(input)) 284 | .collect(); 285 | let negative_embeddings: Vec<_> = negatives 286 | .iter() 287 | .map(|negative| item_embeddings.index(negative)) 288 | .collect(); 289 | let output_embeddings: Vec<_> = outputs 290 | .iter() 291 | .map(|output| item_embeddings.index(output)) 292 | .collect(); 293 | let output_biases: Vec<_> = outputs 294 | .iter() 295 | .map(|output| item_biases.index(output)) 296 | .collect(); 297 | let negative_biases: Vec<_> = negatives 298 | .iter() 299 | .map(|negative| item_biases.index(negative)) 300 | .collect(); 301 | 302 | let alpha = alpha.sigmoid(); 303 | let one_minus_alpha = 1.0 - alpha.clone(); 304 | 305 | let mut states = Vec::with_capacity(self.hyper.max_sequence_length); 306 | let initial_state = input_embeddings.first().unwrap().clone().boxed(); 307 | states.push(initial_state); 308 | for input in &input_embeddings[1..] { 309 | let previous_state = states.last().unwrap().clone(); 310 | states.push( 311 | (alpha.clone() * previous_state + one_minus_alpha.clone() * input.clone()).boxed(), 312 | ); 313 | } 314 | 315 | let positive_predictions: Vec<_> = 316 | izip!(states.iter(), output_embeddings.iter(), output_biases) 317 | .map(|(state, output_embedding, output_bias)| { 318 | state.vector_dot(output_embedding) + output_bias 319 | }) 320 | .collect(); 321 | let negative_predictions: Vec<_> = 322 | izip!(states.iter(), negative_embeddings.iter(), negative_biases) 323 | .map(|(state, negative_embedding, negative_bias)| { 324 | state.vector_dot(negative_embedding) + negative_bias 325 | }) 326 | .collect(); 327 | 328 | let losses: Vec<_> = positive_predictions 329 | .into_iter() 330 | .zip(negative_predictions.into_iter()) 331 | .map(|(pos, neg)| match self.hyper.loss { 332 | Loss::BPR => (neg - pos).sigmoid().boxed(), 333 | Loss::Hinge | Loss::WARP => (1.0 + neg - pos).relu().boxed(), 334 | }) 335 | .collect(); 336 | 337 | let mut summed_losses = Vec::with_capacity(losses.len()); 338 | summed_losses.push(losses[0].clone()); 339 | 340 | for loss in &losses[1..] { 341 | let loss = (summed_losses.last().unwrap().clone() + loss.clone()).boxed(); 342 | summed_losses.push(loss); 343 | } 344 | 345 | Model { 346 | inputs, 347 | outputs, 348 | negatives, 349 | hidden_states: states, 350 | summed_losses, 351 | } 352 | } 353 | fn predict_single(&self, user: &[f32], item_idx: usize) -> f32 { 354 | let item_embeddings = &self.item_embedding; 355 | let item_biases = &self.item_biases; 356 | 357 | let embeddings = item_embeddings.value(); 358 | let biases = item_biases.value(); 359 | 360 | let embedding = embeddings.subview(Axis(0), item_idx); 361 | let bias = biases[(item_idx, 0)]; 362 | let dot = wyrm::simd_dot(user, embedding.as_slice().unwrap()); 363 | 364 | bias + dot 365 | } 366 | } 367 | 368 | struct Model { 369 | inputs: Vec>, 370 | outputs: Vec>, 371 | negatives: Vec>, 372 | hidden_states: Vec>, 373 | summed_losses: Vec>, 374 | } 375 | 376 | impl SequenceModel for Model { 377 | fn state( 378 | &self, 379 | ) -> ( 380 | &[Variable], 381 | &[Variable], 382 | &[Variable], 383 | &[Variable], 384 | ) { 385 | ( 386 | &self.inputs, 387 | &self.outputs, 388 | &self.negatives, 389 | &self.hidden_states, 390 | ) 391 | } 392 | fn losses(&mut self) -> &mut [Variable] { 393 | &mut self.summed_losses 394 | } 395 | fn hidden_states(&mut self) -> &mut [Variable] { 396 | &mut self.hidden_states 397 | } 398 | } 399 | 400 | /// Implicit EWMA model. 401 | #[derive(Debug, Clone, Serialize, Deserialize)] 402 | pub struct ImplicitEWMAModel { 403 | params: Parameters, 404 | } 405 | 406 | impl ImplicitEWMAModel { 407 | /// Fit the EWMA model. 408 | pub fn fit(&mut self, interactions: &CompressedInteractions) -> Result { 409 | fit_sequence_model(interactions, &mut self.params) 410 | } 411 | } 412 | 413 | impl OnlineRankingModel for ImplicitEWMAModel { 414 | type UserRepresentation = ImplicitUser; 415 | fn user_representation( 416 | &self, 417 | item_ids: &[ItemId], 418 | ) -> Result { 419 | self.params.user_representation(item_ids) 420 | } 421 | 422 | fn predict( 423 | &self, 424 | user: &Self::UserRepresentation, 425 | item_ids: &[ItemId], 426 | ) -> Result, PredictionError> { 427 | self.params.predict(user, item_ids) 428 | } 429 | } 430 | 431 | #[cfg(test)] 432 | mod tests { 433 | use std::time::Instant; 434 | 435 | use super::*; 436 | use crate::data::{user_based_split, Interactions}; 437 | use crate::datasets::download_movielens_100k; 438 | use crate::evaluation::mrr_score; 439 | 440 | fn run_test(mut data: Interactions, hyperparameters: Hyperparameters) -> (f32, f32) { 441 | let mut rng = rand::XorShiftRng::from_seed([42; 16]); 442 | 443 | let (train, test) = user_based_split(&mut data, &mut rng, 0.2); 444 | let train_mat = train.to_compressed(); 445 | let test_mat = test.to_compressed(); 446 | 447 | let mut model = hyperparameters.rng(rng).build(); 448 | 449 | let start = Instant::now(); 450 | let loss = model.fit(&train_mat).unwrap(); 451 | let elapsed = start.elapsed(); 452 | let train_mrr = mrr_score(&model, &train_mat).unwrap(); 453 | let test_mrr = mrr_score(&model, &test_mat).unwrap(); 454 | 455 | println!( 456 | "Train MRR {} at loss {} and test MRR {} (in {:?})", 457 | train_mrr, loss, test_mrr, elapsed 458 | ); 459 | 460 | (test_mrr, train_mrr) 461 | } 462 | 463 | #[test] 464 | fn mrr_test_single_thread() { 465 | let data = download_movielens_100k().unwrap(); 466 | 467 | let hyperparameters = Hyperparameters::new(data.num_items(), 128) 468 | .embedding_dim(32) 469 | .learning_rate(0.16) 470 | .l2_penalty(0.0004) 471 | .loss(Loss::Hinge) 472 | .optimizer(Optimizer::Adagrad) 473 | .num_epochs(10) 474 | .num_threads(1); 475 | 476 | let (test_mrr, _) = run_test(data, hyperparameters); 477 | 478 | let expected_mrr = match ::std::env::var("MKL_CBWR") { 479 | Ok(ref val) if val == "AVX" => 0.091, 480 | _ => 0.11, 481 | }; 482 | 483 | assert!(test_mrr > expected_mrr) 484 | } 485 | 486 | #[test] 487 | fn mrr_test_warp() { 488 | let data = download_movielens_100k().unwrap(); 489 | 490 | let hyperparameters = Hyperparameters::new(data.num_items(), 128) 491 | .embedding_dim(32) 492 | .learning_rate(0.16) 493 | .l2_penalty(0.0004) 494 | .loss(Loss::WARP) 495 | .optimizer(Optimizer::Adagrad) 496 | .num_epochs(10) 497 | .num_threads(1); 498 | 499 | let (test_mrr, _) = run_test(data, hyperparameters); 500 | 501 | let expected_mrr = match ::std::env::var("MKL_CBWR") { 502 | Ok(ref val) if val == "AVX" => 0.089, 503 | _ => 0.14, 504 | }; 505 | 506 | assert!(test_mrr > expected_mrr) 507 | } 508 | } 509 | -------------------------------------------------------------------------------- /src/models/lstm.rs: -------------------------------------------------------------------------------- 1 | //! Module for LSTM-based models. 2 | use std::sync::Arc; 3 | 4 | use rand; 5 | use rand::distributions::{Distribution, Normal, Uniform}; 6 | use rand::{Rng, SeedableRng, XorShiftRng}; 7 | use rayon; 8 | use serde::{Deserialize, Serialize}; 9 | 10 | use ndarray::Axis; 11 | 12 | use wyrm; 13 | use wyrm::nn; 14 | use wyrm::optim::Optimizers; 15 | use wyrm::{Arr, BoxedNode, Variable}; 16 | 17 | use super::sequence_model::{fit_sequence_model, SequenceModel, SequenceModelParameters}; 18 | use super::{ImplicitUser, Loss, Optimizer, Parallelism}; 19 | use crate::data::CompressedInteractions; 20 | use crate::{FittingError, ItemId, OnlineRankingModel, PredictionError}; 21 | 22 | fn embedding_init(rows: usize, cols: usize, rng: &mut T) -> wyrm::Arr { 23 | let normal = Normal::new(0.0, 1.0 / cols as f64); 24 | Arr::zeros((rows, cols)).map(|_| normal.sample(rng) as f32) 25 | } 26 | 27 | /// Type of LSTM layer to use. 28 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 29 | pub enum LSTMVariant { 30 | /// Classic LSTM layer. 31 | Normal, 32 | /// A variant where the update and forget gates are coupled. 33 | /// Faster to train. 34 | Coupled, 35 | } 36 | 37 | /// Hyperparameters for the [ImplicitLSTMModel]. 38 | #[derive(Clone, Debug, Serialize, Deserialize)] 39 | pub struct Hyperparameters { 40 | num_items: usize, 41 | max_sequence_length: usize, 42 | item_embedding_dim: usize, 43 | learning_rate: f32, 44 | l2_penalty: f32, 45 | lstm_type: LSTMVariant, 46 | loss: Loss, 47 | optimizer: Optimizer, 48 | parallelism: Parallelism, 49 | rng: XorShiftRng, 50 | num_threads: usize, 51 | num_epochs: usize, 52 | } 53 | 54 | impl Hyperparameters { 55 | /// Build new hyperparameters. 56 | pub fn new(num_items: usize, max_sequence_length: usize) -> Self { 57 | Hyperparameters { 58 | num_items, 59 | max_sequence_length, 60 | item_embedding_dim: 16, 61 | learning_rate: 0.01, 62 | l2_penalty: 0.0, 63 | lstm_type: LSTMVariant::Coupled, 64 | loss: Loss::BPR, 65 | optimizer: Optimizer::Adam, 66 | parallelism: Parallelism::Synchronous, 67 | rng: XorShiftRng::from_seed(rand::thread_rng().gen()), 68 | num_threads: rayon::current_num_threads(), 69 | num_epochs: 10, 70 | } 71 | } 72 | 73 | /// Set the learning rate. 74 | pub fn learning_rate(mut self, learning_rate: f32) -> Self { 75 | self.learning_rate = learning_rate; 76 | self 77 | } 78 | 79 | /// Set the l2 penalty. 80 | pub fn l2_penalty(mut self, l2_penalty: f32) -> Self { 81 | self.l2_penalty = l2_penalty; 82 | self 83 | } 84 | 85 | /// Set the embedding dimensionality. 86 | pub fn embedding_dim(mut self, embedding_dim: usize) -> Self { 87 | self.item_embedding_dim = embedding_dim; 88 | self 89 | } 90 | 91 | /// Set the number of epochs to run per each `fit` call. 92 | pub fn num_epochs(mut self, num_epochs: usize) -> Self { 93 | self.num_epochs = num_epochs; 94 | self 95 | } 96 | 97 | /// Set the loss function. 98 | pub fn loss(mut self, loss: Loss) -> Self { 99 | self.loss = loss; 100 | self 101 | } 102 | 103 | /// Set the LSTM variant 104 | pub fn lstm_variant(mut self, variant: LSTMVariant) -> Self { 105 | self.lstm_type = variant; 106 | self 107 | } 108 | 109 | /// Set number of threads to be used. 110 | pub fn num_threads(mut self, num_threads: usize) -> Self { 111 | self.num_threads = num_threads; 112 | self 113 | } 114 | 115 | /// Set the type of paralellism. 116 | pub fn parallelism(mut self, parallelism: Parallelism) -> Self { 117 | self.parallelism = parallelism; 118 | self 119 | } 120 | 121 | /// Set the random number generator. 122 | pub fn rng(mut self, rng: XorShiftRng) -> Self { 123 | self.rng = rng; 124 | self 125 | } 126 | 127 | #[allow(clippy::wrong_self_convention)] 128 | /// Set the random number generator from seed. 129 | pub fn from_seed(mut self, seed: [u8; 16]) -> Self { 130 | self.rng = XorShiftRng::from_seed(seed); 131 | self 132 | } 133 | 134 | /// Set the optimizer type. 135 | pub fn optimizer(mut self, optimizer: Optimizer) -> Self { 136 | self.optimizer = optimizer; 137 | self 138 | } 139 | 140 | /// Set hyperparameters randomly: useful for hyperparameter search. 141 | pub fn random(num_items: usize, rng: &mut R) -> Self { 142 | Hyperparameters { 143 | num_items, 144 | max_sequence_length: 2_usize.pow(Uniform::new(4, 8).sample(rng)), 145 | item_embedding_dim: 2_usize.pow(Uniform::new(4, 8).sample(rng)), 146 | learning_rate: (10.0_f32).powf(Uniform::new(-3.0, 0.5).sample(rng)), 147 | l2_penalty: (10.0_f32).powf(Uniform::new(-7.0, -3.0).sample(rng)), 148 | loss: if Uniform::new(0.0, 1.0).sample(rng) < 0.5 { 149 | Loss::BPR 150 | } else { 151 | Loss::Hinge 152 | }, 153 | optimizer: if Uniform::new(0.0, 1.0).sample(rng) < 0.5 { 154 | Optimizer::Adam 155 | } else { 156 | Optimizer::Adagrad 157 | }, 158 | lstm_type: if Uniform::new(0.0, 1.0).sample(rng) < 0.5 { 159 | LSTMVariant::Normal 160 | } else { 161 | LSTMVariant::Coupled 162 | }, 163 | parallelism: if Uniform::new(0.0, 1.0).sample(rng) < 0.5 { 164 | Parallelism::Asynchronous 165 | } else { 166 | Parallelism::Synchronous 167 | }, 168 | rng: XorShiftRng::from_seed(rand::thread_rng().gen()), 169 | num_threads: Uniform::new(1, rayon::current_num_threads() + 1).sample(rng), 170 | num_epochs: 2_usize.pow(Uniform::new(3, 7).sample(rng)), 171 | } 172 | } 173 | 174 | fn build_params(mut self) -> Parameters { 175 | let item_embeddings = Arc::new(wyrm::HogwildParameter::new(embedding_init( 176 | self.num_items, 177 | self.item_embedding_dim, 178 | &mut self.rng, 179 | ))); 180 | 181 | let item_biases = Arc::new(wyrm::HogwildParameter::new(Arr::zeros((self.num_items, 1)))); 182 | let lstm_params = nn::lstm::Parameters::new( 183 | self.item_embedding_dim, 184 | self.item_embedding_dim, 185 | &mut self.rng, 186 | ); 187 | 188 | Parameters { 189 | hyper: self, 190 | item_embedding: item_embeddings, 191 | item_biases, 192 | lstm: lstm_params, 193 | } 194 | } 195 | 196 | /// Build a model out of the chosen hyperparameters. 197 | pub fn build(self) -> ImplicitLSTMModel { 198 | ImplicitLSTMModel { 199 | params: self.build_params(), 200 | } 201 | } 202 | } 203 | 204 | #[derive(Debug, Serialize, Deserialize)] 205 | struct Parameters { 206 | hyper: Hyperparameters, 207 | item_embedding: Arc, 208 | item_biases: Arc, 209 | lstm: nn::lstm::Parameters, 210 | } 211 | 212 | impl Clone for Parameters { 213 | fn clone(&self) -> Self { 214 | Parameters { 215 | hyper: self.hyper.clone(), 216 | item_embedding: Arc::new(self.item_embedding.as_ref().clone()), 217 | item_biases: Arc::new(self.item_biases.as_ref().clone()), 218 | lstm: self.lstm.clone(), 219 | } 220 | } 221 | } 222 | 223 | impl SequenceModelParameters for Parameters { 224 | type Output = Model; 225 | fn max_sequence_length(&self) -> usize { 226 | self.hyper.max_sequence_length 227 | } 228 | fn num_threads(&self) -> usize { 229 | self.hyper.num_threads 230 | } 231 | fn rng(&mut self) -> &mut XorShiftRng { 232 | &mut self.hyper.rng 233 | } 234 | fn optimizer(&self) -> Optimizers { 235 | match self.hyper.optimizer { 236 | Optimizer::Adagrad => Optimizers::Adagrad( 237 | wyrm::optim::Adagrad::new() 238 | .learning_rate(self.hyper.learning_rate) 239 | .l2_penalty(self.hyper.l2_penalty), 240 | ), 241 | 242 | Optimizer::Adam => Optimizers::Adam( 243 | wyrm::optim::Adam::new() 244 | .learning_rate(self.hyper.learning_rate) 245 | .l2_penalty(self.hyper.l2_penalty), 246 | ), 247 | } 248 | } 249 | fn parallelism(&self) -> &Parallelism { 250 | &self.hyper.parallelism 251 | } 252 | fn loss(&self) -> &Loss { 253 | &self.hyper.loss 254 | } 255 | fn num_epochs(&self) -> usize { 256 | self.hyper.num_epochs 257 | } 258 | fn build(&self) -> Self::Output { 259 | let item_embeddings = wyrm::ParameterNode::shared(self.item_embedding.clone()); 260 | let item_biases = wyrm::ParameterNode::shared(self.item_biases.clone()); 261 | 262 | let inputs: Vec<_> = (0..self.hyper.max_sequence_length) 263 | .map(|_| wyrm::IndexInputNode::new(&[0; 1])) 264 | .collect(); 265 | let outputs: Vec<_> = (0..self.hyper.max_sequence_length) 266 | .map(|_| wyrm::IndexInputNode::new(&[0; 1])) 267 | .collect(); 268 | let negatives: Vec<_> = (0..self.hyper.max_sequence_length) 269 | .map(|_| wyrm::IndexInputNode::new(&[0; 1])) 270 | .collect(); 271 | 272 | let input_embeddings: Vec<_> = inputs 273 | .iter() 274 | .map(|input| item_embeddings.index(input)) 275 | .collect(); 276 | let negative_embeddings: Vec<_> = negatives 277 | .iter() 278 | .map(|negative| item_embeddings.index(negative)) 279 | .collect(); 280 | let output_embeddings: Vec<_> = outputs 281 | .iter() 282 | .map(|output| item_embeddings.index(output)) 283 | .collect(); 284 | let output_biases: Vec<_> = outputs 285 | .iter() 286 | .map(|output| item_biases.index(output)) 287 | .collect(); 288 | let negative_biases: Vec<_> = negatives 289 | .iter() 290 | .map(|negative| item_biases.index(negative)) 291 | .collect(); 292 | 293 | let layer = match self.hyper.lstm_type { 294 | LSTMVariant::Normal => self.lstm.build(), 295 | LSTMVariant::Coupled => self.lstm.build_coupled(), 296 | }; 297 | 298 | let hidden = layer.forward(&input_embeddings); 299 | 300 | let positive_predictions: Vec<_> = 301 | izip!(hidden.iter(), output_embeddings.iter(), output_biases) 302 | .map(|(hidden_state, output_embedding, output_bias)| { 303 | hidden_state.vector_dot(output_embedding) + output_bias 304 | }) 305 | .collect(); 306 | let negative_predictions: Vec<_> = 307 | izip!(hidden.iter(), negative_embeddings.iter(), negative_biases) 308 | .map(|(hidden_state, negative_embedding, negative_bias)| { 309 | hidden_state.vector_dot(negative_embedding) + negative_bias 310 | }) 311 | .collect(); 312 | 313 | let losses: Vec<_> = positive_predictions 314 | .into_iter() 315 | .zip(negative_predictions.into_iter()) 316 | .map(|(pos, neg)| match self.hyper.loss { 317 | Loss::BPR => (neg - pos).sigmoid().boxed(), 318 | Loss::Hinge | Loss::WARP => (1.0 + neg - pos).relu().boxed(), 319 | }) 320 | .collect(); 321 | 322 | let mut summed_losses = Vec::with_capacity(losses.len()); 323 | summed_losses.push(losses[0].clone()); 324 | 325 | for loss in &losses[1..] { 326 | let loss = (summed_losses.last().unwrap().clone() + loss.clone()).boxed(); 327 | summed_losses.push(loss); 328 | } 329 | 330 | Model { 331 | inputs, 332 | outputs, 333 | negatives, 334 | hidden_states: hidden, 335 | summed_losses, 336 | } 337 | } 338 | fn predict_single(&self, user: &[f32], item_idx: usize) -> f32 { 339 | let item_embeddings = &self.item_embedding; 340 | let item_biases = &self.item_biases; 341 | 342 | let embeddings = item_embeddings.value(); 343 | let biases = item_biases.value(); 344 | 345 | let embedding = embeddings.subview(Axis(0), item_idx); 346 | let bias = biases[(item_idx, 0)]; 347 | let dot = wyrm::simd_dot(user, embedding.as_slice().unwrap()); 348 | 349 | bias + dot 350 | } 351 | } 352 | 353 | struct Model { 354 | inputs: Vec>, 355 | outputs: Vec>, 356 | negatives: Vec>, 357 | hidden_states: Vec>, 358 | summed_losses: Vec>, 359 | } 360 | 361 | impl SequenceModel for Model { 362 | fn state( 363 | &self, 364 | ) -> ( 365 | &[Variable], 366 | &[Variable], 367 | &[Variable], 368 | &[Variable], 369 | ) { 370 | ( 371 | &self.inputs, 372 | &self.outputs, 373 | &self.negatives, 374 | &self.hidden_states, 375 | ) 376 | } 377 | fn losses(&mut self) -> &mut [Variable] { 378 | &mut self.summed_losses 379 | } 380 | fn hidden_states(&mut self) -> &mut [Variable] { 381 | &mut self.hidden_states 382 | } 383 | } 384 | 385 | /// An LSTM-based sequence model for implicit feedback. 386 | #[derive(Clone, Debug, Serialize, Deserialize)] 387 | pub struct ImplicitLSTMModel { 388 | params: Parameters, 389 | } 390 | 391 | impl ImplicitLSTMModel { 392 | /// Fit the model. 393 | /// 394 | /// Returns the loss value. 395 | pub fn fit(&mut self, interactions: &CompressedInteractions) -> Result { 396 | fit_sequence_model(interactions, &mut self.params) 397 | } 398 | } 399 | 400 | impl OnlineRankingModel for ImplicitLSTMModel { 401 | type UserRepresentation = ImplicitUser; 402 | fn user_representation( 403 | &self, 404 | item_ids: &[ItemId], 405 | ) -> Result { 406 | self.params.user_representation(item_ids) 407 | } 408 | 409 | fn predict( 410 | &self, 411 | user: &Self::UserRepresentation, 412 | item_ids: &[ItemId], 413 | ) -> Result, PredictionError> { 414 | self.params.predict(user, item_ids) 415 | } 416 | } 417 | 418 | #[cfg(test)] 419 | mod tests { 420 | use std::time::Instant; 421 | 422 | use super::*; 423 | use crate::data::{user_based_split, Interactions}; 424 | use crate::datasets::download_movielens_100k; 425 | use crate::evaluation::mrr_score; 426 | 427 | fn run_test(mut data: Interactions, hyperparameters: Hyperparameters) -> (f32, f32) { 428 | let mut rng = rand::XorShiftRng::from_seed([42; 16]); 429 | 430 | let (train, test) = user_based_split(&mut data, &mut rng, 0.2); 431 | let train_mat = train.to_compressed(); 432 | let test_mat = test.to_compressed(); 433 | 434 | let mut model = hyperparameters.rng(rng).build(); 435 | 436 | let start = Instant::now(); 437 | let loss = model.fit(&train_mat).unwrap(); 438 | let elapsed = start.elapsed(); 439 | let train_mrr = mrr_score(&model, &train_mat).unwrap(); 440 | let test_mrr = mrr_score(&model, &test_mat).unwrap(); 441 | 442 | println!( 443 | "Train MRR {} at loss {} and test MRR {} (in {:?})", 444 | train_mrr, loss, test_mrr, elapsed 445 | ); 446 | 447 | (test_mrr, train_mrr) 448 | } 449 | 450 | #[test] 451 | fn mrr_test_single_thread() { 452 | let data = download_movielens_100k().unwrap(); 453 | 454 | let hyperparameters = Hyperparameters::new(data.num_items(), 128) 455 | .embedding_dim(32) 456 | .learning_rate(0.16) 457 | .l2_penalty(0.0004) 458 | .lstm_variant(LSTMVariant::Normal) 459 | .loss(Loss::Hinge) 460 | .optimizer(Optimizer::Adagrad) 461 | .num_epochs(10) 462 | .num_threads(1); 463 | 464 | let (test_mrr, _) = run_test(data, hyperparameters); 465 | 466 | let expected_mrr = match ::std::env::var("MKL_CBWR") { 467 | Ok(ref val) if val == "AVX" => 0.091, 468 | _ => 0.081, 469 | }; 470 | 471 | assert!(test_mrr > expected_mrr) 472 | } 473 | 474 | #[test] 475 | fn mrr_test_two_threads() { 476 | let data = download_movielens_100k().unwrap(); 477 | 478 | let hyperparameters = Hyperparameters::new(data.num_items(), 128) 479 | .embedding_dim(32) 480 | .learning_rate(0.16) 481 | .l2_penalty(0.0004) 482 | .lstm_variant(LSTMVariant::Normal) 483 | .loss(Loss::Hinge) 484 | .optimizer(Optimizer::Adagrad) 485 | .num_epochs(10) 486 | .num_threads(2); 487 | 488 | let (test_mrr, _) = run_test(data, hyperparameters); 489 | 490 | let expected_mrr = match ::std::env::var("MKL_CBWR") { 491 | Ok(ref val) if val == "AVX" => 0.078, 492 | _ => 0.074, 493 | }; 494 | 495 | assert!(test_mrr > expected_mrr) 496 | } 497 | 498 | #[test] 499 | fn mrr_test_warp() { 500 | let data = download_movielens_100k().unwrap(); 501 | 502 | let hyperparameters = Hyperparameters::new(data.num_items(), 128) 503 | .embedding_dim(32) 504 | .learning_rate(0.16) 505 | .l2_penalty(0.0004) 506 | .lstm_variant(LSTMVariant::Normal) 507 | .loss(Loss::WARP) 508 | .optimizer(Optimizer::Adagrad) 509 | .num_epochs(10) 510 | .num_threads(1); 511 | 512 | let (test_mrr, _) = run_test(data, hyperparameters); 513 | 514 | let expected_mrr = match ::std::env::var("MKL_CBWR") { 515 | Ok(ref val) if val == "AVX" => 0.089, 516 | _ => 0.10, 517 | }; 518 | 519 | assert!(test_mrr > expected_mrr) 520 | } 521 | 522 | #[test] 523 | fn empty_interactions() { 524 | let data = Interactions::new(100, 100).to_compressed(); 525 | let mut model = Hyperparameters::new(100, 100).build(); 526 | match model.fit(&data) { 527 | Err(FittingError::NoInteractions) => {} 528 | _ => panic!("No error returned."), 529 | } 530 | } 531 | 532 | } 533 | -------------------------------------------------------------------------------- /src/models/mod.rs: -------------------------------------------------------------------------------- 1 | //! Models module. 2 | use serde::{Deserialize, Serialize}; 3 | 4 | pub mod ewma; 5 | pub mod lstm; 6 | mod sequence_model; 7 | 8 | /// The user representation used by implicit sequence models. 9 | #[derive(Clone, Debug)] 10 | pub struct ImplicitUser { 11 | user_embedding: Vec, 12 | } 13 | 14 | /// The loss used for training the model. 15 | #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] 16 | pub enum Loss { 17 | /// Bayesian Personalised Ranking. 18 | BPR, 19 | /// Pairwise hinge loss. 20 | Hinge, 21 | /// WARP 22 | WARP, 23 | } 24 | 25 | /// Optimizer user to train the model. 26 | #[derive(Clone, Debug, Serialize, Deserialize)] 27 | pub enum Optimizer { 28 | /// Adagrad. 29 | Adagrad, 30 | /// Adam. 31 | Adam, 32 | } 33 | 34 | /// Type of parallelism used to train the model. 35 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 36 | pub enum Parallelism { 37 | /// Multiple threads operate in parallel without any locking. 38 | Asynchronous, 39 | /// Multiple threads synchronise parameters between minibatches. 40 | Synchronous, 41 | } 42 | -------------------------------------------------------------------------------- /src/models/sequence_model.rs: -------------------------------------------------------------------------------- 1 | use rand::distributions::Distribution; 2 | use rand::distributions::Uniform; 3 | use rand::{Rng, SeedableRng, XorShiftRng}; 4 | use rayon::prelude::*; 5 | 6 | use wyrm; 7 | use wyrm::optim::{Optimizer as Optim, Optimizers, Synchronizable}; 8 | use wyrm::{BoxedNode, DataInput, Variable}; 9 | 10 | use super::{ImplicitUser, Loss, Parallelism}; 11 | use crate::data::CompressedInteractions; 12 | use crate::{FittingError, ItemId, OnlineRankingModel, PredictionError}; 13 | 14 | pub trait SequenceModelParameters { 15 | type Output; 16 | fn max_sequence_length(&self) -> usize; 17 | fn num_threads(&self) -> usize; 18 | fn rng(&mut self) -> &mut XorShiftRng; 19 | fn optimizer(&self) -> Optimizers; 20 | fn parallelism(&self) -> &Parallelism; 21 | fn loss(&self) -> &Loss; 22 | fn num_epochs(&self) -> usize; 23 | fn build(&self) -> Self::Output; 24 | fn predict_single(&self, user: &[f32], item_idx: usize) -> f32; 25 | } 26 | 27 | /// Trait expressing a sequence model. 28 | pub trait SequenceModel { 29 | /// Return the sequence losses of the model. 30 | fn losses(&mut self) -> &mut [Variable]; 31 | /// Return the inner state of the model. These are: 32 | /// - inputs 33 | /// - targets 34 | /// - negatives 35 | /// - hidden states. 36 | fn state( 37 | &self, 38 | ) -> ( 39 | &[Variable], 40 | &[Variable], 41 | &[Variable], 42 | &[Variable], 43 | ); 44 | fn hidden_states(&mut self) -> &mut [Variable]; 45 | } 46 | 47 | fn sample_warp_negative>( 48 | parameters: &T, 49 | hidden_state: &[f32], 50 | positive_idx: usize, 51 | negative_item_range: &Uniform, 52 | thread_rng: &mut XorShiftRng, 53 | ) -> usize { 54 | let pos_prediction = parameters.predict_single(hidden_state, positive_idx); 55 | 56 | let mut negative_idx = 0; 57 | 58 | for _ in 0..5 { 59 | negative_idx = negative_item_range.sample(thread_rng); 60 | let neg_prediction = parameters.predict_single(hidden_state, negative_idx); 61 | 62 | if 1.0 - pos_prediction + neg_prediction > 0.0 { 63 | break; 64 | } 65 | } 66 | 67 | negative_idx 68 | } 69 | 70 | pub fn fit_sequence_model + Sync>( 71 | interactions: &CompressedInteractions, 72 | parameters: &mut T, 73 | ) -> Result { 74 | let negative_item_range = Uniform::new(0, interactions.num_items()); 75 | 76 | let mut subsequences: Vec<_> = interactions 77 | .iter_users() 78 | .flat_map(|user| { 79 | user.chunks(parameters.max_sequence_length()) 80 | .map(|(item_ids, _)| item_ids) 81 | .filter(|item_ids| item_ids.len() > 2) 82 | }) 83 | .collect(); 84 | parameters.rng().shuffle(&mut subsequences); 85 | 86 | if subsequences.is_empty() { 87 | return Err(FittingError::NoInteractions); 88 | } 89 | 90 | let optimizer = parameters.optimizer(); 91 | let num_chunks = subsequences.len() / parameters.num_threads(); 92 | let sync_optim = optimizer.synchronized(parameters.num_threads()); 93 | 94 | let mut partitions: Vec<_> = subsequences 95 | .chunks_mut(num_chunks) 96 | .zip(sync_optim.into_iter()) 97 | .map(|(chunk, optim)| (chunk, XorShiftRng::from_seed(parameters.rng().gen()), optim)) 98 | .collect(); 99 | 100 | let loss = partitions 101 | .par_iter_mut() 102 | .map(|(partition, ref mut thread_rng, sync_optim)| { 103 | let mut model = parameters.build(); 104 | 105 | let mut loss_value = 0.0; 106 | let mut examples = 0; 107 | 108 | for _ in 0..parameters.num_epochs() { 109 | thread_rng.shuffle(partition); 110 | 111 | for &item_ids in partition.iter() { 112 | { 113 | let (inputs, outputs, negatives, hidden_states) = model.state(); 114 | 115 | for (&input_idx, &output_idx, input, output, negative, hidden) in izip!( 116 | item_ids, 117 | item_ids.iter().skip(1), 118 | inputs, 119 | outputs, 120 | negatives, 121 | hidden_states 122 | ) { 123 | input.set_value(input_idx); 124 | 125 | let negative_idx = if parameters.loss() == &Loss::WARP { 126 | hidden.forward(); 127 | let hidden_state = hidden.value(); 128 | 129 | sample_warp_negative( 130 | parameters, 131 | hidden_state.as_slice().unwrap(), 132 | output_idx, 133 | &negative_item_range, 134 | thread_rng, 135 | ) 136 | } else { 137 | negative_item_range.sample(thread_rng) 138 | }; 139 | 140 | output.set_value(output_idx); 141 | negative.set_value(negative_idx); 142 | } 143 | } 144 | 145 | // Get the loss at the end of the sequence. 146 | let loss_idx = item_ids.len().saturating_sub(2); 147 | 148 | // We need to clear the graph if the loss is WARP 149 | // in order for backpropagation to trigger correctly. 150 | // This is because by calling forward we've added the 151 | // resulting nodes to the graph. 152 | if parameters.loss() == &Loss::WARP { 153 | model.hidden_states()[loss_idx].clear(); 154 | } 155 | 156 | let loss = &mut model.losses()[loss_idx]; 157 | loss_value += loss.value().scalar_sum(); 158 | examples += loss_idx + 1; 159 | 160 | loss.forward(); 161 | loss.backward(1.0); 162 | 163 | if parameters.num_threads() > 1 164 | && parameters.parallelism() == &Parallelism::Synchronous 165 | { 166 | sync_optim.step(loss.parameters()); 167 | } else { 168 | optimizer.step(loss.parameters()); 169 | } 170 | } 171 | } 172 | 173 | loss_value / (1.0 + examples as f32) 174 | }) 175 | .sum(); 176 | 177 | Ok(loss) 178 | } 179 | 180 | impl + Sync> OnlineRankingModel for T { 181 | type UserRepresentation = ImplicitUser; 182 | fn user_representation( 183 | &self, 184 | item_ids: &[ItemId], 185 | ) -> Result { 186 | let model = self.build(); 187 | 188 | let item_ids = &item_ids[item_ids.len().saturating_sub(self.max_sequence_length())..]; 189 | 190 | let (inputs, _, _, hidden_states) = model.state(); 191 | 192 | for (&input_idx, input) in izip!(item_ids, inputs) { 193 | input.set_value(input_idx); 194 | } 195 | 196 | // Get the loss at the end of the sequence. 197 | let loss_idx = item_ids.len().saturating_sub(1); 198 | 199 | // Select the hidden state after ingesting all the inputs. 200 | let hidden_state = &hidden_states[loss_idx]; 201 | 202 | // Run the network forward up to that point. 203 | hidden_state.forward(); 204 | 205 | // Get the value. 206 | let representation = hidden_state.value(); 207 | 208 | Ok(ImplicitUser { 209 | user_embedding: representation.as_slice().unwrap().to_owned(), 210 | }) 211 | } 212 | 213 | fn predict( 214 | &self, 215 | user: &Self::UserRepresentation, 216 | item_ids: &[ItemId], 217 | ) -> Result, PredictionError> { 218 | let user_slice = &user.user_embedding; 219 | 220 | item_ids 221 | .iter() 222 | .map(|&item_idx| { 223 | let prediction = self.predict_single(user_slice, item_idx); 224 | 225 | if prediction.is_finite() { 226 | Ok(prediction) 227 | } else { 228 | Err(PredictionError::InvalidPredictionValue) 229 | } 230 | }) 231 | .collect() 232 | } 233 | } 234 | --------------------------------------------------------------------------------