├── Checkpoint ├── amazon_toy.pth └── douban_book.pth ├── Dataset ├── Amazon_toy │ ├── 1 │ ├── item_index_mix.pkl │ ├── item_index_toy.npy │ ├── item_index_toy.pkl │ ├── mix_log_file_final.pkl │ ├── mix_log_timestep_final.pkl │ ├── toy_log_file_final.pkl │ ├── toy_log_timestep_final.pkl │ └── user_index_overleap.pkl └── Douban_book │ ├── 1 │ ├── book_log_file_final.pkl │ ├── book_log_timestep_final.pkl │ ├── item_index_book.npy │ ├── item_index_book.pkl │ ├── item_index_mix.pkl │ ├── mix_log_file_final.pkl │ ├── mix_log_timestep_final.pkl │ └── user_index_overleap.pkl ├── PDRec.py ├── README.md ├── TI_DiffRec.py ├── models ├── DNN.py ├── gaussian_diffusion.py └── model.py ├── overall_structure.png └── utils └── utils.py /Checkpoint/amazon_toy.pth: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Checkpoint/douban_book.pth: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Dataset/Amazon_toy/1: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Dataset/Amazon_toy/item_index_mix.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Amazon_toy/item_index_mix.pkl -------------------------------------------------------------------------------- /Dataset/Amazon_toy/item_index_toy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Amazon_toy/item_index_toy.npy -------------------------------------------------------------------------------- /Dataset/Amazon_toy/item_index_toy.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Amazon_toy/item_index_toy.pkl -------------------------------------------------------------------------------- /Dataset/Amazon_toy/mix_log_file_final.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Amazon_toy/mix_log_file_final.pkl -------------------------------------------------------------------------------- /Dataset/Amazon_toy/mix_log_timestep_final.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Amazon_toy/mix_log_timestep_final.pkl -------------------------------------------------------------------------------- /Dataset/Amazon_toy/toy_log_file_final.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Amazon_toy/toy_log_file_final.pkl -------------------------------------------------------------------------------- /Dataset/Amazon_toy/toy_log_timestep_final.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Amazon_toy/toy_log_timestep_final.pkl -------------------------------------------------------------------------------- /Dataset/Amazon_toy/user_index_overleap.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Amazon_toy/user_index_overleap.pkl -------------------------------------------------------------------------------- /Dataset/Douban_book/1: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Dataset/Douban_book/book_log_file_final.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Douban_book/book_log_file_final.pkl -------------------------------------------------------------------------------- /Dataset/Douban_book/book_log_timestep_final.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Douban_book/book_log_timestep_final.pkl -------------------------------------------------------------------------------- /Dataset/Douban_book/item_index_book.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Douban_book/item_index_book.npy -------------------------------------------------------------------------------- /Dataset/Douban_book/item_index_book.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Douban_book/item_index_book.pkl -------------------------------------------------------------------------------- /Dataset/Douban_book/item_index_mix.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Douban_book/item_index_mix.pkl -------------------------------------------------------------------------------- /Dataset/Douban_book/mix_log_file_final.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Douban_book/mix_log_file_final.pkl -------------------------------------------------------------------------------- /Dataset/Douban_book/mix_log_timestep_final.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Douban_book/mix_log_timestep_final.pkl -------------------------------------------------------------------------------- /Dataset/Douban_book/user_index_overleap.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/Dataset/Douban_book/user_index_overleap.pkl -------------------------------------------------------------------------------- /PDRec.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import argparse 5 | import os 6 | import io 7 | import math 8 | from matplotlib.pyplot import MultipleLocator 9 | import matplotlib.pyplot as plt 10 | from copy import deepcopy 11 | import random 12 | import ipdb 13 | from models.model import GRU4Rec_withNeg_Dist, SASRec_V1_withNeg_Dist 14 | from models.model import EarlyStopping_onetower 15 | from models.DNN import DNN 16 | import models.gaussian_diffusion as gd 17 | from utils.utils import * 18 | 19 | 20 | # -*- coding: UTF-8 -*- 21 | plt.switch_backend('agg') 22 | np.set_printoptions(suppress=True) 23 | np.set_printoptions(threshold=2000) 24 | from matplotlib.font_manager import FontManager 25 | fm = FontManager() 26 | mat_fonts = set(f.name for f in fm.ttflist) 27 | print(mat_fonts) 28 | 29 | 30 | def str2bool(s): 31 | if s not in {'false', 'true'}: 32 | raise ValueError('Not a valid boolean string') 33 | return s == 'true' 34 | 35 | 36 | def map_into_BCELoss(scores): 37 | return 1/2 * (scores + 1) 38 | 39 | def min_max_normalize_batch(tensor): 40 | min_val = torch.min(tensor, dim=1)[0] 41 | max_val = torch.max(tensor, dim=1)[0] 42 | normalized_tensor = torch.div(tensor - min_val.unsqueeze(1), max_val.unsqueeze(1) - min_val.unsqueeze(1)) 43 | return normalized_tensor 44 | 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--dataset', required=True) 47 | parser.add_argument('--cross_dataset', default='111', type=str) 48 | # parser.add_argument('--train_dir', required=True) 49 | parser.add_argument('--batch_size', default=128, type=int) 50 | parser.add_argument('--lr', default=0.001, type=float) 51 | parser.add_argument('--maxlen', default=200, type=int) 52 | parser.add_argument('--hidden_units', default=64, type=int) 53 | parser.add_argument('--num_blocks', default=2, type=int) 54 | parser.add_argument('--num_epochs', default=1000, type=int) 55 | parser.add_argument('--num_heads', default=1, type=int) 56 | parser.add_argument('--dropout_rate', default=0.2, type=float) 57 | parser.add_argument('--l2_emb', default=0.0, type=float) 58 | parser.add_argument('--device', default='cuda', type=str) 59 | parser.add_argument('--inference_only', default=False, type=str2bool) 60 | parser.add_argument('--state_dict_path', default=None, type=str) 61 | parser.add_argument('--num_samples', default=100, type=int) 62 | parser.add_argument('--decay', default=4, type=int) 63 | parser.add_argument('--lr_decay_rate', default=0.99, type=float) 64 | parser.add_argument('--index', default=0, type=int) 65 | parser.add_argument('--version', default=None, type=str) 66 | parser.add_argument('--lr_linear', default=0.01, type=float) 67 | parser.add_argument('--start_decay_linear', default=8, type=int) 68 | parser.add_argument('--temperature', default=5, type=float) 69 | parser.add_argument('--seed', default=2024, type=int) 70 | parser.add_argument('--lrscheduler', default='ExponentialLR', type=str) 71 | parser.add_argument('--patience', default=10, type=int) 72 | 73 | parser.add_argument('--lr_diff', type=float, default=0.00005, help='learning rate') 74 | parser.add_argument('--weight_decay_diff', type=float, default=0.0) 75 | parser.add_argument('--tst_w_val', action='store_true', help='test with validation') 76 | parser.add_argument('--log_name', type=str, default='log', help='the log name') 77 | parser.add_argument('--round', type=int, default=1, help='record the experiment') 78 | 79 | parser.add_argument('--w_min', type=float, default=0.1, help='the minimum weight for interactions') 80 | parser.add_argument('--w_max', type=float, default=1., help='the maximum weight for interactions') 81 | 82 | # params for the DNN model 83 | parser.add_argument('--time_type', type=str, default='cat', help='cat or add') 84 | parser.add_argument('--dims', type=str, default='[1000]', help='the dims for the DNN') 85 | parser.add_argument('--norm', type=bool, default=False, help='Normalize the input or not') 86 | parser.add_argument('--emb_size', type=int, default=10, help='timestep embedding size') 87 | 88 | # params for diffusion 89 | parser.add_argument('--mean_type', type=str, default='x0', help='MeanType for diffusion: x0, eps') 90 | parser.add_argument('--steps', type=int, default=10, help='diffusion steps') 91 | parser.add_argument('--noise_schedule', type=str, default='linear-var', help='the schedule for noise generating') 92 | parser.add_argument('--noise_scale', type=float, default=0.01, help='noise scale for noise generating') 93 | parser.add_argument('--noise_min', type=float, default=0.0005, help='noise lower bound for noise generating') 94 | parser.add_argument('--noise_max', type=float, default=0.005, help='noise upper bound for noise generating') 95 | parser.add_argument('--sampling_noise', type=bool, default=False, help='sampling with noise or not') 96 | parser.add_argument('--sampling_steps', type=int, default=0, help='steps of the forward process during inference') 97 | parser.add_argument('--reweight', type=bool, default=True, help='assign different weight to different timestep or not') 98 | parser.add_argument('--reweight_version', type=str, default='AllLinear', help='in AllOne, AllLinear, MinMax') 99 | parser.add_argument('--result_path', type=str, default=True, help='the path of result') 100 | parser.add_argument('--filter_prob', type=float, default=0.1, help='the path of result') 101 | parser.add_argument('--scale_weight', type=float, default=1.0, help='the path of result') 102 | parser.add_argument('--scale_max', type=float, default=0.0, help='the path of result') 103 | parser.add_argument('--rank_weight', type=float, default=0.0, help='the path of result') 104 | 105 | parser.add_argument('--cal_version', type=int, default=1, help='the path of result') 106 | parser.add_argument('--candidate_min_percentage_user', default=0, type=int) 107 | parser.add_argument('--candidate_max_percentage_user', default=99, type=int) 108 | parser.add_argument('--top_candidate_coarse_num', default=10, type=int) 109 | parser.add_argument('--top_candidate_fine_num', default=10, type=int) 110 | parser.add_argument('--top_candidate_weight', default=0.1, type=float) 111 | parser.add_argument('--base_model', default='GRU4Rec', type=str) 112 | 113 | args = parser.parse_args() 114 | 115 | 116 | SEED = args.seed 117 | 118 | random.seed(SEED) 119 | np.random.seed(SEED) 120 | torch.manual_seed(SEED) 121 | torch.cuda.manual_seed_all(SEED) 122 | 123 | result_path = './results_filedile/' + str(args.dataset) + '/PDRec_'+str(args.base_model)+'/' 124 | print("Save in path:", result_path) 125 | if not os.path.isdir(result_path): 126 | os.makedirs(result_path) 127 | with open(os.path.join(result_path, 'args.txt'), 'w') as f: 128 | f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])])) 129 | # f.close() 130 | args.result_path = result_path 131 | 132 | if __name__ == '__main__': 133 | dataset = data_partition(args.version, args.dataset, args.cross_dataset, args.maxlen) 134 | [user_train_mix, user_train_source, user_train_target, user_valid_target, user_test_target, user_train_mix_sequence_for_target, user_train_source_sequence_for_target, usernum, itemnum, interval, user_train_ti_mix, user_train_ti_source, user_train_ti_target, user_valid_ti_target, user_test_ti_target] = dataset 135 | 136 | print("the user number is:", usernum) 137 | print("the item number is:", itemnum) 138 | 139 | cc_source = 0.0 140 | for u in user_train_target: 141 | if len(user_train_target[u]) > 0: 142 | cc_source = cc_source + len(user_train_target[u]) 143 | print('average sequence length in source domain: %.2f' % (cc_source / len(user_train_target))) 144 | 145 | 146 | random_min = 0 147 | random_max = 0 148 | random_source_min = 0 149 | random_source_max = 0 150 | if args.dataset == 'amazon_toy': 151 | item_number = interval 152 | random_min = 1 153 | random_max = interval + 1 154 | random_source_min = interval + 1 155 | random_source_max = itemnum + 1 156 | print("The min is {} and the max is {} in amazon_toy".format(random_min, random_max)) 157 | print("The min is {} and the max is {} in source domain".format(random_source_min, random_source_max)) 158 | elif args.dataset == 'douban_book': 159 | item_number = interval 160 | random_min = 1 161 | random_max = interval + 1 162 | random_source_min = interval + 1 163 | random_source_max = itemnum + 1 164 | print("The min is {} and the max is {} in amazon_book".format(random_min, random_max)) 165 | print("The min is {} and the max is {} in source domain".format(random_source_min, random_source_max)) 166 | candidate_min_user = math.floor(item_number * args.candidate_min_percentage_user / 100) 167 | candidate_max_user = math.ceil(item_number * args.candidate_max_percentage_user / 100) 168 | item_list = torch.arange(start=random_min, end=random_max, step=1, device='cuda', requires_grad=False) 169 | print("The item_number is:",item_number) 170 | print("The candidate_min_user is:",candidate_min_user) 171 | print("The candidate_max_user is:",candidate_max_user) 172 | 173 | # ipdb.set_trace() 174 | user_list = [] 175 | for u_i in range(1, usernum): 176 | if len(user_train_source[u_i]) >= 1 and len(user_train_target[u_i]) >= 2: 177 | user_list.append(u_i) 178 | num_batch = math.ceil(len(user_list) / args.batch_size) # 908 179 | if args.base_model == 'GRU4Rec': 180 | model = GRU4Rec_withNeg_Dist(usernum, itemnum, args).cuda() # no ReLU activation in original SASRec implementation? 181 | elif args.base_model == 'SASRec': 182 | model = SASRec_V1_withNeg_Dist(usernum, itemnum, args).cuda() # no ReLU activation in original SASRec implementation? 183 | for name, param in model.named_parameters(): 184 | try: 185 | torch.nn.init.xavier_normal_(param.data) 186 | except: 187 | pass # just ignore those failed init layers 188 | 189 | epoch_start_idx = 1 190 | 191 | bce_criterion = torch.nn.BCEWithLogitsLoss(reduction='none') # torch.nn.BCELoss() 192 | adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98)) 193 | 194 | # set the early stop 195 | early_stopping = EarlyStopping_onetower(args, verbose=True) # 关于 EarlyStopping 的代码可先看博客后面的内容 196 | 197 | # set the learning rate scheduler 198 | if args.lrscheduler == 'Steplr': # 199 | learningrate_scheduler = torch.optim.lr_scheduler.StepLR(adam_optimizer, step_size=args.decay, gamma=args.lr_decay_rate, verbose=True) 200 | elif args.lrscheduler == 'ExponentialLR': # 201 | learningrate_scheduler = torch.optim.lr_scheduler.ExponentialLR(adam_optimizer, gamma=args.lr_decay_rate, last_epoch=-1, verbose=True) 202 | elif args.lrscheduler == 'CosineAnnealingLR': 203 | learningrate_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(adam_optimizer, T_max=args.num_epochs, eta_min=0, last_epoch=-1, verbose=True) 204 | 205 | 206 | ### Build Gaussian Diffusion ### 207 | if args.mean_type == 'x0': 208 | mean_type = gd.ModelMeanType.START_X 209 | elif args.mean_type == 'eps': 210 | mean_type = gd.ModelMeanType.EPSILON 211 | else: 212 | raise ValueError("Unimplemented mean type %s" % args.mean_type) 213 | 214 | # ipdb.set_trace() 215 | diffusion = gd.GaussianDiffusion(mean_type, args.noise_schedule, args.noise_scale, args.noise_min, args.noise_max, args.steps, 'cuda').cuda() 216 | 217 | ### Build MLP ### 218 | out_dims = eval(args.dims) + [itemnum+1] # [1000, 94949] 219 | in_dims = out_dims[::-1] # [94949, 1000] 220 | model_diff = DNN(in_dims, out_dims, args.emb_size, time_type="cat", norm=args.norm).cuda() 221 | 222 | optimizer_diff = torch.optim.AdamW(model_diff.parameters(), lr=args.lr_diff, weight_decay=args.weight_decay_diff) 223 | print("model_diff ready.") 224 | 225 | param_num = 0 226 | mlp_num = sum([param.nelement() for param in model_diff.parameters()]) 227 | diff_num = sum([param.nelement() for param in diffusion.parameters()]) # 0 228 | param_num = mlp_num + diff_num 229 | print("Number of all parameters:", param_num) 230 | 231 | # Same as the pre-trained TI-DiffRec hyper-parameters 232 | if args.dataset == 'amazon_toy': 233 | model_diff_path = './Checkpoint/amazon_toy.pth' 234 | args.lr_diff=5e-5 235 | args.weight_decay_diff=0.5 236 | args.dims='[1000]' 237 | args.emb_size=10 238 | args.mean_type='x0' 239 | args.steps=10 240 | args.noise_scale=0.01 241 | args.noise_min=0.0005 242 | args.noise_max=0.005 243 | args.sampling_steps=0 244 | args.reweight=1 245 | args.w_min=0.5 246 | args.w_max=1.0 247 | args.reweight_version='AllOne' 248 | args.log_name='log' 249 | args.round=1 250 | elif args.dataset == 'douban_book': 251 | model_diff_path = './Checkpoint/douban_book.pth' 252 | args.lr_diff=5e-5 253 | args.weight_decay_diff=0.5 254 | args.dims='[256]' 255 | args.emb_size=8 256 | args.mean_type='x0' 257 | args.steps=10 258 | args.noise_scale=0.01 259 | args.noise_min=0.0005 260 | args.noise_max=0.01 261 | args.sampling_steps=0 262 | args.reweight=1 263 | args.w_min=0.3 264 | args.w_max=1.0 265 | args.reweight_version='AllOne' 266 | args.log_name='log' 267 | args.round=1 268 | # ipdb.set_trace() 269 | model_diff = torch.load(model_diff_path).to('cuda') 270 | model_diff.eval() 271 | 272 | sampler = WarpSampler_V13_final_please_Diff_TI(random_min, random_max, random_source_min, random_source_max, user_train_mix, user_train_source, user_train_target, user_train_ti_mix, user_train_ti_source, user_train_ti_target, user_train_mix_sequence_for_target, user_train_source_sequence_for_target, usernum, itemnum, None, None, w_min = args.w_min, w_max = args.w_max, reweight_version=args.reweight_version, batch_size=args.batch_size, maxlen=args.maxlen, n_workers=3) 273 | 274 | user_nearest_item_raw = torch.zeros([usernum+1, candidate_max_user-candidate_min_user], dtype=torch.int32, device='cuda') 275 | user_top_item_raw = torch.zeros([usernum+1, args.top_candidate_coarse_num], dtype=torch.int32, device='cuda') 276 | diff_matrix = torch.zeros([usernum+1, itemnum+1], device='cuda') 277 | weight_matrix = torch.zeros([usernum+1, itemnum+1], device='cuda') 278 | with torch.no_grad(): 279 | for u in range(1, usernum + 1): 280 | if len(user_train_mix[u]) < 1 or len(user_train_source[u]) < 1 or len(user_train_target[u]) < 1: 281 | continue 282 | # init the tensor 283 | corpus_target_temp = np.zeros([itemnum+1], dtype=np.float32) 284 | # set the position-aware weight 285 | weight_target = scale_withminmax(user_train_ti_target[u], args.w_min, args.w_max, args.reweight_version) 286 | corpus_target_temp[user_train_target[u]] = weight_target 287 | corpus_target_temp = torch.tensor(corpus_target_temp, device='cuda') 288 | diff_prob = diffusion.p_sample(model_diff, corpus_target_temp.unsqueeze(0), args.sampling_steps, args.sampling_noise).detach().squeeze() 289 | # ipdb.set_trace() 290 | sort_indices = torch.sort(input=diff_prob[random_min: random_max], dim=0, descending=True, stable=True)[1][candidate_min_user:candidate_max_user+len(user_train_target[u])+2] # torch.Size([4966]) 291 | user_indices = copy.deepcopy(item_list[sort_indices]) # torch.Size([1001]) 292 | 293 | sort_top_indices = torch.sort(input=diff_prob[random_min: random_max], dim=0, descending=True, stable=True)[1][0:args.top_candidate_coarse_num] # torch.Size([4966]) 294 | user_top_indices = copy.deepcopy(item_list[sort_top_indices]) # torch.Size([1001]) 295 | for it in (user_train_target[u]+user_valid_target[u]+user_test_target[u]): 296 | if it in user_indices: 297 | user_equal_it_index = torch.nonzero(user_indices == it).squeeze(1) 298 | user_indices = del_tensor_ele(user_indices, user_equal_it_index) 299 | 300 | user_nearest_item_raw[u] = user_indices[:candidate_max_user-candidate_min_user] 301 | user_top_item_raw[u] = user_top_indices[:args.top_candidate_coarse_num] 302 | diff_matrix[u] = diff_prob 303 | weight_matrix[u] = corpus_target_temp 304 | if u % 1000 == 0: 305 | print("Diffusion user:", u) 306 | 307 | T = 0.0 308 | t0 = time.time() 309 | for epoch in range(epoch_start_idx, args.num_epochs + 1): 310 | t1 = time.time() 311 | loss_epoch = 0 312 | loss_weight_epoch = 0 313 | model.train() 314 | nearest_index = torch.tensor(np.random.randint(low=0, high=candidate_max_user-candidate_min_user, size=[usernum+1, args.maxlen]), device='cuda') 315 | user_nearest_item = torch.gather(user_nearest_item_raw, dim=-1,index=nearest_index) 316 | 317 | for step in range(num_batch): # tqdm(range(num_batch), total=num_batch, ncols=70, leave=False, unit='b'): 318 | u, seq_mix, seq_source, seq_target, pos_target, neg_target, user_train_mix_sequence_for_target_indices, user_train_source_sequence_for_target_indices, seq_mix_inter, seq_source_inter, seq_target_inter, seq_mix_inter_temp, seq_source_inter_temp, seq_target_inter_temp = sampler.next_batch() # tuples to ndarray 319 | u, seq, pos, neg, seq_target_inter, seq_target_inter_temp = np.array(u), np.array(seq_target), np.array(pos_target), np.array(neg_target), np.array(seq_target_inter), np.array(seq_target_inter_temp) 320 | u, seq, pos, neg, seq_target_inter, seq_target_inter_temp = torch.tensor(u,device='cuda'), torch.tensor(seq,device='cuda'), torch.tensor(pos,device='cuda'), torch.tensor(neg,device='cuda'), torch.tensor(seq_target_inter,device='cuda'), torch.tensor(seq_target_inter_temp,device='cuda') 321 | # NNS module 322 | neg_diff = torch.index_select(user_nearest_item, dim=0, index=u) 323 | neg_diff = torch.where(pos>0, neg_diff, torch.zeros(pos.shape, dtype=torch.int32, device='cuda')) 324 | soft_diff = torch.index_select(user_top_item_raw, dim=0, index=u) 325 | seq_new = torch.cat([seq[:,1:], pos[:,-1].unsqueeze(1)],dim=1) 326 | neg_list = [] 327 | neg_list.append(neg.squeeze()) 328 | neg_list.append(neg_diff) 329 | 330 | # DPA module 331 | pos_logits, neg_logits, soft_logits = model(u, seq, pos, neg_list, seq_new, soft_diff) 332 | sort_soft_index = torch.sort(input=soft_logits, dim=1, descending=True, stable=True)[1][:,0:args.top_candidate_fine_num] 333 | soft_logits = torch.gather(soft_logits, dim=1, index=sort_soft_index) 334 | pos_labels, neg_labels, soft_labels = torch.ones(pos_logits.shape).cuda(), torch.zeros(pos_logits.shape).cuda(), torch.ones(soft_logits.shape).cuda() 335 | 336 | adam_optimizer.zero_grad() 337 | indices = torch.where(pos != 0) 338 | 339 | # HBR module 340 | pos_position = torch.where(pos>0, torch.ones(pos.shape,device='cuda'), torch.zeros(pos.shape,device='cuda')) 341 | diff_prob = torch.index_select(diff_matrix, dim=0,index=u) 342 | diff_prob_batch = torch.gather(diff_prob, dim=-1,index=pos.long()) 343 | diff_weight_times = torch.where(torch.min(diff_prob_batch,dim=1)[0]>0, torch.ones(torch.min(diff_prob_batch,dim=1)[0].shape,device='cuda')*0.5, torch.ones(torch.min(diff_prob_batch,dim=1)[0].shape,device='cuda')*1.5) 344 | diff_prob_batch = torch.where(pos>0, diff_prob_batch, (torch.min(diff_prob_batch,dim=1)[0]*diff_weight_times).unsqueeze(1).repeat([1, diff_prob_batch.shape[1]])) 345 | diff_prob_batch = min_max_normalize_batch(diff_prob_batch) 346 | 347 | _, sorted_indices_diff = torch.sort(diff_prob_batch, dim=1, descending=False) # 对每行进行升序排序并获取排序后的索引 348 | sorted_rank_diff = sorted_indices_diff.argsort(dim=1) 349 | sorted_rank_diff = sorted_rank_diff - (args.maxlen -1 - pos_position.sum(-1).unsqueeze(1)) 350 | 351 | rescale = pos_position.sum(-1) / torch.where(torch.isnan(diff_prob_batch.sum(-1)), torch.ones(diff_prob_batch.sum(-1).shape,device='cuda'), diff_prob_batch.sum(-1)) 352 | diff_prob_batch = torch.mul(diff_prob_batch, rescale.unsqueeze(1)) 353 | diff_intermedia_batch = sorted_rank_diff / torch.where(pos_position.sum(-1)>0, pos_position.sum(-1), torch.ones(diff_prob_batch.sum(-1).shape,device='cuda')*1e3).unsqueeze(1) 354 | diff_rank_batch = torch.where(pos>0, diff_intermedia_batch, torch.zeros(pos.shape,device='cuda')) 355 | diff_prob_batch = diff_prob_batch * (1-args.rank_weight) + diff_rank_batch * args.rank_weight 356 | diff_prob_batch = torch.clamp(diff_prob_batch, 0.0, args.scale_max) 357 | 358 | diff_weight = diff_prob_batch[indices] 359 | loss_reweight = diff_weight*args.scale_weight 360 | 361 | loss = (bce_criterion(pos_logits[indices], pos_labels[indices])*loss_reweight).mean() 362 | loss += bce_criterion(soft_logits, soft_labels).mean() * args.top_candidate_weight 363 | for k in range(0, len(neg_logits)): 364 | loss += bce_criterion(neg_logits[k][indices], neg_labels[indices]).mean() / len(neg_logits) 365 | 366 | loss_epoch += loss.item() 367 | loss_weight_epoch += loss_reweight.mean().item() 368 | 369 | loss.backward() 370 | adam_optimizer.step() 371 | print("In epoch {} iteration {}: loss is {}, loss_weight_mean is {}".format(epoch, step, loss.item(), loss_reweight.mean().item())) 372 | with io.open(result_path + 'loss_log.txt', 'a', encoding='utf-8') as file: 373 | file.write("In epoch {} iteration {}: loss is {}, loss_weight is {}\n".format(epoch, step, loss.item(), loss_reweight.mean().item())) 374 | learningrate_scheduler.step() 375 | 376 | t2 = time.time() 377 | print("In epoch {}: loss is {}, loss_weight is {}, time is {}\n".format(epoch, loss_epoch / num_batch, loss_weight_epoch / num_batch, t2 - t1)) 378 | with io.open(result_path + 'train_loss.txt', 'a', encoding='utf-8') as file: 379 | file.write("In epoch {}: loss is {}, loss_weight is {}, time is {}\n".format(epoch, loss_epoch / num_batch, loss_weight_epoch / num_batch, t2 - t1)) 380 | 381 | model.eval() 382 | # Speed-up evaluation 383 | if epoch > 50: 384 | print('Evaluating', end='') 385 | t_test = evaluate_PDRec(model, dataset, args, user_list) 386 | t3 = time.time() 387 | print('epoch:%d, epoch_time: %.4f(s), total_time: %.4f(s), test:\n' % (epoch, t3-t1, t3-t0)) 388 | print(' test: NDCG@1: %.4f, NDCG@5: %.4f, NDCG@10: %.4f, NDCG@20: %.4f, NDCG@50: %.4f, HR@1: %.4f, HR@5: %.4f, HR@10: %.4f, HR@20: %.4f, HR@50: %.4f, AUC: %.4f, loss: %.4f\n' % (t_test[0], t_test[1], t_test[2], t_test[3], t_test[4], t_test[5], t_test[6], t_test[7], t_test[8], t_test[9], t_test[10], t_test[11])) 389 | 390 | with io.open(result_path + 'test_performance.txt', 'a', encoding='utf-8') as file: 391 | file.write('epoch:%d, epoch_time: %.4f(s), total_time: %.4f(s), test:\n' % (epoch, t3-t1, t3-t0)) 392 | file.write(' NDCG@1: %.4f, NDCG@5: %.4f, NDCG@10: %.4f, NDCG@20: %.4f, NDCG@50: %.4f, HR@1: %.4f, HR@5: %.4f, HR@10: %.4f, HR@20: %.4f, HR@50: %.4f, AUC: %.4f, loss: %.4f\n' % (t_test[0], t_test[1], t_test[2], t_test[3], t_test[4], t_test[5], t_test[6], t_test[7], t_test[8], t_test[9], t_test[10], t_test[11])) 393 | 394 | early_stopping(epoch, model, result_path, t_test) 395 | if early_stopping.early_stop: 396 | print("Save in path:", result_path) 397 | print("Early stopping in the epoch {}, NDCG@1: {:.4f}, NDCG@5: {:.4f}, NDCG@10: {:.4f}, NDCG@20: {:.4f}, NDCG@50: {:.4f}, HR@1: {:.4f}, HR@5: {:.4f}, HR@10: {:.4f}, HR@20: {:.4f}, HR@50: {:.4f}, AUC: {:.4f}".format(early_stopping.save_epoch, early_stopping.best_performance[0], early_stopping.best_performance[1], early_stopping.best_performance[2], early_stopping.best_performance[3], early_stopping.best_performance[4], early_stopping.best_performance[5], early_stopping.best_performance[6], early_stopping.best_performance[7], early_stopping.best_performance[8], early_stopping.best_performance[9], early_stopping.best_performance[10])) 398 | with io.open(result_path + 'save_model.txt', 'a', encoding='utf-8') as file: 399 | file.write("Early stopping in the epoch {}, NDCG@1: {:.4f}, NDCG@5: {:.4f}, NDCG@10: {:.4f}, NDCG@20: {:.4f}, NDCG@50: {:.4f}, HR@1: {:.4f}, HR@5: {:.4f}, HR@10: {:.4f}, HR@20: {:.4f}, HR@50: {:.4f}, AUC: {:.4f}\n".format(early_stopping.save_epoch, early_stopping.best_performance[0], early_stopping.best_performance[1], early_stopping.best_performance[2], early_stopping.best_performance[3], early_stopping.best_performance[4], early_stopping.best_performance[5], early_stopping.best_performance[6], early_stopping.best_performance[7], early_stopping.best_performance[8], early_stopping.best_performance[9], early_stopping.best_performance[10])) 400 | break 401 | 402 | sampler.close() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PDRec 2 | The source code is for the paper: [Plug-In Diffusion Model for Sequential Recommendation](https://arxiv.org/pdf/2401.02913.pdf) accepted in AAAI 2024 by Haokai Ma, Ruobing Xie, Lei Meng, Xin Chen, Xu Zhang, Leyu Lin and Zhanhui Kang. 3 | 4 | ## Overview 5 | This paper presents a novel Plug-In Diffusion Model for Recommendation (PDRec) framework, which employs the diffusion model as a flexible plugin to jointly take full advantage of the diffusion-generating user preferences on all items. Specifically, PDRec first infers the users' dynamic preferences on all items via a time-interval diffusion model and proposes a Historical Behavior Reweighting (HBR) mechanism to identify the high-quality behaviors and suppress noisy behaviors. In addition to the observed items, PDRec proposes a Diffusion-based Positive Augmentation (DPA) strategy to leverage the top-ranked unobserved items as the potential positive samples, bringing in informative and diverse soft signals to alleviate data sparsity. To alleviate the false negative sampling issue, PDRec employs Noise-free Negative Sampling (NNS) to select stable negative samples for ensuring effective model optimization.![_](./overall_structure.png) 6 | 7 | ## Dependencies 8 | - Python 3.8.10 9 | - PyTorch 1.12.0+cu102 10 | - pytorch-lightning==1.6.5 11 | - Torchvision==0.8.2 12 | - Pandas==1.3.5 13 | - Scipy==1.7.3 14 | 15 | ## Implementation of PDRec 16 | We use the Toy dataset from the [Amazon](https://nijianmo.github.io/amazon/index.html) platform and the Book dataset from the [Douban](https://github.com/RUCAIBox/RecBole-CDR) platform, you can get this. 17 | 18 | Due to the file size limitation, you can download the checkpoints of TI-DiffRec released by us from [Google drive](https://drive.google.com/drive/folders/1bD1IO2cG2xkN1WGofXqqz6mV8Ah21FRi?usp=sharing) and place them in the Checkpoint folder. 19 | 20 | ### PDRec (GRU4Rec) on Toy: 21 | ``` 22 | CUDA_VISIBLE_DEVICES=0 python PDRec.py --dataset=amazon_toy --lr 0.005 --temperature 5 --scale_weight 2.0 --scale_max 3.0 --rank_weight 0.1 --candidate_min_percentage_user 50 --top_candidate_coarse_num 50 --top_candidate_fine_num 5 --top_candidate_weight 0.3 --base_model GRU4Rec 23 | ``` 24 | ### PDRec (SASRec) on Toy: 25 | ``` 26 | CUDA_VISIBLE_DEVICES=1 python PDRec.py --dataset=amazon_toy --lr 0.005 --temperature 5 --scale_weight 4.0 --scale_max 1.0 --rank_weight 0.1 --candidate_min_percentage_user 90 --top_candidate_coarse_num 50 --top_candidate_fine_num 5 --top_candidate_weight 0.05 --base_model SASRec 27 | ``` 28 | ### PDRec (GRU4Rec) on Book: 29 | ``` 30 | CUDA_VISIBLE_DEVICES=2 python PDRec.py --dataset=douban_book --lr 0.01 --temperature 10 --scale_weight 4.0 --scale_max 3.0 --rank_weight 0.3 --candidate_min_percentage_user 80 --top_candidate_coarse_num 100 --top_candidate_fine_num 1 --top_candidate_weight 0.01 --base_model GRU4Rec 31 | ``` 32 | ### PDRec (SASRec) on Book: 33 | ``` 34 | CUDA_VISIBLE_DEVICES=3 python PDRec.py --dataset=douban_book --lr 0.001 --temperature 10 --scale_weight 4.0 --scale_max 3.0 --rank_weight 0.5 --candidate_min_percentage_user 80 --top_candidate_coarse_num 100 --top_candidate_fine_num 1 --top_candidate_weight 0.01 --base_model SASRec 35 | ``` 36 | ### TI-DiffRec on Toy 37 | ``` 38 | CUDA_VISIBLE_DEVICES=0 python TI_DiffRec.py --lr=5e-5 --dims=[1000] --emb_size=10 --noise_scale=0.01 --noise_min=0.0005 --noise_max=0.005 --reweight=1 --w_min=0.5 --w_max=1.0 --dataset=amazon_toy 39 | ``` 40 | ### TI-DiffRec on Book 41 | ``` 42 | CUDA_VISIBLE_DEVICES=0 python TI_DiffRec.py --lr=5e-5 --dims=[256] --emb_size=8 --noise_scale=0.01 --noise_min=0.0005 --noise_max=0.01 --reweight=1 --w_min=0.3 --w_max=1.0 --dataset=douban_book 43 | ``` 44 | 45 | ## BibTeX 46 | If you find this work useful for your research, please kindly cite PDRec by: 47 | ``` 48 | @inproceedings{PDRec, 49 | title={Plug-In Diffusion Model for Sequential Recommendation}, 50 | author={Ma, Haokai and Xie, Ruobing and Meng, Lei and Chen, Xin and Zhang, Xu and Lin, Leyu and Kang, Zhanhui}, 51 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 52 | year={2024} 53 | } 54 | ``` 55 | 56 | ## Acknowledgement 57 | The structure of this code is largely based on [DiffRec](https://github.com/YiyanXu/DiffRec) and [SASRec](https://github.com/pmixer/SASRec.pytorch) and the dataset is collected by [Amazon](https://nijianmo.github.io/amazon/index.html) and [RecBole](https://github.com/RUCAIBox/RecBole-CDR). Thanks for these works. 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /TI_DiffRec.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model for recommendation 3 | """ 4 | 5 | import argparse 6 | from ast import parse 7 | import os 8 | import io 9 | import time 10 | import numpy as np 11 | import copy 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torch.utils.data as data 17 | from torch.utils.data import DataLoader 18 | import torch.backends.cudnn as cudnn 19 | import torch.nn.functional as F 20 | 21 | from utils.utils import * 22 | 23 | import models.gaussian_diffusion as gd 24 | from models.DNN import DNN 25 | from copy import deepcopy 26 | import math 27 | import random 28 | import ipdb 29 | 30 | def worker_init_fn(worker_id): 31 | np.random.seed(random_seed + worker_id) 32 | def seed_worker(worker_id): 33 | worker_seed = torch.initial_seed() % 2**32 34 | np.random.seed(worker_seed) 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--lr', type=float, default=5e-5, help='learning rate') 38 | parser.add_argument('--weight_decay', type=float, default=0.5) 39 | parser.add_argument('--batch_size', type=int, default=400) 40 | parser.add_argument('--epochs', type=int, default=1000, help='upper epoch limit') 41 | parser.add_argument('--topN', type=str, default='[10, 20, 50, 100]') 42 | parser.add_argument('--tst_w_val', action='store_true', help='test with validation') 43 | parser.add_argument('--cuda', action='store_true', help='use CUDA') 44 | parser.add_argument('--gpu', type=str, default='0', help='gpu card ID') 45 | parser.add_argument('--save_path', type=str, default='./saved_models/', help='save model path') 46 | parser.add_argument('--log_name', type=str, default='log', help='the log name') 47 | parser.add_argument('--round', type=int, default=1, help='record the experiment') 48 | parser.add_argument('--seed', type=int, default=2024, help='the random seed') 49 | parser.add_argument('--num_samples', default=100, type=int) 50 | 51 | parser.add_argument('--version', default=None, type=str) 52 | parser.add_argument('--dataset', default=None, type=str) 53 | parser.add_argument('--cross_dataset', default=None, type=str) 54 | parser.add_argument('--maxlen', default=200, type=int) 55 | parser.add_argument('--index', default=0, type=int) 56 | 57 | parser.add_argument('--w_min', type=float, default=0.1, help='the minimum weight for interactions') 58 | parser.add_argument('--w_max', type=float, default=1., help='the maximum weight for interactions') 59 | 60 | # params for the model 61 | parser.add_argument('--time_type', type=str, default='cat', help='cat or add') 62 | parser.add_argument('--dims', type=str, default='[1000]', help='the dims for the DNN') 63 | parser.add_argument('--norm', type=bool, default=False, help='Normalize the input or not') 64 | parser.add_argument('--emb_size', type=int, default=10, help='timestep embedding size') 65 | 66 | # params for diffusion 67 | parser.add_argument('--mean_type', type=str, default='x0', help='MeanType for diffusion: x0, eps') 68 | parser.add_argument('--steps', type=int, default=10, help='diffusion steps') 69 | parser.add_argument('--noise_schedule', type=str, default='linear-var', help='the schedule for noise generating') 70 | parser.add_argument('--noise_scale', type=float, default=0.1, help='noise scale for noise generating') 71 | parser.add_argument('--noise_min', type=float, default=0.0001, help='noise lower bound for noise generating') 72 | parser.add_argument('--noise_max', type=float, default=0.02, help='noise upper bound for noise generating') 73 | parser.add_argument('--sampling_noise', type=bool, default=False, help='sampling with noise or not') 74 | parser.add_argument('--sampling_steps', type=int, default=0, help='steps of the forward process during inference') 75 | parser.add_argument('--reweight', type=bool, default=True, help='assign different weight to different timestep or not') 76 | parser.add_argument('--reweight_version', type=str, default='AllOne', help='in AllOne, AllLinear, MinMax') 77 | 78 | parser.add_argument('--lr_decay_rate', default=0.99, type=float) 79 | parser.add_argument('--lrscheduler', default='ExponentialLR', type=str) 80 | args = parser.parse_args() 81 | 82 | # ipdb.set_trace() 83 | random_seed = args.seed 84 | torch.manual_seed(random_seed) # cpu 85 | torch.cuda.manual_seed(random_seed) # gpu 86 | np.random.seed(random_seed) # numpy 87 | random.seed(random_seed) # random and transforms 88 | torch.backends.cudnn.deterministic=True # cudnn 89 | 90 | print("args:", args) 91 | 92 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 93 | device = torch.device("cuda:0" if args.cuda else "cpu") 94 | 95 | print("Starting time: ", time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) 96 | 97 | result_path = './results_file/' + str(args.dataset) + '/TI_DiffRec/' 98 | print("Save in path:", result_path) 99 | if not os.path.isdir(result_path): 100 | os.makedirs(result_path) 101 | with open(os.path.join(result_path, 'args.txt'), 'w') as f: 102 | f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])])) 103 | # f.close() 104 | 105 | 106 | ### DATA LOAD ### 107 | dataset = data_partition(args.version, args.dataset, args.cross_dataset, args.maxlen) 108 | 109 | [user_train_mix, user_train_source, user_train_target, user_valid_target, user_test_target, user_train_mix_sequence_for_target, user_train_source_sequence_for_target, usernum, itemnum, interval, user_train_ti_mix, user_train_ti_source, user_train_ti_target, user_valid_ti_target, user_test_ti_target] = dataset 110 | 111 | num_batch = math.ceil(len(user_train_source) / args.batch_size) # 908 112 | cc_source = 0.0 113 | cc_target = 0.0 114 | for u in user_train_source: 115 | cc_source = cc_source + len(user_train_source[u]) 116 | cc_target = cc_target + len(user_train_target[u]) 117 | 118 | # Toy_Game----Toy: 8.22 / 12.32 / 20.54 119 | # Toy_Game----Game: 11.73 / 8.36 / 20.10 120 | print('average sequence length in source domain: %.2f' % (cc_source / len(user_train_source))) 121 | print('average sequence length in target domain: %.2f' % (cc_target / len(user_train_source))) 122 | print('average sequence length in both domain: %.2f' % ((cc_source + cc_target) / len(user_train_source))) 123 | 124 | 125 | random_min = 0 126 | random_max = 0 127 | random_source_min = 0 128 | random_source_max = 0 129 | if args.dataset == 'amazon_toy': 130 | random_min = 1 131 | random_max = interval + 1 132 | random_source_min = interval + 1 133 | random_source_max = itemnum + 1 134 | print("The min is {} and the max is {} in amazon_toy".format(random_min, random_max)) 135 | print("The min is {} and the max is {} in source domain".format(random_source_min, random_source_max)) 136 | elif args.dataset == 'douban_book': 137 | random_min = 1 138 | random_max = interval + 1 139 | random_source_min = interval + 1 140 | random_source_max = itemnum + 1 141 | print("The min is {} and the max is {} in douban_book".format(random_min, random_max)) 142 | print("The min is {} and the max is {} in source domain".format(random_source_min, random_source_max)) 143 | 144 | sampler = WarpSampler_T_DiffCDR_TI(random_min, random_max, random_source_min, random_source_max, user_train_mix, user_train_source, user_train_target, user_train_ti_mix, user_train_ti_source, user_train_ti_target, usernum, itemnum, batch_size=args.batch_size, w_min = args.w_min, w_max = args.w_max, reweight_version = args.reweight_version, n_workers=3) 145 | 146 | 147 | ### Build Gaussian Diffusion ### 148 | if args.mean_type == 'x0': 149 | mean_type = gd.ModelMeanType.START_X 150 | elif args.mean_type == 'eps': 151 | mean_type = gd.ModelMeanType.EPSILON 152 | else: 153 | raise ValueError("Unimplemented mean type %s" % args.mean_type) 154 | 155 | # ipdb.set_trace() 156 | diffusion = gd.GaussianDiffusion(mean_type, args.noise_schedule, args.noise_scale, args.noise_min, args.noise_max, args.steps, 'cuda').cuda() 157 | 158 | ### Build MLP ### 159 | out_dims = eval(args.dims) + [itemnum+1] # [1000, 94949] 160 | in_dims = out_dims[::-1] # [94949, 1000] 161 | model = DNN(in_dims, out_dims, args.emb_size, time_type="cat", norm=args.norm).cuda() 162 | 163 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 164 | print("models ready.") 165 | if args.lrscheduler == 'Steplr': # 166 | learningrate_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.decay, gamma=args.lr_decay_rate, verbose=True) 167 | elif args.lrscheduler == 'ExponentialLR': # 168 | learningrate_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay_rate, last_epoch=-1, verbose=True) 169 | elif args.lrscheduler == 'CosineAnnealingLR': 170 | learningrate_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epochs, eta_min=0, last_epoch=-1, verbose=True) 171 | 172 | param_num = 0 173 | mlp_num = sum([param.nelement() for param in model.parameters()]) 174 | diff_num = sum([param.nelement() for param in diffusion.parameters()]) # 0 175 | param_num = mlp_num + diff_num 176 | print("Number of all parameters:", param_num) 177 | 178 | best_recall, best_epoch = -100, 0 179 | best_result = None 180 | print("Start training...") 181 | for epoch in range(1, args.epochs + 1): 182 | if epoch - best_epoch >= 10: 183 | print('-'*18) 184 | print('Exiting from training early') 185 | break 186 | 187 | model.train() 188 | start_time = time.time() 189 | 190 | batch_count = 0 191 | total_loss = 0.0 192 | 193 | # ipdb.set_trace() 194 | t_start = time.time() 195 | for step in range(num_batch): # tqdm(range(num_batch), total=num_batch, ncols=70, leave=False, unit='b'): 196 | user, seq_mix, seq_source, seq_target, seq_mix_temp, seq_source_temp, seq_target_temp = sampler.next_batch() 197 | user, seq_mix, seq_source, seq_target, seq_mix_temp, seq_source_temp, seq_target_temp = np.array(user), np.array(seq_mix), np.array(seq_source), np.array(seq_target), np.array(seq_mix_temp), np.array(seq_source_temp), np.array(seq_target_temp) 198 | user, seq_mix, seq_source, seq_target, seq_mix_temp, seq_source_temp, seq_target_temp = torch.tensor(user,device='cuda'), torch.tensor(seq_mix,device='cuda'), torch.tensor(seq_source,device='cuda'), torch.tensor(seq_target,device='cuda'), torch.tensor(seq_mix_temp,device='cuda'), torch.tensor(seq_source_temp,device='cuda'), torch.tensor(seq_target_temp,device='cuda') 199 | batch_count += 1 200 | optimizer.zero_grad() 201 | # ipdb.set_trace() 202 | losses = diffusion.training_losses(model, seq_target_temp, args.reweight) 203 | loss = losses["loss"].mean() 204 | total_loss += loss.item() 205 | loss.backward() 206 | optimizer.step() 207 | print(" In epoch {} iteration {}: loss={:.4f}".format(epoch, step, loss.item())) 208 | # ipdb.set_trace() 209 | t_end = time.time() 210 | print("Time interval of one epoch:{:.4f}".format(t_end-t_start)) 211 | # ipdb.set_trace() 212 | learningrate_scheduler.step() 213 | print("The end batch_count is:",batch_count) 214 | print("In epoch {}: loss_mean={:.4f}, lr={}".format(epoch, total_loss/step, learningrate_scheduler.get_last_lr())) 215 | 216 | if epoch % 1 == 0: 217 | model.eval() 218 | # ipdb.set_trace() 219 | t_test = evaluate_T_DiffRec_TI(model, diffusion, dataset, args, random_min, random_max, random_source_min, random_source_max) 220 | print('epoch:%d, epoch_time: %.4f(s): NDCG@1: %.4f, NDCG@5: %.4f, NDCG@10: %.4f, NDCG@20: %.4f, NDCG@50: %.4f, HR@1: %.4f, HR@5: %.4f, HR@10: %.4f, HR@20: %.4f, HR@50: %.4f, AUC: %.4f\n' % (epoch, time.time()-start_time, t_test[0], t_test[1], t_test[2], t_test[3], t_test[4], t_test[5], t_test[6], t_test[7], t_test[8], t_test[9], t_test[10])) 221 | with io.open(result_path + 'test_performance.txt', 'a', encoding='utf-8') as file: 222 | file.write('epoch:%d, epoch_time: %.4f(s), NDCG@1: %.4f, NDCG@5: %.4f, NDCG@10: %.4f, NDCG@20: %.4f, NDCG@50: %.4f, HR@1: %.4f, HR@5: %.4f, HR@10: %.4f, HR@20: %.4f, HR@50: %.4f, AUC: %.4f\n' % (epoch, time.time()-start_time, t_test[0], t_test[1], t_test[2], t_test[3], t_test[4], t_test[5], t_test[6], t_test[7], t_test[8], t_test[9], t_test[10])) 223 | 224 | if t_test[2] > best_recall: # NDCG@10 as selection 225 | best_recall, best_epoch = t_test[2], epoch 226 | best_results = t_test 227 | 228 | torch.save(model, '{}{}_lr{}_wd{}_bs{}_dims{}_emb{}_{}_steps{}_scale{}_min{}_max{}_sample{}_reweight{}_wmin{}_wmax{}_{}.pth' \ 229 | .format(result_path, args.dataset, args.lr, args.weight_decay, args.batch_size, args.dims, args.emb_size, args.mean_type, args.steps, args.noise_scale, args.noise_min, args.noise_max, args.sampling_steps, args.reweight, args.w_min, args.w_max, args.log_name)) 230 | 231 | print("Runing Epoch {:03d} ".format(epoch) + 'train loss {:.4f}'.format(total_loss) + " costs " + time.strftime( 232 | "%H: %M: %S", time.gmtime(time.time()-start_time))) 233 | print('---'*18) 234 | 235 | print('==='*18) 236 | print("End. Best Epoch {:03d} ".format(best_epoch)) 237 | print('Best results: epoch:{:d}, NDCG@1: {:.4f}, NDCG@5: {:.4f}, NDCG@10: {:.4f}, NDCG@20: {:.4f}, NDCG@50: {:.4f}, HR@1: {:.4f}, HR@5: {:.4f}, HR@10: {:.4f}, HR@20: {:.4f}, HR@50: {:.4f}, AUC: {:.4f}\n'.format(epoch, best_results[0], best_results[1], best_results[2], best_results[3], best_results[4], best_results[5], best_results[6], best_results[7], best_results[8], best_results[9], best_results[10])) 238 | with io.open(result_path + 'test_performance.txt', 'a', encoding='utf-8') as file: 239 | file.write('======================================================\n') 240 | file.write('Best results: epoch:{:d}, NDCG@1: {:.4f}, NDCG@5: {:.4f}, NDCG@10: {:.4f}, NDCG@20: {:.4f}, NDCG@50: {:.4f}, HR@1: {:.4f}, HR@5: {:.4f}, HR@10: {:.4f}, HR@20: {:.4f}, HR@50: {:.4f}, AUC: {:.4f}\n'.format(epoch, best_results[0], best_results[1], best_results[2], best_results[3], best_results[4], best_results[5], best_results[6], best_results[7], best_results[8], best_results[9], best_results[10])) 241 | 242 | print("End time: ", time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) 243 | 244 | 245 | 246 | 247 | 248 | -------------------------------------------------------------------------------- /models/DNN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import numpy as np 5 | import math 6 | 7 | class DNN(nn.Module): 8 | """ 9 | A deep neural network for the reverse diffusion preocess. 10 | """ 11 | def __init__(self, in_dims, out_dims, emb_size, time_type="cat", norm=False, dropout=0.5): 12 | # in_dims: [94949, 1000] 13 | # out_dims: [1000, 94949] 14 | # emb_size: 10 15 | # time_type: 'cat' 16 | # norm: False 17 | # dropout: 0.5 18 | super(DNN, self).__init__() 19 | self.in_dims = in_dims # [49604, 1000] 20 | self.out_dims = out_dims # [1000, 49604] 21 | assert out_dims[0] == in_dims[-1], "In and out dimensions must equal to each other." 22 | self.time_type = time_type # 'cat' 23 | self.time_emb_dim = emb_size # 10 24 | self.norm = norm # False 25 | 26 | self.emb_layer = nn.Linear(self.time_emb_dim, self.time_emb_dim) # Linear(in_features=10, out_features=10, bias=True) 27 | 28 | if self.time_type == "cat": 29 | in_dims_temp = [self.in_dims[0] + self.time_emb_dim] + self.in_dims[1:] # [49614, 1000] 30 | else: 31 | raise ValueError("Unimplemented timestep embedding type %s" % self.time_type) 32 | out_dims_temp = self.out_dims # [1000, 49604] 33 | 34 | self.in_layers = nn.ModuleList([nn.Linear(d_in, d_out) \ 35 | for d_in, d_out in zip(in_dims_temp[:-1], in_dims_temp[1:])]) 36 | # ModuleList( 37 | # (0): Linear(in_features=49614, out_features=1000, bias=True) 38 | # ) 39 | self.out_layers = nn.ModuleList([nn.Linear(d_in, d_out) \ 40 | for d_in, d_out in zip(out_dims_temp[:-1], out_dims_temp[1:])]) 41 | # ModuleList( 42 | # (0): Linear(in_features=1000, out_features=49604, bias=True) 43 | # ) 44 | 45 | self.drop = nn.Dropout(dropout) 46 | self.init_weights() 47 | 48 | def init_weights(self): 49 | for layer in self.in_layers: 50 | # Xavier Initialization for weights 51 | size = layer.weight.size() 52 | fan_out = size[0] 53 | fan_in = size[1] 54 | std = np.sqrt(2.0 / (fan_in + fan_out)) 55 | layer.weight.data.normal_(0.0, std) 56 | 57 | # Normal Initialization for weights 58 | layer.bias.data.normal_(0.0, 0.001) 59 | 60 | for layer in self.out_layers: 61 | # Xavier Initialization for weights 62 | size = layer.weight.size() 63 | fan_out = size[0] 64 | fan_in = size[1] 65 | std = np.sqrt(2.0 / (fan_in + fan_out)) 66 | layer.weight.data.normal_(0.0, std) 67 | 68 | # Normal Initialization for weights 69 | layer.bias.data.normal_(0.0, 0.001) 70 | 71 | size = self.emb_layer.weight.size() 72 | fan_out = size[0] 73 | fan_in = size[1] 74 | std = np.sqrt(2.0 / (fan_in + fan_out)) 75 | self.emb_layer.weight.data.normal_(0.0, std) 76 | self.emb_layer.bias.data.normal_(0.0, 0.001) 77 | 78 | def forward(self, x, timesteps): 79 | time_emb = timestep_embedding(timesteps, self.time_emb_dim).to(x.device) # torch.Size([400, 10]) 80 | emb = self.emb_layer(time_emb) # torch.Size([400, 10])----linear project 81 | if self.norm: 82 | x = F.normalize(x) 83 | x = self.drop(x) # torch.Size([400, 94949]) 84 | h = torch.cat([x, emb], dim=-1) # torch.Size([400, 94959]) 85 | for i, layer in enumerate(self.in_layers): 86 | h = layer(h) 87 | h = torch.tanh(h) # torch.Size([400, 1000]) 88 | 89 | for i, layer in enumerate(self.out_layers): 90 | h = layer(h) 91 | if i != len(self.out_layers) - 1: 92 | h = torch.tanh(h) 93 | 94 | return h # torch.Size([400, 94949]) 95 | 96 | 97 | def timestep_embedding(timesteps, dim, max_period=10000): 98 | """ 99 | Create sinusoidal timestep embeddings. 100 | 101 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 102 | These may be fractional. 103 | :param dim: the dimension of the output. 104 | :param max_period: controls the minimum frequency of the embeddings. 105 | :return: an [N x dim] Tensor of positional embeddings. 106 | """ 107 | 108 | half = dim // 2 # 5 109 | freqs = torch.exp( 110 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 111 | ).to(timesteps.device) # torch.Size([5]) 112 | args = timesteps[:, None].float() * freqs[None] # torch.Size([400, 5]) 113 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # torch.Size([400, 10]) 114 | if dim % 2: 115 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) # torch.Size([400, 10]) 116 | return embedding 117 | -------------------------------------------------------------------------------- /models/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import math 3 | import numpy as np 4 | import torch as th 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import ipdb 8 | class ModelMeanType(enum.Enum): 9 | START_X = enum.auto() # the model predicts x_0 10 | EPSILON = enum.auto() # the model predicts epsilon 11 | 12 | class GaussianDiffusion(nn.Module): 13 | def __init__(self, mean_type, noise_schedule, noise_scale, noise_min, noise_max,\ 14 | steps, device, history_num_per_term=10, beta_fixed=True): 15 | self.mean_type = mean_type # 16 | self.noise_schedule = noise_schedule # 'linear-var' 17 | self.noise_scale = noise_scale # 0.0005 18 | self.noise_min = noise_min # 0.001 19 | self.noise_max = noise_max # 0.005 20 | self.steps = steps # 10 21 | self.device = device # self.device 22 | 23 | self.history_num_per_term = history_num_per_term # 10 24 | self.Lt_history = th.zeros(steps, history_num_per_term, dtype=th.float64).to(device) # torch.Size([10, 10]) 25 | self.Lt_count = th.zeros(steps, dtype=int).to(device) # torch.Size([10]) 26 | 27 | if noise_scale != 0.: 28 | self.betas = th.tensor(self.get_betas(), dtype=th.float64).to(self.device) # torch.Size([10]) 29 | if beta_fixed: 30 | self.betas[0] = 0.00001 # Deep Unsupervised Learning using Noneequilibrium Thermodynamics 2.4.1 31 | # The variance \beta_1 of the first step is fixed to a small constant to prevent overfitting. 32 | assert len(self.betas.shape) == 1, "betas must be 1-D" 33 | assert len(self.betas) == self.steps, "num of betas must equal to diffusion steps" 34 | assert (self.betas > 0).all() and (self.betas <= 1).all(), "betas out of range" 35 | self.calculate_for_diffusion() 36 | 37 | super(GaussianDiffusion, self).__init__() 38 | 39 | def get_betas(self): 40 | """ 41 | Given the schedule name, create the betas for the diffusion process. 42 | """ 43 | # 线性的加噪方案,DDPM的加噪方案 44 | if self.noise_schedule == "linear" or self.noise_schedule == "linear-var": 45 | start = self.noise_scale * self.noise_min 46 | end = self.noise_scale * self.noise_max 47 | if self.noise_schedule == "linear": 48 | return np.linspace(start, end, self.steps, dtype=np.float64) 49 | else: 50 | return betas_from_linear_variance(self.steps, np.linspace(start, end, self.steps, dtype=np.float64)) 51 | elif self.noise_schedule == "cosine": 52 | return betas_for_alpha_bar( 53 | self.steps, 54 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 55 | ) 56 | elif self.noise_schedule == "binomial": # Deep Unsupervised Learning using Noneequilibrium Thermodynamics 2.4.1 57 | ts = np.arange(self.steps) 58 | betas = [1 / (self.steps - t + 1) for t in ts] 59 | return betas 60 | else: 61 | raise NotImplementedError(f"unknown beta schedule: {self.noise_schedule}!") 62 | 63 | def calculate_for_diffusion(self): 64 | alphas = 1.0 - self.betas # α torch.Size([10]) 65 | self.alphas_cumprod = th.cumprod(alphas, axis=0) # $\bar{\alpha}_t$ 66 | self.alphas_cumprod_prev = th.cat([th.tensor([1.0]).to(self.device), self.alphas_cumprod[:-1]]).to(self.device) # $\bar{\alpha}_{t-1}$ 1.0 and alpha[1:step] 67 | self.alphas_cumprod_next = th.cat([self.alphas_cumprod[1:], th.tensor([0.0]).to(self.device)]).to(self.device) # $\bar{\alpha}_{t+1}$ alpha[0:step-1] and 0.0 68 | assert self.alphas_cumprod_prev.shape == (self.steps,) 69 | 70 | # calculations for diffusion q(x_t | x_{t-1}) and others 71 | self.sqrt_alphas_cumprod = th.sqrt(self.alphas_cumprod) # torch.Size([10]) 72 | self.sqrt_one_minus_alphas_cumprod = th.sqrt((1.0 - self.alphas_cumprod)) # torch.Size([10]) 73 | 74 | self.log_one_minus_alphas_cumprod = th.log(1.0 - self.alphas_cumprod) # torch.Size([10]) 75 | self.sqrt_recip_alphas_cumprod = th.sqrt((1.0 / self.alphas_cumprod)) # torch.Size([10]) 76 | self.sqrt_recipm1_alphas_cumprod = th.sqrt((1.0 / self.alphas_cumprod - 1)) # torch.Size([10]) 77 | 78 | # calculations for posterior q(x_{t-1} | x_t, x_0) ---- equation 10 79 | self.posterior_variance = ( 80 | self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 81 | ) # torch.Size([10]) 82 | 83 | # log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain. 84 | # 后验分布方差在扩散模型开始处为0,计算对视时需要进行截断,就是用t=1时的值替代t=0时刻的值 85 | self.posterior_log_variance_clipped = th.log( 86 | th.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]]) 87 | ) # torch.Size([10]) 88 | 89 | # 后验分布计算均值公式的两个系数,对应于论文中公式11 90 | self.posterior_mean_coef1 = ( 91 | self.betas * th.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 92 | ) # torch.Size([10]) 93 | self.posterior_mean_coef2 = ( 94 | (1.0 - self.alphas_cumprod_prev) * th.sqrt(alphas) / (1.0 - self.alphas_cumprod) 95 | ) # torch.Size([10]) 96 | 97 | # 从q(x_t | x_0)中采样图像 98 | def p_sample(self, model, x_start, steps, sampling_noise=False): 99 | # model 100 | # x_start: torch.Size([400, 94949]) 101 | # steps 102 | # sampling_noise 103 | assert steps <= self.steps, "Too much steps in inference." 104 | if steps == 0: 105 | x_t = x_start # ---- 106 | else: 107 | t = th.tensor([steps - 1] * x_start.shape[0]).to(x_start.device) 108 | x_t = self.q_sample(x_start, t) 109 | 110 | indices = list(range(self.steps))[::-1] # ---- 111 | 112 | if self.noise_scale == 0.: 113 | for i in indices: 114 | t = th.tensor([i] * x_t.shape[0]).to(x_start.device) 115 | x_t = model(x_t, t) 116 | return x_t 117 | 118 | # Reverse step by step 119 | for i in indices: 120 | t = th.tensor([i] * x_t.shape[0]).to(x_start.device) # ---- 121 | out = self.p_mean_variance(model, x_t, t) # ---- 122 | if sampling_noise: 123 | noise = th.randn_like(x_t) 124 | nonzero_mask = ( 125 | (t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))) 126 | ) # no noise when t == 0 127 | x_t = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 128 | else: 129 | x_t = out["mean"] # ---- 130 | return x_t # predicted x_0 131 | 132 | def training_losses(self, model, x_start, reweight=False): 133 | batch_size, device = x_start.size(0), x_start.device # 400, device(type='cuda', index=0) 134 | ts, pt = self.sample_timesteps(batch_size, device, 'importance') # torch.Size([400]), torch.Size([400]) 135 | noise = th.randn_like(x_start) # torch.Size([400, 94949]) 136 | if self.noise_scale != 0.: 137 | x_t = self.q_sample(x_start, ts, noise) 138 | else: 139 | x_t = x_start 140 | 141 | terms = {} 142 | model_output = model(x_t, ts) # torch.Size([400, 49604]) 143 | target = { 144 | ModelMeanType.START_X: x_start, 145 | ModelMeanType.EPSILON: noise, 146 | }[self.mean_type] 147 | 148 | assert model_output.shape == target.shape == x_start.shape 149 | 150 | mse = mean_flat((target - model_output) ** 2) 151 | 152 | if reweight == True: 153 | if self.mean_type == ModelMeanType.START_X: 154 | weight = self.SNR(ts - 1) - self.SNR(ts) # torch.Size([400]) 155 | weight = th.where((ts == 0), 1.0, weight) # torch.Size([400]) 156 | loss = mse 157 | elif self.mean_type == ModelMeanType.EPSILON: 158 | weight = (1 - self.alphas_cumprod[ts]) / ((1-self.alphas_cumprod_prev[ts])**2 * (1-self.betas[ts])) 159 | weight = th.where((ts == 0), 1.0, weight) 160 | likelihood = mean_flat((x_start - self._predict_xstart_from_eps(x_t, ts, model_output))**2 / 2.0) 161 | loss = th.where((ts == 0), likelihood, mse) 162 | else: 163 | weight = th.tensor([1.0] * len(target)).to(device) 164 | 165 | terms["loss"] = weight * loss 166 | 167 | # update Lt_history & Lt_count 168 | for t, loss in zip(ts, terms["loss"]): 169 | if self.Lt_count[t] == self.history_num_per_term: 170 | Lt_history_old = self.Lt_history.clone() 171 | self.Lt_history[t, :-1] = Lt_history_old[t, 1:] 172 | self.Lt_history[t, -1] = loss.detach() 173 | else: 174 | try: 175 | self.Lt_history[t, self.Lt_count[t]] = loss.detach() 176 | self.Lt_count[t] += 1 177 | except: 178 | print(t) 179 | print(self.Lt_count[t]) 180 | print(loss) 181 | raise ValueError 182 | 183 | terms["loss"] /= pt 184 | return terms 185 | 186 | def sample_timesteps(self, batch_size, device, method='uniform', uniform_prob=0.001): 187 | if method == 'importance': # importance sampling 188 | if not (self.Lt_count == self.history_num_per_term).all(): 189 | return self.sample_timesteps(batch_size, device, method='uniform') 190 | 191 | Lt_sqrt = th.sqrt(th.mean(self.Lt_history ** 2, axis=-1)) 192 | pt_all = Lt_sqrt / th.sum(Lt_sqrt) 193 | pt_all *= 1- uniform_prob 194 | pt_all += uniform_prob / len(pt_all) 195 | 196 | assert pt_all.sum(-1) - 1. < 1e-5 197 | 198 | t = th.multinomial(pt_all, num_samples=batch_size, replacement=True) 199 | pt = pt_all.gather(dim=0, index=t) * len(pt_all) 200 | 201 | return t, pt 202 | 203 | elif method == 'uniform': # uniform sampling 204 | t = th.randint(0, self.steps, (batch_size,), device=device).long() 205 | pt = th.ones_like(t).float() 206 | 207 | return t, pt 208 | 209 | else: 210 | raise ValueError 211 | 212 | # 从q(x_t | x_0)中采样图像 213 | def q_sample(self, x_start, t, noise=None): 214 | if noise is None: # 如果没有传入噪声 215 | noise = th.randn_like(x_start) # # 从标准分布中随机采样一个与x_0大小一致的噪音 torch.Size([400, 94949]) 216 | assert noise.shape == x_start.shape 217 | return ( 218 | self._extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 219 | + self._extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 220 | * noise # 直接用公式9进行重参数采样得到x_t 221 | ) 222 | 223 | # 完整对应论文中的公式10和11,计算后验分布的均值和方差 224 | def q_posterior_mean_variance(self, x_start, x_t, t): 225 | """ 226 | Compute the mean and variance of the diffusion posterior: 227 | q(x_{t-1} | x_t, x_0) 228 | """ 229 | # _extract_into_tensor函数是把sqrt_alphas_cumprod中的第t个元素取出,与x_0相乘得到均值 230 | assert x_start.shape == x_t.shape 231 | posterior_mean = ( 232 | self._extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 233 | + self._extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 234 | ) 235 | posterior_variance = self._extract_into_tensor(self.posterior_variance, t, x_t.shape) 236 | posterior_log_variance_clipped = self._extract_into_tensor( 237 | self.posterior_log_variance_clipped, t, x_t.shape 238 | ) 239 | assert ( 240 | posterior_mean.shape[0] 241 | == posterior_variance.shape[0] 242 | == posterior_log_variance_clipped.shape[0] 243 | == x_start.shape[0] 244 | ) 245 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 246 | 247 | # 通过模型(Unet),基于x_t预测x_{t-1}的均值与方差;即逆扩散过程的均值和方差,也会预测x_0 248 | def p_mean_variance(self, model, x, t): 249 | """ 250 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 251 | the initial x, x_0. 252 | """ 253 | B, C = x.shape[:2] # # batch_size(400), channel_nums (94949) 254 | assert t.shape == (B, ) # 一个batch中每个图片输入都对应一个时间步t,故t的size为(batch_size,) 255 | # 虽然Unet输出的尺寸一样,但模型训练预测的目标不同,输出数据表示的含义不同 256 | model_output = model(x, t) # torch.Size([400, 94949]) 257 | 258 | model_variance = self.posterior_variance # torch.Size([10]) 259 | model_log_variance = self.posterior_log_variance_clipped # torch.Size([10]) 260 | 261 | model_variance = self._extract_into_tensor(model_variance, t, x.shape) # torch.Size([400, 94949]) 262 | model_log_variance = self._extract_into_tensor(model_log_variance, t, x.shape) # torch.Size([400, 94949]) 263 | 264 | if self.mean_type == ModelMeanType.START_X: 265 | pred_xstart = model_output 266 | elif self.mean_type == ModelMeanType.EPSILON: 267 | pred_xstart = self._predict_xstart_from_eps(x, t, eps=model_output) 268 | else: 269 | raise NotImplementedError(self.mean_type) 270 | 271 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 272 | 273 | assert ( 274 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 275 | ) 276 | 277 | return { 278 | "mean": model_mean, 279 | "variance": model_variance, 280 | "log_variance": model_log_variance, 281 | "pred_xstart": pred_xstart, 282 | } 283 | 284 | # 基于论文中的公式11,将公式转换以下就能基于均值μ和x_t求x_0;参数中的xprev就是Unet模型预测的均值 285 | def _predict_xstart_from_eps(self, x_t, t, eps): 286 | assert x_t.shape == eps.shape 287 | return ( 288 | self._extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 289 | - self._extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 290 | ) 291 | 292 | def SNR(self, t): 293 | """ 294 | Compute the signal-to-noise ratio for a single timestep. 295 | """ 296 | self.alphas_cumprod = self.alphas_cumprod.to(t.device) 297 | return self.alphas_cumprod[t] / (1 - self.alphas_cumprod[t]) 298 | 299 | 300 | 301 | def _extract_into_tensor(self, arr, timesteps, broadcast_shape): 302 | """ 303 | Extract values from a 1-D numpy array for a batch of indices. 304 | 305 | :param arr: the 1-D numpy array. 306 | :param timesteps: a tensor of indices into the array to extract. 307 | :param broadcast_shape: a larger shape of K dimensions with the batch 308 | dimension equal to the length of timesteps. 309 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 310 | """ 311 | # res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 312 | arr = arr.to(timesteps.device) # torch.Size([10]) 313 | res = arr[timesteps].float() # torch.Size([400]) 314 | while len(res.shape) < len(broadcast_shape): 315 | res = res[..., None] # torch.Size([400, 1]) 316 | return res.expand(broadcast_shape) # torch.Size([400, 94949]) 317 | 318 | def betas_from_linear_variance(steps, variance, max_beta=0.999): 319 | # steps: 10 320 | # variance: np.linspace from the start(5e-7) to the end(2.5e-6) with setps(10) 321 | # max_beta=0.999 322 | # ipdb.set_trace() 323 | alpha_bar = 1 - variance # (10,) 324 | betas = [] 325 | betas.append(1 - alpha_bar[0]) 326 | for i in range(1, steps): 327 | betas.append(min(1 - alpha_bar[i] / alpha_bar[i - 1], max_beta)) 328 | return np.array(betas) # (10,) 329 | 330 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 331 | """ 332 | Create a beta schedule that discretizes the given alpha_t_bar function, 333 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 334 | 335 | :param num_diffusion_timesteps: the number of betas to produce. 336 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 337 | produces the cumulative product of (1-beta) up to that 338 | part of the diffusion process. 339 | :param max_beta: the maximum beta to use; use values lower than 1 to 340 | prevent singularities. 341 | """ 342 | betas = [] 343 | for i in range(num_diffusion_timesteps): 344 | t1 = i / num_diffusion_timesteps 345 | t2 = (i + 1) / num_diffusion_timesteps 346 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 347 | return np.array(betas) 348 | 349 | def normal_kl(mean1, logvar1, mean2, logvar2): 350 | """ 351 | Compute the KL divergence between two gaussians. 352 | 353 | Shapes are automatically broadcasted, so batches can be compared to 354 | scalars, among other use cases. 355 | """ 356 | tensor = None 357 | for obj in (mean1, logvar1, mean2, logvar2): 358 | if isinstance(obj, th.Tensor): 359 | tensor = obj 360 | break 361 | assert tensor is not None, "at least one argument must be a Tensor" 362 | 363 | # Force variances to be Tensors. Broadcasting helps convert scalars to 364 | # Tensors, but it does not work for th.exp(). 365 | logvar1, logvar2 = [ 366 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 367 | for x in (logvar1, logvar2) 368 | ] 369 | 370 | return 0.5 * ( 371 | -1.0 372 | + logvar2 373 | - logvar1 374 | + th.exp(logvar1 - logvar2) 375 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 376 | ) 377 | 378 | def mean_flat(tensor): 379 | """ 380 | Take the mean over all non-batch dimensions. 381 | """ 382 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 383 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import ipdb 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | import math 7 | import os 8 | import io 9 | import copy 10 | import time 11 | import random 12 | import copy 13 | class EarlyStopping_onetower: 14 | """Early stops the training if validation loss doesn't improve after a given patience.""" 15 | def __init__(self, args, verbose=True, delta=0): 16 | """ 17 | Args: 18 | patience (int): How long to wait after last time validation loss improved. 19 | Default: 5 20 | verbose (bool): If True, prints a message for each validation loss improvement. 21 | Default: False 22 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 23 | Default: 0 24 | """ 25 | self.patience = args.patience 26 | self.verbose = verbose 27 | self.counter = 0 28 | self.epoch = 100 29 | self.best_performance = None 30 | self.early_stop = False 31 | self.ndcg_max = None 32 | self.save_epoch = None 33 | self.delta = delta 34 | self.version = args.version 35 | self.dataset = args.dataset 36 | self.base_model = args.base_model 37 | 38 | def __call__(self, epoch, model, result_path, t_test): 39 | 40 | if self.ndcg_max is None: 41 | self.ndcg_max = t_test[2] 42 | self.best_performance = t_test 43 | self.save_epoch = epoch 44 | self.save_checkpoint(epoch, model, result_path, t_test) 45 | elif t_test[2] < self.ndcg_max: 46 | self.counter += 1 47 | print(f'In the epoch: {epoch}, EarlyStopping counter: {self.counter} out of {self.patience}') 48 | if self.counter >= self.patience and epoch>=self.epoch: 49 | self.early_stop = True 50 | else: 51 | self.best_performance = t_test 52 | self.save_epoch = epoch 53 | self.save_checkpoint(epoch, model, result_path, t_test) 54 | self.counter = 0 55 | 56 | def save_checkpoint(self, epoch, model, result_path, t_test): 57 | print(f'Validation loss in {epoch} decreased {self.ndcg_max:.4f} --> {t_test[2]:.4f}. Saving model ...\n') 58 | with io.open(result_path + 'save_model.txt', 'a', encoding='utf-8') as file: 59 | file.write("NDCG@10 in epoch {} decreased {:.4f} --> {:.4f}, the HR@10 is {:.4f}, the AUC is {:.4f}, the loss_rec is {:.4f}. Saving model...\n".format(epoch, self.ndcg_max, t_test[2], t_test[7], t_test[10], t_test[11])) 60 | torch.save(model.state_dict(), os.path.join(result_path, 'checkpoint.pt')) 61 | self.ndcg_max = t_test[2] 62 | 63 | class PointWiseFeedForward(torch.nn.Module): 64 | def __init__(self, hidden_units, dropout_rate): 65 | 66 | super(PointWiseFeedForward, self).__init__() 67 | 68 | self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1) 69 | self.dropout1 = torch.nn.Dropout(p=dropout_rate) 70 | self.relu = torch.nn.ReLU() 71 | self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1) 72 | self.dropout2 = torch.nn.Dropout(p=dropout_rate) 73 | 74 | def forward(self, inputs): 75 | outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2)))))) 76 | outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length) 77 | outputs += inputs 78 | return outputs 79 | 80 | class SASRec_Embedding(torch.nn.Module): 81 | def __init__(self, item_num, args): 82 | super(SASRec_Embedding, self).__init__() 83 | 84 | self.item_num = item_num # 3416 85 | self.dev = args.device #'cuda' 86 | 87 | # TODO: loss += args.l2_emb for regularizing embedding vectors during training 88 | # https://stackoverflow.com/questions/42704283/adding-l1-l2-regularization-in-pytorch 89 | self.item_emb = torch.nn.Embedding(self.item_num+1, args.hidden_units, padding_idx=0) #Embedding(3417, 50, padding_idx=0) 90 | self.pos_emb = torch.nn.Embedding(args.maxlen, args.hidden_units) # TO IMPROVE Embedding(200, 50) 91 | self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate) #Dropout(p=0.2) 92 | 93 | self.attention_layernorms = torch.nn.ModuleList() # 2 layers of LayerNorm 94 | self.attention_layers = torch.nn.ModuleList() # 2 layers of MultiheadAttention 95 | self.forward_layernorms = torch.nn.ModuleList() # 2 layers of LayerNorm 96 | self.forward_layers = torch.nn.ModuleList() # 2 layers of PointWiseFeedForward 97 | 98 | self.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) # LayerNorm(torch.Size([50]), eps=1e-08, elementwise_affine=True) 99 | 100 | for _ in range(args.num_blocks): 101 | new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) #LayerNorm(torch.Size([50]), eps=1e-08, elementwise_affine=True) 102 | self.attention_layernorms.append(new_attn_layernorm) 103 | 104 | new_attn_layer = torch.nn.MultiheadAttention(args.hidden_units, 105 | args.num_heads, 106 | args.dropout_rate, batch_first=True) # MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=50, out_features=50, bias=True)) 107 | self.attention_layers.append(new_attn_layer) 108 | 109 | new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) # LayerNorm((50,), eps=1e-08, elementwise_affine=True) 110 | self.forward_layernorms.append(new_fwd_layernorm) 111 | 112 | new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate) 113 | self.forward_layers.append(new_fwd_layer) 114 | 115 | def log2feats(self, log_seqs): 116 | # tt0 = time.time() 117 | seqs = self.item_emb(log_seqs) 118 | seqs *= self.item_emb.embedding_dim ** 0.5 # torch.Size([128, 200, 64]) 119 | positions = torch.tile(torch.arange(0,log_seqs.shape[1]), [log_seqs.shape[0],1]).cuda() # torch.Size([128, 200]) 120 | # add the position embedding 121 | seqs += self.pos_emb(positions) 122 | seqs = self.emb_dropout(seqs) # torch.Size([128, 200, 64]) 123 | 124 | # mask the noninteracted position 125 | timeline_mask = torch.BoolTensor(log_seqs.cpu() == 0).cuda() # (128,200) 126 | seqs *= ~timeline_mask.unsqueeze(-1) # broadcast in last dim 127 | 128 | tl = seqs.shape[1] # time dim len for enforce causality, 200 129 | attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device='cuda')) #(200,200) 130 | 131 | for i in range(len(self.attention_layers)): 132 | Q = self.attention_layernorms[i](seqs) #torch.Size([128, 200, 50]) 133 | mha_outputs, _ = self.attention_layers[i](Q, seqs, seqs, attn_mask=attention_mask) # torch.Size([128, 200, 50]) 134 | # key_padding_mask=timeline_mask 135 | # need_weights=False) this arg do not work? 136 | seqs = Q + mha_outputs # torch.Size([128, 200, 50]) 137 | 138 | seqs = self.forward_layernorms[i](seqs) # torch.Size([128, 200, 50]) 139 | seqs = self.forward_layers[i](seqs) # torch.Size([128, 200, 50]) 140 | seqs *= ~timeline_mask.unsqueeze(-1) # torch.Size([128, 200, 50]) 141 | 142 | log_feats = self.last_layernorm(seqs) # (U, T, C) -> (U, -1, C) 143 | 144 | return log_feats 145 | 146 | def forward(self, log_seqs): # for training 147 | log_feats = self.log2feats(log_seqs) # torch.Size([128, 200, 50]) user_ids hasn't been used yet 148 | 149 | return log_feats # pos_pred, neg_pred 150 | 151 | 152 | 153 | class GRU4Rec_withNeg_Dist(torch.nn.Module): 154 | def __init__(self, user_num, item_num, args): 155 | super(GRU4Rec_withNeg_Dist, self).__init__() 156 | 157 | self.source_item_emb = torch.nn.Embedding(item_num+1, args.hidden_units, padding_idx=0) #Embedding(3417, 50, padding_idx=0) 158 | self.target_item_emb = torch.nn.Embedding(item_num+1, args.hidden_units, padding_idx=0) #Embedding(3417, 50, padding_idx=0) 159 | 160 | self.gru_source = torch.nn.GRU(args.hidden_units, args.hidden_units, batch_first=True) 161 | self.gru_target = torch.nn.GRU(args.hidden_units, args.hidden_units, batch_first=True) 162 | 163 | self.h0_source = torch.nn.Parameter(torch.zeros((1, 1, args.hidden_units), requires_grad=True)) 164 | self.h0_target = torch.nn.Parameter(torch.zeros((1, 1, args.hidden_units), requires_grad=True)) 165 | 166 | self.dev = args.device #'cuda' 167 | 168 | self.leakyrelu = torch.nn.LeakyReLU() 169 | self.relu = torch.nn.ReLU() 170 | 171 | self.temperature = args.temperature 172 | self.fname = args.dataset 173 | self.dropout = torch.nn.Dropout(p=args.dropout_rate) 174 | 175 | def forward(self, user_ids, log_seqs, pos_seqs, neg_list, log_seqs_all, soft_diff): # for training 176 | # user_ids:(128,) 177 | # log_seqs:(128, 200) 178 | # pos_seqs:(128, 200) 179 | # neg_seqs:(128, 200) 180 | # ipdb.set_trace() 181 | neg_embs = [] 182 | neg_logits = [] 183 | source_log_embedding = self.source_item_emb(log_seqs) 184 | source_log_feats, _ = self.gru_source(source_log_embedding, self.h0_source.tile(1,source_log_embedding.shape[0],1)) #2,121,100 185 | source_log_all_embedding = self.source_item_emb(log_seqs_all) 186 | source_log_all_feats, _ = self.gru_source(source_log_all_embedding, self.h0_source.tile(1,source_log_all_embedding.shape[0],1)) #2,121,100 187 | pos_embs = self.source_item_emb(pos_seqs) # torch.Size([128, 200, 50]) 188 | soft_embs = self.source_item_emb(soft_diff) # torch.Size([128, 200, 50]) 189 | for i in range(0,len(neg_list)): 190 | neg_embs.append(self.source_item_emb(neg_list[i])) # torch.Size([128, 200, 50]) 191 | 192 | # get the l2 norm for the target domain recommendation 193 | source_log_feats_l2norm = torch.nn.functional.normalize(source_log_feats, p=2, dim=-1) 194 | pos_embs_l2norm = torch.nn.functional.normalize(pos_embs, p=2, dim=-1) 195 | pos_logits = (source_log_feats_l2norm * pos_embs_l2norm).sum(dim=-1) # torch.Size([128, 200]) 196 | pos_logits = pos_logits * self.temperature 197 | 198 | for i in range(0,len(neg_list)): 199 | neg_embs_l2norm_i = torch.nn.functional.normalize(neg_embs[i], p=2, dim=-1) 200 | neg_logits_i = (source_log_feats_l2norm * neg_embs_l2norm_i).sum(dim=-1) # torch.Size([128, 200]) 201 | neg_logits_i = neg_logits_i * self.temperature 202 | neg_logits.append(neg_logits_i) 203 | 204 | source_log_all_feats_l2norm = torch.nn.functional.normalize(source_log_all_feats, p=2, dim=-1) 205 | soft_embs_l2norm = torch.nn.functional.normalize(soft_embs, p=2, dim=-1) 206 | soft_logits = (source_log_all_feats_l2norm[:,-1,:].unsqueeze(1).expand(-1,soft_embs_l2norm.shape[1],-1) * soft_embs_l2norm).sum(dim=-1) # torch.Size([128, 200]) 207 | soft_logits = soft_logits * self.temperature 208 | 209 | return pos_logits, neg_logits, soft_logits # pos_pred, neg_pred 210 | 211 | 212 | def predict(self, user_ids, log_seqs, item_indices): # for inference 213 | # user_ids: (1,) 214 | # log_seqs: (1, 200) 215 | # item_indices: (101,) 216 | # ipdb.set_trace() 217 | source_log_embedding = self.source_item_emb(log_seqs) 218 | source_log_feats, _ = self.gru_source(source_log_embedding, self.h0_source.tile(1,source_log_embedding.shape[0],1)) #2,121,100 219 | 220 | item_embs = self.source_item_emb(item_indices) 221 | # get the l2 norm for the target domain recommendation 222 | final_feat = source_log_feats[:, -1, :] # torch.Size([1, 50]) 223 | final_feat_l2norm = torch.nn.functional.normalize(final_feat, p=2, dim=-1) 224 | item_embs_l2norm = torch.nn.functional.normalize(item_embs, p=2, dim=-1) 225 | 226 | logits = item_embs_l2norm.matmul(final_feat_l2norm.unsqueeze(-1)).squeeze(-1) 227 | logits = logits * self.temperature 228 | 229 | 230 | return logits # preds # (U, I) 231 | 232 | 233 | 234 | 235 | class SASRec_V1_withNeg_Dist(torch.nn.Module): 236 | def __init__(self, user_num, item_num, args): 237 | super(SASRec_V1_withNeg_Dist, self).__init__() 238 | 239 | self.sasrec_embedding_source = SASRec_Embedding(item_num, args) 240 | self.sasrec_embedding_target = SASRec_Embedding(item_num, args) 241 | self.dev = args.device #'cuda' 242 | 243 | 244 | self.leakyrelu = torch.nn.LeakyReLU() 245 | self.relu = torch.nn.ReLU() 246 | 247 | self.temperature = args.temperature 248 | self.fname = args.dataset 249 | self.dropout = torch.nn.Dropout(p=args.dropout_rate) 250 | 251 | 252 | 253 | def forward(self, user_ids, log_seqs, pos_seqs, neg_list, log_seqs_all, soft_diff): # for training 254 | neg_embs = [] 255 | # ipdb.set_trace() 256 | source_log_feats = self.sasrec_embedding_source(log_seqs) # torch.Size([128, 200, 50]) 257 | source_log_all_feats = self.sasrec_embedding_source(log_seqs_all) # torch.Size([128, 200, 50]) 258 | pos_embs = self.sasrec_embedding_source.item_emb(pos_seqs) # torch.Size([128, 200, 50]) 259 | soft_embs = self.sasrec_embedding_source.item_emb(soft_diff) # torch.Size([128, 200, 50]) 260 | for i in range(0,len(neg_list)): 261 | neg_embs.append(self.sasrec_embedding_source.item_emb(neg_list[i])) 262 | 263 | # get the l2 norm for the target domain recommendation 264 | source_log_feats_l2norm = torch.nn.functional.normalize(source_log_feats, p=2, dim=-1) 265 | pos_embs_l2norm = torch.nn.functional.normalize(pos_embs, p=2, dim=-1) 266 | pos_logits = (source_log_feats_l2norm * pos_embs_l2norm).sum(dim=-1) # torch.Size([128, 200]) 267 | pos_logits = pos_logits * self.temperature 268 | 269 | neg_logits = [] 270 | for i in range(0,len(neg_list)): 271 | neg_embs_l2norm_i = torch.nn.functional.normalize(neg_embs[i], p=2, dim=-1) 272 | neg_logits_i = (source_log_feats_l2norm * neg_embs_l2norm_i).sum(dim=-1) # torch.Size([128, 200]) 273 | neg_logits_i = neg_logits_i * self.temperature 274 | neg_logits.append(neg_logits_i) 275 | 276 | source_log_all_feats_l2norm = torch.nn.functional.normalize(source_log_all_feats, p=2, dim=-1) 277 | soft_embs_l2norm = torch.nn.functional.normalize(soft_embs, p=2, dim=-1) 278 | soft_logits = (source_log_all_feats_l2norm[:,-1,:].unsqueeze(1).expand(-1,soft_embs_l2norm.shape[1],-1) * soft_embs_l2norm).sum(dim=-1) # torch.Size([128, 200]) 279 | soft_logits = soft_logits * self.temperature 280 | 281 | return pos_logits, neg_logits, soft_logits # pos_pred, neg_pred 282 | 283 | 284 | def predict(self, user_ids, log_seqs, item_indices): # for inference 285 | # ipdb.set_trace() 286 | source_log_feats = self.sasrec_embedding_source(log_seqs) # torch.Size([1, 200, 64]) 287 | item_embs = self.sasrec_embedding_source.item_emb(item_indices) # torch.Size([1, 100, 64]) 288 | # get the l2 norm for the target domain recommendation 289 | final_feat = source_log_feats[:, -1, :] # torch.Size([1, 64]) 290 | final_feat_l2norm = torch.nn.functional.normalize(final_feat, p=2, dim=-1) # torch.Size([1, 64]) 291 | item_embs_l2norm = torch.nn.functional.normalize(item_embs, p=2, dim=-1) # torch.Size([1, 100, 64]) 292 | 293 | logits = item_embs_l2norm.matmul(final_feat_l2norm.unsqueeze(-1)).squeeze(-1) # torch.Size([1, 100]) 294 | logits = logits * self.temperature # torch.Size([1, 100]) 295 | 296 | return logits # torch.Size([1, 100]) 297 | 298 | -------------------------------------------------------------------------------- /overall_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hulkima/PDRec/fc557dd989aad85aaf4e4875f6d386bfd2c80184/overall_structure.png -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | import random 5 | import numpy as np 6 | from collections import defaultdict 7 | from multiprocessing import Process, Queue 8 | import ipdb 9 | import dill as pkl 10 | import time 11 | from sklearn.metrics import roc_auc_score 12 | 13 | 14 | def del_tensor_ele(arr,index): 15 | arr1 = arr[0:index] 16 | arr2 = arr[index+1:] 17 | return torch.cat((arr1,arr2),dim=0) 18 | 19 | # sampler for batch generation 20 | def random_neq(l, r, s): 21 | t = np.random.randint(l, r) 22 | while t in s: 23 | t = np.random.randint(l, r) 24 | return t 25 | 26 | def scale_withminmax(given_list, min_value, max_value, reweight_version): 27 | min_list = np.min(given_list) 28 | max_list = np.max(given_list) 29 | if min_list == max_list: 30 | # ipdb.set_trace() 31 | if reweight_version == "AllOne": 32 | scaled_array = np.ones(shape=len(given_list),dtype=np.float32)*max_value 33 | elif reweight_version == "AllLinear": 34 | scaled_array = np.linspace(min_value, max_value, len(given_list), dtype=np.float32) 35 | elif reweight_version == "MinMax": 36 | max_list = max_list + 1e3 37 | scale_factor = (max_value - min_value) / (max_list - min_list) 38 | scaled_array = min_value + (np.array(given_list, dtype=np.float32) - min_list) * scale_factor 39 | else: 40 | scale_factor = (max_value - min_value) / (max_list - min_list) 41 | scaled_array = min_value + (np.array(given_list, dtype=np.float32) - min_list) * scale_factor 42 | 43 | return scaled_array 44 | 45 | def get_exclusive(t1, t2): 46 | t1_exclusive = t1[(t1.view(1, -1) != t2.view(-1, 1)).all(dim=0)] 47 | return t1_exclusive 48 | 49 | 50 | # source:book----range[1,interval+1);target:movie[interval+1, itemnum + 1) 51 | def sample_function_T_DiffCDR_TI(random_min, random_max, random_source_min, random_source_max, user_train_mix, user_train_source, user_train_target, user_train_ti_mix, user_train_ti_source, user_train_ti_target, usernum, itemnum, batch_size, w_min, w_max, reweight_version, result_queue): 52 | 53 | def sample(): 54 | user = np.random.randint(1, usernum + 1) 55 | while len(user_train_mix[user]) <= 1 or len(user_train_source[user]) <= 1 or len(user_train_target[user]) <= 1: 56 | user = np.random.randint(1, usernum + 1) 57 | 58 | # init the tensor 59 | seq_mix = np.zeros([itemnum+1], dtype=np.float32) 60 | seq_source = np.zeros([itemnum+1], dtype=np.float32) 61 | seq_target = np.zeros([itemnum+1], dtype=np.float32) 62 | seq_mix_temp = np.zeros([itemnum+1], dtype=np.float32) 63 | seq_source_temp = np.zeros([itemnum+1], dtype=np.float32) 64 | seq_target_temp = np.zeros([itemnum+1], dtype=np.float32) 65 | 66 | # set the position-aware weight 67 | weight_mix = scale_withminmax(user_train_ti_mix[user], w_min, w_max, reweight_version) 68 | weight_source = scale_withminmax(user_train_ti_source[user], w_min, w_max, reweight_version) 69 | weight_target = scale_withminmax(user_train_ti_target[user], w_min, w_max, reweight_version) 70 | 71 | mask = np.logical_and(random_source_min <= np.array(user_train_mix[user]), np.array(user_train_mix[user]) < random_source_max) 72 | weight_mix = np.where(mask, weight_mix / 2, weight_mix) 73 | 74 | # generate the 75 | seq_mix[user_train_mix[user]] = 1.0 76 | seq_source[user_train_source[user]] = 1.0 77 | seq_target[user_train_target[user]] = 1.0 78 | 79 | seq_mix_temp[user_train_mix[user]] = weight_mix 80 | seq_source_temp[user_train_source[user]] = weight_source 81 | seq_target_temp[user_train_target[user]] = weight_target 82 | 83 | return (user, seq_mix, seq_source, seq_target, seq_mix_temp, seq_source_temp, seq_target_temp) 84 | 85 | 86 | while True: 87 | one_batch = [] 88 | for i in range(batch_size): 89 | one_batch.append(sample()) 90 | 91 | result_queue.put(zip(*one_batch)) 92 | 93 | 94 | 95 | class WarpSampler_T_DiffCDR_TI(object): 96 | def __init__(self, random_min, random_max, random_source_min, random_source_max, user_train_mix, user_train_source, user_train_target, user_train_ti_mix, user_train_ti_source, user_train_ti_target, usernum, itemnum, batch_size=64, w_min=0.1, w_max=1.0, reweight_version='AllLinear', n_workers=1): 97 | self.result_queue = Queue(maxsize=n_workers * 10) 98 | self.processors = [] 99 | for i in range(n_workers): 100 | self.processors.append( 101 | Process(target=sample_function_T_DiffCDR_TI, args=(random_min, 102 | random_max, 103 | random_source_min, 104 | random_source_max, 105 | user_train_mix, 106 | user_train_source, 107 | user_train_target, 108 | user_train_ti_mix, 109 | user_train_ti_source, 110 | user_train_ti_target, 111 | usernum, 112 | itemnum, 113 | batch_size, 114 | w_min, 115 | w_max, 116 | reweight_version, 117 | self.result_queue 118 | ))) 119 | self.processors[-1].daemon = True 120 | self.processors[-1].start() 121 | 122 | def next_batch(self): 123 | return self.result_queue.get() 124 | 125 | def close(self): 126 | for p in self.processors: 127 | p.terminate() 128 | p.join() 129 | 130 | 131 | 132 | # source:book----range[1,interval+1);target:movie[interval+1, itemnum + 1) 133 | def sample_function_V13_final_please_Diff_TI(random_min, random_max, random_source_min, random_source_max, user_train_mix, user_train_source, user_train_target, user_train_ti_mix, user_train_ti_source, user_train_ti_target, user_train_mix_sequence_for_target, user_train_source_sequence_for_target, usernum, itemnum, w_min, w_max, reweight_version, batch_size, maxlen, sample_ratio, result_queue): 134 | 135 | def sample(): 136 | user = np.random.randint(1, usernum + 1) 137 | while len(user_train_mix[user]) < 1 or len(user_train_source[user]) < 1 or len(user_train_target[user]) < 1: 138 | user = np.random.randint(1, usernum + 1) 139 | 140 | user_train_mix_u = np.array(user_train_mix[user]) 141 | user_train_source_u = np.array(user_train_source[user]) 142 | user_train_target_u = np.array(user_train_target[user]) 143 | 144 | seq_mix = np.zeros([maxlen], dtype=np.int32) 145 | seq_source = np.zeros([maxlen], dtype=np.int32) 146 | seq_target = np.zeros([maxlen], dtype=np.int32) 147 | pos_target = np.zeros([maxlen], dtype=np.int32) 148 | neg_target = np.zeros([sample_ratio, maxlen], dtype=np.int32) 149 | user_train_mix_sequence_for_target_indices = np.zeros([maxlen], dtype=np.int32) 150 | user_train_source_sequence_for_target_indices = np.zeros([maxlen], dtype=np.int32) 151 | 152 | nxt_target = user_train_target_u[-1] # # 最后一个交互的物品 153 | 154 | idx_mix = maxlen - 1 #49 155 | idx_source = maxlen - 1 #49 156 | idx_target = maxlen - 1 #49 157 | 158 | ts_target = set(user_train_target_u) # a set 159 | for i in reversed(range(0, len(user_train_mix_u))): # reversed是逆序搜索,这里的i指的是交互的物品 160 | seq_mix[idx_mix] = user_train_mix_u[i] 161 | idx_mix -= 1 162 | if idx_mix == -1: break 163 | 164 | for i in reversed(range(0, len(user_train_source_u))): # reversed是逆序搜索,这里的i指的是交互的物品 165 | seq_source[idx_source] = user_train_source_u[i] 166 | idx_source -= 1 167 | if idx_source == -1: break 168 | 169 | for i in reversed(range(0, len(user_train_target_u[:-1]))): # reversed是逆序搜索,这里的i指的是交互的物品 170 | seq_target[idx_target] = user_train_target_u[i] 171 | pos_target[idx_target] = nxt_target 172 | if user_train_mix_sequence_for_target[user][i] < -maxlen: 173 | user_train_mix_sequence_for_target_indices[idx_target] = 0 174 | else: 175 | user_train_mix_sequence_for_target_indices[idx_target] = user_train_mix_sequence_for_target[user][i] + maxlen 176 | 177 | if user_train_source_sequence_for_target[user][i] < -maxlen or user_train_source_sequence_for_target[user][i] == -len(user_train_source_u)-1: 178 | user_train_source_sequence_for_target_indices[idx_target] = 0 179 | else: 180 | user_train_source_sequence_for_target_indices[idx_target] = user_train_source_sequence_for_target[user][i] + maxlen 181 | if nxt_target != 0: 182 | for j in range(0,sample_ratio): 183 | neg_target[j, idx_target] = random_neq(random_min, random_max, ts_target) 184 | nxt_target = user_train_target_u[i] 185 | idx_target -= 1 186 | if idx_target == -1: break 187 | 188 | # init the tensor 189 | seq_mix_inter = np.zeros([itemnum+1], dtype=np.float32) 190 | seq_source_inter = np.zeros([itemnum+1], dtype=np.float32) 191 | seq_target_inter = np.zeros([itemnum+1], dtype=np.float32) 192 | seq_mix_inter_temp = np.zeros([itemnum+1], dtype=np.float32) 193 | seq_source_inter_temp = np.zeros([itemnum+1], dtype=np.float32) 194 | seq_target_inter_temp = np.zeros([itemnum+1], dtype=np.float32) 195 | # ipdb.set_trace() 196 | # set the position-aware weight 197 | weight_mix = scale_withminmax(user_train_ti_mix[user], w_min, w_max, reweight_version) 198 | weight_source = scale_withminmax(user_train_ti_source[user], w_min, w_max, reweight_version) 199 | weight_target = scale_withminmax(user_train_ti_target[user], w_min, w_max, reweight_version) 200 | 201 | 202 | index_tensor = torch.arange(len(user_train_mix_u)) 203 | condition_mask = (user_train_mix_u >= random_source_min) & (user_train_mix_u < random_source_max) 204 | weight_mix[condition_mask] /= 2 205 | 206 | # generate the 207 | seq_mix_inter[user_train_mix_u] = 1.0 208 | seq_source_inter[user_train_source_u] = 1.0 209 | seq_target_inter[user_train_target_u] = 1.0 210 | 211 | seq_mix_inter_temp[user_train_mix_u] = weight_mix 212 | seq_source_inter_temp[user_train_source_u] = weight_source 213 | seq_target_inter_temp[user_train_target_u] = weight_target 214 | 215 | return (user, seq_mix, seq_source, seq_target, pos_target, neg_target, user_train_mix_sequence_for_target_indices, user_train_source_sequence_for_target_indices, seq_mix_inter, seq_source_inter, seq_target_inter, seq_mix_inter_temp, seq_source_inter_temp, seq_target_inter_temp) 216 | 217 | 218 | while True: 219 | one_batch = [] 220 | for i in range(batch_size): 221 | one_batch.append(sample()) 222 | 223 | result_queue.put(zip(*one_batch)) 224 | 225 | class WarpSampler_V13_final_please_Diff_TI(object): 226 | def __init__(self, random_min, random_max, random_source_min, random_source_max, user_train_mix, user_train_source, user_train_target, user_train_ti_mix, user_train_ti_source, user_train_ti_target, user_train_mix_sequence_for_target, user_train_source_sequence_for_target, user_list, itemnum, itemnum_source, itemnum_target, w_min=0.1, w_max=1.0, reweight_version='AllOne', batch_size=64, maxlen=10, n_workers=1, sample_ratio=1): 227 | self.result_queue = Queue(maxsize=n_workers * 10) 228 | self.processors = [] 229 | for i in range(n_workers): 230 | self.processors.append( 231 | Process(target=sample_function_V13_final_please_Diff_TI, args=(random_min, 232 | random_max, 233 | random_source_min, 234 | random_source_max, 235 | user_train_mix, 236 | user_train_source, 237 | user_train_target, 238 | user_train_ti_mix, 239 | user_train_ti_source, 240 | user_train_ti_target, 241 | user_train_mix_sequence_for_target, 242 | user_train_source_sequence_for_target, 243 | user_list, 244 | itemnum, 245 | w_min, 246 | w_max, 247 | reweight_version, 248 | batch_size, 249 | maxlen, 250 | sample_ratio, 251 | self.result_queue 252 | ))) 253 | self.processors[-1].daemon = True 254 | self.processors[-1].start() 255 | 256 | def next_batch(self): 257 | return self.result_queue.get() 258 | 259 | def close(self): 260 | for p in self.processors: 261 | p.terminate() 262 | p.join() 263 | 264 | 265 | 266 | 267 | 268 | 269 | def data_partition(version, fname, dataset_name, maxlen): 270 | usernum = 0 271 | itemnum = 0 272 | user_train = {} 273 | user_valid = {} 274 | user_test = {} 275 | interval = 0 276 | # ipdb.set_trace() 277 | # assume user/item index starting from 1 278 | 279 | if fname == 'amazon_toy': 280 | with open('./Dataset/Amazon_toy/toy_log_file_final.pkl', 'rb') as f: 281 | toy_log_file_final = pkl.load(f) 282 | 283 | with open('./Dataset/Amazon_toy/mix_log_file_final.pkl', 'rb') as f: 284 | mix_log_file_final = pkl.load(f) 285 | 286 | with open('./Dataset/Amazon_toy/item_index_toy.pkl', 'rb') as f: 287 | item_index_toy = pkl.load(f) 288 | 289 | with open('./Dataset/Amazon_toy/item_index_mix.pkl', 'rb') as f: 290 | item_index_mix = pkl.load(f) 291 | 292 | with open('./Dataset/Amazon_toy/user_index_overleap.pkl', 'rb') as f: 293 | user_index_overleap = pkl.load(f) 294 | with open('./Dataset/Amazon_toy/toy_log_timestep_final.pkl', 'rb') as f: 295 | toy_log_timestep_final = pkl.load(f) 296 | 297 | with open('./Dataset/Amazon_toy/mix_log_timestep_final.pkl', 'rb') as f: 298 | mix_log_timestep_final = pkl.load(f) 299 | item_index_toy_array = np.load('./Dataset/Amazon_toy/item_index_toy.npy') 300 | 301 | interval = 37868 302 | 303 | elif fname == 'douban_book': 304 | with open('./Dataset/Douban_book/book_log_file_final.pkl', 'rb') as f: 305 | toy_log_file_final = pkl.load(f) 306 | 307 | with open('./Dataset/Douban_book/mix_log_file_final.pkl', 'rb') as f: 308 | mix_log_file_final = pkl.load(f) 309 | 310 | with open('./Dataset/Douban_book/item_index_book.pkl', 'rb') as f: 311 | item_index_toy = pkl.load(f) 312 | 313 | with open('./Dataset/Douban_book/item_index_mix.pkl', 'rb') as f: 314 | item_index_mix = pkl.load(f) 315 | 316 | with open('./Dataset/Douban_book/user_index_overleap.pkl', 'rb') as f: 317 | user_index_overleap = pkl.load(f) 318 | 319 | with open('./Dataset/Douban_book/book_log_timestep_final.pkl', 'rb') as f: 320 | toy_log_timestep_final = pkl.load(f) 321 | 322 | with open('./Dataset/Douban_book/mix_log_timestep_final.pkl', 'rb') as f: 323 | mix_log_timestep_final = pkl.load(f) 324 | 325 | item_index_toy_array = np.load('./Dataset/Douban_book/item_index_book.npy') 326 | 327 | interval = 33697 328 | # ipdb.set_trace() 329 | usernum = len(user_index_overleap.keys()) # 116254 330 | 331 | if fname == 'amazon_toy': 332 | user_train_toy_mix = {} 333 | user_train_toy_source = {} 334 | user_train_toy_target = {} 335 | user_valid_toy_target = {} 336 | user_test_toy_target = {} 337 | user_train_toy_mix_sequence_for_target = {} 338 | user_train_toy_source_sequence_for_target = {} 339 | 340 | user_train_ti_toy_mix = {} 341 | user_train_ti_toy_source = {} 342 | user_train_ti_toy_target = {} 343 | user_valid_ti_toy_target = {} 344 | user_test_ti_toy_target = {} 345 | 346 | position_mix = [] 347 | position_source = [] 348 | # ipdb.set_trace() 349 | itemnum = len(item_index_mix.keys()) 350 | for k in range(1, len(user_index_overleap.keys()) + 1): 351 | v_mix_toy = copy.deepcopy(mix_log_file_final[k]) 352 | v_toy = copy.deepcopy(toy_log_file_final[k]) 353 | 354 | t_mix_toy = copy.deepcopy(mix_log_timestep_final[k]) 355 | t_toy = copy.deepcopy(toy_log_timestep_final[k]) 356 | 357 | toy_last_name = item_index_toy_array[(v_toy[-1] - 1)] # the name of the last interacted movie in Amazon Movie 358 | toy_last_id = item_index_mix[toy_last_name] # the name of the the last interacted movie in Amazon Mix 359 | toy_last_index = np.argwhere(np.array(v_mix_toy)==toy_last_id)[-1].item() 360 | user_mix_toy = v_mix_toy[:toy_last_index+1] 361 | user_ti_mix_toy = t_mix_toy[:toy_last_index+1] 362 | 363 | if len(user_mix_toy) < 3: 364 | ipdb.set_trace() 365 | 366 | user_train_toy_mix[k] = [] 367 | user_train_toy_source[k] = [] 368 | user_train_toy_target[k] = [] 369 | user_valid_toy_target[k] = [] 370 | user_test_toy_target[k] = [] 371 | 372 | user_train_ti_toy_mix[k] = [] 373 | user_train_ti_toy_source[k] = [] 374 | user_train_ti_toy_target[k] = [] 375 | user_valid_ti_toy_target[k] = [] 376 | user_test_ti_toy_target[k] = [] 377 | for re_id in reversed(range(0,len(user_mix_toy))): 378 | if user_mix_toy[re_id] >= interval+1: # from 551942 to XXX, source 379 | user_train_toy_source[k].append(user_mix_toy[re_id]) 380 | user_train_toy_mix[k].append(user_mix_toy[re_id]) 381 | user_train_ti_toy_source[k].append(user_ti_mix_toy[re_id]) 382 | user_train_ti_toy_mix[k].append(user_ti_mix_toy[re_id]) 383 | elif user_mix_toy[re_id] <= interval: # from 1 to 551941, target 384 | if len(user_test_toy_target[k]) == 0: 385 | user_test_toy_target[k].append(user_mix_toy[re_id]) 386 | user_test_ti_toy_target[k].append(user_ti_mix_toy[re_id]) 387 | elif len(user_valid_toy_target[k]) == 0: 388 | user_valid_toy_target[k].append(user_mix_toy[re_id]) 389 | user_valid_ti_toy_target[k].append(user_ti_mix_toy[re_id]) 390 | elif len(user_test_toy_target[k]) == 1 and len(user_valid_toy_target[k]) == 1: 391 | user_train_toy_target[k].append(user_mix_toy[re_id]) 392 | user_train_toy_mix[k].append(user_mix_toy[re_id]) 393 | user_train_ti_toy_target[k].append(user_ti_mix_toy[re_id]) 394 | user_train_ti_toy_mix[k].append(user_ti_mix_toy[re_id]) 395 | 396 | user_train_toy_mix[k].reverse() 397 | user_train_toy_source[k].reverse() 398 | user_train_toy_target[k].reverse() 399 | user_train_ti_toy_mix[k].reverse() 400 | user_train_ti_toy_source[k].reverse() 401 | user_train_ti_toy_target[k].reverse() 402 | 403 | 404 | pos_mix = len(user_train_toy_mix[k])-1 405 | pos_source = len(user_train_toy_source[k])-1 406 | mix_sequence_for_target_list = [] 407 | source_sequence_for_target_list = [] 408 | for i in reversed(list(range(0, len(user_train_toy_mix[k])))): 409 | if user_train_toy_mix[k][i] >= interval+1: 410 | pos_source = pos_source - 1 411 | elif user_train_toy_mix[k][i] <= interval: 412 | mix_sequence_for_target_list.append(pos_mix-1) 413 | source_sequence_for_target_list.append(pos_source) 414 | pos_mix = pos_mix - 1 415 | 416 | mix_sequence_for_target = mix_sequence_for_target_list[:-1] 417 | source_sequence_for_target = source_sequence_for_target_list[:-1] 418 | mix_sequence_for_target.reverse() 419 | source_sequence_for_target.reverse() 420 | 421 | user_train_toy_mix_sequence_for_target[k] = [] 422 | user_train_toy_source_sequence_for_target[k] = [] 423 | for x in mix_sequence_for_target: 424 | user_train_toy_mix_sequence_for_target[k].append(x - len(user_train_toy_mix[k])) 425 | 426 | for x in source_sequence_for_target: 427 | user_train_toy_source_sequence_for_target[k].append(x - len(user_train_toy_source[k])) 428 | 429 | # ipdb.set_trace() 430 | return [user_train_toy_mix, user_train_toy_source, user_train_toy_target, user_valid_toy_target, user_test_toy_target, user_train_toy_mix_sequence_for_target, user_train_toy_source_sequence_for_target, usernum, itemnum, interval, user_train_ti_toy_mix, user_train_ti_toy_source, user_train_ti_toy_target, user_valid_ti_toy_target, user_test_ti_toy_target] 431 | 432 | elif fname == 'douban_book': 433 | user_train_toy_mix = {} 434 | user_train_toy_source = {} 435 | user_train_toy_target = {} 436 | user_valid_toy_target = {} 437 | user_test_toy_target = {} 438 | user_train_toy_mix_sequence_for_target = {} 439 | user_train_toy_source_sequence_for_target = {} 440 | 441 | user_train_ti_toy_mix = {} 442 | user_train_ti_toy_source = {} 443 | user_train_ti_toy_target = {} 444 | user_valid_ti_toy_target = {} 445 | user_test_ti_toy_target = {} 446 | 447 | position_mix = [] 448 | position_source = [] 449 | # ipdb.set_trace() 450 | itemnum = len(item_index_mix.keys()) 451 | for k in range(1, len(user_index_overleap.keys()) + 1): 452 | v_mix_toy = copy.deepcopy(mix_log_file_final[k]) 453 | v_toy = copy.deepcopy(toy_log_file_final[k]) 454 | 455 | t_mix_toy = copy.deepcopy(mix_log_timestep_final[k]) 456 | t_toy = copy.deepcopy(toy_log_timestep_final[k]) 457 | 458 | toy_last_index = np.argwhere(np.array(v_mix_toy)==v_toy[-1])[-1].item() 459 | 460 | user_mix_toy = v_mix_toy[:toy_last_index+1] 461 | user_ti_mix_toy = t_mix_toy[:toy_last_index+1] 462 | 463 | if len(user_mix_toy) < 3: 464 | ipdb.set_trace() 465 | 466 | user_train_toy_mix[k] = [] 467 | user_train_toy_source[k] = [] 468 | user_train_toy_target[k] = [] 469 | user_valid_toy_target[k] = [] 470 | user_test_toy_target[k] = [] 471 | 472 | user_train_ti_toy_mix[k] = [] 473 | user_train_ti_toy_source[k] = [] 474 | user_train_ti_toy_target[k] = [] 475 | user_valid_ti_toy_target[k] = [] 476 | user_test_ti_toy_target[k] = [] 477 | for re_id in reversed(range(0,len(user_mix_toy))): 478 | if user_mix_toy[re_id] >= interval+1: # from 551942 to XXX, source 479 | user_train_toy_source[k].append(user_mix_toy[re_id]) 480 | user_train_toy_mix[k].append(user_mix_toy[re_id]) 481 | user_train_ti_toy_source[k].append(user_ti_mix_toy[re_id]) 482 | user_train_ti_toy_mix[k].append(user_ti_mix_toy[re_id]) 483 | elif user_mix_toy[re_id] <= interval: # from 1 to 551941, target 484 | if len(user_test_toy_target[k]) == 0: 485 | user_test_toy_target[k].append(user_mix_toy[re_id]) 486 | user_test_ti_toy_target[k].append(user_ti_mix_toy[re_id]) 487 | elif len(user_valid_toy_target[k]) == 0: 488 | user_valid_toy_target[k].append(user_mix_toy[re_id]) 489 | user_valid_ti_toy_target[k].append(user_ti_mix_toy[re_id]) 490 | elif len(user_test_toy_target[k]) == 1 and len(user_valid_toy_target[k]) == 1: 491 | user_train_toy_target[k].append(user_mix_toy[re_id]) 492 | user_train_toy_mix[k].append(user_mix_toy[re_id]) 493 | user_train_ti_toy_target[k].append(user_ti_mix_toy[re_id]) 494 | user_train_ti_toy_mix[k].append(user_ti_mix_toy[re_id]) 495 | 496 | user_train_toy_mix[k].reverse() 497 | user_train_toy_source[k].reverse() 498 | user_train_toy_target[k].reverse() 499 | user_train_ti_toy_mix[k].reverse() 500 | user_train_ti_toy_source[k].reverse() 501 | user_train_ti_toy_target[k].reverse() 502 | 503 | 504 | pos_mix = len(user_train_toy_mix[k])-1 505 | pos_source = len(user_train_toy_source[k])-1 506 | mix_sequence_for_target_list = [] 507 | source_sequence_for_target_list = [] 508 | for i in reversed(list(range(0, len(user_train_toy_mix[k])))): 509 | if user_train_toy_mix[k][i] >= interval+1: 510 | pos_source = pos_source - 1 511 | elif user_train_toy_mix[k][i] <= interval: 512 | mix_sequence_for_target_list.append(pos_mix-1) 513 | source_sequence_for_target_list.append(pos_source) 514 | pos_mix = pos_mix - 1 515 | 516 | mix_sequence_for_target = mix_sequence_for_target_list[:-1] 517 | source_sequence_for_target = source_sequence_for_target_list[:-1] 518 | mix_sequence_for_target.reverse() 519 | source_sequence_for_target.reverse() 520 | 521 | user_train_toy_mix_sequence_for_target[k] = [] 522 | user_train_toy_source_sequence_for_target[k] = [] 523 | for x in mix_sequence_for_target: 524 | user_train_toy_mix_sequence_for_target[k].append(x - len(user_train_toy_mix[k])) 525 | 526 | for x in source_sequence_for_target: 527 | user_train_toy_source_sequence_for_target[k].append(x - len(user_train_toy_source[k])) 528 | 529 | return [user_train_toy_mix, user_train_toy_source, user_train_toy_target, user_valid_toy_target, user_test_toy_target, user_train_toy_mix_sequence_for_target, user_train_toy_source_sequence_for_target, usernum, itemnum, interval, user_train_ti_toy_mix, user_train_ti_toy_source, user_train_ti_toy_target, user_valid_ti_toy_target, user_test_ti_toy_target] 530 | 531 | #calculate the auc 532 | def compute_auc(scores): 533 | scores = -scores.detach().cpu().numpy() 534 | num_pos = 1 535 | score_neg = scores[num_pos:] 536 | num_hit = 0 537 | 538 | for i in range(num_pos): 539 | num_hit += len(np.where(score_neg < scores[i])[0]) 540 | 541 | auc = num_hit / (num_pos * len(score_neg)) 542 | return auc 543 | 544 | 545 | # TODO: merge evaluate functions for test and val set 546 | # evaluate on test set 547 | def evaluate_PDRec(model, dataset, args, user_list): 548 | with torch.no_grad(): 549 | print('Start test...') 550 | [user_train_mix, user_train_source, user_train_target, user_valid_target, user_test_target, user_train_mix_sequence_for_target, user_train_source_sequence_for_target, usernum, itemnum, interval, user_train_ti_mix, user_train_ti_source, user_train_ti_target, user_valid_ti_target, user_test_ti_target] = dataset 551 | 552 | random_min = 1 553 | random_max = interval + 1 554 | item_entries = np.arange(start=random_min, stop=random_max, step=1, dtype=int) 555 | print("The min in source domain is {} and the max in source domain is {}".format(random_min, random_max)) 556 | 557 | NDCG_1 = 0.0 558 | NDCG_5 = 0.0 559 | NDCG_10 = 0.0 560 | NDCG_20 = 0.0 561 | NDCG_50 = 0.0 562 | HT_1 = 0.0 563 | HT_5 = 0.0 564 | HT_10 = 0.0 565 | HT_20 = 0.0 566 | HT_50 = 0.0 567 | AUC = 0.0 568 | loss = 0.0 569 | valid_user = 0.0 570 | labels = torch.zeros(100, device=args.device) 571 | labels[0] = 1 572 | 573 | for u in user_list: 574 | seq_target = np.zeros([args.maxlen], dtype=np.int32) # (200,) 575 | idx_target = args.maxlen - 1 #49 576 | 577 | seq_target[idx_target] = user_valid_target[u][0] 578 | idx_target -= 1 579 | for i in reversed(user_train_target[u]): 580 | seq_target[idx_target] = i 581 | idx_target -= 1 582 | if idx_target == -1: break 583 | 584 | sample_pool = np.setdiff1d(item_entries, seq_target) 585 | item_idx = np.random.choice(sample_pool, args.num_samples, replace=False) 586 | item_idx[0] = user_test_target[u][0] 587 | predictions = model.predict(torch.tensor(u).cuda(), torch.tensor(seq_target).cuda().unsqueeze(0), torch.tensor(item_idx).cuda().unsqueeze(0)) 588 | 589 | AUC += roc_auc_score(labels.cpu(), predictions[0].cpu()) 590 | 591 | loss_test = torch.nn.BCEWithLogitsLoss()(predictions[0].detach(), labels) 592 | 593 | loss += loss_test.item() 594 | predictions = -predictions[0] # - for 1st argsort DESC 595 | 596 | rank = predictions.argsort().argsort()[0].item() 597 | 598 | valid_user += 1 599 | 600 | # AUC += compute_auc(predictions) 601 | if rank < 1: 602 | NDCG_1 += 1 / np.log2(rank + 2) 603 | HT_1 += 1 604 | if rank < 5: 605 | NDCG_5 += 1 / np.log2(rank + 2) 606 | HT_5 += 1 607 | if rank < 10: 608 | NDCG_10 += 1 / np.log2(rank + 2) 609 | HT_10 += 1 610 | if rank < 20: 611 | NDCG_20 += 1 / np.log2(rank + 2) 612 | HT_20 += 1 613 | if rank < 50: 614 | NDCG_50 += 1 / np.log2(rank + 2) 615 | HT_50 += 1 616 | 617 | if valid_user % 1000 == 0: 618 | print('process test user {}'.format(valid_user)) 619 | print("The total number of user is:", valid_user) 620 | return NDCG_1 / valid_user, NDCG_5 / valid_user, NDCG_10 / valid_user, NDCG_20 / valid_user, NDCG_50 / valid_user, HT_1 / valid_user, HT_5 / valid_user, HT_10 / valid_user, HT_20 / valid_user, HT_50 / valid_user, AUC / valid_user, loss / valid_user 621 | 622 | 623 | 624 | 625 | def evaluate_T_DiffRec_TI(model, diffusion, dataset, args, random_min, random_max, random_source_min, random_source_max): 626 | with torch.no_grad(): 627 | print('Start test...') 628 | [user_train_mix, user_train_source, user_train_target, user_valid_target, user_test_target, user_train_mix_sequence_for_target, user_train_source_sequence_for_target, usernum, itemnum, interval, user_train_ti_mix, user_train_ti_source, user_train_ti_target, user_valid_ti_target, user_test_ti_target] = dataset 629 | print("The min in source domain is {} and the max in source domain is {}".format(random_min, random_max)) 630 | item_entries = torch.arange(start=random_min, end=random_max, step=1, dtype=int, device='cuda') 631 | 632 | NDCG_1 = 0.0 633 | NDCG_5 = 0.0 634 | NDCG_10 = 0.0 635 | NDCG_20 = 0.0 636 | NDCG_50 = 0.0 637 | HT_1 = 0.0 638 | HT_5 = 0.0 639 | HT_10 = 0.0 640 | HT_20 = 0.0 641 | HT_50 = 0.0 642 | 643 | AUC = 0.0 644 | valid_user = 0.0 645 | users = range(1, usernum + 1) # range(1, 116255) 646 | labels = torch.zeros(100, device='cuda') 647 | labels[0] = 1 648 | for u in users: 649 | if len(user_train_mix[u]) <= 1 or len(user_train_source[u]) <= 1 or len(user_train_target[u]) <= 1: 650 | continue 651 | 652 | # init the tensor 653 | seq_target = torch.zeros([itemnum+1], dtype=torch.float32, device='cuda') 654 | seq_target_temp = torch.zeros([itemnum+1], dtype=torch.float32, device='cuda') 655 | # the interaction length 656 | user_train_target_this = user_train_target[u]+user_valid_target[u] 657 | user_all_target_this = user_train_target_this+user_test_target[u] 658 | user_train_ti_target_this = user_train_ti_target[u]+user_valid_ti_target[u] 659 | weight_target = scale_withminmax(user_train_ti_target_this, args.w_min, args.w_max, args.reweight_version) 660 | seq_target[user_train_target_this] = 1.0 661 | seq_target_temp[user_train_target_this] = torch.tensor(weight_target, dtype=torch.float32, device='cuda') 662 | 663 | # ipdb.set_trace() 664 | prediction = diffusion.p_sample(model, seq_target_temp.unsqueeze(0), args.sampling_steps, args.sampling_noise) 665 | # ipdb.set_trace() 666 | sample_pool = get_exclusive(item_entries, torch.tensor(user_all_target_this,device='cuda')) 667 | random_index = torch.randperm(sample_pool.shape[0]) 668 | item_idx = sample_pool[random_index[:args.num_samples]] 669 | item_idx[0] = user_all_target_this[-1] 670 | score = torch.index_select(prediction, dim=1, index=item_idx).squeeze() 671 | 672 | AUC += roc_auc_score(labels.cpu(), score.cpu()) 673 | 674 | score = -score # - for 1st argsort DESC 675 | rank = score.argsort().argsort()[0].item() 676 | valid_user += 1 677 | 678 | if rank < 1: 679 | NDCG_1 += 1 / np.log2(rank + 2) 680 | HT_1 += 1 681 | if rank < 5: 682 | NDCG_5 += 1 / np.log2(rank + 2) 683 | HT_5 += 1 684 | if rank < 10: 685 | NDCG_10 += 1 / np.log2(rank + 2) 686 | HT_10 += 1 687 | if rank < 20: 688 | NDCG_20 += 1 / np.log2(rank + 2) 689 | HT_20 += 1 690 | if rank < 50: 691 | NDCG_50 += 1 / np.log2(rank + 2) 692 | HT_50 += 1 693 | if valid_user % 1000 == 0: 694 | print('process test user {}'.format(valid_user)) 695 | 696 | return NDCG_1 / valid_user, NDCG_5 / valid_user, NDCG_10 / valid_user, NDCG_20 / valid_user, NDCG_50 / valid_user, HT_1 / valid_user, HT_5 / valid_user, HT_10 / valid_user, HT_20 / valid_user, HT_50 / valid_user, AUC / valid_user 697 | --------------------------------------------------------------------------------