├── README.md ├── datasets.py ├── main.py ├── models.py ├── modules.py ├── trainers.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | This is the repository for paper: [ContrastVAE: Contrastive Variational AutoEncoder for Sequential Recommendation](https://arxiv.org/pdf/2209.00456.pdf) 2 | 3 | 4 | ## the model version control: 5 | 6 | --variational_dropout : using the variational augmentation 7 | 8 | --latent_contrastive_learning: using the model augmentation 9 | 10 | --latent_data_augmentation: using the data augmentation 11 | 12 | --VAandDA: using both variational augmentation and data augmentation 13 | 14 | without any above version control: the model is the vanilla attentive variational autoencoder 15 | 16 | 17 | ## train on Beauty 18 | python main.py --latent_contrastive_learning --data_name=Beauty --latent_clr_weight=0.6 --reparam_dropout_rate=0.1 --lr=0.001 --hidden_size=128 --max_seq_length=50 --hidden_dropout_prob=0.3 --num_hidden_layers=1 --weight_decay=0.0 --num_attention_heads=4 --model_name=VAGRec --attention_probs_dropout_prob=0.0 --anneal_cap=0.2 --total_annealing_step=10000 19 | 20 | ## Office 21 | python main.py --variational_dropout --gpu_id 1 --data_name=Office_Products --latent_clr_weight=0.3 --lr=0.001 --hidden_size=128 --max_seq_length=100 --hidden_dropout_prob=0.3 --num_hidden_layers=1 --weight_decay=0.0 --num_attention_heads=4 --model_name=VAGRec --attention_probs_dropout_prob=0.3 --anneal_cap=0.2 --total_annealing_step=20000 22 | 23 | ## Tool 24 | python main.py --variational_dropout --gpu_id 1 --data_name=Tools_and_Home_Improvement --latent_clr_weight=0.4 --lr=0.001 --hidden_size=128 --max_seq_length=100 --hidden_dropout_prob=0.3 --num_hidden_layers=1 --weight_decay=0.0 --num_attention_heads=4 --model_name=VAGRecVD --attention_probs_dropout_prob=0.3 --anneal_cap=0.4 --total_annealing_step=5000 25 | 26 | ## Toy 27 | 28 | python main.py --variational_dropout --gpu_id 1 --data_name=Toys_and_Games --latent_clr_weight=0.3 --lr=0.001 --hidden_size=128 --max_seq_length=100 --hidden_dropout_prob=0.3 --num_hidden_layers=1 --weight_decay=0.0 --num_attention_heads=4 --model_name=VAGRecVD --attention_probs_dropout_prob=0.3 --anneal_cap=0.2 --total_annealing_step=10000 29 | 30 | ## Reference 31 | 32 | > @inproceedings{wang2022contrastvae, 33 | title={ContrastVAE: Contrastive Variational AutoEncoder for Sequential Recommendation}, 34 | author={Wang, Yu and Zhang, Hengrui and Liu, Zhiwei and Yang, Liangwei and Yu, Philip S}, 35 | booktitle={Proceedings of the 31st ACM International Conference on Information \& Knowledge Management}, 36 | pages={2056--2066}, 37 | year={2022} 38 | } 39 | 40 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import random 2 | import copy 3 | import torch 4 | from torch.utils.data import Dataset 5 | import math 6 | import numpy as np 7 | import random 8 | from utils import neg_sample 9 | 10 | 11 | class SeqDataset(Dataset): 12 | 13 | def __init__(self, args, user_seq, test_neg_items=None, data_type='train'): 14 | self.args = args 15 | self.user_seq = user_seq 16 | self.test_neg_items = test_neg_items 17 | self.data_type = data_type 18 | self.max_len = args.max_seq_length 19 | 20 | def __getitem__(self, index): 21 | 22 | user_id = index 23 | items = self.user_seq[index] 24 | 25 | assert self.data_type in {"train", "valid", "test"} 26 | 27 | # [0, 1, 2, 3, 4, 5, 6] user_seq[index] 28 | # train [0, 1, 2, 3] = input_ids 29 | # target [1, 2, 3, 4] = target_pos 30 | # target_neg [7,8,10,312] 31 | 32 | # valid [0, 1, 2, 3, 4] = input_ids 33 | # target_pos [1,2,3,4,5] 34 | # target_neg [7,8,10,312, 123] 35 | # answer [5] 36 | 37 | # test [0, 1, 2, 3, 4, 5] = input_ids 38 | # target_pos [1,2,3,4,5,6] 39 | # answer [6] 40 | 41 | 42 | 43 | 44 | if self.data_type == "train": 45 | input_ids = items[:-3] 46 | target_pos = items[1:-2] 47 | answer = [0] # no use 48 | 49 | elif self.data_type == 'valid': 50 | input_ids = items[:-2] 51 | target_pos = items[1:-1] 52 | answer = [items[-2]] 53 | 54 | else: 55 | input_ids = items[:-1] 56 | target_pos = items[1:] 57 | answer = [items[-1]] 58 | 59 | 60 | target_neg = [] 61 | seq_set = set(items) 62 | for _ in input_ids: 63 | target_neg.append(neg_sample(seq_set, self.args.item_size)) 64 | 65 | if self.args.latent_data_augmentation or self.args.VAandDA: 66 | dice = random.sample(range(3), k=1) 67 | copy_input_ids = copy.deepcopy(input_ids) 68 | if dice == 0: 69 | aug_input_ids = self.item_crop(copy_input_ids) 70 | elif dice ==1: 71 | aug_input_ids = self.item_mask(copy_input_ids) 72 | else: 73 | aug_input_ids = self.item_reorder(copy_input_ids) 74 | 75 | 76 | # add 0 ids from the start 77 | pad_len = self.max_len - len(input_ids) 78 | input_ids = [0] * pad_len + input_ids 79 | target_pos = [0] * pad_len + target_pos 80 | target_neg = [0] * pad_len + target_neg 81 | 82 | # for long sequences that longer than max_len 83 | input_ids = input_ids[-self.max_len:] 84 | target_pos = target_pos[-self.max_len:] 85 | target_neg = target_neg[-self.max_len:] 86 | 87 | if self.args.latent_data_augmentation or self.args.VAandDA: 88 | # add 0 ids from the start 89 | aug_pad_len = self.max_len - len(aug_input_ids) 90 | aug_input_ids = [0] * aug_pad_len + aug_input_ids 91 | 92 | # for long sequences that longer than max_len 93 | aug_input_ids = aug_input_ids[-self.max_len:] 94 | else: aug_input_ids = 0 95 | 96 | assert len(input_ids) == self.max_len 97 | assert len(target_pos) == self.max_len 98 | assert len(target_neg) == self.max_len 99 | 100 | if self.test_neg_items is not None: 101 | test_samples = self.test_neg_items[index] 102 | 103 | cur_tensors = ( 104 | torch.tensor(user_id, dtype=torch.long), # user_id for testing 105 | torch.tensor(input_ids, dtype=torch.long), 106 | torch.tensor(target_pos, dtype=torch.long), 107 | torch.tensor(target_neg, dtype=torch.long), 108 | torch.tensor(answer, dtype=torch.long), 109 | torch.tensor(test_samples, dtype=torch.long), 110 | torch.tensor(aug_input_ids,dtype=torch.long), 111 | ) 112 | else: # all of shape: b*max_sq 113 | cur_tensors = ( 114 | torch.tensor(user_id, dtype=torch.long), # user_id for testing 115 | torch.tensor(input_ids, dtype=torch.long), # training 116 | torch.tensor(target_pos, dtype=torch.long), # targeting, one item right-shifted, since task is to predict next item 117 | torch.tensor(target_neg, dtype=torch.long), # random sample an item out of training and eval for every training items. 118 | torch.tensor(answer, dtype=torch.long), # last item for prediction. 119 | torch.tensor(aug_input_ids,dtype=torch.long) 120 | ) 121 | 122 | return cur_tensors 123 | 124 | def item_crop(self, item_seq, eta=0.6): # item_Seq: [batch, max_seq] 125 | item_seq = np.array(item_seq) 126 | item_seq_len = len(item_seq) 127 | num_left = math.floor(item_seq_len * eta) 128 | crop_begin = random.randint(0, item_seq_len - num_left) 129 | croped_item_seq = np.zeros(item_seq.shape[0]) 130 | if crop_begin + num_left < item_seq.shape[0]: 131 | croped_item_seq[:num_left] = item_seq[crop_begin:crop_begin + num_left] 132 | else: 133 | croped_item_seq[:num_left] = item_seq[crop_begin:] 134 | return list(croped_item_seq) 135 | 136 | 137 | def item_mask(self, item_seq, gamma=0.3): 138 | item_seq = np.array(item_seq) 139 | item_seq_len = len(item_seq) 140 | num_mask = math.floor(item_seq_len * gamma) 141 | mask_index = random.sample(range(item_seq_len), k=num_mask) 142 | masked_item_seq = item_seq.copy() 143 | masked_item_seq[mask_index] = self.args.mask_id # token 0 has been used for semantic masking 144 | return list(masked_item_seq) 145 | 146 | 147 | def item_reorder(self, item_seq, beta=0.6): 148 | item_seq = np.array(item_seq) 149 | item_seq_len = len(item_seq) 150 | num_reorder = math.floor(item_seq_len * beta) 151 | reorder_begin = random.randint(0, item_seq_len - num_reorder) 152 | reordered_item_seq = item_seq.copy() 153 | shuffle_index = list(range(reorder_begin, reorder_begin + num_reorder)) 154 | random.shuffle(shuffle_index) 155 | reordered_item_seq[reorder_begin:reorder_begin + num_reorder] = reordered_item_seq[shuffle_index] 156 | return list(reordered_item_seq) 157 | 158 | 159 | def __len__(self): 160 | return len(self.user_seq) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import numpy as np 5 | import random 6 | import torch 7 | import argparse 8 | import pdb 9 | 10 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 11 | 12 | from datasets import SeqDataset 13 | from trainers import ContrastVAETrainer 14 | from models import ContrastVAE, ContrastVAE_VD 15 | from utils import EarlyStopping, get_user_seqs, get_user_seqs_replace, get_item2attribute_json, check_path, set_seed 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | 21 | # data args 22 | parser.add_argument('--data_dir', default='./data/', type=str) 23 | parser.add_argument('--output_dir', default='output/', type=str) 24 | parser.add_argument('--data_name', default='Toys_and_Games', type=str) 25 | parser.add_argument('--do_eval', action='store_true') 26 | parser.add_argument('--ckp', default=10, type=int, help="pretrain epochs 10, 20, 30...") 27 | 28 | 29 | # model args 30 | parser.add_argument("--model_name", default='ContrastVAE', type=str) 31 | parser.add_argument("--hidden_size", type=int, default=128, help="hidden size of transformer model") 32 | parser.add_argument("--num_hidden_layers", type=int, default=1, help="number of layers") 33 | parser.add_argument('--num_attention_heads', default=4, type=int) 34 | parser.add_argument('--hidden_act', default="gelu", type=str) # gelu relu 35 | parser.add_argument("--attention_probs_dropout_prob", type=float, default=0.0, help="attention dropout p") 36 | parser.add_argument("--hidden_dropout_prob", type=float, default=0.3, help="hidden dropout p") 37 | parser.add_argument("--initializer_range", type=float, default=0.02) 38 | parser.add_argument('--max_seq_length', default=100, type=int) 39 | 40 | # train args 41 | parser.add_argument("--lr", type=float, default=0.001, help="learning rate of adam") 42 | parser.add_argument("--batch_size", type=int, default=256, help="number of batch_size") 43 | parser.add_argument("--epochs", type=int, default=400, help="number of epochs") 44 | parser.add_argument("--no_cuda", action="store_true") 45 | parser.add_argument("--log_freq", type=int, default=1, help="per epoch print res") 46 | parser.add_argument("--seed", default=42, type=int) 47 | 48 | parser.add_argument("--weight_decay", type=float, default=0.0, help="weight_decay of adam") 49 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value") 50 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam second beta value") 51 | parser.add_argument("--gpu_id", type=str, default="0", help="gpu_id") 52 | 53 | # model variants 54 | parser.add_argument("--variational_dropout", action='store_true') 55 | parser.add_argument("--VAandDA", action='store_true') 56 | parser.add_argument("--latent_contrastive_learning", action='store_true') 57 | parser.add_argument('--latent_data_augmentation', action='store_true') 58 | 59 | parser.add_argument("--latent_clr_weight", type=float, default=0.3, help="weight for latent clr loss") 60 | parser.add_argument("--reparam_dropout_rate", type=float, default=0.2, help="dropout rate for reparameterization dropout") 61 | parser.add_argument("--store_latent", action='store_true', help="store the latent representation of sequence embedding") 62 | 63 | # KL annealing args 64 | parser.add_argument('--anneal_cap', type=float, default=0.3) 65 | parser.add_argument('--total_annealing_step', type=int, default=10000) 66 | 67 | # contrastive args 68 | parser.add_argument('--temperature', type=float, default=0.5) 69 | parser.add_argument('--i', type=int) 70 | parser.add_argument('--eval_model_path', type = str) 71 | args = parser.parse_args() 72 | 73 | set_seed(args.seed) 74 | check_path(args.output_dir) 75 | 76 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 77 | args.cuda_condition = torch.cuda.is_available() and not args.no_cuda 78 | 79 | 80 | args.data_file = args.data_dir + args.data_name + '.txt' 81 | 82 | 83 | """ 84 | load data 85 | user_seq: original list contains all interacted items 86 | max_item: number of all items plus 0 87 | valid_rating_matrix: shape (num_users, num_items), sparse matrix, value:1, row: user_id, col: item_id, record [:-2] items from user_seq 88 | test_rating_matrix: same as valid_rating_matrix, but record [:-1] items 89 | """ 90 | 91 | user_seq, max_item, valid_rating_matrix, test_rating_matrix, num_users = \ 92 | get_user_seqs(args.data_file) 93 | 94 | 95 | args.item_size = max_item + 2 96 | args.num_users = num_users 97 | args.mask_id = max_item + 1 98 | 99 | 100 | # set item score in train set to `0` in validation 101 | args.train_matrix = valid_rating_matrix 102 | print(f"valid rating matix shape: {valid_rating_matrix.shape}") 103 | 104 | 105 | 106 | 107 | 108 | # save model args 109 | args_str = f'{args.model_name}' \ 110 | f'-{args.data_name}' \ 111 | f'-{args.hidden_size}' \ 112 | f'-{args.num_hidden_layers}' \ 113 | f'-{args.num_attention_heads}' \ 114 | f'-{args.hidden_act}' \ 115 | f'-{args.attention_probs_dropout_prob}' \ 116 | f'-{args.hidden_dropout_prob}' \ 117 | f'-{args.max_seq_length}' \ 118 | f'-{args.lr}' \ 119 | f'-{args.weight_decay}' \ 120 | f'-{args.anneal_cap}' \ 121 | f'-{args.total_annealing_step}' \ 122 | f'-{args.reparam_dropout_rate}'\ 123 | f'-{args.latent_clr_weight}' 124 | 125 | args.log_file = os.path.join(args.output_dir, args_str + '.txt') 126 | with open(args.log_file, 'a') as f: 127 | f.write(str(args) + '\n') 128 | 129 | # save model 130 | checkpoint = args_str + '.pt' 131 | args.checkpoint_path = os.path.join(args.output_dir, checkpoint) 132 | 133 | 134 | train_dataset = SeqDataset(args, user_seq, data_type='train') 135 | train_sampler = RandomSampler(train_dataset) 136 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size) 137 | 138 | eval_dataset = SeqDataset(args, user_seq, data_type='valid') 139 | eval_sampler = SequentialSampler(eval_dataset) 140 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.batch_size) 141 | 142 | 143 | test_dataset = SeqDataset(args, user_seq, data_type='test') 144 | test_sampler = SequentialSampler(test_dataset) 145 | test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.batch_size) 146 | 147 | if args.variational_dropout: 148 | model = ContrastVAE_VD(args) 149 | else: model = ContrastVAE(args=args) 150 | 151 | trainer = ContrastVAETrainer(model, train_dataloader, eval_dataloader, 152 | test_dataloader, args) 153 | 154 | if args.do_eval: 155 | # load the best model 156 | print('---------------load best model and do eval-------------------') 157 | trainer.model.load_state_dict(torch.load(args.checkpoint_path)) 158 | trainer.args.train_matrix = test_rating_matrix 159 | trainer.test('best', full_sort=True) 160 | else: 161 | 162 | early_stopping = EarlyStopping(args.checkpoint_path, patience=200, verbose=True) 163 | for epoch in range(args.epochs): 164 | trainer.train(epoch) 165 | scores, _, _ = trainer.valid(epoch, full_sort=True) 166 | early_stopping(np.array([scores[4], scores[5]]), trainer.model) # here only check best recall@10, ndcg@10 167 | if early_stopping.early_stop: 168 | print("Early stopping") 169 | break 170 | 171 | print('---------------Change to test_rating_matrix!-------------------') 172 | # load the best model 173 | trainer.model.load_state_dict(torch.load(args.checkpoint_path)) 174 | valid_scores, _, _ = trainer.valid('best', full_sort=True) 175 | trainer.args.train_matrix = test_rating_matrix 176 | scores, result_info, _ = trainer.test('best', full_sort=True) 177 | 178 | print(args_str) 179 | with open(args.log_file, 'a') as f: 180 | f.write(args_str + '\n') 181 | f.write(result_info + '\n') 182 | 183 | 184 | if __name__ == '__main__': 185 | main() 186 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from modules import Encoder, LayerNorm, Decoder, VariationalDropout 4 | import math 5 | import numpy as np 6 | import random 7 | 8 | class ContrastVAE(nn.Module): 9 | 10 | def __init__(self, args): 11 | super(ContrastVAE, self).__init__() 12 | 13 | self.item_embeddings = nn.Embedding(args.item_size, args.hidden_size, padding_idx=0) 14 | self.position_embeddings = nn.Embedding(args.max_seq_length, args.hidden_size) 15 | self.item_encoder_mu = Encoder(args) 16 | self.item_encoder_logvar = Encoder(args) 17 | self.item_decoder = Decoder(args) 18 | self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12) 19 | self.dropout = nn.Dropout(args.hidden_dropout_prob) 20 | self.args = args 21 | self.latent_dropout = nn.Dropout(args.reparam_dropout_rate) 22 | self.apply(self.init_weights) 23 | self.temperature = nn.Parameter(torch.zeros(1), requires_grad=True) 24 | 25 | def add_position_embedding(self, sequence): 26 | 27 | seq_length = sequence.size(1) 28 | position_ids = torch.arange(seq_length, dtype=torch.long, device=sequence.device) 29 | position_ids = position_ids.unsqueeze(0).expand_as(sequence) 30 | item_embeddings = self.item_embeddings(sequence) # shape: b*max_Sq*d 31 | position_embeddings = self.position_embeddings(position_ids) 32 | sequence_emb = item_embeddings + position_embeddings 33 | sequence_emb = self.LayerNorm(sequence_emb) 34 | sequence_emb = self.dropout(sequence_emb) 35 | 36 | return sequence_emb # shape: b*max_Sq*d 37 | 38 | 39 | def extended_attention_mask(self, input_ids): 40 | attention_mask = (input_ids > 0).long()# used for mu, var 41 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64 b*1*1*max_Sq 42 | max_len = attention_mask.size(-1) 43 | attn_shape = (1, max_len, max_len) 44 | subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8 for causality 45 | subsequent_mask = (subsequent_mask == 0).unsqueeze(1) #1*1*max_Sq*max_Sq 46 | subsequent_mask = subsequent_mask.long() 47 | 48 | if self.args.cuda_condition: 49 | subsequent_mask = subsequent_mask.cuda() 50 | 51 | extended_attention_mask = extended_attention_mask * subsequent_mask #shape: b*1*max_Sq*max_Sq 52 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 53 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 54 | 55 | return extended_attention_mask 56 | 57 | 58 | def eps_anneal_function(self, step): 59 | 60 | return min(1.0, (1.0*step)/self.args.total_annealing_step) 61 | 62 | def reparameterization(self, mu, logvar, step): # vanila reparam 63 | 64 | std = torch.exp(0.5 * logvar) 65 | if self.training: 66 | eps = torch.randn_like(std) 67 | res = mu + std * eps 68 | else: res = mu + std 69 | return res 70 | 71 | def reparameterization1(self, mu, logvar, step): # reparam without noise 72 | std = torch.exp(0.5*logvar) 73 | return mu+std 74 | 75 | 76 | def reparameterization2(self, mu, logvar, step): # use dropout 77 | 78 | if self.training: 79 | std = self.latent_dropout(torch.exp(0.5*logvar)) 80 | else: std = torch.exp(0.5*logvar) 81 | res = mu + std 82 | return res 83 | 84 | def reparameterization3(self, mu, logvar,step): # apply classical dropout on whole result 85 | std = torch.exp(0.5*logvar) 86 | res = self.latent_dropout(mu + std) 87 | return res 88 | 89 | 90 | def init_weights(self, module): 91 | """ Initialize the weights. 92 | """ 93 | if isinstance(module, (nn.Linear, nn.Embedding)): 94 | # Slightly different from the TF version which uses truncated_normal for initialization 95 | # cf https://github.com/pytorch/pytorch/pull/5617 96 | module.weight.data.normal_(mean=0.0, std=self.args.initializer_range) 97 | elif isinstance(module, LayerNorm): 98 | module.bias.data.zero_() 99 | module.weight.data.fill_(1.0) 100 | if isinstance(module, nn.Linear) and module.bias is not None: 101 | module.bias.data.zero_() 102 | 103 | 104 | def encode(self, sequence_emb, extended_attention_mask): # forward 105 | 106 | item_encoded_mu_layers = self.item_encoder_mu(sequence_emb, 107 | extended_attention_mask, 108 | output_all_encoded_layers=True) 109 | 110 | item_encoded_logvar_layers = self.item_encoder_logvar(sequence_emb, extended_attention_mask, 111 | True) 112 | 113 | return item_encoded_mu_layers[-1], item_encoded_logvar_layers[-1] 114 | 115 | def decode(self, z, extended_attention_mask): 116 | item_decoder_layers = self.item_decoder(z, 117 | extended_attention_mask, 118 | output_all_encoded_layers = True) 119 | sequence_output = item_decoder_layers[-1] 120 | return sequence_output 121 | 122 | 123 | 124 | def forward(self, input_ids, aug_input_ids, step): 125 | 126 | sequence_emb = self.add_position_embedding(input_ids)# shape: b*max_Sq*d 127 | extended_attention_mask = self.extended_attention_mask(input_ids) 128 | 129 | if self.args.latent_contrastive_learning: 130 | mu1, log_var1 = self.encode(sequence_emb, extended_attention_mask) 131 | mu2, log_var2 = self.encode(sequence_emb, extended_attention_mask) 132 | z1 = self.reparameterization1(mu1, log_var1, step) 133 | z2 = self.reparameterization2(mu2, log_var2, step) 134 | reconstructed_seq1 = self.decode(z1, extended_attention_mask) 135 | reconstructed_seq2 = self.decode(z2, extended_attention_mask) 136 | return reconstructed_seq1, reconstructed_seq2, mu1, mu2, log_var1, log_var2, z1, z2 137 | 138 | elif self.args.latent_data_augmentation: 139 | aug_sequence_emb = self.add_position_embedding(aug_input_ids) # shape: b*max_Sq*d 140 | aug_extended_attention_mask = self.extended_attention_mask(aug_input_ids) 141 | 142 | mu1, log_var1 = self.encode(sequence_emb, extended_attention_mask) 143 | mu2, log_var2 = self.encode(aug_sequence_emb, aug_extended_attention_mask) 144 | z1 = self.reparameterization1(mu1, log_var1, step) 145 | z2 = self.reparameterization2(mu2, log_var2, step) 146 | reconstructed_seq1 = self.decode(z1, extended_attention_mask) 147 | reconstructed_seq2 = self.decode(z2, extended_attention_mask) 148 | return reconstructed_seq1, reconstructed_seq2, mu1, mu2, log_var1, log_var2, z1, z2 149 | 150 | else: # vanilla attentive VAE 151 | mu, log_var = self.encode(sequence_emb, extended_attention_mask) 152 | z = self.reparameterization(mu, log_var, step) 153 | reconstructed_seq1 = self.decode(z, extended_attention_mask) 154 | return reconstructed_seq1, mu, log_var 155 | 156 | 157 | 158 | 159 | 160 | class ContrastVAE_VD(ContrastVAE): 161 | 162 | def __init__(self, args): 163 | super(ContrastVAE, self).__init__() 164 | 165 | self.item_embeddings = nn.Embedding(args.item_size, args.hidden_size, padding_idx=0) 166 | self.position_embeddings = nn.Embedding(args.max_seq_length, args.hidden_size) 167 | 168 | self.item_encoder_mu = Encoder(args) 169 | self.item_encoder_logvar = Encoder(args) 170 | self.item_decoder = Decoder(args) 171 | 172 | self.dropout = nn.Dropout(args.hidden_dropout_prob) 173 | 174 | self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12) 175 | self.latent_dropout_VD = VariationalDropout(inputshape=[args.max_seq_length, args.hidden_size], adaptive='layerwise') 176 | self.latent_dropout = nn.Dropout(0.1) 177 | self.args = args 178 | self.apply(self.init_weights) 179 | 180 | self.drop_rate = nn.Parameter(torch.tensor(0.2), requires_grad=True) 181 | 182 | 183 | def reparameterization3(self, mu, logvar, step): # use drop out 184 | 185 | std, alpha = self.latent_dropout_VD(torch.exp(0.5*logvar)) 186 | res = mu + std 187 | return res, alpha 188 | 189 | def forward(self, input_ids, augmented_input_ids, step): 190 | if self.args.variational_dropout: 191 | sequence_emb = self.add_position_embedding(input_ids) # shape: b*max_Sq*d 192 | extended_attention_mask = self.extended_attention_mask(input_ids) 193 | mu1, log_var1 = self.encode(sequence_emb, extended_attention_mask) 194 | mu2, log_var2 = self.encode(sequence_emb, extended_attention_mask) 195 | z1 = self.reparameterization1(mu1, log_var1, step) 196 | z2, alpha = self.reparameterization3(mu2, log_var2, step) 197 | reconstructed_seq1 = self.decode(z1, extended_attention_mask) 198 | reconstructed_seq2 = self.decode(z2, extended_attention_mask) 199 | 200 | elif self.args.VAandDA: 201 | sequence_emb = self.add_position_embedding(input_ids) # shape: b*max_Sq*d 202 | extended_attention_mask = self.extended_attention_mask(input_ids) 203 | aug_sequence_emb = self.add_position_embedding(augmented_input_ids) # shape: b*max_Sq*d 204 | aug_extended_attention_mask = self.extended_attention_mask(augmented_input_ids) 205 | 206 | mu1, log_var1 = self.encode(sequence_emb, extended_attention_mask) 207 | mu2, log_var2 = self.encode(aug_sequence_emb, aug_extended_attention_mask) 208 | z1 = self.reparameterization1(mu1, log_var1, step) 209 | z2, alpha = self.reparameterization3(mu2, log_var2, step) 210 | reconstructed_seq1 = self.decode(z1, extended_attention_mask) 211 | reconstructed_seq2 = self.decode(z2, extended_attention_mask) 212 | 213 | 214 | return reconstructed_seq1, reconstructed_seq2, mu1, mu2, log_var1, log_var2, z1, z2, alpha 215 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import copy 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | """activation function""" 10 | def gelu(x): 11 | """Implementation of the gelu activation function. 12 | For information: OpenAI GPT's gelu is slightly different 13 | (and gives slightly different results): 14 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * 15 | (x + 0.044715 * torch.pow(x, 3)))) 16 | Also see https://arxiv.org/abs/1606.08415 17 | """ 18 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 19 | 20 | 21 | def swish(x): 22 | return x * torch.sigmoid(x) 23 | 24 | ACT2FN = {"gelu": gelu, "relu": F.relu, "swish": swish} 25 | 26 | """Transformer toolkits""" 27 | 28 | class LayerNorm(nn.Module): 29 | def __init__(self, hidden_size, eps=1e-12): 30 | """Construct a layernorm module in the TF style (epsilon inside the square root). 31 | """ 32 | super(LayerNorm, self).__init__() 33 | self.weight = nn.Parameter(torch.ones(hidden_size)) 34 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 35 | self.variance_epsilon = eps 36 | 37 | def forward(self, x): 38 | u = x.mean(-1, keepdim=True) 39 | s = (x - u).pow(2).mean(-1, keepdim=True) 40 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 41 | return self.weight * x + self.bias 42 | 43 | 44 | 45 | class SelfAttention(nn.Module): 46 | def __init__(self, args): 47 | super(SelfAttention, self).__init__() 48 | if args.hidden_size % args.num_attention_heads != 0: 49 | raise ValueError( 50 | "The hidden size (%d) is not a multiple of the number of attention " 51 | "heads (%d)" % (args.hidden_size, args.num_attention_heads)) 52 | self.num_attention_heads = args.num_attention_heads 53 | self.attention_head_size = int(args.hidden_size / args.num_attention_heads) 54 | self.all_head_size = self.num_attention_heads * self.attention_head_size 55 | 56 | self.query = nn.Linear(args.hidden_size, self.all_head_size) 57 | self.key = nn.Linear(args.hidden_size, self.all_head_size) 58 | self.value = nn.Linear(args.hidden_size, self.all_head_size) 59 | 60 | self.attn_dropout = nn.Dropout(args.attention_probs_dropout_prob) 61 | 62 | self.dense = nn.Linear(args.hidden_size, args.hidden_size) 63 | self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12) 64 | self.out_dropout = nn.Dropout(args.hidden_dropout_prob) 65 | 66 | def transpose_for_scores(self, x): 67 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 68 | x = x.view(*new_x_shape) 69 | return x.permute(0, 2, 1, 3) 70 | 71 | def forward(self, input_tensor, attention_mask): 72 | mixed_query_layer = self.query(input_tensor) 73 | mixed_key_layer = self.key(input_tensor) 74 | mixed_value_layer = self.value(input_tensor) 75 | 76 | query_layer = self.transpose_for_scores(mixed_query_layer) 77 | key_layer = self.transpose_for_scores(mixed_key_layer) 78 | value_layer = self.transpose_for_scores(mixed_value_layer) 79 | 80 | # Take the dot product between "query" and "key" to get the raw attention scores. 81 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 82 | 83 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 84 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 85 | # [batch_size heads seq_len seq_len] scores 86 | # [batch_size 1 1 seq_len] 87 | attention_scores = attention_scores + attention_mask 88 | 89 | # Normalize the attention scores to probabilities. 90 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 91 | attention_probs = self.attn_dropout(attention_probs) 92 | context_layer = torch.matmul(attention_probs, value_layer) 93 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 94 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 95 | context_layer = context_layer.view(*new_context_layer_shape) 96 | hidden_states = self.dense(context_layer) 97 | hidden_states = self.out_dropout(hidden_states) 98 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 99 | 100 | return hidden_states 101 | 102 | 103 | 104 | 105 | class Intermediate(nn.Module): 106 | def __init__(self, args): 107 | super(Intermediate, self).__init__() 108 | self.dense_1 = nn.Linear(args.hidden_size, args.hidden_size * 4) 109 | if isinstance(args.hidden_act, str): 110 | self.intermediate_act_fn = ACT2FN[args.hidden_act] 111 | else: 112 | self.intermediate_act_fn = args.hidden_act 113 | 114 | self.dense_2 = nn.Linear(args.hidden_size * 4, args.hidden_size) 115 | self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12) 116 | self.dropout = nn.Dropout(args.hidden_dropout_prob) 117 | 118 | def forward(self, input_tensor): 119 | 120 | hidden_states = self.dense_1(input_tensor) 121 | hidden_states = self.intermediate_act_fn(hidden_states) 122 | 123 | hidden_states = self.dense_2(hidden_states) 124 | hidden_states = self.dropout(hidden_states) 125 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 126 | 127 | return hidden_states 128 | 129 | 130 | 131 | class Layer(nn.Module): # attention block 132 | def __init__(self, args): 133 | super(Layer, self).__init__() 134 | self.attention = SelfAttention(args) 135 | self.intermediate = Intermediate(args) 136 | 137 | def forward(self, hidden_states, attention_mask): 138 | attention_output = self.attention(hidden_states, attention_mask) 139 | intermediate_output = self.intermediate(attention_output) 140 | return intermediate_output 141 | 142 | 143 | 144 | 145 | class Encoder(nn.Module): 146 | def __init__(self, args): 147 | super(Encoder, self).__init__() 148 | layer = Layer(args) 149 | self.layer = nn.ModuleList([copy.deepcopy(layer) 150 | for _ in range(args.num_hidden_layers)]) 151 | 152 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): 153 | """ 154 | 155 | :param hidden_states: bxmax_Sqxd 156 | :param attention_mask: b*1*max_Sq*max_Sq 157 | :param output_all_encoded_layers: True or False 158 | :return: 159 | """ 160 | 161 | all_encoder_layers = [] 162 | for layer_module in self.layer: 163 | hidden_states = layer_module(hidden_states, attention_mask) 164 | if output_all_encoded_layers: 165 | all_encoder_layers.append(hidden_states) 166 | if not output_all_encoded_layers: 167 | all_encoder_layers.append(hidden_states) 168 | return all_encoder_layers 169 | 170 | 171 | class Decoder(nn.Module): 172 | def __init__(self, args): 173 | super(Decoder, self).__init__() 174 | layer = Layer(args) 175 | self.layer = nn.ModuleList([copy.deepcopy(layer) 176 | for _ in range(args.num_hidden_layers)]) 177 | 178 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): 179 | """ 180 | 181 | :param hidden_states: bxmax_Sqxd 182 | :param attention_mask: b*1*max_Sq*max_Sq 183 | :param output_all_encoded_layers: True or False 184 | :return: 185 | """ 186 | 187 | all_decoder_layers = [] 188 | for layer_module in self.layer: 189 | hidden_states = layer_module(hidden_states, attention_mask) 190 | if output_all_encoded_layers: 191 | all_decoder_layers.append(hidden_states) 192 | if not output_all_encoded_layers: 193 | all_decoder_layers.append(hidden_states) 194 | return all_decoder_layers 195 | 196 | 197 | class NCELoss(nn.Module): 198 | """ 199 | """ 200 | 201 | def __init__(self, temperature, device): 202 | super(NCELoss, self).__init__() 203 | self.device = device 204 | self.criterion = nn.CrossEntropyLoss().to(self.device) 205 | self.temperature = temperature 206 | self.cossim = nn.CosineSimilarity(dim=-1).to(self.device) 207 | 208 | # #modified based on impl: https://github.com/ae-foster/pytorch-simclr/blob/dc9ac57a35aec5c7d7d5fe6dc070a975f493c1a5/critic.py#L5 209 | def forward(self, batch_sample_one, batch_sample_two): # batch_size* 210 | 211 | sim11 = self.cossim(batch_sample_one.unsqueeze(-2), batch_sample_one.unsqueeze(-3)) / self.temperature 212 | sim22 = self.cossim(batch_sample_two.unsqueeze(-2), batch_sample_two.unsqueeze(-3)) / self.temperature 213 | sim12 = self.cossim(batch_sample_one.unsqueeze(-2), batch_sample_two.unsqueeze(-3)) / self.temperature 214 | 215 | d = sim12.shape[-1] 216 | sim11[..., range(d), range(d)] = float('-inf') 217 | sim22[..., range(d), range(d)] = float('-inf') 218 | raw_scores1 = torch.cat([sim12, sim11], dim=-1) 219 | raw_scores2 = torch.cat([sim22, sim12.transpose(-1, -2)], dim=-1) 220 | logits = torch.cat([raw_scores1, raw_scores2], dim=-2) 221 | labels = torch.arange(2 * d, dtype=torch.long, device=logits.device) 222 | nce_loss = self.criterion(logits, labels) 223 | return nce_loss 224 | 225 | 226 | """toolkit for variational dropout""" 227 | 228 | def _logit(x): 229 | return np.log(x/(1. - x)) 230 | 231 | def _check_p(p): 232 | if p == 0.5: 233 | return 0.4999 234 | elif p > 0.5: 235 | return 0.4999 236 | elif p <= 0.0: 237 | return 0.0001 238 | else: 239 | return p 240 | 241 | class VariationalDropout(nn.Module): 242 | def __init__(self, inputshape, p=0.2, adaptive=None): 243 | super(VariationalDropout, self).__init__() 244 | 245 | self.adaptive = adaptive 246 | p = _check_p(p) 247 | 248 | if self.adaptive == None: 249 | self.logitalpha =nn.Parameter(torch.tensor(_logit(np.sqrt(p / (1. - p))) 250 | ), requires_grad=False) 251 | 252 | elif self.adaptive == 'layerwise': 253 | self.logitalpha =nn.Parameter(torch.tensor(_logit(np.sqrt(p / (1. - p))) 254 | ), requires_grad=True) 255 | 256 | elif self.adaptive == "elementwise": 257 | # initialise param for each activation passed 258 | self.logitalpha =nn.Parameter(torch.tensor( 259 | np.ones(inputshape[1:]).astype(np.float32) * _logit(np.sqrt(p / (1. - p))) 260 | ), requires_grad=True) 261 | 262 | elif self.adaptive == "weightwise": 263 | # this will only work in the case of dropout type B 264 | self.logitalpha =nn.Parameter(torch.tensor(np.ones(inputshape).astype(np.float32) * _logit(np.sqrt(p / (1. - p))) 265 | ), requires_grad=True) 266 | 267 | def forward(self, x): 268 | 269 | alpha = F.sigmoid(self.logitalpha) 270 | output =x*alpha*torch.randn_like(x) 271 | 272 | return output, alpha 273 | 274 | 275 | def priorKL(alpha): 276 | 277 | c1 = 1.161451241083230 278 | c2 = -1.502041176441722 279 | c3 = 0.586299206427007 280 | return -torch.mean(0.5 * torch.log(alpha) + c1 * alpha + c2 * torch.pow(alpha, 2) + c3 * torch.pow(alpha, 3)) 281 | 282 | 283 | -------------------------------------------------------------------------------- /trainers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tqdm 3 | import random 4 | from collections import defaultdict 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.optim import Adam 9 | from modules import NCELoss, priorKL 10 | from utils import recall_at_k, ndcg_k, get_metric, cal_mrr, get_user_performance_perpopularity, get_item_performance_perpopularity 11 | 12 | class Trainer: 13 | def __init__(self, model, train_dataloader, 14 | eval_dataloader, 15 | test_dataloader, args): 16 | 17 | self.args = args 18 | self.cuda_condition = torch.cuda.is_available() and not self.args.no_cuda 19 | self.device = torch.device("cuda" if self.cuda_condition else "cpu") 20 | 21 | self.model = model 22 | if self.cuda_condition: 23 | self.model.cuda() 24 | 25 | # Setting the train and test data loader 26 | self.train_dataloader = train_dataloader 27 | self.eval_dataloader = eval_dataloader 28 | self.test_dataloader = test_dataloader 29 | 30 | # self.data_name = self.args.data_name 31 | betas = (self.args.adam_beta1, self.args.adam_beta2) 32 | self.optim = Adam(self.model.parameters(), lr=self.args.lr, betas=betas, weight_decay=self.args.weight_decay) 33 | 34 | print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]), flush=True) 35 | self.criterion = nn.BCELoss() 36 | 37 | def train(self, epoch): 38 | self.iteration(epoch, self.train_dataloader) 39 | 40 | def valid(self, epoch, full_sort=False): 41 | return self.iteration(epoch, self.eval_dataloader, full_sort, train=False) 42 | 43 | def test(self, epoch, full_sort=False): 44 | return self.iteration(epoch, self.test_dataloader, full_sort, train=False) 45 | 46 | def complicated_eval(self, user_seq, args): 47 | return self.eval_analysis(self.test_dataloader, user_seq, args) 48 | 49 | def iteration(self, epoch, dataloader, full_sort=False, train=True): 50 | raise NotImplementedError 51 | 52 | def eval_analysis(self, dataloader, seqs): 53 | raise NotImplementedError 54 | 55 | def get_sample_scores(self, epoch, pred_list): 56 | pred_list = (-pred_list).argsort().argsort()[:, 0] 57 | HIT_1, NDCG_1, MRR = get_metric(pred_list, 1) 58 | HIT_5, NDCG_5, MRR = get_metric(pred_list, 5) 59 | HIT_10, NDCG_10, MRR = get_metric(pred_list, 10) 60 | post_fix = { 61 | "Epoch": epoch, 62 | "HIT@1": '{:.4f}'.format(HIT_1), "NDCG@1": '{:.4f}'.format(NDCG_1), 63 | "HIT@5": '{:.4f}'.format(HIT_5), "NDCG@5": '{:.4f}'.format(NDCG_5), 64 | "HIT@10": '{:.4f}'.format(HIT_10), "NDCG@10": '{:.4f}'.format(NDCG_10), 65 | "MRR": '{:.4f}'.format(MRR), 66 | } 67 | print(post_fix, flush=True) 68 | with open(self.args.log_file, 'a') as f: 69 | f.write(str(post_fix) + '\n') 70 | return [HIT_1, NDCG_1, HIT_5, NDCG_5, HIT_10, NDCG_10, MRR], str(post_fix), None 71 | 72 | def get_full_sort_score(self, epoch, answers, pred_list): 73 | recall, ndcg, mrr = [], [], 0 74 | recall_dict_list = [] 75 | ndcg_dict_list = [] 76 | for k in [1, 5, 10, 15, 20, 40]: 77 | recall_result, recall_dict_k = recall_at_k(answers, pred_list, k) 78 | recall.append(recall_result) 79 | recall_dict_list.append(recall_dict_k) 80 | ndcg_result, ndcg_dict_k = ndcg_k(answers, pred_list, k) 81 | ndcg.append(ndcg_result) 82 | ndcg_dict_list.append(ndcg_dict_k) 83 | mrr, mrr_dict = cal_mrr(answers, pred_list) 84 | post_fix = { 85 | "Epoch": epoch, 86 | "HIT@1": '{:.8f}'.format(recall[0]), "NDCG@1": '{:.8f}'.format(ndcg[0]), 87 | "HIT@5": '{:.8f}'.format(recall[1]), "NDCG@5": '{:.8f}'.format(ndcg[1]), 88 | "HIT@10": '{:.8f}'.format(recall[2]), "NDCG@10": '{:.8f}'.format(ndcg[2]), 89 | "HIT@15": '{:.8f}'.format(recall[3]), "NDCG@15": '{:.8f}'.format(ndcg[3]), 90 | "HIT@20": '{:.8f}'.format(recall[4]), "NDCG@20": '{:.8f}'.format(ndcg[4]), 91 | "HIT@40": '{:.8f}'.format(recall[5]), "NDCG@40": '{:.8f}'.format(ndcg[5]), 92 | "MRR": '{:.8f}'.format(mrr) 93 | } 94 | print(post_fix, flush=True) 95 | with open(self.args.log_file, 'a') as f: 96 | f.write(str(post_fix) + '\n') 97 | return [recall[0], ndcg[0], recall[1], ndcg[1], recall[2], ndcg[2], recall[3], ndcg[3], recall[4], ndcg[4], recall[5], ndcg[5], mrr], str(post_fix), [recall_dict_list, ndcg_dict_list, mrr_dict] 98 | 99 | def get_pos_items_ranks(self, batch_pred_lists, answers): 100 | num_users = len(batch_pred_lists) 101 | batch_pos_ranks = defaultdict(list) 102 | for i in range(num_users): 103 | pred_list = batch_pred_lists[i] 104 | true_set = set(answers[i]) 105 | for ind, pred_item in enumerate(pred_list): 106 | if pred_item in true_set: 107 | batch_pos_ranks[pred_item].append(ind+1) 108 | return batch_pos_ranks 109 | 110 | def save(self, file_name): 111 | torch.save(self.model.cpu().state_dict(), file_name) 112 | self.model.to(self.device) 113 | 114 | def load(self, file_name): 115 | self.model.load_state_dict(torch.load(file_name, map_location='cuda:0')) 116 | 117 | def cross_entropy(self, seq_out, pos_ids, neg_ids): 118 | # [batch seq_len hidden_size] 119 | pos_emb = self.model.item_embeddings(pos_ids) 120 | neg_emb = self.model.item_embeddings(neg_ids) 121 | # [batch*seq_len hidden_size] 122 | pos = pos_emb.view(-1, pos_emb.size(2)) 123 | neg = neg_emb.view(-1, neg_emb.size(2)) 124 | seq_emb = seq_out.view(-1, self.args.hidden_size) # [batch*seq_len hidden_size] 125 | pos_logits = torch.sum(pos * seq_emb, -1) # [batch*seq_len] 126 | neg_logits = torch.sum(neg * seq_emb, -1) 127 | istarget = (pos_ids > 0).view(pos_ids.size(0) * self.model.args.max_seq_length).float() # [batch*seq_len] 128 | loss = torch.sum( 129 | - torch.log(torch.sigmoid(pos_logits) + 1e-24) * istarget - 130 | torch.log(1 - torch.sigmoid(neg_logits) + 1e-24) * istarget 131 | ) / torch.sum(istarget) 132 | 133 | auc = torch.sum( 134 | ((torch.sign(pos_logits - neg_logits) + 1) / 2) * istarget 135 | ) / torch.sum(istarget) 136 | 137 | return loss, auc 138 | 139 | 140 | def predict_sample(self, seq_out, test_neg_sample): 141 | # [batch 100 hidden_size] 142 | test_item_emb = self.model.item_embeddings(test_neg_sample) 143 | # [batch hidden_size] 144 | test_logits = torch.bmm(test_item_emb, seq_out.unsqueeze(-1)).squeeze(-1) # [B 100] 145 | return test_logits 146 | 147 | def predict_full(self, seq_out): 148 | # [item_num hidden_size] 149 | test_item_emb = self.model.item_embeddings.weight 150 | # [batch hidden_size ] 151 | rating_pred = torch.matmul(seq_out, test_item_emb.transpose(0, 1)) 152 | return rating_pred 153 | 154 | class ContrastVAETrainer(Trainer): 155 | def __init__(self, model, train_dataloader, eval_dataloader, test_dataloader, args): 156 | super(ContrastVAETrainer, self).__init__(model, train_dataloader, eval_dataloader, test_dataloader,args) 157 | self.step = 0 158 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 159 | self.cl_criterion = NCELoss(args.temperature, device) 160 | self.variational_dropout = args.variational_dropout 161 | self.args = args 162 | 163 | 164 | def kl_anneal_function(self, anneal_cap, step, total_annealing_step): 165 | """ 166 | 167 | :param step: increment by 1 for every forward-backward step 168 | :param k: temperature for logistic annealing 169 | :param x0: pre-fixed parameter control the speed of anealing. total annealing steps 170 | :return: 171 | """ 172 | # borrows from https://github.com/timbmg/Sentence-VAE/blob/master/train.py 173 | return min(anneal_cap, (1.0*step)/total_annealing_step) 174 | 175 | def loss_fn_vanila(self, reconstructed_seq1, mu, log_var, target_pos_seq, target_neg_seq, step): 176 | """ 177 | compute kl divergence, reconstruction 178 | :param sequence_reconstructed: b*max_Sq*d 179 | :param mu: b*d 180 | :param log_var: b*d 181 | :param target_pos_seq: b*max_Sq*d 182 | :param target_neg_seq: b*max_Sq*d 183 | :return: 184 | """ 185 | 186 | """compute KL divergence""" 187 | 188 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=-1)) 189 | kld_weight = self.kl_anneal_function(self.args.anneal_cap, step, self.args.total_annealing_step) 190 | 191 | """compute reconstruction loss from Trainer""" 192 | recons_loss1, recons_auc = self.cross_entropy(reconstructed_seq1, target_pos_seq, target_neg_seq) 193 | 194 | 195 | loss = recons_loss1 + kld_weight*kld_loss 196 | 197 | return loss, recons_auc 198 | 199 | 200 | def loss_fn_latent_clr(self, reconstructed_seq1, reconstructed_seq2, mu1, mu2, log_var1, log_var2, z1, z2, target_pos_seq, target_neg_seq, step): 201 | """ 202 | compute kl divergence, reconstruction loss and contrastive loss 203 | :param sequence_reconstructed: b*max_Sq*d 204 | :param mu: b*d 205 | :param log_var: b*d 206 | :param target_pos_seq: b*max_Sq*d 207 | :param target_neg_seq: b*max_Sq*d 208 | :return: 209 | """ 210 | 211 | """compute KL divergence""" 212 | 213 | kld_loss1 = torch.mean(-0.5 * torch.sum(1 + log_var1 - mu1 ** 2 - log_var1.exp(), dim=-1)) 214 | kld_loss2 = torch.mean(-0.5 * torch.sum(1 + log_var2 - mu2 ** 2 - log_var2.exp(), dim=-1)) 215 | kld_weight = self.kl_anneal_function(self.args.anneal_cap, step, self.args.total_annealing_step) 216 | 217 | """compute reconstruction loss from Trainer""" 218 | recons_loss1, recons_auc = self.cross_entropy(reconstructed_seq1, target_pos_seq, target_neg_seq) 219 | recons_loss2, recons_auc = self.cross_entropy(reconstructed_seq2, target_pos_seq, target_neg_seq) 220 | 221 | """compute clr loss""" 222 | user_representation1 = torch.sum(z1, dim=1) 223 | user_representation2 = torch.sum(z2, dim=1) 224 | 225 | contrastive_loss = self.cl_criterion(user_representation1, user_representation2) 226 | 227 | 228 | loss = recons_loss1 + recons_loss2 + kld_weight*(kld_loss1 + kld_loss2) + self.args.latent_clr_weight * contrastive_loss 229 | return loss, recons_auc 230 | 231 | 232 | def loss_fn_VD_latent_clr(self, reconstructed_seq1, reconstructed_seq2, mu1, mu2, log_var1, log_var2, z1,z2, target_pos_seq, target_neg_seq, step, alpha): 233 | """ 234 | compute kl divergence, reconstruction loss and contrastive loss 235 | :param sequence_reconstructed: b*max_Sq*d 236 | :param mu: b*d 237 | :param log_var: b*d 238 | :param target_pos_seq: b*max_Sq*d 239 | :param target_neg_seq: b*max_Sq*d 240 | :return: 241 | """ 242 | 243 | """compute KL divergence""" 244 | 245 | kld_loss1 = torch.mean(-0.5 * torch.sum(1 + log_var1 - mu1 ** 2 - log_var1.exp(), dim=-1)) 246 | kld_loss2 = torch.mean(-0.5 * torch.sum(1 + log_var2 - mu2 ** 2 - log_var2.exp(), dim=-1)) 247 | kld_weight = self.kl_anneal_function(self.args.anneal_cap, step, self.args.total_annealing_step) 248 | 249 | """compute reconstruction loss from Trainer""" 250 | recons_loss1, recons_auc = self.cross_entropy(reconstructed_seq1, target_pos_seq, target_neg_seq) 251 | recons_loss2, recons_auc = self.cross_entropy(reconstructed_seq2, target_pos_seq, target_neg_seq) 252 | 253 | """compute clr loss""" 254 | 255 | user_representation1 = torch.sum(z1, dim=1) 256 | user_representation2 = torch.sum(z2, dim=1) 257 | contrastive_loss = self.cl_criterion(user_representation1, user_representation2) 258 | 259 | """compute priorKL loss""" 260 | adaptive_alpha_loss = priorKL(alpha) 261 | loss = recons_loss1 + recons_loss2 + kld_weight * (kld_loss1 + kld_loss2) + self.args.latent_clr_weight * contrastive_loss+ adaptive_alpha_loss 262 | 263 | return loss, recons_auc 264 | 265 | def iteration(self, epoch, dataloader, full_sort=False, train=True): 266 | 267 | str_code = "train" if train else "test" 268 | 269 | rec_data_iter = dataloader 270 | if train: 271 | self.model.train() 272 | rec_avg_loss = 0.0 273 | rec_cur_loss = 0.0 274 | rec_avg_auc = 0.0 275 | 276 | 277 | for batch in rec_data_iter: 278 | 279 | batch = tuple(t.to(self.device) for t in batch) 280 | _, input_ids, target_pos, target_neg, _,aug_input_ids = batch # input_ids, target_ids: [b,max_Sq] 281 | 282 | 283 | if self.variational_dropout: 284 | # reconstructed_seq1, reconstructed_seq2, mu, log_var, alpha = self.model.forward(input_ids, self.step) # shape:b*max_Sq*d 285 | reconstructed_seq1, reconstructed_seq2, mu1, mu2, log_var1, log_var2, z1, z2, alpha = self.model.forward(input_ids,0, self.step) 286 | loss, recons_auc = self.loss_fn_VD_latent_clr(reconstructed_seq1, reconstructed_seq2, mu1, mu2, log_var1, log_var2,z1,z2, target_pos, target_neg, self.step, alpha) 287 | 288 | elif self.args.latent_contrastive_learning: 289 | reconstructed_seq1, reconstructed_seq2, mu1, mu2, log_var1, log_var2, z1, z2 = self.model.forward(input_ids, 0,self.step) 290 | loss, recons_auc = self.loss_fn_latent_clr(reconstructed_seq1, reconstructed_seq2, mu1, mu2, log_var1, log_var2, z1, z2, target_pos, target_neg, self.step) 291 | 292 | elif self.args.latent_data_augmentation: 293 | reconstructed_seq1, reconstructed_seq2, mu1, mu2, log_var1, log_var2, z1, z2 = self.model.forward(input_ids, aug_input_ids, self.step) 294 | loss, recons_auc = self.loss_fn_latent_clr(reconstructed_seq1, reconstructed_seq2, mu1, mu2,log_var1, log_var2, z1, z2, target_pos, target_neg,self.step) 295 | 296 | elif self.args.VAandDA: 297 | reconstructed_seq1, reconstructed_seq2, mu1, mu2, log_var1, log_var2, z1, z2, alpha = self.model.forward(input_ids, aug_input_ids, self.step) 298 | loss, recons_auc = self.loss_fn_VD_latent_clr(reconstructed_seq1, reconstructed_seq2, mu1, mu2, log_var1, log_var2,z1,z2, target_pos, target_neg, self.step, alpha) 299 | 300 | else: 301 | reconstructed_seq1, mu, log_var = self.model.forward(input_ids, 0, self.step) # shape:b*max_Sq*d 302 | loss, recons_auc = self.loss_fn_vanila(reconstructed_seq1, mu, log_var, target_pos, target_neg, self.step) 303 | 304 | 305 | 306 | self.optim.zero_grad() 307 | loss.backward() 308 | self.optim.step() 309 | 310 | self.step += 1 311 | rec_avg_loss += loss.item() 312 | rec_cur_loss = loss.item() 313 | rec_avg_auc += recons_auc.item() 314 | 315 | post_fix = { 316 | "epoch": epoch, 317 | "rec_avg_loss": '{:.4f}'.format(rec_avg_loss / len(rec_data_iter)), 318 | "rec_cur_loss": '{:.4f}'.format(rec_cur_loss), 319 | "rec_avg_auc": '{:.4f}'.format(rec_avg_auc / len(rec_data_iter)), 320 | } 321 | 322 | if (epoch + 1) % self.args.log_freq == 0: 323 | print(str(post_fix), flush=True) 324 | 325 | with open(self.args.log_file, 'a') as f: 326 | f.write(str(post_fix) + '\n') 327 | 328 | else: 329 | self.model.eval() 330 | 331 | with torch.no_grad(): 332 | pred_list = None 333 | 334 | if self.args.store_latent: 335 | user_embeddings = torch.zeros((self.args.num_users, self.args.hidden_size)) 336 | seq_mus = torch.zeros((self.args.num_users, self.args.max_seq_length, self.args.hidden_size)) 337 | seq_logvar = torch.zeros((self.args.num_users, self.args.max_seq_length, self.args.hidden_size)) 338 | 339 | if full_sort: 340 | answer_list = None 341 | #for i, batch in rec_data_iter: 342 | print(f"full sort evaluation") 343 | i = 0 344 | for batch in rec_data_iter: 345 | # 0. batch_data will be sent into the device(GPU or cpu) 346 | batch = tuple(t.to(self.device) for t in batch) 347 | user_ids, input_ids, target_pos, target_neg, answers,aug_input_ids = batch 348 | istarget =torch.unsqueeze((target_pos > 0),-1) # [batch*seq_len] 349 | 350 | if self.variational_dropout: 351 | recommend_reconstruct1, reconstructed_seq2, mu1, mu2, log_var1, log_var2, z1, z2, alpha= self.model.forward(input_ids,0, self.step) 352 | if self.args.store_latent: 353 | user_embeddings[user_ids, :] = torch.sum(z1*istarget, 1).cpu() 354 | seq_mus[user_ids, :,:] = mu1.cpu() 355 | seq_logvar[user_ids, :,:] = log_var1.cpu() 356 | 357 | 358 | elif self.args.latent_contrastive_learning: 359 | recommend_reconstruct1, recommend_reconstruct2, mu1, mu2, log_var1, log_var2, z1, z2 = self.model.forward(input_ids, 0, self.step) 360 | if self.args.store_latent: 361 | user_embeddings[user_ids, :] = torch.sum(z1*istarget, 1).cpu() 362 | seq_mus[user_ids, :, :] = mu1.cpu() 363 | seq_logvar[user_ids,:, :] = log_var1.cpu() 364 | 365 | elif self.args.latent_data_augmentation == True: 366 | recommend_reconstruct1, recommend_reconstruct2, mu1, mu2, log_var1, log_var2, z1, z2 = self.model.forward( 367 | input_ids, aug_input_ids, self.step) 368 | if self.args.store_latent: 369 | user_embeddings[user_ids, :] = (z1*istarget).sum(1).cpu() 370 | seq_mus[user_ids,:, :] = mu1.cpu() 371 | seq_logvar[user_ids,:, :] = log_var1.cpu() 372 | 373 | elif self.args.VAandDA: 374 | recommend_reconstruct1, _, mu1, mu2, log_var1, log_var2, z1, z2, alpha = self.model.forward(input_ids, aug_input_ids, self.step) 375 | if self.args.store_latent: 376 | user_embeddings[user_ids, :] = (z1*istarget).sum(1).cpu() 377 | seq_mus[user_ids,:, :] = mu1.cpu() 378 | seq_logvar[user_ids,:, :] = log_var1.cpu() 379 | 380 | else: # vanila beta-vae with transformerr 381 | recommend_reconstruct1, mu, log_var,= self.model.forward(input_ids,0, self.step) 382 | if self.args.store_latent: 383 | res = mu + torch.exp(0.5 * log_var) 384 | user_embeddings[user_ids, :] = torch.sum(res*istarget, 1).cpu() 385 | seq_mus[user_ids,:, :] = mu.cpu() 386 | seq_logvar[user_ids,:, :] = log_var.cpu() 387 | 388 | recommend_output = recommend_reconstruct1[:, -1, :] 389 | rating_pred = self.predict_full(recommend_output) 390 | 391 | rating_pred = rating_pred.cpu().data.numpy().copy() 392 | batch_user_index = user_ids.cpu().numpy() 393 | 394 | rating_pred[self.args.train_matrix[batch_user_index].toarray() > 0] = 0 395 | # reference: https://stackoverflow.com/a/23734295, https://stackoverflow.com/a/20104162 396 | ind = np.argpartition(rating_pred, -40)[:, -40:] 397 | arr_ind = rating_pred[np.arange(len(rating_pred))[:, None], ind] 398 | arr_ind_argsort = np.argsort(arr_ind)[np.arange(len(rating_pred)), ::-1] 399 | batch_pred_list = ind[np.arange(len(rating_pred))[:, None], arr_ind_argsort] 400 | 401 | if i == 0: 402 | pred_list = batch_pred_list 403 | answer_list = answers.cpu().data.numpy() 404 | else: 405 | pred_list = np.append(pred_list, batch_pred_list, axis=0) 406 | answer_list = np.append(answer_list, answers.cpu().data.numpy(), axis=0) 407 | i += 1 408 | 409 | return self.get_full_sort_score(epoch, answer_list, pred_list) 410 | 411 | 412 | else: 413 | assert "We need full_sort evaluation" 414 | return 0 415 | 416 | 417 | 418 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/30 11:06 3 | # @Author : Hui Wang 4 | 5 | import numpy as np 6 | import math 7 | import random 8 | import os 9 | import json 10 | import pickle 11 | from scipy.sparse import csr_matrix 12 | from tqdm import tqdm 13 | import multiprocessing 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | 18 | def set_seed(seed): 19 | random.seed(seed) 20 | os.environ['PYTHONHASHSEED'] = str(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | # some cudnn methods can be random even after fixing the seed 26 | # unless you tell it to be deterministic 27 | torch.backends.cudnn.deterministic = True 28 | 29 | def check_path(path): 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | print(f'{path} created') 33 | 34 | def neg_sample(item_set, item_size): # random sample an item id that is not in the user's interact history 35 | item = random.randint(1, item_size - 1) 36 | while item in item_set: 37 | item = random.randint(1, item_size - 1) 38 | return item 39 | 40 | class EarlyStopping: 41 | """Early stops the training if validation loss doesn't improve after a given patience.""" 42 | def __init__(self, checkpoint_path, patience=7, verbose=False, delta=0): 43 | """ 44 | Args: 45 | patience (int): How long to wait after last time validation loss improved. 46 | Default: 7 47 | verbose (bool): If True, prints a message for each validation loss improvement. 48 | Default: False 49 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 50 | Default: 0 51 | """ 52 | self.checkpoint_path = checkpoint_path 53 | self.patience = patience 54 | self.verbose = verbose 55 | self.counter = 0 56 | self.best_score = None 57 | self.early_stop = False 58 | self.delta = delta 59 | 60 | def compare(self, score): 61 | for i in range(len(score)): 62 | # 有一个指标增加了就认为是还在涨 63 | if score[i] > self.best_score[i]+self.delta: 64 | return False 65 | return True 66 | 67 | def __call__(self, score, model): 68 | # score HIT@10 NDCG@10 69 | 70 | if self.best_score is None: 71 | self.best_score = score 72 | self.score_min = np.array([0]*len(score)) 73 | self.save_checkpoint(score, model) 74 | elif self.compare(score): 75 | self.counter += 1 76 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 77 | if self.counter >= self.patience: 78 | self.early_stop = True 79 | else: 80 | self.best_score = score 81 | self.save_checkpoint(score, model) 82 | self.counter = 0 83 | 84 | def save_checkpoint(self, score, model): 85 | '''Saves model when validation loss decrease.''' 86 | if self.verbose: 87 | # ({self.score_min:.6f} --> {score:.6f}) # 这里如果是一个值的话输出才不会有问题 88 | print(f'Validation score increased. Saving model ...') 89 | torch.save(model.state_dict(), self.checkpoint_path) 90 | self.score_min = score 91 | 92 | 93 | 94 | 95 | 96 | def kmax_pooling(x, dim, k): 97 | index = x.topk(k, dim=dim)[1].sort(dim=dim)[0] 98 | return x.gather(dim, index).squeeze(dim) 99 | 100 | def avg_pooling(x, dim): 101 | return x.sum(dim=dim)/x.size(dim) 102 | 103 | 104 | def generate_rating_matrix_valid(user_seq, num_users, num_items): 105 | # three lists are used to construct sparse matrix 106 | row = [] 107 | col = [] 108 | data = [] 109 | for user_id, item_list in enumerate(user_seq): 110 | for item in item_list[:-2]: # 111 | row.append(user_id) 112 | col.append(item) 113 | data.append(1) 114 | 115 | row = np.array(row) 116 | col = np.array(col) 117 | data = np.array(data) 118 | rating_matrix = csr_matrix((data, (row, col)), shape=(num_users, num_items)) 119 | 120 | return rating_matrix 121 | 122 | def generate_rating_matrix_test(user_seq, num_users, num_items): 123 | # three lists are used to construct sparse matrix 124 | row = [] 125 | col = [] 126 | data = [] 127 | for user_id, item_list in enumerate(user_seq): 128 | for item in item_list[:-1]: # 129 | row.append(user_id) 130 | col.append(item) 131 | data.append(1) 132 | 133 | row = np.array(row) 134 | col = np.array(col) 135 | data = np.array(data) 136 | rating_matrix = csr_matrix((data, (row, col)), shape=(num_users, num_items)) 137 | 138 | return rating_matrix 139 | 140 | def get_user_seqs(data_file): 141 | """ 142 | load txt data file 143 | :param data_file: path of dataset, every line: [user_id item1 item2 item3 ...] 144 | :return: 145 | """ 146 | lines = open(data_file).readlines() 147 | user_seq = [] 148 | item_set = set() 149 | for line in lines: 150 | user, items = line.strip().split(' ', 1) 151 | items = items.split(' ') 152 | items = [int(item) for item in items] 153 | user_seq.append(items) 154 | item_set = item_set | set(items) 155 | max_item = max(item_set) 156 | 157 | num_users = len(lines) 158 | num_items = max_item + 2 159 | 160 | valid_rating_matrix = generate_rating_matrix_valid(user_seq, num_users, num_items) 161 | test_rating_matrix = generate_rating_matrix_test(user_seq, num_users, num_items) 162 | return user_seq, max_item, valid_rating_matrix, test_rating_matrix, num_users 163 | 164 | def get_user_seqs_replace(data_file, num_items): 165 | """ 166 | load txt data file 167 | :param data_file: path of dataset, every line: [user_id item1 item2 item3 ...] 168 | :return: 169 | """ 170 | lines = open(data_file).readlines() 171 | user_seq = [] 172 | item_set = set() 173 | for line in lines: 174 | user, items = line.strip().split(' ', 1) 175 | items = items.split(' ') 176 | items = [int(item) for item in items] 177 | user_seq.append(items) 178 | item_set = item_set | set(items) 179 | 180 | num_users = len(lines) 181 | # num_items = max_item + 2 182 | 183 | valid_rating_matrix = generate_rating_matrix_valid(user_seq, num_users, num_items) 184 | test_rating_matrix = generate_rating_matrix_test(user_seq, num_users, num_items) 185 | return user_seq, valid_rating_matrix, test_rating_matrix, num_users 186 | 187 | def get_user_seqs_long(data_file): 188 | lines = open(data_file).readlines() 189 | user_seq = [] 190 | long_sequence = [] 191 | item_set = set() 192 | for line in lines: 193 | user, items = line.strip().split(' ', 1) 194 | items = items.split(' ') 195 | items = [int(item) for item in items] 196 | long_sequence.extend(items) # 后面的都是采的负例 197 | user_seq.append(items) 198 | item_set = item_set | set(items) 199 | max_item = max(item_set) 200 | 201 | return user_seq, max_item, long_sequence 202 | 203 | def get_user_seqs_and_sample(data_file, sample_file): 204 | lines = open(data_file).readlines() 205 | user_seq = [] 206 | item_set = set() 207 | for line in lines: 208 | user, items = line.strip().split(' ', 1) 209 | items = items.split(' ') 210 | items = [int(item) for item in items] 211 | user_seq.append(items) 212 | item_set = item_set | set(items) 213 | max_item = max(item_set) 214 | 215 | lines = open(sample_file).readlines() 216 | sample_seq = [] 217 | for line in lines: 218 | user, items = line.strip().split(' ', 1) 219 | items = items.split(' ') 220 | items = [int(item) for item in items] 221 | sample_seq.append(items) 222 | 223 | assert len(user_seq) == len(sample_seq) 224 | 225 | return user_seq, max_item, sample_seq 226 | 227 | def get_item2attribute_json(data_file): 228 | item2attribute = json.loads(open(data_file).readline()) 229 | attribute_set = set() 230 | for item, attributes in item2attribute.items(): 231 | attribute_set = attribute_set | set(attributes) 232 | attribute_size = max(attribute_set) # 331 233 | return item2attribute, attribute_size 234 | 235 | def get_metric(pred_list, topk=10): 236 | NDCG = 0.0 237 | HIT = 0.0 238 | MRR = 0.0 239 | # [batch] the answer's rank 240 | for rank in pred_list: 241 | MRR += 1.0 / (rank + 1.0) 242 | if rank < topk: 243 | NDCG += 1.0 / np.log2(rank + 2.0) 244 | HIT += 1.0 245 | return HIT /len(pred_list), NDCG /len(pred_list), MRR /len(pred_list) 246 | 247 | def precision_at_k_per_sample(actual, predicted, topk): 248 | num_hits = 0 249 | for place in predicted: 250 | if place in actual: 251 | num_hits += 1 252 | return num_hits / (topk + 0.0) 253 | 254 | def precision_at_k(actual, predicted, topk): 255 | sum_precision = 0.0 256 | num_users = len(predicted) 257 | for i in range(num_users): 258 | act_set = set(actual[i]) 259 | pred_set = set(predicted[i][:topk]) 260 | sum_precision += len(act_set & pred_set) / float(topk) 261 | 262 | return sum_precision / num_users 263 | 264 | def recall_at_k(actual, predicted, topk): 265 | sum_recall = 0.0 266 | num_users = len(predicted) 267 | true_users = 0 268 | recall_dict = {} 269 | for i in range(num_users): 270 | act_set = set(actual[i]) 271 | pred_set = set(predicted[i][:topk]) 272 | if len(act_set) != 0: 273 | #sum_recall += len(act_set & pred_set) / float(len(act_set)) 274 | one_user_recall = len(act_set & pred_set) / float(len(act_set)) 275 | recall_dict[i] = one_user_recall 276 | sum_recall += one_user_recall 277 | true_users += 1 278 | return sum_recall / true_users, recall_dict 279 | 280 | def cal_mrr(actual, predicted): 281 | sum_mrr = 0. 282 | true_users = 0 283 | num_users = len(predicted) 284 | mrr_dict = {} 285 | for i in range(num_users): 286 | r = [] 287 | act_set = set(actual[i]) 288 | pred_list = predicted[i] 289 | for item in pred_list: 290 | if item in act_set: 291 | r.append(1) 292 | else: 293 | r.append(0) 294 | r = np.array(r) 295 | if np.sum(r) > 0: 296 | #sum_mrr += np.reciprocal(np.where(r==1)[0]+1, dtype=np.float)[0] 297 | one_user_mrr = np.reciprocal(np.where(r==1)[0]+1, dtype=np.float)[0] 298 | sum_mrr += one_user_mrr 299 | true_users += 1 300 | mrr_dict[i] = one_user_mrr 301 | else: 302 | mrr_dict[i] = 0. 303 | return sum_mrr / len(predicted), mrr_dict 304 | 305 | 306 | def apk(actual, predicted, k=10): 307 | """ 308 | Computes the average precision at k. 309 | This function computes the average precision at k between two lists of 310 | items. 311 | Parameters 312 | ---------- 313 | actual : list 314 | A list of elements that are to be predicted (order doesn't matter) 315 | predicted : list 316 | A list of predicted elements (order does matter) 317 | k : int, optional 318 | The maximum number of predicted elements 319 | Returns 320 | ------- 321 | score : double 322 | The average precision at k over the input lists 323 | """ 324 | if len(predicted)>k: 325 | predicted = predicted[:k] 326 | 327 | score = 0.0 328 | num_hits = 0.0 329 | 330 | for i,p in enumerate(predicted): 331 | if p in actual and p not in predicted[:i]: 332 | num_hits += 1.0 333 | score += num_hits / (i+1.0) 334 | 335 | if not actual: 336 | return 0.0 337 | 338 | return score / min(len(actual), k) 339 | 340 | 341 | def mapk(actual, predicted, k=10): 342 | """ 343 | Computes the mean average precision at k. 344 | This function computes the mean average prescision at k between two lists 345 | of lists of items. 346 | Parameters 347 | ---------- 348 | actual : list 349 | A list of lists of elements that are to be predicted 350 | (order doesn't matter in the lists) 351 | predicted : list 352 | A list of lists of predicted elements 353 | (order matters in the lists) 354 | k : int, optional 355 | The maximum number of predicted elements 356 | Returns 357 | ------- 358 | score : double 359 | The mean average precision at k over the input lists 360 | """ 361 | return np.mean([apk(a, p, k) for a, p in zip(actual, predicted)]) 362 | 363 | def ndcg_k(actual, predicted, topk): 364 | res = 0 365 | ndcg_dict = {} 366 | for user_id in range(len(actual)): 367 | k = min(topk, len(actual[user_id])) 368 | idcg = idcg_k(k) 369 | dcg_k = sum([int(predicted[user_id][j] in 370 | set(actual[user_id])) / math.log(j+2, 2) for j in range(topk)]) 371 | res += dcg_k / idcg 372 | ndcg_dict[user_id] = dcg_k / idcg 373 | return res / float(len(actual)), ndcg_dict 374 | 375 | 376 | # Calculates the ideal discounted cumulative gain at k 377 | def idcg_k(k): 378 | res = sum([1.0/math.log(i+2, 2) for i in range(k)]) 379 | if not res: 380 | return 1.0 381 | else: 382 | return res 383 | 384 | def dcg_at_k(r, k, method=1): 385 | """Score is discounted cumulative gain (dcg) 386 | Relevance is positive real values. Can use binary 387 | as the previous methods. 388 | Returns: 389 | Discounted cumulative gain 390 | """ 391 | r = np.asfarray(r)[:k] 392 | if r.size: 393 | if method == 0: 394 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 395 | elif method == 1: 396 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 397 | else: 398 | raise ValueError('method must be 0 or 1.') 399 | return 0. 400 | 401 | 402 | def ndcg_at_k(r, k, method=1): 403 | """Score is normalized discounted cumulative gain (ndcg) 404 | Relevance is positive real values. Can use binary 405 | as the previous methods. 406 | Returns: 407 | Normalized discounted cumulative gain 408 | """ 409 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) 410 | if not dcg_max: 411 | return 0. 412 | return dcg_at_k(r, k, method) / dcg_max 413 | 414 | 415 | def itemperf_recall(ranks, k): 416 | ranks = np.array(ranks) 417 | if len(ranks) == 0: 418 | return 0 419 | return np.sum(ranks<=k) / len(ranks) 420 | 421 | def itemperf_ndcg(ranks, k, size): 422 | ndcg = 0.0 423 | if len(ranks) == 0: 424 | return 0. 425 | for onerank in ranks: 426 | r = np.zeros(size) 427 | r[onerank-1] = 1 428 | ndcg += ndcg_at_k(r, k) 429 | return ndcg / len(ranks) 430 | 431 | 432 | def get_user_performance_perpopularity(train, results_users, Ks): 433 | [recall_dict_list, ndcg_dict_list, mrr_dict] = results_users 434 | short_seq_results = { 435 | "recall": np.zeros(len(Ks)), 436 | "ndcg": np.zeros(len(Ks)), 437 | "mrr": 0., 438 | } 439 | num_short_seqs = 0 440 | 441 | long_seq_results = { 442 | "recall": np.zeros(len(Ks)), 443 | "ndcg": np.zeros(len(Ks)), 444 | "mrr": 0., 445 | } 446 | num_long_seqs = 0 447 | 448 | short7_seq_results = { 449 | "recall": np.zeros(len(Ks)), 450 | "ndcg": np.zeros(len(Ks)), 451 | "mrr": 0., 452 | } 453 | num_short7_seqs = 0 454 | 455 | short37_seq_results = { 456 | "recall": np.zeros(len(Ks)), 457 | "ndcg": np.zeros(len(Ks)), 458 | "mrr": 0., 459 | } 460 | num_short37_seqs = 0 461 | 462 | medium3_seq_results = { 463 | "recall": np.zeros(len(Ks)), 464 | "ndcg": np.zeros(len(Ks)), 465 | "mrr": 0., 466 | } 467 | 468 | num_medium3_seqs = 0 469 | 470 | medium7_seq_results = { 471 | "recall": np.zeros(len(Ks)), 472 | "ndcg": np.zeros(len(Ks)), 473 | "mrr": 0., 474 | } 475 | num_medium7_seqs = 0 476 | 477 | test_users = list(train.keys()) 478 | for result_user in tqdm(test_users): 479 | if len(train[result_user]) <= 3: 480 | num_short_seqs += 1 481 | if len(train[result_user]) <= 7: 482 | num_short7_seqs += 1 483 | if len(train[result_user]) > 3 and len(train[result_user]) <= 7: 484 | num_short37_seqs += 1 485 | if len(train[result_user]) > 3 and len(train[result_user]) < 20: 486 | num_medium3_seqs += 1 487 | if len(train[result_user]) > 7 and len(train[result_user]) < 20: 488 | num_medium7_seqs += 1 489 | if len(train[result_user]) >= 20: 490 | num_long_seqs += 1 491 | for k_ind in range(len(recall_dict_list)): 492 | k = Ks[k_ind] 493 | recall_dict_k = recall_dict_list[k_ind] 494 | ndcg_dict_k = ndcg_dict_list[k_ind] 495 | 496 | for result_user in tqdm(test_users): 497 | if len(train[result_user]) <= 3: 498 | short_seq_results["recall"][k_ind] += recall_dict_k[result_user] 499 | short_seq_results["ndcg"][k_ind] += ndcg_dict_k[result_user] 500 | 501 | if len(train[result_user]) <= 7: 502 | short7_seq_results["recall"][k_ind] += recall_dict_k[result_user] 503 | short7_seq_results["ndcg"][k_ind] += ndcg_dict_k[result_user] 504 | 505 | if len(train[result_user]) > 3 and len(train[result_user]) <= 7: 506 | short37_seq_results["recall"][k_ind] += recall_dict_k[result_user] 507 | short37_seq_results["ndcg"][k_ind] += ndcg_dict_k[result_user] 508 | 509 | if len(train[result_user]) > 3 and len(train[result_user]) < 20: 510 | medium3_seq_results["recall"][k_ind] += recall_dict_k[result_user] 511 | medium3_seq_results["ndcg"][k_ind] += ndcg_dict_k[result_user] 512 | 513 | if len(train[result_user]) > 7 and len(train[result_user]) < 20: 514 | medium7_seq_results["recall"][k_ind] += recall_dict_k[result_user] 515 | medium7_seq_results["ndcg"][k_ind] += ndcg_dict_k[result_user] 516 | 517 | if len(train[result_user]) >= 20: 518 | long_seq_results["recall"][k_ind] += recall_dict_k[result_user] 519 | long_seq_results["ndcg"][k_ind] += ndcg_dict_k[result_user] 520 | 521 | for result_user in tqdm(test_users): 522 | if len(train[result_user]) <= 3: 523 | short_seq_results["mrr"] += mrr_dict[result_user] 524 | 525 | if len(train[result_user]) <= 7: 526 | short7_seq_results["mrr"] += mrr_dict[result_user] 527 | 528 | if len(train[result_user]) > 3 and len(train[result_user]) <= 7: 529 | short37_seq_results["mrr"] += mrr_dict[result_user] 530 | 531 | if len(train[result_user]) > 3 and len(train[result_user]) < 20: 532 | medium3_seq_results["mrr"] += mrr_dict[result_user] 533 | 534 | if len(train[result_user]) > 7 and len(train[result_user]) < 20: 535 | medium7_seq_results["mrr"] += mrr_dict[result_user] 536 | 537 | if len(train[result_user]) >= 20: 538 | long_seq_results["mrr"] += mrr_dict[result_user] 539 | 540 | if num_short_seqs > 0: 541 | short_seq_results["recall"] /= num_short_seqs 542 | short_seq_results["ndcg"] /= num_short_seqs 543 | short_seq_results["mrr"] /= num_short_seqs 544 | print(f"testing #of short seq users with less than 3 training points: {num_short_seqs}") 545 | 546 | if num_short7_seqs > 0: 547 | short7_seq_results["recall"] /= num_short7_seqs 548 | short7_seq_results["ndcg"] /= num_short7_seqs 549 | short7_seq_results["mrr"] /= num_short7_seqs 550 | print(f"testing #of short seq users with less than 7 training points: {num_short7_seqs}") 551 | 552 | if num_short37_seqs > 0: 553 | short37_seq_results["recall"] /= num_short37_seqs 554 | short37_seq_results["ndcg"] /= num_short37_seqs 555 | short37_seq_results["mrr"] /= num_short37_seqs 556 | print(f"testing #of short seq users with 3 - 7 training points: {num_short37_seqs}") 557 | 558 | if num_medium3_seqs > 0: 559 | medium3_seq_results["recall"] /= num_medium3_seqs 560 | medium3_seq_results["ndcg"] /= num_medium3_seqs 561 | medium3_seq_results["mrr"] /= num_medium3_seqs 562 | print(f"testing #of short seq users with medium3 training points: {num_medium3_seqs}") 563 | 564 | if num_medium7_seqs > 0: 565 | medium7_seq_results["recall"] /= num_medium7_seqs 566 | medium7_seq_results["ndcg"] /= num_medium7_seqs 567 | medium7_seq_results["mrr"] /= num_medium7_seqs 568 | print(f"testing #of short seq users with medium7 training points: {num_medium7_seqs}") 569 | 570 | if num_long_seqs > 0: 571 | long_seq_results["recall"] /= num_long_seqs 572 | long_seq_results["ndcg"] /= num_long_seqs 573 | long_seq_results["mrr"] /= num_long_seqs 574 | 575 | print(f"testing #of short seq users with >= 20 training points: {num_long_seqs}") 576 | 577 | print('testshort: ' + str(short_seq_results)) 578 | print('testshort7: ' + str(short7_seq_results)) 579 | print('testshort37: ' + str(short37_seq_results)) 580 | print('testmedium3: ' + str(medium3_seq_results)) 581 | print('testmedium7: ' + str(medium7_seq_results)) 582 | print('testlong: ' + str(long_seq_results)) 583 | 584 | 585 | def eval_one_setitems(x): 586 | Ks = [1, 5, 10, 15, 20, 40] 587 | result = { 588 | "recall": 0, 589 | "ndcg": 0 590 | } 591 | ranks = x[0] 592 | k_ind = x[1] 593 | test_num_items = x[2] 594 | freq_ind = x[3] 595 | 596 | result['recall'] = itemperf_recall(ranks, Ks[k_ind]) 597 | result['ndcg'] = itemperf_ndcg(ranks, Ks[k_ind], test_num_items) 598 | 599 | return result, k_ind, freq_ind 600 | 601 | 602 | def get_item_performance_perpopularity(items_in_freqintervals, all_pos_items_ranks, Ks, freq_quantiles, num_items): 603 | cores = multiprocessing.cpu_count() // 2 604 | pool = multiprocessing.Pool(cores) 605 | test_num_items_in_intervals = [] 606 | interval_results = [{'recall': np.zeros(len(Ks)), 'ndcg': np.zeros(len(Ks))} for _ in range(len(items_in_freqintervals))] 607 | 608 | all_freq_all_ranks = [] 609 | all_ks = [] 610 | all_numtestitems = [] 611 | all_freq_ind = [] 612 | for freq_ind, item_list in enumerate(items_in_freqintervals): 613 | num_item_pos_interactions = 0 614 | all_ranks = [] 615 | interval_items = [] 616 | for item in item_list: 617 | pos_ranks_oneitem = all_pos_items_ranks.get(item, []) 618 | if len(pos_ranks_oneitem) > 0: 619 | interval_items.append(item) 620 | all_ranks.extend(pos_ranks_oneitem) 621 | for k_ind in range(len(Ks)): 622 | all_ks.append(k_ind) 623 | all_freq_all_ranks.append(all_ranks) 624 | all_numtestitems.append(num_items) 625 | all_freq_ind.append(freq_ind) 626 | test_num_items_in_intervals.append(interval_items) 627 | 628 | item_eval_freq_data = zip(all_freq_all_ranks, all_ks, all_numtestitems, all_freq_ind) 629 | batch_item_result = pool.map(eval_one_setitems, item_eval_freq_data) 630 | 631 | 632 | for oneresult in batch_item_result: 633 | result_dict = oneresult[0] 634 | k_ind = oneresult[1] 635 | freq_ind = oneresult[2] 636 | interval_results[freq_ind]['recall'][k_ind] = result_dict['recall'] 637 | interval_results[freq_ind]['ndcg'][k_ind] = result_dict['ndcg'] 638 | 639 | 640 | 641 | item_freq = freq_quantiles 642 | for i in range(len(item_freq)+1): 643 | if i == 0: 644 | print('For items in freq between 0 - %d with %d items: ' % (item_freq[i], len(test_num_items_in_intervals[i]))) 645 | elif i == len(item_freq): 646 | print('For items in freq between %d - max with %d items: ' % (item_freq[i-1], len(test_num_items_in_intervals[i]))) 647 | else: 648 | print('For items in freq between %d - %d with %d items: ' % (item_freq[i-1], item_freq[i], len(test_num_items_in_intervals[i]))) 649 | for k_ind in range(len(Ks)): 650 | k = Ks[k_ind] 651 | print('Recall@%d:%.6f, NDCG@%d:%.6f'%(k, interval_results[i]['recall'][k_ind], k, interval_results[i]['ndcg'][k_ind])) 652 | --------------------------------------------------------------------------------