├── .gitignore ├── README.md ├── environment.yml ├── metrics.py ├── requirements.txt ├── run.py ├── torchmf.py ├── trainplot.png └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # pycharm 92 | .idea/ 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchmf 2 | 3 | matrix factorization in PyTorch 4 | 5 | ## Installation 6 | 7 | ### pip 8 | 9 | ```commandline 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ### conda 14 | 15 | ```commandline 16 | conda env create -f environment.yml 17 | conda activate torchmf 18 | ``` 19 | 20 | ## Example 21 | 22 | ```bash 23 | $ python run.py --example explicit 24 | ( 1 ): 100%|████████████████████████████████████████████████████████| 89/89 [00:01<00:00, 85.88it/s, train_loss=7.9] 25 | Epoch: 1 train: 15.04790 val: 8.84972 26 | ( 2 ): 100%|███████████████████████████████████████████████████████| 89/89 [00:01<00:00, 84.06it/s, train_loss=2.96] 27 | Epoch: 2 train: 4.34132 val: 4.04638 28 | ( 3 ): 100%|███████████████████████████████████████████████████████| 89/89 [00:01<00:00, 81.54it/s, train_loss=1.51] 29 | Epoch: 3 train: 1.87918 val: 2.43315 30 | ( 4 ): 100%|███████████████████████████████████████████████████████| 89/89 [00:01<00:00, 85.55it/s, train_loss=1.19] 31 | Epoch: 4 train: 1.21419 val: 1.80296 32 | ( 5 ): 100%|██████████████████████████████████████████████████████| 89/89 [00:00<00:00, 90.87it/s, train_loss=0.945] 33 | Epoch: 5 train: 0.99693 val: 1.49770 34 | ( 6 ): 100%|██████████████████████████████████████████████████████| 89/89 [00:00<00:00, 89.33it/s, train_loss=0.914] 35 | Epoch: 6 train: 0.90174 val: 1.33501 36 | ( 7 ): 100%|██████████████████████████████████████████████████████| 89/89 [00:00<00:00, 115.70it/s, train_loss=0.83] 37 | Epoch: 7 train: 0.85230 val: 1.23783 38 | ( 8 ): 100%|██████████████████████████████████████████████████████| 89/89 [00:01<00:00, 88.85it/s, train_loss=0.879] 39 | Epoch: 8 train: 0.82072 val: 1.17781 40 | ( 9 ): 100%|█████████████████████████████████████████████████████| 89/89 [00:00<00:00, 119.93it/s, train_loss=0.766] 41 | Epoch: 9 train: 0.79898 val: 1.13976 42 | (10 ): 100%|█████████████████████████████████████████████████████| 89/89 [00:00<00:00, 122.18it/s, train_loss=0.736] 43 | Epoch: 10 train: 0.77820 val: 1.10951 44 | ``` 45 | 46 | which looks something like 47 | 48 | ![train plot](trainplot.png) 49 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: torchmf 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - asn1crypto=0.24.0 7 | - blas=1.0 8 | - ca-certificates=2019.5.15 9 | - certifi=2019.6.16 10 | - cffi=1.12.3 11 | - chardet=3.0.4 12 | - cryptography=2.7 13 | - freetype=2.9.1 14 | - idna=2.8 15 | - intel-openmp=2019.4 16 | - joblib=0.13.2 17 | - jpeg=9b 18 | - libcxx=4.0.1 19 | - libcxxabi=4.0.1 20 | - libedit=3.1.20181209 21 | - libffi=3.2.1 22 | - libgfortran=3.0.1 23 | - libpng=1.6.37 24 | - libtiff=4.0.10 25 | - llvm-openmp=4.0.1 26 | - mkl=2019.4 27 | - mkl-service=2.0.2 28 | - mkl_fft=1.0.12 29 | - mkl_random=1.0.2 30 | - ncurses=6.1 31 | - ninja=1.9.0 32 | - numpy=1.16.4 33 | - numpy-base=1.16.4 34 | - olefile=0.46 35 | - openssl=1.1.1c 36 | - pandas=0.24.2 37 | - pillow=6.0.0 38 | - pip=19.1.1 39 | - pycparser=2.19 40 | - pyopenssl=19.0.0 41 | - pysocks=1.7.0 42 | - python=3.6.8 43 | - python-dateutil=2.8.0 44 | - pytorch=1.1.0 45 | - pytz=2019.1 46 | - readline=7.0 47 | - requests=2.22.0 48 | - scikit-learn=0.21.2 49 | - scipy=1.2.1 50 | - setuptools=41.0.1 51 | - six=1.12.0 52 | - sqlite=3.28.0 53 | - tk=8.6.8 54 | - torchvision=0.3.0 55 | - tqdm=4.32.1 56 | - urllib3=1.24.2 57 | - wheel=0.33.4 58 | - xz=5.2.4 59 | - zlib=1.2.11 60 | - zstd=1.3.7 61 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import roc_auc_score 3 | from torch import multiprocessing as mp 4 | import torch 5 | 6 | 7 | def get_row_indices(row, interactions): 8 | start = interactions.indptr[row] 9 | end = interactions.indptr[row + 1] 10 | return interactions.indices[start:end] 11 | 12 | 13 | def auc(model, interactions, num_workers=1): 14 | aucs = [] 15 | processes = [] 16 | n_users = interactions.shape[0] 17 | mp_batch = int(np.ceil(n_users / num_workers)) 18 | 19 | queue = mp.Queue() 20 | rows = np.arange(n_users) 21 | np.random.shuffle(rows) 22 | for rank in range(num_workers): 23 | start = rank * mp_batch 24 | end = np.min((start + mp_batch, n_users)) 25 | p = mp.Process(target=batch_auc, 26 | args=(queue, rows[start:end], interactions, model)) 27 | p.start() 28 | processes.append(p) 29 | 30 | while True: 31 | is_alive = False 32 | for p in processes: 33 | if p.is_alive(): 34 | is_alive = True 35 | break 36 | if not is_alive and queue.empty(): 37 | break 38 | 39 | while not queue.empty(): 40 | aucs.append(queue.get()) 41 | 42 | queue.close() 43 | for p in processes: 44 | p.join() 45 | return np.mean(aucs) 46 | 47 | 48 | def batch_auc(queue, rows, interactions, model): 49 | n_items = interactions.shape[1] 50 | items = torch.arange(0, n_items).long() 51 | users_init = torch.ones(n_items).long() 52 | for row in rows: 53 | row = int(row) 54 | users = users_init.fill_(row) 55 | 56 | preds = model.predict(users, items) 57 | actuals = get_row_indices(row, interactions) 58 | 59 | if len(actuals) == 0: 60 | continue 61 | y_test = np.zeros(n_items) 62 | y_test[actuals] = 1 63 | queue.put(roc_auc_score(y_test, preds.data.numpy())) 64 | 65 | 66 | def patk(model, interactions, num_workers=1, k=5): 67 | patks = [] 68 | processes = [] 69 | n_users = interactions.shape[0] 70 | mp_batch = int(np.ceil(n_users / num_workers)) 71 | 72 | queue = mp.Queue() 73 | rows = np.arange(n_users) 74 | np.random.shuffle(rows) 75 | for rank in range(num_workers): 76 | start = rank * mp_batch 77 | end = np.min((start + mp_batch, n_users)) 78 | p = mp.Process(target=batch_patk, 79 | args=(queue, rows[start:end], interactions, model), 80 | kwargs={'k': k}) 81 | p.start() 82 | processes.append(p) 83 | 84 | while True: 85 | is_alive = False 86 | for p in processes: 87 | if p.is_alive(): 88 | is_alive = True 89 | break 90 | if not is_alive and queue.empty(): 91 | break 92 | 93 | while not queue.empty(): 94 | patks.append(queue.get()) 95 | 96 | queue.close() 97 | for p in processes: 98 | p.join() 99 | return np.mean(patks) 100 | 101 | 102 | def batch_patk(queue, rows, interactions, model, k=5): 103 | n_items = interactions.shape[1] 104 | 105 | items = torch.arange(0, n_items).long() 106 | users_init = torch.ones(n_items).long() 107 | for row in rows: 108 | row = int(row) 109 | users = users_init.fill_(row) 110 | 111 | preds = model.predict(users, items) 112 | actuals = get_row_indices(row, interactions) 113 | 114 | if len(actuals) == 0: 115 | continue 116 | 117 | top_k = np.argpartition(-np.squeeze(preds.data.numpy()), k) 118 | top_k = set(top_k[:k]) 119 | true_pids = set(actuals) 120 | if true_pids: 121 | queue.put(len(top_k & true_pids) / float(k)) 122 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.16.4 2 | pandas==0.24.2 3 | requests==2.22.0 4 | scikit-learn==0.21.2 5 | scipy==1.2.1 6 | torch==1.1.0 7 | tqdm==4.32.1 8 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import torch 5 | 6 | from torchmf import (BaseModule, BPRModule, BasePipeline, 7 | bpr_loss, PairwiseInteractions) 8 | import utils 9 | 10 | 11 | def explicit(): 12 | train, test = utils.get_movielens_train_test_split() 13 | pipeline = BasePipeline(train, test=test, model=BaseModule, 14 | n_factors=10, batch_size=1024, dropout_p=0.02, 15 | lr=0.02, weight_decay=0.1, 16 | optimizer=torch.optim.Adam, n_epochs=40, 17 | verbose=True, random_seed=2017) 18 | pipeline.fit() 19 | 20 | 21 | def implicit(): 22 | train, test = utils.get_movielens_train_test_split(implicit=True) 23 | 24 | pipeline = BasePipeline(train, test=test, verbose=True, 25 | batch_size=1024, num_workers=4, 26 | n_factors=20, weight_decay=0, 27 | dropout_p=0., lr=.2, sparse=True, 28 | optimizer=torch.optim.SGD, n_epochs=40, 29 | random_seed=2017, loss_function=bpr_loss, 30 | model=BPRModule, 31 | interaction_class=PairwiseInteractions, 32 | eval_metrics=('auc', 'patk')) 33 | pipeline.fit() 34 | 35 | 36 | def hogwild(): 37 | train, test = utils.get_movielens_train_test_split(implicit=True) 38 | 39 | pipeline = BasePipeline(train, test=test, verbose=True, 40 | batch_size=1024, num_workers=4, 41 | n_factors=20, weight_decay=0, 42 | dropout_p=0., lr=.2, sparse=True, 43 | optimizer=torch.optim.SGD, n_epochs=40, 44 | random_seed=2017, loss_function=bpr_loss, 45 | model=BPRModule, hogwild=True, 46 | interaction_class=PairwiseInteractions, 47 | eval_metrics=('auc', 'patk')) 48 | pipeline.fit() 49 | 50 | 51 | if __name__ == '__main__': 52 | parser = argparse.ArgumentParser(description='torchmf') 53 | parser.add_argument('--example', 54 | help='explicit, implicit, or hogwild') 55 | args = parser.parse_args() 56 | if args.example == 'explicit': 57 | explicit() 58 | elif args.example == 'implicit': 59 | implicit() 60 | elif args.example == 'hogwild': 61 | hogwild() 62 | else: 63 | print('example must be explicit, implicit, or hogwild') 64 | 65 | -------------------------------------------------------------------------------- /torchmf.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | 4 | import numpy as np 5 | from sklearn.metrics import roc_auc_score 6 | import torch 7 | from torch import nn 8 | import torch.multiprocessing as mp 9 | import torch.utils.data as data 10 | from tqdm import tqdm 11 | 12 | import metrics 13 | 14 | 15 | # Models 16 | # Interactions Dataset => Singular Iter => Singular Loss 17 | # Pairwise Datasets => Pairwise Iter => Pairwise Loss 18 | # Pairwise Iters 19 | # Loss Functions 20 | # Optimizers 21 | # Metric callbacks 22 | 23 | # Serve up users, items (and items could be pos_items, neg_items) 24 | # In this case, the iteration remains the same. Pass both items into a model 25 | # which is a concat of the base model. it handles the pos and neg_items 26 | # accordingly. define the loss after. 27 | 28 | 29 | class Interactions(data.Dataset): 30 | """ 31 | Hold data in the form of an interactions matrix. 32 | Typical use-case is like a ratings matrix: 33 | - Users are the rows 34 | - Items are the columns 35 | - Elements of the matrix are the ratings given by a user for an item. 36 | """ 37 | 38 | def __init__(self, mat): 39 | self.mat = mat.astype(np.float32).tocoo() 40 | self.n_users = self.mat.shape[0] 41 | self.n_items = self.mat.shape[1] 42 | 43 | def __getitem__(self, index): 44 | row = self.mat.row[index] 45 | col = self.mat.col[index] 46 | val = self.mat.data[index] 47 | return (row, col), val 48 | 49 | def __len__(self): 50 | return self.mat.nnz 51 | 52 | 53 | class PairwiseInteractions(data.Dataset): 54 | """ 55 | Sample data from an interactions matrix in a pairwise fashion. The row is 56 | treated as the main dimension, and the columns are sampled pairwise. 57 | """ 58 | 59 | def __init__(self, mat): 60 | self.mat = mat.astype(np.float32).tocoo() 61 | 62 | self.n_users = self.mat.shape[0] 63 | self.n_items = self.mat.shape[1] 64 | 65 | self.mat_csr = self.mat.tocsr() 66 | if not self.mat_csr.has_sorted_indices: 67 | self.mat_csr.sort_indices() 68 | 69 | def __getitem__(self, index): 70 | row = self.mat.row[index] 71 | found = False 72 | 73 | while not found: 74 | neg_col = np.random.randint(self.n_items) 75 | if self.not_rated(row, neg_col, self.mat_csr.indptr, 76 | self.mat_csr.indices): 77 | found = True 78 | 79 | pos_col = self.mat.col[index] 80 | val = self.mat.data[index] 81 | 82 | return (row, (pos_col, neg_col)), val 83 | 84 | def __len__(self): 85 | return self.mat.nnz 86 | 87 | @staticmethod 88 | def not_rated(row, col, indptr, indices): 89 | # similar to use of bsearch in lightfm 90 | start = indptr[row] 91 | end = indptr[row + 1] 92 | searched = np.searchsorted(indices[start:end], col, 'right') 93 | if searched >= (end - start): 94 | # After the array 95 | return False 96 | return col != indices[searched] # Not found 97 | 98 | def get_row_indices(self, row): 99 | start = self.mat_csr.indptr[row] 100 | end = self.mat_csr.indptr[row + 1] 101 | return self.mat_csr.indices[start:end] 102 | 103 | 104 | class BaseModule(nn.Module): 105 | """ 106 | Base module for explicit matrix factorization. 107 | """ 108 | 109 | def __init__(self, 110 | n_users, 111 | n_items, 112 | n_factors=40, 113 | dropout_p=0, 114 | sparse=False): 115 | """ 116 | 117 | Parameters 118 | ---------- 119 | n_users : int 120 | Number of users 121 | n_items : int 122 | Number of items 123 | n_factors : int 124 | Number of latent factors (or embeddings or whatever you want to 125 | call it). 126 | dropout_p : float 127 | p in nn.Dropout module. Probability of dropout. 128 | sparse : bool 129 | Whether or not to treat embeddings as sparse. NOTE: cannot use 130 | weight decay on the optimizer if sparse=True. Also, can only use 131 | Adagrad. 132 | """ 133 | super(BaseModule, self).__init__() 134 | self.n_users = n_users 135 | self.n_items = n_items 136 | self.n_factors = n_factors 137 | self.user_biases = nn.Embedding(n_users, 1, sparse=sparse) 138 | self.item_biases = nn.Embedding(n_items, 1, sparse=sparse) 139 | self.user_embeddings = nn.Embedding(n_users, n_factors, sparse=sparse) 140 | self.item_embeddings = nn.Embedding(n_items, n_factors, sparse=sparse) 141 | 142 | self.dropout_p = dropout_p 143 | self.dropout = nn.Dropout(p=self.dropout_p) 144 | 145 | self.sparse = sparse 146 | 147 | def forward(self, users, items): 148 | """ 149 | Forward pass through the model. For a single user and item, this 150 | looks like: 151 | 152 | user_bias + item_bias + user_embeddings.dot(item_embeddings) 153 | 154 | Parameters 155 | ---------- 156 | users : np.ndarray 157 | Array of user indices 158 | items : np.ndarray 159 | Array of item indices 160 | 161 | Returns 162 | ------- 163 | preds : np.ndarray 164 | Predicted ratings. 165 | 166 | """ 167 | ues = self.user_embeddings(users) 168 | uis = self.item_embeddings(items) 169 | 170 | preds = self.user_biases(users) 171 | preds += self.item_biases(items) 172 | preds += (self.dropout(ues) * self.dropout(uis)).sum(dim=1, keepdim=True) 173 | 174 | return preds.squeeze() 175 | 176 | def __call__(self, *args): 177 | return self.forward(*args) 178 | 179 | def predict(self, users, items): 180 | return self.forward(users, items) 181 | 182 | 183 | def bpr_loss(preds, vals): 184 | sig = nn.Sigmoid() 185 | return (1.0 - sig(preds)).pow(2).sum() 186 | 187 | 188 | class BPRModule(nn.Module): 189 | 190 | def __init__(self, 191 | n_users, 192 | n_items, 193 | n_factors=40, 194 | dropout_p=0, 195 | sparse=False, 196 | model=BaseModule): 197 | super(BPRModule, self).__init__() 198 | 199 | self.n_users = n_users 200 | self.n_items = n_items 201 | self.n_factors = n_factors 202 | self.dropout_p = dropout_p 203 | self.sparse = sparse 204 | self.pred_model = model( 205 | self.n_users, 206 | self.n_items, 207 | n_factors=n_factors, 208 | dropout_p=dropout_p, 209 | sparse=sparse 210 | ) 211 | 212 | def forward(self, users, items): 213 | assert isinstance(items, tuple), \ 214 | 'Must pass in items as (pos_items, neg_items)' 215 | # Unpack 216 | (pos_items, neg_items) = items 217 | pos_preds = self.pred_model(users, pos_items) 218 | neg_preds = self.pred_model(users, neg_items) 219 | return pos_preds - neg_preds 220 | 221 | def predict(self, users, items): 222 | return self.pred_model(users, items) 223 | 224 | 225 | class BasePipeline: 226 | """ 227 | Class defining a training pipeline. Instantiates data loaders, model, 228 | and optimizer. Handles training for multiple epochs and keeping track of 229 | train and test loss. 230 | """ 231 | 232 | def __init__(self, 233 | train, 234 | test=None, 235 | model=BaseModule, 236 | n_factors=40, 237 | batch_size=32, 238 | dropout_p=0.02, 239 | sparse=False, 240 | lr=0.01, 241 | weight_decay=0., 242 | optimizer=torch.optim.Adam, 243 | loss_function=nn.MSELoss(reduction='sum'), 244 | n_epochs=10, 245 | verbose=False, 246 | random_seed=None, 247 | interaction_class=Interactions, 248 | hogwild=False, 249 | num_workers=0, 250 | eval_metrics=None, 251 | k=5): 252 | self.train = train 253 | self.test = test 254 | 255 | if hogwild: 256 | num_loader_workers = 0 257 | else: 258 | num_loader_workers = num_workers 259 | self.train_loader = data.DataLoader( 260 | interaction_class(train), batch_size=batch_size, shuffle=True, 261 | num_workers=num_loader_workers) 262 | if self.test is not None: 263 | self.test_loader = data.DataLoader( 264 | interaction_class(test), batch_size=batch_size, shuffle=True, 265 | num_workers=num_loader_workers) 266 | self.num_workers = num_workers 267 | self.n_users = self.train.shape[0] 268 | self.n_items = self.train.shape[1] 269 | self.n_factors = n_factors 270 | self.batch_size = batch_size 271 | self.dropout_p = dropout_p 272 | self.lr = lr 273 | self.weight_decay = weight_decay 274 | self.loss_function = loss_function 275 | self.n_epochs = n_epochs 276 | if sparse: 277 | assert weight_decay == 0.0 278 | self.model = model(self.n_users, 279 | self.n_items, 280 | n_factors=self.n_factors, 281 | dropout_p=self.dropout_p, 282 | sparse=sparse) 283 | self.optimizer = optimizer(self.model.parameters(), 284 | lr=self.lr, 285 | weight_decay=self.weight_decay) 286 | self.warm_start = False 287 | self.losses = collections.defaultdict(list) 288 | self.verbose = verbose 289 | self.hogwild = hogwild 290 | if random_seed is not None: 291 | if self.hogwild: 292 | random_seed += os.getpid() 293 | torch.manual_seed(random_seed) 294 | np.random.seed(random_seed) 295 | 296 | if eval_metrics is None: 297 | eval_metrics = [] 298 | self.eval_metrics = eval_metrics 299 | self.k = k 300 | 301 | def break_grads(self): 302 | for param in self.model.parameters(): 303 | # Break gradient sharing 304 | if param.grad is not None: 305 | param.grad.data = param.grad.data.clone() 306 | 307 | def fit(self): 308 | for epoch in range(1, self.n_epochs + 1): 309 | 310 | if self.hogwild: 311 | self.model.share_memory() 312 | processes = [] 313 | train_losses = [] 314 | queue = mp.Queue() 315 | for rank in range(self.num_workers): 316 | p = mp.Process(target=self._fit_epoch, 317 | kwargs={'epoch': epoch, 318 | 'queue': queue}) 319 | p.start() 320 | processes.append(p) 321 | for p in processes: 322 | p.join() 323 | 324 | while True: 325 | is_alive = False 326 | for p in processes: 327 | if p.is_alive(): 328 | is_alive = True 329 | break 330 | if not is_alive and queue.empty(): 331 | break 332 | 333 | while not queue.empty(): 334 | train_losses.append(queue.get()) 335 | queue.close() 336 | train_loss = np.mean(train_losses) 337 | else: 338 | train_loss = self._fit_epoch(epoch) 339 | 340 | self.losses['train'].append(train_loss) 341 | row = 'Epoch: {0:^3} train: {1:^10.5f}'.format(epoch, self.losses['train'][-1]) 342 | if self.test is not None: 343 | self.losses['test'].append(self._validation_loss()) 344 | row += 'val: {0:^10.5f}'.format(self.losses['test'][-1]) 345 | for metric in self.eval_metrics: 346 | func = getattr(metrics, metric) 347 | res = func(self.model, self.test_loader.dataset.mat_csr, 348 | num_workers=self.num_workers) 349 | self.losses['eval-{}'.format(metric)].append(res) 350 | row += 'eval-{0}: {1:^10.5f}'.format(metric, res) 351 | self.losses['epoch'].append(epoch) 352 | if self.verbose: 353 | print(row) 354 | 355 | def _fit_epoch(self, epoch=1, queue=None): 356 | if self.hogwild: 357 | self.break_grads() 358 | 359 | self.model.train() 360 | total_loss = torch.Tensor([0]) 361 | pbar = tqdm(enumerate(self.train_loader), 362 | total=len(self.train_loader), 363 | desc='({0:^3})'.format(epoch)) 364 | for batch_idx, ((row, col), val) in pbar: 365 | self.optimizer.zero_grad() 366 | 367 | row = row.long() 368 | # TODO: turn this into a collate_fn like the data_loader 369 | if isinstance(col, list): 370 | col = tuple(c.long() for c in col) 371 | else: 372 | col = col.long() 373 | val = val.float() 374 | 375 | preds = self.model(row, col) 376 | loss = self.loss_function(preds, val) 377 | loss.backward() 378 | 379 | self.optimizer.step() 380 | 381 | total_loss += loss.item() 382 | batch_loss = loss.item() / row.size()[0] 383 | pbar.set_postfix(train_loss=batch_loss) 384 | total_loss /= self.train.nnz 385 | if queue is not None: 386 | queue.put(total_loss[0]) 387 | else: 388 | return total_loss[0] 389 | 390 | def _validation_loss(self): 391 | self.model.eval() 392 | total_loss = torch.Tensor([0]) 393 | for batch_idx, ((row, col), val) in enumerate(self.test_loader): 394 | row = row.long() 395 | if isinstance(col, list): 396 | col = tuple(c.long() for c in col) 397 | else: 398 | col = col.long() 399 | val = val.float() 400 | 401 | preds = self.model(row, col) 402 | loss = self.loss_function(preds, val) 403 | total_loss += loss.item() 404 | 405 | total_loss /= self.test.nnz 406 | return total_loss[0] 407 | -------------------------------------------------------------------------------- /trainplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EthanRosenthal/torchmf/7832e2da4997886160ff5700f8dfa433634b1527/trainplot.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import zipfile 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import scipy.sparse as sp 8 | 9 | """ 10 | Shamelessly stolen from 11 | https://github.com/maciejkula/triplet_recommendations_keras 12 | """ 13 | 14 | 15 | def train_test_split(interactions, n=10): 16 | """ 17 | Split an interactions matrix into training and test sets. 18 | Parameters 19 | ---------- 20 | interactions : np.ndarray 21 | n : int (default=10) 22 | Number of items to select / row to place into test. 23 | 24 | Returns 25 | ------- 26 | train : np.ndarray 27 | test : np.ndarray 28 | """ 29 | test = np.zeros(interactions.shape) 30 | train = interactions.copy() 31 | for user in range(interactions.shape[0]): 32 | if interactions[user, :].nonzero()[0].shape[0] > n: 33 | test_interactions = np.random.choice(interactions[user, :].nonzero()[0], 34 | size=n, 35 | replace=False) 36 | train[user, test_interactions] = 0. 37 | test[user, test_interactions] = interactions[user, test_interactions] 38 | 39 | # Test and training are truly disjoint 40 | assert(np.all((train * test) == 0)) 41 | return train, test 42 | 43 | 44 | def _get_data_path(): 45 | """ 46 | Get path to the movielens dataset file. 47 | """ 48 | data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 49 | 'data') 50 | if not os.path.exists(data_path): 51 | print('Making data path') 52 | os.mkdir(data_path) 53 | return data_path 54 | 55 | 56 | def _download_movielens(dest_path): 57 | """ 58 | Download the dataset. 59 | """ 60 | 61 | url = 'http://files.grouplens.org/datasets/movielens/ml-100k.zip' 62 | req = requests.get(url, stream=True) 63 | 64 | print('Downloading MovieLens data') 65 | 66 | with open(os.path.join(dest_path, 'ml-100k.zip'), 'wb') as fd: 67 | for chunk in req.iter_content(chunk_size=None): 68 | fd.write(chunk) 69 | 70 | with zipfile.ZipFile(os.path.join(dest_path, 'ml-100k.zip'), 'r') as z: 71 | z.extractall(dest_path) 72 | 73 | 74 | def read_movielens_df(): 75 | path = _get_data_path() 76 | zipfile = os.path.join(path, 'ml-100k.zip') 77 | if not os.path.isfile(zipfile): 78 | _download_movielens(path) 79 | fname = os.path.join(path, 'ml-100k', 'u.data') 80 | names = ['user_id', 'item_id', 'rating', 'timestamp'] 81 | df = pd.read_csv(fname, sep='\t', names=names) 82 | return df 83 | 84 | 85 | def get_movielens_interactions(): 86 | df = read_movielens_df() 87 | 88 | n_users = df.user_id.unique().shape[0] 89 | n_items = df.item_id.unique().shape[0] 90 | 91 | interactions = np.zeros((n_users, n_items)) 92 | for row in df.itertuples(): 93 | interactions[row[1] - 1, row[2] - 1] = row[3] 94 | return interactions 95 | 96 | 97 | def get_movielens_train_test_split(implicit=False): 98 | interactions = get_movielens_interactions() 99 | if implicit: 100 | interactions = (interactions >= 4).astype(np.float32) 101 | train, test = train_test_split(interactions) 102 | train = sp.coo_matrix(train) 103 | test = sp.coo_matrix(test) 104 | return train, test 105 | --------------------------------------------------------------------------------