├── asset ├── triangle.png └── overview_vis.png ├── requirements.txt ├── hyper_params.py ├── parse.py ├── LICENSE ├── model.py ├── utils.py ├── README.md ├── eval.py ├── data.py └── main.py /asset/triangle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seoyoungh/svd-ae/HEAD/asset/triangle.png -------------------------------------------------------------------------------- /asset/overview_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seoyoungh/svd-ae/HEAD/asset/overview_vis.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0+cu113 2 | jax==0.3.25 3 | jaxlib==0.3.25 4 | h5py==3.7.0 5 | neural-tangents==0.6.1 6 | scipy==1.7.3 7 | numpy==1.21.6 -------------------------------------------------------------------------------- /hyper_params.py: -------------------------------------------------------------------------------- 1 | from parse import parse_args 2 | args = parse_args() 3 | 4 | hyper_params = { 5 | # COMMON 6 | 'dataset': args.dataset, 7 | 'seed': args.seed, 8 | 'model': args.model, 9 | 10 | # SVD-AE 11 | 'k': args.k, 12 | 'load': args.load, 13 | 14 | # Inifinite-AE 15 | 'lamda': args.lamda, # Only used if grid_search_lamda == False 16 | 'float64': False, 17 | 'depth': 1, 18 | 'grid_search_lamda': args.grid_search, 19 | 'user_support': -1, # Number of users to keep (randomly) & -1 implies use all users 20 | } 21 | -------------------------------------------------------------------------------- /parse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser(description="") 5 | parser.add_argument('--model', type=str, default='svd-ae', help='rec-model, support [svd-ae, ease, inf-ae]') 6 | parser.add_argument('--k', type=int, default=148, help='the rank parameter m') 7 | parser.add_argument('--dataset', type=str, default='ml-1m', help='available datasets: [gowalla, yelp2018, ml-1m]') 8 | parser.add_argument('--load', type=int, default=0) 9 | parser.add_argument('--seed', type=int, default=42) 10 | parser.add_argument('--lamda', type=float, default=1.0) 11 | parser.add_argument('--grid_search', type=int, default=0) 12 | 13 | return parser.parse_args() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Seoyoung Hong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import functools 3 | from jax import scipy as sp 4 | from jax import numpy as jnp 5 | from neural_tangents import stax 6 | import torch 7 | from torch import nn 8 | import numpy as np 9 | 10 | def make_kernelized_rr_forward(hyper_params): 11 | _, _, kernel_fn = FullyConnectedNetwork( 12 | depth=hyper_params['depth'], 13 | num_classes=hyper_params['num_items'] 14 | ) 15 | # NOTE: Un-comment this if the dataset size is very big (didn't need it for experiments in the paper) 16 | # kernel_fn = nt.batch(kernel_fn, batch_size=128) 17 | kernel_fn = functools.partial(kernel_fn, get='ntk') 18 | 19 | @jax.jit 20 | def kernelized_rr_forward(X_train, X_predict, reg=0.1): 21 | K_train = kernel_fn(X_train, X_train) # user * user 22 | K_predict = kernel_fn(X_predict, X_train) # user * user 23 | K_reg = (K_train + jnp.abs(reg) * jnp.trace(K_train) * jnp.eye(K_train.shape[0]) / K_train.shape[0]) # user * user 24 | return jnp.dot(K_predict, sp.linalg.solve(K_reg, X_train, sym_pos=True)) 25 | # sp.linalg.solve(K_reg, X_train, sym_pos=True)) -> user * item 26 | 27 | return kernelized_rr_forward, kernel_fn 28 | 29 | def FullyConnectedNetwork( 30 | depth, 31 | W_std = 2 ** 0.5, 32 | b_std = 0.1, 33 | num_classes = 10, 34 | parameterization = 'ntk' 35 | ): 36 | activation_fn = stax.Relu() 37 | dense = functools.partial(stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization) 38 | 39 | layers = [stax.Flatten()] 40 | # NOTE: setting width = 1024 doesn't matter as the NTK parameterization will stretch this till \infty 41 | for _ in range(depth): layers += [dense(1024), activation_fn] 42 | layers += [stax.Dense(num_classes, W_std=W_std, b_std=b_std, parameterization=parameterization)] 43 | 44 | return stax.serial(*layers) 45 | 46 | class EASE(nn.Module): 47 | def __init__(self, adj_mat, item_adj, device='cuda:0'): 48 | super(EASE, self).__init__() 49 | self.adj_mat = adj_mat.to(device) 50 | self.item_adj = item_adj.to(device) 51 | 52 | def forward(self, lambda_): 53 | G = self.item_adj 54 | diagIndices = np.diag_indices(G.shape[0]) 55 | G[diagIndices] += lambda_ 56 | P = torch.inverse(G) 57 | B = P / (-torch.diag(P)) 58 | B[diagIndices] = 0 59 | rating = torch.mm(self.adj_mat, B) 60 | 61 | return rating 62 | 63 | class SVD_AE(nn.Module): 64 | def __init__(self, adj_mat, norm_adj, user_sv, item_sv, device='cuda:0'): 65 | super(SVD_AE, self).__init__() 66 | self.adj_mat = adj_mat.to(device) 67 | self.norm_adj = norm_adj.to(device) 68 | self.user_sv = user_sv.to(device) # (K, M) 69 | self.item_sv = item_sv.to(device) # (K, N) 70 | 71 | def forward(self, lambda_mat): 72 | A = self.item_sv @ (torch.diag(1/lambda_mat)) @ self.user_sv.T 73 | rating = torch.mm(self.norm_adj, A @ self.adj_mat.to_dense()) 74 | # torch.inverse(torch.diag(lambda_mat)) 75 | return rating -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from collections import defaultdict 4 | import torch 5 | import scipy.sparse as sp 6 | import time 7 | import os 8 | 9 | def get_common_path(hyper_params): 10 | ret = "{}_{}_".format( 11 | hyper_params['dataset'], hyper_params['model'] 12 | ) 13 | if hyper_params['model'] == 'svd-ae': ret += "k_{}_".format(hyper_params['k']) 14 | else: 15 | if hyper_params['grid_search_lamda']: ret += "grid_search_lamda_" 16 | else: ret += "lamda_{}_".format(hyper_params['lamda']) 17 | 18 | ret += "seed_{}".format(hyper_params['seed']) 19 | return ret 20 | 21 | def get_item_count_map(data): 22 | item_count = defaultdict(int) 23 | for u, i, r in data.data['train']: item_count[i] += 1 24 | return item_count 25 | 26 | def get_item_propensity(hyper_params, data, A = 0.55, B = 1.5): 27 | item_freq_map = get_item_count_map(data) 28 | item_freq = [ item_freq_map[i] for i in range(hyper_params['num_items']) ] 29 | num_instances = hyper_params['num_interactions'] 30 | 31 | C = (np.log(num_instances)-1)*np.power(B+1, A) 32 | wts = 1.0 + C*np.power(np.array(item_freq)+B, -A) 33 | return np.ravel(wts) 34 | 35 | def file_write(log_file, s, dont_print=False): 36 | if dont_print == False: print(s) 37 | if log_file is None: return 38 | f = open(log_file, 'a') 39 | f.write(s+'\n') 40 | f.close() 41 | 42 | def log_end_epoch(hyper_params, metrics, step, time_elpased, metrics_on = '(TEST)', dont_print = False): 43 | string2 = "" 44 | for m in metrics: string2 += " | " + m + ' = ' + str("{:2.4f}".format(metrics[m])) 45 | string2 += ' ' + metrics_on 46 | 47 | if hyper_params['model'] == 'svd-ae': 48 | ss = '| end of step {:4d} | time = {:5.2f}'.format(step, time_elpased) 49 | else: 50 | ss = '| end of step {:4d} | time = {:5.2f} | best lambda = {}'.format(step, time_elpased, hyper_params['lamda']) 51 | 52 | ss += string2 53 | file_write(hyper_params['log_file'], ss, dont_print = dont_print) 54 | 55 | def set_seed(seed): 56 | random.seed(seed) 57 | np.random.seed(seed) 58 | if torch.cuda.is_available(): 59 | torch.cuda.manual_seed(seed) 60 | torch.cuda.manual_seed_all(seed) 61 | torch.manual_seed(seed) 62 | 63 | def convert_sp_mat_to_sp_tensor(X): 64 | coo = X.tocoo().astype(np.float32) 65 | row = torch.Tensor(coo.row).long() 66 | col = torch.Tensor(coo.col).long() 67 | index = torch.stack([row, col]) 68 | data = torch.FloatTensor(coo.data) 69 | return torch.sparse.FloatTensor(index, data, torch.Size(coo.shape)) 70 | 71 | def get_file_name(dataset, k, path): 72 | ut = f"{dataset}-{k}-ut.npy" 73 | s = f"{dataset}-{k}-s.npy" 74 | vt = f"{dataset}-{k}-vt.npy" 75 | file_list = [ut, s, vt] 76 | file_list = [os.path.join(path, file) for file in file_list] 77 | return file_list 78 | 79 | def preprocess_ease(adj_mat, device): 80 | start = time.time() 81 | adj_mat = adj_mat 82 | item_adj = adj_mat.T @ adj_mat 83 | adj_mat = convert_sp_mat_to_sp_tensor(adj_mat).to_dense() 84 | item_adj = convert_sp_mat_to_sp_tensor(item_adj).to_dense() 85 | end = time.time() 86 | print('Pre-processing time: ', end-start) 87 | return adj_mat, item_adj 88 | 89 | def preprocess_svd(LOAD, dataset, adj_mat, k, path, device): 90 | # start = time.time() 91 | file_list = get_file_name(dataset, k, path) 92 | rowsum = np.array(adj_mat.sum(axis=1)) 93 | rowsum = np.where(rowsum == 0.0, 1.0, rowsum) # Do not divide by zero 94 | d_inv = np.power(rowsum, -0.5).flatten() 95 | d_inv[np.isinf(d_inv)] = 0. 96 | d_mat = sp.diags(d_inv) 97 | norm_adj = d_mat.dot(adj_mat) 98 | colsum = np.array(adj_mat.sum(axis=0)) 99 | colsum = np.where(colsum == 0.0, 1.0, colsum) # Do not divide by zero 100 | d_inv = np.power(colsum, -0.5).flatten() 101 | d_inv[np.isinf(d_inv)] = 0. 102 | d_mat_i = sp.diags(d_inv) 103 | d_mat_i_inv = sp.diags(1/d_inv) 104 | norm_adj = norm_adj.dot(d_mat_i) 105 | norm_adj = norm_adj.tocsc() 106 | adj_mat = convert_sp_mat_to_sp_tensor(adj_mat) 107 | norm_adj = convert_sp_mat_to_sp_tensor(norm_adj) 108 | 109 | if LOAD: 110 | cond = os.path.isfile(file_list[0]) & os.path.isfile(file_list[1]) & os.path.isfile(file_list[2]) 111 | if cond: 112 | print("Load pre-calculated eigenvectors and eigenvalues!") 113 | ut, s, vt = np.load(file_list[0]), np.load(file_list[1]), np.load(file_list[2]) 114 | else: 115 | print("Saved numpy files don't exist!") 116 | exit() 117 | else: 118 | start = time.time() 119 | ut, s, vt = torch.svd_lowrank(norm_adj, q=k, niter=2, M=None) 120 | end = time.time() 121 | if not os.path.isdir(path): 122 | os.makedirs(path) 123 | np.save(file_list[0], ut.cpu().numpy()) 124 | np.save(file_list[1], s.cpu().numpy()) 125 | np.save(file_list[2], vt.cpu().numpy()) 126 | 127 | norm_adj = norm_adj.to_dense() 128 | ut = torch.FloatTensor(ut) 129 | s = torch.FloatTensor(s) 130 | vt = torch.FloatTensor(vt) 131 | # end = time.time() 132 | print('Pre-processing time: ', end-start) 133 | return adj_mat, norm_adj, ut, s, vt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SVD-AE: Simple Autoencoders for Collaborative Filtering 2 | 3 | ![GitHub Repo stars](https://img.shields.io/github/stars/seoyoungh/svd-ae) 4 | [![arXiv](https://img.shields.io/badge/arXiv-2405.04746-b31b1b.svg)](https://arxiv.org/abs/2405.04746) [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Fseoyoungh%2Fsvd-ae&count_bg=%230D6CFF&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)](https://hits.seeyoufarm.com) 5 | 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/svd-ae-simple-autoencoders-for-collaborative/recommendation-systems-on-gowalla)](https://paperswithcode.com/sota/recommendation-systems-on-gowalla?p=svd-ae-simple-autoencoders-for-collaborative) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/svd-ae-simple-autoencoders-for-collaborative/collaborative-filtering-on-movielens-10m)](https://paperswithcode.com/sota/collaborative-filtering-on-movielens-10m?p=svd-ae-simple-autoencoders-for-collaborative) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/svd-ae-simple-autoencoders-for-collaborative/collaborative-filtering-on-movielens-1m)](https://paperswithcode.com/sota/collaborative-filtering-on-movielens-1m?p=svd-ae-simple-autoencoders-for-collaborative) 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/svd-ae-simple-autoencoders-for-collaborative/recommendation-systems-on-yelp2018)](https://paperswithcode.com/sota/recommendation-systems-on-yelp2018?p=svd-ae-simple-autoencoders-for-collaborative) 10 | 11 | 12 | This repository contains the official implementation of [SVD-AE](https://arxiv.org/abs/2405.04746), a novel approach to collaborative filtering introduced in our IJCAI 2024 paper. 13 | 14 | - 📌 **Check out [our poster](https://www.dropbox.com/scl/fi/mm5obivc6hss0jgy0vdsl/SVD_AE_IJCAI_Poster.pdf?rlkey=mqkfnb5rc4fa1eee46w7q4h3l&st=ht28tyaw&dl=0) for a visual overview of SVD-AE!** 15 | 16 | - 🧵 **For a detailed explanation, see [our Twitter/X thread](https://x.com/jeongwhan_choi/status/1821010085713465694) highlighting key aspects of SVD-AE.** 17 | 18 | - 🎞️ **[Our presentation slides](https://www.dropbox.com/scl/fi/okdrh2htm4czcfb6cuhbo/SVD_AE_Talk.pdf?rlkey=1c60s0styu9e9u1rdzq9oqtln&st=v4wkydsz&dl=0) provide a comprehensive look at the method and results.** 19 | 20 | 21 | The best overall balance between 3 goals | The accuracy, robustness, and computation time of various methods on Gowalla | 22 | :-------------------------:|:-------------------------: 23 | | | 24 | 25 | - Key Features of **SVD-AE**: 26 | - Closed-form solution for efficient computation 27 | - Low-rank inductive bias for noise robustness 28 | - Competitive performance across various datasets 29 | 30 | --- 31 | 32 | ## Installation 33 | To set up the environment for running SVD-AE, follow these steps: 34 | 35 | 1. Clone this repository: 36 | 37 | ```bash 38 | git clone https://github.com/your_username/svd-ae.git 39 | cd svd-ae 40 | ``` 41 | 42 | 2. Create a virtual environment (optional but recommended): 43 | ```bash 44 | python -m venv svd_ae 45 | source svd_ae/bin/activate 46 | ``` 47 | 48 | 3. Install the required dependencies: 49 | ```bash 50 | pip install -r requirements.txt 51 | ``` 52 | 53 | ## Dataset Preparation 54 | Before running the code, you need to download and prepare the datasets: 55 | 56 | 1. Download the dataset archive from this Google Drive [link](https://drive.google.com/file/d/1cuhQw1aR9BEwutK3svKtL_-CGmcIPOiX/view?usp=sharing). 57 | 2. Unzip the downloaded file in the project root directory: 58 | 59 | ```bash 60 | unzip data.zip 61 | ``` 62 | 63 | This will create a data folder containing the preprocessed datasets (ML-1M, ML-10M, Gowalla, and Yelp2018). 64 | 65 | ## Usage 66 | To run SVD-AE on different datasets, use the following commands: 67 | 68 | - For ML-1M: 69 | ```bash 70 | python main.py --dataset ml-1m --k 148 71 | ``` 72 | 73 | - For ML-10M: 74 | ```bash 75 | python main.py --dataset ml-10m --k 427 76 | ``` 77 | 78 | - For Gowalla: 79 | ```bash 80 | python main.py --dataset gowalla --k 1194 81 | ``` 82 | 83 | - Yelp: 84 | ```bash 85 | python main.py --dataset yelp2018 --k 1267 86 | ``` 87 | 88 | ## Hyperparameters 89 | The main hyperparameter for SVD-AE is `k`, which represents the rank used in the truncated SVD. The optimal values for each dataset are provided in the usage examples above. You can experiment with different values to see how they affect performance. 90 | Other configurable parameters can be found in the `main.py` file. Feel free to adjust them according to your needs. 91 | 92 | ## Citation 93 | 94 | If you use this code or find SVD-AE helpful in your research, please cite our paper: 95 | 96 | ```bibtex 97 | @inproceedings{hong2024svdae, 98 | title = {SVD-AE: Simple Autoencoders for Collaborative Filtering}, 99 | author = {Hong, Seoyoung and Choi, Jeongwhan and Lee, Yeon-Chang and Kumar, Srijan and Park, Noseong}, 100 | booktitle = {Proceedings of the Thirty-Third International Joint Conference on 101 | Artificial Intelligence, {IJCAI-24}}, 102 | publisher = {International Joint Conferences on Artificial Intelligence Organization}, 103 | pages = {2054--2062}, 104 | year = {2024}, 105 | doi = {10.24963/ijcai.2024/227}, 106 | url = {https://doi.org/10.24963/ijcai.2024/227}, 107 | } 108 | ``` 109 | 110 | ## Star History 111 | 112 | [![Star History Chart](https://api.star-history.com/svg?repos=seoyoungh/svd-ae&type=Date)](https://star-history.com/#seoyoungh/svd-ae&Date) -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | import jax.numpy as jnp 4 | from numba import jit, float64 5 | import time 6 | 7 | INF = float(1e6) 8 | 9 | def evaluate(rating, hyper_params, kernelized_rr_forward, data, item_propensity, train_x, topk = [ 10, 20, 100 ], test_set_eval = False): 10 | preds, y_binary, metrics = [], [], {} 11 | for kind in [ 'HR', 'NDCG', 'PSP', 'RECALL', 'PRECISION' ]: # [ 'HR', 'NDCG', 'PSP' ]: 12 | for k in topk: 13 | metrics['{}@{}'.format(kind, k)] = 0.0 14 | # Train positive set -- these items will be set to -infinity while prediction on the val/test set 15 | 16 | train_positive_list = list(map(list, data.data['train_positive_set'])) 17 | if test_set_eval: 18 | for u in range(len(train_positive_list)): train_positive_list[u] += list(data.data['val_positive_set'][u]) 19 | 20 | # Train positive interactions (in matrix form) as context for prediction on val/test set 21 | eval_context = data.data['train_matrix'] 22 | if test_set_eval: eval_context += data.data['val_matrix'] 23 | 24 | # What needs to be predicted 25 | to_predict = data.data['val_positive_set'] 26 | if test_set_eval: to_predict = data.data['test_positive_set'] 27 | 28 | bsz = hyper_params['num_users'] 29 | # bsz = 20_000 # These many users 30 | 31 | train_time = 0 32 | 33 | for i in range(0, hyper_params['num_users'], bsz): 34 | if hyper_params['model'] == 'ease' or hyper_params['model'] == 'svd-ae': 35 | temp_preds = jnp.array(rating.cpu()) 36 | temp_preds_copy = temp_preds.copy() 37 | predicted_rating = temp_preds 38 | else: 39 | train_start_time = time.time() 40 | temp_preds = kernelized_rr_forward(train_x, eval_context[i:i+bsz].todense(), reg = hyper_params['lamda']) 41 | temp_preds_copy = temp_preds.copy() 42 | temp_train_time = time.time() - train_start_time 43 | train_time += temp_train_time 44 | predicted_rating = temp_preds # predicted_rating_score 45 | metrics, temp_preds, temp_y = evaluate_batch( 46 | data.data['negatives'][i:i+bsz], np.array(temp_preds), 47 | train_positive_list[i:i+bsz], to_predict[i:i+bsz], item_propensity, 48 | topk, metrics 49 | ) 50 | preds += temp_preds 51 | y_binary += temp_y 52 | 53 | if hyper_params['model'] == 'inf-ae': 54 | print('Training time: {}'.format(train_time)) 55 | 56 | y_binary, preds = np.array(y_binary), np.array(preds) 57 | if (True not in np.isnan(y_binary)) and (True not in np.isnan(preds)): 58 | metrics['AUC'] = round(fast_auc(y_binary, preds), 4) 59 | 60 | for kind in [ 'HR', 'NDCG', 'PSP', 'RECALL', 'PRECISION' ]: # [ 'HR', 'NDCG', 'PSP' ]: 61 | for k in topk: 62 | metrics['{}@{}'.format(kind, k)] = round( 63 | float(100.0 * metrics['{}@{}'.format(kind, k)]) / hyper_params['num_users'], 4 64 | ) 65 | 66 | # metrics['num_users'] = int(train_x.shape[0]) 67 | # metrics['num_interactions'] = int(jnp.count_nonzero(train_x.astype(np.int8))) 68 | 69 | return metrics, predicted_rating 70 | 71 | def evaluate_batch(auc_negatives, logits, train_positive, test_positive_set, item_propensity, topk, metrics, train_metrics = False): 72 | ''' 73 | logits: predicted rating matrix 74 | train_positive: list of train postivie items 75 | test_positive_set: list of test postivie items 76 | ''' 77 | # AUC Stuff 78 | temp_preds, temp_y = [], [] 79 | for b in range(len(logits)): 80 | temp_preds += np.take(logits[b], np.array(list(test_positive_set[b]))).tolist() 81 | temp_y += [ 1.0 for _ in range(len(test_positive_set[b])) ] 82 | 83 | temp_preds += np.take(logits[b], auc_negatives[b]).tolist() 84 | temp_y += [ 0.0 for _ in range(len(auc_negatives[b])) ] 85 | 86 | # Marking train-set consumed items as negative INF 87 | for b in range(len(logits)): logits[b][ train_positive[b] ] = -INF 88 | indices = (-logits).argsort()[:, :max(topk)].tolist() 89 | 90 | for k in topk: 91 | for b in range(len(logits)): 92 | num_pos = float(len(test_positive_set[b])) 93 | metrics['HR@{}'.format(k)] += float(len(set(indices[b][:k]) & test_positive_set[b])) / float(min(num_pos, k)) 94 | metrics['RECALL@{}'.format(k)] += float(len(set(indices[b][:k]) & test_positive_set[b])) / float(num_pos) 95 | metrics['PRECISION@{}'.format(k)] += float(len(set(indices[b][:k]) & test_positive_set[b])) / float(k) 96 | 97 | test_positive_sorted_psp = sorted([ item_propensity[x] for x in test_positive_set[b] ])[::-1] 98 | 99 | dcg, idcg, psp, max_psp = 0.0, 0.0, 0.0, 0.0 100 | for at, pred in enumerate(indices[b][:k]): 101 | if pred in test_positive_set[b]: 102 | dcg += 1.0 / np.log2(at + 2) 103 | psp += float(item_propensity[pred]) / float(min(num_pos, k)) 104 | if at < num_pos: 105 | idcg += 1.0 / np.log2(at + 2) 106 | max_psp += test_positive_sorted_psp[at] 107 | 108 | metrics['NDCG@{}'.format(k)] += dcg / idcg 109 | metrics['PSP@{}'.format(k)] += psp / max_psp 110 | 111 | return metrics, temp_preds, temp_y 112 | 113 | @jit(float64(float64[:], float64[:])) 114 | def fast_auc(y_true, y_prob): 115 | y_true = y_true[np.argsort(y_prob)] 116 | nfalse, auc = 0, 0 117 | for i in range(len(y_true)): 118 | nfalse += (1 - y_true[i]) 119 | auc += y_true[i] * nfalse 120 | return auc / (nfalse * (len(y_true) - nfalse)) -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from scipy.sparse import csr_matrix 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import copy 5 | import h5py 6 | import gc 7 | 8 | class Dataset: 9 | def __init__(self, hyper_params): 10 | self.data = load_raw_dataset(hyper_params['dataset']) 11 | self.set_of_active_users = list(set(self.data['train'][:, 0].tolist())) 12 | self.hyper_params = self.update_hyper_params(hyper_params) 13 | 14 | def update_hyper_params(self, hyper_params): 15 | updated_params = copy.deepcopy(hyper_params) 16 | 17 | self.num_users, self.num_items = self.data['num_users'], self.data['num_items'] 18 | self.num_interactions = self.data['num_interactions'] 19 | 20 | # Update hyper-params to have some basic data stats 21 | updated_params.update({ 22 | 'num_users': self.num_users, 23 | 'num_items': self.num_items, 24 | 'num_interactions': self.num_interactions 25 | }) 26 | 27 | return updated_params 28 | 29 | def sample_users(self, num_to_sample): 30 | if num_to_sample == -1: 31 | ret = self.data['train_matrix'] 32 | else: 33 | sampled_users = np.random.choice(self.set_of_active_users, num_to_sample, replace=False) 34 | sampled_interactions = self.data['train'][np.in1d(self.data['train'][:, 0], sampled_users)] 35 | ret = csr_matrix( 36 | ( np.ones(sampled_interactions.shape[0]), (sampled_interactions[:, 0], sampled_interactions[:, 1]) ), 37 | shape = (self.num_users, self.num_items) 38 | ) 39 | # This just removes the users which were not sampled 40 | return ret[ret.getnnz(1)>0] # sparse matrix 41 | # return jnp.array(ret[ret.getnnz(1)>0].todense()) 42 | 43 | def load_raw_dataset(dataset, data_path = None, index_path = None): 44 | if data_path is None or index_path is None: 45 | data_path, index_path = [ 46 | "data/{}/total_data.hdf5".format(dataset), 47 | "data/{}/index.npz".format(dataset) 48 | ] 49 | 50 | with h5py.File(data_path, 'r') as f: data = np.array(list(zip(f['user'][:], f['item'][:], f['rating'][:]))) 51 | index = np.array(np.load(index_path)['data'], dtype = np.int32) 52 | 53 | def remap(data, index): 54 | ## Counting number of unique users/items before 55 | valid_users, valid_items = set(), set() 56 | for at, (u, i, r) in enumerate(data): 57 | if index[at] != -1: 58 | valid_users.add(u) 59 | valid_items.add(i) 60 | 61 | ## Map creation done! 62 | user_map = dict(zip(list(valid_users), list(range(len(valid_users))))) 63 | item_map = dict(zip(list(valid_items), list(range(len(valid_items))))) 64 | 65 | return user_map, item_map 66 | 67 | 68 | user_map, item_map = remap(data, index) 69 | 70 | new_data, new_index = [], [] 71 | for at, (u, i, r) in enumerate(data): 72 | if index[at] == -1: continue 73 | new_data.append([ user_map[u], item_map[i], r ]) 74 | new_index.append(index[at]) 75 | data = np.array(new_data, dtype = np.int32) 76 | index = np.array(new_index, dtype = np.int32) 77 | 78 | def select(data, index, index_val): 79 | final = data[np.where(index == index_val)[0]] 80 | final[:, 2] = 1.0 # explicit -> implicit 81 | return final.astype(np.int32) 82 | 83 | ret = { 84 | 'item_map': item_map, 85 | 'train': select(data, index, 0), 86 | 'val': select(data, index, 1), 87 | 'test': select(data, index, 2) 88 | } 89 | 90 | num_users = int(max(data[:, 0]) + 1) 91 | num_items = len(item_map) 92 | 93 | del data, index ; gc.collect() 94 | 95 | def make_user_history(arr): 96 | ret = [ set() for _ in range(num_users) ] 97 | for u, i, r in arr: 98 | if i >= num_items: continue 99 | ret[int(u)].add(int(i)) 100 | return ret 101 | 102 | ret['train_positive_set'] = make_user_history(ret['train']) 103 | ret['val_positive_set'] = make_user_history(ret['val']) 104 | ret['test_positive_set'] = make_user_history(ret['test']) 105 | 106 | ret['train_matrix'] = csr_matrix( 107 | ( np.ones(ret['train'].shape[0]), (ret['train'][:, 0].astype(np.int32), ret['train'][:, 1].astype(np.int32)) ), 108 | shape = (num_users, num_items) 109 | ) 110 | 111 | ret['val_matrix'] = csr_matrix( 112 | ( np.ones(ret['val'].shape[0]), (ret['val'][:, 0].astype(np.int32), ret['val'][:, 1].astype(np.int32)) ), 113 | shape = (num_users, num_items) 114 | ) 115 | 116 | # Negatives will be used for AUC computation 117 | ret['negatives'] = [ set() for _ in range(num_users) ] 118 | for u in range(num_users): 119 | while len(ret['negatives'][u]) < 50: 120 | rand_item = np.random.randint(0, num_items) 121 | if rand_item in ret['train_positive_set'][u]: continue 122 | if rand_item in ret['test_positive_set'][u]: continue 123 | ret['negatives'][u].add(rand_item) 124 | ret['negatives'][u] = list(ret['negatives'][u]) 125 | ret['negatives'] = np.array(ret['negatives'], dtype=np.int32) 126 | 127 | ret.update({ 128 | 'num_users': num_users, 129 | 'num_items': num_items, 130 | 'num_interactions': len(ret['train']), 131 | }) 132 | 133 | print("# users:", num_users) 134 | print("# items:", num_items) 135 | print("# train interactions:", len(ret['train'])) 136 | print("# val interactions:", len(ret['val'])) 137 | print("# test interactions:", len(ret['test'])) 138 | 139 | return ret 140 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 3 | os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1" 4 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" 5 | 6 | import time 7 | import copy 8 | import random 9 | import numpy as np 10 | import torch 11 | from jax import numpy as jnp 12 | 13 | import model 14 | from parse import parse_args 15 | from utils import log_end_epoch, get_item_propensity, get_common_path, set_seed, preprocess_svd, preprocess_ease, convert_sp_mat_to_sp_tensor 16 | 17 | args = parse_args() 18 | 19 | def train(hyper_params, data): 20 | from model import make_kernelized_rr_forward 21 | from eval import evaluate 22 | 23 | # This just instantiates the function 24 | kernelized_rr_forward, kernel_fn = make_kernelized_rr_forward(hyper_params) 25 | sampled_matrix = data.sample_users(hyper_params['user_support']) # Random user sample 26 | 27 | if hyper_params['model'] == 'svd-ae': 28 | adj_mat = data.data['train_matrix'] + data.data['val_matrix'] 29 | PATH = os.getcwd() 30 | adj_mat, norm_adj, ut, s, vt = preprocess_svd(hyper_params['load'], hyper_params['dataset'], adj_mat, hyper_params['k'], os.path.join(PATH, 'checkpoints'), device) 31 | train_model = model.SVD_AE(adj_mat, norm_adj, ut, vt, device) 32 | 33 | elif hyper_params['model'] == 'ease': 34 | adj_mat = data.data['train_matrix'] 35 | adj_mat, item_adj = preprocess_ease(adj_mat, device) 36 | train_model = model.EASE(adj_mat, item_adj, device) 37 | 38 | elif hyper_params['model'] == 'inf-ae': 39 | rating = None 40 | 41 | else: 42 | print('This model is not supported!') 43 | exit() 44 | 45 | sampled_matrix = jnp.array(sampled_matrix.todense()) 46 | 47 | ''' 48 | NOTE: No training required! We will compute dual-variables \alpha on the fly in `kernelized_rr_forward` 49 | However, if we needed to perform evaluation multiple times, we could pre-compute \alpha like so: 50 | 51 | import jax, jax.numpy as jnp, jax.scipy as sp 52 | @jax.jit 53 | def precompute_alpha(X, lamda=0.1): 54 | K = kernel_fn(X, X) 55 | K_reg = (K + jnp.abs(lamda) * jnp.trace(K) * jnp.eye(K.shape[0]) / K.shape[0]) 56 | return sp.linalg.solve(K_reg, X, sym_pos=True) 57 | alpha = precompute_alpha(sampled_matrix, lamda=0.1) # Change for the desired value of lamda 58 | ''' 59 | 60 | # Used for computing the PSP-metric 61 | item_propensity = get_item_propensity(hyper_params, data) 62 | 63 | # Evaluation 64 | start_time = time.time() 65 | 66 | VAL_METRIC = "HR@10" 67 | best_metric, best_lamda = None, None 68 | 69 | if hyper_params['model'] == 'svd-ae': 70 | print(len(s)) 71 | s = s.to(device) 72 | rating = train_model(s) 73 | test_metrics, preds = evaluate(rating, hyper_params, kernelized_rr_forward, data, item_propensity, sampled_matrix, test_set_eval = True) 74 | 75 | # MSE 76 | adj_mat = data.data['train_matrix'] + data.data['val_matrix'] 77 | adj_mat = jnp.array(convert_sp_mat_to_sp_tensor(adj_mat).to_dense()) 78 | err = (preds - adj_mat) ** 2 79 | mse = sum(sum(err)) / (adj_mat.shape[0] * adj_mat.shape[1]) 80 | print("\nMSE value: {}".format(mse)) 81 | 82 | 83 | elif hyper_params['model'] == 'ease': 84 | # Validate on the validation-set 85 | for lamda in [ 1.0, 10.0, 100.0, 1000.0, 10000.0 ] if hyper_params['grid_search_lamda'] else [ hyper_params['lamda'] ]: 86 | hyper_params['lamda'] = lamda 87 | rating = train_model(lamda) 88 | val_metrics, preds = evaluate(rating, hyper_params, kernelized_rr_forward, data, item_propensity, sampled_matrix) 89 | log_end_epoch(hyper_params, val_metrics, 0, time.time() - start_time) 90 | if (best_metric is None) or (val_metrics[VAL_METRIC] > best_metric): best_metric, best_lamda = val_metrics[VAL_METRIC], lamda 91 | print("\nBest lambda value: {}".format(best_lamda)) 92 | hyper_params['lamda'] = best_lamda 93 | 94 | # Test on the train + validation set 95 | adj_mat = data.data['train_matrix'] + data.data['val_matrix'] 96 | adj_mat, item_adj = preprocess_ease(adj_mat, device) 97 | train_model = model.EASE(adj_mat, item_adj, device) 98 | rating = train_model(best_lamda) 99 | test_metrics, preds = evaluate(rating, hyper_params, kernelized_rr_forward, data, item_propensity, sampled_matrix, test_set_eval = True) 100 | 101 | # MSE 102 | adj_mat = data.data['train_matrix'] + data.data['val_matrix'] 103 | adj_mat = jnp.array(convert_sp_mat_to_sp_tensor(adj_mat).to_dense()) 104 | err = (preds - adj_mat) ** 2 105 | mse = sum(sum(err)) / (adj_mat.shape[0] * adj_mat.shape[1]) 106 | print("\nMSE value: {}".format(mse)) 107 | 108 | 109 | else: 110 | # Validate on the validation-set 111 | for lamda in [ 0.0, 1.0, 5.0, 20.0, 50.0, 100.0 ] if hyper_params['grid_search_lamda'] else [ hyper_params['lamda'] ]: 112 | hyper_params['lamda'] = lamda 113 | val_metrics, preds = evaluate(rating, hyper_params, kernelized_rr_forward, data, item_propensity, sampled_matrix) 114 | log_end_epoch(hyper_params, val_metrics, 0, time.time() - start_time) 115 | if (best_metric is None) or (val_metrics[VAL_METRIC] > best_metric): best_metric, best_lamda = val_metrics[VAL_METRIC], lamda 116 | print("Best lambda value: {}".format(best_lamda)) 117 | hyper_params['lamda'] = best_lamda 118 | test_metrics, preds = evaluate(rating, hyper_params, kernelized_rr_forward, data, item_propensity, sampled_matrix, test_set_eval = True) 119 | 120 | # MSE 121 | adj_mat = data.data['train_matrix'] + data.data['val_matrix'] 122 | adj_mat = jnp.array(convert_sp_mat_to_sp_tensor(adj_mat).to_dense()) 123 | err = (preds - adj_mat) ** 2 124 | mse = sum(sum(err)) / (adj_mat.shape[0] * adj_mat.shape[1]) 125 | print("\nMSE value: {}".format(mse)) 126 | 127 | 128 | # Return metrics with the best lamda on the test-set 129 | log_end_epoch(hyper_params, test_metrics, 0, time.time() - start_time) 130 | start_time = time.time() 131 | 132 | return test_metrics 133 | 134 | def main(hyper_params, gpu_id = None): 135 | if gpu_id is not None: os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 136 | 137 | from jax.config import config 138 | if 'float64' in hyper_params and hyper_params['float64'] == True: config.update('jax_enable_x64', True) 139 | 140 | from data import Dataset 141 | 142 | os.makedirs("./results/logs/", exist_ok=True) 143 | hyper_params['log_file'] = "./results/logs/" + get_common_path(hyper_params) + ".txt" 144 | data = Dataset(hyper_params) 145 | hyper_params = copy.deepcopy(data.hyper_params) # Updated w/ data-stats 146 | 147 | return train(hyper_params, data) 148 | 149 | if __name__ == "__main__": 150 | from hyper_params import hyper_params 151 | set_seed(hyper_params['seed']) 152 | GPU = torch.cuda.is_available() 153 | device = torch.device('cuda:0' if GPU else 'cpu') 154 | main(hyper_params) 155 | --------------------------------------------------------------------------------