├── .gitignore ├── README.md ├── utils.py ├── evaluation.py ├── caser.py ├── data_process └── 3_item_dpp_emb.py ├── interactions.py ├── LICENSE └── train_caser.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pyc 3 | .idea/ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # DPPLikelihoods4SeqRec 3 | 4 | A PyTorch implementation of the paper: 5 | 6 | *Determinantal Point Process Likelihoods for Sequential Recommendation, Yuli Liu, Christian Walder and Lexing Xie, SIGIR '22* 7 | 8 | # Requirements 9 | * Python 2 or 3 10 | * [PyTorch v0.4+](https://github.com/pytorch/pytorch) 11 | * Numpy 12 | * SciPy 13 | 14 | # Usage 15 | 1. Install required packages. 16 | 2. run python train_caser.py 17 | 18 | # Configurations 19 | 20 | 21 | #### Data 22 | 23 | 24 | #### Model Args (in train_caser.py) 25 | 26 | 27 | # Citation 28 | 29 | If you use this code in your paper, please cite the paper: 30 | 31 | 32 | # Acknowledgment 33 | 34 | This project (utils.py, interactions.py, etc.) is heavily built on [Spotlight](https://github.com/maciejkula/spotlight) and [Caser](https://github.com/graytowne/caser_pytorch). 35 | Thanks to the authors! 36 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import random 6 | 7 | activation_getter = {'iden': lambda x: x, 'relu': F.relu, 'tanh': torch.tanh, 'sigm': torch.sigmoid} 8 | 9 | 10 | def gpu(tensor, gpu=False): 11 | 12 | if gpu: 13 | return tensor.cuda() 14 | else: 15 | return tensor 16 | 17 | 18 | def cpu(tensor): 19 | 20 | if tensor.is_cuda: 21 | return tensor.cpu() 22 | else: 23 | return tensor 24 | 25 | 26 | def minibatch(*tensors, **kwargs): 27 | 28 | batch_size = kwargs.get('batch_size', 128) 29 | 30 | if len(tensors) == 1: 31 | tensor = tensors[0] 32 | for i in range(0, len(tensor), batch_size): 33 | yield tensor[i:i + batch_size] 34 | else: 35 | for i in range(0, len(tensors[0]), batch_size): 36 | yield tuple(x[i:i + batch_size] for x in tensors) 37 | 38 | 39 | def shuffle(*arrays, **kwargs): 40 | 41 | require_indices = kwargs.get('indices', False) 42 | 43 | if len(set(len(x) for x in arrays)) != 1: 44 | raise ValueError('All inputs to shuffle must have ' 45 | 'the same length.') 46 | 47 | shuffle_indices = np.arange(len(arrays[0])) 48 | np.random.shuffle(shuffle_indices) 49 | 50 | if len(arrays) == 1: 51 | result = arrays[0][shuffle_indices] 52 | else: 53 | result = tuple(x[shuffle_indices] for x in arrays) 54 | 55 | if require_indices: 56 | return result, shuffle_indices 57 | else: 58 | return result 59 | 60 | 61 | def assert_no_grad(variable): 62 | 63 | if variable.requires_grad: 64 | raise ValueError( 65 | "nn criterions don't compute the gradient w.r.t. targets - please " 66 | "mark these variables as volatile or not requiring gradients" 67 | ) 68 | 69 | 70 | def set_seed(seed, cuda=False): 71 | 72 | np.random.seed(seed) 73 | random.seed(seed) 74 | if cuda: 75 | torch.cuda.manual_seed(seed) 76 | else: 77 | torch.manual_seed(seed) 78 | 79 | 80 | def str2bool(v): 81 | return v.lower() in ('true') -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | def _compute_apk(targets, predictions, k): 5 | 6 | if len(predictions) > k: 7 | predictions = predictions[:k] 8 | 9 | score = 0.0 10 | num_hits = 0.0 11 | 12 | for i, p in enumerate(predictions): 13 | if p in targets and p not in predictions[:i]: 14 | num_hits += 1.0 15 | score += num_hits / (i + 1.0) 16 | 17 | if not list(targets): 18 | return 0.0 19 | 20 | return score / min(len(targets), k) 21 | 22 | def dcg_at_k(r, k, method=1): 23 | r = np.asfarray(r)[:k] 24 | if r.size: 25 | if method == 0: 26 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 27 | elif method == 1: 28 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 29 | else: 30 | raise ValueError('method must be 0 or 1.') 31 | return 0. 32 | 33 | def ndcg_at_k(r, k, method=1): 34 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) 35 | if not dcg_max: 36 | return 0. 37 | return dcg_at_k(r, k, method) / dcg_max 38 | 39 | def cc_at_k(cc, k, CATE_NUM): 40 | cates = set() 41 | for i in range(k): 42 | if i > (len(cc)-1): 43 | break 44 | for c in cc[i]: 45 | cates.add(c) 46 | return len(cates) / CATE_NUM 47 | 48 | def _compute_precision_recall(targets, predictions, k, iidcate_map, cate_num): 49 | 50 | pred = predictions[:k] 51 | r = [] 52 | cc = [] 53 | for i in pred: 54 | if i in targets: 55 | r.append(1) 56 | else: 57 | r.append(0) 58 | if i == 0: 59 | continue 60 | else: 61 | cc.append(iidcate_map[i-1]) 62 | 63 | num_hit = len(set(pred).intersection(set(targets))) 64 | precision = float(num_hit) / len(pred) 65 | recall = float(num_hit) / len(targets) 66 | ndcg = ndcg_at_k(r, k) 67 | cc = cc_at_k(cc, k, cate_num) 68 | return precision, recall, ndcg, cc 69 | 70 | def evaluate_ranking(model, test, config, l_kernel, cate, train=None, k=10): 71 | """ 72 | Compute Precision@k, Recall@k scores and average precision (AP). 73 | One score is given for every user with interactions in the test 74 | set, representing the AP, Precision@k and Recall@k of all their 75 | test items. 76 | 77 | Parameters 78 | ---------- 79 | 80 | model: fitted instance of a recommender model 81 | The model to evaluate. 82 | test: :class:`spotlight.interactions.Interactions` 83 | Test interactions. 84 | train: :class:`spotlight.interactions.Interactions`, optional 85 | Train interactions. If supplied, rated items in 86 | interactions will be excluded. 87 | k: int or array of int, 88 | The maximum number of predicted items 89 | """ 90 | 91 | test = test.tocsr() 92 | 93 | if train is not None: 94 | train = train.tocsr() 95 | 96 | if not isinstance(k, list): 97 | ks = [k] 98 | else: 99 | ks = k 100 | 101 | precisions = [list() for _ in range(len(ks))] 102 | recalls = [list() for _ in range(len(ks))] 103 | ndcgs = [list() for _ in range(len(ks))] 104 | ccs = [list() for _ in range(len(ks))] 105 | apks = list() 106 | 107 | for user_id, row in enumerate(test): 108 | 109 | if not len(row.indices): 110 | continue 111 | 112 | predictions = -model.predict(user_id) 113 | if train is not None: 114 | rated = set(train[user_id].indices) 115 | else: 116 | rated = [] 117 | 118 | predictions = predictions.argsort() 119 | predictions = [p for p in predictions if p not in rated] 120 | 121 | targets = row.indices 122 | if 0 in targets: 123 | print('there is 0') 124 | 125 | for i, _k in enumerate(ks): 126 | precision, recall, ndcg, cc = _compute_precision_recall(targets, predictions, _k, cate, config.cate_num) 127 | precisions[i].append(precision) 128 | recalls[i].append(recall) 129 | ndcgs[i].append(ndcg) 130 | ccs[i].append(cc) 131 | 132 | apks.append(_compute_apk(targets, predictions, k=np.inf)) 133 | 134 | precisions = [np.array(i) for i in precisions] 135 | recalls = [np.array(i) for i in recalls] 136 | 137 | if not isinstance(k, list): 138 | precisions = precisions[0] 139 | recalls = recalls[0] 140 | 141 | return precisions, recalls, ndcgs, ccs 142 | -------------------------------------------------------------------------------- /caser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from utils import activation_getter 6 | 7 | 8 | class Caser(nn.Module): 9 | """ 10 | Convolutional Sequence Embedding Recommendation Model (Caser)[1]. 11 | 12 | [1] Personalized Top-N Sequential Recommendation via Convolutional Sequence Embedding, Jiaxi Tang and Ke Wang , WSDM '18 13 | 14 | Parameters 15 | ---------- 16 | 17 | num_users: int, 18 | Number of users. 19 | num_items: int, 20 | Number of items. 21 | model_args: args, 22 | Model-related arguments, like latent dimensions. 23 | """ 24 | 25 | def __init__(self, num_users, num_items, model_args): 26 | super(Caser, self).__init__() 27 | self.args = model_args 28 | 29 | # init args 30 | L = self.args.L 31 | dims = self.args.d 32 | self.n_h = self.args.nh 33 | self.n_v = self.args.nv 34 | self.drop_ratio = self.args.drop 35 | self.ac_conv = activation_getter[self.args.ac_conv] 36 | self.ac_fc = activation_getter[self.args.ac_fc] 37 | 38 | # user and item embeddings 39 | self.user_embeddings = nn.Embedding(num_users, dims) 40 | self.item_embeddings = nn.Embedding(num_items, dims) 41 | 42 | # vertical conv layer 43 | self.conv_v = nn.Conv2d(1, self.n_v, (L, 1)) 44 | 45 | # horizontal conv layer 46 | lengths = [i + 1 for i in range(L)] 47 | self.conv_h = nn.ModuleList([nn.Conv2d(1, self.n_h, (i, dims)) for i in lengths]) 48 | 49 | # fully-connected layer 50 | self.fc1_dim_v = self.n_v * dims 51 | self.fc1_dim_h = self.n_h * len(lengths) 52 | fc1_dim_in = self.fc1_dim_v + self.fc1_dim_h 53 | # W1, b1 can be encoded with nn.Linear 54 | self.fc1 = nn.Linear(fc1_dim_in, dims) 55 | # W2, b2 are encoded with nn.Embedding, as we don't need to compute scores for all items 56 | self.W2 = nn.Embedding(num_items, dims+dims) 57 | self.b2 = nn.Embedding(num_items, 1) 58 | 59 | # dropout 60 | self.dropout = nn.Dropout(self.drop_ratio) 61 | 62 | ## weight initialization 63 | self.user_embeddings.weight.data.normal_(0, 1.0 / self.user_embeddings.embedding_dim) 64 | self.item_embeddings.weight.data.normal_(0, 1.0 / self.item_embeddings.embedding_dim) 65 | self.W2.weight.data.normal_(0, 1.0 / self.W2.embedding_dim) 66 | self.b2.weight.data.zero_() 67 | 68 | self.cache_x = None 69 | 70 | def forward(self, seq_var, user_var, item_var, for_pred=False): 71 | """ 72 | The forward propagation used to get recommendation scores, given 73 | triplet (user, sequence, targets). 74 | 75 | Parameters 76 | ---------- 77 | 78 | seq_var: torch.FloatTensor with size [batch_size, max_sequence_length] 79 | a batch of sequence 80 | user_var: torch.LongTensor with size [batch_size] 81 | a batch of user 82 | item_var: torch.LongTensor with size [batch_size] 83 | a batch of items 84 | for_pred: boolean, optional 85 | Train or Prediction. Set to True when evaluation. 86 | """ 87 | 88 | # Embedding Look-up 89 | item_embs = self.item_embeddings(seq_var).unsqueeze(1) # use unsqueeze() to get 4-D, seq embeddings 90 | user_emb = self.user_embeddings(user_var).squeeze(1) 91 | 92 | # Convolutional Layers 93 | out, out_h, out_v = None, None, None 94 | # vertical conv layer 95 | if self.n_v: 96 | out_v = self.conv_v(item_embs) 97 | out_v = out_v.view(-1, self.fc1_dim_v) # prepare for fully connect 98 | 99 | # horizontal conv layer 100 | out_hs = list() 101 | if self.n_h: 102 | for conv in self.conv_h: 103 | conv_out = self.ac_conv(conv(item_embs).squeeze(3)) 104 | pool_out = F.max_pool1d(conv_out, conv_out.size(2)).squeeze(2) 105 | out_hs.append(pool_out) 106 | out_h = torch.cat(out_hs, 1) # prepare for fully connect 107 | 108 | # Fully-connected Layers, final item embeddings 109 | out = torch.cat([out_v, out_h], 1) 110 | # apply dropout 111 | out = self.dropout(out) 112 | 113 | # fully-connected layer 114 | z = self.ac_fc(self.fc1(out)) 115 | x = torch.cat([z, user_emb], 1) #z is combined by seq item embs and user emb 116 | 117 | w2 = self.W2(item_var) 118 | b2 = self.b2(item_var) 119 | 120 | if for_pred: 121 | w2 = w2.squeeze() 122 | b2 = b2.squeeze() 123 | res = (x * w2).sum(1) + b2 124 | else: 125 | res = torch.baddbmm(b2, w2, x.unsqueeze(2)).squeeze() 126 | 127 | return res 128 | -------------------------------------------------------------------------------- /data_process/3_item_dpp_emb.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from copy import deepcopy 4 | import tensorflow.compat.v1 as tf 5 | tf.disable_v2_behavior() 6 | from random import shuffle 7 | import time 8 | from ast import literal_eval 9 | import pickle as cPickle 10 | 11 | import torch.nn.functional as F 12 | import torch 13 | #################################### 14 | # logistic dpp, used to generate diverse item embedings based on item sets 15 | # this code is mainly based the source code of "Multi-Task Determinantal Point Processes for Recommendation", thanks to the authors 16 | # generate user_num, item_num, input files (item sets with fixed length 5), and output_files, according to your data 17 | #################################### 18 | 19 | t0 = time.time() 20 | 21 | np.random.seed(0) 22 | 23 | #################################### 24 | # parameters 25 | #################################### 26 | user_num = 4641 27 | item_num = 2235 28 | emb_dim = 64 29 | set_length = 5 #k_sized length of a set 30 | lr = 1e-4 31 | decay_step = 100 32 | decay = 0.95 33 | 34 | sigmoid_lbda = 0.01 35 | epochs = 100 36 | runs = 1 37 | batch_size = 1024 38 | emb_init_mean = 0. 39 | emb_init_std = 0.01 40 | diag_init_mean = 1. 41 | diag_init_std = 0.01 42 | regu_weight = 0. 43 | 44 | ################################ 45 | # get sets from prepared sets files. 46 | # format: 47 | # each line: u;id1,id2,...,id5;id2,id3,...,id6;... 48 | # positive sets are selected from a user's interacted items (in training dataset), each set contains 5 items 49 | # negative sets are randomly selected items that a user is not interested in 50 | ################################ 51 | def get_sets(pos_set_file, neg_set_file): 52 | 53 | upos_sets = [] 54 | with open(pos_set_file) as f: 55 | for l in f.readlines(): 56 | sstr = l.strip().split(';') 57 | u, sets = int(sstr[0]), sstr[1:] 58 | 59 | for s in sets: 60 | a_set = [] 61 | s1 = s.split(',') 62 | for id in s1: 63 | a_set.append(int(id)) 64 | if len(a_set) == set_length: 65 | upos_sets.append(a_set) 66 | 67 | uneg_sets = [] 68 | with open(neg_set_file) as f: 69 | for l in f.readlines(): 70 | sstr = l.strip().split(';') 71 | u, sets = int(sstr[0]), sstr[1:] 72 | 73 | for s in sets: 74 | a_set = [] 75 | s1 = s.split(',') 76 | for id in s1: 77 | a_set.append(int(id)) 78 | if len(a_set) == set_length: 79 | uneg_sets.append(a_set) 80 | return np.array(upos_sets), np.array(uneg_sets) 81 | 82 | ################################ 83 | # create model 84 | ################################ 85 | def set_det(item_sets): 86 | subV = tf.gather(weights['V'],item_sets) 87 | subD = tf.matrix_diag(tf.square(tf.gather(weights['D'],item_sets))) 88 | K1 = tf.matmul(subV, tf.transpose(subV,perm=[0,2,1])) 89 | K = tf.add(K1,subD) 90 | eps = tf.eye(tf.shape(K)[1],tf.shape(K)[1],[tf.shape(K)[0]]) 91 | K = tf.add(K,eps) 92 | res = tf.matrix_determinant(K) 93 | return res 94 | 95 | def logsigma(itemSet): 96 | return tf.reduce_mean(tf.log(1-tf.exp(-sigmoid_lbda*set_det(itemSet)))) 97 | 98 | def regularization(itemSet): 99 | itemsInBatch, _ = tf.unique(tf.reshape(itemSet,[-1])) 100 | subV = tf.gather(weights['V'],itemsInBatch) 101 | subD = tf.gather(weights['D'],itemsInBatch) 102 | subV_norm = tf.reduce_mean(tf.norm(subV,axis=1)) 103 | subD_norm = tf.norm(subD) 104 | return subV_norm+subD_norm 105 | 106 | ################################ 107 | # tf graph 108 | ################################ 109 | 110 | pset_input = tf.placeholder(tf.int32, [None,None]) #item sets 111 | nset_input = tf.placeholder(tf.int32, [None,None]) #item sets 112 | 113 | #get processed sets 114 | pos_sets, neg_sets = get_sets('pos_item_sets_3.txt', 'neg_item_sets_3.txt') 115 | train_size = len(pos_sets) 116 | 117 | print(pos_sets.shape, neg_sets.shape) 118 | for run in range(runs): 119 | # Construct model 120 | pset_input = tf.placeholder(tf.int32, [None,None]) #item sets 121 | nset_input = tf.placeholder(tf.int32, [None,None]) #item sets 122 | 123 | # Store layers weight & bias 124 | initializer = tf.keras.initializers.glorot_normal() 125 | weights = { 126 | 'V': tf.Variable(initializer([item_num, emb_dim]), name='item_embeddings'), 127 | 'D': tf.Variable(initializer([item_num]), name='item_bias') 128 | } 129 | # Construct model 130 | loss = logsigma(pset_input) + tf.log(1 - logsigma(nset_input)) # - regu_weight*regularization(pset_input) + regu_weight*regularization(nset_input) 131 | 132 | optimizer = tf.train.AdamOptimizer(learning_rate=lr,beta1=0.01,beta2=0.01) 133 | train_op = optimizer.minimize(-loss) 134 | 135 | # Initializing the variables 136 | init = tf.global_variables_initializer() 137 | 138 | print("start training...") 139 | with tf.Session() as sess: 140 | sess.run(init) 141 | # Training cycle 142 | for epoch in range(epochs): 143 | ave_cost = 0. 144 | nbatch = 0 145 | while True: 146 | if nbatch*batch_size <= train_size: 147 | pos_batch = pos_sets[nbatch*batch_size: (nbatch+1)*batch_size] 148 | neg_batch = neg_sets[nbatch*batch_size: (nbatch+1)*batch_size] 149 | else: 150 | if train_size - (nbatch-1)*batch_size > 0: 151 | pos_batch = pos_sets[(nbatch-1)*batch_size: train_size] 152 | neg_batch = neg_sets[(nbatch-1)*batch_size: train_size] 153 | break 154 | nbatch += 1 155 | 156 | _, c = sess.run([train_op, loss], feed_dict={pset_input: pos_batch, nset_input: neg_batch}) 157 | ave_cost += c / nbatch 158 | 159 | param = sess.run(weights) 160 | cPickle.dump(param, open('item_kernel_3.pkl', 'wb')) #T=3 161 | -------------------------------------------------------------------------------- /interactions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes describing datasets of user-item interactions. Instances of these 3 | are returned by dataset-fetching and dataset-processing functions. 4 | """ 5 | 6 | import numpy as np 7 | 8 | import scipy.sparse as sp 9 | 10 | 11 | class Interactions(object): 12 | """ 13 | Interactions object. Contains (at a minimum) pair of user-item 14 | interactions. This is designed only for implicit feedback scenarios. 15 | 16 | Parameters 17 | ---------- 18 | 19 | file_path: file contains (user,item,rating) triplets 20 | user_map: dict of user mapping 21 | item_map: dict of item mapping 22 | """ 23 | 24 | def __init__(self, file_path, 25 | user_map=None, 26 | item_map=None): 27 | 28 | if not user_map and not item_map: 29 | user_map = dict() 30 | item_map = dict() 31 | 32 | num_user = 0 33 | num_item = 0 34 | else: 35 | num_user = len(user_map) 36 | num_item = len(item_map) + 1 37 | 38 | user_ids = list() 39 | item_ids = list() 40 | # read users and items from file 41 | with open(file_path, 'r') as fin: 42 | for line in fin: 43 | ids = line.strip().split() 44 | u = ids[0] 45 | iids = ids[1:] 46 | for i in iids: 47 | user_ids.append(u) 48 | item_ids.append(i) 49 | 50 | # update user and item mapping 51 | for u in user_ids: 52 | if u not in user_map: 53 | user_map[u] = num_user 54 | num_user += 1 55 | for i in item_ids: 56 | if i not in item_map: 57 | item_map[i] = num_item 58 | num_item += 1 59 | 60 | user_ids = np.array([user_map[u] for u in user_ids]) 61 | item_ids = np.array([item_map[i] for i in item_ids]) 62 | 63 | self.num_users = num_user 64 | self.num_items = num_item 65 | 66 | self.user_ids = user_ids 67 | self.item_ids = item_ids 68 | 69 | self.user_map = user_map 70 | self.item_map = item_map 71 | 72 | self.sequences = None 73 | self.test_sequences = None 74 | 75 | def __len__(self): 76 | 77 | return len(self.user_ids) 78 | 79 | def tocoo(self): 80 | """ 81 | Transform to a scipy.sparse COO matrix. 82 | """ 83 | 84 | row = self.user_ids 85 | col = self.item_ids 86 | data = np.ones(len(self)) 87 | 88 | return sp.coo_matrix((data, (row, col)), 89 | shape=(self.num_users, self.num_items)) 90 | 91 | def tocsr(self): 92 | """ 93 | Transform to a scipy.sparse CSR matrix. 94 | """ 95 | 96 | return self.tocoo().tocsr() 97 | 98 | def to_sequence(self, sequence_length=5, target_length=1): 99 | """ 100 | Transform to sequence form. 101 | 102 | Valid subsequences of users' interactions are returned. For 103 | example, if a user interacted with items [1, 2, 3, 4, 5, 6, 7, 8, 9], the 104 | returned interactions matrix at sequence length 5 and target length 3 105 | will be be given by: 106 | 107 | sequences: 108 | 109 | [[1, 2, 3, 4, 5], 110 | [2, 3, 4, 5, 6], 111 | [3, 4, 5, 6, 7]] 112 | 113 | targets: 114 | 115 | [[6, 7], 116 | [7, 8], 117 | [8, 9]] 118 | 119 | sequence for test (the last 'sequence_length' items of each user's sequence): 120 | 121 | [[5, 6, 7, 8, 9]] 122 | 123 | Parameters 124 | ---------- 125 | 126 | sequence_length: int 127 | Sequence length. Subsequences shorter than this 128 | will be left-padded with zeros. 129 | target_length: int 130 | Sequence target length. 131 | """ 132 | 133 | # change the item index start from 1 as 0 is used for padding in sequences 134 | for k, v in self.item_map.items(): 135 | self.item_map[k] = v + 1 136 | self.item_ids = self.item_ids + 1 137 | self.num_items += 1 138 | 139 | max_sequence_length = sequence_length + target_length 140 | 141 | # Sort first by user id 142 | sort_indices = np.lexsort((self.user_ids,)) 143 | 144 | user_ids = self.user_ids[sort_indices] 145 | item_ids = self.item_ids[sort_indices] 146 | 147 | user_ids, indices, counts = np.unique(user_ids, 148 | return_index=True, 149 | return_counts=True) 150 | 151 | num_subsequences = sum([c - max_sequence_length + 1 if c >= max_sequence_length else 1 for c in counts]) 152 | 153 | sequences = np.zeros((num_subsequences, sequence_length), 154 | dtype=np.int64) 155 | sequences_targets = np.zeros((num_subsequences, target_length), 156 | dtype=np.int64) 157 | sequence_users = np.empty(num_subsequences, 158 | dtype=np.int64) 159 | 160 | test_sequences = np.zeros((self.num_users, sequence_length), 161 | dtype=np.int64) 162 | test_users = np.empty(self.num_users, 163 | dtype=np.int64) 164 | 165 | _uid = None 166 | for i, (uid, 167 | item_seq) in enumerate(_generate_sequences(user_ids, 168 | item_ids, 169 | indices, 170 | max_sequence_length)): 171 | if uid != _uid: 172 | test_sequences[uid][:] = item_seq[-sequence_length:] #last previous sequence 173 | test_users[uid] = uid 174 | _uid = uid 175 | sequences_targets[i][:] = item_seq[-target_length:] 176 | sequences[i][:] = item_seq[:sequence_length] 177 | sequence_users[i] = uid 178 | self.sequences = SequenceInteractions(sequence_users, sequences, sequences_targets) 179 | self.test_sequences = SequenceInteractions(test_users, test_sequences) 180 | 181 | 182 | class SequenceInteractions(object): 183 | """ 184 | Interactions encoded as a sequence matrix. 185 | 186 | Parameters 187 | ---------- 188 | user_ids: np.array 189 | sequence users 190 | sequences: np.array 191 | The interactions sequence matrix, as produced by 192 | :func:`~Interactions.to_sequence` 193 | targets: np.array 194 | sequence targets 195 | """ 196 | 197 | def __init__(self, 198 | user_ids, 199 | sequences, 200 | targets=None): 201 | self.user_ids = user_ids 202 | self.sequences = sequences 203 | self.targets = targets 204 | 205 | self.L = sequences.shape[1] 206 | self.T = None 207 | if np.any(targets): 208 | self.T = targets.shape[1] 209 | 210 | 211 | def _sliding_window(tensor, window_size, step_size=1): 212 | if len(tensor) - window_size >= 0: 213 | for i in range(len(tensor), 0, -step_size): 214 | if i - window_size >= 0: 215 | yield tensor[i - window_size:i] 216 | else: 217 | break 218 | else: 219 | num_paddings = window_size - len(tensor) 220 | # Pad sequence with 0s if it is shorter than windows size. 221 | yield np.pad(tensor, (num_paddings, 0), 'constant') 222 | 223 | 224 | def _generate_sequences(user_ids, item_ids, 225 | indices, 226 | max_sequence_length): 227 | for i in range(len(indices)): 228 | 229 | start_idx = indices[i] 230 | 231 | if i >= len(indices) - 1: 232 | stop_idx = None 233 | else: 234 | stop_idx = indices[i + 1] 235 | 236 | for seq in _sliding_window(item_ids[start_idx:stop_idx], 237 | max_sequence_length): 238 | yield (user_ids[i], seq) 239 | 240 | 241 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /train_caser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from time import time 3 | 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import pickle as cPickle 8 | 9 | from caser import Caser 10 | from evaluation import evaluate_ranking 11 | from interactions import Interactions 12 | from utils import * 13 | 14 | 15 | class Recommender(object): 16 | """ 17 | Contains attributes and methods that needed to train a sequential 18 | recommendation model. Models are trained by many tuples of 19 | (users, sequences, targets, negatives) and negatives are from negative 20 | sampling: for any known tuple of (user, sequence, targets), one or more 21 | items are randomly sampled to act as negatives. 22 | 23 | 24 | Parameters 25 | ---------- 26 | 27 | n_iter: int, 28 | Number of iterations to run. 29 | batch_size: int, 30 | Minibatch size. 31 | l2: float, 32 | L2 loss penalty, also known as the 'lambda' of l2 regularization. 33 | neg_samples: int, 34 | Number of negative samples to generate for each targets. 35 | If targets=3 and neg_samples=3, then it will sample 9 negatives. 36 | learning_rate: float, 37 | Initial learning rate. 38 | use_cuda: boolean, 39 | Run the model on a GPU or CPU. 40 | model_args: args, 41 | Model-related arguments, like latent dimensions. 42 | """ 43 | 44 | def __init__(self, 45 | n_iter=None, 46 | batch_size=None, 47 | l2=None, 48 | neg_samples=None, 49 | learning_rate=None, 50 | use_cuda=False, 51 | model_args=None): 52 | 53 | # model related 54 | self._num_items = None 55 | self._num_users = None 56 | self._net = None 57 | self.model_args = model_args 58 | 59 | # learning related 60 | self._batch_size = batch_size 61 | self._n_iter = n_iter 62 | self._learning_rate = learning_rate 63 | self._l2 = l2 64 | self._neg_samples = neg_samples 65 | self._device = torch.device("cuda" if use_cuda else "cpu") 66 | 67 | # rank evaluation related 68 | self.test_sequence = None 69 | self._candidate = dict() 70 | 71 | @property 72 | def _initialized(self): 73 | return self._net is not None 74 | 75 | def _initialize(self, interactions): 76 | self._num_items = interactions.num_items 77 | self._num_users = interactions.num_users 78 | 79 | self.test_sequence = interactions.test_sequences 80 | 81 | self._net = Caser(self._num_users, 82 | self._num_items, 83 | self.model_args).to(self._device) 84 | 85 | self._optimizer = optim.Adam(self._net.parameters(), 86 | weight_decay=self._l2, 87 | lr=self._learning_rate) 88 | 89 | def fit(self, train, test, cate, config, verbose=False): 90 | """ 91 | The general training loop to fit the model 92 | 93 | Parameters 94 | ---------- 95 | 96 | train: :class:`spotlight.interactions.Interactions` 97 | training instances, also contains test sequences 98 | test: :class:`spotlight.interactions.Interactions` 99 | only contains targets for test sequences 100 | verbose: bool, optional 101 | print the logs 102 | """ 103 | ################################## 104 | # read pre-learned kernel 105 | ################################### 106 | lk_param = cPickle.load(open(config.l_kernel_emb, 'rb'), encoding="latin1") 107 | lk_tensor = torch.FloatTensor(lk_param['V']).to(self._device) 108 | 109 | lk_emb_i = F.normalize(lk_tensor, p=2, dim=1) 110 | l_kernel = torch.matmul(lk_emb_i, lk_emb_i.t()) 111 | 112 | #l_kernel = torch.sigmoid(l_kernel) #this line is optional; use this line, if encounter non-invertible or nan problem 113 | 114 | #l_kernel_un = torch.matmul(lk_tensor, lk_tensor.t()) ##un-normalized pre-learned kernel 115 | 116 | # convert to sequences, targets and users 117 | sequences_np = train.sequences.sequences 118 | targets_np = train.sequences.targets 119 | users_np = train.sequences.user_ids.reshape(-1, 1) 120 | 121 | L, T = train.sequences.L, train.sequences.T 122 | 123 | n_train = sequences_np.shape[0] 124 | 125 | output_str = 'total training instances: %d' % n_train 126 | print(output_str) 127 | 128 | if not self._initialized: 129 | self._initialize(train) 130 | 131 | start_epoch = 0 132 | pre_list = [] 133 | for epoch_num in range(start_epoch, self._n_iter): 134 | 135 | t1 = time() 136 | 137 | # set model to training mode 138 | self._net.train() 139 | 140 | users_np, sequences_np, targets_np = shuffle(users_np, 141 | sequences_np, 142 | targets_np) 143 | 144 | negatives_np = self._generate_negative_samples(users_np, train, n=self._neg_samples) 145 | 146 | # convert numpy arrays to PyTorch tensors and move it to the corresponding devices 147 | users, sequences, targets, negatives = (torch.from_numpy(users_np).long(), 148 | torch.from_numpy(sequences_np).long(), 149 | torch.from_numpy(targets_np).long(), 150 | torch.from_numpy(negatives_np).long()) 151 | 152 | users, sequences, targets, negatives = (users.to(self._device), 153 | sequences.to(self._device), 154 | targets.to(self._device), 155 | negatives.to(self._device)) 156 | 157 | epoch_loss = 0.0 158 | 159 | for (minibatch_num, 160 | (batch_users, 161 | batch_sequences, 162 | batch_targets, 163 | batch_negatives)) in enumerate(minibatch(users, 164 | sequences, 165 | targets, 166 | negatives, 167 | batch_size=self._batch_size)): 168 | items_to_predict = torch.cat((batch_targets, batch_negatives, batch_sequences), 1) 169 | items_prediction = self._net(batch_sequences, 170 | batch_users, 171 | items_to_predict) 172 | 173 | (targets_prediction, negatives_prediction, 174 | seq_prediction) = torch.split(items_prediction, 175 | [batch_targets.size(1), 176 | batch_negatives.size(1), 177 | batch_sequences.size(1)], dim=1) 178 | 179 | self._optimizer.zero_grad() 180 | 181 | if config.dpp_loss == 0: 182 | # compute the binary cross-entropy loss 183 | positive_loss = -torch.mean( 184 | torch.log(torch.sigmoid(targets_prediction))) 185 | negative_loss = -torch.mean( 186 | torch.log(1 - torch.sigmoid(negatives_prediction))) 187 | loss = positive_loss + negative_loss 188 | 189 | ############################################### 190 | # compute the dpp set loss 191 | ############################################### 192 | # DSL 193 | elif config.dpp_loss == 1: 194 | dpp_lhs = [] 195 | size = targets_prediction.shape[0] 196 | batch_sets = torch.cat((batch_targets, batch_negatives), 1) 197 | batch_predictions = torch.cat((targets_prediction, negatives_prediction), 1) 198 | #minibatch format 199 | if config.batch_format == 1: 200 | batch_pos_kernel = torch.zeros(size, config.T, config.T).cuda() 201 | batch_set_kernel = torch.zeros(size, config.T + config.neg_samples, config.T + config.neg_samples).cuda() 202 | 203 | for n in range(size): 204 | batch_pos_kernel[n] = l_kernel[batch_targets[n]-1][:, batch_targets[n]-1] 205 | batch_set_kernel[n] = l_kernel[batch_sets[n]-1][:, batch_sets[n]-1] 206 | 207 | batch_pos_q = torch.diag_embed(torch.exp(targets_prediction)) #can also try sigmoid in some cases 208 | batch_set_q = torch.diag_embed(torch.exp(batch_predictions)) 209 | 210 | batch_pos_kernel = torch.bmm(torch.bmm(batch_pos_q, batch_pos_kernel), batch_pos_q) 211 | batch_set_kernel = torch.bmm(torch.bmm(batch_set_q, batch_set_kernel), batch_set_q) 212 | 213 | p_diag = torch.eye(config.T)*1e-5 214 | pa_diag = p_diag.reshape((1, config.T, config.T)) 215 | pbatch_diag = pa_diag.repeat(size, 1, 1) 216 | 217 | s_diag = torch.eye(config.T+config.neg_samples) 218 | sa_diag = s_diag.reshape((1, config.T + config.neg_samples, config.T + config.neg_samples)) 219 | sbatch_diag = sa_diag.repeat(size, 1, 1) 220 | 221 | batch_pos_det = torch.det(batch_pos_kernel.cpu() + pbatch_diag).cuda() 222 | batch_set_det = torch.det(batch_set_kernel.cpu() + sbatch_diag).cuda() 223 | 224 | dpp_loss = torch.log(batch_pos_det/batch_set_det) 225 | loss = -torch.mean(dpp_loss) 226 | else: 227 | for n in range(size): 228 | pos_q = torch.diag_embed(torch.exp(targets_prediction[n])) 229 | set_q = torch.diag_embed(torch.exp(batch_predictions[n])) 230 | 231 | pos_l_kernel = l_kernel[batch_targets[n]-1][:, batch_targets[n]-1] 232 | set_l_kernel = l_kernel[batch_sets[n]-1][:, batch_sets[n]-1] 233 | 234 | pos_k = torch.mm(torch.mm(pos_q, pos_l_kernel), pos_q) 235 | set_k = torch.mm(torch.mm(set_q, set_l_kernel), set_q) 236 | 237 | pos_det = torch.det(pos_k.cpu() + torch.eye(len(batch_targets[n]))*1e-5).cuda() 238 | set_det = torch.det(set_k.cpu() + torch.eye(len(batch_sets[n]))).cuda() 239 | 240 | dpp_loss = torch.log(pos_det/set_det) 241 | 242 | dpp_lhs.append(dpp_loss) 243 | loss = -torch.mean(torch.stack(dpp_lhs)) 244 | # CDSL 245 | elif config.dpp_loss == 2: 246 | dpp_lhs = [] 247 | size = targets_prediction.shape[0] 248 | set_items = torch.cat((batch_sequences, batch_targets, batch_negatives), 1) 249 | set_predictions = torch.cat((seq_prediction, targets_prediction, negatives_prediction), 1) 250 | 251 | pos_items = torch.cat((batch_sequences, batch_targets), 1) 252 | pos_predictions = torch.cat((seq_prediction, targets_prediction), 1) #L+T 253 | if config.batch_format == 1: 254 | batch_pos_kernel = torch.zeros(size, config.L + config.T, config.L + config.T).cuda() 255 | batch_set_kernel = torch.zeros(size, config.L + config.T + config.neg_samples, config.L + config.T + config.neg_samples).cuda() 256 | 257 | for n in range(size): 258 | batch_pos_kernel[n] = l_kernel[pos_items[n]-1][:, pos_items[n]-1] 259 | batch_set_kernel[n] = l_kernel[set_items[n]-1][:, set_items[n]-1] 260 | 261 | batch_pos_q = torch.diag_embed(torch.exp(pos_predictions)) 262 | batch_set_q = torch.diag_embed(torch.exp(set_predictions)) 263 | 264 | batch_pos_kernel = torch.bmm(torch.bmm(batch_pos_q, batch_pos_kernel), batch_pos_q) 265 | batch_set_kernel = torch.bmm(torch.bmm(batch_set_q, batch_set_kernel), batch_set_q) 266 | 267 | p_diag = torch.eye(config.L + config.T)*1e-3 268 | pa_diag = p_diag.reshape((1, config.L + config.T, config.L + config.T)) 269 | pbatch_diag = pa_diag.repeat(size, 1, 1) 270 | 271 | s_diag = torch.diag_embed(torch.FloatTensor([1e-3]*config.L+[1]*(config.neg_samples+config.T))) 272 | sa_diag = s_diag.reshape((1, config.L + config.T + config.neg_samples, config.L + config.T + config.neg_samples)) 273 | sbatch_diag = sa_diag.repeat(size, 1, 1) 274 | 275 | batch_pos_det = torch.det(batch_pos_kernel.cpu() + pbatch_diag).cuda() 276 | batch_set_det = torch.det(batch_set_kernel.cpu() + sbatch_diag).cuda() 277 | 278 | dpp_loss = torch.log(batch_pos_det/batch_set_det) 279 | loss = -torch.mean(dpp_loss) 280 | else: 281 | diag_I = torch.diag_embed(torch.FloatTensor([1e-3]*config.L+[1]*(config.neg_samples+config.T))) 282 | diag_posI = torch.diag_embed(torch.FloatTensor([1e-3]*(config.L+config.T))) 283 | for n in range(size): 284 | pos_q = torch.diag_embed(torch.exp(pos_predictions[n])) 285 | set_q = torch.diag_embed(torch.exp(set_predictions[n])) 286 | 287 | pos_l_kernel = l_kernel[pos_items[n]-1][:, pos_items[n]-1] 288 | set_l_kernel = l_kernel[set_items[n]-1][:, set_items[n]-1] 289 | 290 | pos_k = torch.mm(torch.mm(pos_q, pos_l_kernel), pos_q) 291 | set_k = torch.mm(torch.mm(set_q, set_l_kernel), set_q) 292 | 293 | pos_det = torch.det(pos_k.cpu() + diag_posI).cuda() 294 | set_det = torch.det(set_k.cpu() + diag_I).cuda() 295 | 296 | dpp_loss = torch.log(pos_det/set_det) 297 | dpp_lhs.append(dpp_loss) 298 | loss = -torch.mean(torch.stack(dpp_lhs)) 299 | 300 | epoch_loss += loss.item() 301 | 302 | loss.backward() 303 | self._optimizer.step() 304 | 305 | epoch_loss /= minibatch_num + 1 306 | 307 | t2 = time() 308 | if verbose: 309 | if (epoch_num+1) % 10 == 0: 310 | precision, recall, ndcg, cc = evaluate_ranking(self, test, config, l_kernel, cate, train, k=[3, 5, 10]) 311 | output_str = "Epoch %d [%.1f s], loss=%.4f, " \ 312 | "prec@3=%.4f, *prec@5=%.4f, prec@10=%.4f, " \ 313 | "recall@3=%.4f, recall@5=%.4f, recall@10=%.4f, " \ 314 | "ndcg@3=%.4f, ndcg@5=%.4f, ndcg@10=%.4f, " \ 315 | "*cc@3=%.4f, cc@5=%.4f, cc@10=%.4f, [%.1f s]" % (epoch_num + 1, 316 | t2 - t1, 317 | epoch_loss, 318 | np.mean(precision[0]), 319 | np.mean(precision[1]), 320 | np.mean(precision[2]), 321 | np.mean(recall[0]), 322 | np.mean(recall[1]), 323 | np.mean(recall[2]), 324 | np.mean(ndcg[0]), 325 | np.mean(ndcg[1]), 326 | np.mean(ndcg[2]), 327 | np.mean(cc[0]), 328 | np.mean(cc[1]), 329 | np.mean(cc[2]), 330 | time() - t2) 331 | 332 | print(output_str) 333 | else: 334 | output_str = "Epoch %d [%.1f s]\tloss=%.4f [%.1f s]" % (epoch_num + 1, 335 | t2 - t1, 336 | epoch_loss, 337 | time() - t2) 338 | print(output_str) 339 | 340 | def _generate_negative_samples(self, users, interactions, n): 341 | 342 | """ 343 | Sample negative from a candidate set of each user. The 344 | candidate set of each user is defined by: 345 | {All Items} \ {Items Rated by User} 346 | 347 | Parameters 348 | ---------- 349 | 350 | users: array of np.int64 351 | sequence users 352 | interactions: :class:`spotlight.interactions.Interactions` 353 | training instances, used for generate candidates 354 | n: int 355 | total number of negatives to sample for each sequence 356 | """ 357 | 358 | users_ = users.squeeze() 359 | negative_samples = np.zeros((users_.shape[0], n), np.int64) 360 | if not self._candidate: 361 | all_items = np.arange(interactions.num_items - 1) + 1 # 0 for padding 362 | train = interactions.tocsr() 363 | for user, row in enumerate(train): 364 | self._candidate[user] = list(set(all_items) - set(row.indices)) 365 | 366 | for i, u in enumerate(users_): 367 | for j in range(n): 368 | x = self._candidate[u] 369 | negative_samples[i, j] = x[ 370 | np.random.randint(len(x))] 371 | 372 | return negative_samples 373 | 374 | def predict(self, user_id, item_ids=None): 375 | """ 376 | Make predictions for evaluation: given a user id, it will 377 | first retrieve the test sequence associated with that user 378 | and compute the recommendation scores for items. 379 | 380 | Parameters 381 | ---------- 382 | 383 | user_id: int 384 | users id for which prediction scores needed. 385 | item_ids: array, optional 386 | Array containing the item ids for which prediction scores 387 | are desired. If not supplied, predictions for all items 388 | will be computed. 389 | """ 390 | 391 | if self.test_sequence is None: 392 | raise ValueError('Missing test sequences, cannot make predictions') 393 | 394 | # set model to evaluation model 395 | self._net.eval() 396 | with torch.no_grad(): 397 | sequences_np = self.test_sequence.sequences[user_id, :] 398 | sequences_np = np.atleast_2d(sequences_np) 399 | 400 | if item_ids is None: 401 | item_ids = np.arange(self._num_items).reshape(-1, 1) 402 | 403 | sequences = torch.from_numpy(sequences_np).long() 404 | item_ids = torch.from_numpy(item_ids).long() 405 | user_id = torch.from_numpy(np.array([[user_id]])).long() 406 | 407 | user, sequences, items = (user_id.to(self._device), 408 | sequences.to(self._device), 409 | item_ids.to(self._device)) 410 | 411 | out = self._net(sequences, 412 | user, 413 | items, 414 | for_pred=True) 415 | 416 | return out.cpu().numpy().flatten() 417 | 418 | def sigma(self, x): 419 | res = 1 - torch.exp(-model_config.sigma_alpha*x) 420 | return res 421 | 422 | def get_cates_map(cate_file): 423 | iidcate_map = {} #iid:cates 424 | ## movie_id:cate_ids, cate_ids is not only one 425 | with open(cate_file) as f_cate: 426 | for l in f_cate.readlines(): 427 | if len(l) == 0: break 428 | l = l.strip('\n') 429 | items = [int(i) for i in l.split(' ')] 430 | iid, cate_ids = items[0], items[1:] 431 | iidcate_map[iid] = cate_ids 432 | return iidcate_map 433 | 434 | if __name__ == '__main__': 435 | parser = argparse.ArgumentParser() 436 | # data arguments 437 | parser.add_argument('--train_root', type=str, default='datasets/beauty/train_3.txt') 438 | parser.add_argument('--test_root', type=str, default='datasets/beauty/test_3.txt') 439 | parser.add_argument('--cateid_root', type=str, default='datasets/beauty/cate.txt') 440 | parser.add_argument('--l_kernel_emb', type=str, default='datasets/beauty/item_kernel_3.pkl') 441 | parser.add_argument('--cate_num', type=int, default=213) 442 | parser.add_argument('--L', type=int, default=5) 443 | parser.add_argument('--T', type=int, default=3, help="consistent with the postfix of dataset") 444 | # dpp arguments 445 | parser.add_argument('--neg_samples', type=int, default=3, help="Z") 446 | parser.add_argument('--dpp_loss', type=int, default=2, help="0:cross-entropy, 1:DSL, 2:CDSL") 447 | parser.add_argument('--batch_format', type=int, default=1, help="use minibatch format for dpp loss or not") 448 | # train arguments 449 | parser.add_argument('--n_iter', type=int, default=100) 450 | parser.add_argument('--seed', type=int, default=1234) 451 | parser.add_argument('--batch_size', type=int, default=512) 452 | parser.add_argument('--learning_rate', type=float, default=0.001, help="[0.0005 0.001 0.0015], default 0.001") 453 | parser.add_argument('--l2', type=float, default=1e-4) 454 | parser.add_argument('--use_cuda', type=str2bool, default=True) 455 | 456 | config = parser.parse_args() 457 | 458 | # model dependent arguments 459 | model_parser = argparse.ArgumentParser() 460 | model_parser.add_argument('--d', type=int, default=50) 461 | model_parser.add_argument('--nv', type=int, default=4) 462 | model_parser.add_argument('--nh', type=int, default=16) 463 | model_parser.add_argument('--drop', type=float, default=0.5) 464 | model_parser.add_argument('--ac_conv', type=str, default='relu') 465 | model_parser.add_argument('--ac_fc', type=str, default='relu') 466 | model_parser.add_argument('--sigma_alpha', type=float, default=0.01) 467 | 468 | model_config = model_parser.parse_args() 469 | model_config.L = config.L 470 | 471 | # set seed 472 | set_seed(config.seed, 473 | cuda=config.use_cuda) 474 | 475 | # load dataset 476 | train = Interactions(config.train_root) 477 | # transform triplets to sequence representation 478 | train.to_sequence(config.L, config.T) 479 | 480 | test = Interactions(config.test_root, 481 | user_map=train.user_map, 482 | item_map=train.item_map) 483 | 484 | cate = get_cates_map(config.cateid_root) 485 | 486 | print(config) 487 | print(model_config) 488 | # fit model 489 | model = Recommender(n_iter=config.n_iter, 490 | batch_size=config.batch_size, 491 | learning_rate=config.learning_rate, 492 | l2=config.l2, 493 | neg_samples=config.neg_samples, 494 | model_args=model_config, 495 | use_cuda=config.use_cuda) 496 | 497 | model.fit(train, test, cate, config, verbose=True) 498 | --------------------------------------------------------------------------------