├── CLEA-new ├── Dunn_0.sh ├── Instacart_0.sh ├── dataset │ ├── Dunn │ │ └── user_date_tran_dict_new.txt │ └── Instacart │ │ └── user_date_tran_dict_new.txt ├── main_1.py └── module │ ├── config.py │ ├── logger.py │ ├── model_1.py │ └── util.py ├── CLEA.pdf ├── CLEA ├── Dunn_0.sh ├── Instacart_0.sh ├── dataset │ ├── Dunn │ │ └── user_date_tran_dict_new.txt │ └── Instacart │ │ └── user_date_tran_dict_new.txt ├── main_1.py └── module │ ├── config.py │ ├── logger.py │ ├── model_1.py │ └── util.py ├── README.md └── dataset_process_instacart.py /CLEA-new/Dunn_0.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | python main_1.py \ 4 | --pos_margin 0.1 \ 5 | --neg_margin 0.9 \ 6 | --same_embedding 1 \ 7 | --dataset 'Dunn' \ 8 | --max_basket_size 50 \ 9 | --max_basket_num 6 \ 10 | --num_product 4995 \ 11 | --num_users 36421 \ 12 | --to2 0 \ 13 | --alternative_train_batch 500 \ 14 | --test_every_epoch 4 \ 15 | --G1_flag 0 \ 16 | --device 1 \ 17 | --log_fire 'basemodel' \ 18 | --dropout 0.2 \ 19 | --lr 0.001 \ 20 | --l2 0.00001 \ 21 | --output_dir './result' \ 22 | --pretrain_epoch 20 \ 23 | --before_epoch 0 \ 24 | --epoch 10 \ 25 | --batch_size 256 \ 26 | --embedding_dim 128 \ 27 | --temp_learn 1 \ 28 | --history 1 29 | 30 | 31 | python main_1.py \ 32 | --pos_margin 0.1 \ 33 | --neg_margin 0.9 \ 34 | --same_embedding 1 \ 35 | --dataset 'Dunn' \ 36 | --max_basket_size 50 \ 37 | --max_basket_num 6 \ 38 | --num_product 4995 \ 39 | --num_users 36421 \ 40 | --to2 0 \ 41 | --alternative_train_epoch 10 \ 42 | --alternative_train_epoch_D 10 \ 43 | --alternative_train_batch 200 \ 44 | --test_every_epoch 4 \ 45 | --G1_flag 0 \ 46 | --device 1 \ 47 | --log_fire 't_10_ANN_0001_triple_1_lr_001_G1_001_pos_0.1' \ 48 | --dropout 0.2 \ 49 | --lr 0.001 \ 50 | --G1_lr 0.001 \ 51 | --l2 0.00001 \ 52 | --output_dir './result' \ 53 | --pretrain_epoch 2 \ 54 | --before_epoch 2 \ 55 | --epoch 60 \ 56 | --batch_size 256 \ 57 | --embedding_dim 128 \ 58 | --temp_learn 0 \ 59 | --temp_min 0.2 \ 60 | --ANNEAL_RATE 0.0001 \ 61 | --temp 10 \ 62 | --history 1 -------------------------------------------------------------------------------- /CLEA-new/Instacart_0.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | python main_1.py \ 4 | --pos_margin 0.1 \ 5 | --neg_margin 0.9 \ 6 | --same_embedding 1 \ 7 | --dataset 'Instacart' \ 8 | --num_product 8222 \ 9 | --num_users 6886 \ 10 | --max_basket_size 35 \ 11 | --max_basket_num 32 \ 12 | --to2 0 \ 13 | --alternative_train_batch 1000 \ 14 | --test_every_epoch 4 \ 15 | --G1_flag 0 \ 16 | --device 0 \ 17 | --log_fire 'basemodel' \ 18 | --dropout 0.2 \ 19 | --lr 0.001 \ 20 | --l2 0.00001 \ 21 | --output_dir './result' \ 22 | --pretrain_epoch 20 \ 23 | --before_epoch 0 \ 24 | --epoch 10 \ 25 | --batch_size 256 \ 26 | --embedding_dim 128 \ 27 | --temp_learn 1 \ 28 | --history 1 29 | 30 | 31 | python main_1.py \ 32 | --same_embedding 1 \ 33 | --pos_margin 0.1 \ 34 | --neg_margin 0.9 \ 35 | --dataset 'Instacart' \ 36 | --num_product 8222 \ 37 | --num_users 6886 \ 38 | --max_basket_size 35 \ 39 | --max_basket_num 32 \ 40 | --to2 0 \ 41 | --alternative_train_batch 1000 \ 42 | --alternative_train_epoch 5 \ 43 | --alternative_train_epoch_D 5 \ 44 | --test_every_epoch 4 \ 45 | --G1_flag 0 \ 46 | --device 0 \ 47 | --log_fire 't_10_ANN_0001_same_1_triple_1_lr_001_G1_001_diftemp0_0_pos_01_epoch_G5_D5' \ 48 | --dropout 0.2 \ 49 | --lr 0.001 \ 50 | --G1_lr 0.001 \ 51 | --l2 0.00001 \ 52 | --output_dir './result' \ 53 | --pretrain_epoch 2 \ 54 | --before_epoch 2 \ 55 | --epoch 40 \ 56 | --batch_size 128 \ 57 | --embedding_dim 128 \ 58 | --temp_learn 0 \ 59 | --temp_min 0.3 \ 60 | --ANNEAL_RATE 0.0001 \ 61 | --temp 10 \ 62 | --history 1 63 | 64 | -------------------------------------------------------------------------------- /CLEA-new/module/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--same_embedding', type=int, default=1, help='max_basket_num') 5 | parser.add_argument('--test_every_epoch', type=int, default=5, help='max_basket_num') 6 | parser.add_argument('--pos_margin', type=float, default=0.3, help='max_basket_num') 7 | parser.add_argument('--neg_margin', type=float, default=0.7, help='max_basket_num') 8 | parser.add_argument('--next_k', type=int, default=2, help='neg_ratio') 9 | parser.add_argument('--neg_ratio', type=int, default=1, help='neg_ratio') 10 | parser.add_argument('--judge_ratio', type=int, default = 1, help='judge_ratio') 11 | parser.add_argument('--device_id', type=int, default=0, help='GPU_ID') 12 | parser.add_argument('--seed', type=int, default=24, help='seed') 13 | parser.add_argument('--G1_flag', type=int, default=0, help='train_type : with G1 1 / with no G1 -1 / from no G1 to G1 0') 14 | 15 | parser.add_argument('--sd1', type=float, default=1, help='sd1') 16 | parser.add_argument('--sd2', type=float, default=1, help='sd2') 17 | parser.add_argument('--sd3', type=float, default=1, help='sd3') 18 | parser.add_argument('--sd4', type=float, default=1, help='sd4') 19 | parser.add_argument('--sd5', type=float, default=1, help='sd5') 20 | 21 | parser.add_argument('--basket_pool_type', type=str, default='avg', help='basket_pool_type') 22 | parser.add_argument('--num_layer', type=int, default=1, help='num_layer') 23 | parser.add_argument('--test_type', type=int, default=1000, help='0:old 1000:1000 500:500') 24 | 25 | parser.add_argument('--group_split1', type=int, default=4, help='basket_group_split') 26 | parser.add_argument('--group_split2', type=int, default=6, help='basket_group_split') 27 | parser.add_argument('--max_basket_size', type=int, default=35, help='max_basket_size') 28 | parser.add_argument('--max_basket_num', type=int, default=32, help='max_basket_num') 29 | parser.add_argument('--dataset', type=str, default='Instacart', help='dataset name') 30 | parser.add_argument('--num_product', type=int, default=8222 , help='n_items TaFeng:9963 Instacart:8222 Delicious:6539') 31 | parser.add_argument('--num_users', type=int, default= 6886, help='n_users TaFeng:16060 Instacart:6885 Delicious:1735') 32 | parser.add_argument('--to2', type=int, default= 0, help='n_users TaFeng:16060 Instacart:6885 Delicious:1735') 33 | parser.add_argument('--distrisample', type=int, default= 0, help='') 34 | 35 | parser.add_argument('--output_dir', type=str, default='./result', help='') 36 | parser.add_argument('--log_fire', type=str, default='test', help='basket_group_split') #_learning 37 | parser.add_argument('--temp', type=float, default=1, help='') #1 38 | parser.add_argument('--temp_min', type=float, default=0.3, help='') #0.3 39 | parser.add_argument('--pretrain_epoch', type=int, default= 2, help='n_users TaFeng:16060 Instacart:6885 Delicious:1735') 40 | parser.add_argument('--pretrain_h', type=int, default= 2, help='n_users TaFeng:16060 Instacart:6885 Delicious:1735') 41 | parser.add_argument('--batch_size', type=int, default=128, help='batch_size') 42 | parser.add_argument('--epoch', type=int, default=50, help='the number of epochs to train for') 43 | parser.add_argument('--ANNEAL_RATE', type=float, default=0.0002, help='ANNEAL_RATE') #0.0003 44 | parser.add_argument('--lr', type=float, default=0.00001, help='learning rate') 45 | parser.add_argument('--H_lr', type=float, default=0.00001, help='learning rate') # [0.001, 0.0005, 0.0001] 46 | parser.add_argument('--G1_lr', type=float, default=0.0001, help='learning rate') 47 | parser.add_argument('--G2D_lr', type=float, default=0.0001, help='learning rate') 48 | parser.add_argument('--l2', type=float, default=0.00001, help='l2 penalty') # [0.001, 0.0005, 0.0001, 0.00005, 0.00001] 49 | parser.add_argument('--embedding_dim', type=int, default=256,help='hidden sise') # [0.001, 0.0005, 0.0001, 0.00005, 0.00001] 50 | parser.add_argument('--alternative_train_epoch', type=int, default=5, help='max_basket_num') 51 | parser.add_argument('--alternative_train_epoch_D', type=int, default=1, help='max_basket_num') 52 | parser.add_argument('--alternative_train_epoch_G2', type=int, default=1, help='!!!!!!!!!!! ') 53 | parser.add_argument('--alternative_train_batch', type=int, default=200, help='max_basket_num') 54 | parser.add_argument('--dropout', type=float, default=0.2, help='dropout') 55 | parser.add_argument('--history', type=int, default=0, help='history') 56 | parser.add_argument('--temp_learn', type=int, default=0, help='history') 57 | parser.add_argument('--before_epoch', type=int, default=2, help='basket_group_split') #18 63 58 | parser.add_argument('--double', type=int, default=0, help='basket_group_split') 59 | parser.add_argument('--clip_value', type=float, default=0.1, help='dropout') 60 | parser.add_argument('--rnn', type=str, default='GRU', help='dropout') 61 | 62 | parser.add_argument('--soft', type=int, default=0, help='whether to train with soft attention') 63 | 64 | parser.add_argument('--margin1', type=float, default=0.1, help='new_tripleloss1')# TODO new 65 | parser.add_argument('--margin2', type=float, default=0.1, help='new_tripleloss1')# TODO new 66 | args = parser.parse_args() 67 | 68 | 69 | # -*- coding:utf-8 -*- 70 | class Config(object): 71 | def __init__(self): 72 | self.margin1 = args.margin1# TODO new 73 | self.margin2 = args.margin2# TODO new 74 | 75 | self.same_embedding = args.same_embedding 76 | self.neg_margin = args.neg_margin 77 | self.pos_margin = args.pos_margin 78 | self.next_k = args.next_k 79 | self.pretrain_h = args.pretrain_h 80 | self.soft = args.soft 81 | 82 | self.rnn = args.rnn 83 | self.double = args.double 84 | self.alternative_train_epoch_D = args.alternative_train_epoch_D 85 | self.alternative_train_epoch_G2 = args.alternative_train_epoch_G2 86 | self.clip_value = args.clip_value 87 | self.temp_learn = args.temp_learn 88 | self.output_dir = args.output_dir 89 | self.G1_lr = args.G1_lr 90 | self.distrisample = args.distrisample 91 | self.pretrain_epoch = args.pretrain_epoch 92 | self.to2 = args.to2 93 | self.MODEL_DIR = './runs' 94 | self.input_dir = 'dataset/{}'.format(args.dataset)+'/user_date_tran_dict_new.txt' 95 | self.dataset = args.dataset 96 | self.epochs = args.epoch 97 | self.device_id = args.device_id 98 | self.log_interval = 500 # num of batches between two logging #300 99 | self.num_users = args.num_users # 100 | self.num_product = args.num_product 101 | self.item_list = list(range(args.num_product)) 102 | self.seed = args.seed 103 | self.test_ratio = 100 104 | self.log_fire = args.log_fire 105 | self.test_type = args.test_type 106 | self.test_every_epoch = args.test_every_epoch 107 | 108 | 109 | self.alternative_train_epoch = args.alternative_train_epoch 110 | self.alternative_train_batch = args.alternative_train_batch 111 | self.max_basket_size = args.max_basket_size 112 | self.max_basket_num = args.max_basket_num 113 | self.group_split1 = args.group_split1 114 | self.group_split2 = args.group_split2 115 | self.clip = 0 116 | self.batch_size = args.batch_size 117 | self.neg_ratio = args.neg_ratio 118 | self.judge_ratio = args.judge_ratio 119 | self.sd1 = args.sd1 120 | self.sd2 = args.sd2 121 | self.sd3 = args.sd3 122 | self.sd4 = args.sd4 123 | self.sd5 = args.sd5 124 | self.learning_rate = args.lr 125 | self.H_lr = args.H_lr 126 | self.G1_lr = args.G1_lr 127 | self.G2D_lr = args.G2D_lr 128 | self.dropout = args.dropout 129 | self.weight_decay = args.l2 130 | self.basket_pool_type = args.basket_pool_type 131 | self.num_layer = args.num_layer 132 | 133 | self.embedding_dim = args.embedding_dim 134 | self.ANNEAL_RATE = args.ANNEAL_RATE 135 | 136 | self.temp = args.temp 137 | self.temp_min = args.temp_min 138 | self.before_epoch = args.before_epoch 139 | 140 | self.G1_flag = args.G1_flag 141 | self.histroy = args.history 142 | 143 | def list_all_member(self, logger): 144 | for name, value in vars(self).items(): 145 | if not name.startswith('item'): 146 | logger.info('%s=%s' % (name, value)) 147 | 148 | 149 | ''' 150 | 151 | ''' -------------------------------------------------------------------------------- /CLEA-new/module/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | 5 | 6 | class Logger(object): 7 | 8 | def __init__(self, filename): 9 | 10 | self.logger = logging.getLogger(filename) 11 | self.logger.setLevel(logging.DEBUG) 12 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d: %(message)s', 13 | datefmt='%Y-%m-%d %H:%M:%S') 14 | 15 | # write into file 16 | fh = logging.FileHandler(filename) 17 | fh.setLevel(logging.DEBUG) 18 | fh.setFormatter(formatter) 19 | 20 | # show on console 21 | ch = logging.StreamHandler(sys.stdout) 22 | ch.setLevel(logging.DEBUG) 23 | ch.setFormatter(formatter) 24 | 25 | # add to Handler 26 | self.logger.addHandler(fh) 27 | self.logger.addHandler(ch) 28 | 29 | def _flush(self): 30 | for handler in self.logger.handlers: 31 | handler.flush() 32 | 33 | def debug(self, message): 34 | self.logger.debug(message) 35 | self._flush() 36 | 37 | def info(self, message): 38 | self.logger.info(message) 39 | self._flush() 40 | 41 | def warning(self, message): 42 | self.logger.warning(message) 43 | self._flush() 44 | 45 | def error(self, message): 46 | self.logger.error(message) 47 | self._flush() 48 | 49 | def critical(self, message): 50 | self.logger.critical(message) 51 | self._flush() 52 | 53 | 54 | if __name__ == '__main__': 55 | log = Logger('NB.log') 56 | log.debug('debug') 57 | log.info('info') 58 | log.warning('warning') 59 | log.error('error') 60 | log.critical('critical') 61 | 62 | -------------------------------------------------------------------------------- /CLEA-new/module/model_1.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class NBModel(nn.Module): # 10 | def __init__(self, config, device): 11 | super(NBModel, self).__init__() 12 | self.device = device 13 | self.sd1 = config.sd1 14 | self.sd2 = config.sd2 15 | self.sd3 = config.sd3 16 | self.sd4 = config.sd4 17 | self.sd5 = config.sd5 18 | 19 | self.margin1 = config.margin1# TODO new 20 | self.margin2 = config.margin2# TODO new 21 | 22 | self.neg_margin = config.neg_margin 23 | self.pos_margin = config.pos_margin 24 | 25 | self.num_users = config.num_users 26 | 27 | self.judge_ratio = config.judge_ratio ### 28 | self.num_product = config.num_product 29 | 30 | # self.embed = nn.Embedding(config.num_product + 1, config.embedding_dim, padding_idx=0) # , 31 | # self.user_embed = nn.Embedding(config.num_users, config.embedding_dim) 32 | 33 | self.D = Discriminator(config, self.device) 34 | self.G0 = Generator1(config, self.device) 35 | self.G2 = Generator2(config, self.device) 36 | 37 | self.G1_flag = 0 38 | self.mse = nn.MSELoss() 39 | 40 | def init_weight(self): 41 | torch.nn.init.xavier_normal_(self.embed.weight) 42 | 43 | # profile 44 | def forward(self, T, uid, input_seq_tensor, tar_b_tensor, weight, history, neg_set_tensor=None, train=True, 45 | G1flag=0, pretrain=0, sd2=1): 46 | ''' 47 | 48 | :param T: is Gumbel softmax's temperature 49 | :param uid: is userid 50 | :param input_seq_tensor: K * B 51 | :param tar_b_tensor: K 52 | :param weight: itemnum 53 | :param history: K * itemnum 54 | :param neg_set_tensor: K * neg_ratio 55 | :param train: whether train 56 | :return: classify prob metric K * itemnum 57 | ''' 58 | 59 | self.sd5 = sd2 60 | 61 | self.G1_flag = G1flag 62 | self.pretrain = pretrain 63 | # input_embeddings = self.embed(input_seq_tensor + 1) # K * B * H 64 | # target_embedding = self.embed(tar_b_tensor + 1) # K * H 65 | 66 | mask = torch.ones_like(input_seq_tensor,dtype = torch.float).to(self.device) 67 | mask[input_seq_tensor == -1] = 0 68 | tar_expand = tar_b_tensor.view(-1,1).expand_as(input_seq_tensor) 69 | mask0 = torch.zeros_like(input_seq_tensor,dtype = torch.float).to(self.device) 70 | mask0[tar_expand == input_seq_tensor] = 1 71 | 72 | input_embeddings = self.G2.embed1(input_seq_tensor + 1) 73 | target_embedding = self.G2.embed1(tar_b_tensor + 1) 74 | 75 | 76 | 77 | test = 1 78 | if train == True: 79 | test = 0 80 | # print(self.embed.weight.data) 81 | if ((self.G1_flag == 0) & (pretrain == 0)): # 82 | # 83 | self.filter_basket = torch.ones_like(input_seq_tensor,dtype = torch.float).to(self.device) # K * B 84 | real_generate_embedding1 = self.G2(self.filter_basket, input_seq_tensor,uid) # K*H 85 | fake_discr = self.D(real_generate_embedding1, history, tar_b_tensor) # K*n_items 86 | 87 | if train == True: 88 | all_sum = mask.sum() 89 | 90 | loss, (p1, p2, p3, p4) = self.loss2_G1flag0(fake_discr, tar_b_tensor) 91 | 92 | return loss, fake_discr, (p1, p2, p3, p4), (all_sum, all_sum, all_sum, all_sum/all_sum,all_sum/all_sum) 93 | 94 | fake_discr = torch.softmax(fake_discr, dim=-1) 95 | return mask.sum() / mask.sum(),mask0.sum() / mask0.sum(), fake_discr 96 | else: 97 | self.filter_basket, test_basket = self.G0(input_seq_tensor, T, tar_b_tensor, self.G1_flag,test,input_embeddings,target_embedding) # K * B 98 | real_generate_embedding1 = self.G2(self.filter_basket[:, :, 0], input_seq_tensor,uid) 99 | fake_discr = self.D(real_generate_embedding1, history, tar_b_tensor, input_seq_tensor) # K*n_items 100 | 101 | 102 | ################################################ 103 | select_repeats = mask0 * ((self.filter_basket[:,:,0] > 1 / 2 ).float())#torch.tensor((self.filter_basket[:,:,0] > 1 / 2 ).int(),dtype = torch.long).to(self.device) 104 | 105 | repeat_ratio = (select_repeats.sum(1)/(mask0.sum(1)+1e-24)).sum()/(((mask0.sum(1)>0).float()+1e-24).sum()) 106 | 107 | if train == True: 108 | 109 | rest = mask * self.filter_basket[:, :, 0]#.detach() 110 | rest_int = (rest > 1 / 2).int() 111 | real_rest_sum = rest.sum() 112 | real_rest_sum_int = rest_int.sum() 113 | all_sum = mask.sum() 114 | ratio = (rest * ((rest > 1 / 2).float())).sum() / max(1, ((rest > 1 / 2).float()).sum()) 115 | 116 | # self.rest_basket = torch.ones_like(self.filter_basket,dtype = torch.float).to(self.device) - self.filter_basket 117 | rest_generate_embedding1 = self.G2(self.filter_basket[:, :, 1], input_seq_tensor,uid) 118 | rest_discr = self.D(rest_generate_embedding1, history, tar_b_tensor, input_seq_tensor) 119 | 120 | filter_pos = mask * self.filter_basket[:, :, 0] 121 | filter_neg = mask * self.filter_basket[:, :, 1] 122 | # if ((self.sd5 == 10000)&(self.G1_flag == 1)): 123 | # loss, (p1, p2, p3, p4) = self.loss2_G1flag1(fake_discr, tar_b_tensor,rest_discr) 124 | # 125 | # return loss, fake_discr, (p1, p2, p3, p4), (real_rest_sum, real_rest_sum_int, all_sum, ratio) 126 | 127 | ########################################### 128 | self.whole_basket = torch.ones_like(input_seq_tensor,dtype = torch.float).to(self.device) # K * B 129 | whole_generate_embedding1 = self.G2(self.whole_basket, input_seq_tensor,uid) # K*H 130 | whole_discr = self.D(whole_generate_embedding1, history, tar_b_tensor) # K*n_items 131 | 132 | loss, (p1, p2, p3, p4) = self.loss2_G1flag1(fake_discr, tar_b_tensor, rest_discr, 133 | whole_discr, filter_pos, filter_neg, 134 | mask) 135 | return loss, fake_discr, (p1, p2, p3, p4), (real_rest_sum, real_rest_sum_int, all_sum, ratio,repeat_ratio) 136 | 137 | select_repeats = mask0 * ((test_basket > 1 / 2).float())#torch.tensor((test_basket > 1 / 2).int(),dtype = torch.long).to(self.device) 138 | test_repeat_ratio = (select_repeats.sum(1) / (mask0.sum(1) + 1e-24)).sum() / (((mask0.sum(1) > 0).float()).sum()) 139 | # test_repeat_ratio = torch.mean(select_repeats.sum(1) / (mask0.sum(1) + 1e-24)) 140 | 141 | rest_test = mask * test_basket.detach() 142 | rest_sum = rest_test.sum() 143 | all_sum = mask.sum() 144 | test_rest_ratio = rest_sum / all_sum 145 | test_generate_embedding1 = self.G2(test_basket, input_seq_tensor,uid) 146 | test_discr = self.D(test_generate_embedding1, history, tar_b_tensor, input_seq_tensor) # K*n_items 147 | test_discr = torch.softmax(test_discr, -1) 148 | return test_rest_ratio,test_repeat_ratio, test_discr 149 | 150 | def loss2_G1flag0(self, fake_discr, target_labels): 151 | ''' 152 | :param fake_discr: K * itemnum 153 | :param target_labels: K 154 | :param neg_labels: K * neg_ratio = K * nK 155 | :return: 156 | ''' 157 | fake_discr = torch.softmax(fake_discr, dim=-1) 158 | 159 | item_num = fake_discr.size(0) # K 160 | index = torch.tensor(np.linspace(0, item_num, num=item_num, endpoint=False), dtype=torch.long) # K 161 | pfake = fake_discr[index, target_labels] # K 162 | 163 | loss_1 = - torch.mean(torch.log((pfake) + 1e-24)) 164 | 165 | loss = self.sd1 * loss_1 166 | return loss, (loss_1, loss_1, loss_1, loss_1) 167 | 168 | def loss2_G1flag1(self, fake_discr, target_labels, rest_discri,whole_discr,filter_pos,filter_neg,mask): 169 | ''' 170 | :param fake_discr: K * itemnum 171 | :param filter_pos: K * B 172 | :param target_labels: K 173 | :param neg_labels: K * neg_ratio = K * nK 174 | :param rest_discri: K * itemnum 175 | :param neg_discri: (Kxjudge_ratio) * itemnum 176 | :return: 177 | ''' 178 | fake_discr = torch.softmax(fake_discr,dim = -1) 179 | rest_discri = torch.softmax(rest_discri,dim = -1) 180 | whole_discr = torch.softmax(whole_discr, dim=-1) 181 | 182 | 183 | item_num = fake_discr.size(0) # K 184 | index = torch.tensor(np.linspace(0, item_num, num=item_num, endpoint=False), dtype=torch.long) # K 185 | pfake = fake_discr[index, target_labels] # K 186 | prest = rest_discri[index, target_labels] # K 187 | pwhole = whole_discr[index, target_labels] # K 188 | 189 | pos_restratio = torch.mean(filter_pos.sum(1).view(1,-1)/(mask.sum(1).view(1,-1)),dim=-1).view(1,-1) # 1*1 190 | pos_margin = self.pos_margin * torch.ones_like(pos_restratio).to(self.device) # 191 | zeros = torch.zeros_like(pos_restratio).to(self.device) 192 | pos_restratio = torch.cat((pos_margin-pos_restratio,zeros),dim=0) #2*1 193 | 194 | neg_restratio = torch.mean(filter_neg.sum(1).view(1,-1)/(mask.sum(1).view(1,-1)),dim=-1).view(1,-1) # 1*1 195 | neg_margin = self.neg_margin * torch.ones_like(neg_restratio).to(self.device) 196 | neg_restratio = torch.cat((neg_restratio - neg_margin, zeros), dim=0) 197 | 198 | loss_0 = torch.max(pos_restratio,dim = 0)[0] + torch.max(neg_restratio,dim = 0)[0] 199 | loss_1 = - torch.mean(torch.nn.LogSigmoid()(pfake - pwhole)) 200 | loss_2 = - torch.mean(torch.nn.LogSigmoid()(pwhole - prest)) 201 | loss_3 = - torch.mean(torch.log(pfake + 1e-24)) 202 | loss0 = loss_0 + loss_1 + loss_2 + loss_3 203 | if (self.G1_flag == 1): 204 | return loss0, (loss_1, loss_2, loss_3, loss_3) 205 | else: 206 | loss_4 = - torch.mean(torch.log((pwhole + 1e-24) )) 207 | loss0 = loss_0 + loss_1 + loss_2 + loss_3 + loss_4 208 | return loss0, (loss_1, loss_2, loss_3, loss_4) 209 | 210 | class Generator1(nn.Module): 211 | def __init__(self, config, device, dropout_p=0.2): 212 | super(Generator1, self).__init__() 213 | self.device = device 214 | self.dropout_p = config.dropout 215 | self.input_size = config.num_product 216 | self.hidden_size = config.embedding_dim//2 217 | self.max_basket_size = config.max_basket_size 218 | 219 | self.same_embedding = config.same_embedding 220 | 221 | self.soft = config.soft 222 | self.temp_learn = config.temp_learn 223 | self.temp = nn.Parameter(torch.ones(1)* config.temp) 224 | self.temp_init = config.temp 225 | 226 | self.embed = nn.Embedding(config.num_product + 1, self.hidden_size, padding_idx=0) 227 | # self.W = nn.Linear(self.hidden_size, self.hidden_size) 228 | self.dropout = nn.Dropout(self.dropout_p) 229 | 230 | self.judge_model = nn.Sequential( 231 | nn.Linear(self.hidden_size * 2, self.hidden_size), 232 | nn.Dropout(self.dropout_p), 233 | nn.LeakyReLU(inplace=True), 234 | nn.Linear(self.hidden_size, 2) 235 | ) 236 | self.judge_model1 = nn.Linear(self.hidden_size * 2, 2) 237 | 238 | # profile 239 | def init_weight(self): 240 | for name, parms in self.named_parameters(): # TODO 241 | parms.data.normal_(0, 0.1) 242 | torch.nn.init.xavier_normal_(self.embed.weight.data) # good 243 | self.temp = nn.Parameter(torch.ones(1).to(self.device)* self.temp_init) # * config.temp # TODO 这个必须弄,不然templearn=1时它一开始rest_ratio = 0.99 244 | 245 | def forward(self, input_seq_tensor, T, target_tensor, G1_flag=1,test=0,input_embeddings = None,target_embedding = None): 246 | ''' 247 | :param input_seq_tensor: K * B 248 | :param T: 249 | :param target_tensor: 250 | :param G1_flag: 251 | :param test: 252 | :param input_embeddings: 253 | :param target_embedding: 254 | :return: 255 | ''' 256 | def hook_fn(grad): 257 | print(grad) 258 | 259 | 260 | if self.same_embedding == 0: 261 | target_embedding = self.embed(target_tensor + 1) 262 | input_embeddings = self.embed(input_seq_tensor + 1) 263 | 264 | in_tar = torch.cat( 265 | (input_embeddings, target_embedding.view(target_embedding.size(0), 1, -1).expand_as(input_embeddings)), 266 | dim=2) # K*B*2H 267 | in_tar = self.dropout(in_tar) 268 | resnet_o_prob = self.judge_model1(in_tar) 269 | o_prob = self.judge_model(in_tar) # K*B*2 270 | 271 | o_prob = (o_prob + resnet_o_prob) # 272 | # att_prob = torch.sigmoid(o_prob) 273 | o_prob = torch.softmax(o_prob, dim=-1) 274 | 275 | 276 | if self.temp_learn == 1: 277 | if self.temp > 0: 278 | prob_hard, prob_soft = self.gumbel_softmax(torch.log(o_prob + 1e-24), self.temp, hard=True,input_seq_tensor = input_seq_tensor) 279 | else: 280 | prob_hard, prob_soft = self.gumbel_softmax(torch.log(o_prob + 1e-24), 0.3, hard=True,input_seq_tensor = input_seq_tensor) 281 | else: 282 | prob_hard, prob_soft = self.gumbel_softmax(torch.log(o_prob + 1e-24), T, hard=True,input_seq_tensor = input_seq_tensor) 283 | 284 | prob_soft_new = prob_hard*prob_soft 285 | 286 | 287 | if test == 0 :#and G1_flag != 0: 288 | return prob_soft_new, None 289 | else: 290 | if self.temp_learn == 1: 291 | if self.temp > 0: 292 | o_prob_hard, o_prob_soft = self.gumbel_test(torch.log(o_prob + 1e-24), self.temp) 293 | else: 294 | o_prob_hard, o_prob_soft = self.gumbel_test(torch.log(o_prob + 1e-24), 0.3) 295 | else: 296 | o_prob_hard, o_prob_soft = self.gumbel_test(torch.log(o_prob + 1e-24), T) 297 | 298 | test_prob_hard = o_prob_hard * o_prob_soft 299 | test_prob_hard = test_prob_hard.detach() 300 | 301 | return test_prob_hard, test_prob_hard[:, :, 0] 302 | 303 | def sample_gumbel(self, shape, eps=1e-20): 304 | U = torch.rand(shape).to(self.device) 305 | return -torch.log(-torch.log(U + eps) + eps) 306 | 307 | def gumbel_softmax_sample(self, logits, temperature,input_seq_tensor = None): 308 | ''' 309 | :param logits: # K*B*2 310 | :param temperature: 311 | :param input_seq_tensor: K*B 312 | :return: 313 | ''' 314 | sample = self.sample_gumbel([int(self.input_size+1),2],eps= 1e-20) # n_items+1 * 2 315 | x_index = input_seq_tensor.clone()+1 #K*B 316 | x_index = x_index.unsqueeze(2).repeat(1,1,2) # K*B*2 317 | # print(x_index.size()) 318 | y_index = torch.zeros_like(input_seq_tensor,dtype = torch.long).to(self.device).unsqueeze(2) #K*B*1 319 | y_index1 = torch.ones_like(input_seq_tensor, dtype=torch.long).to(self.device).unsqueeze(2) # K*B*1 320 | y_index = torch.cat((y_index,y_index1),dim = 2) #K*B*2 321 | # print(y_index.size()) 322 | sample_logits = sample[x_index.long(),y_index.long()] 323 | y = logits + sample_logits 324 | # y = logits + self.sample_gumbel(logits.size()) 325 | return F.softmax(y / temperature, dim=-1) 326 | 327 | def gumbel_softmax(self, logits, temperature, hard=False,input_seq_tensor = None): 328 | """ 329 | ST-gumple-softmax 330 | input: [*, n_class] 331 | return: flatten --> [*, n_class] an one-hot vector 332 | """ 333 | y = self.gumbel_softmax_sample(logits, temperature,input_seq_tensor) 334 | 335 | if not hard: 336 | return y # .view(-1, latent_dim * categorical_dim) 337 | 338 | shape = y.size() 339 | _, ind = y.max(dim=-1) 340 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 341 | y_hard.scatter_(1, ind.view(-1, 1), 1) 342 | y_hard = y_hard.view(*shape) 343 | # Set gradients w.r.t. y_hard gradients w.r.t. y 344 | # y_hard = (y_hard - y).detach() + y 345 | return y_hard, y # .view(-1, latent_dim * categorical_dim) 346 | 347 | def gumbel_test(self, logits, temperature): 348 | y = logits 349 | y = F.softmax(y / temperature, dim=-1) 350 | shape = y.size() 351 | _, ind = y.max(dim=-1) 352 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 353 | y_hard.scatter_(1, ind.view(-1, 1), 1) 354 | y_hard = y_hard.view(*shape) 355 | return y_hard, y # .view(-1, latent_dim * categorical_dim) 356 | 357 | class Generator2(nn.Module): 358 | def __init__(self, config, device, dropout_p=0.2): 359 | super(Generator2, self).__init__() 360 | self.device = device 361 | self.dropout_p = config.dropout 362 | self.input_size = config.num_product 363 | self.num_users = config.num_users 364 | self.hidden_size = config.embedding_dim 365 | self.basket_pool_type = config.basket_pool_type 366 | 367 | self.dropout = nn.Dropout(self.dropout_p) 368 | self.max_basket_size = config.max_basket_size 369 | self.bidirectional = False 370 | self.batch_first = True 371 | self.num_layer = config.num_layer 372 | self.gru_hidden_size = self.hidden_size//2 373 | 374 | self.embed1 = nn.Embedding(config.num_product + 1, self.gru_hidden_size, padding_idx=0) 375 | self.user_embed1 = nn.Embedding(config.num_users, self.gru_hidden_size) 376 | 377 | self.gru1 = nn.GRU(self.gru_hidden_size, 378 | self.gru_hidden_size, 379 | num_layers=self.num_layer, 380 | bidirectional=self.bidirectional, 381 | batch_first=self.batch_first) 382 | 383 | 384 | # profile 385 | def forward(self, filter_basket_prob, input_seq_tensor,uid = None): 386 | ''' 387 | :param filter_basket_prob: K * B [13] [2] [0] 2 388 | :param input_embeddings: K * B * H 389 | :param input_seq_tensor: K * B ,with padding -1 390 | :return: fake_target_embeddings # K * hidden_size 391 | # basket_embedding K*basket_num*H 392 | ''' 393 | 394 | def hook_fn(grad): 395 | print(grad) 396 | 397 | user_embedding = self.user_embed1(uid) 398 | mask = (input_seq_tensor != -1).detach().int() 399 | mask = torch.tensor(mask,dtype = torch.float).to(self.device) 400 | filter_basket_prob = filter_basket_prob * mask # K * B 401 | # K*B*H 402 | 403 | input_embeddings1 = self.embed1(input_seq_tensor + 1) 404 | 405 | 406 | 407 | input_embeddings_f1 = torch.mul(input_embeddings1, 408 | filter_basket_prob.view(filter_basket_prob.size(0), -1, 1).expand( 409 | filter_basket_prob.size(0), -1, self.gru_hidden_size)) 410 | 411 | input_embeddings_f1 = input_embeddings_f1.view(input_embeddings_f1.size(0), -1, self.max_basket_size, 412 | self.gru_hidden_size) # K*basket_num*max_basket_size*H 413 | 414 | filter_b = torch.max(filter_basket_prob.view(filter_basket_prob.size(0), -1, self.max_basket_size), dim=-1)[ 415 | 0] # K*basket_num #### 416 | 417 | if self.basket_pool_type == 'avg': 418 | filtered_tensor = filter_basket_prob.view(filter_basket_prob.size(0), -1, 1).expand( 419 | filter_basket_prob.size(0), -1, self.gru_hidden_size).view(input_embeddings_f1.size(0), -1, 420 | self.max_basket_size, 421 | self.gru_hidden_size) 422 | basket_embedding1 = (torch.sum(input_embeddings_f1, dim=2) / ( 423 | filtered_tensor.sum(dim=2) + 1e-10)) # K*basket_num*H 424 | else: 425 | mask_inf = filter_basket_prob.view(filter_basket_prob.size(0), -1, 1).expand( 426 | filter_basket_prob.size(0), -1, self.gru_hidden_size).int() 427 | mask_inf = (1 - mask_inf) * (-9999) 428 | mask_inf = mask_inf.view(mask_inf.size(0), -1, self.max_basket_size, 429 | self.gru_hidden_size) 430 | input_embeddings_f1 = input_embeddings_f1 + mask_inf 431 | basket_embedding1 = (torch.max(input_embeddings_f1, dim=2)[0]) # K*basket_num*H 432 | 433 | input_filter_b = (filter_b > 0).detach().int() # K*basket_num ( value:0/1) 434 | sorted, indices = torch.sort(input_filter_b, descending=True) 435 | lengths = torch.sum(sorted, dim=-1).squeeze().view(1, -1).squeeze(0) 436 | length_mask = (lengths == 0).int() 437 | length_mask = torch.tensor(length_mask, dtype=torch.long).to(self.device) 438 | lengths = lengths + length_mask 439 | inputs1 = basket_embedding1.gather(dim=1, 440 | index=indices.unsqueeze(2).expand_as( 441 | basket_embedding1)) # K*basket_num*H 442 | 443 | # sort data by lengths 444 | _, idx_sort = torch.sort(lengths, dim=0, descending=True) 445 | _, idx_unsort = torch.sort(idx_sort, dim=0) 446 | sort_embed_input1 = inputs1.index_select(0, Variable(idx_sort)) 447 | sort_lengths = lengths[idx_sort] 448 | 449 | sort_lengths = torch.tensor(sort_lengths.clone().cpu(), dtype=torch.int64) 450 | inputs_packed1 = nn.utils.rnn.pack_padded_sequence(sort_embed_input1, 451 | sort_lengths, 452 | batch_first=True) 453 | # process using RNN 454 | out_pack1, ht1 = self.gru1(inputs_packed1) 455 | raw_o = nn.utils.rnn.pad_packed_sequence(out_pack1, batch_first=True) 456 | raw_o = raw_o[0] 457 | raw_o = raw_o[idx_unsort] 458 | x = torch.tensor(np.linspace(0, raw_o.size(0), num=raw_o.size(0), endpoint=False), dtype=torch.long).to(self.device) 459 | y = lengths - 1 460 | outputs_last = raw_o[x, y] # 2,2,6 461 | 462 | # ht1 = torch.transpose(ht1, 0, 1)[idx_unsort] 463 | # ht1 = torch.transpose(ht1, 0, 1) 464 | # out1 = self.fc1(ht1[-1]) # .squeeze() 465 | # out1 = self.fc1(outputs_last) 466 | 467 | return outputs_last # K * hidden_size 468 | 469 | class Discriminator(nn.Module): 470 | def __init__(self, config, device, dropout_p=0.2): 471 | super(Discriminator, self).__init__() 472 | self.device = device 473 | self.dropout_p = config.dropout 474 | self.input_size = config.num_product 475 | self.hidden_size = config.embedding_dim 476 | self.max_basket_size = config.max_basket_size 477 | self.gru_hidden_size = self.hidden_size//2 478 | 479 | self.fc1 = nn.Linear(self.gru_hidden_size, self.hidden_size) 480 | 481 | # TODO hidden_size 482 | self.judge_model1 = nn.Sequential( 483 | nn.Linear(self.hidden_size, self.hidden_size), 484 | nn.Dropout(0), 485 | nn.LeakyReLU(inplace=True), 486 | nn.Linear(self.hidden_size, self.input_size) 487 | ) 488 | self.judge_model2 = nn.Sequential( 489 | nn.Linear(self.hidden_size, self.hidden_size), 490 | nn.Dropout(0), 491 | nn.LeakyReLU(inplace=True), 492 | nn.Linear(self.hidden_size, self.input_size) 493 | ) 494 | 495 | self.dropout = nn.Dropout(dropout_p) 496 | # self.histroy = config.histroy 497 | # if self.histroy == 1: 498 | # self.attn = nn.Linear(self.input_size, self.input_size) 499 | 500 | # profile 501 | def forward(self, item_embeddings1, history_record, target_tensor, input_seq_tensor=None): # K * hidden_size 1*9963 502 | def hook_fn(grad): 503 | print(grad) 504 | 505 | item_embeddings1 = self.fc1(item_embeddings1) 506 | item_embeddings1 = self.dropout(item_embeddings1) 507 | judge1 = self.judge_model1(item_embeddings1) # K*input_size 508 | judge2 = self.judge_model2(item_embeddings1) # K*input_size 509 | judge = judge2 + judge1 510 | 511 | return judge # K * n_items 512 | 513 | 514 | 515 | -------------------------------------------------------------------------------- /CLEA-new/module/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | import random 5 | from collections import defaultdict 6 | 7 | # random.seed(11) 8 | 9 | def get_dict(path): 10 | f = open(path, 'r') 11 | a = f.read() 12 | geted_dict = eval(a) 13 | f.close() 14 | return geted_dict 15 | 16 | 17 | def get_distribute_items(n_items,input_dir,ratio = 0.75): 18 | user_tran_date_dict = get_dict(input_dir) 19 | count = [0.0] * n_items 20 | count_all = 0 21 | for idx, userid in enumerate(list(user_tran_date_dict.keys())): 22 | for basket in user_tran_date_dict[userid]: 23 | for item in basket: 24 | count[item] += 1 25 | count_all += 1 26 | p_item = np.array(count) 27 | 28 | p_item_tensor = torch.from_numpy(np.array(p_item)) 29 | p_item_tensor = torch.pow(p_item_tensor, ratio) 30 | p_item = np.array(p_item_tensor) 31 | # p_item = p_item / count_all 32 | # precision = list(precision.cpu().numpy()) 33 | return p_item 34 | 35 | def get_all_neg_p(neg_sample,p_item): 36 | neg_sample_neg_p = dict() 37 | for u in neg_sample: 38 | neg_index = neg_sample[u] 39 | p_neg = p_item[neg_index] 40 | p_neg = p_neg / np.sum(p_neg) 41 | 42 | if np.sum(p_neg) == 1: 43 | return p_neg 44 | else: 45 | p_neg[0] += (1 - np.sum(p_neg)) 46 | neg_sample_neg_p[u] = p_neg 47 | return neg_sample_neg_p 48 | 49 | def get_neg_p(p_item,neg_set): 50 | neg_index = neg_set#torch.tensor(neg_set,dtype=torch.long).to(device) 51 | p_neg = p_item[neg_index] 52 | p_neg = p_neg / np.sum(p_neg) 53 | 54 | if np.sum(p_neg) == 1: 55 | return p_neg 56 | else: 57 | p_neg[0]+= (1 - np.sum(p_neg)) 58 | return p_neg 59 | 60 | # @profile 61 | def get_dataset(input_dir, max_basket_size,max_basket_num,neg_ratio,history = 0,next_k = 1): 62 | print("--------------Begin Data Process--------------") 63 | neg_ratio = 1 64 | user_tran_date_dict_old = get_dict(input_dir) 65 | 66 | user_tran_date_dict = dict() 67 | for userid in user_tran_date_dict_old.keys(): 68 | seq = user_tran_date_dict_old[userid] 69 | if len(seq) > max_basket_num: 70 | seq = seq[-max_basket_num:] 71 | if len(seq) < 1 + next_k: continue 72 | for b_id, basket in enumerate(seq): 73 | if len(basket) > max_basket_size: 74 | seq[b_id] = basket[-max_basket_size:] 75 | user_tran_date_dict[userid] = seq 76 | 77 | 78 | train_times = 0 79 | valid_times = 0 80 | test_times = 0 81 | 82 | itemnum = 0 83 | for userid in user_tran_date_dict.keys(): 84 | seq = user_tran_date_dict[userid] 85 | for basket in seq: 86 | for item in basket: 87 | if item > itemnum: 88 | itemnum = item 89 | itemnum = itemnum + 1 90 | item_list = [i for i in range(0, itemnum)] 91 | 92 | result_vector = np.zeros(itemnum) 93 | basket_count = 0 94 | for userid in user_tran_date_dict.keys(): 95 | seq = user_tran_date_dict[userid][:-next_k] 96 | for basket in seq: 97 | basket_count += 1 98 | result_vector[basket] += 1 99 | weights = np.zeros(itemnum) 100 | max_freq = basket_count # max(result_vector) 101 | for idx in range(len(result_vector)): 102 | if result_vector[idx] > 0: 103 | weights[idx] = max_freq / result_vector[idx] 104 | else: 105 | weights[idx] = 0 106 | 107 | TRAIN_DATASET = [] 108 | train_batch = defaultdict(list) 109 | VALID_DATASET = [] 110 | valid_batch = defaultdict(list) 111 | TEST_DATASET = [] 112 | test_batch = defaultdict(list) 113 | neg_sample = dict() 114 | 115 | # train_userid_list = list(user_tran_date_dict.keys())[:math.ceil(0.9 * len(list(user_tran_date_dict.keys())))] 116 | all_user_num = len(list(user_tran_date_dict.keys())) 117 | train_user_num = 0 118 | train_userid_list = list(user_tran_date_dict.keys())[:math.ceil(0.9 * len(list(user_tran_date_dict.keys())))] 119 | 120 | for userid in user_tran_date_dict.keys(): 121 | if userid in train_userid_list: 122 | seq = user_tran_date_dict[userid][:-next_k] 123 | else: 124 | seq = user_tran_date_dict[userid][:-next_k] 125 | seq_pool = [] 126 | for basket in seq: 127 | seq_pool = seq_pool + basket 128 | neg_sample[userid] = list(set(item_list) - set(seq_pool)) 129 | 130 | for userid in user_tran_date_dict.keys(): 131 | if userid in train_userid_list: 132 | seq = user_tran_date_dict[userid] 133 | before = [] 134 | train_seq = seq[:-1] 135 | for basketid, basket in enumerate(train_seq): 136 | if len(basket) > max_basket_size: 137 | basket = basket[-max_basket_size:] 138 | else: 139 | padd_num = max_basket_size - len(basket) 140 | padding_item = [-1] * padd_num 141 | basket = basket + padding_item 142 | before.append(basket) 143 | if len(before) == 1: continue 144 | U = userid 145 | S = before[:-1] 146 | S_pool = [] 147 | H = np.zeros(itemnum) 148 | H_pad = np.zeros(itemnum + 1) 149 | for basket in S: 150 | S_pool = S_pool + basket 151 | no_pad_basket = list(set(basket)-set([-1])) 152 | H[no_pad_basket] += 1 153 | H = H / len(before[:-1]) 154 | H_pad[1:] = H 155 | L = len(before[:-1]) 156 | tar_basket = train_seq[basketid] 157 | for item in tar_basket: 158 | T = item 159 | N = random.sample(neg_sample[userid], neg_ratio) 160 | train_batch[L].append((U, S_pool, T, H_pad[0:2], N, L)) 161 | train_times += 1 162 | 163 | test_seq = seq 164 | before = [] 165 | for basketid, basket in enumerate(test_seq): 166 | if len(basket) > max_basket_size: 167 | basket = basket[-max_basket_size:] 168 | else: 169 | padd_num = max_basket_size - len(basket) 170 | padding_item = [-1] * padd_num 171 | basket = basket + padding_item 172 | before.append(basket) 173 | U = userid 174 | S = list(before[:-1]) 175 | S_pool = [] 176 | H = np.zeros(itemnum) 177 | H_pad = np.zeros(itemnum+1) 178 | for basket in S: 179 | S_pool = S_pool + basket 180 | no_pad_basket = list(set(basket) - set([-1])) 181 | H[no_pad_basket] += 1 182 | H = H / len(S) 183 | H_pad[1:] = H 184 | L = len(before[:-1]) 185 | T_basket = before[-1] 186 | test_batch[L].append((U, S_pool, T_basket, H_pad[0:2], L)) 187 | test_times += 1 188 | 189 | 190 | else: 191 | seq = user_tran_date_dict[userid] 192 | before = [] 193 | valid_seq = seq 194 | for basketid, basket in enumerate(valid_seq): 195 | if len(basket) > max_basket_size: 196 | basket = basket[-max_basket_size:] 197 | else: 198 | padd_num = max_basket_size - len(basket) 199 | padding_item = [-1] * padd_num 200 | basket = basket + padding_item 201 | before.append(basket) 202 | if len(before) == 1: continue 203 | if len(before) < len(valid_seq): continue 204 | U = userid 205 | S = before[:-1] 206 | S_pool = [] 207 | H = np.zeros(itemnum) 208 | H_pad = np.zeros(itemnum + 1) 209 | for basket in S: 210 | S_pool = S_pool + basket 211 | no_pad_basket = list(set(basket) - set([-1])) 212 | H[no_pad_basket] += 1 213 | H = H / len(S) 214 | H_pad[1:] = H 215 | L = len(before[:-1]) 216 | tar_basket = valid_seq[basketid] 217 | 218 | if history == 0: 219 | tar_basket = list(set(tar_basket)-set(S_pool)) 220 | if len(tar_basket) < 1:continue 221 | padd_num = max_basket_size - len(tar_basket) 222 | padding_item = [-1] * padd_num 223 | T_basket = tar_basket + padding_item 224 | valid_batch[L].append((U, S_pool, T_basket, H_pad[0:2], L)) 225 | valid_times += 1 226 | else: 227 | T_basket = before[-1] 228 | valid_batch[L].append((U, S_pool, T_basket, H_pad[0:2], L)) 229 | valid_times += 1 230 | 231 | for l in train_batch.keys(): 232 | TRAIN_DATASET.append(list(zip(*train_batch[l]))) 233 | 234 | for l in test_batch.keys(): 235 | TEST_DATASET.append(list(zip(*test_batch[l]))) 236 | 237 | for l in valid_batch.keys(): 238 | VALID_DATASET.append(list(zip(*valid_batch[l]))) 239 | 240 | 241 | 242 | print("--------------Data Process is Over--------------") 243 | return TRAIN_DATASET, VALID_DATASET, TEST_DATASET, neg_sample, weights, itemnum, train_times, test_times, valid_times 244 | 245 | 246 | 247 | # @profile 248 | def get_batch_TRAIN_DATASET(dataset, batch_size): 249 | print('--------------Data Process is Begin--------------') 250 | random.shuffle(dataset) 251 | for idx, (UU, SS, TT, HH, NN, LL) in enumerate(dataset): 252 | userid = torch.tensor(UU, dtype=torch.long) 253 | input_seq = torch.tensor(SS, dtype=torch.long) 254 | target = torch.tensor(TT, dtype=torch.long) 255 | history = torch.from_numpy(np.array(HH)).float() 256 | neg_items = torch.tensor(NN, dtype=torch.long) 257 | 258 | if SS.__len__() < 2: 259 | continue 260 | if SS.__len__() <= batch_size: 261 | batch_userid = userid 262 | batch_input_seq = input_seq 263 | batch_target = target 264 | batch_history = history 265 | batch_neg_items = neg_items 266 | yield (batch_userid,batch_input_seq,batch_target,batch_history,batch_neg_items) 267 | else: 268 | batch_begin = 0 269 | while (batch_begin + batch_size) <= SS.__len__(): 270 | batch_userid = userid[batch_begin:batch_begin + batch_size] 271 | batch_input_seq = input_seq[batch_begin:batch_begin + batch_size] 272 | batch_target = target[batch_begin:batch_begin + batch_size] 273 | batch_history = history[batch_begin:batch_begin + batch_size] 274 | batch_neg_items = neg_items[batch_begin:batch_begin + batch_size] 275 | yield (batch_userid, batch_input_seq, batch_target, batch_history, batch_neg_items) 276 | batch_begin = batch_begin + batch_size 277 | if (batch_begin + batch_size > SS.__len__()) & (batch_begin < SS.__len__()): 278 | 279 | batch_userid = userid[batch_begin:] 280 | batch_input_seq = input_seq[batch_begin:] 281 | batch_target = target[batch_begin:] 282 | batch_history = history[batch_begin:] 283 | batch_neg_items = neg_items[batch_begin:] 284 | yield (batch_userid, batch_input_seq, batch_target, batch_history, batch_neg_items) 285 | 286 | 287 | 288 | 289 | # @profile 290 | def get_batch_TEST_DATASET(TEST_DATASET, batch_size): 291 | BATCHES = [] 292 | random.shuffle(TEST_DATASET) 293 | for idx, (UU, SS, TT_bsk, HH, LL) in enumerate(TEST_DATASET): 294 | 295 | userid = torch.tensor(UU, dtype=torch.long) 296 | input_seq = torch.tensor(SS, dtype=torch.long) 297 | try: 298 | target = torch.tensor(TT_bsk, dtype=torch.long) 299 | except ValueError: 300 | print(TT_bsk) 301 | history = torch.from_numpy(np.array(HH)).float() 302 | 303 | assert UU.__len__() == SS.__len__() 304 | assert UU.__len__() == TT_bsk.__len__() 305 | assert UU.__len__() == HH.__len__() 306 | 307 | if SS.__len__() < 1: continue 308 | if SS.__len__() <= batch_size: 309 | batch_userid = userid 310 | batch_input_seq = input_seq 311 | batch_target = target 312 | batch_history = history 313 | yield (batch_userid, batch_input_seq, batch_target, batch_history) 314 | else: 315 | batch_begin = 0 316 | while (batch_begin + batch_size) <= SS.__len__(): 317 | batch_userid = userid[batch_begin:batch_begin + batch_size] 318 | batch_input_seq = input_seq[batch_begin:batch_begin + batch_size] 319 | batch_target = target[batch_begin:batch_begin + batch_size] 320 | batch_history = history[batch_begin:batch_begin + batch_size] 321 | yield (batch_userid, batch_input_seq, batch_target, batch_history) 322 | batch_begin = batch_begin + batch_size 323 | 324 | if (batch_begin + batch_size > SS.__len__()) & (batch_begin < SS.__len__()): 325 | batch_userid = userid[batch_begin:] 326 | batch_input_seq = input_seq[batch_begin:] 327 | batch_target = target[batch_begin:] 328 | batch_history = history[batch_begin:] 329 | yield (batch_userid, batch_input_seq, batch_target, batch_history) -------------------------------------------------------------------------------- /CLEA.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QYQ-bot/CLEA/b830735b9dd69e3226e174ac920bce679ed57573/CLEA.pdf -------------------------------------------------------------------------------- /CLEA/Dunn_0.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | for dim in 64 5 | do 6 | echo "$dim " 7 | python main_1.py \ 8 | --pos_margin 0.1 \ 9 | --neg_margin 0.9 \ 10 | --same_embedding 1 \ 11 | --dataset 'Dunn' \ 12 | --max_basket_size 50 \ 13 | --max_basket_num 6 \ 14 | --num_product 4995 \ 15 | --num_users 36421 \ 16 | --alternative_train_batch 500 \ 17 | --test_every_epoch 4 \ 18 | --G1_flag 0 \ 19 | --device 1 \ 20 | --log_fire 'basemodel' \ 21 | --dropout 0.2 \ 22 | --lr 0.001 \ 23 | --l2 0.00001 \ 24 | --output_dir './result' \ 25 | --pretrain_epoch 20 \ 26 | --before_epoch 0 \ 27 | --epoch 10 \ 28 | --batch_size 256 \ 29 | --embedding_dim $dim \ 30 | --temp_learn 1 \ 31 | --history 1 32 | 33 | 34 | python main_1.py \ 35 | --pos_margin 0.1 \ 36 | --neg_margin 0.9 \ 37 | --same_embedding 1 \ 38 | --dataset 'Dunn' \ 39 | --max_basket_size 50 \ 40 | --max_basket_num 6 \ 41 | --num_product 4995 \ 42 | --num_users 36421 \ 43 | --alternative_train_epoch 10 \ 44 | --alternative_train_epoch_D 10 \ 45 | --alternative_train_batch 200 \ 46 | --test_every_epoch 4 \ 47 | --G1_flag 0 \ 48 | --device 1 \ 49 | --log_fire 't_10_ANN_0001_lr_001_G1_001_pos_0.1' \ 50 | --dropout 0.2 \ 51 | --lr 0.001 \ 52 | --G1_lr 0.001 \ 53 | --l2 0.00001 \ 54 | --output_dir './result' \ 55 | --pretrain_epoch 2 \ 56 | --before_epoch 2 \ 57 | --epoch 30 \ 58 | --batch_size 256 \ 59 | --embedding_dim $dim \ 60 | --temp_learn 1 \ 61 | --temp_min 0.2 \ 62 | --ANNEAL_RATE 0.0001 \ 63 | --temp 10 \ 64 | --history 1 65 | done -------------------------------------------------------------------------------- /CLEA/Instacart_0.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | for dim in 64 4 | do 5 | echo "$dim " 6 | python main_1.py \ 7 | --pos_margin 0.1 \ 8 | --neg_margin 0.9 \ 9 | --same_embedding 1 \ 10 | --dataset 'Instacart' \ 11 | --num_product 8222 \ 12 | --num_users 6886 \ 13 | --max_basket_size 35 \ 14 | --max_basket_num 32 \ 15 | --alternative_train_batch 1000 \ 16 | --test_every_epoch 4 \ 17 | --G1_flag 0 \ 18 | --device 0 \ 19 | --log_fire 'basemodel' \ 20 | --dropout 0.2 \ 21 | --lr 0.001 \ 22 | --l2 0.00001 \ 23 | --output_dir './result' \ 24 | --pretrain_epoch 20 \ 25 | --before_epoch 0 \ 26 | --epoch 10 \ 27 | --batch_size 256 \ 28 | --embedding_dim $dim \ 29 | --temp_learn 1 \ 30 | --history 1 31 | 32 | 33 | python main_1.py \ 34 | --same_embedding 1 \ 35 | --pos_margin 0.1 \ 36 | --neg_margin 0.9 \ 37 | --dataset 'Instacart' \ 38 | --num_product 8222 \ 39 | --num_users 6886 \ 40 | --max_basket_size 35 \ 41 | --max_basket_num 32 \ 42 | --alternative_train_batch 1000 \ 43 | --alternative_train_epoch 5 \ 44 | --alternative_train_epoch_D 5 \ 45 | --test_every_epoch 4 \ 46 | --G1_flag 0 \ 47 | --device 0 \ 48 | --log_fire 't_10_ANN_0001_same_1_lr_001_G1_001_diftemp0_0_pos_01_epoch_G5_D5' \ 49 | --dropout 0.2 \ 50 | --lr 0.001 \ 51 | --G1_lr 0.001 \ 52 | --l2 0.00001 \ 53 | --output_dir './result' \ 54 | --pretrain_epoch 2 \ 55 | --before_epoch 2 \ 56 | --epoch 40 \ 57 | --batch_size 128 \ 58 | --embedding_dim $dim \ 59 | --temp_learn 0 \ 60 | --temp_min 0.3 \ 61 | --ANNEAL_RATE 0.0001 \ 62 | --temp 10 \ 63 | --history 1 64 | done 65 | 66 | -------------------------------------------------------------------------------- /CLEA/main_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import math 5 | import time 6 | import scipy.sparse as sp 7 | from module.util import * 8 | from module.model_1 import NBModel 9 | from module.config import Config 10 | import pickle 11 | import os.path 12 | from module.logger import Logger 13 | # from torch.utils.tensorboard import SummaryWriter 14 | 15 | 16 | def load_dataset(fname, fuc): 17 | dataset = fuc() 18 | return dataset 19 | 20 | 21 | # @profile 22 | def train(): 23 | 24 | if torch.cuda.is_available(): 25 | torch.cuda.set_device(Config().device_id) 26 | 27 | isExist = os.path.exists(Config().output_dir) 28 | if not isExist: 29 | os.makedirs(Config().output_dir) 30 | 31 | logger_path = os.path.join(Config().output_dir, 'NB_{}_{}_{}.log'.format(Config().embedding_dim/2,Config().dataset, Config().log_fire)) 32 | logger = Logger(logger_path) 33 | logger.info('*' * 150) 34 | logger.info(' update ndcg ') 35 | logger.info(' ************************************* model_1 *************************************** ') 36 | input_dir = os.path.join('./', Config().input_dir) 37 | fuc = lambda x=None: get_dataset(input_dir, Config().max_basket_size, Config().max_basket_num, 38 | Config().neg_ratio, Config().histroy) # ,Config().histroy 39 | dataset_path = os.path.join('./', 'dataset_{}_history_{}.pkl'.format(Config().dataset, Config().histroy)) 40 | dataset = load_dataset(dataset_path, fuc) 41 | 42 | TRAIN_DATASET, VALID_DATASET, TEST_DATASET, neg_sample, weights, itemnum, train_times, test_times, valid_times = dataset 43 | logger.info('test user nums : {} valid user nums : {}'.format(test_times, valid_times)) 44 | weights = torch.tensor(weights, dtype=torch.float32).to(device) 45 | 46 | Config().list_all_member(logger) 47 | NB = NBModel(Config(), device).to(device) 48 | 49 | def get_test_neg_set(test_type, batch_size, DATASET): 50 | neg_test_sample = dict() 51 | 52 | for batchid, (batch_userid, batch_input_seq, pad_batch_target_bsk, batch_history) in enumerate( 53 | get_batch_TEST_DATASET(DATASET, batch_size)): 54 | 55 | pad_batch_target_bsk = pad_batch_target_bsk.detach().cpu().numpy().tolist() # 56 | for bid, pad_target_bsk in enumerate(pad_batch_target_bsk): # 57 | uid_tensor = batch_userid[bid].to(device) 58 | uid = int(uid_tensor.detach().cpu().numpy()) 59 | 60 | tar_b = list(set(pad_target_bsk) - set([-1])) 61 | if Config().histroy == 0: 62 | neg_set = list(set(neg_sample[uid]) - set(tar_b)) 63 | S_pool = batch_input_seq[bid].cpu().numpy().tolist() 64 | tar_b = list(set(tar_b) - set(S_pool)) # 65 | else: 66 | neg_set = list(set(Config().item_list) - set(tar_b)) 67 | 68 | if len(tar_b) < 1: continue 69 | 70 | len_t = len(tar_b) 71 | if test_type > 0: 72 | neg_set = random.sample(neg_set, (test_type - len_t)) 73 | else: 74 | if len(neg_set) >= (Config().test_ratio * len(tar_b)): 75 | neg_set = random.sample(neg_set, tar_b.__len__() * Config().test_ratio) 76 | elif (Config().num_product) >= (Config().test_ratio * len(tar_b)): 77 | neg_set = list(set( 78 | random.sample(Config().item_list, tar_b.__len__() * Config().test_ratio)) - set(tar_b)) 79 | else: 80 | neg_set = list(set(Config().item_list) - set(tar_b)) # -set(tar_b) 81 | neg_test_sample[uid] = neg_set 82 | return neg_test_sample 83 | 84 | test_neg_set = get_test_neg_set(Config().test_type, Config().batch_size, TEST_DATASET) 85 | valid_neg_set = get_test_neg_set(Config().test_type, Config().batch_size, VALID_DATASET) 86 | sd = Config().sd2 87 | 88 | def train_model(epoch, G1_flag, pretrain, temp=1, batch_size=256, alternative_train_batch=100, temp0 = Config().temp_min): 89 | NB.train() 90 | flagg = 0 91 | start_time = time.clock() 92 | loss_all = 0 93 | p1_all = 0 94 | p2_all = 0 95 | p3_all = 0 96 | p4_all = 0 97 | repeat_ratio_all = 0 98 | real_delete_all = 0 99 | sum_all = 0 100 | loss_all_count = 0 101 | temp = temp 102 | batch_num = 0 103 | for batchid, (batch_userid, batch_input_seq, batch_target, batch_history, batch_neg_items) in enumerate( 104 | get_batch_TRAIN_DATASET(TRAIN_DATASET, batch_size)): 105 | batch_num += 1 106 | if G1_flag == 0: 107 | temp = temp0#Config().temp_min 108 | else: 109 | if batchid > 1: 110 | if batchid % alternative_train_batch == 1: 111 | temp = np.maximum(temp * np.exp(-Config().ANNEAL_RATE * batchid), Config().temp_min) 112 | batch_userid = batch_userid.to(device) 113 | batch_input_seq = batch_input_seq.to(device) 114 | batch_target = batch_target.to(device) 115 | batch_history = batch_history.to(device) 116 | batch_neg_items = batch_neg_items.to(device) 117 | neg_set = [] 118 | for u in batch_userid.detach().cpu().numpy().tolist(): 119 | neg_items = random.sample(neg_sample[u], Config().neg_ratio) 120 | neg_set.append(neg_items) 121 | batch_neg_items = torch.tensor(neg_set, dtype=torch.long).to(device) 122 | 123 | # if batch_target.__len__() < 2: continue 124 | 125 | loss, _, (p1, p2, p3, p4), (real_rest_sum, real_rest_sum_int, all_sum, rest_ratio,repeat_ratio) = NB(temp, batch_userid, 126 | batch_input_seq, 127 | batch_target, 128 | weights, 129 | batch_history, 130 | batch_neg_items, 131 | train=True, 132 | G1flag=G1_flag, 133 | pretrain=pretrain, 134 | sd2=sd) 135 | 136 | optimizer_dict[G1_flag].zero_grad() 137 | loss.backward() 138 | 139 | optimizer_dict[G1_flag].step() 140 | 141 | loss_all += loss.data.item() 142 | real_delete_all += real_rest_sum_int.data.item() 143 | repeat_ratio_all += repeat_ratio.data.item() 144 | sum_all += all_sum.data.item() 145 | p1_all += p1.data.item() 146 | p2_all += p2.data.item() 147 | p3_all += p3.data.item() 148 | p4_all += p4.data.item() 149 | loss_all_count += 1 150 | 151 | if batchid % Config().log_interval == 0: 152 | elapsed = (time.clock() - start_time) / Config().log_interval 153 | cur_loss = loss.data.item() # turn tensor into float 154 | cur_p1 = p1.data.item() 155 | cur_p2 = p2.data.item() 156 | cur_p3 = p3.data.item() 157 | cur_p4 = p4.data.item() 158 | start_time = time.clock() 159 | logger.info( 160 | '[Training]| Epochs {:3d} | Batch {:5d} | ms/batch {:02.2f} | Loss {:05.4f} | p1_Loss {:05.4f} | p2_Loss {:05.4f} | p3_Loss {:05.4f} | p4_Loss {:05.4f} | ' 161 | .format(epoch, batchid, elapsed, cur_loss, cur_p1, cur_p2, cur_p3, cur_p4)) 162 | logger.info( 163 | 'real_rest_sum {:05.4f} | real_rest_sum_int {:05.4f} | all_sum {:05.4f} | rest_ratio {:05.4f} | repeat_ratio {:05.4f}'.format( 164 | real_rest_sum.data.item(), real_rest_sum_int.item(), all_sum.data.item(), 165 | rest_ratio.data.item(),repeat_ratio.data.item())) 166 | loss_all = loss_all / loss_all_count 167 | p1_all = p1_all / loss_all_count 168 | p2_all = p2_all / loss_all_count 169 | p3_all = p3_all / loss_all_count 170 | p4_all = p4_all / loss_all_count 171 | real_delete_ratio = real_delete_all / sum_all 172 | repeat_ratio_all = repeat_ratio_all/loss_all_count 173 | logger.info('batch_num: {}'.format(batch_num)) 174 | logger.info( 175 | '[Training]| Epochs {:3d} | loss_all {:05.4f} | p1_all {:05.4f} | p2_all {:05.4f} | p3_all {:05.4f} | p4_all {:05.4f} | real_rest_ratio {:05.4f} | repeat_ratio {:05.4f} |'.format( 176 | epoch, loss_all, p1_all, p2_all, p3_all, p4_all, real_delete_ratio, repeat_ratio_all)) 177 | return loss_all, temp 178 | 179 | def valid_model_1000_top5(epoch, G1_flag, test_type=0, pretrain=0, temp=1, batch_size=256): 180 | 181 | def get_index(prob, neg_set, tar_b): 182 | mask = get_tensor([neg_set]).to(device).view(1, -1).expand(prob.size(0), -1) # n_items +1 183 | mask = (torch.ones_like(mask).to(device) - mask) * (-9999) 184 | prob = prob + mask # K*n_items 185 | value_5, index_5 = torch.topk(prob, 5) 186 | 187 | tar_b_tensor = torch.tensor(tar_b).to(device) # 188 | 189 | item_num = tar_b_tensor.size(0) # 190 | index = torch.tensor(np.linspace(0, item_num, num=item_num, endpoint=False), dtype=torch.long) # K 191 | pfake = prob[index, tar_b_tensor] # K 192 | 193 | return value_5, index_5, pfake 194 | 195 | NB.eval() 196 | 197 | hit_ratio_5 = 0 198 | recall_5 = 0 199 | precision_5 = 0 200 | f1_5 = 0 201 | ndcg_5 = 0 202 | mrr_5 = 0 203 | 204 | time_count1 = 0 205 | 206 | has_fake_user_5 = 0 207 | fake_length_5 = 0 208 | 209 | test_num = 0 210 | temp = temp 211 | 212 | test_repeat_ratio = [] 213 | test_neg_repeat_ratio = [] 214 | test_real_ratio = [] 215 | test_neg_ratio = [] 216 | 217 | p_n_score_differences = [] 218 | 219 | with torch.no_grad(): 220 | for batchid, (batch_userid, batch_input_seq, pad_batch_target_bsk, batch_history) in enumerate( 221 | get_batch_TEST_DATASET(VALID_DATASET, batch_size)): 222 | if batchid % Config().alternative_train_batch == 1: 223 | temp = np.maximum(temp * np.exp(-Config().ANNEAL_RATE * batchid), Config().temp_min) 224 | pad_batch_target_bsk = pad_batch_target_bsk.detach().cpu().numpy().tolist() # 225 | for bid, pad_target_bsk in enumerate(pad_batch_target_bsk): ## 226 | uid_tensor = batch_userid[bid].to(device) 227 | uid = int(uid_tensor.detach().cpu().numpy()) 228 | 229 | tar_b = list(set(pad_target_bsk) - set([-1])) # 230 | if Config().histroy == 0: 231 | S_pool = batch_input_seq[bid].cpu().numpy().tolist() 232 | tar_b = list(set(tar_b) - set(S_pool)) # 233 | if len(tar_b) < 1: continue 234 | 235 | test_num += 1 236 | 237 | input_tensor = batch_input_seq[bid].to(device) 238 | history_tensor = batch_history[bid].to(device) 239 | 240 | len_t = len(tar_b) 241 | neg_set = random.sample(list(set(Config().item_list) - set(tar_b)), (test_type - len_t)) 242 | # neg_set = valid_neg_set[uid] 243 | 244 | minibatch_all_items = get_minibatch_split_all_items(neg_set) 245 | neg_set = tar_b + neg_set 246 | 247 | target_items_tensor = torch.tensor(tar_b, dtype=torch.long).to(device) # K 248 | input_tensor_expand = input_tensor.view(1, -1).expand(len(tar_b), -1) # K * B 249 | uid_tensor_expand = uid_tensor.expand(len(tar_b)) # K 250 | history_tensor_expand = history_tensor.view(1, -1).expand(len(tar_b), -1) # K * n_items 251 | rest_ratio,repeat_ratio, prob = NB(temp, uid_tensor_expand, input_tensor_expand, target_items_tensor, weights, 252 | history_tensor_expand, 253 | neg_set_tensor=None, train=False, G1flag=G1_flag, pretrain=pretrain, sd2=sd) 254 | rest_ratio = rest_ratio.detach() 255 | repeat_ratio = repeat_ratio.detach() 256 | test_real_ratio.append(rest_ratio.data.item()) 257 | test_repeat_ratio.append(repeat_ratio.data.item()) 258 | t_value_5, t_index_5, pfake = get_index(prob, neg_set, tar_b) 259 | value_5 = t_value_5 260 | index_5 = t_index_5 261 | pfake = pfake 262 | 263 | for id, target_items in enumerate(minibatch_all_items): 264 | target_items_tensor = torch.tensor(target_items, dtype=torch.long).to(device) # K 265 | input_tensor_expand = input_tensor.view(1, -1).expand(len(target_items), -1) # K * B 266 | uid_tensor_expand = uid_tensor.expand(len(target_items)) # K 267 | history_tensor_expand = history_tensor.view(1, -1).expand(len(target_items), -1) # K * n_items 268 | 269 | rest_ratio, repeat_ratio, prob = NB(temp, uid_tensor_expand, input_tensor_expand, target_items_tensor, 270 | weights, 271 | history_tensor_expand, 272 | neg_set_tensor=None, train=False, G1flag=G1_flag, pretrain=pretrain, 273 | sd2=sd) 274 | rest_ratio = rest_ratio.detach() 275 | repeat_ratio = repeat_ratio.detach() 276 | test_neg_ratio.append(rest_ratio.data.item()) 277 | test_neg_repeat_ratio.append(repeat_ratio.data.item()) 278 | t_value_5, t_index_5, t_pfake = get_index(prob, neg_set, target_items) 279 | value_5 = torch.cat((value_5, t_value_5), dim=0) 280 | index_5 = torch.cat((index_5, t_index_5), dim=0) 281 | pfake = torch.cat((pfake, t_pfake), dim=0) 282 | 283 | start_t1 = time.time() 284 | 285 | p_pos_score = pfake[0:len(tar_b)] 286 | p_neg_score = pfake[len(tar_b):] 287 | # neg_top10_value, _ = torch.topk(p_neg_score, len(tar_b)) 288 | p_n_score_differences.append((torch.mean(p_pos_score) - torch.mean(p_neg_score)).data.item()) 289 | 290 | rank_pfake = pfake.cpu().numpy() 291 | rank_pfake = -np.array(rank_pfake) 292 | rank_index = np.argsort(rank_pfake) 293 | select_index = rank_index[:5] 294 | hit = np.array(neg_set)[select_index] 295 | fake_basket_5 = list(hit) 296 | 297 | hit_len_5 = len(set(fake_basket_5) & set(tar_b)) 298 | fake_length_5 += len(fake_basket_5) 299 | if len(fake_basket_5) > 0: 300 | has_fake_user_5 += 1 301 | 302 | if hit_len_5 > 0: 303 | ndcg_t, mrr_t = get_ndcg(fake_basket_5, tar_b) 304 | ndcg_5 += ndcg_t 305 | mrr_5 += mrr_t 306 | hit_ratio_5 += 1 307 | recall_5 += hit_len_5 / len(tar_b) 308 | precision_5 += hit_len_5 / len(fake_basket_5) 309 | f1_5 += (2 * (hit_len_5 / len(tar_b)) * (hit_len_5 / len(fake_basket_5)) / ( 310 | (hit_len_5 / len(tar_b)) + (hit_len_5 / len(fake_basket_5)))) 311 | end_t1 = time.time() 312 | time_count1 += (-start_t1 + end_t1) 313 | 314 | logger.info(1) 315 | 316 | avg_fake_basket_user_5 = fake_length_5 / has_fake_user_5 317 | 318 | hit_ratio_5 = hit_ratio_5 / test_num 319 | recall_5 = recall_5 / test_num 320 | precision_5 = precision_5 / test_num 321 | f1_5 = f1_5 / test_num 322 | ndcg_5 = ndcg_5 / test_num 323 | mrr_5 = mrr_5 / test_num 324 | 325 | test_real_ratio_avg = np.mean(test_real_ratio) 326 | test_neg_ratio_avg = np.mean(test_neg_ratio) 327 | test_repeat_ratio_avg = np.nanmean(test_repeat_ratio) 328 | test_neg_repeat_ratio_avg = np.nanmean(test_neg_repeat_ratio) 329 | p_n_score_differences_avg = np.nanmean(p_n_score_differences) 330 | 331 | logger.info( 332 | 'valid_real_ratio_avg {:05.4f} valid_neg_ratio_avg {:05.4f} valid_repeat_ratio_avg {:05.4f} valid_neg_repeat_ratio_avg {:05.4f}'.format(test_real_ratio_avg, 333 | test_neg_ratio_avg,test_repeat_ratio_avg,test_neg_repeat_ratio_avg)) 334 | 335 | logger.info( 336 | '[Validation] neg_sample TOP5 [Test]| Epochs {:3d} | Hit ratio {:02.4f} | recall {:05.4f} | precision {:05.4f} | f1 {: 05.4f} | ndcg {: 05.4f} | mrr {: 05.4f} | have_fake_user {:3d} | avg_fake_length {: 05.4f} | all_valid_user_num {:3d}' 337 | .format(epoch, hit_ratio_5, recall_5, precision_5, f1_5, ndcg_5, mrr_5, has_fake_user_5, 338 | avg_fake_basket_user_5, 339 | test_num)) 340 | 341 | logger.info('[Validation] p_n_score_differences_avg {:05.4f} '.format(p_n_score_differences_avg)) 342 | logger.info('##############################################') 343 | return hit_ratio_5, recall_5, precision_5, f1_5, ndcg_5, mrr_5, avg_fake_basket_user_5 344 | 345 | def get_minibatch_split_all_items(item_list): 346 | minibatch_all_items = [] 347 | minibatch_size = 500 348 | count = 0 349 | while count < len(item_list): 350 | if count + minibatch_size <= len(item_list): 351 | target_items = item_list[count:count + minibatch_size] 352 | count = count + minibatch_size 353 | minibatch_all_items.append(target_items) 354 | else: 355 | target_items = item_list[count:] 356 | count = len(item_list) 357 | minibatch_all_items.append(target_items) 358 | return minibatch_all_items 359 | 360 | def get_onehot_tensor(n_item, baskets): 361 | input_basket = [] 362 | for basket in baskets: 363 | input_basket.append(np.squeeze( 364 | sp.coo_matrix(([1.] * basket.__len__(), ([0] * basket.__len__(), basket)), 365 | shape=(1, n_item)).toarray())) 366 | input_basket = torch.tensor(input_basket, dtype=torch.float) 367 | return input_basket 368 | 369 | def get_tensor(seq_list): # 370 | seq_tensor = get_onehot_tensor(Config().num_product, seq_list) 371 | return seq_tensor 372 | 373 | def get_ndcg(fake_basket, tar_b): 374 | u_dcg = 0 375 | u_idcg = 0 376 | rank_i = 0 377 | rank_flag = 0 378 | p_len = min(len(tar_b), 5) 379 | for k in range(5): # 380 | if k < len(fake_basket): 381 | if fake_basket[k] in set(tar_b): # 382 | u_dcg += 1 / math.log(k + 1 + 1, 2) 383 | if rank_flag == 0: 384 | rank_i += 1 / (k + 1) # min(p_len - 1, k) 385 | rank_flag = 1 386 | 387 | idea = min(len(tar_b), 5) 388 | for k in range(idea): 389 | u_idcg += 1 / math.log(k + 1 + 1, 2) 390 | ndcg = u_dcg / u_idcg 391 | return ndcg, rank_i 392 | 393 | def test_model_1000_top5_new(epoch, G1_flag, group_split1=4, group_split2=6, test_type=0, pretrain=0, temp=1, 394 | batch_size=256): 395 | 396 | basket_length_dict = {} 397 | for i in [5]: 398 | basket_length_dict[i] = {} 399 | 400 | basket_length_dict[i]['<{}'.format(group_split1)] = {} 401 | basket_length_dict[i]['<{}'.format(group_split1)]['hit'] = [] 402 | basket_length_dict[i]['<{}'.format(group_split1)]['recall'] = [] 403 | basket_length_dict[i]['<{}'.format(group_split1)]['precision'] = [] 404 | basket_length_dict[i]['<{}'.format(group_split1)]['f1'] = [] 405 | basket_length_dict[i]['<{}'.format(group_split1)]['ndcg'] = [] 406 | basket_length_dict[i]['<{}'.format(group_split1)]['mrr'] = [] 407 | 408 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)] = {} 409 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)]['hit'] = [] 410 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)]['recall'] = [] 411 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)]['precision'] = [] 412 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)]['f1'] = [] 413 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)]['ndcg'] = [] 414 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)]['mrr'] = [] 415 | 416 | basket_length_dict[i]['>{}'.format(group_split2)] = {} 417 | basket_length_dict[i]['>{}'.format(group_split2)]['hit'] = [] 418 | basket_length_dict[i]['>{}'.format(group_split2)]['recall'] = [] 419 | basket_length_dict[i]['>{}'.format(group_split2)]['precision'] = [] 420 | basket_length_dict[i]['>{}'.format(group_split2)]['f1'] = [] 421 | basket_length_dict[i]['>{}'.format(group_split2)]['ndcg'] = [] 422 | basket_length_dict[i]['>{}'.format(group_split2)]['mrr'] = [] 423 | 424 | def get_index(prob, neg_set, tar_b): 425 | mask = get_tensor([neg_set]).to(device).view(1, -1).expand(prob.size(0), -1) # n_items +1 426 | mask = (torch.ones_like(mask).to(device) - mask) * (-9999) 427 | prob = prob + mask # K*n_items 428 | value_5, index_5 = torch.topk(prob, 5) 429 | 430 | tar_b_tensor = torch.tensor(tar_b).to(device) # 431 | 432 | item_num = tar_b_tensor.size(0) # 433 | index = torch.tensor(np.linspace(0, item_num, num=item_num, endpoint=False), dtype=torch.long) # K 434 | pfake = prob[index, tar_b_tensor] # K 435 | 436 | return value_5, index_5, pfake 437 | 438 | NB.eval() 439 | 440 | hit_ratio_5 = 0 441 | recall_5 = 0 442 | precision_5 = 0 443 | f1_5 = 0 444 | ndcg_5 = 0 445 | mrr_5 = 0 446 | 447 | time_count1 = 0 448 | 449 | has_fake_user_5 = 0 450 | fake_length_5 = 0 451 | 452 | test_num = 0 453 | temp = temp 454 | test_real_ratio = [] 455 | test_neg_ratio = [] 456 | 457 | test_repeat_ratio = [] 458 | test_neg_repeat_ratio = [] 459 | 460 | p_n_score_differences = [] 461 | with torch.no_grad(): 462 | for batchid, (batch_userid, batch_input_seq, pad_batch_target_bsk, batch_history) in enumerate( 463 | get_batch_TEST_DATASET(TEST_DATASET, batch_size)): 464 | if batchid % Config().alternative_train_batch == 1: 465 | temp = np.maximum(temp * np.exp(-Config().ANNEAL_RATE * batchid), Config().temp_min) 466 | pad_batch_target_bsk = pad_batch_target_bsk.detach().cpu().numpy().tolist() # 467 | for bid, pad_target_bsk in enumerate(pad_batch_target_bsk): ## 468 | uid_tensor = batch_userid[bid].to(device) 469 | uid = int(uid_tensor.detach().cpu().numpy()) 470 | 471 | tar_b = list(set(pad_target_bsk) - set([-1])) 472 | if Config().histroy == 0: 473 | S_pool = batch_input_seq[bid].cpu().numpy().tolist() 474 | tar_b = list(set(tar_b) - set(S_pool)) # 475 | if len(tar_b) < 1: continue 476 | 477 | test_num += 1 478 | 479 | input_tensor = batch_input_seq[bid].to(device) 480 | history_tensor = batch_history[bid].to(device) 481 | 482 | l = int(input_tensor.view(1, -1).view(1, -1, Config().max_basket_size).size(1)) 483 | mask_input = (input_tensor != -1).int() 484 | avg_basket_size = float(mask_input.sum().data.item() / l) 485 | if avg_basket_size < group_split1: 486 | key_value = '<{}'.format(group_split1) 487 | elif avg_basket_size > group_split2: 488 | key_value = '>{}'.format(group_split2) 489 | else: 490 | key_value = '{}<<{}'.format(group_split1, group_split2) 491 | # l = int(input_tensor.view(1, -1).view(1, -1, Config().max_basket_size).size(1)) 492 | # if l < group_split1: 493 | # key_value = '<{}'.format(group_split1) 494 | # elif l > group_split2: 495 | # key_value = '>{}'.format(group_split2) 496 | # else: 497 | # key_value = '{}<<{}'.format(group_split1, group_split2) 498 | 499 | len_t = len(tar_b) 500 | neg_set = random.sample(list(set(Config().item_list) - set(tar_b)), (test_type - len_t)) 501 | # neg_set = test_neg_set[uid] 502 | 503 | minibatch_all_items = get_minibatch_split_all_items(neg_set) 504 | neg_set = tar_b + neg_set 505 | 506 | target_items_tensor = torch.tensor(tar_b, dtype=torch.long).to(device) # K 507 | input_tensor_expand = input_tensor.view(1, -1).expand(len(tar_b), -1) # K * B 508 | uid_tensor_expand = uid_tensor.expand(len(tar_b)) # K 509 | history_tensor_expand = history_tensor.view(1, -1).expand(len(tar_b), -1) # K * n_items 510 | rest_ratio,repeat_ratio, prob = NB(temp, uid_tensor_expand, input_tensor_expand, target_items_tensor, weights, 511 | history_tensor_expand, 512 | neg_set_tensor=None, train=False, G1flag=G1_flag, pretrain=pretrain, sd2=sd) 513 | rest_ratio = rest_ratio.detach() 514 | repeat_ratio = repeat_ratio.detach() 515 | test_real_ratio.append(rest_ratio.data.item()) 516 | test_repeat_ratio.append(repeat_ratio.data.item()) 517 | t_value_5, t_index_5, pfake = get_index(prob, neg_set, tar_b) 518 | value_5 = t_value_5 519 | index_5 = t_index_5 520 | pfake = pfake 521 | 522 | for id, target_items in enumerate(minibatch_all_items): 523 | target_items_tensor = torch.tensor(target_items, dtype=torch.long).to(device) # K 524 | input_tensor_expand = input_tensor.view(1, -1).expand(len(target_items), -1) # K * B 525 | uid_tensor_expand = uid_tensor.expand(len(target_items)) # K 526 | history_tensor_expand = history_tensor.view(1, -1).expand(len(target_items), -1) # K * n_items 527 | 528 | rest_ratio, repeat_ratio,prob = NB(temp, uid_tensor_expand, input_tensor_expand, target_items_tensor, 529 | weights, 530 | history_tensor_expand, 531 | neg_set_tensor=None, train=False, G1flag=G1_flag, pretrain=pretrain, 532 | sd2=sd) 533 | rest_ratio = rest_ratio.detach() 534 | repeat_ratio = repeat_ratio.detach() 535 | test_neg_ratio.append(rest_ratio.data.item()) 536 | test_neg_repeat_ratio.append(repeat_ratio.data.item()) 537 | t_value_5, t_index_5, t_pfake = get_index(prob, neg_set, target_items) 538 | value_5 = torch.cat((value_5, t_value_5), dim=0) 539 | index_5 = torch.cat((index_5, t_index_5), dim=0) 540 | pfake = torch.cat((pfake, t_pfake), dim=0) 541 | 542 | start_t1 = time.time() 543 | 544 | p_pos_score = pfake[0:len(tar_b)] 545 | p_neg_score = pfake[len(tar_b):] 546 | # neg_top10_value, _ = torch.topk(p_neg_score, len(tar_b)) 547 | p_n_score_differences.append((torch.mean(p_pos_score) - torch.mean(p_neg_score)).data.item()) 548 | 549 | rank_pfake = pfake.cpu().numpy() 550 | rank_pfake = -np.array(rank_pfake) 551 | rank_index = np.argsort(rank_pfake) 552 | select_index = rank_index[:5] 553 | hit = np.array(neg_set)[select_index] 554 | fake_basket_5 = list(hit) 555 | 556 | hit_len_5 = len(set(fake_basket_5) & set(tar_b)) 557 | fake_length_5 += len(fake_basket_5) 558 | if len(fake_basket_5) > 0: 559 | has_fake_user_5 += 1 560 | 561 | if hit_len_5 > 0: 562 | ndcg_t, mrr_t = get_ndcg(fake_basket_5, tar_b) 563 | ndcg_5 += ndcg_t 564 | mrr_5 += mrr_t 565 | hit_ratio_5 += 1 566 | recall_5 += hit_len_5 / len(tar_b) 567 | precision_5 += hit_len_5 / len(fake_basket_5) 568 | f1_5 += (2 * (hit_len_5 / len(tar_b)) * (hit_len_5 / len(fake_basket_5)) / ( 569 | (hit_len_5 / len(tar_b)) + (hit_len_5 / len(fake_basket_5)))) 570 | 571 | basket_length_dict[5][key_value]['ndcg'].append(ndcg_t) 572 | basket_length_dict[5][key_value]['mrr'].append(mrr_t) 573 | basket_length_dict[5][key_value]['hit'].append(1) 574 | basket_length_dict[5][key_value]['recall'].append(hit_len_5 / len(tar_b)) 575 | basket_length_dict[5][key_value]['precision'].append(hit_len_5 / len(fake_basket_5)) 576 | basket_length_dict[5][key_value]['f1'].append( 577 | (2 * (hit_len_5 / len(tar_b)) * (hit_len_5 / len(fake_basket_5)) / ( 578 | (hit_len_5 / len(tar_b)) + (hit_len_5 / len(fake_basket_5))))) 579 | else: 580 | basket_length_dict[5][key_value]['ndcg'].append(0) 581 | basket_length_dict[5][key_value]['mrr'].append(0) 582 | basket_length_dict[5][key_value]['hit'].append(0) 583 | basket_length_dict[5][key_value]['recall'].append(0) 584 | basket_length_dict[5][key_value]['precision'].append(0) 585 | basket_length_dict[5][key_value]['f1'].append(0) 586 | 587 | end_t1 = time.time() 588 | time_count1 += (-start_t1 + end_t1) 589 | 590 | logger.info(1) 591 | test_loss = 0 592 | 593 | avg_fake_basket_user_5 = fake_length_5 / has_fake_user_5 594 | 595 | hit_ratio_5 = hit_ratio_5 / test_num 596 | recall_5 = recall_5 / test_num 597 | precision_5 = precision_5 / test_num 598 | f1_5 = f1_5 / test_num 599 | ndcg_5 = ndcg_5 / test_num 600 | mrr_5 = mrr_5 / test_num 601 | 602 | test_real_ratio_avg = np.mean(test_real_ratio) 603 | test_neg_ratio_avg = np.mean(test_neg_ratio) 604 | test_repeat_ratio_avg = np.nanmean(test_repeat_ratio) 605 | test_neg_repeat_ratio_avg = np.nanmean(test_neg_repeat_ratio) 606 | p_n_score_differences_avg = np.nanmean(p_n_score_differences) 607 | 608 | for kk in [5]: 609 | 610 | length_list = list(basket_length_dict[kk].keys()) 611 | for length in length_list: 612 | hit_list = basket_length_dict[kk][length]['hit'] 613 | recall_list = basket_length_dict[kk][length]['recall'] 614 | precision_list = basket_length_dict[kk][length]['precision'] 615 | f1_list = basket_length_dict[kk][length]['f1'] 616 | ndcg_list = basket_length_dict[kk][length]['ndcg'] 617 | mrr_list = basket_length_dict[kk][length]['mrr'] 618 | logger.info( 619 | 'Epochs {:3d} topk {:3d} basket_length {} num {:3d} Hit ratio {:02.4f} | recall {:05.4f} | precision {:05.4f} | f1 {: 05.4f} | ndcg {: 05.4f} | mrr {: 05.4f} '.format( 620 | epoch, kk, length, len(hit_list), np.mean(hit_list), np.mean(recall_list), 621 | np.mean(precision_list), np.mean(f1_list), np.mean(ndcg_list), np.mean(mrr_list) 622 | )) 623 | 624 | logger.info('...............................................') 625 | 626 | logger.info( 627 | 'test_real_ratio_avg {:05.4f} test_neg_ratio_avg {:05.4f} test_repeat_ratio_avg {:05.4f} test_neg_repeat_ratio_avg {:05.4f}'.format(test_real_ratio_avg, test_neg_ratio_avg,test_repeat_ratio_avg,test_neg_repeat_ratio_avg)) 628 | logger.info( 629 | 'neg_sample TOP5 [Test]| Epochs {:3d} | Hit ratio {:02.4f} | recall {:05.4f} | precision {:05.4f} | f1 {: 05.4f} | ndcg {: 05.4f} | mrr {: 05.4f} | have_fake_user {:3d} | avg_fake_length {: 05.4f} | all_test_user_num {:3d}' 630 | .format(epoch, hit_ratio_5, recall_5, precision_5, f1_5, ndcg_5, mrr_5, has_fake_user_5, 631 | avg_fake_basket_user_5, 632 | test_num)) 633 | logger.info('[Test] p_n_score_differences_avg {:05.4f} '.format(p_n_score_differences_avg)) 634 | logger.info('##############################################') 635 | return hit_ratio_5, recall_5, precision_5, f1_5, ndcg_5, mrr_5, avg_fake_basket_user_5, test_loss 636 | 637 | def test_as_DREAM_new(epoch, G1_flag, group_split1=4, group_split2=6, test_type=0, pretrain=0, temp=1, 638 | batch_size=256): 639 | 640 | basket_length_dict = {} 641 | for i in [5]: 642 | basket_length_dict[i] = {} 643 | 644 | basket_length_dict[i]['<{}'.format(group_split1)] = {} 645 | basket_length_dict[i]['<{}'.format(group_split1)]['hit'] = [] 646 | basket_length_dict[i]['<{}'.format(group_split1)]['recall'] = [] 647 | basket_length_dict[i]['<{}'.format(group_split1)]['precision'] = [] 648 | basket_length_dict[i]['<{}'.format(group_split1)]['f1'] = [] 649 | basket_length_dict[i]['<{}'.format(group_split1)]['ndcg'] = [] 650 | basket_length_dict[i]['<{}'.format(group_split1)]['mrr'] = [] 651 | 652 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)] = {} 653 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)]['hit'] = [] 654 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)]['recall'] = [] 655 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)]['precision'] = [] 656 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)]['f1'] = [] 657 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)]['ndcg'] = [] 658 | basket_length_dict[i]['{}<<{}'.format(group_split1, group_split2)]['mrr'] = [] 659 | 660 | basket_length_dict[i]['>{}'.format(group_split2)] = {} 661 | basket_length_dict[i]['>{}'.format(group_split2)]['hit'] = [] 662 | basket_length_dict[i]['>{}'.format(group_split2)]['recall'] = [] 663 | basket_length_dict[i]['>{}'.format(group_split2)]['precision'] = [] 664 | basket_length_dict[i]['>{}'.format(group_split2)]['f1'] = [] 665 | basket_length_dict[i]['>{}'.format(group_split2)]['ndcg'] = [] 666 | basket_length_dict[i]['>{}'.format(group_split2)]['mrr'] = [] 667 | 668 | NB.eval() 669 | 670 | hit_ratio_5 = 0 671 | recall_5 = 0 672 | precision_5 = 0 673 | f1_5 = 0 674 | ndcg_5 = 0 675 | mrr_5 = 0 676 | 677 | has_fake_user_5 = 0 678 | fake_length_5 = 0 679 | 680 | test_num = 0 681 | temp = temp 682 | with torch.no_grad(): 683 | for batchid, (batch_userid, batch_input_seq, pad_batch_target_bsk, batch_history) in enumerate( 684 | get_batch_TEST_DATASET(TEST_DATASET, batch_size)): 685 | if batchid % Config().alternative_train_batch == 1: 686 | temp = np.maximum(temp * np.exp(-Config().ANNEAL_RATE * batchid), Config().temp_min) 687 | pad_batch_target_bsk = pad_batch_target_bsk.detach().cpu().numpy().tolist() # 688 | for bid, pad_target_bsk in enumerate(pad_batch_target_bsk): ## 689 | uid_tensor = batch_userid[bid].to(device) 690 | uid = int(uid_tensor.detach().cpu().numpy()) 691 | 692 | tar_b = list(set(pad_target_bsk) - set([-1])) 693 | if Config().histroy == 0: 694 | S_pool = batch_input_seq[bid].cpu().numpy().tolist() 695 | tar_b = list(set(tar_b) - set(S_pool)) 696 | if len(tar_b) < 1: continue 697 | 698 | test_num += 1 699 | 700 | input_tensor = batch_input_seq[bid].to(device) 701 | history_tensor = batch_history[bid].to(device) 702 | 703 | l = int(input_tensor.view(1, -1).view(1, -1, Config().max_basket_size).size(1)) 704 | mask_input = (input_tensor != -1).int() 705 | avg_basket_size = float(mask_input.sum().data.item() / l) 706 | if avg_basket_size < group_split1: 707 | key_value = '<{}'.format(group_split1) 708 | elif avg_basket_size > group_split2: 709 | key_value = '>{}'.format(group_split2) 710 | else: 711 | key_value = '{}<<{}'.format(group_split1, group_split2) 712 | 713 | # l = int(input_tensor.view(1, -1).view(1, -1, Config().max_basket_size).size(1)) 714 | # if l < group_split1: 715 | # key_value = '<{}'.format(group_split1) 716 | # elif l > group_split2: 717 | # key_value = '>{}'.format(group_split2) 718 | # else: 719 | # key_value = '{}<<{}'.format(group_split1, group_split2) 720 | 721 | neg_set = test_neg_set[uid] 722 | neg_set = tar_b + neg_set 723 | 724 | target_items_tensor = torch.tensor(tar_b, dtype=torch.long).to(device) # K 725 | input_tensor_expand = input_tensor.view(1, -1).expand(len(tar_b), -1) # K * B 726 | uid_tensor_expand = uid_tensor.expand(len(tar_b)) # K 727 | history_tensor_expand = history_tensor.view(1, -1).expand(len(tar_b), -1) # K * n_items 728 | _,_, prob = NB(temp, uid_tensor_expand, input_tensor_expand, target_items_tensor, weights, 729 | history_tensor_expand, 730 | neg_set_tensor=None, train=False, G1flag=G1_flag, pretrain=pretrain, sd2=sd) 731 | prob = prob.detach()[0, :].reshape(1, -1) # 1*n_items 732 | mask = get_tensor([neg_set]).to(device).view(1, -1) 733 | mask = (torch.ones_like(mask).to(device) - mask) * (-9999) 734 | prob = prob + mask 735 | value_5, index_5 = torch.topk(prob.squeeze(), 5) 736 | 737 | index_5 = index_5.tolist() # 738 | fake_basket_5 = index_5 739 | 740 | hit_len_5 = len(set(fake_basket_5) & set(tar_b)) 741 | 742 | fake_length_5 += len(fake_basket_5) 743 | if len(fake_basket_5) > 0: 744 | has_fake_user_5 += 1 745 | 746 | if hit_len_5 > 0: 747 | 748 | ndcg_t, mrr_t = get_ndcg(fake_basket_5, tar_b) 749 | ndcg_5 += ndcg_t 750 | mrr_5 += mrr_t 751 | hit_ratio_5 += 1 752 | recall_5 += hit_len_5 / len(tar_b) 753 | precision_5 += hit_len_5 / len(fake_basket_5) 754 | f1_5 += (2 * (hit_len_5 / len(tar_b)) * (hit_len_5 / len(fake_basket_5)) / ( 755 | (hit_len_5 / len(tar_b)) + (hit_len_5 / len(fake_basket_5)))) 756 | 757 | basket_length_dict[5][key_value]['ndcg'].append(ndcg_t) 758 | basket_length_dict[5][key_value]['mrr'].append(mrr_t) 759 | basket_length_dict[5][key_value]['hit'].append(1) 760 | basket_length_dict[5][key_value]['recall'].append(hit_len_5 / len(tar_b)) 761 | basket_length_dict[5][key_value]['precision'].append(hit_len_5 / len(fake_basket_5)) 762 | basket_length_dict[5][key_value]['f1'].append( 763 | (2 * (hit_len_5 / len(tar_b)) * (hit_len_5 / len(fake_basket_5)) / ( 764 | (hit_len_5 / len(tar_b)) + (hit_len_5 / len(fake_basket_5))))) 765 | else: 766 | basket_length_dict[5][key_value]['ndcg'].append(0) 767 | basket_length_dict[5][key_value]['mrr'].append(0) 768 | basket_length_dict[5][key_value]['hit'].append(0) 769 | basket_length_dict[5][key_value]['recall'].append(0) 770 | basket_length_dict[5][key_value]['precision'].append(0) 771 | basket_length_dict[5][key_value]['f1'].append(0) 772 | 773 | logger.info(1) 774 | test_loss = 0 775 | 776 | avg_fake_basket_user_5 = fake_length_5 / has_fake_user_5 777 | 778 | hit_ratio_5 = hit_ratio_5 / test_num 779 | recall_5 = recall_5 / test_num 780 | precision_5 = precision_5 / test_num 781 | f1_5 = f1_5 / test_num 782 | ndcg_5 = ndcg_5 / test_num 783 | mrr_5 = mrr_5 / test_num 784 | 785 | for kk in [5]: 786 | 787 | length_list = list(basket_length_dict[kk].keys()) 788 | for length in length_list: 789 | hit_list = basket_length_dict[kk][length]['hit'] 790 | recall_list = basket_length_dict[kk][length]['recall'] 791 | precision_list = basket_length_dict[kk][length]['precision'] 792 | f1_list = basket_length_dict[kk][length]['f1'] 793 | ndcg_list = basket_length_dict[kk][length]['ndcg'] 794 | mrr_list = basket_length_dict[kk][length]['mrr'] 795 | logger.info( 796 | 'Epochs {:3d} topk {:3d} basket_length {} num {:3d} Hit ratio {:02.4f} | recall {:05.4f} | precision {:05.4f} | f1 {: 05.4f} | ndcg {: 05.4f} | mrr {: 05.4f} '.format( 797 | epoch, kk, length, len(hit_list), np.mean(hit_list), np.mean(recall_list), 798 | np.mean(precision_list), np.mean(f1_list), np.mean(ndcg_list), np.mean(mrr_list) 799 | )) 800 | 801 | logger.info('...............................................') 802 | 803 | logger.info( 804 | 'neg_sample TOP5 [Test]| Epochs {:3d} | Hit ratio {:02.4f} | recall {:05.4f} | precision {:05.4f} | f1 {: 05.4f} | ndcg {: 05.4f} | mrr {: 05.4f}| have_fake_user {:3d} | avg_fake_length {: 05.4f} | all_test_user_num {:3d}' 805 | .format(epoch, hit_ratio_5, recall_5, precision_5, f1_5, ndcg_5, mrr_5, has_fake_user_5, 806 | avg_fake_basket_user_5, 807 | test_num)) 808 | 809 | logger.info('##############################################') 810 | return hit_ratio_5, recall_5, precision_5, f1_5, ndcg_5, mrr_5, avg_fake_basket_user_5, test_loss 811 | try: 812 | valid_hit_l = [] 813 | valid_recall_l = [] 814 | test_hit_l = [] 815 | test_recall_l = [] 816 | train_loss_l = [] 817 | 818 | optimizer_dict = {} 819 | schedular_dict = {} 820 | optimizer_dict[2] = torch.optim.Adam([ 821 | {'params': NB.G2.parameters()}, 822 | {'params': NB.D.parameters()}, 823 | {'params': NB.G0.parameters(), 'lr': Config().G1_lr}], lr=Config().learning_rate, 824 | weight_decay=Config().weight_decay) 825 | schedular_dict[2] = torch.optim.lr_scheduler.StepLR(optimizer_dict[2], step_size=3, gamma=1) 826 | 827 | optimizer_dict[1] = torch.optim.Adam([ 828 | {'params': NB.G0.parameters()}] 829 | , lr=Config().G1_lr, weight_decay=Config().weight_decay) 830 | schedular_dict[1] = torch.optim.lr_scheduler.StepLR(optimizer_dict[1], step_size=3, gamma=1) 831 | 832 | optimizer_dict[0] = torch.optim.Adam([ 833 | {'params': NB.G2.parameters()}, 834 | {'params': NB.D.parameters()}] 835 | , lr=Config().learning_rate, weight_decay=Config().weight_decay) 836 | schedular_dict[0] = torch.optim.lr_scheduler.StepLR(optimizer_dict[0], step_size=3, gamma=1) 837 | 838 | best_hit_ratio = 0 839 | best_recall = 0 840 | best_precision = 0 841 | best_f1 = 0 842 | best_ndcg = 0 843 | best_mrr = 0 844 | temp = Config().temp 845 | if Config().G1_flag == 1: G1_flag = 2 ## 846 | if Config().G1_flag == 0: 847 | G1_flag = 0 ##### 848 | if Config().G1_flag == -1: G1_flag = 0 #### 849 | pretrain = 0 850 | 851 | train_epoch = 0 ## 852 | pretrained_epoch = 0 ## 853 | 854 | # tb_writer = SummaryWriter( 855 | # "./result/main_1_{}_log_{}".format(Config().dataset, Config().log_fire)) 856 | 857 | logs = dict() 858 | 859 | first_batch_size = Config().batch_size 860 | al_batch = Config().alternative_train_batch 861 | B = first_batch_size 862 | 863 | if Config().before_epoch > 0: 864 | PATH = os.path.join(Config().MODEL_DIR, "base_model_{}_{}_{}.pt".format(Config().embedding_dim/2,Config().dataset, 865 | 'basemodel')) 866 | checkpoint0 = torch.load(PATH) 867 | checkpoint = checkpoint0['model_state_dict'] 868 | # optimizer_dict[0].load_state_dict(checkpoint0['optimizer0_state_dict']) 869 | # NB.load_state_dict(checkpoint) 870 | model_dict = NB.state_dict() 871 | pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict} 872 | model_dict.update(pretrained_dict) 873 | NB.load_state_dict(model_dict) 874 | 875 | logger.info('reset from ... {}'.format(PATH)) 876 | 877 | NB.G0.init_weight() 878 | 879 | test_as_DREAM_new( 880 | Config().before_epoch, 881 | G1_flag, 882 | Config().group_split1, 883 | Config().group_split2, 884 | test_type=Config().test_type, 885 | pretrain=pretrain, 886 | temp=Config().temp_min, batch_size=B) 887 | 888 | # test_model_1000_top5_new( 889 | # Config().before_epoch, 890 | # G1_flag, 891 | # Config().group_split1, 892 | # Config().group_split2, 893 | # pretrain=pretrain, 894 | # test_type=Config().test_type, temp=Config().temp_min, batch_size=B) 895 | 896 | if Config().before_epoch >= Config().pretrain_epoch: 897 | if Config().G1_flag == 0: 898 | if pretrain == 0: 899 | G1_flag = 1 900 | pretrain = 1 901 | 902 | val_loss_before = dict() 903 | val_loss_before['hit'] = 0 904 | val_loss_before['recall'] = 0 905 | val_loss_before['precision'] = 0 906 | val_loss_before['f1'] = 0 907 | val_loss_before['ndcg'] = 0 908 | 909 | ############################################ 910 | temp0 = Config().temp_min 911 | for epoch in range(Config().before_epoch + 1, Config().epochs + 1): 912 | sd = Config().sd2 913 | save_flag = 0 914 | if Config().G1_flag == 0: 915 | if pretrained_epoch == Config().pretrain_epoch: 916 | if pretrain == 0: 917 | G1_flag = 1 918 | pretrain = 1 919 | 920 | if ((train_epoch % Config().alternative_train_epoch == 0) & (train_epoch > 0)): 921 | train_epoch = 0 922 | temp = Config().temp 923 | if Config().G1_flag == 0: 924 | if G1_flag == 0: 925 | G1_flag = 1 926 | else: 927 | train_epoch = Config().alternative_train_epoch - Config().alternative_train_epoch_D 928 | save_flag = 1 929 | G1_flag = 0 930 | temp0 = Config().temp_min #默认为0 931 | logger.info( 932 | 'G1_flag {} pretrain {} temp {} train_epoch {} B {} al_batch {} lr {} temp0 {}'.format(G1_flag, pretrain, temp, 933 | train_epoch, B, 934 | al_batch, 935 | schedular_dict[ 936 | G1_flag].get_lr()[ 937 | 0],temp0)) 938 | 939 | train_loss, temp0 = train_model(epoch, G1_flag, pretrain, temp=temp, batch_size=B, 940 | alternative_train_batch=al_batch,temp0 = temp0) 941 | 942 | valid_hit_ratio, valid_recall, valid_precision, valid_f1, valid_ndcg, valid_mrr, valid_avg_fake_basket_user = valid_model_1000_top5( 943 | epoch, G1_flag, 944 | test_type=Config().test_type, 945 | pretrain=pretrain, 946 | temp=temp0, 947 | batch_size=B) 948 | 949 | valid_hit_l.append(valid_hit_ratio) 950 | valid_recall_l.append(valid_recall) 951 | train_loss_l.append(train_loss) 952 | 953 | learning_rate_scalar = schedular_dict[G1_flag].get_lr()[0] 954 | logs['learning_rate_G1flag_{}'.format(G1_flag)] = learning_rate_scalar 955 | logs['train_loss'] = train_loss 956 | logs['valid_hit'] = valid_hit_ratio 957 | logs['valid_recall'] = valid_recall 958 | logs['valid_pre'] = valid_precision 959 | logs['valid_f1'] = valid_f1 960 | logs['valid_ndcg'] = valid_ndcg 961 | 962 | schedular_dict[G1_flag].step() 963 | 964 | if (valid_hit_ratio > val_loss_before['hit']) | (valid_recall > val_loss_before['recall']) | ( 965 | valid_precision > val_loss_before['precision']) | ( 966 | valid_f1 > val_loss_before['f1']) | (valid_ndcg > val_loss_before['ndcg']): 967 | better = 0 968 | if valid_hit_ratio > val_loss_before['hit']: better += 1 969 | if valid_recall > val_loss_before['recall']: better += 1 970 | if valid_precision > val_loss_before['precision']: better += 1 971 | if valid_f1 > val_loss_before['f1']: better += 1 972 | if valid_ndcg > val_loss_before['ndcg']: better += 1 973 | if better > 1: 974 | val_loss_before['hit'] = valid_hit_ratio 975 | val_loss_before['recall'] = valid_recall 976 | val_loss_before['precision'] = valid_precision 977 | val_loss_before['f1'] = valid_f1 978 | val_loss_before['ndcg'] = valid_ndcg 979 | save_flag = 1 980 | 981 | ##### 982 | if ((((train_epoch + 1) % Config().alternative_train_epoch == 0)) | ( 983 | epoch % Config().test_every_epoch == 0) | (save_flag == 1) | (epoch == Config().epochs) | ( 984 | pretrained_epoch == (Config().pretrain_epoch - 1))): 985 | 986 | if ((G1_flag == 0) & (pretrain == 0)): 987 | hit_ratio, recall, precision, f1, ndcg, mrr, avg_fake_basket_user, test_loss = test_as_DREAM_new( 988 | epoch, 989 | G1_flag, 990 | Config().group_split1, 991 | Config().group_split2, 992 | test_type=Config().test_type, 993 | pretrain=pretrain, 994 | temp=temp, batch_size=B) 995 | logs['hit'] = hit_ratio 996 | logs['recall'] = recall 997 | logs['precision'] = precision 998 | logs['f1'] = f1 999 | logs['ndcg'] = ndcg 1000 | 1001 | if (hit_ratio > best_hit_ratio) | (f1 > best_f1) | (ndcg > best_ndcg): 1002 | better = 0 1003 | if hit_ratio > best_hit_ratio: better += 1 1004 | if f1 > best_f1: better += 1 1005 | if ndcg > best_ndcg: better += 1 1006 | if better > 1: 1007 | model_name = os.path.join(Config().MODEL_DIR, "base_model_{}_{}_{}.pt".format(Config().embedding_dim/2,Config().dataset, 1008 | Config().log_fire)) 1009 | checkpoint = {'epoch': epoch, 1010 | 'model_state_dict': NB.state_dict(), 1011 | } 1012 | torch.save(checkpoint, model_name) 1013 | logger.info("Save model as %s" % model_name) 1014 | else: 1015 | hit_ratio, recall, precision, f1, ndcg, mrr, avg_fake_basket_user, test_loss = test_model_1000_top5_new( 1016 | epoch, 1017 | G1_flag, 1018 | Config().group_split1, 1019 | Config().group_split2, 1020 | pretrain=pretrain, 1021 | test_type=Config().test_type, temp=temp0, batch_size=B) 1022 | logs['hit'] = hit_ratio 1023 | logs['recall'] = recall 1024 | logs['precision'] = precision 1025 | logs['f1'] = f1 1026 | logs['ndcg'] = ndcg 1027 | 1028 | test_hit_l.append(hit_ratio) 1029 | test_recall_l.append(recall) 1030 | 1031 | if (hit_ratio > best_hit_ratio) | (recall > best_recall) | (precision > best_precision) | ( 1032 | f1 > best_f1) | (ndcg > best_ndcg) | (mrr > best_mrr): 1033 | better = 0 1034 | if hit_ratio > best_hit_ratio: better += 1 1035 | if recall > best_recall: better += 1 1036 | if precision > best_precision: better += 1 1037 | if f1 > best_f1: better += 1 1038 | if ndcg > best_ndcg: better += 1 1039 | if better > 1: 1040 | best_PATH = os.path.join(Config().MODEL_DIR, 1041 | "model_1_{}_{}_{}.pt".format(Config().embedding_dim/2,Config().dataset, 1042 | Config().log_fire)) 1043 | checkpoint = {'epoch': epoch, 1044 | 'temp0': temp0, 1045 | 'G1_flag':G1_flag, 1046 | 'model_state_dict': NB.state_dict(), 1047 | } 1048 | torch.save(checkpoint, best_PATH) 1049 | logger.info("Save model as %s" % best_PATH) 1050 | 1051 | best_hit_ratio = max(hit_ratio, best_hit_ratio) 1052 | best_recall = max(recall, best_recall) 1053 | best_precision = max(precision, best_precision) 1054 | best_f1 = max(f1, best_f1) 1055 | best_ndcg = max(ndcg, best_ndcg) 1056 | best_mrr = max(mrr, best_mrr) 1057 | logs['besthit'] = best_hit_ratio 1058 | logs['bestrecall'] = best_recall 1059 | logs['bestprecision'] = best_precision 1060 | logs['bestf1'] = best_f1 1061 | logs['bestndcg'] = best_ndcg 1062 | 1063 | logger.info( 1064 | 'Epochs {:3d} best_hit {:05.4f} best_recall {:05.4f} best_precision {:05.4f} best_f1 {:05.4f} best_ndcg {:05.4f} best_mrr {:05.4f} '.format( 1065 | epoch, best_hit_ratio, best_recall, best_precision, best_f1, best_ndcg, best_mrr)) 1066 | 1067 | logger.info( 1068 | 'Epochs {:3d} loss_all {:05.4f} Hit ratio {:02.4f} recall {:05.4f} precision {:05.4f} f1 {: 05.4f} ndcg {: 05.4f} mrr {: 05.4f} avg_fake_length {:05.4f}' 1069 | .format(epoch, train_loss, hit_ratio, recall, precision, f1, ndcg, mrr, 1070 | avg_fake_basket_user)) 1071 | 1072 | # for key, value in logs.items(): 1073 | # tb_writer.add_scalar(key, value, epoch) 1074 | 1075 | temp = Config().temp 1076 | if Config().G1_flag == 0: 1077 | train_epoch += 1 1078 | if pretrain == 0: 1079 | if G1_flag == 0: 1080 | pretrained_epoch += 1 1081 | train_epoch = 0 1082 | logger.info('valid_hit {}'.format(valid_hit_l)) 1083 | logger.info('valid_recall {}'.format(valid_recall_l)) 1084 | logger.info('test_hit {}'.format(test_hit_l)) 1085 | logger.info('test_hit {}'.format(test_recall_l)) 1086 | logger.info('train_loss {}'.format(train_loss_l)) 1087 | except KeyboardInterrupt: 1088 | logger.info('Early Stopping!') 1089 | 1090 | 1091 | if __name__ == '__main__': 1092 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 1093 | train() 1094 | -------------------------------------------------------------------------------- /CLEA/module/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--same_embedding', type=int, default=1, help='whether use the same embedding in G1 and D') 5 | parser.add_argument('--test_every_epoch', type=int, default=5, help='max_basket_num') 6 | parser.add_argument('--pos_margin', type=float, default=0.3, help='max_basket_num') 7 | parser.add_argument('--neg_margin', type=float, default=0.7, help='max_basket_num') 8 | 9 | parser.add_argument('--neg_ratio', type=int, default=1, help='neg_ratio') 10 | parser.add_argument('--device_id', type=int, default=0, help='GPU_ID') 11 | parser.add_argument('--G1_flag', type=int, default=0, help='train_type : with G1 1 / with no G1 -1 / from no G1 to G1 0') 12 | 13 | parser.add_argument('--sd1', type=float, default=1, help='sd1') 14 | parser.add_argument('--sd2', type=float, default=1, help='sd2') 15 | parser.add_argument('--sd3', type=float, default=1, help='sd3') 16 | parser.add_argument('--sd4', type=float, default=1, help='sd4') 17 | parser.add_argument('--sd5', type=float, default=1, help='sd5') 18 | 19 | parser.add_argument('--basket_pool_type', type=str, default='avg', help='basket_pool_type') 20 | parser.add_argument('--num_layer', type=int, default=1, help='num_layer') 21 | parser.add_argument('--test_type', type=int, default=1000, help='0:old 1000:1000 500:500') 22 | 23 | parser.add_argument('--group_split1', type=int, default=4, help='basket_group_split') 24 | parser.add_argument('--group_split2', type=int, default=6, help='basket_group_split') 25 | parser.add_argument('--max_basket_size', type=int, default=35, help='max_basket_size') 26 | parser.add_argument('--max_basket_num', type=int, default=32, help='max_basket_num') 27 | parser.add_argument('--dataset', type=str, default='Instacart', help='dataset name') 28 | parser.add_argument('--num_product', type=int, default=8222 , help='n_items TaFeng:9963 Instacart:8222 Delicious:6539') 29 | parser.add_argument('--num_users', type=int, default= 6886, help='n_users TaFeng:16060 Instacart:6885 Delicious:1735') 30 | parser.add_argument('--distrisample', type=int, default= 0, help='') 31 | 32 | parser.add_argument('--output_dir', type=str, default='./result', help='') 33 | parser.add_argument('--log_fire', type=str, default='test', help='basket_group_split') #_learning 34 | parser.add_argument('--temp', type=float, default=1, help='') #1 35 | parser.add_argument('--temp_min', type=float, default=0.3, help='') #0.3 36 | parser.add_argument('--pretrain_epoch', type=int, default= 2, help='n_users TaFeng:16060 Instacart:6885 Delicious:1735') 37 | parser.add_argument('--batch_size', type=int, default=128, help='batch_size') 38 | parser.add_argument('--epoch', type=int, default=50, help='the number of epochs to train for') 39 | parser.add_argument('--ANNEAL_RATE', type=float, default=0.0002, help='ANNEAL_RATE') #0.0003 40 | parser.add_argument('--lr', type=float, default=0.00001, help='learning rate') 41 | parser.add_argument('--G1_lr', type=float, default=0.0001, help='learning rate') 42 | parser.add_argument('--l2', type=float, default=0.00001, help='l2 penalty') # [0.001, 0.0005, 0.0001, 0.00005, 0.00001] 43 | parser.add_argument('--embedding_dim', type=int, default=256,help='hidden sise') # [0.001, 0.0005, 0.0001, 0.00005, 0.00001] 44 | parser.add_argument('--alternative_train_epoch', type=int, default=5, help='max_basket_num') 45 | parser.add_argument('--alternative_train_epoch_D', type=int, default=1, help='max_basket_num') 46 | parser.add_argument('--alternative_train_batch', type=int, default=200, help='max_basket_num') 47 | parser.add_argument('--dropout', type=float, default=0.2, help='dropout') 48 | parser.add_argument('--history', type=int, default=0, help='history') 49 | parser.add_argument('--temp_learn', type=int, default=0, help='temp_learn') 50 | parser.add_argument('--before_epoch', type=int, default=2, help='basket_group_split') #18 63 51 | args = parser.parse_args() 52 | 53 | 54 | # -*- coding:utf-8 -*- 55 | class Config(object): 56 | def __init__(self): 57 | self.same_embedding = args.same_embedding 58 | self.neg_margin = args.neg_margin 59 | self.pos_margin = args.pos_margin 60 | 61 | self.alternative_train_epoch_D = args.alternative_train_epoch_D 62 | self.temp_learn = args.temp_learn 63 | self.output_dir = args.output_dir 64 | 65 | self.distrisample = args.distrisample 66 | self.pretrain_epoch = args.pretrain_epoch 67 | self.MODEL_DIR = './runs' 68 | self.input_dir = 'dataset/{}'.format(args.dataset)+'/user_date_tran_dict_new.txt' 69 | self.dataset = args.dataset 70 | self.epochs = args.epoch 71 | self.device_id = args.device_id 72 | self.log_interval = 500 # num of batches between two logging #300 73 | self.num_users = args.num_users # 74 | self.num_product = args.num_product 75 | self.item_list = list(range(args.num_product)) 76 | self.test_ratio = 100 77 | self.log_fire = args.log_fire 78 | self.test_type = args.test_type 79 | self.test_every_epoch = args.test_every_epoch 80 | 81 | 82 | self.alternative_train_epoch = args.alternative_train_epoch 83 | self.alternative_train_batch = args.alternative_train_batch 84 | self.max_basket_size = args.max_basket_size 85 | self.max_basket_num = args.max_basket_num 86 | self.group_split1 = args.group_split1 87 | self.group_split2 = args.group_split2 88 | self.batch_size = args.batch_size 89 | self.neg_ratio = args.neg_ratio 90 | 91 | self.sd1 = args.sd1 92 | self.sd2 = args.sd2 93 | self.sd3 = args.sd3 94 | self.sd4 = args.sd4 95 | self.sd5 = args.sd5 96 | self.learning_rate = args.lr 97 | self.G1_lr = args.G1_lr 98 | self.dropout = args.dropout 99 | self.weight_decay = args.l2 100 | self.basket_pool_type = args.basket_pool_type 101 | self.num_layer = args.num_layer 102 | 103 | self.embedding_dim = args.embedding_dim 104 | self.ANNEAL_RATE = args.ANNEAL_RATE 105 | 106 | self.temp = args.temp 107 | self.temp_min = args.temp_min 108 | self.before_epoch = args.before_epoch 109 | 110 | self.G1_flag = args.G1_flag 111 | self.histroy = args.history 112 | 113 | def list_all_member(self, logger): 114 | for name, value in vars(self).items(): 115 | if not name.startswith('item'): 116 | logger.info('%s=%s' % (name, value)) 117 | 118 | 119 | ''' 120 | 121 | ''' -------------------------------------------------------------------------------- /CLEA/module/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | 5 | 6 | class Logger(object): 7 | 8 | def __init__(self, filename): 9 | 10 | self.logger = logging.getLogger(filename) 11 | self.logger.setLevel(logging.DEBUG) 12 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d: %(message)s', 13 | datefmt='%Y-%m-%d %H:%M:%S') 14 | 15 | # write into file 16 | fh = logging.FileHandler(filename) 17 | fh.setLevel(logging.DEBUG) 18 | fh.setFormatter(formatter) 19 | 20 | # show on console 21 | ch = logging.StreamHandler(sys.stdout) 22 | ch.setLevel(logging.DEBUG) 23 | ch.setFormatter(formatter) 24 | 25 | # add to Handler 26 | self.logger.addHandler(fh) 27 | self.logger.addHandler(ch) 28 | 29 | def _flush(self): 30 | for handler in self.logger.handlers: 31 | handler.flush() 32 | 33 | def debug(self, message): 34 | self.logger.debug(message) 35 | self._flush() 36 | 37 | def info(self, message): 38 | self.logger.info(message) 39 | self._flush() 40 | 41 | def warning(self, message): 42 | self.logger.warning(message) 43 | self._flush() 44 | 45 | def error(self, message): 46 | self.logger.error(message) 47 | self._flush() 48 | 49 | def critical(self, message): 50 | self.logger.critical(message) 51 | self._flush() 52 | 53 | 54 | if __name__ == '__main__': 55 | log = Logger('NB.log') 56 | log.debug('debug') 57 | log.info('info') 58 | log.warning('warning') 59 | log.error('error') 60 | log.critical('critical') 61 | 62 | -------------------------------------------------------------------------------- /CLEA/module/model_1.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class NBModel(nn.Module): # 10 | def __init__(self, config, device): 11 | super(NBModel, self).__init__() 12 | self.device = device 13 | self.sd1 = config.sd1 14 | self.sd2 = config.sd2 15 | self.sd3 = config.sd3 16 | self.sd4 = config.sd4 17 | self.sd5 = config.sd5 18 | 19 | self.neg_margin = config.neg_margin 20 | self.pos_margin = config.pos_margin 21 | 22 | self.num_users = config.num_users 23 | 24 | self.num_product = config.num_product 25 | 26 | self.D = Discriminator(config, self.device) 27 | self.G0 = Generator1(config, self.device) 28 | self.G2 = Generator2(config, self.device) 29 | 30 | self.G1_flag = 0 31 | self.mse = nn.MSELoss() 32 | 33 | def init_weight(self): 34 | torch.nn.init.xavier_normal_(self.embed.weight) 35 | 36 | # profile 37 | def forward(self, T, uid, input_seq_tensor, tar_b_tensor, weight, history, neg_set_tensor=None, train=True, 38 | G1flag=0, pretrain=0, sd2=1): 39 | ''' 40 | 41 | :param T: is Gumbel softmax's temperature 42 | :param uid: is userid 43 | :param input_seq_tensor: K * B 44 | :param tar_b_tensor: K 45 | :param weight: itemnum 46 | :param history: K * itemnum 47 | :param neg_set_tensor: K * neg_ratio 48 | :param train: whether train 49 | :return: classify prob metric K * itemnum 50 | ''' 51 | 52 | 53 | 54 | self.G1_flag = G1flag 55 | self.pretrain = pretrain 56 | 57 | mask = torch.ones_like(input_seq_tensor,dtype = torch.float).to(self.device) 58 | mask[input_seq_tensor == -1] = 0 59 | tar_expand = tar_b_tensor.view(-1,1).expand_as(input_seq_tensor) 60 | mask0 = torch.zeros_like(input_seq_tensor,dtype = torch.float).to(self.device) 61 | mask0[tar_expand == input_seq_tensor] = 1 62 | 63 | input_embeddings = self.G2.embed1(input_seq_tensor + 1) 64 | target_embedding = self.G2.embed1(tar_b_tensor + 1) 65 | 66 | 67 | 68 | test = 1 69 | if train == True: 70 | test = 0 71 | # print(self.embed.weight.data) 72 | if ((self.G1_flag == 0) & (pretrain == 0)): # 73 | # 74 | self.filter_basket = torch.ones_like(input_seq_tensor,dtype = torch.float).to(self.device) # K * B 75 | real_generate_embedding1 = self.G2(self.filter_basket, input_seq_tensor,uid) # K*H 76 | fake_discr = self.D(real_generate_embedding1, history, tar_b_tensor) # K*n_items 77 | 78 | if train == True: 79 | all_sum = mask.sum() 80 | 81 | loss, (p1, p2, p3, p4) = self.loss2_G1flag0(fake_discr, tar_b_tensor) 82 | 83 | return loss, fake_discr, (p1, p2, p3, p4), (all_sum, all_sum, all_sum, all_sum/all_sum,all_sum/all_sum) 84 | 85 | fake_discr = torch.softmax(fake_discr, dim=-1) 86 | return mask.sum() / mask.sum(),mask0.sum() / mask0.sum(), fake_discr 87 | else: 88 | self.filter_basket, test_basket = self.G0(input_seq_tensor, T, tar_b_tensor, self.G1_flag,test,input_embeddings,target_embedding) # K * B 89 | real_generate_embedding1 = self.G2(self.filter_basket[:, :, 0], input_seq_tensor,uid) 90 | fake_discr = self.D(real_generate_embedding1, history, tar_b_tensor, input_seq_tensor) # K*n_items 91 | 92 | 93 | ################################################ 94 | select_repeats = mask0 * ((self.filter_basket[:,:,0] > 1 / 2 ).float())#torch.tensor((self.filter_basket[:,:,0] > 1 / 2 ).int(),dtype = torch.long).to(self.device) 95 | 96 | repeat_ratio = (select_repeats.sum(1)/(mask0.sum(1)+1e-24)).sum()/(((mask0.sum(1)>0).float()+1e-24).sum()) 97 | 98 | if train == True: 99 | 100 | rest = mask * self.filter_basket[:, :, 0]#.detach() 101 | rest_int = (rest > 1 / 2).int() 102 | real_rest_sum = rest.sum() 103 | real_rest_sum_int = rest_int.sum() 104 | all_sum = mask.sum() 105 | ratio = (rest * ((rest > 1 / 2).float())).sum() / max(1, ((rest > 1 / 2).float()).sum()) 106 | 107 | # self.rest_basket = torch.ones_like(self.filter_basket,dtype = torch.float).to(self.device) - self.filter_basket 108 | rest_generate_embedding1 = self.G2(self.filter_basket[:, :, 1], input_seq_tensor,uid) 109 | rest_discr = self.D(rest_generate_embedding1, history, tar_b_tensor, input_seq_tensor) 110 | 111 | filter_pos = mask * self.filter_basket[:, :, 0] 112 | filter_neg = mask * self.filter_basket[:, :, 1] 113 | 114 | self.whole_basket = torch.ones_like(input_seq_tensor,dtype = torch.float).to(self.device) # K * B 115 | whole_generate_embedding1 = self.G2(self.whole_basket, input_seq_tensor,uid) # K*H 116 | whole_discr = self.D(whole_generate_embedding1, history, tar_b_tensor) # K*n_items 117 | 118 | loss, (p1, p2, p3, p4) = self.loss2_G1flag1(fake_discr, tar_b_tensor, rest_discr, 119 | whole_discr, filter_pos, filter_neg, 120 | mask) 121 | return loss, fake_discr, (p1, p2, p3, p4), (real_rest_sum, real_rest_sum_int, all_sum, ratio,repeat_ratio) 122 | 123 | select_repeats = mask0 * ((test_basket > 1 / 2).float())#torch.tensor((test_basket > 1 / 2).int(),dtype = torch.long).to(self.device) 124 | test_repeat_ratio = (select_repeats.sum(1) / (mask0.sum(1) + 1e-24)).sum() / (((mask0.sum(1) > 0).float()).sum()) 125 | # test_repeat_ratio = torch.mean(select_repeats.sum(1) / (mask0.sum(1) + 1e-24)) 126 | 127 | rest_test = mask * test_basket.detach() 128 | rest_sum = rest_test.sum() 129 | all_sum = mask.sum() 130 | test_rest_ratio = rest_sum / all_sum 131 | test_generate_embedding1 = self.G2(test_basket, input_seq_tensor,uid) 132 | test_discr = self.D(test_generate_embedding1, history, tar_b_tensor, input_seq_tensor) # K*n_items 133 | test_discr = torch.softmax(test_discr, -1) 134 | return test_rest_ratio,test_repeat_ratio, test_discr 135 | 136 | def loss2_G1flag0(self, fake_discr, target_labels): 137 | ''' 138 | :param fake_discr: K * itemnum 139 | :param target_labels: K 140 | :param neg_labels: K * neg_ratio = K * nK 141 | :return: 142 | ''' 143 | fake_discr = torch.softmax(fake_discr, dim=-1) 144 | 145 | item_num = fake_discr.size(0) # K 146 | index = torch.tensor(np.linspace(0, item_num, num=item_num, endpoint=False), dtype=torch.long) # K 147 | pfake = fake_discr[index, target_labels] # K 148 | 149 | loss_1 = - torch.mean(torch.log((pfake) + 1e-24)) 150 | 151 | loss = self.sd1 * loss_1 152 | return loss, (loss_1, loss_1, loss_1, loss_1) 153 | 154 | def loss2_G1flag1(self, fake_discr, target_labels, rest_discri,whole_discr,filter_pos,filter_neg,mask): 155 | ''' 156 | :param fake_discr: K * itemnum 157 | :param filter_pos: K * B 158 | :param target_labels: K 159 | :param neg_labels: K * neg_ratio = K * nK 160 | :param rest_discri: K * itemnum 161 | :param neg_discri: (Kxjudge_ratio) * itemnum 162 | :return: 163 | ''' 164 | fake_discr = torch.softmax(fake_discr,dim = -1) 165 | rest_discri = torch.softmax(rest_discri,dim = -1) 166 | whole_discr = torch.softmax(whole_discr, dim=-1) 167 | 168 | 169 | item_num = fake_discr.size(0) # K 170 | index = torch.tensor(np.linspace(0, item_num, num=item_num, endpoint=False), dtype=torch.long) # K 171 | pfake = fake_discr[index, target_labels] # K 172 | prest = rest_discri[index, target_labels] # K 173 | pwhole = whole_discr[index, target_labels] # K 174 | 175 | pos_restratio = torch.mean(filter_pos.sum(1).view(1,-1)/(mask.sum(1).view(1,-1)),dim=-1).view(1,-1) # 1*1 176 | pos_margin = self.pos_margin * torch.ones_like(pos_restratio).to(self.device) # 177 | zeros = torch.zeros_like(pos_restratio).to(self.device) 178 | pos_restratio = torch.cat((pos_margin-pos_restratio,zeros),dim=0) #2*1 179 | 180 | neg_restratio = torch.mean(filter_neg.sum(1).view(1,-1)/(mask.sum(1).view(1,-1)),dim=-1).view(1,-1) # 1*1 181 | neg_margin = self.neg_margin * torch.ones_like(neg_restratio).to(self.device) 182 | neg_restratio = torch.cat((neg_restratio - neg_margin, zeros), dim=0) 183 | 184 | loss_0 = torch.max(pos_restratio,dim = 0)[0] + torch.max(neg_restratio,dim = 0)[0] 185 | loss_1 = - torch.mean(torch.nn.LogSigmoid()(pfake - pwhole)) 186 | loss_2 = - torch.mean(torch.nn.LogSigmoid()(pwhole - prest)) 187 | loss_3 = - torch.mean(torch.log(pfake + 1e-24)) 188 | loss0 = loss_0 + loss_1 + loss_2 + loss_3 189 | if (self.G1_flag == 1): 190 | return loss0, (loss_1, loss_2, loss_3, loss_3) 191 | else: 192 | loss_4 = - torch.mean(torch.log((pwhole + 1e-24) )) 193 | loss0 = loss_0 + loss_1 + loss_2 + loss_3 + loss_4 194 | return loss0, (loss_1, loss_2, loss_3, loss_4) 195 | 196 | class Generator1(nn.Module): 197 | def __init__(self, config, device, dropout_p=0.2): 198 | super(Generator1, self).__init__() 199 | self.device = device 200 | self.dropout_p = config.dropout 201 | self.input_size = config.num_product 202 | self.hidden_size = config.embedding_dim 203 | self.max_basket_size = config.max_basket_size 204 | 205 | self.same_embedding = config.same_embedding 206 | 207 | self.temp_learn = config.temp_learn 208 | self.temp = nn.Parameter(torch.ones(1)* config.temp) 209 | self.temp_init = config.temp 210 | 211 | self.embed = nn.Embedding(config.num_product + 1, self.hidden_size, padding_idx=0) 212 | # self.W = nn.Linear(self.hidden_size, self.hidden_size) 213 | self.dropout = nn.Dropout(self.dropout_p) 214 | 215 | self.judge_model = nn.Sequential( 216 | nn.Linear(self.hidden_size * 2, self.hidden_size), 217 | nn.Dropout(self.dropout_p), 218 | nn.LeakyReLU(inplace=True), 219 | nn.Linear(self.hidden_size, 2) 220 | ) 221 | self.judge_model1 = nn.Linear(self.hidden_size * 2, 2) 222 | 223 | # profile 224 | def init_weight(self): 225 | for name, parms in self.named_parameters(): # TODO 226 | parms.data.normal_(0, 0.1) 227 | torch.nn.init.xavier_normal_(self.embed.weight.data) # good 228 | self.temp = nn.Parameter(torch.ones(1).to(self.device)* self.temp_init) # * config.temp # TODO 这个必须弄,不然templearn=1时它一开始rest_ratio = 0.99 229 | 230 | def forward(self, input_seq_tensor, T, target_tensor, G1_flag=1,test=0,input_embeddings = None,target_embedding = None): 231 | ''' 232 | :param input_seq_tensor: K * B 233 | :param T: 234 | :param target_tensor: 235 | :param G1_flag: 236 | :param test: 237 | :param input_embeddings: 238 | :param target_embedding: 239 | :return: 240 | ''' 241 | def hook_fn(grad): 242 | print(grad) 243 | 244 | 245 | # target_embedding = self.embed(target_tensor + 1) 246 | # input_embeddings = self.embed(input_seq_tensor + 1) 247 | # target_embedding = self.W(target_embedding) 248 | # input_embeddings = self.W(input_embeddings) 249 | if self.same_embedding == 0: 250 | target_embedding = self.embed(target_tensor + 1) 251 | input_embeddings = self.embed(input_seq_tensor + 1) 252 | 253 | in_tar = torch.cat( 254 | (input_embeddings, target_embedding.view(target_embedding.size(0), 1, -1).expand_as(input_embeddings)), 255 | dim=2) # K*B*2H 256 | in_tar = self.dropout(in_tar) 257 | resnet_o_prob = self.judge_model1(in_tar) 258 | o_prob = self.judge_model(in_tar) # K*B*2 259 | 260 | o_prob = (o_prob + resnet_o_prob) # 261 | # att_prob = torch.sigmoid(o_prob) 262 | o_prob = torch.softmax(o_prob, dim=-1) 263 | 264 | 265 | if self.temp_learn == 1: 266 | if self.temp > 0: 267 | prob_hard, prob_soft = self.gumbel_softmax(torch.log(o_prob + 1e-24), self.temp, hard=True,input_seq_tensor = input_seq_tensor) 268 | else: 269 | prob_hard, prob_soft = self.gumbel_softmax(torch.log(o_prob + 1e-24), 0.3, hard=True,input_seq_tensor = input_seq_tensor) 270 | else: 271 | prob_hard, prob_soft = self.gumbel_softmax(torch.log(o_prob + 1e-24), T, hard=True,input_seq_tensor = input_seq_tensor) 272 | 273 | prob_soft_new = prob_hard*prob_soft 274 | 275 | # return prob_soft_new, prob_soft_new[:,:,0] 276 | 277 | if test == 0 :#and G1_flag != 0: 278 | return prob_soft_new, None 279 | else: 280 | if self.temp_learn == 1: 281 | if self.temp > 0: 282 | o_prob_hard, o_prob_soft = self.gumbel_test(torch.log(o_prob + 1e-24), self.temp) 283 | else: 284 | o_prob_hard, o_prob_soft = self.gumbel_test(torch.log(o_prob + 1e-24), 0.3) 285 | else: 286 | o_prob_hard, o_prob_soft = self.gumbel_test(torch.log(o_prob + 1e-24), T) 287 | 288 | test_prob_hard = o_prob_hard * o_prob_soft 289 | test_prob_hard = test_prob_hard.detach() 290 | 291 | return test_prob_hard, test_prob_hard[:, :, 0] 292 | 293 | def sample_gumbel(self, shape, eps=1e-20): 294 | U = torch.rand(shape).to(self.device) 295 | return -torch.log(-torch.log(U + eps) + eps) 296 | 297 | def gumbel_softmax_sample(self, logits, temperature,input_seq_tensor = None): 298 | ''' 299 | :param logits: # K*B*2 300 | :param temperature: 301 | :param input_seq_tensor: K*B 302 | :return: 303 | ''' 304 | sample = self.sample_gumbel([int(self.input_size+1),2],eps= 1e-20) # n_items+1 * 2 305 | x_index = input_seq_tensor.clone()+1 #K*B 306 | x_index = x_index.unsqueeze(2).repeat(1,1,2) # K*B*2 307 | # print(x_index.size()) 308 | y_index = torch.zeros_like(input_seq_tensor,dtype = torch.long).to(self.device).unsqueeze(2) #K*B*1 309 | y_index1 = torch.ones_like(input_seq_tensor, dtype=torch.long).to(self.device).unsqueeze(2) # K*B*1 310 | y_index = torch.cat((y_index,y_index1),dim = 2) #K*B*2 311 | # print(y_index.size()) 312 | sample_logits = sample[x_index.long(),y_index.long()] 313 | y = logits + sample_logits 314 | # y = logits + self.sample_gumbel(logits.size()) 315 | return F.softmax(y / temperature, dim=-1) 316 | 317 | def gumbel_softmax(self, logits, temperature, hard=False,input_seq_tensor = None): 318 | """ 319 | ST-gumple-softmax 320 | input: [*, n_class] 321 | return: flatten --> [*, n_class] an one-hot vector 322 | """ 323 | y = self.gumbel_softmax_sample(logits, temperature,input_seq_tensor) 324 | 325 | if not hard: 326 | return y # .view(-1, latent_dim * categorical_dim) 327 | 328 | shape = y.size() 329 | _, ind = y.max(dim=-1) 330 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 331 | y_hard.scatter_(1, ind.view(-1, 1), 1) 332 | y_hard = y_hard.view(*shape) 333 | # Set gradients w.r.t. y_hard gradients w.r.t. y 334 | # y_hard = (y_hard - y).detach() + y 335 | return y_hard, y # .view(-1, latent_dim * categorical_dim) 336 | 337 | def gumbel_test(self, logits, temperature): 338 | y = logits 339 | y = F.softmax(y / temperature, dim=-1) 340 | shape = y.size() 341 | _, ind = y.max(dim=-1) 342 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 343 | y_hard.scatter_(1, ind.view(-1, 1), 1) 344 | y_hard = y_hard.view(*shape) 345 | return y_hard, y # .view(-1, latent_dim * categorical_dim) 346 | 347 | class Generator2(nn.Module): 348 | def __init__(self, config, device, dropout_p=0.2): 349 | super(Generator2, self).__init__() 350 | self.device = device 351 | self.dropout_p = config.dropout 352 | self.input_size = config.num_product 353 | self.num_users = config.num_users 354 | self.hidden_size = config.embedding_dim 355 | self.basket_pool_type = config.basket_pool_type 356 | 357 | self.dropout = nn.Dropout(self.dropout_p) 358 | self.max_basket_size = config.max_basket_size 359 | self.bidirectional = False 360 | self.batch_first = True 361 | self.num_layer = config.num_layer 362 | self.gru_hidden_size = self.hidden_size 363 | 364 | self.embed1 = nn.Embedding(config.num_product + 1, self.gru_hidden_size, padding_idx=0) 365 | 366 | self.gru1 = nn.GRU(self.gru_hidden_size, 367 | self.gru_hidden_size, 368 | num_layers=self.num_layer, 369 | bidirectional=self.bidirectional, 370 | batch_first=self.batch_first) 371 | 372 | 373 | # profile 374 | def forward(self, filter_basket_prob, input_seq_tensor,uid = None): 375 | ''' 376 | :param filter_basket_prob: K * B [13] [2] [0] 2 377 | :param input_embeddings: K * B * H 378 | :param input_seq_tensor: K * B ,with padding -1 379 | :return: fake_target_embeddings # K * hidden_size 380 | # basket_embedding K*basket_num*H 381 | ''' 382 | 383 | def hook_fn(grad): 384 | print(grad) 385 | 386 | mask = (input_seq_tensor != -1).detach().int() 387 | mask = torch.tensor(mask,dtype = torch.float).to(self.device) 388 | filter_basket_prob = filter_basket_prob * mask # K * B 389 | # K*B*H 390 | 391 | input_embeddings1 = self.embed1(input_seq_tensor + 1) 392 | 393 | 394 | 395 | input_embeddings_f1 = torch.mul(input_embeddings1, 396 | filter_basket_prob.view(filter_basket_prob.size(0), -1, 1).expand( 397 | filter_basket_prob.size(0), -1, self.gru_hidden_size)) 398 | 399 | input_embeddings_f1 = input_embeddings_f1.view(input_embeddings_f1.size(0), -1, self.max_basket_size, 400 | self.gru_hidden_size) # K*basket_num*max_basket_size*H 401 | 402 | filter_b = torch.max(filter_basket_prob.view(filter_basket_prob.size(0), -1, self.max_basket_size), dim=-1)[ 403 | 0] # K*basket_num #### 404 | 405 | if self.basket_pool_type == 'avg': 406 | filtered_tensor = filter_basket_prob.view(filter_basket_prob.size(0), -1, 1).expand( 407 | filter_basket_prob.size(0), -1, self.gru_hidden_size).view(input_embeddings_f1.size(0), -1, 408 | self.max_basket_size, 409 | self.gru_hidden_size) 410 | basket_embedding1 = (torch.sum(input_embeddings_f1, dim=2) / ( 411 | filtered_tensor.sum(dim=2) + 1e-10)) # K*basket_num*H 412 | else: 413 | mask_inf = filter_basket_prob.view(filter_basket_prob.size(0), -1, 1).expand( 414 | filter_basket_prob.size(0), -1, self.gru_hidden_size).int() 415 | mask_inf = (1 - mask_inf) * (-9999) 416 | mask_inf = mask_inf.view(mask_inf.size(0), -1, self.max_basket_size, 417 | self.gru_hidden_size) 418 | input_embeddings_f1 = input_embeddings_f1 + mask_inf 419 | basket_embedding1 = (torch.max(input_embeddings_f1, dim=2)[0]) # K*basket_num*H 420 | 421 | input_filter_b = (filter_b > 0).detach().int() # K*basket_num ( value:0/1) 422 | sorted, indices = torch.sort(input_filter_b, descending=True) 423 | lengths = torch.sum(sorted, dim=-1).squeeze().view(1, -1).squeeze(0) 424 | length_mask = (lengths == 0).int() 425 | length_mask = torch.tensor(length_mask, dtype=torch.long).to(self.device) 426 | lengths = lengths + length_mask 427 | inputs1 = basket_embedding1.gather(dim=1, 428 | index=indices.unsqueeze(2).expand_as( 429 | basket_embedding1)) # K*basket_num*H 430 | 431 | # sort data by lengths 432 | _, idx_sort = torch.sort(lengths, dim=0, descending=True) 433 | _, idx_unsort = torch.sort(idx_sort, dim=0) 434 | sort_embed_input1 = inputs1.index_select(0, Variable(idx_sort)) 435 | sort_lengths = lengths[idx_sort] 436 | 437 | sort_lengths = torch.tensor(sort_lengths.clone().cpu(), dtype=torch.int64) 438 | inputs_packed1 = nn.utils.rnn.pack_padded_sequence(sort_embed_input1, 439 | sort_lengths, 440 | batch_first=True) 441 | # process using RNN 442 | out_pack1, ht1 = self.gru1(inputs_packed1) 443 | raw_o = nn.utils.rnn.pad_packed_sequence(out_pack1, batch_first=True) 444 | raw_o = raw_o[0] 445 | raw_o = raw_o[idx_unsort] 446 | x = torch.tensor(np.linspace(0, raw_o.size(0), num=raw_o.size(0), endpoint=False), dtype=torch.long).to(self.device) 447 | y = lengths - 1 448 | outputs_last = raw_o[x, y] # 2,2,6 449 | 450 | # ht1 = torch.transpose(ht1, 0, 1)[idx_unsort] 451 | # ht1 = torch.transpose(ht1, 0, 1) 452 | # out1 = self.fc1(ht1[-1]) # .squeeze() 453 | # out1 = self.fc1(outputs_last) 454 | 455 | return outputs_last # K * hidden_size 456 | 457 | class Discriminator(nn.Module): 458 | def __init__(self, config, device, dropout_p=0.2): 459 | super(Discriminator, self).__init__() 460 | self.device = device 461 | self.dropout_p = config.dropout 462 | self.input_size = config.num_product 463 | self.hidden_size = config.embedding_dim * 2 464 | self.max_basket_size = config.max_basket_size 465 | self.gru_hidden_size = config.embedding_dim 466 | 467 | self.fc1 = nn.Linear(self.gru_hidden_size, self.hidden_size) 468 | 469 | # TODO hidden_size 470 | self.judge_model1 = nn.Sequential( 471 | nn.Linear(self.hidden_size, self.hidden_size), 472 | nn.Dropout(0), 473 | nn.LeakyReLU(inplace=True), 474 | nn.Linear(self.hidden_size, self.input_size) 475 | ) 476 | self.judge_model2 = nn.Sequential( 477 | nn.Linear(self.hidden_size, self.hidden_size), 478 | nn.Dropout(0), 479 | nn.LeakyReLU(inplace=True), 480 | nn.Linear(self.hidden_size, self.input_size) 481 | ) 482 | 483 | self.dropout = nn.Dropout(dropout_p) 484 | # self.histroy = config.histroy 485 | # if self.histroy == 1: 486 | # self.attn = nn.Linear(self.input_size, self.input_size) 487 | 488 | # profile 489 | def forward(self, item_embeddings1, history_record, target_tensor, input_seq_tensor=None): # K * hidden_size 1*9963 490 | def hook_fn(grad): 491 | print(grad) 492 | 493 | item_embeddings1 = self.fc1(item_embeddings1) 494 | item_embeddings1 = self.dropout(item_embeddings1) 495 | judge1 = self.judge_model1(item_embeddings1) # K*input_size 496 | judge2 = self.judge_model2(item_embeddings1) # K*input_size 497 | judge = judge2 + judge1 498 | # judge = torch.softmax(judge2 + judge1, dim=-1) 499 | 500 | return judge # K * n_items 501 | 502 | 503 | 504 | -------------------------------------------------------------------------------- /CLEA/module/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | import random 5 | from collections import defaultdict 6 | 7 | # random.seed(11) 8 | 9 | def get_dict(path): 10 | f = open(path, 'r') 11 | a = f.read() 12 | geted_dict = eval(a) 13 | f.close() 14 | return geted_dict 15 | 16 | 17 | def get_distribute_items(n_items,input_dir,ratio = 0.75): 18 | user_tran_date_dict = get_dict(input_dir) 19 | count = [0.0] * n_items 20 | count_all = 0 21 | for idx, userid in enumerate(list(user_tran_date_dict.keys())): 22 | for basket in user_tran_date_dict[userid]: 23 | for item in basket: 24 | count[item] += 1 25 | count_all += 1 26 | p_item = np.array(count) 27 | 28 | p_item_tensor = torch.from_numpy(np.array(p_item)) 29 | p_item_tensor = torch.pow(p_item_tensor, ratio) 30 | p_item = np.array(p_item_tensor) 31 | # p_item = p_item / count_all 32 | # precision = list(precision.cpu().numpy()) 33 | return p_item 34 | 35 | def get_all_neg_p(neg_sample,p_item): 36 | neg_sample_neg_p = dict() 37 | for u in neg_sample: 38 | neg_index = neg_sample[u] 39 | p_neg = p_item[neg_index] 40 | p_neg = p_neg / np.sum(p_neg) 41 | 42 | if np.sum(p_neg) == 1: 43 | return p_neg 44 | else: 45 | p_neg[0] += (1 - np.sum(p_neg)) 46 | neg_sample_neg_p[u] = p_neg 47 | return neg_sample_neg_p 48 | 49 | def get_neg_p(p_item,neg_set): 50 | neg_index = neg_set#torch.tensor(neg_set,dtype=torch.long).to(device) 51 | p_neg = p_item[neg_index] 52 | p_neg = p_neg / np.sum(p_neg) 53 | 54 | if np.sum(p_neg) == 1: 55 | return p_neg 56 | else: 57 | p_neg[0]+= (1 - np.sum(p_neg)) 58 | return p_neg 59 | 60 | # @profile 61 | def get_dataset(input_dir, max_basket_size,max_basket_num,neg_ratio,history = 0,next_k = 1): 62 | print("--------------Begin Data Process--------------") 63 | neg_ratio = 1 64 | user_tran_date_dict_old = get_dict(input_dir) 65 | 66 | user_tran_date_dict = dict() 67 | for userid in user_tran_date_dict_old.keys(): 68 | seq = user_tran_date_dict_old[userid] 69 | if len(seq) > max_basket_num: 70 | seq = seq[-max_basket_num:] 71 | if len(seq) < 1 + next_k: continue 72 | for b_id, basket in enumerate(seq): 73 | if len(basket) > max_basket_size: 74 | seq[b_id] = basket[-max_basket_size:] 75 | user_tran_date_dict[userid] = seq 76 | 77 | 78 | train_times = 0 79 | valid_times = 0 80 | test_times = 0 81 | 82 | itemnum = 0 83 | for userid in user_tran_date_dict.keys(): 84 | seq = user_tran_date_dict[userid] 85 | for basket in seq: 86 | for item in basket: 87 | if item > itemnum: 88 | itemnum = item 89 | itemnum = itemnum + 1 90 | item_list = [i for i in range(0, itemnum)] 91 | 92 | result_vector = np.zeros(itemnum) 93 | basket_count = 0 94 | for userid in user_tran_date_dict.keys(): 95 | seq = user_tran_date_dict[userid][:-next_k] 96 | for basket in seq: 97 | basket_count += 1 98 | result_vector[basket] += 1 99 | weights = np.zeros(itemnum) 100 | max_freq = basket_count # max(result_vector) 101 | for idx in range(len(result_vector)): 102 | if result_vector[idx] > 0: 103 | weights[idx] = max_freq / result_vector[idx] 104 | else: 105 | weights[idx] = 0 106 | 107 | TRAIN_DATASET = [] 108 | train_batch = defaultdict(list) 109 | VALID_DATASET = [] 110 | valid_batch = defaultdict(list) 111 | TEST_DATASET = [] 112 | test_batch = defaultdict(list) 113 | neg_sample = dict() 114 | 115 | # train_userid_list = list(user_tran_date_dict.keys())[:math.ceil(0.9 * len(list(user_tran_date_dict.keys())))] 116 | all_user_num = len(list(user_tran_date_dict.keys())) 117 | train_user_num = 0 118 | train_userid_list = list(user_tran_date_dict.keys())[:math.ceil(0.9 * len(list(user_tran_date_dict.keys())))] 119 | 120 | for userid in user_tran_date_dict.keys(): 121 | if userid in train_userid_list: 122 | seq = user_tran_date_dict[userid][:-next_k] 123 | else: 124 | seq = user_tran_date_dict[userid][:-next_k] 125 | seq_pool = [] 126 | for basket in seq: 127 | seq_pool = seq_pool + basket 128 | neg_sample[userid] = list(set(item_list) - set(seq_pool)) 129 | 130 | for userid in user_tran_date_dict.keys(): 131 | if userid in train_userid_list: 132 | seq = user_tran_date_dict[userid] 133 | before = [] 134 | train_seq = seq[:-1] 135 | for basketid, basket in enumerate(train_seq): 136 | if len(basket) > max_basket_size: 137 | basket = basket[-max_basket_size:] 138 | else: 139 | padd_num = max_basket_size - len(basket) 140 | padding_item = [-1] * padd_num 141 | basket = basket + padding_item 142 | before.append(basket) 143 | if len(before) == 1: continue 144 | U = userid 145 | S = before[:-1] 146 | S_pool = [] 147 | H = np.zeros(itemnum) 148 | H_pad = np.zeros(itemnum + 1) 149 | for basket in S: 150 | S_pool = S_pool + basket 151 | no_pad_basket = list(set(basket)-set([-1])) 152 | H[no_pad_basket] += 1 153 | H = H / len(before[:-1]) 154 | H_pad[1:] = H 155 | L = len(before[:-1]) 156 | tar_basket = train_seq[basketid] 157 | for item in tar_basket: 158 | T = item 159 | N = random.sample(neg_sample[userid], neg_ratio) 160 | train_batch[L].append((U, S_pool, T, H_pad[0:2], N, L)) 161 | train_times += 1 162 | 163 | test_seq = seq 164 | before = [] 165 | for basketid, basket in enumerate(test_seq): 166 | if len(basket) > max_basket_size: 167 | basket = basket[-max_basket_size:] 168 | else: 169 | padd_num = max_basket_size - len(basket) 170 | padding_item = [-1] * padd_num 171 | basket = basket + padding_item 172 | before.append(basket) 173 | U = userid 174 | S = list(before[:-1]) 175 | S_pool = [] 176 | H = np.zeros(itemnum) 177 | H_pad = np.zeros(itemnum+1) 178 | for basket in S: 179 | S_pool = S_pool + basket 180 | no_pad_basket = list(set(basket) - set([-1])) 181 | H[no_pad_basket] += 1 182 | H = H / len(S) 183 | H_pad[1:] = H 184 | L = len(before[:-1]) 185 | T_basket = before[-1] 186 | test_batch[L].append((U, S_pool, T_basket, H_pad[0:2], L)) 187 | test_times += 1 188 | 189 | 190 | else: 191 | seq = user_tran_date_dict[userid] 192 | before = [] 193 | valid_seq = seq 194 | for basketid, basket in enumerate(valid_seq): 195 | if len(basket) > max_basket_size: 196 | basket = basket[-max_basket_size:] 197 | else: 198 | padd_num = max_basket_size - len(basket) 199 | padding_item = [-1] * padd_num 200 | basket = basket + padding_item 201 | before.append(basket) 202 | if len(before) == 1: continue 203 | if len(before) < len(valid_seq): continue 204 | U = userid 205 | S = before[:-1] 206 | S_pool = [] 207 | H = np.zeros(itemnum) 208 | H_pad = np.zeros(itemnum + 1) 209 | for basket in S: 210 | S_pool = S_pool + basket 211 | no_pad_basket = list(set(basket) - set([-1])) 212 | H[no_pad_basket] += 1 213 | H = H / len(S) 214 | H_pad[1:] = H 215 | L = len(before[:-1]) 216 | tar_basket = valid_seq[basketid] 217 | 218 | if history == 0: 219 | tar_basket = list(set(tar_basket)-set(S_pool)) 220 | if len(tar_basket) < 1:continue 221 | padd_num = max_basket_size - len(tar_basket) 222 | padding_item = [-1] * padd_num 223 | T_basket = tar_basket + padding_item 224 | valid_batch[L].append((U, S_pool, T_basket, H_pad[0:2], L)) 225 | valid_times += 1 226 | else: 227 | T_basket = before[-1] 228 | valid_batch[L].append((U, S_pool, T_basket, H_pad[0:2], L)) 229 | valid_times += 1 230 | 231 | for l in train_batch.keys(): 232 | TRAIN_DATASET.append(list(zip(*train_batch[l]))) 233 | 234 | for l in test_batch.keys(): 235 | TEST_DATASET.append(list(zip(*test_batch[l]))) 236 | 237 | for l in valid_batch.keys(): 238 | VALID_DATASET.append(list(zip(*valid_batch[l]))) 239 | 240 | 241 | 242 | print("--------------Data Process is Over--------------") 243 | return TRAIN_DATASET, VALID_DATASET, TEST_DATASET, neg_sample, weights, itemnum, train_times, test_times, valid_times 244 | 245 | 246 | 247 | # @profile 248 | def get_batch_TRAIN_DATASET(dataset, batch_size): 249 | print('--------------Data Process is Begin--------------') 250 | random.shuffle(dataset) 251 | for idx, (UU, SS, TT, HH, NN, LL) in enumerate(dataset): 252 | userid = torch.tensor(UU, dtype=torch.long) 253 | input_seq = torch.tensor(SS, dtype=torch.long) 254 | target = torch.tensor(TT, dtype=torch.long) 255 | history = torch.from_numpy(np.array(HH)).float() 256 | neg_items = torch.tensor(NN, dtype=torch.long) 257 | 258 | if SS.__len__() < 2: 259 | continue 260 | if SS.__len__() <= batch_size: 261 | batch_userid = userid 262 | batch_input_seq = input_seq 263 | batch_target = target 264 | batch_history = history 265 | batch_neg_items = neg_items 266 | yield (batch_userid,batch_input_seq,batch_target,batch_history,batch_neg_items) 267 | else: 268 | batch_begin = 0 269 | while (batch_begin + batch_size) <= SS.__len__(): 270 | batch_userid = userid[batch_begin:batch_begin + batch_size] 271 | batch_input_seq = input_seq[batch_begin:batch_begin + batch_size] 272 | batch_target = target[batch_begin:batch_begin + batch_size] 273 | batch_history = history[batch_begin:batch_begin + batch_size] 274 | batch_neg_items = neg_items[batch_begin:batch_begin + batch_size] 275 | yield (batch_userid, batch_input_seq, batch_target, batch_history, batch_neg_items) 276 | batch_begin = batch_begin + batch_size 277 | if (batch_begin + batch_size > SS.__len__()) & (batch_begin < SS.__len__()): 278 | 279 | batch_userid = userid[batch_begin:] 280 | batch_input_seq = input_seq[batch_begin:] 281 | batch_target = target[batch_begin:] 282 | batch_history = history[batch_begin:] 283 | batch_neg_items = neg_items[batch_begin:] 284 | yield (batch_userid, batch_input_seq, batch_target, batch_history, batch_neg_items) 285 | 286 | 287 | 288 | 289 | # @profile 290 | def get_batch_TEST_DATASET(TEST_DATASET, batch_size): 291 | BATCHES = [] 292 | random.shuffle(TEST_DATASET) 293 | for idx, (UU, SS, TT_bsk, HH, LL) in enumerate(TEST_DATASET): 294 | 295 | userid = torch.tensor(UU, dtype=torch.long) 296 | input_seq = torch.tensor(SS, dtype=torch.long) 297 | try: 298 | target = torch.tensor(TT_bsk, dtype=torch.long) 299 | except ValueError: 300 | print(TT_bsk) 301 | history = torch.from_numpy(np.array(HH)).float() 302 | 303 | assert UU.__len__() == SS.__len__() 304 | assert UU.__len__() == TT_bsk.__len__() 305 | assert UU.__len__() == HH.__len__() 306 | 307 | if SS.__len__() < 1: continue 308 | if SS.__len__() <= batch_size: 309 | batch_userid = userid 310 | batch_input_seq = input_seq 311 | batch_target = target 312 | batch_history = history 313 | yield (batch_userid, batch_input_seq, batch_target, batch_history) 314 | else: 315 | batch_begin = 0 316 | while (batch_begin + batch_size) <= SS.__len__(): 317 | batch_userid = userid[batch_begin:batch_begin + batch_size] 318 | batch_input_seq = input_seq[batch_begin:batch_begin + batch_size] 319 | batch_target = target[batch_begin:batch_begin + batch_size] 320 | batch_history = history[batch_begin:batch_begin + batch_size] 321 | yield (batch_userid, batch_input_seq, batch_target, batch_history) 322 | batch_begin = batch_begin + batch_size 323 | 324 | if (batch_begin + batch_size > SS.__len__()) & (batch_begin < SS.__len__()): 325 | batch_userid = userid[batch_begin:] 326 | batch_input_seq = input_seq[batch_begin:] 327 | batch_target = target[batch_begin:] 328 | batch_history = history[batch_begin:] 329 | yield (batch_userid, batch_input_seq, batch_target, batch_history) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLEA-new 2 | 3 | The World is Binary: Contrastive Learning for Denoising Next Basket Recommendation 4 | 5 | 6 | 7 |