├── DataHander.py ├── LICENSE.txt ├── README.md ├── Utils ├── TimeLogger.py └── Utils.py ├── datasets ├── ciao │ └── dataset.pkl.zip ├── epinions │ └── dataset.pkl.zip └── yelp │ └── dataset.pkl.zip ├── framework_00.png ├── main.py ├── models ├── diffusion_process.py └── model.py ├── param.py ├── scripts ├── run_ciao.sh ├── run_epinions.sh └── run_yelp.sh └── utils.py /DataHander.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | import torch.utils.data as data 4 | import torch as t 5 | from utils import load_data, load_model, save_model, fix_random_seed_as 6 | import scipy.sparse as sp 7 | from scipy.sparse import csr_matrix, coo_matrix, dok_matrix 8 | from param import args 9 | import pickle 10 | import torch, dgl 11 | 12 | class DataHandler: 13 | def __init__(self): 14 | predir = '' 15 | if args.dataset == 'yelp': 16 | predir = './datasets/yelp/' 17 | self.datapath = predir + 'dataset.pkl' 18 | elif args.dataset == 'ciao': 19 | predir = './datasets/ciao/' 20 | self.datapath = predir + 'dataset.pkl' 21 | elif args.dataset == 'epinions': 22 | predir = './datasets/epinions/' 23 | self.datapath = predir + 'dataset.pkl' 24 | self.predir = predir 25 | 26 | 27 | 28 | def loadOneFile(self,data_path): 29 | with open(data_path, 'rb') as f: 30 | data = pickle.load(f) 31 | return data 32 | 33 | def LoadData(self): 34 | self.dataset = self.loadOneFile(self.datapath) 35 | trnMat = self.dataset['train'] 36 | tstMat = self.dataset['test'] 37 | valMat = self.dataset['val'] 38 | trainset = TrnData(trnMat) 39 | testset = TstData(tstMat, trnMat) 40 | valset = TstData(valMat,trnMat) 41 | self.n_user, self.n_item = self.dataset['userCount'], self.dataset['itemCount'] 42 | args.user, args.item = self.n_user, self.n_item 43 | 44 | self.trainloader = DataLoader( 45 | dataset=trainset, 46 | batch_size=args.batch_size, 47 | shuffle=True, 48 | num_workers=args.num_workers 49 | ) 50 | self.valloader = DataLoader( 51 | dataset=valset, 52 | batch_size=args.test_batch_size, 53 | shuffle=False, 54 | num_workers=args.num_workers 55 | ) 56 | self.testloader = DataLoader( 57 | dataset=testset, 58 | batch_size=args.test_batch_size, 59 | shuffle=False, 60 | num_workers=args.num_workers 61 | ) 62 | 63 | 64 | self.uu_graph = dgl.from_scipy(self.dataset['trust']) 65 | uimat = self.dataset['train'].tocsr() 66 | self.ui_graph = self.makeBiAdj(uimat,self.n_user,self.n_item) 67 | 68 | def makeBiAdj(self, mat,n_user,n_item): 69 | a = sp.csr_matrix((n_user, n_user)) 70 | b = sp.csr_matrix((n_item, n_item)) 71 | mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])]) 72 | mat = (mat != 0) * 1.0 73 | mat = mat.tocoo() 74 | edge_src,edge_dst = mat.nonzero() 75 | ui_graph = dgl.graph(data=(edge_src, edge_dst), 76 | idtype=torch.int32, 77 | num_nodes=mat.shape[0] 78 | ) 79 | 80 | return ui_graph 81 | 82 | # def normalizeAdj(self, mat): 83 | # degree = np.array(mat.sum(axis=-1)) 84 | # dInvSqrt = np.reshape(np.power(degree, -0.5), [-1]) 85 | # dInvSqrt[np.isinf(dInvSqrt)] = 0.0 86 | # dInvSqrtMat = sp.diags(dInvSqrt) 87 | # return mat.dot(dInvSqrtMat).transpose().dot(dInvSqrtMat).tocoo() 88 | 89 | class TrnData(data.Dataset): 90 | def __init__(self, coomat): 91 | self.rows = coomat.row 92 | self.cols = coomat.col 93 | self.dokmat = coomat.todok() 94 | self.negs = np.zeros(len(self.rows)).astype(np.int32) 95 | 96 | def negSampling(self): 97 | for i in range(len(self.rows)): 98 | u = self.rows[i] 99 | while True: 100 | iNeg = np.random.randint(args.item) 101 | if (u, iNeg) not in self.dokmat: 102 | break 103 | self.negs[i] = iNeg 104 | 105 | def __len__(self): 106 | return len(self.rows) 107 | 108 | def __getitem__(self, idx): 109 | return self.rows[idx], self.cols[idx], self.negs[idx] 110 | 111 | class TstData(data.Dataset): 112 | def __init__(self, coomat, trnMat): 113 | self.csrmat = (trnMat.tocsr() != 0) * 1.0 114 | 115 | tstLocs = [None] * coomat.shape[0] 116 | tstUsrs = set() 117 | for i in range(len(coomat.data)): 118 | row = coomat.row[i] 119 | col = coomat.col[i] 120 | if tstLocs[row] is None: 121 | tstLocs[row] = list() 122 | tstLocs[row].append(col) 123 | tstUsrs.add(row) 124 | tstUsrs = np.array(list(tstUsrs)) 125 | self.tstUsrs = tstUsrs 126 | self.tstLocs = tstLocs 127 | 128 | def __len__(self): 129 | return len(self.tstUsrs) 130 | 131 | def __getitem__(self, idx): 132 | return self.tstUsrs[idx], np.reshape(self.csrmat[self.tstUsrs[idx]].toarray(), [-1]) 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RecDiff: Diffusion Model for Social Recommendation 2 | ![bg](https://github.com/Zongwei9888/Experiment_Images/blob/b264fe0bae60741d88bf58f249da99c1a9272bb8/RecDiff_images/Recdiff.jpeg) 3 | This is the PyTorch-based implementation for RecDiff model proposed in this paper: 4 | >Diffusion Model for Social Recommendation 5 | ![model](./framework_00.png) 6 | ## Abstract 7 | Social recommendation has emerged as a powerful approach to enhance personalized recommendations by leveraging the social connections among users, such as following and friend relations observed in online social platforms. The fundamental assumption of social recommendation is that socially-connected users exhibit homophily in their preference patterns. This means that users connected by social ties tend to have similar tastes in user-item activities, such as rating and purchasing. However, this assumption is not always valid due to the presence of irrelevant and false social ties, which can contaminate user embeddings and adversely affect recommendation accuracy. To address this challenge, we propose a novel diffusion-based social denoising framework for recommendation (RecDiff). Our approach utilizes a simple yet effective hidden-space diffusion paradigm to alleivate the noisy effect in the compressed and dense representation space. By performing multi-step noise diffusion and removal, RecDiff possesses a robust ability to identify and eliminate noise from the encoded user representations, even when the noise levels vary. The diffusion module is optimized in a downstream task-aware manner, thereby maximizing its ability to enhance the recommendation process. We conducted extensive experiments to evaluate the efficacy of our framework, and the results demonstrate its superiority in terms of recommendation accuracy, training efficiency, and denoising effectiveness. 8 | 9 | ## Code Structures 10 | . 11 | ├── DataHandler.py 12 | ├── main.py 13 | ├── param.py 14 | ├── utils.py 15 | ├── Utils 16 | │ ├── TimeLogger.py 17 | │ ├── Utils.py 18 | ├── models 19 | │ ├── diffusion_process.py 20 | │ ├── model.py 21 | ├── scripts 22 | │ ├── run_ciao.sh 23 | │ ├── run_epinions.sh 24 | │ ├── run_yelp.sh 25 | └── README 26 | 27 | ## Environment 28 | - python=3.8 29 | - torch=1.12.1 30 | - numpy=1.23.1 31 | - scipy=1.9.1 32 | - dgl=1.0.2+cu113 33 | ## Datasets 34 | Our experiments are conducted on three benchmark datasets collected from Ciao, Epinions and Yelp online platforms. In those sites, social connections can be established among users in addition to their observed implicit feedback (e.g., rating, click) over different items. 35 | 36 | | Dataset | # Users | # Items | # Interactions | # Social Ties | 37 | | :------: | :-----: |:-------:|:--------------:|:-------------:| 38 | | Ciao | 1,925 | 1,5053 | 23,223 | 65,084 | 39 | | Epinions | 14,680 | 233,261 | 447,312 | 632,144 | 40 | | Yelp | 99,262 | 105,142 | 672,513 | 1,298,522 | 41 | ## Usage 42 | 43 | Please unzip the datasets first. Also you need to create the `History/`+'dataset_name (e.g,ciao)' and the `Models/`+ 'dataset_name (e.g,ciao)' directories. The command lines to train SDR on the three datasets are as below. The hyperparameters in the commands are set as default. 44 | 45 | - Ciao 46 | 47 | ```shell 48 | bash scripts/run_ciao.sh 49 | ``` 50 | 51 | - Epinions 52 | 53 | ```shell 54 | bash scripts/run_epinions.sh 55 | ``` 56 | 57 | - Yelp 58 | 59 | ```shell 60 | bash scripts/run_yelp.sh 61 | ``` 62 | ## Evaluation Results 63 | ### Overall Performance: 64 | RecDiff outperforms the baseline model with various top-N settings. 65 | ![performance](https://github.com/Zongwei9888/Experiment_Images/blob/94f30406a5fdb6747a215744e87e8fdee4bdb470/RecDiff_images/Overall_performs.png) 66 | ![performance](https://github.com/Zongwei9888/Experiment_Images/blob/f8cb0e7ca95a96f8d1d976d7304195e304cf41a8/RecDiff_images/Top-n_performance.png) 67 | 68 | ## Citation 69 | If you find this work useful for your research, please consider citing our paper: 70 | 71 | @misc{li2024recdiff, 72 | title={RecDiff: Diffusion Model for Social Recommendation}, 73 | author={Zongwei Li and Lianghao Xia and Chao Huang}, 74 | year={2024}, 75 | eprint={2406.01629}, 76 | archivePrefix={arXiv}, 77 | primaryClass={cs.IR} 78 | } 79 | -------------------------------------------------------------------------------- /Utils/TimeLogger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | logmsg = '' 4 | timemark = dict() 5 | saveDefault = False 6 | def log(msg, save=None, oneline=False): 7 | global logmsg 8 | global saveDefault 9 | time = datetime.datetime.now() 10 | tem = '%s: %s' % (time, msg) 11 | if save != None: 12 | if save: 13 | logmsg += tem + '\n' 14 | elif saveDefault: 15 | logmsg += tem + '\n' 16 | if oneline: 17 | print(tem, end='\r') 18 | else: 19 | print(tem) 20 | 21 | def marktime(marker): 22 | global timemark 23 | timemark[marker] = datetime.datetime.now() 24 | 25 | 26 | if __name__ == '__main__': 27 | log('') -------------------------------------------------------------------------------- /Utils/Utils.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import os 3 | def innerProduct(usrEmbeds, itmEmbeds): 4 | return t.sum(usrEmbeds * itmEmbeds, dim=-1) 5 | 6 | def pairPredict(ancEmbeds, posEmbeds, negEmbeds): 7 | return innerProduct(ancEmbeds, posEmbeds) - innerProduct(ancEmbeds, negEmbeds) 8 | 9 | def calcRegLoss(model): 10 | ret = 0 11 | for W in model.parameters(): 12 | ret += W.norm(2).square() 13 | # ret += (model.usrStruct + model.itmStruct) 14 | return ret 15 | 16 | # def calcReward(bprLoss, keepRate): 17 | # # return t.where(bprLoss >= threshold, 1.0, xi) 18 | # _, posLocs = t.topk(bprLoss, int(bprLoss.shape[0] * (1 - keepRate))) 19 | # ones = t.ones_like(bprLoss).cuda() 20 | # reward = t.minimum(bprLoss, ones * (0.5 - 1e-6)) 21 | # pckBprLoss = bprLoss[posLocs] 22 | # ones = t.ones_like(pckBprLoss).cuda() 23 | # reward[posLocs] = t.minimum(t.maximum(pckBprLoss, ones * (0.5 + 1e-6)), ones) 24 | # return reward 25 | 26 | def calcReward(bprLossDiff, keepRate): 27 | _, posLocs = t.topk(bprLossDiff, int(bprLossDiff.shape[0] * (1 - keepRate))) 28 | reward = t.zeros_like(bprLossDiff).cuda() 29 | reward[posLocs] = 1.0 30 | return reward 31 | 32 | def calcGradNorm(model): 33 | ret = 0 34 | for p in model.parameters(): 35 | if p.grad is not None: 36 | ret += p.grad.data.norm(2).square() 37 | ret = (ret ** 0.5) 38 | ret.detach() 39 | return ret 40 | 41 | def getFileName(save,epoch,args): 42 | 43 | file = f"autocf-{args.data}--{epoch}.pth.tar" 44 | return os.path.join(save,'model',file) 45 | 46 | def create_exp_dir(path, scripts_to_save=None): 47 | if not os.path.exists(path): 48 | os.makedirs(path) 49 | os.mkdir(os.path.join(path, 'model')) 50 | 51 | print('Experiment dir : {}'.format(path)) 52 | if scripts_to_save is not None: 53 | os.mkdir(os.path.join(path, 'scripts')) 54 | for script in scripts_to_save: 55 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 56 | shutil.copyfile(script, dst_file) 57 | -------------------------------------------------------------------------------- /datasets/ciao/dataset.pkl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/RecDiff/0f5b51a79a6dd74c6fd32d6acb7485a0133d2ca8/datasets/ciao/dataset.pkl.zip -------------------------------------------------------------------------------- /datasets/epinions/dataset.pkl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/RecDiff/0f5b51a79a6dd74c6fd32d6acb7485a0133d2ca8/datasets/epinions/dataset.pkl.zip -------------------------------------------------------------------------------- /datasets/yelp/dataset.pkl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/RecDiff/0f5b51a79a6dd74c6fd32d6acb7485a0133d2ca8/datasets/yelp/dataset.pkl.zip -------------------------------------------------------------------------------- /framework_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/RecDiff/0f5b51a79a6dd74c6fd32d6acb7485a0133d2ca8/framework_00.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch, pickle, time, os 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | 6 | from param import args 7 | from DataHander import DataHandler 8 | from models.model import SDNet ,GCNModel 9 | 10 | from utils import load_model, save_model, fix_random_seed_as 11 | from tqdm import tqdm 12 | 13 | from models import diffusion_process as dp 14 | from Utils.Utils import * 15 | import logging 16 | import sys 17 | class Coach: 18 | def __init__(self, handler): 19 | self.args = args 20 | self.device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu') 21 | self.handler = handler 22 | self.train_loader = self.handler.trainloader 23 | self.valloader = self.handler.valloader 24 | self.testloader = self.handler.testloader 25 | self.n_user,self.n_item = self.handler.n_user, self.handler.n_item 26 | self.uiGraph = self.handler.ui_graph.to(self.device) 27 | self.uuGraph = self.handler.uu_graph.to(self.device) 28 | 29 | 30 | self.GCNModel = GCNModel(args,self.n_user, self.n_item).to(self.device) 31 | ### Build Diffusion process### 32 | 33 | output_dims = [args.dims] + [args.n_hid] 34 | input_dims = output_dims[::-1] 35 | self.SDNet = SDNet(input_dims, output_dims, args.emb_size, time_type="cat", norm=args.norm).to(self.device) 36 | 37 | self.DiffProcess=dp.DiffusionProcess(args.noise_schedule,args.noise_scale, args.noise_min, args.noise_max, args.steps,self.device).to(self.device) 38 | 39 | self.optimizer1 = torch.optim.Adam([ 40 | {'params': self.GCNModel.parameters(),'weight_decay':0}, 41 | ], lr=args.lr) 42 | self.optimizer2 = torch.optim.Adam([ 43 | {'params': self.SDNet.parameters(), 'weight_decay': 0}, 44 | ], lr=args.difflr) 45 | self.scheduler1 = torch.optim.lr_scheduler.StepLR( 46 | self.optimizer1, 47 | step_size=args.decay_step, 48 | gamma=args.decay 49 | ) 50 | self.scheduler2 = torch.optim.lr_scheduler.StepLR( 51 | self.optimizer2, 52 | step_size=args.decay_step, 53 | gamma=args.decay 54 | ) 55 | 56 | self.train_loss = [] 57 | self.his_recall = [] 58 | self.his_ndcg = [] 59 | def train(self): 60 | args = self.args 61 | self.save_history = True 62 | log_format = '%(asctime)s %(message)s' 63 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') 64 | log_save = './History/' + args.dataset + '/' 65 | log_file = args.save_name 66 | fname = f'{log_file}.txt' 67 | fh = logging.FileHandler(os.path.join(log_save, fname)) 68 | fh.setFormatter(logging.Formatter(log_format)) 69 | logger = logging.getLogger() 70 | logger.addHandler(fh) 71 | logger.info(args) 72 | logger.info('================') 73 | best_recall, best_ndcg, best_epoch, wait = 0, 0, 0, 0 74 | start_time = time.time() 75 | for self.epoch in range(1, args.n_epoch + 1): 76 | epoch_losses = self.train_one_epoch() 77 | self.train_loss.append(epoch_losses) 78 | print('epoch {} done! elapsed {:.2f}.s, epoch_losses {}'.format( 79 | self.epoch, time.time() - start_time, epoch_losses 80 | ), flush=True) 81 | if self.epoch%5==0: 82 | recall, ndcg = self.test(self.testloader) 83 | 84 | #Record the history of recall and ndcg 85 | self.his_recall.append(recall) 86 | self.his_ndcg.append(ndcg) 87 | cur_best = recall + ndcg > best_recall + best_ndcg 88 | if cur_best: 89 | best_recall, best_ndcg, best_epoch = recall, ndcg, self.epoch 90 | wait = 0 91 | else: 92 | wait += 1 93 | logger.info('+ epoch {} tested, elapsed {:.2f}s, Recall@{}: {:.4f}, NDCG@{}: {:.4f}'.format( 94 | self.epoch, time.time() - start_time, args.topk, recall, args.topk, ndcg)) 95 | if args.model_dir and cur_best: 96 | desc = args.save_name 97 | perf = '' # f'N/R_{ndcg:.4f}/{hr:.4f}' 98 | fname = f'{args.desc}_{desc}_{perf}.pth' 99 | 100 | save_model(self.GCNModel, self.SDNet, os.path.join(args.model_dir, fname), self.optimizer1,self.optimizer2) 101 | if self.save_history: 102 | self.saveHistory() 103 | 104 | 105 | if wait >= args.patience: 106 | print(f'Early stop at epoch {self.epoch}, best epoch {best_epoch}') 107 | break 108 | 109 | print(f'Best Recall@{args.topk} {best_recall:.6f}, NDCG@{args.topk} {best_ndcg:.6f},', flush=True) 110 | 111 | 112 | 113 | def train_one_epoch(self): 114 | self.SDNet.train() 115 | self.GCNModel.train() 116 | dataloader = self.train_loader 117 | epoch_losses = [0] * 3 118 | dataloader.dataset.negSampling() 119 | tqdm_dataloader = tqdm(dataloader) 120 | since = time.time() 121 | 122 | for iteration, batch in enumerate(tqdm_dataloader): 123 | user_idx, pos_idx, neg_idx = batch 124 | user_idx = user_idx.long().cuda() 125 | pos_idx = pos_idx.long().cuda() 126 | neg_idx = neg_idx.long().cuda() 127 | uiEmbeds,uuEmbeds = self.GCNModel(self.uiGraph,self.uuGraph,True) 128 | uEmbeds = uiEmbeds[:self.n_user] 129 | iEmbeds = uiEmbeds[self.n_user:] 130 | user = uEmbeds[user_idx] 131 | pos = iEmbeds[pos_idx] 132 | neg = iEmbeds[neg_idx] 133 | 134 | uu_terms = self.DiffProcess.caculate_losses(self.SDNet, uuEmbeds[user_idx], args.reweight) 135 | uuelbo = uu_terms["loss"].mean() 136 | user = user+uu_terms["pred_xstart"] 137 | diffloss = uuelbo 138 | scoreDiff = pairPredict(user, pos, neg) 139 | bprLoss = - (scoreDiff).sigmoid().log().sum() / args.batch_size 140 | regLoss = ((torch.norm(user) ** 2 + torch.norm(pos) ** 2 + torch.norm(neg) ** 2) * args.reg)/args.batch_size 141 | loss = bprLoss + regLoss 142 | losses = [bprLoss.item(), regLoss.item()] 143 | 144 | 145 | loss = diffloss+loss 146 | losses.append(diffloss.item()) 147 | 148 | self.optimizer1.zero_grad() 149 | self.optimizer2.zero_grad() 150 | loss.backward() 151 | self.optimizer1.step() 152 | self.optimizer2.step() 153 | 154 | epoch_losses = [x + y for x, y in zip(epoch_losses, losses)] 155 | if self.scheduler1 is not None: 156 | self.scheduler1.step() 157 | self.scheduler2.step() 158 | 159 | epoch_losses = [sum(epoch_losses)] + epoch_losses 160 | time_elapsed = time.time() - since 161 | print('Training complete in {:.4f}s'.format( 162 | time_elapsed )) 163 | return epoch_losses 164 | 165 | def calcRes(self, topLocs, tstLocs, batIds): 166 | assert topLocs.shape[0] == len(batIds) 167 | allRecall = allNdcg = 0 168 | recallBig = 0 169 | ndcgBig = 0 170 | for i in range(len(batIds)): 171 | temTopLocs = list(topLocs[i]) 172 | temTstLocs = tstLocs[batIds[i]] 173 | tstNum = len(temTstLocs) 174 | maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))]) 175 | recall = dcg = 0 176 | for val in temTstLocs: 177 | if val in temTopLocs: 178 | recall += 1 179 | dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2)) 180 | recall = recall / tstNum 181 | ndcg = dcg / maxDcg 182 | allRecall += recall 183 | allNdcg += ndcg 184 | return allRecall, allNdcg 185 | 186 | 187 | def test(self,dataloader): 188 | self.SDNet.eval() 189 | self.GCNModel.eval() 190 | Recall, NDCG = [0] * 2 191 | num = dataloader.dataset.__len__() 192 | 193 | since = time.time() 194 | with torch.no_grad(): 195 | uiEmbeds, uuEmbeds = self.GCNModel(self.uiGraph, self.uuGraph, True) 196 | tqdm_dataloader = tqdm(dataloader) 197 | for iteration, batch in enumerate(tqdm_dataloader, start=1): 198 | user_idx, trnMask = batch 199 | user_idx = user_idx.long().cuda() 200 | trnMask = trnMask.cuda() 201 | 202 | uEmbeds = uiEmbeds[:self.n_user] 203 | iEmbeds = uiEmbeds[self.n_user:] 204 | user = uEmbeds[user_idx] 205 | 206 | uuemb = uuEmbeds[user_idx] 207 | user_predict = self.DiffProcess.p_sample(self.SDNet, uuemb, args.sampling_steps, args.sampling_noise) 208 | user = user + user_predict 209 | allPreds = t.mm(user, t.transpose(iEmbeds, 1, 0)) * (1 - trnMask) - trnMask * 1e8 210 | _, topLocs = t.topk(allPreds, args.topk) 211 | recall, ndcg = self.calcRes(topLocs.cpu().numpy(), dataloader.dataset.tstLocs, user_idx) 212 | Recall+= recall 213 | NDCG+=ndcg 214 | time_elapsed = time.time() - since 215 | print('Testing complete in {:.4f}s'.format( 216 | time_elapsed )) 217 | Recall = Recall/num 218 | NDCG = NDCG/num 219 | return Recall, NDCG 220 | 221 | 222 | 223 | def saveHistory(self): 224 | history = dict() 225 | history['loss'] = self.train_loss 226 | history['Recall'] = self.his_recall 227 | history['NDCG'] = self.his_ndcg 228 | ModelName = "SDR" 229 | desc = args.save_name 230 | perf = '' # f'N/R_{ndcg:.4f}/{hr:.4f}' 231 | fname = f'{args.desc}_{desc}_{perf}.his' 232 | 233 | with open('./History/' + args.dataset + '/' + fname, 'wb') as fs: 234 | pickle.dump(history, fs) 235 | 236 | 237 | 238 | if __name__ == "__main__": 239 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda 240 | fix_random_seed_as(args.seed) 241 | 242 | handler = DataHandler() 243 | handler.LoadData() 244 | app = Coach(handler) 245 | app.train() 246 | 247 | -------------------------------------------------------------------------------- /models/diffusion_process.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import math 3 | import numpy as np 4 | import torch as th 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | class DiffusionProcess(nn.Module): 8 | def __init__(self,noise_schedule, noise_scale, noise_min, noise_max,steps, device, keep_num=10): 9 | super(DiffusionProcess, self).__init__() 10 | self.noise_schedule = noise_schedule 11 | self.noise_scale = noise_scale 12 | self.noise_min = noise_min 13 | self.noise_max = noise_max 14 | self.steps = steps 15 | self.device = device 16 | 17 | self.keep_num = keep_num 18 | self.Lt_record = th.zeros(steps, keep_num, dtype=th.float64).to(device) 19 | self.Lt_count = th.zeros(steps, dtype=int).to(device) 20 | 21 | 22 | #The important parameters for gaussian diffusion 23 | self.beta_nums = th.tensor(self.betas_num(), dtype=th.float64).to(self.device) 24 | assert len(self.beta_nums.shape) == 1, "betas must be 1-D" 25 | assert len(self.beta_nums) == self.steps, "num of betas must equal to diffusion steps" 26 | assert (self.beta_nums > 0).all() and (self.beta_nums <= 1).all(), "betas out of range" 27 | 28 | self.diffusion_setting() 29 | 30 | def betas_num(self): 31 | """ 32 | Given the schedule name, create the betas for the diffusion process. 33 | """ 34 | st_bound = self.noise_scale * self.noise_min 35 | e_bound = self.noise_scale * self.noise_max 36 | if self.noise_schedule == "linear": 37 | return np.linspace(st_bound, e_bound, self.steps, dtype=np.float64) 38 | else: 39 | return betas_from_linear_variance(self.steps, np.linspace(st_bound, e_bound, self.steps, dtype=np.float64)) 40 | def diffusion_setting(self): 41 | alphas = 1.0 - self.beta_nums 42 | self.alphas_cumprod = th.cumprod(alphas, axis=0).to(self.device) 43 | self.alphas_cumprod_prev = th.cat([th.tensor([1.0]).to(self.device), self.alphas_cumprod[:-1]]).to(self.device) # alpha_{t-1} 44 | self.alphas_cumprod_next = th.cat([self.alphas_cumprod[1:], th.tensor([0.0]).to(self.device)]).to(self.device) # alpha_{t+1} 45 | assert self.alphas_cumprod_prev.shape == (self.steps,) 46 | 47 | self.sqrt_alphas_cumprod = th.sqrt(self.alphas_cumprod) 48 | self.sqrt_one_minus_alphas_cumprod = th.sqrt(1.0 - self.alphas_cumprod) 49 | self.log_one_minus_alphas_cumprod = th.log(1.0 - self.alphas_cumprod) 50 | self.sqrt_recip_alphas_cumprod = th.sqrt(1.0 / self.alphas_cumprod) 51 | self.sqrt_recipm1_alphas_cumprod = th.sqrt(1.0 / self.alphas_cumprod - 1) 52 | 53 | self.posterior_variance = ( 54 | self.beta_nums * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 55 | ) 56 | 57 | self.posterior_log_variance_clipped = th.log( 58 | th.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]]) 59 | ) 60 | self.posterior_mean_coef1 = ( 61 | self.beta_nums * th.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 62 | ) 63 | self.posterior_mean_coef2 = ( 64 | (1.0 - self.alphas_cumprod_prev) 65 | * th.sqrt(alphas) 66 | / (1.0 - self.alphas_cumprod) 67 | ) 68 | 69 | 70 | 71 | def caculate_losses(self, model, emb_s, reweight=False): 72 | batch_size, device = emb_s.size(0), emb_s.device 73 | ts, pt = self.sample_timesteps(batch_size, device, 'uniform') 74 | noise = th.randn_like(emb_s) 75 | emb_t = self.forward_process(emb_s, ts, noise) 76 | terms = {} 77 | model_output = model(emb_t, ts) 78 | 79 | 80 | assert model_output.shape == emb_s.shape 81 | 82 | mse = mean_flat((emb_s - model_output) ** 2) 83 | 84 | if reweight == True: 85 | 86 | weight = self.SNR(ts - 1) - self.SNR(ts) 87 | weight = th.where((ts == 0), 1.0, weight) 88 | loss = mse 89 | 90 | else: 91 | weight = th.tensor([1.0] * len(model_output)).to(device) 92 | 93 | terms["loss"] = weight * loss 94 | terms["pred_xstart"] = model_output 95 | return terms 96 | 97 | def p_sample(self, model, emb_s, steps, sampling_noise=False): 98 | assert steps <= self.steps, "Too much steps in inference." 99 | if steps == 0: 100 | emb_t = emb_s 101 | else: 102 | t = th.tensor([steps - 1] * emb_s.shape[0]).to(emb_s.device) 103 | emb_t = self.q_sample(emb_s, t) 104 | 105 | indices = list(range(self.steps))[::-1] 106 | 107 | if self.noise_scale == 0.: 108 | for i in indices: 109 | t = th.tensor([i] * emb_t.shape[0]).to(emb_s.device) 110 | emb_t = model(emb_t, t) 111 | return emb_t 112 | 113 | for i in indices: 114 | t = th.tensor([i] * emb_t.shape[0]).to(emb_s.device) 115 | out = self.p_mean_variance(model, emb_t, t) 116 | if sampling_noise: 117 | noise = th.randn_like(emb_t) 118 | nonzero_mask = ( 119 | (t != 0).float().view(-1, *([1] * (len(emb_t.shape) - 1))) 120 | ) # no noise when t == 0 121 | emb_t = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 122 | else: 123 | emb_t = out["mean"] 124 | return emb_t 125 | 126 | 127 | 128 | 129 | def sample_timesteps(self, batch_size, device, method='uniform', uniform_prob=0.001): 130 | if method == 'importance': # importance sampling 131 | if not (self.Lt_count == self.keep_num).all(): 132 | return self.sample_timesteps(batch_size, device, method='uniform') 133 | 134 | Lt_sqrt = th.sqrt(th.mean(self.Lt_record ** 2, axis=-1)) 135 | pt_all = Lt_sqrt / th.sum(Lt_sqrt) 136 | pt_all *= 1 - uniform_prob 137 | pt_all += uniform_prob / len(pt_all) 138 | 139 | assert pt_all.sum(-1) - 1. < 1e-5 140 | 141 | t = th.multinomial(pt_all, num_samples=batch_size, replacement=True) 142 | pt = pt_all.gather(dim=0, index=t) * len(pt_all) 143 | 144 | return t, pt 145 | 146 | elif method == 'uniform': # uniform sampling 147 | t = th.randint(0, self.steps, (batch_size,), device=device).long() 148 | pt = th.ones_like(t).float() 149 | 150 | return t, pt 151 | 152 | else: 153 | raise ValueError 154 | 155 | def forward_process(self, emb_s, t, noise=None): 156 | if noise is None: 157 | noise = th.randn_like(emb_s) 158 | assert noise.shape == emb_s.shape 159 | return ( 160 | self._extract_into_tensor(self.sqrt_alphas_cumprod, t, emb_s.shape) * emb_s 161 | + self._extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, emb_s.shape) 162 | * noise 163 | ) 164 | 165 | def q_posterior_mean_variance(self, emb_s, emb_t, t): 166 | """ 167 | Compute the mean and variance of the diffusion posterior: 168 | q(x_{t-1} | x_t, x_0) 169 | """ 170 | assert emb_s.shape == emb_t.shape 171 | posterior_mean = ( 172 | self._extract_into_tensor(self.posterior_mean_coef1, t, emb_t.shape) * emb_s 173 | + self._extract_into_tensor(self.posterior_mean_coef2, t, emb_t.shape) * emb_t 174 | ) 175 | posterior_variance = self._extract_into_tensor(self.posterior_variance, t, emb_t.shape) 176 | posterior_log_variance_clipped = self._extract_into_tensor( 177 | self.posterior_log_variance_clipped, t, emb_t.shape 178 | ) 179 | assert ( 180 | posterior_mean.shape[0] 181 | == posterior_variance.shape[0] 182 | == posterior_log_variance_clipped.shape[0] 183 | == emb_s.shape[0] 184 | ) 185 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 186 | 187 | def p_mean_variance(self, model, x, t): 188 | """ 189 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 190 | the initial x, x_0. 191 | """ 192 | B, C = x.shape[:2] 193 | assert t.shape == (B,) 194 | model_output = model(x, t) 195 | 196 | model_variance = self.posterior_variance 197 | model_log_variance = self.posterior_log_variance_clipped 198 | 199 | model_variance = self._extract_into_tensor(model_variance, t, x.shape) 200 | model_log_variance = self._extract_into_tensor(model_log_variance, t, x.shape) 201 | pred_xstart = model_output 202 | 203 | model_mean, _, _ = self.q_posterior_mean_variance(emb_s=pred_xstart, emb_t=x, t=t) 204 | 205 | assert ( 206 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 207 | ) 208 | 209 | return { 210 | "mean": model_mean, 211 | "variance": model_variance, 212 | "log_variance": model_log_variance, 213 | "pred_xstart": pred_xstart, 214 | } 215 | 216 | def _predict_xstart_from_eps(self, emb_t, t, eps): 217 | assert emb_t.shape == eps.shape 218 | return ( 219 | self._extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, emb_t.shape) * emb_t 220 | - self._extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, emb_t.shape) * eps 221 | ) 222 | 223 | def SNR(self, t): 224 | """ 225 | Compute the signal-to-noise ratio for a single timestep. 226 | """ 227 | self.alphas_cumprod = self.alphas_cumprod.to(t.device) 228 | return self.alphas_cumprod[t] / (1 - self.alphas_cumprod[t]) 229 | 230 | def _extract_into_tensor(self, arr, timesteps, broadcast_shape): 231 | """ 232 | Extract values from a 1-D numpy array for a batch of indices. 233 | 234 | :param arr: the 1-D numpy array. 235 | :param timesteps: a tensor of indices into the array to extract. 236 | :param broadcast_shape: a larger shape of K dimensions with the batch 237 | dimension equal to the length of timesteps. 238 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 239 | """ 240 | # res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 241 | arr = arr.to(timesteps.device) 242 | res = arr[timesteps].float() 243 | while len(res.shape) < len(broadcast_shape): 244 | res = res[..., None] 245 | return res.expand(broadcast_shape) 246 | 247 | 248 | 249 | def betas_from_linear_variance(steps, variance, max_beta=0.999): 250 | alpha_bar = 1 - variance 251 | betas = [] 252 | betas.append(1 - alpha_bar[0]) 253 | for i in range(1, steps): 254 | betas.append(min(1 - alpha_bar[i] / alpha_bar[i - 1], max_beta)) 255 | return np.array(betas) 256 | def normal_kl(mean1, logvar1, mean2, logvar2): 257 | """ 258 | Compute the KL divergence between two gaussians. 259 | 260 | Shapes are automatically broadcasted, so batches can be compared to 261 | scalars, among other use cases. 262 | """ 263 | tensor = None 264 | for obj in (mean1, logvar1, mean2, logvar2): 265 | if isinstance(obj, th.Tensor): 266 | tensor = obj 267 | break 268 | assert tensor is not None, "at least one argument must be a Tensor" 269 | 270 | # Force variances to be Tensors. Broadcasting helps convert scalars to 271 | # Tensors, but it does not work for th.exp(). 272 | logvar1, logvar2 = [ 273 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 274 | for x in (logvar1, logvar2) 275 | ] 276 | 277 | return 0.5 * ( 278 | -1.0 279 | + logvar2 280 | - logvar1 281 | + th.exp(logvar1 - logvar2) 282 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 283 | ) 284 | 285 | def mean_flat(tensor): 286 | """ 287 | Take the mean over all non-batch dimensions. 288 | """ 289 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 290 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import dgl, math, torch 2 | import numpy as np 3 | import networkx as nx 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import dgl.function as fn 7 | from dgl.nn.pytorch import GraphConv 8 | import math 9 | from torch.nn.init import xavier_normal_, constant_, xavier_uniform_ 10 | 11 | 12 | 13 | 14 | class UUGCNLayer(nn.Module): 15 | def __init__(self, 16 | in_feats, 17 | out_feats, 18 | weight=False, 19 | bias=False, 20 | activation=None): 21 | super(UUGCNLayer, self).__init__() 22 | self.bias = bias 23 | self._in_feats = in_feats 24 | self._out_feats = out_feats 25 | self.weight = weight 26 | if self.weight: 27 | self.u_w = nn.Parameter(torch.Tensor(in_feats, out_feats)) 28 | xavier_uniform_(self.u_w) 29 | self._activation = activation 30 | 31 | # def forward(self, graph, feat): 32 | def forward(self, graph, u_f): 33 | with graph.local_scope(): 34 | if self.weight: 35 | u_f = torch.mm(u_f, self.u_w) 36 | node_f = u_f 37 | # D^-1/2 38 | # degs = graph.out_degrees().to(feat.device).float().clamp(min=1) 39 | degs = graph.out_degrees().to(u_f.device).float().clamp(min=1) 40 | norm = torch.pow(degs, -0.5).view(-1, 1) 41 | # norm = norm.view(-1,1) 42 | # shp = norm.shape + (1,) * (feat.dim() - 1) 43 | # norm = t.reshape(norm, shp) 44 | 45 | node_f = node_f * norm 46 | 47 | graph.ndata['n_f'] = node_f 48 | # graph.edata['e_f'] = e_f 49 | graph.update_all(fn.copy_u(u='n_f', out='m'), reduce_func=fn.sum(msg='m', out='n_f')) 50 | 51 | rst = graph.ndata['n_f'] 52 | 53 | degs = graph.in_degrees().to(u_f.device).float().clamp(min=1) 54 | norm = torch.pow(degs, -0.5).view(-1, 1) 55 | # shp = norm.shape + (1,) * (feat.dim() - 1) 56 | # norm = t.reshape(norm, shp) 57 | rst = rst * norm 58 | 59 | if self._activation is not None: 60 | rst = self._activation(rst) 61 | 62 | return rst 63 | class GCNLayer(nn.Module): 64 | def __init__(self, 65 | in_feats, 66 | out_feats, 67 | weight=False, 68 | bias=False, 69 | activation=None): 70 | super(GCNLayer, self).__init__() 71 | self.bias = bias 72 | self._in_feats = in_feats 73 | self._out_feats = out_feats 74 | self.weight = weight 75 | if self.weight: 76 | self.u_w = nn.Parameter(torch.Tensor(in_feats, out_feats)) 77 | self.v_w = nn.Parameter(torch.Tensor(in_feats, out_feats)) 78 | # self.e_w = nn.Parameter(t.Tensor(in_feats, out_feats)) 79 | xavier_uniform_(self.u_w) 80 | xavier_uniform_(self.v_w) 81 | # init.xavier_uniform_(self.e_w) 82 | self._activation = activation 83 | 84 | # def forward(self, graph, feat): 85 | def forward(self, graph, u_f, v_f): 86 | with graph.local_scope(): 87 | if self.weight: 88 | u_f = torch.mm(u_f, self.u_w) 89 | v_f = torch.mm(v_f, self.v_w) 90 | # e_f = t.mm(e_f, self.e_w) 91 | node_f = torch.cat([u_f, v_f], dim=0) 92 | # D^-1/2 93 | # degs = graph.out_degrees().to(feat.device).float().clamp(min=1) 94 | degs = graph.out_degrees().to(u_f.device).float().clamp(min=1) 95 | norm = torch.pow(degs, -0.5).view(-1, 1) 96 | # norm = norm.view(-1,1) 97 | # shp = norm.shape + (1,) * (feat.dim() - 1) 98 | # norm = t.reshape(norm, shp) 99 | 100 | node_f = node_f * norm 101 | 102 | graph.ndata['n_f'] = node_f 103 | # graph.edata['e_f'] = e_f 104 | graph.update_all(fn.copy_u(u='n_f', out='m'), reduce_func=fn.sum(msg='m', out='n_f')) 105 | 106 | rst = graph.ndata['n_f'] 107 | 108 | degs = graph.in_degrees().to(u_f.device).float().clamp(min=1) 109 | norm = torch.pow(degs, -0.5).view(-1, 1) 110 | # shp = norm.shape + (1,) * (feat.dim() - 1) 111 | # norm = t.reshape(norm, shp) 112 | rst = rst * norm 113 | 114 | if self._activation is not None: 115 | rst = self._activation(rst) 116 | 117 | return rst 118 | 119 | class GCNModel(nn.Module): 120 | def __init__(self,args, n_user,n_item): 121 | super(GCNModel, self).__init__() 122 | self.n_user = n_user 123 | self.n_item = n_item 124 | self.n_hid = args.n_hid 125 | self.n_layers = args.n_layers 126 | self.s_layers = args.s_layers 127 | self.embedding_dict = self.init_weight(n_user, n_item, self.n_hid) 128 | self.act = nn.LeakyReLU(0.5, inplace=True) 129 | self.layers = nn.ModuleList() 130 | self.uu_Layers = nn.ModuleList() 131 | self.weight = args.weight 132 | for i in range(0, self.n_layers): 133 | self.layers.append(GCNLayer(self.n_hid, self.n_hid, weight=self.weight, bias=False, activation=self.act)) 134 | for i in range(0, self.s_layers): 135 | self.uu_Layers.append(UUGCNLayer(self.n_hid,self.n_hid,weight=self.weight, bias=False, activation=self.act)) 136 | def init_weight(self, userNum, itemNum, hide_dim): 137 | initializer = nn.init.xavier_uniform_ 138 | 139 | embedding_dict = nn.ParameterDict({ 140 | 'user_emb': nn.Parameter(initializer(torch.empty(userNum, hide_dim))), 141 | 'item_emb': nn.Parameter(initializer(torch.empty(itemNum, hide_dim))), 142 | }) 143 | return embedding_dict 144 | def forward(self, uigraph, uugraph, isTrain=True): 145 | 146 | init_embedding = torch.concat([self.embedding_dict['user_emb'],self.embedding_dict['item_emb']],axis=0) 147 | init_user_embedding = self.embedding_dict['user_emb'] 148 | all_embeddings = [init_embedding] 149 | all_uu_embeddings = [init_user_embedding] 150 | 151 | for i, layer in enumerate(self.layers): 152 | if i == 0: 153 | embeddings = layer(uigraph, self.embedding_dict['user_emb'], self.embedding_dict['item_emb']) 154 | else: 155 | embeddings = layer(uigraph, embeddings[:self.n_user], embeddings[self.n_user:]) 156 | 157 | norm_embeddings = F.normalize(embeddings, p=2, dim=1) 158 | all_embeddings += [norm_embeddings] 159 | ui_embeddings = sum(all_embeddings) 160 | 161 | for i, layer in enumerate(self.uu_Layers): 162 | if i == 0: 163 | embeddings = layer(uugraph, self.embedding_dict['user_emb']) 164 | else: 165 | embeddings = layer(uugraph, embeddings) 166 | norm_embeddings = F.normalize(embeddings, p=2, dim=1) 167 | all_uu_embeddings +=[norm_embeddings] 168 | uu_embeddings = sum(all_uu_embeddings) 169 | 170 | return ui_embeddings,uu_embeddings 171 | 172 | 173 | 174 | 175 | #Social hidden fully-connected architecture 176 | class SDNet(nn.Module): 177 | """ 178 | A deep neural network for the reverse diffusion preocess. 179 | """ 180 | 181 | def __init__(self, in_dims, out_dims, emb_size, time_type="cat", norm=False, dropout=0.5): 182 | super(SDNet, self).__init__() 183 | self.in_dims = in_dims 184 | self.out_dims = out_dims 185 | assert out_dims[0] == in_dims[-1], "In and out dimensions must equal to each other." 186 | self.time_type = time_type 187 | self.time_emb_dim = emb_size 188 | self.norm = norm 189 | 190 | self.emb_layer = nn.Linear(self.time_emb_dim, self.time_emb_dim) 191 | 192 | if self.time_type == "cat": 193 | in_dims_temp = [self.in_dims[0] + self.time_emb_dim] + self.in_dims[1:] 194 | else: 195 | raise ValueError("Unimplemented timestep embedding type %s" % self.time_type) 196 | out_dims_temp = self.out_dims 197 | 198 | self.in_layers = nn.ModuleList([nn.Linear(d_in, d_out) \ 199 | for d_in, d_out in zip(in_dims_temp[:-1], in_dims_temp[1:])]) 200 | self.out_layers = nn.ModuleList([nn.Linear(d_in, d_out) \ 201 | for d_in, d_out in zip(out_dims_temp[:-1], out_dims_temp[1:])]) 202 | 203 | self.drop = nn.Dropout(dropout) 204 | self.init_weights() 205 | 206 | def init_weights(self): 207 | for layer in self.in_layers: 208 | # Xavier Initialization for weights 209 | size = layer.weight.size() 210 | fan_out = size[0] 211 | fan_in = size[1] 212 | std = np.sqrt(2.0 / (fan_in + fan_out)) 213 | layer.weight.data.normal_(0.0, std) 214 | 215 | # Normal Initialization for weights 216 | layer.bias.data.normal_(0.0, 0.001) 217 | 218 | for layer in self.out_layers: 219 | # Xavier Initialization for weights 220 | size = layer.weight.size() 221 | fan_out = size[0] 222 | fan_in = size[1] 223 | std = np.sqrt(2.0 / (fan_in + fan_out)) 224 | layer.weight.data.normal_(0.0, std) 225 | 226 | # Normal Initialization for weights 227 | layer.bias.data.normal_(0.0, 0.001) 228 | 229 | size = self.emb_layer.weight.size() 230 | fan_out = size[0] 231 | fan_in = size[1] 232 | std = np.sqrt(2.0 / (fan_in + fan_out)) 233 | self.emb_layer.weight.data.normal_(0.0, std) 234 | self.emb_layer.bias.data.normal_(0.0, 0.001) 235 | 236 | def forward(self, x, timesteps): 237 | time_emb = timestep_embedding(timesteps, self.time_emb_dim).to(x.device) 238 | emb = self.emb_layer(time_emb) 239 | if self.norm: 240 | x = F.normalize(x) 241 | x = self.drop(x) 242 | h = torch.cat([x, emb], dim=-1) 243 | for i, layer in enumerate(self.in_layers): 244 | h = layer(h) 245 | h = torch.tanh(h) 246 | 247 | for i, layer in enumerate(self.out_layers): 248 | h = layer(h) 249 | if i != len(self.out_layers) - 1: 250 | h = torch.tanh(h) 251 | 252 | return h 253 | 254 | 255 | def timestep_embedding(timesteps, dim, max_period=10000): 256 | """ 257 | Create sinusoidal timestep embeddings. 258 | 259 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 260 | These may be fractional. 261 | :param dim: the dimension of the output. 262 | :param max_period: controls the minimum frequency of the embeddings. 263 | :return: an [N x dim] Tensor of positional embeddings. 264 | """ 265 | 266 | half = dim // 2 267 | freqs = torch.exp( 268 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 269 | ).to(timesteps.device) 270 | args = timesteps[:, None].float() * freqs[None] 271 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 272 | if dim % 2: 273 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 274 | return embedding 275 | 276 | 277 | def mean_flat(tensor): 278 | """ 279 | Take the mean over all non-batch dimensions. 280 | """ 281 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 282 | 283 | -------------------------------------------------------------------------------- /param.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | # from yaml import safe_load as yaml_load 3 | from json import dumps as json_dumps 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser(description='SDR Arguments') 7 | parser.add_argument('--desc', type=str, default='') 8 | 9 | #Configuration Arguments 10 | parser.add_argument('--cuda', type=str, default='0') 11 | parser.add_argument('--seed', type=int, default=2023) 12 | 13 | #Model Arguments 14 | parser.add_argument('--n_hid', type=int, default=64) 15 | parser.add_argument('--n_layers', type=int, default=2) 16 | parser.add_argument('--s_layers', type=int, default=2) 17 | parser.add_argument('--weight', type=bool, default=True, help='Add linear weight or not') 18 | 19 | 20 | 21 | #Train Arguments 22 | parser.add_argument('--dropout', type=float, default=0) 23 | 24 | #Optimization Arguments 25 | parser.add_argument('--lr', type=float, default=5e-3) 26 | parser.add_argument('--difflr', type=float, default=1e-3) 27 | parser.add_argument('--reg', type=float, default=1e-2) 28 | parser.add_argument('--decay', type=float, default=0.985) 29 | parser.add_argument('--decay_step', type=int, default=1) 30 | parser.add_argument('--n_epoch', type=int, default=150) 31 | parser.add_argument('--batch_size', type=int, default=2048) 32 | parser.add_argument('--patience', type=int, default=20) 33 | 34 | # Valid/Test Arguments 35 | parser.add_argument('--topk', type=int, default=20) 36 | parser.add_argument('--test_batch_size', type=int, default=1024) 37 | 38 | #Data Arguments 39 | parser.add_argument('--dataset', type=str, default="epinions") 40 | parser.add_argument('--num_workers', type=int, default=0) 41 | parser.add_argument('--save_name', type=str, default='tem') 42 | parser.add_argument('--checkpoint', type=str, default="./Model/epinions/_tem_.pth") 43 | parser.add_argument('--model_dir', type=str, default="./Model/epinions/") 44 | 45 | # params for the denoiser 46 | parser.add_argument('--time_type', type=str, default='cat', help='cat or add') 47 | parser.add_argument('--dims', type=int, default=64, help='the dims for the DNN') 48 | parser.add_argument('--norm', type=bool, default=True, help='Normalize the input or not') 49 | parser.add_argument('--emb_size', type=int, default=16, help='timestep embedding size') 50 | 51 | # params for diffusions 52 | parser.add_argument('--steps', type=int, default=20, help='diffusion steps') 53 | parser.add_argument('--noise_schedule', type=str, default='linear-var', help='the schedule for noise generating') 54 | parser.add_argument('--noise_scale', type=float, default=1, help='noise scale for noise generating') 55 | parser.add_argument('--noise_min', type=float, default=0.0001, help='noise lower bound for noise generating') 56 | parser.add_argument('--noise_max', type=float, default=0.01, help='noise upper bound for noise generating') 57 | parser.add_argument('--sampling_noise', type=bool, default=False, help='sampling with noise or not') 58 | parser.add_argument('--sampling_steps', type=int, default=0, help='steps of the forward process during inference') 59 | parser.add_argument('--reweight', type=bool, default=True, 60 | help='assign different weight to different timestep or not') 61 | 62 | 63 | 64 | return parser.parse_args() 65 | 66 | args = parse_args() 67 | 68 | 69 | # ciao 5e-3, 1e-3, 1 1e-2, 20, emb_size:16 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /scripts/run_ciao.sh: -------------------------------------------------------------------------------- 1 | python main.py --n_hid 64 --dataset ciao --n_layers 2 --s_layer 2 --lr 5e-3 --difflr 1e-3 --reg 1e-2 --batch_size 2048 --test_batch_size 1024 \ 2 | --emb_size 16 --steps 20 --noise_scale 1 --model_dir './Model/ciao/' 3 | -------------------------------------------------------------------------------- /scripts/run_epinions.sh: -------------------------------------------------------------------------------- 1 | python main.py --n_hid 64 --dataset epinions --n_layers 2 --s_layer 2 --lr 0.001 --difflr 0.001 --reg 0.0001 --batch_size 4096 --test_batch_size 1024 \ 2 | --emb_size 16 --steps 200 --noise_scale 0.1 --model_dir './Model/epinions/' 3 | -------------------------------------------------------------------------------- /scripts/run_yelp.sh: -------------------------------------------------------------------------------- 1 | python main.py --n_hid 64 --dataset yelp --n_layers 2 --s_layer 2 --lr 0.001 --difflr 0.0001 --reg 0.005 --batch_size 4096 --test_batch_size 2048 \ 2 | --emb_size 16 --steps 50 --noise_scale 0.1 --model_dir './Model/yelp/' 3 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch, os, pickle, random 2 | import numpy as np 3 | # from yaml import safe_load as yaml_load 4 | # from json import dumps as json_dumps 5 | 6 | 7 | def load_data(data_path): 8 | with open(data_path, 'rb') as f: 9 | data = pickle.load(f) 10 | return data 11 | 12 | 13 | def save_model(model,difussion_model ,save_path, optimizer1=None,optimizer2 = None): 14 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 15 | data2save = { 16 | 'state_dict1': model.state_dict(), 17 | 'optimizer1': optimizer1.state_dict(), 18 | 'state_dict2': difussion_model.state_dict(), 19 | 'optimizer2': optimizer2.state_dict(), 20 | 21 | } 22 | torch.save(data2save, save_path) 23 | 24 | 25 | 26 | def load_model(model,model2, load_path, optimizer=None): 27 | data2load = torch.load(load_path, map_location='cpu') 28 | model.load_state_dict(data2load['state_dict1']) 29 | model2.load_state_dict(data2load['state_dict2']) 30 | if optimizer is not None and data2load['optimizer'] is not None: 31 | optimizer = data2load['optimizer'] 32 | 33 | 34 | 35 | def fix_random_seed_as(seed): 36 | random.seed(seed) 37 | torch.random.manual_seed(seed) 38 | torch.cuda.manual_seed_all(seed) 39 | np.random.seed(seed) 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | 43 | 44 | if __name__ == "__main__": 45 | pass 46 | --------------------------------------------------------------------------------