├── asset
├── triangle.png
└── overview_vis.png
├── requirements.txt
├── hyper_params.py
├── parse.py
├── LICENSE
├── model.py
├── utils.py
├── README.md
├── eval.py
├── data.py
└── main.py
/asset/triangle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seoyoungh/svd-ae/HEAD/asset/triangle.png
--------------------------------------------------------------------------------
/asset/overview_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seoyoungh/svd-ae/HEAD/asset/overview_vis.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.11.0+cu113
2 | jax==0.3.25
3 | jaxlib==0.3.25
4 | h5py==3.7.0
5 | neural-tangents==0.6.1
6 | scipy==1.7.3
7 | numpy==1.21.6
--------------------------------------------------------------------------------
/hyper_params.py:
--------------------------------------------------------------------------------
1 | from parse import parse_args
2 | args = parse_args()
3 |
4 | hyper_params = {
5 | # COMMON
6 | 'dataset': args.dataset,
7 | 'seed': args.seed,
8 | 'model': args.model,
9 |
10 | # SVD-AE
11 | 'k': args.k,
12 | 'load': args.load,
13 |
14 | # Inifinite-AE
15 | 'lamda': args.lamda, # Only used if grid_search_lamda == False
16 | 'float64': False,
17 | 'depth': 1,
18 | 'grid_search_lamda': args.grid_search,
19 | 'user_support': -1, # Number of users to keep (randomly) & -1 implies use all users
20 | }
21 |
--------------------------------------------------------------------------------
/parse.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_args():
4 | parser = argparse.ArgumentParser(description="")
5 | parser.add_argument('--model', type=str, default='svd-ae', help='rec-model, support [svd-ae, ease, inf-ae]')
6 | parser.add_argument('--k', type=int, default=148, help='the rank parameter m')
7 | parser.add_argument('--dataset', type=str, default='ml-1m', help='available datasets: [gowalla, yelp2018, ml-1m]')
8 | parser.add_argument('--load', type=int, default=0)
9 | parser.add_argument('--seed', type=int, default=42)
10 | parser.add_argument('--lamda', type=float, default=1.0)
11 | parser.add_argument('--grid_search', type=int, default=0)
12 |
13 | return parser.parse_args()
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Seoyoung Hong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import functools
3 | from jax import scipy as sp
4 | from jax import numpy as jnp
5 | from neural_tangents import stax
6 | import torch
7 | from torch import nn
8 | import numpy as np
9 |
10 | def make_kernelized_rr_forward(hyper_params):
11 | _, _, kernel_fn = FullyConnectedNetwork(
12 | depth=hyper_params['depth'],
13 | num_classes=hyper_params['num_items']
14 | )
15 | # NOTE: Un-comment this if the dataset size is very big (didn't need it for experiments in the paper)
16 | # kernel_fn = nt.batch(kernel_fn, batch_size=128)
17 | kernel_fn = functools.partial(kernel_fn, get='ntk')
18 |
19 | @jax.jit
20 | def kernelized_rr_forward(X_train, X_predict, reg=0.1):
21 | K_train = kernel_fn(X_train, X_train) # user * user
22 | K_predict = kernel_fn(X_predict, X_train) # user * user
23 | K_reg = (K_train + jnp.abs(reg) * jnp.trace(K_train) * jnp.eye(K_train.shape[0]) / K_train.shape[0]) # user * user
24 | return jnp.dot(K_predict, sp.linalg.solve(K_reg, X_train, sym_pos=True))
25 | # sp.linalg.solve(K_reg, X_train, sym_pos=True)) -> user * item
26 |
27 | return kernelized_rr_forward, kernel_fn
28 |
29 | def FullyConnectedNetwork(
30 | depth,
31 | W_std = 2 ** 0.5,
32 | b_std = 0.1,
33 | num_classes = 10,
34 | parameterization = 'ntk'
35 | ):
36 | activation_fn = stax.Relu()
37 | dense = functools.partial(stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization)
38 |
39 | layers = [stax.Flatten()]
40 | # NOTE: setting width = 1024 doesn't matter as the NTK parameterization will stretch this till \infty
41 | for _ in range(depth): layers += [dense(1024), activation_fn]
42 | layers += [stax.Dense(num_classes, W_std=W_std, b_std=b_std, parameterization=parameterization)]
43 |
44 | return stax.serial(*layers)
45 |
46 | class EASE(nn.Module):
47 | def __init__(self, adj_mat, item_adj, device='cuda:0'):
48 | super(EASE, self).__init__()
49 | self.adj_mat = adj_mat.to(device)
50 | self.item_adj = item_adj.to(device)
51 |
52 | def forward(self, lambda_):
53 | G = self.item_adj
54 | diagIndices = np.diag_indices(G.shape[0])
55 | G[diagIndices] += lambda_
56 | P = torch.inverse(G)
57 | B = P / (-torch.diag(P))
58 | B[diagIndices] = 0
59 | rating = torch.mm(self.adj_mat, B)
60 |
61 | return rating
62 |
63 | class SVD_AE(nn.Module):
64 | def __init__(self, adj_mat, norm_adj, user_sv, item_sv, device='cuda:0'):
65 | super(SVD_AE, self).__init__()
66 | self.adj_mat = adj_mat.to(device)
67 | self.norm_adj = norm_adj.to(device)
68 | self.user_sv = user_sv.to(device) # (K, M)
69 | self.item_sv = item_sv.to(device) # (K, N)
70 |
71 | def forward(self, lambda_mat):
72 | A = self.item_sv @ (torch.diag(1/lambda_mat)) @ self.user_sv.T
73 | rating = torch.mm(self.norm_adj, A @ self.adj_mat.to_dense())
74 | # torch.inverse(torch.diag(lambda_mat))
75 | return rating
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | from collections import defaultdict
4 | import torch
5 | import scipy.sparse as sp
6 | import time
7 | import os
8 |
9 | def get_common_path(hyper_params):
10 | ret = "{}_{}_".format(
11 | hyper_params['dataset'], hyper_params['model']
12 | )
13 | if hyper_params['model'] == 'svd-ae': ret += "k_{}_".format(hyper_params['k'])
14 | else:
15 | if hyper_params['grid_search_lamda']: ret += "grid_search_lamda_"
16 | else: ret += "lamda_{}_".format(hyper_params['lamda'])
17 |
18 | ret += "seed_{}".format(hyper_params['seed'])
19 | return ret
20 |
21 | def get_item_count_map(data):
22 | item_count = defaultdict(int)
23 | for u, i, r in data.data['train']: item_count[i] += 1
24 | return item_count
25 |
26 | def get_item_propensity(hyper_params, data, A = 0.55, B = 1.5):
27 | item_freq_map = get_item_count_map(data)
28 | item_freq = [ item_freq_map[i] for i in range(hyper_params['num_items']) ]
29 | num_instances = hyper_params['num_interactions']
30 |
31 | C = (np.log(num_instances)-1)*np.power(B+1, A)
32 | wts = 1.0 + C*np.power(np.array(item_freq)+B, -A)
33 | return np.ravel(wts)
34 |
35 | def file_write(log_file, s, dont_print=False):
36 | if dont_print == False: print(s)
37 | if log_file is None: return
38 | f = open(log_file, 'a')
39 | f.write(s+'\n')
40 | f.close()
41 |
42 | def log_end_epoch(hyper_params, metrics, step, time_elpased, metrics_on = '(TEST)', dont_print = False):
43 | string2 = ""
44 | for m in metrics: string2 += " | " + m + ' = ' + str("{:2.4f}".format(metrics[m]))
45 | string2 += ' ' + metrics_on
46 |
47 | if hyper_params['model'] == 'svd-ae':
48 | ss = '| end of step {:4d} | time = {:5.2f}'.format(step, time_elpased)
49 | else:
50 | ss = '| end of step {:4d} | time = {:5.2f} | best lambda = {}'.format(step, time_elpased, hyper_params['lamda'])
51 |
52 | ss += string2
53 | file_write(hyper_params['log_file'], ss, dont_print = dont_print)
54 |
55 | def set_seed(seed):
56 | random.seed(seed)
57 | np.random.seed(seed)
58 | if torch.cuda.is_available():
59 | torch.cuda.manual_seed(seed)
60 | torch.cuda.manual_seed_all(seed)
61 | torch.manual_seed(seed)
62 |
63 | def convert_sp_mat_to_sp_tensor(X):
64 | coo = X.tocoo().astype(np.float32)
65 | row = torch.Tensor(coo.row).long()
66 | col = torch.Tensor(coo.col).long()
67 | index = torch.stack([row, col])
68 | data = torch.FloatTensor(coo.data)
69 | return torch.sparse.FloatTensor(index, data, torch.Size(coo.shape))
70 |
71 | def get_file_name(dataset, k, path):
72 | ut = f"{dataset}-{k}-ut.npy"
73 | s = f"{dataset}-{k}-s.npy"
74 | vt = f"{dataset}-{k}-vt.npy"
75 | file_list = [ut, s, vt]
76 | file_list = [os.path.join(path, file) for file in file_list]
77 | return file_list
78 |
79 | def preprocess_ease(adj_mat, device):
80 | start = time.time()
81 | adj_mat = adj_mat
82 | item_adj = adj_mat.T @ adj_mat
83 | adj_mat = convert_sp_mat_to_sp_tensor(adj_mat).to_dense()
84 | item_adj = convert_sp_mat_to_sp_tensor(item_adj).to_dense()
85 | end = time.time()
86 | print('Pre-processing time: ', end-start)
87 | return adj_mat, item_adj
88 |
89 | def preprocess_svd(LOAD, dataset, adj_mat, k, path, device):
90 | # start = time.time()
91 | file_list = get_file_name(dataset, k, path)
92 | rowsum = np.array(adj_mat.sum(axis=1))
93 | rowsum = np.where(rowsum == 0.0, 1.0, rowsum) # Do not divide by zero
94 | d_inv = np.power(rowsum, -0.5).flatten()
95 | d_inv[np.isinf(d_inv)] = 0.
96 | d_mat = sp.diags(d_inv)
97 | norm_adj = d_mat.dot(adj_mat)
98 | colsum = np.array(adj_mat.sum(axis=0))
99 | colsum = np.where(colsum == 0.0, 1.0, colsum) # Do not divide by zero
100 | d_inv = np.power(colsum, -0.5).flatten()
101 | d_inv[np.isinf(d_inv)] = 0.
102 | d_mat_i = sp.diags(d_inv)
103 | d_mat_i_inv = sp.diags(1/d_inv)
104 | norm_adj = norm_adj.dot(d_mat_i)
105 | norm_adj = norm_adj.tocsc()
106 | adj_mat = convert_sp_mat_to_sp_tensor(adj_mat)
107 | norm_adj = convert_sp_mat_to_sp_tensor(norm_adj)
108 |
109 | if LOAD:
110 | cond = os.path.isfile(file_list[0]) & os.path.isfile(file_list[1]) & os.path.isfile(file_list[2])
111 | if cond:
112 | print("Load pre-calculated eigenvectors and eigenvalues!")
113 | ut, s, vt = np.load(file_list[0]), np.load(file_list[1]), np.load(file_list[2])
114 | else:
115 | print("Saved numpy files don't exist!")
116 | exit()
117 | else:
118 | start = time.time()
119 | ut, s, vt = torch.svd_lowrank(norm_adj, q=k, niter=2, M=None)
120 | end = time.time()
121 | if not os.path.isdir(path):
122 | os.makedirs(path)
123 | np.save(file_list[0], ut.cpu().numpy())
124 | np.save(file_list[1], s.cpu().numpy())
125 | np.save(file_list[2], vt.cpu().numpy())
126 |
127 | norm_adj = norm_adj.to_dense()
128 | ut = torch.FloatTensor(ut)
129 | s = torch.FloatTensor(s)
130 | vt = torch.FloatTensor(vt)
131 | # end = time.time()
132 | print('Pre-processing time: ', end-start)
133 | return adj_mat, norm_adj, ut, s, vt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SVD-AE: Simple Autoencoders for Collaborative Filtering
2 |
3 | 
4 | [](https://arxiv.org/abs/2405.04746) [](https://hits.seeyoufarm.com)
5 |
6 | [](https://paperswithcode.com/sota/recommendation-systems-on-gowalla?p=svd-ae-simple-autoencoders-for-collaborative)
7 | [](https://paperswithcode.com/sota/collaborative-filtering-on-movielens-10m?p=svd-ae-simple-autoencoders-for-collaborative)
8 | [](https://paperswithcode.com/sota/collaborative-filtering-on-movielens-1m?p=svd-ae-simple-autoencoders-for-collaborative)
9 | [](https://paperswithcode.com/sota/recommendation-systems-on-yelp2018?p=svd-ae-simple-autoencoders-for-collaborative)
10 |
11 |
12 | This repository contains the official implementation of [SVD-AE](https://arxiv.org/abs/2405.04746), a novel approach to collaborative filtering introduced in our IJCAI 2024 paper.
13 |
14 | - 📌 **Check out [our poster](https://www.dropbox.com/scl/fi/mm5obivc6hss0jgy0vdsl/SVD_AE_IJCAI_Poster.pdf?rlkey=mqkfnb5rc4fa1eee46w7q4h3l&st=ht28tyaw&dl=0) for a visual overview of SVD-AE!**
15 |
16 | - 🧵 **For a detailed explanation, see [our Twitter/X thread](https://x.com/jeongwhan_choi/status/1821010085713465694) highlighting key aspects of SVD-AE.**
17 |
18 | - 🎞️ **[Our presentation slides](https://www.dropbox.com/scl/fi/okdrh2htm4czcfb6cuhbo/SVD_AE_Talk.pdf?rlkey=1c60s0styu9e9u1rdzq9oqtln&st=v4wkydsz&dl=0) provide a comprehensive look at the method and results.**
19 |
20 |
21 | The best overall balance between 3 goals | The accuracy, robustness, and computation time of various methods on Gowalla |
22 | :-------------------------:|:-------------------------:
23 |
|
|
24 |
25 | - Key Features of **SVD-AE**:
26 | - Closed-form solution for efficient computation
27 | - Low-rank inductive bias for noise robustness
28 | - Competitive performance across various datasets
29 |
30 | ---
31 |
32 | ## Installation
33 | To set up the environment for running SVD-AE, follow these steps:
34 |
35 | 1. Clone this repository:
36 |
37 | ```bash
38 | git clone https://github.com/your_username/svd-ae.git
39 | cd svd-ae
40 | ```
41 |
42 | 2. Create a virtual environment (optional but recommended):
43 | ```bash
44 | python -m venv svd_ae
45 | source svd_ae/bin/activate
46 | ```
47 |
48 | 3. Install the required dependencies:
49 | ```bash
50 | pip install -r requirements.txt
51 | ```
52 |
53 | ## Dataset Preparation
54 | Before running the code, you need to download and prepare the datasets:
55 |
56 | 1. Download the dataset archive from this Google Drive [link](https://drive.google.com/file/d/1cuhQw1aR9BEwutK3svKtL_-CGmcIPOiX/view?usp=sharing).
57 | 2. Unzip the downloaded file in the project root directory:
58 |
59 | ```bash
60 | unzip data.zip
61 | ```
62 |
63 | This will create a data folder containing the preprocessed datasets (ML-1M, ML-10M, Gowalla, and Yelp2018).
64 |
65 | ## Usage
66 | To run SVD-AE on different datasets, use the following commands:
67 |
68 | - For ML-1M:
69 | ```bash
70 | python main.py --dataset ml-1m --k 148
71 | ```
72 |
73 | - For ML-10M:
74 | ```bash
75 | python main.py --dataset ml-10m --k 427
76 | ```
77 |
78 | - For Gowalla:
79 | ```bash
80 | python main.py --dataset gowalla --k 1194
81 | ```
82 |
83 | - Yelp:
84 | ```bash
85 | python main.py --dataset yelp2018 --k 1267
86 | ```
87 |
88 | ## Hyperparameters
89 | The main hyperparameter for SVD-AE is `k`, which represents the rank used in the truncated SVD. The optimal values for each dataset are provided in the usage examples above. You can experiment with different values to see how they affect performance.
90 | Other configurable parameters can be found in the `main.py` file. Feel free to adjust them according to your needs.
91 |
92 | ## Citation
93 |
94 | If you use this code or find SVD-AE helpful in your research, please cite our paper:
95 |
96 | ```bibtex
97 | @inproceedings{hong2024svdae,
98 | title = {SVD-AE: Simple Autoencoders for Collaborative Filtering},
99 | author = {Hong, Seoyoung and Choi, Jeongwhan and Lee, Yeon-Chang and Kumar, Srijan and Park, Noseong},
100 | booktitle = {Proceedings of the Thirty-Third International Joint Conference on
101 | Artificial Intelligence, {IJCAI-24}},
102 | publisher = {International Joint Conferences on Artificial Intelligence Organization},
103 | pages = {2054--2062},
104 | year = {2024},
105 | doi = {10.24963/ijcai.2024/227},
106 | url = {https://doi.org/10.24963/ijcai.2024/227},
107 | }
108 | ```
109 |
110 | ## Star History
111 |
112 | [](https://star-history.com/#seoyoungh/svd-ae&Date)
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import numpy as np
3 | import jax.numpy as jnp
4 | from numba import jit, float64
5 | import time
6 |
7 | INF = float(1e6)
8 |
9 | def evaluate(rating, hyper_params, kernelized_rr_forward, data, item_propensity, train_x, topk = [ 10, 20, 100 ], test_set_eval = False):
10 | preds, y_binary, metrics = [], [], {}
11 | for kind in [ 'HR', 'NDCG', 'PSP', 'RECALL', 'PRECISION' ]: # [ 'HR', 'NDCG', 'PSP' ]:
12 | for k in topk:
13 | metrics['{}@{}'.format(kind, k)] = 0.0
14 | # Train positive set -- these items will be set to -infinity while prediction on the val/test set
15 |
16 | train_positive_list = list(map(list, data.data['train_positive_set']))
17 | if test_set_eval:
18 | for u in range(len(train_positive_list)): train_positive_list[u] += list(data.data['val_positive_set'][u])
19 |
20 | # Train positive interactions (in matrix form) as context for prediction on val/test set
21 | eval_context = data.data['train_matrix']
22 | if test_set_eval: eval_context += data.data['val_matrix']
23 |
24 | # What needs to be predicted
25 | to_predict = data.data['val_positive_set']
26 | if test_set_eval: to_predict = data.data['test_positive_set']
27 |
28 | bsz = hyper_params['num_users']
29 | # bsz = 20_000 # These many users
30 |
31 | train_time = 0
32 |
33 | for i in range(0, hyper_params['num_users'], bsz):
34 | if hyper_params['model'] == 'ease' or hyper_params['model'] == 'svd-ae':
35 | temp_preds = jnp.array(rating.cpu())
36 | temp_preds_copy = temp_preds.copy()
37 | predicted_rating = temp_preds
38 | else:
39 | train_start_time = time.time()
40 | temp_preds = kernelized_rr_forward(train_x, eval_context[i:i+bsz].todense(), reg = hyper_params['lamda'])
41 | temp_preds_copy = temp_preds.copy()
42 | temp_train_time = time.time() - train_start_time
43 | train_time += temp_train_time
44 | predicted_rating = temp_preds # predicted_rating_score
45 | metrics, temp_preds, temp_y = evaluate_batch(
46 | data.data['negatives'][i:i+bsz], np.array(temp_preds),
47 | train_positive_list[i:i+bsz], to_predict[i:i+bsz], item_propensity,
48 | topk, metrics
49 | )
50 | preds += temp_preds
51 | y_binary += temp_y
52 |
53 | if hyper_params['model'] == 'inf-ae':
54 | print('Training time: {}'.format(train_time))
55 |
56 | y_binary, preds = np.array(y_binary), np.array(preds)
57 | if (True not in np.isnan(y_binary)) and (True not in np.isnan(preds)):
58 | metrics['AUC'] = round(fast_auc(y_binary, preds), 4)
59 |
60 | for kind in [ 'HR', 'NDCG', 'PSP', 'RECALL', 'PRECISION' ]: # [ 'HR', 'NDCG', 'PSP' ]:
61 | for k in topk:
62 | metrics['{}@{}'.format(kind, k)] = round(
63 | float(100.0 * metrics['{}@{}'.format(kind, k)]) / hyper_params['num_users'], 4
64 | )
65 |
66 | # metrics['num_users'] = int(train_x.shape[0])
67 | # metrics['num_interactions'] = int(jnp.count_nonzero(train_x.astype(np.int8)))
68 |
69 | return metrics, predicted_rating
70 |
71 | def evaluate_batch(auc_negatives, logits, train_positive, test_positive_set, item_propensity, topk, metrics, train_metrics = False):
72 | '''
73 | logits: predicted rating matrix
74 | train_positive: list of train postivie items
75 | test_positive_set: list of test postivie items
76 | '''
77 | # AUC Stuff
78 | temp_preds, temp_y = [], []
79 | for b in range(len(logits)):
80 | temp_preds += np.take(logits[b], np.array(list(test_positive_set[b]))).tolist()
81 | temp_y += [ 1.0 for _ in range(len(test_positive_set[b])) ]
82 |
83 | temp_preds += np.take(logits[b], auc_negatives[b]).tolist()
84 | temp_y += [ 0.0 for _ in range(len(auc_negatives[b])) ]
85 |
86 | # Marking train-set consumed items as negative INF
87 | for b in range(len(logits)): logits[b][ train_positive[b] ] = -INF
88 | indices = (-logits).argsort()[:, :max(topk)].tolist()
89 |
90 | for k in topk:
91 | for b in range(len(logits)):
92 | num_pos = float(len(test_positive_set[b]))
93 | metrics['HR@{}'.format(k)] += float(len(set(indices[b][:k]) & test_positive_set[b])) / float(min(num_pos, k))
94 | metrics['RECALL@{}'.format(k)] += float(len(set(indices[b][:k]) & test_positive_set[b])) / float(num_pos)
95 | metrics['PRECISION@{}'.format(k)] += float(len(set(indices[b][:k]) & test_positive_set[b])) / float(k)
96 |
97 | test_positive_sorted_psp = sorted([ item_propensity[x] for x in test_positive_set[b] ])[::-1]
98 |
99 | dcg, idcg, psp, max_psp = 0.0, 0.0, 0.0, 0.0
100 | for at, pred in enumerate(indices[b][:k]):
101 | if pred in test_positive_set[b]:
102 | dcg += 1.0 / np.log2(at + 2)
103 | psp += float(item_propensity[pred]) / float(min(num_pos, k))
104 | if at < num_pos:
105 | idcg += 1.0 / np.log2(at + 2)
106 | max_psp += test_positive_sorted_psp[at]
107 |
108 | metrics['NDCG@{}'.format(k)] += dcg / idcg
109 | metrics['PSP@{}'.format(k)] += psp / max_psp
110 |
111 | return metrics, temp_preds, temp_y
112 |
113 | @jit(float64(float64[:], float64[:]))
114 | def fast_auc(y_true, y_prob):
115 | y_true = y_true[np.argsort(y_prob)]
116 | nfalse, auc = 0, 0
117 | for i in range(len(y_true)):
118 | nfalse += (1 - y_true[i])
119 | auc += y_true[i] * nfalse
120 | return auc / (nfalse * (len(y_true) - nfalse))
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | from scipy.sparse import csr_matrix
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import copy
5 | import h5py
6 | import gc
7 |
8 | class Dataset:
9 | def __init__(self, hyper_params):
10 | self.data = load_raw_dataset(hyper_params['dataset'])
11 | self.set_of_active_users = list(set(self.data['train'][:, 0].tolist()))
12 | self.hyper_params = self.update_hyper_params(hyper_params)
13 |
14 | def update_hyper_params(self, hyper_params):
15 | updated_params = copy.deepcopy(hyper_params)
16 |
17 | self.num_users, self.num_items = self.data['num_users'], self.data['num_items']
18 | self.num_interactions = self.data['num_interactions']
19 |
20 | # Update hyper-params to have some basic data stats
21 | updated_params.update({
22 | 'num_users': self.num_users,
23 | 'num_items': self.num_items,
24 | 'num_interactions': self.num_interactions
25 | })
26 |
27 | return updated_params
28 |
29 | def sample_users(self, num_to_sample):
30 | if num_to_sample == -1:
31 | ret = self.data['train_matrix']
32 | else:
33 | sampled_users = np.random.choice(self.set_of_active_users, num_to_sample, replace=False)
34 | sampled_interactions = self.data['train'][np.in1d(self.data['train'][:, 0], sampled_users)]
35 | ret = csr_matrix(
36 | ( np.ones(sampled_interactions.shape[0]), (sampled_interactions[:, 0], sampled_interactions[:, 1]) ),
37 | shape = (self.num_users, self.num_items)
38 | )
39 | # This just removes the users which were not sampled
40 | return ret[ret.getnnz(1)>0] # sparse matrix
41 | # return jnp.array(ret[ret.getnnz(1)>0].todense())
42 |
43 | def load_raw_dataset(dataset, data_path = None, index_path = None):
44 | if data_path is None or index_path is None:
45 | data_path, index_path = [
46 | "data/{}/total_data.hdf5".format(dataset),
47 | "data/{}/index.npz".format(dataset)
48 | ]
49 |
50 | with h5py.File(data_path, 'r') as f: data = np.array(list(zip(f['user'][:], f['item'][:], f['rating'][:])))
51 | index = np.array(np.load(index_path)['data'], dtype = np.int32)
52 |
53 | def remap(data, index):
54 | ## Counting number of unique users/items before
55 | valid_users, valid_items = set(), set()
56 | for at, (u, i, r) in enumerate(data):
57 | if index[at] != -1:
58 | valid_users.add(u)
59 | valid_items.add(i)
60 |
61 | ## Map creation done!
62 | user_map = dict(zip(list(valid_users), list(range(len(valid_users)))))
63 | item_map = dict(zip(list(valid_items), list(range(len(valid_items)))))
64 |
65 | return user_map, item_map
66 |
67 |
68 | user_map, item_map = remap(data, index)
69 |
70 | new_data, new_index = [], []
71 | for at, (u, i, r) in enumerate(data):
72 | if index[at] == -1: continue
73 | new_data.append([ user_map[u], item_map[i], r ])
74 | new_index.append(index[at])
75 | data = np.array(new_data, dtype = np.int32)
76 | index = np.array(new_index, dtype = np.int32)
77 |
78 | def select(data, index, index_val):
79 | final = data[np.where(index == index_val)[0]]
80 | final[:, 2] = 1.0 # explicit -> implicit
81 | return final.astype(np.int32)
82 |
83 | ret = {
84 | 'item_map': item_map,
85 | 'train': select(data, index, 0),
86 | 'val': select(data, index, 1),
87 | 'test': select(data, index, 2)
88 | }
89 |
90 | num_users = int(max(data[:, 0]) + 1)
91 | num_items = len(item_map)
92 |
93 | del data, index ; gc.collect()
94 |
95 | def make_user_history(arr):
96 | ret = [ set() for _ in range(num_users) ]
97 | for u, i, r in arr:
98 | if i >= num_items: continue
99 | ret[int(u)].add(int(i))
100 | return ret
101 |
102 | ret['train_positive_set'] = make_user_history(ret['train'])
103 | ret['val_positive_set'] = make_user_history(ret['val'])
104 | ret['test_positive_set'] = make_user_history(ret['test'])
105 |
106 | ret['train_matrix'] = csr_matrix(
107 | ( np.ones(ret['train'].shape[0]), (ret['train'][:, 0].astype(np.int32), ret['train'][:, 1].astype(np.int32)) ),
108 | shape = (num_users, num_items)
109 | )
110 |
111 | ret['val_matrix'] = csr_matrix(
112 | ( np.ones(ret['val'].shape[0]), (ret['val'][:, 0].astype(np.int32), ret['val'][:, 1].astype(np.int32)) ),
113 | shape = (num_users, num_items)
114 | )
115 |
116 | # Negatives will be used for AUC computation
117 | ret['negatives'] = [ set() for _ in range(num_users) ]
118 | for u in range(num_users):
119 | while len(ret['negatives'][u]) < 50:
120 | rand_item = np.random.randint(0, num_items)
121 | if rand_item in ret['train_positive_set'][u]: continue
122 | if rand_item in ret['test_positive_set'][u]: continue
123 | ret['negatives'][u].add(rand_item)
124 | ret['negatives'][u] = list(ret['negatives'][u])
125 | ret['negatives'] = np.array(ret['negatives'], dtype=np.int32)
126 |
127 | ret.update({
128 | 'num_users': num_users,
129 | 'num_items': num_items,
130 | 'num_interactions': len(ret['train']),
131 | })
132 |
133 | print("# users:", num_users)
134 | print("# items:", num_items)
135 | print("# train interactions:", len(ret['train']))
136 | print("# val interactions:", len(ret['val']))
137 | print("# test interactions:", len(ret['test']))
138 |
139 | return ret
140 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
3 | os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"
4 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
5 |
6 | import time
7 | import copy
8 | import random
9 | import numpy as np
10 | import torch
11 | from jax import numpy as jnp
12 |
13 | import model
14 | from parse import parse_args
15 | from utils import log_end_epoch, get_item_propensity, get_common_path, set_seed, preprocess_svd, preprocess_ease, convert_sp_mat_to_sp_tensor
16 |
17 | args = parse_args()
18 |
19 | def train(hyper_params, data):
20 | from model import make_kernelized_rr_forward
21 | from eval import evaluate
22 |
23 | # This just instantiates the function
24 | kernelized_rr_forward, kernel_fn = make_kernelized_rr_forward(hyper_params)
25 | sampled_matrix = data.sample_users(hyper_params['user_support']) # Random user sample
26 |
27 | if hyper_params['model'] == 'svd-ae':
28 | adj_mat = data.data['train_matrix'] + data.data['val_matrix']
29 | PATH = os.getcwd()
30 | adj_mat, norm_adj, ut, s, vt = preprocess_svd(hyper_params['load'], hyper_params['dataset'], adj_mat, hyper_params['k'], os.path.join(PATH, 'checkpoints'), device)
31 | train_model = model.SVD_AE(adj_mat, norm_adj, ut, vt, device)
32 |
33 | elif hyper_params['model'] == 'ease':
34 | adj_mat = data.data['train_matrix']
35 | adj_mat, item_adj = preprocess_ease(adj_mat, device)
36 | train_model = model.EASE(adj_mat, item_adj, device)
37 |
38 | elif hyper_params['model'] == 'inf-ae':
39 | rating = None
40 |
41 | else:
42 | print('This model is not supported!')
43 | exit()
44 |
45 | sampled_matrix = jnp.array(sampled_matrix.todense())
46 |
47 | '''
48 | NOTE: No training required! We will compute dual-variables \alpha on the fly in `kernelized_rr_forward`
49 | However, if we needed to perform evaluation multiple times, we could pre-compute \alpha like so:
50 |
51 | import jax, jax.numpy as jnp, jax.scipy as sp
52 | @jax.jit
53 | def precompute_alpha(X, lamda=0.1):
54 | K = kernel_fn(X, X)
55 | K_reg = (K + jnp.abs(lamda) * jnp.trace(K) * jnp.eye(K.shape[0]) / K.shape[0])
56 | return sp.linalg.solve(K_reg, X, sym_pos=True)
57 | alpha = precompute_alpha(sampled_matrix, lamda=0.1) # Change for the desired value of lamda
58 | '''
59 |
60 | # Used for computing the PSP-metric
61 | item_propensity = get_item_propensity(hyper_params, data)
62 |
63 | # Evaluation
64 | start_time = time.time()
65 |
66 | VAL_METRIC = "HR@10"
67 | best_metric, best_lamda = None, None
68 |
69 | if hyper_params['model'] == 'svd-ae':
70 | print(len(s))
71 | s = s.to(device)
72 | rating = train_model(s)
73 | test_metrics, preds = evaluate(rating, hyper_params, kernelized_rr_forward, data, item_propensity, sampled_matrix, test_set_eval = True)
74 |
75 | # MSE
76 | adj_mat = data.data['train_matrix'] + data.data['val_matrix']
77 | adj_mat = jnp.array(convert_sp_mat_to_sp_tensor(adj_mat).to_dense())
78 | err = (preds - adj_mat) ** 2
79 | mse = sum(sum(err)) / (adj_mat.shape[0] * adj_mat.shape[1])
80 | print("\nMSE value: {}".format(mse))
81 |
82 |
83 | elif hyper_params['model'] == 'ease':
84 | # Validate on the validation-set
85 | for lamda in [ 1.0, 10.0, 100.0, 1000.0, 10000.0 ] if hyper_params['grid_search_lamda'] else [ hyper_params['lamda'] ]:
86 | hyper_params['lamda'] = lamda
87 | rating = train_model(lamda)
88 | val_metrics, preds = evaluate(rating, hyper_params, kernelized_rr_forward, data, item_propensity, sampled_matrix)
89 | log_end_epoch(hyper_params, val_metrics, 0, time.time() - start_time)
90 | if (best_metric is None) or (val_metrics[VAL_METRIC] > best_metric): best_metric, best_lamda = val_metrics[VAL_METRIC], lamda
91 | print("\nBest lambda value: {}".format(best_lamda))
92 | hyper_params['lamda'] = best_lamda
93 |
94 | # Test on the train + validation set
95 | adj_mat = data.data['train_matrix'] + data.data['val_matrix']
96 | adj_mat, item_adj = preprocess_ease(adj_mat, device)
97 | train_model = model.EASE(adj_mat, item_adj, device)
98 | rating = train_model(best_lamda)
99 | test_metrics, preds = evaluate(rating, hyper_params, kernelized_rr_forward, data, item_propensity, sampled_matrix, test_set_eval = True)
100 |
101 | # MSE
102 | adj_mat = data.data['train_matrix'] + data.data['val_matrix']
103 | adj_mat = jnp.array(convert_sp_mat_to_sp_tensor(adj_mat).to_dense())
104 | err = (preds - adj_mat) ** 2
105 | mse = sum(sum(err)) / (adj_mat.shape[0] * adj_mat.shape[1])
106 | print("\nMSE value: {}".format(mse))
107 |
108 |
109 | else:
110 | # Validate on the validation-set
111 | for lamda in [ 0.0, 1.0, 5.0, 20.0, 50.0, 100.0 ] if hyper_params['grid_search_lamda'] else [ hyper_params['lamda'] ]:
112 | hyper_params['lamda'] = lamda
113 | val_metrics, preds = evaluate(rating, hyper_params, kernelized_rr_forward, data, item_propensity, sampled_matrix)
114 | log_end_epoch(hyper_params, val_metrics, 0, time.time() - start_time)
115 | if (best_metric is None) or (val_metrics[VAL_METRIC] > best_metric): best_metric, best_lamda = val_metrics[VAL_METRIC], lamda
116 | print("Best lambda value: {}".format(best_lamda))
117 | hyper_params['lamda'] = best_lamda
118 | test_metrics, preds = evaluate(rating, hyper_params, kernelized_rr_forward, data, item_propensity, sampled_matrix, test_set_eval = True)
119 |
120 | # MSE
121 | adj_mat = data.data['train_matrix'] + data.data['val_matrix']
122 | adj_mat = jnp.array(convert_sp_mat_to_sp_tensor(adj_mat).to_dense())
123 | err = (preds - adj_mat) ** 2
124 | mse = sum(sum(err)) / (adj_mat.shape[0] * adj_mat.shape[1])
125 | print("\nMSE value: {}".format(mse))
126 |
127 |
128 | # Return metrics with the best lamda on the test-set
129 | log_end_epoch(hyper_params, test_metrics, 0, time.time() - start_time)
130 | start_time = time.time()
131 |
132 | return test_metrics
133 |
134 | def main(hyper_params, gpu_id = None):
135 | if gpu_id is not None: os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
136 |
137 | from jax.config import config
138 | if 'float64' in hyper_params and hyper_params['float64'] == True: config.update('jax_enable_x64', True)
139 |
140 | from data import Dataset
141 |
142 | os.makedirs("./results/logs/", exist_ok=True)
143 | hyper_params['log_file'] = "./results/logs/" + get_common_path(hyper_params) + ".txt"
144 | data = Dataset(hyper_params)
145 | hyper_params = copy.deepcopy(data.hyper_params) # Updated w/ data-stats
146 |
147 | return train(hyper_params, data)
148 |
149 | if __name__ == "__main__":
150 | from hyper_params import hyper_params
151 | set_seed(hyper_params['seed'])
152 | GPU = torch.cuda.is_available()
153 | device = torch.device('cuda:0' if GPU else 'cpu')
154 | main(hyper_params)
155 |
--------------------------------------------------------------------------------