├── .gitattributes ├── README.md ├── data_preprocessing └── data_preprocessing_item_2nd_hop.py ├── data_utils.py ├── datasets ├── multi_hop_inters_ML_1M.pt ├── sec_hop_inters_ML_1M.pt ├── test_list_ML-1M.npy ├── train_list_ML-1M.npy └── valid_list_ML-1M.npy ├── evaluate_utils.py ├── inference.py ├── main.py ├── models ├── CAM_AE.py ├── CAM_AE_multihops.py └── gaussian_diffusion.py └── saved_models └── CAM_3hops.pth /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pt filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CF-Diff 2 | This is the pytorch implementation of our paper at SIGIR 2024: 3 | > [Collaborative Filtering Based on Diffusion Models: Unveiling the Potential of High-Order Connectivity](https://arxiv.org/pdf/2404.14240) 4 | > 5 | > Hou, Yu and Park, Jin-Duk and Shin, Won-Yong 6 | 7 | The implementation of diffusion model and evaluation parts are referred to [DiffRec](https://github.com/YiyanXu/DiffRec/tree/main). Thank you for this contribution. 8 | ## Environment 9 | - Anaconda 3 10 | - python 3.8.17 11 | - pytorch 1.13.1 12 | - numpy 1.24.3 13 | - math 14 | 15 | ## Usage 16 | ### Data 17 | The user-item interactions, train/valid/test, are in './datasets' folder. "sec_hop_inters_ML_1M.pt" contains the information of second-hop user-item interactions and "multi_hop_inters_ML_1M.pt" contains multi-hop user-item interactions. 18 | More data about "high-order interactions" can be found [here](https://drive.google.com/drive/folders/1CJdlsNuDnLiiyh4iN1eRBGRAKZ3GfxZn?usp=drive_link). More "saved_models" can be found [here](https://drive.google.com/drive/folders/1CJdlsNuDnLiiyh4iN1eRBGRAKZ3GfxZn?usp=drive_link). 19 | ### Training 20 | #### CF-Diff 21 | ``` 22 | cd ./CF_Diff 23 | python main.py 24 | ``` 25 | 26 | ### Inference 27 | ``` 28 | cd ./CF_Diff 29 | python inference.py 30 | ``` 31 | ## Citation 32 | 33 | ``` 34 | @inproceedings{hou2024collaborative, 35 | title = {Collaborative Filtering Based on Diffusion Models: Unveiling the Potential of High-Order Connectivity}, 36 | author = {Hou, Yu and Park, Jin-Duk and Shin, Won-Yong}, 37 | booktitle = {Proceedings of the 47th International ACM SIGIR Conference on Research and Development in Information Retrieval}, 38 | year = {2024} 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /data_preprocessing/data_preprocessing_item_2nd_hop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import scipy.sparse as sp 4 | 5 | 6 | def data_load(train_path, valid_path, test_path): 7 | train_list = np.load(train_path, allow_pickle=True) 8 | valid_list = np.load(valid_path, allow_pickle=True) 9 | test_list = np.load(test_path, allow_pickle=True) 10 | print(train_list) 11 | print(len(train_list)) 12 | 13 | uid_max = 0 14 | iid_max = 0 15 | train_dict = {} 16 | 17 | for uid, iid in train_list: 18 | if uid not in train_dict: 19 | train_dict[uid] = [] 20 | train_dict[uid].append(iid) 21 | if uid > uid_max: 22 | uid_max = uid 23 | if iid > iid_max: 24 | iid_max = iid 25 | 26 | n_user = uid_max + 1 27 | n_item = iid_max + 1 28 | print(f'user num: {n_user}') 29 | print(f'item num: {n_item}') 30 | 31 | train_data = sp.csr_matrix((np.ones_like(train_list[:, 0]), \ 32 | (train_list[:, 0], train_list[:, 1])), dtype='float64', \ 33 | shape=(n_user, n_item)) 34 | 35 | valid_y_data = sp.csr_matrix((np.ones_like(valid_list[:, 0]), 36 | (valid_list[:, 0], valid_list[:, 1])), dtype='float64', 37 | shape=(n_user, n_item)) # valid_groundtruth 38 | 39 | test_y_data = sp.csr_matrix((np.ones_like(test_list[:, 0]), 40 | (test_list[:, 0], test_list[:, 1])), dtype='float64', 41 | shape=(n_user, n_item)) # test_groundtruth 42 | 43 | return train_data, valid_y_data, test_y_data, n_user, n_item 44 | 45 | 46 | data_path = './datasets/yelp2018/' 47 | 48 | dataset_name = 'yelp2018' 49 | 50 | train_path = data_path + 'train_list_' + dataset_name + '.npy' 51 | valid_path = data_path + 'valid_list_' + dataset_name + '.npy' 52 | test_path = data_path + 'test_list_' + dataset_name + '.npy' 53 | print(train_path) 54 | print(valid_path) 55 | print(test_path) 56 | 57 | train_data, valid_y_data, test_y_data, n_user, n_item = data_load(train_path, valid_path, test_path) 58 | print(train_data.shape) 59 | #print(train_data.nbytes) 60 | 61 | data = train_data.todense().A 62 | print(data.shape) 63 | print(data.nbytes) 64 | #valid = valid_y_data.todense().A 65 | #test = test_y_data.todense().A 66 | # 67 | print("ints:", np.sum(np.sum(data, axis=1))) 68 | 69 | 70 | def get_2hop_item_based(data): 71 | # Initialize an empty tensor 72 | sec_hop_infos = torch.empty(len(data), len(data[0])) 73 | print(sec_hop_infos.size()) 74 | 75 | # Loop to add data to the tensor 76 | sec_hop_inters = torch.sum(data, axis=0) / n_user 77 | for i, row in enumerate(data): 78 | 79 | zero_indices = torch.nonzero(row<0.000001).t()#.squeeze() 80 | if i % 1000 == 0: 81 | print(i) 82 | 83 | sec_hop_infos[i] = sec_hop_inters 84 | sec_hop_infos[i][zero_indices[0]] = 0 85 | 86 | #tensor = torch.cat((data, sec_hop_infos), dim=1) # Concatenate the data to the tensor 87 | 88 | return sec_hop_infos 89 | 90 | # Call the function 91 | hop2_rates_test = get_2hop_item_based(torch.tensor(data, dtype=torch.float32)) 92 | 93 | # Print the resulting tensor 94 | print(hop2_rates_test.size()) 95 | 96 | # filename = "datasets/yelp2018/two_hop_rates_items_yelp2018.pt" 97 | # torch.save(hop2_rates_test, filename) -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | from torch.utils.data import Dataset 4 | import torch 5 | 6 | 7 | def data_load(train_path, valid_path, test_path): 8 | train_list = np.load(train_path, allow_pickle=True) 9 | valid_list = np.load(valid_path, allow_pickle=True) 10 | test_list = np.load(test_path, allow_pickle=True) 11 | 12 | uid_max = 0 13 | iid_max = 0 14 | train_dict = {} 15 | 16 | for uid, iid in train_list: 17 | if uid not in train_dict: 18 | train_dict[uid] = [] 19 | train_dict[uid].append(iid) 20 | if uid > uid_max: 21 | uid_max = uid 22 | if iid > iid_max: 23 | iid_max = iid 24 | 25 | n_user = uid_max + 1 26 | n_item = iid_max + 1 27 | print(f'user num: {n_user}') 28 | print(f'item num: {n_item}') 29 | 30 | train_data = sp.csr_matrix((np.ones_like(train_list[:, 0]), \ 31 | (train_list[:, 0], train_list[:, 1])), dtype='float64', \ 32 | shape=(n_user, n_item)) 33 | 34 | valid_y_data = sp.csr_matrix((np.ones_like(valid_list[:, 0]), 35 | (valid_list[:, 0], valid_list[:, 1])), dtype='float64', 36 | shape=(n_user, n_item)) # valid_groundtruth 37 | 38 | test_y_data = sp.csr_matrix((np.ones_like(test_list[:, 0]), 39 | (test_list[:, 0], test_list[:, 1])), dtype='float64', 40 | shape=(n_user, n_item)) # test_groundtruth 41 | 42 | return train_data, valid_y_data, test_y_data, n_user, n_item 43 | 44 | 45 | class DataDiffusion(Dataset): 46 | def __init__(self, data): 47 | self.data = data 48 | def __getitem__(self, index): 49 | item = self.data[index] 50 | return item 51 | def __len__(self): 52 | return len(self.data) 53 | 54 | 55 | class DataDiffusion2(Dataset): 56 | def __init__(self, data1, data2): 57 | self.data1 = data1 58 | self.data2 = data2 59 | def __getitem__(self, index): 60 | item1 = self.data1[index] 61 | item2 = self.data2[index] 62 | return item1, item2 63 | def __len__(self): 64 | return len(self.data1) 65 | 66 | 67 | class DataDiffusion3(Dataset): 68 | def __init__(self, data1, data2, data3): 69 | self.data1 = data1 70 | self.data2 = data2 71 | self.data3 = data3 72 | def __getitem__(self, index): 73 | item1 = self.data1[index] 74 | item2 = self.data2[index] 75 | item3 = self.data3[index] 76 | return item1, item2, item3 77 | def __len__(self): 78 | return len(self.data1) 79 | 80 | 81 | def get_top_k_similar_pearson(data, k): 82 | # Subtract the mean of each row from the rows (center the data) 83 | mean_centered_data = data - data.mean(dim=1, keepdim=True) 84 | 85 | # Compute the covariance matrix 86 | covariance_matrix = torch.mm(mean_centered_data, mean_centered_data.t()) 87 | 88 | # Normalize the covariance matrix to get Pearson correlation coefficients 89 | # Calculate the standard deviation for each row 90 | std_dev = mean_centered_data.norm(p=2, dim=1, keepdim=True) 91 | 92 | # Avoid division by zero in case there is a row with zero variance 93 | std_dev[std_dev == 0] = 1 94 | 95 | # Pearson correlation matrix 96 | pearson_correlation_matrix = covariance_matrix / torch.mm(std_dev, std_dev.t()) 97 | 98 | # We need to zero out the diagonal elements (self-correlation) before getting top-k 99 | # Fill diagonal with very low value which cannot be a top correlation 100 | eye = torch.eye(pearson_correlation_matrix.size(0), device=pearson_correlation_matrix.device) 101 | pearson_correlation_matrix -= eye * 2 # Subtract 2 which is definitely out of bound for correlation 102 | 103 | # Get top-k values along each row 104 | _, indices = torch.topk(pearson_correlation_matrix, k=k, dim=1) 105 | 106 | return indices 107 | -------------------------------------------------------------------------------- /datasets/multi_hop_inters_ML_1M.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:662d2aeb2a9f3a75a8bfa53bfd48f25194b0f31939e13ce2ed7db89fe85a3dc1 3 | size 133734293 4 | -------------------------------------------------------------------------------- /datasets/sec_hop_inters_ML_1M.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackfrost168/CF_Diff/7e30c302091a7ca6ff3fdcc958f5252c8bc60b5b/datasets/sec_hop_inters_ML_1M.pt -------------------------------------------------------------------------------- /datasets/test_list_ML-1M.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackfrost168/CF_Diff/7e30c302091a7ca6ff3fdcc958f5252c8bc60b5b/datasets/test_list_ML-1M.npy -------------------------------------------------------------------------------- /datasets/train_list_ML-1M.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackfrost168/CF_Diff/7e30c302091a7ca6ff3fdcc958f5252c8bc60b5b/datasets/train_list_ML-1M.npy -------------------------------------------------------------------------------- /datasets/valid_list_ML-1M.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackfrost168/CF_Diff/7e30c302091a7ca6ff3fdcc958f5252c8bc60b5b/datasets/valid_list_ML-1M.npy -------------------------------------------------------------------------------- /evaluate_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def computeTopNAccuracy(GroundTruth, predictedIndices, topN): 4 | precision = [] 5 | recall = [] 6 | NDCG = [] 7 | MRR = [] 8 | 9 | for index in range(len(topN)): 10 | sumForPrecision = 0 11 | sumForRecall = 0 12 | sumForNdcg = 0 13 | sumForMRR = 0 14 | for i in range(len(predictedIndices)): 15 | if len(GroundTruth[i]) != 0: 16 | mrrFlag = True 17 | userHit = 0 18 | userMRR = 0 19 | dcg = 0 20 | idcg = 0 21 | idcgCount = len(GroundTruth[i]) 22 | ndcg = 0 23 | for j in range(topN[index]): 24 | if predictedIndices[i][j] in GroundTruth[i]: 25 | # if Hit! 26 | dcg += 1.0/math.log2(j + 2) 27 | if mrrFlag: 28 | userMRR = (1.0/(j+1.0)) 29 | mrrFlag = False 30 | userHit += 1 31 | 32 | if idcgCount > 0: 33 | idcg += 1.0/math.log2(j + 2) 34 | idcgCount = idcgCount-1 35 | 36 | if(idcg != 0): 37 | ndcg += (dcg/idcg) 38 | 39 | sumForPrecision += userHit / topN[index] 40 | sumForRecall += userHit / len(GroundTruth[i]) 41 | sumForNdcg += ndcg 42 | sumForMRR += userMRR 43 | 44 | precision.append(round(sumForPrecision / len(predictedIndices), 4)) 45 | recall.append(round(sumForRecall / len(predictedIndices), 4)) 46 | NDCG.append(round(sumForNdcg / len(predictedIndices), 4)) 47 | MRR.append(round(sumForMRR / len(predictedIndices), 4)) 48 | 49 | return precision, recall, NDCG, MRR 50 | 51 | 52 | def print_results(loss, valid_result, test_result): 53 | """output the evaluation results.""" 54 | if loss is not None: 55 | print("[Train]: loss: {:.4f}".format(loss)) 56 | if valid_result is not None: 57 | print("[Valid]: Precision: {} Recall: {} NDCG: {} MRR: {}".format( 58 | '-'.join([str(x) for x in valid_result[0]]), 59 | '-'.join([str(x) for x in valid_result[1]]), 60 | '-'.join([str(x) for x in valid_result[2]]), 61 | '-'.join([str(x) for x in valid_result[3]]))) 62 | if test_result is not None: 63 | print("[Test]: Precision: {} Recall: {} NDCG: {} MRR: {}".format( 64 | '-'.join([str(x) for x in test_result[0]]), 65 | '-'.join([str(x) for x in test_result[1]]), 66 | '-'.join([str(x) for x in test_result[2]]), 67 | '-'.join([str(x) for x in test_result[3]]))) 68 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference of a diffusion model for recommendation 3 | """ 4 | 5 | import argparse 6 | import os 7 | import numpy as np 8 | import torch 9 | 10 | from torch.utils.data import DataLoader 11 | 12 | import models.gaussian_diffusion as gd 13 | import evaluate_utils 14 | import data_utils 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataset', type=str, default='ML-1M', help='choose the dataset') 19 | parser.add_argument('--data_path', type=str, default='./datasets/ML-1M/', help='load data path') 20 | parser.add_argument('--batch_size', type=int, default=200) 21 | parser.add_argument('--topN', type=str, default='[10, 20, 50, 100]') 22 | parser.add_argument('--cuda', action='store_true', help='use CUDA') 23 | parser.add_argument('--gpu', type=str, default='0', help='gpu card ID') 24 | parser.add_argument('--log_name', type=str, default='log', help='the log name') 25 | 26 | parser.add_argument('--lr', type=float, default=0.00001, help='learning rate') 27 | parser.add_argument('--weight_decay', type=float, default=0.0) 28 | 29 | # params for the model 30 | parser.add_argument('--time_type', type=str, default='cat', help='cat or add') 31 | parser.add_argument('--norm', type=bool, default=False, help='Normalize the input or not') 32 | parser.add_argument('--emb_size', type=int, default=10, help='timestep embedding size') 33 | parser.add_argument('--n_hops', type=int, default=2, help='Number of multi-hop neighbors') 34 | 35 | # params for diffusion 36 | parser.add_argument('--mean_type', type=str, default='x0', help='MeanType for diffusion: x0, eps') 37 | parser.add_argument('--steps', type=int, default=100, help='diffusion steps') 38 | parser.add_argument('--noise_schedule', type=str, default='linear-var', help='the schedule for noise generating') 39 | parser.add_argument('--noise_scale', type=float, default=0.1, help='noise scale for noise generating') 40 | parser.add_argument('--noise_min', type=float, default=0.001, help='noise lower bound for noise generating') 41 | parser.add_argument('--noise_max', type=float, default=0.02, help='noise upper bound for noise generating') 42 | parser.add_argument('--sampling_noise', type=bool, default=False, help='sampling with noise or not') 43 | parser.add_argument('--sampling_steps', type=int, default=0, help='steps of the forward process during inference') 44 | parser.add_argument('--reweight', type=bool, default=True, help='assign different weight to different timestep or not') 45 | 46 | args = parser.parse_args() 47 | 48 | args.data_path = args.data_path 49 | if args.dataset == 'ML-1M': 50 | args.steps = 5 51 | args.noise_scale = 0.01 52 | args.noise_min = 0.001 53 | args.noise_max = 0.01 54 | args.n_hops = 3 55 | elif args.dataset == 'anime': 56 | args.steps = 10 57 | args.noise_scale = 0.003 58 | args.noise_min = 0.0001 59 | args.noise_max = 0.01 60 | args.n_hops = 2 61 | elif args.dataset == 'yelp2018': 62 | args.steps = 20 63 | args.noise_scale = 0.01 64 | args.noise_min = 0.001 65 | args.noise_max = 0.01 66 | args.n_hops = 2 67 | 68 | else: 69 | raise ValueError 70 | 71 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 72 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 73 | 74 | ### DATA LOAD ### 75 | train_path = args.data_path + 'train_list_' + args.dataset + '.npy' 76 | valid_path = args.data_path + 'valid_list_' + args.dataset + '.npy' 77 | test_path = args.data_path + 'test_list_' + args.dataset + '.npy' 78 | 79 | print("{}-hop neighbors are taken into account".format(args.n_hops)) 80 | if args.n_hops == 2: 81 | sec_hop = torch.load(args.data_path + 'two_hop_rates_items_' + args.dataset + '.pt') 82 | multi_hop = sec_hop 83 | elif args.n_hops == 3: 84 | multi_hop = torch.load(args.data_path + 'multi_hop_inters_' + args.dataset + '.pt') 85 | 86 | train_data, valid_y_data, test_y_data, n_user, n_item = data_utils.data_load(train_path, valid_path, test_path) 87 | train_dataset = data_utils.DataDiffusion2(torch.FloatTensor(train_data.A), multi_hop) 88 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size) 89 | test_loader = DataLoader(train_dataset, batch_size=args.batch_size) 90 | 91 | mask_tv = train_data + valid_y_data 92 | 93 | print('data ready.') 94 | 95 | 96 | ### CREATE DIFFUISON ### 97 | if args.mean_type == 'x0': 98 | mean_type = gd.ModelMeanType.START_X 99 | elif args.mean_type == 'eps': 100 | else: 101 | raise ValueError("Unimplemented mean type %s" % args.mean_type) 102 | 103 | diffusion = gd.GaussianDiffusion(mean_type, args.noise_schedule, \ 104 | args.noise_scale, args.noise_min, args.noise_max, args.steps, device) 105 | diffusion.to(device) 106 | 107 | ### CREATE DNN ### 108 | model_path = "saved_models/" 109 | if args.dataset == "anime": 110 | model_name = "CAM_2hops_anime.pth" 111 | elif args.dataset == "yelp2018": 112 | model_name = "CAM_2hops_yelp2018.pth" 113 | elif args.dataset == "ML-1M": 114 | model_name = "CAM_3hops.pth" 115 | 116 | 117 | def evaluate(data_loader, data_te, mask_his, topN, model): 118 | model.eval() 119 | e_idxlist = list(range(mask_his.shape[0])) 120 | e_N = mask_his.shape[0] 121 | 122 | predict_items = [] 123 | target_items = [] 124 | for i in range(e_N): 125 | target_items.append(data_te[i, :].nonzero()[1].tolist()) 126 | 127 | with torch.no_grad(): 128 | for batch_idx, (batch, batch_2) in enumerate(data_loader): 129 | his_data = mask_his[e_idxlist[batch_idx*args.batch_size:batch_idx*args.batch_size+len(batch)]] 130 | batch = batch.to(device) 131 | batch_2 = batch_2.to(device) 132 | prediction = diffusion.p_sample(model, batch, batch_2, args.sampling_steps, args.sampling_noise) 133 | prediction[his_data.nonzero()] = -np.inf 134 | 135 | _, indices = torch.topk(prediction, topN[-1]) 136 | indices = indices.cpu().numpy().tolist() 137 | predict_items.extend(indices) 138 | 139 | test_results = evaluate_utils.computeTopNAccuracy(target_items, predict_items, topN) 140 | 141 | return test_results 142 | 143 | 144 | model = torch.load(model_path + model_name).to(device) # batch=50 145 | 146 | print("Initial models ready.") 147 | 148 | valid_results = evaluate(test_loader, valid_y_data, train_data, eval(args.topN), model) 149 | test_results = evaluate(test_loader, test_y_data, mask_tv, eval(args.topN), model) 150 | evaluate_utils.print_results(None, valid_results, test_results) 151 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | from torch.utils.data import DataLoader 8 | import torch.backends.cudnn as cudnn 9 | import models.gaussian_diffusion as gd 10 | 11 | from models.CAM_AE import CAM_AE 12 | from models.CAM_AE_multihops import CAM_AE_multihops 13 | 14 | import evaluate_utils 15 | import data_utils 16 | import random 17 | 18 | random_seed = 1 19 | torch.manual_seed(random_seed) # cpu 20 | torch.cuda.manual_seed(random_seed) # gpu 21 | np.random.seed(random_seed) # numpy 22 | random.seed(random_seed) # random and transforms 23 | torch.backends.cudnn.deterministic = True # cudnn 24 | 25 | 26 | def worker_init_fn(worker_id): 27 | np.random.seed(random_seed + worker_id) 28 | 29 | 30 | def seed_worker(worker_id): 31 | worker_seed = torch.initial_seed() % 2 ** 32 32 | np.random.seed(worker_seed) 33 | 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--dataset', type=str, default='ML-1M', help='choose the dataset') 37 | parser.add_argument('--data_path', type=str, default='./datasets/', help='load data path') 38 | parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') 39 | parser.add_argument('--weight_decay', type=float, default=0.0) 40 | parser.add_argument('--batch_size', type=int, default=128) 41 | parser.add_argument('--epochs', type=int, default=1000, help='upper epoch limit') 42 | parser.add_argument('--topN', type=str, default='[10, 20, 50, 100]') 43 | parser.add_argument('--tst_w_val', action='store_true', help='test with validation') 44 | parser.add_argument('--cuda', action='store_true', help='use CUDA') 45 | parser.add_argument('--gpu', type=str, default='0', help='gpu card ID') 46 | parser.add_argument('--save_path', type=str, default='./saved_models/', help='save model path') 47 | parser.add_argument('--log_name', type=str, default='log', help='the log name') 48 | parser.add_argument('--round', type=int, default=1, help='record the experiment') 49 | 50 | # params for the model 51 | parser.add_argument('--time_type', type=str, default='cat', help='cat or add') 52 | parser.add_argument('--norm', type=bool, default=False, help='Normalize the input or not') 53 | parser.add_argument('--emb_size', type=int, default=10, help='timestep embedding size') 54 | 55 | # params for diffusion 56 | parser.add_argument('--mean_type', type=str, default='x0', help='MeanType for diffusion: x0, eps') 57 | parser.add_argument('--steps', type=int, default=20, help='diffusion steps') 58 | parser.add_argument('--noise_schedule', type=str, default='linear-var', help='the schedule for noise generating') 59 | parser.add_argument('--noise_scale', type=float, default=0.01, help='noise scale for noise generating') 60 | parser.add_argument('--noise_min', type=float, default=0.001, help='noise lower bound for noise generating') 61 | parser.add_argument('--noise_max', type=float, default=0.01, help='noise upper bound for noise generating') 62 | parser.add_argument('--sampling_noise', type=bool, default=False, help='sampling with noise or not') 63 | parser.add_argument('--sampling_steps', type=int, default=0, help='steps of the forward process during inference') 64 | parser.add_argument('--reweight', type=bool, default=True, help='assign different weight to different timestep or not') 65 | 66 | print("torch version:", torch.__version__) 67 | 68 | args = parser.parse_args() 69 | print("args:", args) 70 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 71 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 72 | print("device:", device) 73 | print("Starting time: ", time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))) 74 | 75 | ### DATA LOAD ### 76 | data_name = 'ML-1M' 77 | train_path = args.data_path + 'train_list_' + args.dataset + '.npy' 78 | valid_path = args.data_path + 'valid_list_' + args.dataset + '.npy' 79 | test_path = args.data_path + 'test_list_' + args.dataset + '.npy' 80 | 81 | n_hop = 3 # The number of hops neighbors, e.g. n_hop=3 means three hops neighbors are taken into account 82 | print("{}-hop neighbors are taken into account".format(n_hop)) 83 | if n_hop == 2: 84 | sec_hop = torch.load(args.data_path + 'sec_hop_inters_ML_1M.pt') 85 | multi_hop = sec_hop 86 | elif n_hop == 3: 87 | multi_hop = torch.load(args.data_path + 'multi_hop_inters_ML_1M.pt') 88 | 89 | train_data, valid_y_data, test_y_data, n_user, n_item = data_utils.data_load(train_path, valid_path, test_path) 90 | train_dataset = data_utils.DataDiffusion(torch.FloatTensor(train_data.A)) 91 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, shuffle=False, num_workers=0, 92 | worker_init_fn=worker_init_fn) 93 | test_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False) 94 | 95 | train_loader_sec_hop = DataLoader(multi_hop, batch_size=args.batch_size, pin_memory=True, shuffle=False, num_workers=0, 96 | worker_init_fn=worker_init_fn) 97 | test_loader_sec_hop = DataLoader(multi_hop, batch_size=args.batch_size, shuffle=False) 98 | 99 | if args.tst_w_val: 100 | tv_dataset = data_utils.DataDiffusion(torch.FloatTensor(train_data.A) + torch.FloatTensor(valid_y_data.A)) 101 | test_twv_loader = DataLoader(tv_dataset, batch_size=args.batch_size, shuffle=False) 102 | mask_tv = train_data + valid_y_data 103 | 104 | print('data is ready.') 105 | 106 | ### Build Gaussian Diffusion ### 107 | if args.mean_type == 'x0': 108 | mean_type = gd.ModelMeanType.START_X 109 | elif args.mean_type == 'eps': 110 | mean_type = gd.ModelMeanType.EPSILON 111 | else: 112 | raise ValueError("Unimplemented mean type %s" % args.mean_type) 113 | 114 | diffusion = gd.GaussianDiffusion(mean_type, args.noise_schedule, \ 115 | args.noise_scale, args.noise_min, args.noise_max, args.steps, device).to(device) 116 | 117 | # Build model 118 | if n_hop == 2: 119 | model = CAM_AE(16, 2, 2, n_item, args.emb_size).to(device) 120 | elif n_hop == 3: 121 | model = CAM_AE_multihops(16, 4, 2, n_item, args.emb_size).to(device) 122 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 123 | print("models are ready.") 124 | 125 | 126 | def evaluate(data_loader, data_loader_sec_hop, data_te, mask_his, topN): 127 | model.eval() 128 | e_idxlist = list(range(mask_his.shape[0])) 129 | e_N = mask_his.shape[0] 130 | 131 | predict_items = [] 132 | target_items = [] 133 | for i in range(e_N): 134 | target_items.append(data_te[i, :].nonzero()[1].tolist()) 135 | 136 | with torch.no_grad(): 137 | for (batch_idx, batch), (batch_idx_2, batch_2) in zip(enumerate(data_loader), enumerate(data_loader_sec_hop)): 138 | his_data = mask_his[e_idxlist[batch_idx * args.batch_size:batch_idx * args.batch_size + len(batch)]] 139 | batch = batch.to(device) 140 | batch_2 = batch_2.to(device) 141 | prediction = diffusion.p_sample(model, batch, batch_2, args.sampling_steps, args.sampling_noise) 142 | prediction[his_data.nonzero()] = -np.inf 143 | _, indices = torch.topk(prediction, topN[-1]) 144 | indices = indices.cpu().numpy().tolist() 145 | predict_items.extend(indices) 146 | 147 | test_results = evaluate_utils.computeTopNAccuracy(target_items, predict_items, topN) 148 | 149 | return test_results 150 | 151 | 152 | if __name__ == '__main__': 153 | 154 | best_recall, best_epoch = -100, 0 155 | best_test_result = None 156 | print("Start training...") 157 | for epoch in range(1, args.epochs + 1): 158 | if epoch - best_epoch >= 20: 159 | print('-' * 18) 160 | print('Exiting from training early') 161 | break 162 | 163 | model.train() 164 | start_time = time.time() 165 | 166 | batch_count = 0 167 | total_loss = 0.0 168 | 169 | for (batch_idx, batch), (batch_idx_2, batch_2) in zip(enumerate(train_loader), enumerate(train_loader_sec_hop)): 170 | batch = batch.to(device) 171 | batch_2 = batch_2.to(device) 172 | batch_count += 1 173 | optimizer.zero_grad() 174 | losses = diffusion.training_losses(model, batch, batch_2, args.reweight) 175 | loss = losses["loss"].mean() 176 | total_loss += loss 177 | loss.backward() 178 | optimizer.step() 179 | 180 | if epoch % 5 == 0: 181 | valid_results = evaluate(test_loader, test_loader_sec_hop, valid_y_data, train_data, eval(args.topN)) 182 | if args.tst_w_val: 183 | test_results = evaluate(test_twv_loader, test_loader_sec_hop, test_y_data, mask_tv, eval(args.topN)) 184 | else: 185 | test_results = evaluate(test_loader, test_loader_sec_hop, test_y_data, mask_tv, eval(args.topN)) 186 | evaluate_utils.print_results(None, valid_results, test_results) 187 | 188 | if valid_results[1][1] > best_recall: # recall@20 as selection 189 | best_recall, best_epoch = valid_results[1][1], epoch 190 | best_results = valid_results 191 | best_test_results = test_results 192 | 193 | if not os.path.exists(args.save_path): 194 | os.makedirs(args.save_path) 195 | torch.save(model, 196 | '{}{}_lr{}_wd{}_bs{}_dims{}_emb{}_{}_steps{}_scale{}_min{}_max{}_sample{}_reweight{}_{}.pth' \ 197 | .format(args.save_path, args.dataset, args.lr, args.weight_decay, args.batch_size, args.dims, 198 | args.emb_size, args.mean_type, \ 199 | args.steps, args.noise_scale, args.noise_min, args.noise_max, args.sampling_steps, 200 | args.reweight, args.log_name)) 201 | 202 | print("Runing Epoch {:03d} ".format(epoch) + 'train loss {:.4f}'.format(total_loss) + " costs " + time.strftime( 203 | "%H: %M: %S", time.gmtime(time.time() - start_time))) 204 | print('---' * 18) 205 | 206 | 207 | print('===' * 18) 208 | print("End. Best Epoch {:03d} ".format(best_epoch)) 209 | evaluate_utils.print_results(None, best_results, best_test_results) 210 | print("End time: ", time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))) 211 | -------------------------------------------------------------------------------- /models/CAM_AE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class CAM_AE(nn.Module): 8 | """ 9 | CAM-AE: The neural network architecture for learning the data distribution in the reverse diffusion process. 10 | First-hop neighbors (direct neighbors) are to be integrated. 11 | """ 12 | def __init__(self, d_model, num_heads, num_layers, in_dims, emb_size, time_type="cat", norm=False, dropout=0.5): 13 | super(CAM_AE, self).__init__() 14 | self.in_dims = in_dims 15 | self.time_type = time_type 16 | self.time_emb_dim = emb_size 17 | self.norm = norm 18 | self.num_layers = num_layers 19 | 20 | self.emb_layer = nn.Linear(self.time_emb_dim, self.time_emb_dim) 21 | 22 | self.in_layers = nn.ModuleList([nn.Linear(d_in, d_out) \ 23 | for d_in, d_out in zip([d_model,d_model], [d_model,d_model])]) 24 | self.out_layers = nn.ModuleList([nn.Linear(d_in, d_out) \ 25 | for d_in, d_out in zip([d_model,d_model], [d_model,d_model])]) 26 | self.forward_layers = nn.ModuleList([nn.Linear(d_model, d_model) \ 27 | for i in range(num_layers)]) 28 | self.dim_inters = 650 29 | self.first_hop_embedding = nn.Linear(1, d_model) # Expend dimension 30 | self.first_hop_decoding = nn.Linear(d_model, 1) 31 | self.second_hop_embedding = nn.Linear(1, d_model) # Expend dimension 32 | self.final_out = nn.Linear(self.dim_inters+emb_size, self.dim_inters) 33 | 34 | self.drop = nn.Dropout(dropout) 35 | self.drop1 = nn.Dropout(0.8) 36 | self.drop2 = nn.Dropout(dropout) 37 | 38 | self.encoder = nn.Linear(self.in_dims, self.dim_inters) # Get the encoded vector for user-item interactions 39 | self.decoder = nn.Linear(self.dim_inters+emb_size, self.in_dims) 40 | self.encoder2 = nn.Linear(self.in_dims, self.dim_inters) 41 | 42 | # Attention layer 43 | self.self_attentions = nn.ModuleList([ 44 | nn.MultiheadAttention(d_model, num_heads, dropout=0.5, batch_first=True) 45 | for i in range(num_layers) 46 | ]) 47 | 48 | self.time_emb_dim = emb_size 49 | self.d_model = d_model 50 | self.norm1 = nn.LayerNorm(d_model) 51 | self.norm2 = nn.LayerNorm(d_model) 52 | 53 | def forward(self, x, x_sec_hop, timesteps): 54 | 55 | x = self.encoder(x) 56 | h_sec_hop = self.encoder(x_sec_hop) 57 | 58 | time_emb = timestep_embedding(timesteps, self.time_emb_dim).to(x.device) 59 | emb = self.emb_layer(time_emb) 60 | 61 | if self.norm: 62 | x = F.normalize(x) 63 | x = self.drop(x) 64 | h = torch.cat([x, emb], dim=-1) 65 | h = h.unsqueeze(-1) 66 | h = self.first_hop_embedding(h) 67 | 68 | h_sec_hop = torch.cat([h_sec_hop, emb], dim=-1) 69 | h_sec_hop = h_sec_hop.unsqueeze(-1) 70 | h_sec_hop = self.second_hop_embedding(h_sec_hop) 71 | 72 | for i in range(self.num_layers): 73 | 74 | attention_layer = self.self_attentions[i] 75 | attention, attn_output_weights = attention_layer(h_sec_hop, h, h) 76 | 77 | attention = self.drop1(attention) 78 | h = h + attention 79 | #h = self.norm1(h) 80 | h = self.drop2(h) 81 | forward_pass = self.forward_layers[i] 82 | h = forward_pass(h) 83 | 84 | if i != self.num_layers - 1: 85 | h = torch.tanh(h) 86 | 87 | h = self.first_hop_decoding(h) 88 | h = torch.squeeze(h) 89 | h = torch.tanh(h) 90 | h = self.decoder(h) 91 | 92 | return h 93 | 94 | def timestep_embedding(timesteps, dim, max_period=10000): 95 | """ 96 | Create sinusoidal timestep embeddings. 97 | 98 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 99 | These may be fractional. 100 | :param dim: the dimension of the output. 101 | :param max_period: controls the minimum frequency of the embeddings. 102 | :return: an [N x dim] Tensor of positional embeddings. 103 | """ 104 | half = dim // 2 105 | freqs = torch.exp( 106 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 107 | ).to(timesteps.device) 108 | args = timesteps[:, None].float() * freqs[None] 109 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 110 | if dim % 2: 111 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 112 | return embedding 113 | -------------------------------------------------------------------------------- /models/CAM_AE_multihops.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import math 5 | 6 | 7 | class CAM_AE_multihops(nn.Module): 8 | """ 9 | CAM-AE_multihops: The neural network architecture for learning the data distribution in the reverse diffusion process. 10 | Multi-hop neighbors are to be integrated. 11 | """ 12 | def __init__(self, d_model, num_heads, num_layers, in_dims, emb_size, time_type="cat", norm=False, dropout=0.5): 13 | super(CAM_AE_multihops, self).__init__() 14 | self.in_dims = in_dims 15 | self.time_type = time_type 16 | self.time_emb_dim = emb_size 17 | self.norm = norm 18 | self.num_layers = num_layers 19 | self.emb_layer = nn.Linear(self.time_emb_dim, self.time_emb_dim) 20 | self.in_layers = nn.ModuleList([nn.Linear(d_in, d_out) \ 21 | for d_in, d_out in zip([d_model,d_model], [d_model,d_model])]) 22 | self.out_layers = nn.ModuleList([nn.Linear(d_in, d_out) \ 23 | for d_in, d_out in zip([d_model,d_model], [d_model,d_model])]) 24 | self.forward_layers = nn.ModuleList([nn.Linear(d_model, d_model) \ 25 | for i in range(num_layers)]) 26 | 27 | self.dim_inters = 512 # The hidden dimension, correcting to k in the paper 28 | self.first_hop_embedding = nn.Linear(1, d_model) # Expend dimension, correcting to d in the paper 29 | self.first_hop_decoding = nn.Linear(d_model, 1) 30 | self.second_hop_embedding = nn.Linear(1, d_model) 31 | self.third_hop_embedding = nn.Linear(1, d_model) 32 | self.final_out = nn.Linear(self.dim_inters+emb_size, self.dim_inters) 33 | 34 | self.drop = nn.Dropout(dropout) 35 | self.drop1 = nn.Dropout(0.5) 36 | self.drop2 = nn.Dropout(dropout) 37 | 38 | self.encoder = nn.Linear(self.in_dims, self.dim_inters) 39 | self.decoder = nn.Linear(self.dim_inters+emb_size, self.in_dims) 40 | self.encoder2 = nn.Linear(900, self.dim_inters) 41 | 42 | self.self_attentions = nn.ModuleList([ 43 | nn.MultiheadAttention(d_model, num_heads, dropout=0.5, batch_first=True) 44 | for i in range(num_layers) 45 | ]) 46 | 47 | self.time_emb_dim = emb_size 48 | self.d_model = d_model 49 | self.norm1 = nn.LayerNorm(d_model) 50 | self.norm2 = nn.LayerNorm(d_model) 51 | 52 | def forward(self, x, x_sec_hop, timesteps): 53 | 54 | x = self.encoder(x) 55 | h_sec_hop = self.encoder(x_sec_hop[:, 0:self.in_dims]) #self.encoder(x_sec_hop[:, 0:6969]) #x_sec_hop[:, 0:2810] 56 | h_third_hop = self.encoder(x_sec_hop[:, self.in_dims:]) #self.encoder(x_sec_hop[:, 6969:]) 57 | 58 | time_emb = timestep_embedding(timesteps, self.time_emb_dim).to(x.device) 59 | emb = self.emb_layer(time_emb) 60 | if self.norm: 61 | x = F.normalize(x) 62 | x = self.drop(x) 63 | h = torch.cat([x, emb], dim=-1) 64 | h = h.unsqueeze(-1) 65 | h = self.first_hop_embedding(h) 66 | 67 | h_sec_hop = torch.cat([h_sec_hop, emb], dim=-1) 68 | h_sec_hop = h_sec_hop.unsqueeze(-1) 69 | h_sec_hop = self.second_hop_embedding(h_sec_hop) 70 | 71 | h_third_hop = torch.cat([h_third_hop, emb], dim=-1) 72 | h_third_hop = h_third_hop.unsqueeze(-1) 73 | h_third_hop = self.third_hop_embedding(h_third_hop) 74 | 75 | h2 = h.clone() 76 | 77 | for i in range(self.num_layers): 78 | 79 | attention_layer = self.self_attentions[i] 80 | attention, attn_output_weights = attention_layer(h_sec_hop, h, h) 81 | attention = F.normalize(attention) 82 | attention = self.drop1(attention) 83 | h = h + 0.15 * attention 84 | # #h = self.norm1(h) 85 | h = self.drop2(h) 86 | forward_pass = self.forward_layers[i] 87 | h = forward_pass(h) 88 | 89 | if i != self.num_layers - 1: 90 | h = torch.tanh(h) 91 | 92 | for i in range(self.num_layers): 93 | 94 | attention_layer = self.self_attentions[i] 95 | attention, attn_output_weights = attention_layer(h_third_hop, h2, h2) 96 | attention = F.normalize(attention) 97 | attention = self.drop1(attention) 98 | 99 | h2 = h2 + 0.5 * attention 100 | # #h = self.norm1(h) 101 | h2 = self.drop2(h2) 102 | forward_pass = self.forward_layers[i] 103 | h2 = forward_pass(h2) 104 | 105 | if i != self.num_layers - 1: 106 | h2 = torch.tanh(h2) 107 | 108 | b = 0.9 109 | h = b * h + (1-b) * h2 110 | h = self.first_hop_decoding(h) 111 | h = torch.squeeze(h) 112 | h = torch.tanh(h) 113 | h = self.decoder(h) 114 | 115 | return h 116 | 117 | def timestep_embedding(timesteps, dim, max_period=10000): 118 | """ 119 | Create sinusoidal timestep embeddings. 120 | 121 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 122 | These may be fractional. 123 | :param dim: the dimension of the output. 124 | :param max_period: controls the minimum frequency of the embeddings. 125 | :return: an [N x dim] Tensor of positional embeddings. 126 | """ 127 | half = dim // 2 128 | freqs = torch.exp( 129 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 130 | ).to(timesteps.device) 131 | args = timesteps[:, None].float() * freqs[None] 132 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 133 | if dim % 2: 134 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 135 | return embedding 136 | -------------------------------------------------------------------------------- /models/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import math 3 | import numpy as np 4 | import torch as th 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | 8 | 9 | class ModelMeanType(enum.Enum): 10 | START_X = enum.auto() # the model predicts x_0 11 | EPSILON = enum.auto() # the model predicts epsilon 12 | 13 | 14 | class GaussianDiffusion(nn.Module): 15 | def __init__(self, mean_type, noise_schedule, noise_scale, noise_min, noise_max,\ 16 | steps, device, history_num_per_term=10, beta_fixed=True): 17 | 18 | self.mean_type = mean_type 19 | self.noise_schedule = noise_schedule 20 | self.noise_scale = noise_scale 21 | self.noise_min = noise_min 22 | self.noise_max = noise_max 23 | self.steps = steps 24 | self.device = device 25 | 26 | self.history_num_per_term = history_num_per_term 27 | self.Lt_history = th.zeros(steps, history_num_per_term, dtype=th.float64).to(device) 28 | self.Lt_count = th.zeros(steps, dtype=int).to(device) 29 | 30 | if noise_scale != 0.: 31 | self.betas = th.tensor(self.get_betas(), dtype=th.float64).to(self.device) 32 | if beta_fixed: 33 | self.betas[0] = 0.00001 # Deep Unsupervised Learning using Noneequilibrium Thermodynamics 2.4.1 34 | # The variance \beta_1 of the first step is fixed to a small constant to prevent overfitting. 35 | assert len(self.betas.shape) == 1, "betas must be 1-D" 36 | assert len(self.betas) == self.steps, "num of betas must equal to diffusion steps" 37 | assert (self.betas > 0).all() and (self.betas <= 1).all(), "betas out of range" 38 | 39 | self.calculate_for_diffusion() 40 | 41 | super(GaussianDiffusion, self).__init__() 42 | 43 | def get_betas(self): 44 | """ 45 | Given the schedule name, create the betas for the diffusion process. 46 | """ 47 | if self.noise_schedule == "linear" or self.noise_schedule == "linear-var": 48 | start = self.noise_scale * self.noise_min 49 | end = self.noise_scale * self.noise_max 50 | if self.noise_schedule == "linear": 51 | return np.linspace(start, end, self.steps, dtype=np.float64) 52 | else: 53 | return betas_from_linear_variance(self.steps, np.linspace(start, end, self.steps, dtype=np.float64)) 54 | elif self.noise_schedule == "cosine": 55 | return betas_for_alpha_bar( 56 | self.steps, 57 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 58 | ) 59 | elif self.noise_schedule == "binomial": # Deep Unsupervised Learning using Noneequilibrium Thermodynamics 2.4.1 60 | ts = np.arange(self.steps) 61 | betas = [1 / (self.steps - t + 1) for t in ts] 62 | return betas 63 | else: 64 | raise NotImplementedError(f"unknown beta schedule: {self.noise_schedule}!") 65 | 66 | def calculate_for_diffusion(self): 67 | alphas = 1.0 - self.betas 68 | self.alphas_cumprod = th.cumprod(alphas, axis=0).to(self.device) 69 | self.alphas_cumprod_prev = th.cat([th.tensor([1.0]).to(self.device), self.alphas_cumprod[:-1]]).to(self.device) # alpha_{t-1} 70 | self.alphas_cumprod_next = th.cat([self.alphas_cumprod[1:], th.tensor([0.0]).to(self.device)]).to(self.device) # alpha_{t+1} 71 | assert self.alphas_cumprod_prev.shape == (self.steps,) 72 | 73 | self.sqrt_alphas_cumprod = th.sqrt(self.alphas_cumprod) 74 | self.sqrt_one_minus_alphas_cumprod = th.sqrt(1.0 - self.alphas_cumprod) 75 | self.log_one_minus_alphas_cumprod = th.log(1.0 - self.alphas_cumprod) 76 | self.sqrt_recip_alphas_cumprod = th.sqrt(1.0 / self.alphas_cumprod) 77 | self.sqrt_recipm1_alphas_cumprod = th.sqrt(1.0 / self.alphas_cumprod - 1) 78 | 79 | self.posterior_variance = ( 80 | self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 81 | ) 82 | 83 | self.posterior_log_variance_clipped = th.log( 84 | th.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]]) 85 | ) 86 | self.posterior_mean_coef1 = ( 87 | self.betas * th.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 88 | ) 89 | self.posterior_mean_coef2 = ( 90 | (1.0 - self.alphas_cumprod_prev) 91 | * th.sqrt(alphas) 92 | / (1.0 - self.alphas_cumprod) 93 | ) 94 | 95 | def p_sample(self, model, x_start, x_sec_hop, steps, sampling_noise=False): 96 | assert steps <= self.steps, "Too much steps in inference." 97 | #print("inference step:", steps) 98 | if steps == 0: 99 | x_t = x_start 100 | else: 101 | t = th.tensor([steps - 1] * x_start.shape[0]).to(x_start.device) 102 | x_t = self.q_sample(x_start, t) 103 | 104 | indices = list(range(self.steps))[::-1] 105 | 106 | if self.noise_scale == 0.: 107 | for i in indices: 108 | t = th.tensor([i] * x_t.shape[0]).to(x_start.device) 109 | x_t = model(x_t, x_sec_hop, t) 110 | return x_t 111 | 112 | for i in indices: 113 | t = th.tensor([i] * x_t.shape[0]).to(x_start.device) 114 | out = self.p_mean_variance(model, x_t, x_sec_hop, t) 115 | if sampling_noise: 116 | noise = th.randn_like(x_t) 117 | nonzero_mask = ( 118 | (t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))) 119 | ) # no noise when t == 0 120 | x_t = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 121 | else: 122 | x_t = out["mean"] 123 | return x_t 124 | 125 | def training_losses(self, model, x_start, x_sec_hop, reweight=False): 126 | batch_size, device = x_start.size(0), x_start.device 127 | ts, pt = self.sample_timesteps(batch_size, device, 'importance') 128 | noise = th.randn_like(x_start) 129 | if self.noise_scale != 0.: 130 | x_t = self.q_sample(x_start, ts, noise) 131 | else: 132 | x_t = x_start 133 | 134 | terms = {} 135 | model_output = model(x_t, x_sec_hop, ts) 136 | target = { 137 | ModelMeanType.START_X: x_start, 138 | ModelMeanType.EPSILON: noise, 139 | }[self.mean_type] 140 | # print("model_output shape:", model_output.shape) 141 | # print("target shape:", target.shape) 142 | # print("x_start shape:", x_start.shape) 143 | assert model_output.shape == target.shape == x_start.shape 144 | 145 | # print("target:", target) 146 | # print("model_output:", model_output) 147 | mse = mean_flat((target - model_output) ** 2) 148 | 149 | if reweight == True: 150 | if self.mean_type == ModelMeanType.START_X: 151 | weight = self.SNR(ts - 1) - self.SNR(ts) 152 | weight = th.where((ts == 0), 1.0, weight) 153 | loss = mse 154 | elif self.mean_type == ModelMeanType.EPSILON: 155 | weight = (1 - self.alphas_cumprod[ts]) / ((1-self.alphas_cumprod_prev[ts])**2 * (1-self.betas[ts])) 156 | weight = th.where((ts == 0), 1.0, weight) 157 | likelihood = mean_flat((x_start - self._predict_xstart_from_eps(x_t, ts, model_output))**2 / 2.0) 158 | loss = th.where((ts == 0), likelihood, mse) 159 | else: 160 | weight = th.tensor([1.0] * len(target)).to(device) 161 | 162 | terms["loss"] = weight * loss 163 | 164 | # update Lt_history & Lt_count 165 | for t, loss in zip(ts, terms["loss"]): 166 | if self.Lt_count[t] == self.history_num_per_term: 167 | Lt_history_old = self.Lt_history.clone() 168 | self.Lt_history[t, :-1] = Lt_history_old[t, 1:] 169 | self.Lt_history[t, -1] = loss.detach() 170 | else: 171 | try: 172 | self.Lt_history[t, self.Lt_count[t]] = loss.detach() 173 | self.Lt_count[t] += 1 174 | except: 175 | print(t) 176 | print(self.Lt_count[t]) 177 | print(loss) 178 | raise ValueError 179 | 180 | terms["loss"] /= pt 181 | return terms 182 | 183 | def sample_timesteps(self, batch_size, device, method='uniform', uniform_prob=0.001): 184 | if method == 'importance': # importance sampling 185 | if not (self.Lt_count == self.history_num_per_term).all(): 186 | return self.sample_timesteps(batch_size, device, method='uniform') 187 | 188 | Lt_sqrt = th.sqrt(th.mean(self.Lt_history ** 2, axis=-1)) 189 | pt_all = Lt_sqrt / th.sum(Lt_sqrt) 190 | pt_all *= 1- uniform_prob 191 | pt_all += uniform_prob / len(pt_all) 192 | 193 | assert pt_all.sum(-1) - 1. < 1e-5 194 | 195 | t = th.multinomial(pt_all, num_samples=batch_size, replacement=True) 196 | pt = pt_all.gather(dim=0, index=t) * len(pt_all) 197 | 198 | return t, pt 199 | 200 | elif method == 'uniform': # uniform sampling 201 | t = th.randint(0, self.steps, (batch_size,), device=device).long() 202 | pt = th.ones_like(t).float() 203 | 204 | return t, pt 205 | 206 | else: 207 | raise ValueError 208 | 209 | def q_sample(self, x_start, t, noise=None): 210 | if noise is None: 211 | noise = th.randn_like(x_start) 212 | assert noise.shape == x_start.shape 213 | return ( 214 | self._extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 215 | + self._extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 216 | * noise 217 | ) 218 | 219 | def q_posterior_mean_variance(self, x_start, x_t, t): 220 | """ 221 | Compute the mean and variance of the diffusion posterior: 222 | q(x_{t-1} | x_t, x_0) 223 | """ 224 | assert x_start.shape == x_t.shape 225 | posterior_mean = ( 226 | self._extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 227 | + self._extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 228 | ) 229 | posterior_variance = self._extract_into_tensor(self.posterior_variance, t, x_t.shape) 230 | posterior_log_variance_clipped = self._extract_into_tensor( 231 | self.posterior_log_variance_clipped, t, x_t.shape 232 | ) 233 | assert ( 234 | posterior_mean.shape[0] 235 | == posterior_variance.shape[0] 236 | == posterior_log_variance_clipped.shape[0] 237 | == x_start.shape[0] 238 | ) 239 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 240 | 241 | def p_mean_variance(self, model, x, x_sec_hop, t): 242 | """ 243 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 244 | the initial x, x_0. 245 | """ 246 | B, C = x.shape[:2] 247 | assert t.shape == (B, ) 248 | model_output = model(x, x_sec_hop, t) 249 | 250 | model_variance = self.posterior_variance 251 | model_log_variance = self.posterior_log_variance_clipped 252 | 253 | model_variance = self._extract_into_tensor(model_variance, t, x.shape) 254 | model_log_variance = self._extract_into_tensor(model_log_variance, t, x.shape) 255 | 256 | if self.mean_type == ModelMeanType.START_X: 257 | pred_xstart = model_output 258 | elif self.mean_type == ModelMeanType.EPSILON: 259 | pred_xstart = self._predict_xstart_from_eps(x, t, eps=model_output) 260 | else: 261 | raise NotImplementedError(self.mean_type) 262 | 263 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 264 | 265 | assert ( 266 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 267 | ) 268 | 269 | return { 270 | "mean": model_mean, 271 | "variance": model_variance, 272 | "log_variance": model_log_variance, 273 | "pred_xstart": pred_xstart, 274 | } 275 | 276 | def _predict_xstart_from_eps(self, x_t, t, eps): 277 | assert x_t.shape == eps.shape 278 | return ( 279 | self._extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 280 | - self._extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 281 | ) 282 | 283 | def SNR(self, t): 284 | """ 285 | Compute the signal-to-noise ratio for a single timestep. 286 | """ 287 | self.alphas_cumprod = self.alphas_cumprod.to(t.device) 288 | return self.alphas_cumprod[t] / (1 - self.alphas_cumprod[t]) 289 | 290 | def _extract_into_tensor(self, arr, timesteps, broadcast_shape): 291 | """ 292 | Extract values from a 1-D numpy array for a batch of indices. 293 | 294 | :param arr: the 1-D numpy array. 295 | :param timesteps: a tensor of indices into the array to extract. 296 | :param broadcast_shape: a larger shape of K dimensions with the batch 297 | dimension equal to the length of timesteps. 298 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 299 | """ 300 | # res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 301 | arr = arr.to(timesteps.device) 302 | res = arr[timesteps].float() 303 | while len(res.shape) < len(broadcast_shape): 304 | res = res[..., None] 305 | return res.expand(broadcast_shape) 306 | 307 | def betas_from_linear_variance(steps, variance, max_beta=0.999): 308 | alpha_bar = 1 - variance 309 | betas = [] 310 | betas.append(1 - alpha_bar[0]) 311 | for i in range(1, steps): 312 | betas.append(min(1 - alpha_bar[i] / alpha_bar[i - 1], max_beta)) 313 | return np.array(betas) 314 | 315 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 316 | """ 317 | Create a beta schedule that discretizes the given alpha_t_bar function, 318 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 319 | 320 | :param num_diffusion_timesteps: the number of betas to produce. 321 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 322 | produces the cumulative product of (1-beta) up to that 323 | part of the diffusion process. 324 | :param max_beta: the maximum beta to use; use values lower than 1 to 325 | prevent singularities. 326 | """ 327 | betas = [] 328 | for i in range(num_diffusion_timesteps): 329 | t1 = i / num_diffusion_timesteps 330 | t2 = (i + 1) / num_diffusion_timesteps 331 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 332 | return np.array(betas) 333 | 334 | def normal_kl(mean1, logvar1, mean2, logvar2): 335 | """ 336 | Compute the KL divergence between two gaussians. 337 | 338 | Shapes are automatically broadcasted, so batches can be compared to 339 | scalars, among other use cases. 340 | """ 341 | tensor = None 342 | for obj in (mean1, logvar1, mean2, logvar2): 343 | if isinstance(obj, th.Tensor): 344 | tensor = obj 345 | break 346 | assert tensor is not None, "at least one argument must be a Tensor" 347 | 348 | # Force variances to be Tensors. Broadcasting helps convert scalars to 349 | # Tensors, but it does not work for th.exp(). 350 | logvar1, logvar2 = [ 351 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 352 | for x in (logvar1, logvar2) 353 | ] 354 | 355 | return 0.5 * ( 356 | -1.0 357 | + logvar2 358 | - logvar1 359 | + th.exp(logvar1 - logvar2) 360 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 361 | ) 362 | 363 | def mean_flat(tensor): 364 | """ 365 | Take the mean over all non-batch dimensions. 366 | """ 367 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 368 | -------------------------------------------------------------------------------- /saved_models/CAM_3hops.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackfrost168/CF_Diff/7e30c302091a7ca6ff3fdcc958f5252c8bc60b5b/saved_models/CAM_3hops.pth --------------------------------------------------------------------------------