├── LATTICE ├── LICENSE ├── README.md └── codes │ ├── Models.py │ ├── __pycache__ │ └── Models.cpython-38.pyc │ ├── main.py │ └── utility │ ├── __pycache__ │ ├── batch_test.cpython-38.pyc │ ├── load_data.cpython-38.pyc │ ├── metrics.cpython-38.pyc │ └── parser.cpython-38.pyc │ ├── batch_test.py │ ├── load_data.py │ ├── metrics.py │ └── parser.py ├── PromptMM.png ├── README.md ├── codes ├── Models.py ├── Models_empower.py ├── Models_mmlight.py ├── __pycache__ │ ├── MMD.cpython-38.pyc │ ├── Models.cpython-38.pyc │ ├── Models.cpython-39.pyc │ ├── Models2.cpython-38.pyc │ ├── Models2_0938.cpython-38.pyc │ ├── Models2_0938_modality.cpython-38.pyc │ ├── Models2_0938_sub_gene.cpython-38.pyc │ ├── Models2_MF_VBPR_NGCF_LightGCN.cpython-38.pyc │ ├── Models2_MF_VBPR_NGCF_LightGCN_MMGCN.cpython-38.pyc │ ├── Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR.cpython-38.pyc │ ├── Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec.cpython-38.pyc │ ├── Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec3.cpython-38.pyc │ ├── Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec_SLMRec.cpython-38.pyc │ ├── Models2_sub_gene.cpython-38.pyc │ ├── Models2_sub_gene_0485.cpython-38.pyc │ ├── Models2_sub_gene_co_weight.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_0826.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_333333333333.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_3_5.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_0962_beforeafterTo0962.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_DtwoMLP.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_beforeafter.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_copy.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_delete.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_based_noFeatTrans_ablation.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_dropGNN.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_redoADCL.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_HL_train.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_IRGANdeal.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_MUIT.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_binaryLoss_modelPara.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_oversmoothing.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_woFEAT.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_woIIGRAPH.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_model_feature.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_samepos.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_G_first.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_G_first_auto.cpython-38.pyc │ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_G_first_auto_geneBPR.cpython-38.pyc │ ├── Models3.cpython-38.pyc │ ├── Models3.cpython-39.pyc │ ├── Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28.cpython-38.pyc │ ├── Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28_distribution.cpython-38.pyc │ ├── Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28_distribution_noFeatTrans.cpython-38.pyc │ ├── Models_AD.cpython-38.pyc │ ├── Models_empower.cpython-39.pyc │ ├── Models_mmlight.cpython-39.pyc │ └── generator_discriminator.cpython-38.pyc ├── main.py ├── main_empower.py ├── main_mmlight.py └── utility │ ├── __pycache__ │ ├── batch_test.cpython-38.pyc │ ├── batch_test.cpython-39.pyc │ ├── load_data.cpython-38.pyc │ ├── load_data.cpython-39.pyc │ ├── logging.cpython-38.pyc │ ├── logging.cpython-39.pyc │ ├── metrics.cpython-38.pyc │ ├── metrics.cpython-39.pyc │ ├── norm.cpython-38.pyc │ ├── norm.cpython-39.pyc │ ├── parser.cpython-38.pyc │ └── parser.cpython-39.pyc │ ├── batch_test.py │ ├── load_data.py │ ├── logging.py │ ├── metrics.py │ ├── norm.py │ └── parser.py └── decouple.png /LATTICE/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Big Data and Multi-modal Computing Group, CRIPAC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LATTICE/README.md: -------------------------------------------------------------------------------- 1 | # LATTICE 2 | 3 | PyTorch implementation for ACM Multimedia 2021 paper: [Mining Latent Structures for Multimedia Recommendation](https://dl.acm.org/doi/10.1145/3474085.3475259) 4 | 5 | 6 | 7 | ## Dependencies 8 | 9 | - Python 3.6 10 | - torch==1.5.0 11 | - scikit-learn==0.24.2 12 | 13 | 14 | 15 | ## Dataset Preparation 16 | 17 | - Download **5-core reviews data**, **meta data**, and **image features** from [Amazon product dataset](http://jmcauley.ucsd.edu/data/amazon/links.html). Put data into the directory `data/meta-data/`. 18 | 19 | - Install [sentence-transformers](https://www.sbert.net/docs/installation.html) and download [pretrained models](https://www.sbert.net/docs/pretrained_models.html) to extract textual features. Unzip pretrained model into the directory `sentence-transformers/`: 20 | 21 | ``` 22 | ├─ data/: 23 | ├── sports/ 24 | ├── meta-data/ 25 | ├── image_features_Sports_and_Outdoors.b 26 | ├── meta-Sports_and_Outdoors.json.gz 27 | ├── reviews_Sports_and_Outdoors_5.json.gz 28 | ├── sentence-transformers/ 29 | ├── stsb-roberta-large 30 | ``` 31 | 32 | - Run `python build_data.py` to preprocess data. 33 | 34 | - Run `python cold_start.py` to build cold-start data. 35 | 36 | - We provide processed data [Baidu Yun](https://pan.baidu.com/s/1SWe-XE23Nn0i4xSOXV_JyQ) (access code: m37q), [Google Drive](https://drive.google.com/drive/folders/1sFg9W2wCexWahjqtN6MVc4f4dMj5hyFp?usp=sharing). 37 | 38 | ## Usage 39 | 40 | Start training and inference as: 41 | 42 | ``` 43 | cd codes 44 | python main.py --dataset {DATASET} 45 | ``` 46 | 47 | For cold-start settings: 48 | ``` 49 | python main.py --dataset {DATASET} --core 0 --verbose 1 --lr 1e-5 50 | ``` 51 | 52 | 53 | 54 | ## Citation 55 | 56 | If you want to use our codes in your research, please cite: 57 | 58 | ``` 59 | @inproceedings{LATTICE21, 60 | title = {Mining Latent Structures for Multimedia Recommendation}, 61 | author = {Zhang, Jinghao and 62 | Zhu, Yanqiao and 63 | Liu, Qiang and 64 | Wu, Shu and 65 | Wang, Shuhui and 66 | Wang, Liang}, 67 | booktitle = {Proceedings of the 29th ACM International Conference on Multimedia}, 68 | pages = {3872–3880}, 69 | year = {2021} 70 | } 71 | ``` 72 | 73 | ## Acknowledgement 74 | 75 | The structure of this code is largely based on [LightGCN](https://github.com/gusye1234/LightGCN-PyTorch). Thank for their work. 76 | 77 | -------------------------------------------------------------------------------- /LATTICE/codes/Models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from time import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.sparse as sparse 8 | import torch.nn.functional as F 9 | 10 | from utility.parser import parse_args 11 | args = parse_args() 12 | 13 | def build_knn_neighbourhood(adj, topk): 14 | knn_val, knn_ind = torch.topk(adj, topk, dim=-1) 15 | weighted_adjacency_matrix = (torch.zeros_like(adj)).scatter_(-1, knn_ind, knn_val) 16 | return weighted_adjacency_matrix 17 | def compute_normalized_laplacian(adj): 18 | rowsum = torch.sum(adj, -1) 19 | d_inv_sqrt = torch.pow(rowsum, -0.5) 20 | d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0. 21 | d_mat_inv_sqrt = torch.diagflat(d_inv_sqrt) 22 | L_norm = torch.mm(torch.mm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt) 23 | return L_norm 24 | def build_sim(context): 25 | context_norm = context.div(torch.norm(context, p=2, dim=-1, keepdim=True)) 26 | sim = torch.mm(context_norm, context_norm.transpose(1, 0)) 27 | return sim 28 | 29 | class LATTICE(nn.Module): 30 | def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats, text_feats): 31 | super().__init__() 32 | self.n_users = n_users 33 | self.n_items = n_items 34 | self.embedding_dim = embedding_dim 35 | self.weight_size = weight_size 36 | self.n_ui_layers = len(self.weight_size) 37 | self.weight_size = [self.embedding_dim] + self.weight_size 38 | self.user_embedding = nn.Embedding(n_users, self.embedding_dim) 39 | self.item_id_embedding = nn.Embedding(n_items, self.embedding_dim) 40 | nn.init.xavier_uniform_(self.user_embedding.weight) 41 | nn.init.xavier_uniform_(self.item_id_embedding.weight) 42 | 43 | if args.cf_model == 'ngcf': 44 | self.GC_Linear_list = nn.ModuleList() 45 | self.Bi_Linear_list = nn.ModuleList() 46 | self.dropout_list = nn.ModuleList() 47 | for i in range(self.n_ui_layers): 48 | self.GC_Linear_list.append(nn.Linear(self.weight_size[i], self.weight_size[i+1])) 49 | self.Bi_Linear_list.append(nn.Linear(self.weight_size[i], self.weight_size[i+1])) 50 | self.dropout_list.append(nn.Dropout(dropout_list[i])) 51 | 52 | 53 | self.image_embedding = nn.Embedding.from_pretrained(torch.Tensor(image_feats), freeze=False) 54 | self.text_embedding = nn.Embedding.from_pretrained(torch.Tensor(text_feats), freeze=False) 55 | 56 | 57 | if os.path.exists(args.data_path + 'image_adj_%d.pt'%(args.topk)): 58 | image_adj = torch.load(args.data_path + 'image_adj_%d.pt'%(args.topk)) 59 | else: 60 | image_adj = build_sim(self.image_embedding.weight.detach()) 61 | image_adj = build_knn_neighbourhood(image_adj, topk=args.topk) 62 | image_adj = compute_normalized_laplacian(image_adj) 63 | torch.save(image_adj, args.data_path + 'image_adj_%d.pt'%(args.topk)) 64 | 65 | if os.path.exists(args.data_path + 'text_adj_%d.pt'%(args.topk)): 66 | text_adj = torch.load(args.data_path + 'text_adj_%d.pt'%(args.topk)) 67 | else: 68 | text_adj = build_sim(self.text_embedding.weight.detach()) 69 | text_adj = build_knn_neighbourhood(text_adj, topk=args.topk) 70 | text_adj = compute_normalized_laplacian(text_adj) 71 | torch.save(text_adj, args.data_path + 'text_adj_%d.pt'%(args.topk)) 72 | 73 | self.text_original_adj = text_adj.cuda() 74 | self.image_original_adj = image_adj.cuda() 75 | 76 | self.image_trs = nn.Linear(image_feats.shape[1], args.feat_embed_dim) 77 | self.text_trs = nn.Linear(text_feats.shape[1], args.feat_embed_dim) 78 | 79 | 80 | self.modal_weight = nn.Parameter(torch.Tensor([0.5, 0.5])) 81 | self.softmax = nn.Softmax(dim=0) 82 | 83 | def forward(self, adj, build_item_graph=False): 84 | image_feats = self.image_trs(self.image_embedding.weight) 85 | text_feats = self.text_trs(self.text_embedding.weight) 86 | if build_item_graph: 87 | weight = self.softmax(self.modal_weight) 88 | self.image_adj = build_sim(image_feats) 89 | self.image_adj = build_knn_neighbourhood(self.image_adj, topk=args.topk) 90 | 91 | self.text_adj = build_sim(text_feats) 92 | self.text_adj = build_knn_neighbourhood(self.text_adj, topk=args.topk) 93 | 94 | 95 | learned_adj = weight[0] * self.image_adj + weight[1] * self.text_adj 96 | learned_adj = compute_normalized_laplacian(learned_adj) 97 | original_adj = weight[0] * self.image_original_adj + weight[1] * self.text_original_adj 98 | self.item_adj = (1 - args.lambda_coeff) * learned_adj + args.lambda_coeff * original_adj 99 | else: 100 | self.item_adj = self.item_adj.detach() 101 | 102 | h = self.item_id_embedding.weight 103 | for i in range(args.n_layers): 104 | h = torch.mm(self.item_adj, h) 105 | 106 | if args.cf_model == 'ngcf': 107 | ego_embeddings = torch.cat((self.user_embedding.weight, self.item_id_embedding.weight), dim=0) 108 | all_embeddings = [ego_embeddings] 109 | for i in range(self.n_ui_layers): 110 | side_embeddings = torch.sparse.mm(adj, ego_embeddings) 111 | sum_embeddings = F.leaky_relu(self.GC_Linear_list[i](side_embeddings)) 112 | bi_embeddings = torch.mul(ego_embeddings, side_embeddings) 113 | bi_embeddings = F.leaky_relu(self.Bi_Linear_list[i](bi_embeddings)) 114 | ego_embeddings = sum_embeddings + bi_embeddings 115 | ego_embeddings = self.dropout_list[i](ego_embeddings) 116 | 117 | norm_embeddings = F.normalize(ego_embeddings, p=2, dim=1) 118 | all_embeddings += [norm_embeddings] 119 | 120 | all_embeddings = torch.stack(all_embeddings, dim=1) 121 | all_embeddings = all_embeddings.mean(dim=1, keepdim=False) 122 | u_g_embeddings, i_g_embeddings = torch.split(all_embeddings, [self.n_users, self.n_items], dim=0) 123 | i_g_embeddings = i_g_embeddings + F.normalize(h, p=2, dim=1) 124 | return u_g_embeddings, i_g_embeddings 125 | elif args.cf_model == 'lightgcn': 126 | ego_embeddings = torch.cat((self.user_embedding.weight, self.item_id_embedding.weight), dim=0) 127 | all_embeddings = [ego_embeddings] 128 | for i in range(self.n_ui_layers): 129 | side_embeddings = torch.sparse.mm(adj, ego_embeddings) 130 | ego_embeddings = side_embeddings 131 | all_embeddings += [ego_embeddings] 132 | all_embeddings = torch.stack(all_embeddings, dim=1) 133 | all_embeddings = all_embeddings.mean(dim=1, keepdim=False) 134 | u_g_embeddings, i_g_embeddings = torch.split(all_embeddings, [self.n_users, self.n_items], dim=0) 135 | i_g_embeddings = i_g_embeddings + F.normalize(h, p=2, dim=1) 136 | return u_g_embeddings, i_g_embeddings 137 | elif args.cf_model == 'mf': 138 | return self.user_embedding.weight, self.item_id_embedding.weight + F.normalize(h, p=2, dim=1) 139 | 140 | -------------------------------------------------------------------------------- /LATTICE/codes/__pycache__/Models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/LATTICE/codes/__pycache__/Models.cpython-38.pyc -------------------------------------------------------------------------------- /LATTICE/codes/main.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import math 3 | import os 4 | import random 5 | import sys 6 | from time import time 7 | from tqdm import tqdm 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import torch.sparse as sparse 15 | 16 | from utility.parser import parse_args 17 | from Models import LATTICE 18 | from utility.batch_test import * 19 | 20 | args = parse_args() 21 | 22 | 23 | class Trainer(object): 24 | def __init__(self, data_config): 25 | # argument settings 26 | self.n_users = data_config['n_users'] 27 | self.n_items = data_config['n_items'] 28 | 29 | self.model_name = args.model_name 30 | self.mess_dropout = eval(args.mess_dropout) 31 | self.lr = args.lr 32 | self.emb_dim = args.embed_size 33 | self.batch_size = args.batch_size 34 | self.weight_size = eval(args.weight_size) 35 | self.n_layers = len(self.weight_size) 36 | self.regs = eval(args.regs) 37 | self.decay = self.regs[0] 38 | 39 | self.norm_adj = data_config['norm_adj'] 40 | self.norm_adj = self.sparse_mx_to_torch_sparse_tensor(self.norm_adj).float().cuda() 41 | 42 | image_feats = np.load(args.data_path + '{}/image_feat.npy'.format(args.dataset)) 43 | text_feats = np.load(args.data_path + '{}/text_feat.npy'.format(args.dataset)) 44 | 45 | self.model = LATTICE(self.n_users, self.n_items, self.emb_dim, self.weight_size, self.mess_dropout, image_feats, text_feats) 46 | self.model = self.model.cuda() 47 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) 48 | self.lr_scheduler = self.set_lr_scheduler() 49 | 50 | def set_lr_scheduler(self): 51 | fac = lambda epoch: 0.96 ** (epoch / 50) 52 | scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=fac) 53 | return scheduler 54 | 55 | def test(self, users_to_test, is_val): 56 | self.model.eval() 57 | with torch.no_grad(): 58 | ua_embeddings, ia_embeddings = self.model(self.norm_adj, build_item_graph=True) 59 | result = test_torch(ua_embeddings, ia_embeddings, users_to_test, is_val) 60 | return result 61 | 62 | def train(self): 63 | training_time_list = [] 64 | loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], [] 65 | stopping_step = 0 66 | should_stop = False 67 | cur_best_pre_0 = 0. 68 | 69 | n_batch = data_generator.n_train // args.batch_size + 1 70 | best_recall = 0 71 | for epoch in (range(args.epoch)): 72 | t1 = time() 73 | loss, mf_loss, emb_loss, reg_loss = 0., 0., 0., 0. 74 | n_batch = data_generator.n_train // args.batch_size + 1 75 | f_time, b_time, loss_time, opt_time, clip_time, emb_time = 0., 0., 0., 0., 0., 0. 76 | sample_time = 0. 77 | build_item_graph = True 78 | for idx in (range(n_batch)): 79 | self.model.train() 80 | self.optimizer.zero_grad() 81 | sample_t1 = time() 82 | users, pos_items, neg_items = data_generator.sample() 83 | sample_time += time() - sample_t1 84 | ua_embeddings, ia_embeddings = self.model(self.norm_adj, build_item_graph=build_item_graph) 85 | build_item_graph = False 86 | u_g_embeddings = ua_embeddings[users] 87 | pos_i_g_embeddings = ia_embeddings[pos_items] 88 | neg_i_g_embeddings = ia_embeddings[neg_items] 89 | 90 | 91 | batch_mf_loss, batch_emb_loss, batch_reg_loss = self.bpr_loss(u_g_embeddings, pos_i_g_embeddings, 92 | neg_i_g_embeddings) 93 | 94 | batch_loss = batch_mf_loss + batch_emb_loss + batch_reg_loss 95 | 96 | batch_loss.backward(retain_graph=True) 97 | self.optimizer.step() 98 | 99 | loss += float(batch_loss) 100 | mf_loss += float(batch_mf_loss) 101 | emb_loss += float(batch_emb_loss) 102 | reg_loss += float(batch_reg_loss) 103 | 104 | 105 | self.lr_scheduler.step() 106 | 107 | del ua_embeddings, ia_embeddings, u_g_embeddings, neg_i_g_embeddings, pos_i_g_embeddings 108 | 109 | if math.isnan(loss) == True: 110 | print('ERROR: loss is nan.') 111 | sys.exit() 112 | 113 | perf_str = 'Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f]' % ( 114 | epoch, time() - t1, loss, mf_loss, emb_loss) 115 | training_time_list.append(time() - t1) 116 | print(perf_str) 117 | 118 | if epoch % args.verbose != 0: 119 | continue 120 | 121 | 122 | t2 = time() 123 | users_to_test = list(data_generator.test_set.keys()) 124 | users_to_val = list(data_generator.val_set.keys()) 125 | ret = self.test(users_to_val, is_val=True) 126 | training_time_list.append(t2 - t1) 127 | 128 | t3 = time() 129 | 130 | loss_loger.append(loss) 131 | rec_loger.append(ret['recall']) 132 | pre_loger.append(ret['precision']) 133 | ndcg_loger.append(ret['ndcg']) 134 | hit_loger.append(ret['hit_ratio']) 135 | if args.verbose > 0: 136 | perf_str = 'Epoch %d [%.1fs + %.1fs]: val==[%.5f=%.5f + %.5f + %.5f], recall=[%.5f, %.5f], ' \ 137 | 'precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]' % \ 138 | (epoch, t2 - t1, t3 - t2, loss, mf_loss, emb_loss, reg_loss, ret['recall'][0], 139 | ret['recall'][-1], 140 | ret['precision'][0], ret['precision'][-1], ret['hit_ratio'][0], ret['hit_ratio'][-1], 141 | ret['ndcg'][0], ret['ndcg'][-1]) 142 | print(perf_str) 143 | 144 | if ret['recall'][1] > best_recall: 145 | best_recall = ret['recall'][1] 146 | test_ret = self.test(users_to_test, is_val=False) 147 | perf_str = 'Epoch %d [%.1fs + %.1fs]: test==[%.5f=%.5f + %.5f + %.5f], recall=[%.5f, %.5f], ' \ 148 | 'precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]' % \ 149 | (epoch, t2 - t1, t3 - t2, loss, mf_loss, emb_loss, reg_loss, test_ret['recall'][0], 150 | test_ret['recall'][-1], 151 | test_ret['precision'][0], test_ret['precision'][-1], test_ret['hit_ratio'][0], test_ret['hit_ratio'][-1], 152 | test_ret['ndcg'][0], test_ret['ndcg'][-1]) 153 | print(perf_str) 154 | stopping_step = 0 155 | elif stopping_step < args.early_stopping_patience: 156 | stopping_step += 1 157 | print('#####Early stopping steps: %d #####' % stopping_step) 158 | else: 159 | print('#####Early stop! #####') 160 | break 161 | 162 | print(test_ret) 163 | 164 | def bpr_loss(self, users, pos_items, neg_items): 165 | pos_scores = torch.sum(torch.mul(users, pos_items), dim=1) 166 | neg_scores = torch.sum(torch.mul(users, neg_items), dim=1) 167 | 168 | regularizer = 1./2*(users**2).sum() + 1./2*(pos_items**2).sum() + 1./2*(neg_items**2).sum() 169 | regularizer = regularizer / self.batch_size 170 | 171 | maxi = F.logsigmoid(pos_scores - neg_scores) 172 | mf_loss = -torch.mean(maxi) 173 | 174 | emb_loss = self.decay * regularizer 175 | reg_loss = 0.0 176 | return mf_loss, emb_loss, reg_loss 177 | 178 | def sparse_mx_to_torch_sparse_tensor(self, sparse_mx): 179 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 180 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 181 | indices = torch.from_numpy( 182 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 183 | values = torch.from_numpy(sparse_mx.data) 184 | shape = torch.Size(sparse_mx.shape) 185 | return torch.sparse.FloatTensor(indices, values, shape) 186 | 187 | def set_seed(seed): 188 | np.random.seed(seed) 189 | random.seed(seed) 190 | torch.manual_seed(seed) # cpu 191 | torch.cuda.manual_seed_all(seed) # gpu 192 | 193 | if __name__ == '__main__': 194 | set_seed(args.seed) 195 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) 196 | 197 | config = dict() 198 | config['n_users'] = data_generator.n_users 199 | config['n_items'] = data_generator.n_items 200 | 201 | plain_adj, norm_adj, mean_adj = data_generator.get_adj_mat() 202 | config['norm_adj'] = norm_adj 203 | 204 | trainer = Trainer(data_config=config) 205 | trainer.train() 206 | 207 | -------------------------------------------------------------------------------- /LATTICE/codes/utility/__pycache__/batch_test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/LATTICE/codes/utility/__pycache__/batch_test.cpython-38.pyc -------------------------------------------------------------------------------- /LATTICE/codes/utility/__pycache__/load_data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/LATTICE/codes/utility/__pycache__/load_data.cpython-38.pyc -------------------------------------------------------------------------------- /LATTICE/codes/utility/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/LATTICE/codes/utility/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /LATTICE/codes/utility/__pycache__/parser.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/LATTICE/codes/utility/__pycache__/parser.cpython-38.pyc -------------------------------------------------------------------------------- /LATTICE/codes/utility/batch_test.py: -------------------------------------------------------------------------------- 1 | import utility.metrics as metrics 2 | from utility.parser import parse_args 3 | from utility.load_data import Data 4 | import multiprocessing 5 | import heapq 6 | import torch 7 | import pickle 8 | import numpy as np 9 | from time import time 10 | 11 | cores = multiprocessing.cpu_count() // 5 12 | 13 | args = parse_args() 14 | Ks = eval(args.Ks) 15 | 16 | data_generator = Data(path=args.data_path + args.dataset, batch_size=args.batch_size) 17 | USR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items 18 | N_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test 19 | BATCH_SIZE = args.batch_size 20 | 21 | def ranklist_by_heapq(user_pos_test, test_items, rating, Ks): 22 | item_score = {} 23 | for i in test_items: 24 | item_score[i] = rating[i] 25 | 26 | K_max = max(Ks) 27 | K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get) 28 | 29 | r = [] 30 | for i in K_max_item_score: 31 | if i in user_pos_test: 32 | r.append(1) 33 | else: 34 | r.append(0) 35 | auc = 0. 36 | return r, auc 37 | 38 | def get_auc(item_score, user_pos_test): 39 | item_score = sorted(item_score.items(), key=lambda kv: kv[1]) 40 | item_score.reverse() 41 | item_sort = [x[0] for x in item_score] 42 | posterior = [x[1] for x in item_score] 43 | 44 | r = [] 45 | for i in item_sort: 46 | if i in user_pos_test: 47 | r.append(1) 48 | else: 49 | r.append(0) 50 | auc = metrics.auc(ground_truth=r, prediction=posterior) 51 | return auc 52 | 53 | def ranklist_by_sorted(user_pos_test, test_items, rating, Ks): 54 | item_score = {} 55 | for i in test_items: 56 | item_score[i] = rating[i] 57 | 58 | K_max = max(Ks) 59 | K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get) 60 | 61 | r = [] 62 | for i in K_max_item_score: 63 | if i in user_pos_test: 64 | r.append(1) 65 | else: 66 | r.append(0) 67 | auc = get_auc(item_score, user_pos_test) 68 | return r, auc 69 | 70 | def get_performance(user_pos_test, r, auc, Ks): 71 | precision, recall, ndcg, hit_ratio = [], [], [], [] 72 | 73 | for K in Ks: 74 | precision.append(metrics.precision_at_k(r, K)) 75 | recall.append(metrics.recall_at_k(r, K, len(user_pos_test))) 76 | ndcg.append(metrics.ndcg_at_k(r, K)) 77 | hit_ratio.append(metrics.hit_at_k(r, K)) 78 | 79 | return {'recall': np.array(recall), 'precision': np.array(precision), 80 | 'ndcg': np.array(ndcg), 'hit_ratio': np.array(hit_ratio), 'auc': auc} 81 | 82 | 83 | def test_one_user(x): 84 | # user u's ratings for user u 85 | is_val = x[-1] 86 | rating = x[0] 87 | #uid 88 | u = x[1] 89 | #user u's items in the training set 90 | try: 91 | training_items = data_generator.train_items[u] 92 | except Exception: 93 | training_items = [] 94 | #user u's items in the test set 95 | if is_val: 96 | user_pos_test = data_generator.val_set[u] 97 | else: 98 | user_pos_test = data_generator.test_set[u] 99 | 100 | all_items = set(range(ITEM_NUM)) 101 | 102 | test_items = list(all_items - set(training_items)) 103 | 104 | if args.test_flag == 'part': 105 | r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, Ks) 106 | else: 107 | r, auc = ranklist_by_sorted(user_pos_test, test_items, rating, Ks) 108 | 109 | return get_performance(user_pos_test, r, auc, Ks) 110 | 111 | 112 | def test_torch(ua_embeddings, ia_embeddings, users_to_test, is_val, drop_flag=False, batch_test_flag=False): 113 | result = {'precision': np.zeros(len(Ks)), 'recall': np.zeros(len(Ks)), 'ndcg': np.zeros(len(Ks)), 114 | 'hit_ratio': np.zeros(len(Ks)), 'auc': 0.} 115 | pool = multiprocessing.Pool(cores) 116 | 117 | u_batch_size = BATCH_SIZE * 2 118 | i_batch_size = BATCH_SIZE 119 | 120 | test_users = users_to_test 121 | n_test_users = len(test_users) 122 | n_user_batchs = n_test_users // u_batch_size + 1 123 | count = 0 124 | 125 | for u_batch_id in range(n_user_batchs): 126 | start = u_batch_id * u_batch_size 127 | end = (u_batch_id + 1) * u_batch_size 128 | user_batch = test_users[start: end] 129 | if batch_test_flag: 130 | n_item_batchs = ITEM_NUM // i_batch_size + 1 131 | rate_batch = np.zeros(shape=(len(user_batch), ITEM_NUM)) 132 | 133 | i_count = 0 134 | for i_batch_id in range(n_item_batchs): 135 | i_start = i_batch_id * i_batch_size 136 | i_end = min((i_batch_id + 1) * i_batch_size, ITEM_NUM) 137 | 138 | item_batch = range(i_start, i_end) 139 | u_g_embeddings = ua_embeddings[user_batch] 140 | i_g_embeddings = ia_embeddings[item_batch] 141 | i_rate_batch = torch.matmul(u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1)) 142 | 143 | rate_batch[:, i_start: i_end] = i_rate_batch 144 | i_count += i_rate_batch.shape[1] 145 | 146 | assert i_count == ITEM_NUM 147 | 148 | else: 149 | item_batch = range(ITEM_NUM) 150 | u_g_embeddings = ua_embeddings[user_batch] 151 | i_g_embeddings = ia_embeddings[item_batch] 152 | rate_batch = torch.matmul(u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1)) 153 | 154 | rate_batch = rate_batch.detach().cpu().numpy() 155 | user_batch_rating_uid = zip(rate_batch, user_batch, [is_val] * len(user_batch)) 156 | 157 | batch_result = pool.map(test_one_user, user_batch_rating_uid) 158 | count += len(batch_result) 159 | 160 | for re in batch_result: 161 | result['precision'] += re['precision'] / n_test_users 162 | result['recall'] += re['recall'] / n_test_users 163 | result['ndcg'] += re['ndcg'] / n_test_users 164 | result['hit_ratio'] += re['hit_ratio'] / n_test_users 165 | result['auc'] += re['auc'] / n_test_users 166 | 167 | assert count == n_test_users 168 | pool.close() 169 | return result 170 | -------------------------------------------------------------------------------- /LATTICE/codes/utility/load_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random as rd 3 | import scipy.sparse as sp 4 | from time import time 5 | import json 6 | from utility.parser import parse_args 7 | args = parse_args() 8 | 9 | class Data(object): 10 | def __init__(self, path, batch_size): 11 | # self.path = path + '/%d-core' % args.core 12 | # self.batch_size = batch_size 13 | 14 | # train_file = path + '/%d-core/train.json' % (args.core) 15 | # val_file = path + '/%d-core/val.json' % (args.core) 16 | # test_file = path + '/%d-core/test.json' % (args.core) 17 | 18 | self.path = path #+ '/%d-core' % args.core 19 | self.batch_size = batch_size 20 | 21 | train_file = path + '/train.json'#+ '/%d-core/train.json' % (args.core) 22 | val_file = path + '/val.json' #+ '/%d-core/val.json' % (args.core) 23 | test_file = path + '/test.json' #+ '/%d-core/test.json' % (args.core) 24 | 25 | #get number of users and items 26 | self.n_users, self.n_items = 0, 0 27 | self.n_train, self.n_test = 0, 0 28 | self.neg_pools = {} 29 | 30 | self.exist_users = [] 31 | 32 | train = json.load(open(train_file)) 33 | test = json.load(open(test_file)) 34 | val = json.load(open(val_file)) 35 | for uid, items in train.items(): 36 | if len(items) == 0: 37 | continue 38 | uid = int(uid) 39 | self.exist_users.append(uid) 40 | self.n_items = max(self.n_items, max(items)) 41 | self.n_users = max(self.n_users, uid) 42 | self.n_train += len(items) 43 | 44 | for uid, items in test.items(): 45 | uid = int(uid) 46 | try: 47 | self.n_items = max(self.n_items, max(items)) 48 | self.n_test += len(items) 49 | except: 50 | continue 51 | 52 | for uid, items in val.items(): 53 | uid = int(uid) 54 | try: 55 | self.n_items = max(self.n_items, max(items)) 56 | self.n_val += len(items) 57 | except: 58 | continue 59 | 60 | self.n_items += 1 61 | self.n_users += 1 62 | 63 | text_feats = np.load(args.data_path + '{}/text_feat.npy'.format(args.dataset)) 64 | self.n_items = text_feats.shape[0] 65 | 66 | self.print_statistics() 67 | 68 | self.R = sp.dok_matrix((self.n_users, self.n_items), dtype=np.float32) 69 | self.R_Item_Interacts = sp.dok_matrix((self.n_items, self.n_items), dtype=np.float32) 70 | 71 | self.train_items, self.test_set, self.val_set = {}, {}, {} 72 | for uid, train_items in train.items(): 73 | if len(train_items) == 0: 74 | continue 75 | uid = int(uid) 76 | for idx, i in enumerate(train_items): 77 | self.R[uid, i] = 1. 78 | 79 | self.train_items[uid] = train_items 80 | 81 | for uid, test_items in test.items(): 82 | uid = int(uid) 83 | if len(test_items) == 0: 84 | continue 85 | try: 86 | self.test_set[uid] = test_items 87 | except: 88 | continue 89 | 90 | for uid, val_items in val.items(): 91 | uid = int(uid) 92 | if len(val_items) == 0: 93 | continue 94 | try: 95 | self.val_set[uid] = val_items 96 | except: 97 | continue 98 | 99 | def get_adj_mat(self): 100 | try: 101 | t1 = time() 102 | adj_mat = sp.load_npz(self.path + '/s_adj_mat.npz') 103 | norm_adj_mat = sp.load_npz(self.path + '/s_norm_adj_mat.npz') 104 | mean_adj_mat = sp.load_npz(self.path + '/s_mean_adj_mat.npz') 105 | print('already load adj matrix', adj_mat.shape, time() - t1) 106 | 107 | except Exception: 108 | adj_mat, norm_adj_mat, mean_adj_mat = self.create_adj_mat() 109 | sp.save_npz(self.path + '/s_adj_mat.npz', adj_mat) 110 | sp.save_npz(self.path + '/s_norm_adj_mat.npz', norm_adj_mat) 111 | sp.save_npz(self.path + '/s_mean_adj_mat.npz', mean_adj_mat) 112 | return adj_mat, norm_adj_mat, mean_adj_mat 113 | 114 | def create_adj_mat(self): 115 | t1 = time() 116 | adj_mat = sp.dok_matrix((self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32) 117 | adj_mat = adj_mat.tolil() 118 | R = self.R.tolil() 119 | 120 | adj_mat[:self.n_users, self.n_users:] = R 121 | adj_mat[self.n_users:, :self.n_users] = R.T 122 | adj_mat = adj_mat.todok() 123 | print('already create adjacency matrix', adj_mat.shape, time() - t1) 124 | 125 | t2 = time() 126 | 127 | def normalized_adj_single(adj): 128 | rowsum = np.array(adj.sum(1)) 129 | 130 | d_inv = np.power(rowsum, -1).flatten() 131 | d_inv[np.isinf(d_inv)] = 0. 132 | d_mat_inv = sp.diags(d_inv) 133 | 134 | norm_adj = d_mat_inv.dot(adj) 135 | # norm_adj = adj.dot(d_mat_inv) 136 | print('generate single-normalized adjacency matrix.') 137 | return norm_adj.tocoo() 138 | 139 | def get_D_inv(adj): 140 | rowsum = np.array(adj.sum(1)) 141 | 142 | d_inv = np.power(rowsum, -1).flatten() 143 | d_inv[np.isinf(d_inv)] = 0. 144 | d_mat_inv = sp.diags(d_inv) 145 | return d_mat_inv 146 | 147 | def check_adj_if_equal(adj): 148 | dense_A = np.array(adj.todense()) 149 | degree = np.sum(dense_A, axis=1, keepdims=False) 150 | 151 | temp = np.dot(np.diag(np.power(degree, -1)), dense_A) 152 | print('check normalized adjacency matrix whether equal to this laplacian matrix.') 153 | return temp 154 | 155 | norm_adj_mat = normalized_adj_single(adj_mat + sp.eye(adj_mat.shape[0])) 156 | mean_adj_mat = normalized_adj_single(adj_mat) 157 | 158 | print('already normalize adjacency matrix', time() - t2) 159 | return adj_mat.tocsr(), norm_adj_mat.tocsr(), mean_adj_mat.tocsr() 160 | 161 | 162 | def sample(self): 163 | if self.batch_size <= self.n_users: 164 | users = rd.sample(self.exist_users, self.batch_size) 165 | else: 166 | users = [rd.choice(self.exist_users) for _ in range(self.batch_size)] 167 | # users = self.exist_users[:] 168 | 169 | def sample_pos_items_for_u(u, num): 170 | pos_items = self.train_items[u] 171 | n_pos_items = len(pos_items) 172 | pos_batch = [] 173 | while True: 174 | if len(pos_batch) == num: break 175 | pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0] 176 | pos_i_id = pos_items[pos_id] 177 | 178 | if pos_i_id not in pos_batch: 179 | pos_batch.append(pos_i_id) 180 | return pos_batch 181 | 182 | def sample_neg_items_for_u(u, num): 183 | neg_items = [] 184 | while True: 185 | if len(neg_items) == num: break 186 | neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0] 187 | if neg_id not in self.train_items[u] and neg_id not in neg_items: 188 | neg_items.append(neg_id) 189 | return neg_items 190 | 191 | def sample_neg_items_for_u_from_pools(u, num): 192 | neg_items = list(set(self.neg_pools[u]) - set(self.train_items[u])) 193 | return rd.sample(neg_items, num) 194 | 195 | pos_items, neg_items = [], [] 196 | for u in users: 197 | pos_items += sample_pos_items_for_u(u, 1) 198 | neg_items += sample_neg_items_for_u(u, 1) 199 | # neg_items += sample_neg_items_for_u(u, 3) 200 | return users, pos_items, neg_items 201 | 202 | 203 | 204 | def print_statistics(self): 205 | print('n_users=%d, n_items=%d' % (self.n_users, self.n_items)) 206 | print('n_interactions=%d' % (self.n_train + self.n_test)) 207 | print('n_train=%d, n_test=%d, sparsity=%.5f' % (self.n_train, self.n_test, (self.n_train + self.n_test)/(self.n_users * self.n_items))) 208 | 209 | -------------------------------------------------------------------------------- /LATTICE/codes/utility/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import roc_auc_score 3 | 4 | def recall(rank, ground_truth, N): 5 | return len(set(rank[:N]) & set(ground_truth)) / float(len(set(ground_truth))) 6 | 7 | 8 | def precision_at_k(r, k): 9 | """Score is precision @ k 10 | Relevance is binary (nonzero is relevant). 11 | Returns: 12 | Precision @ k 13 | Raises: 14 | ValueError: len(r) must be >= k 15 | """ 16 | assert k >= 1 17 | r = np.asarray(r)[:k] 18 | return np.mean(r) 19 | 20 | 21 | def average_precision(r,cut): 22 | """Score is average precision (area under PR curve) 23 | Relevance is binary (nonzero is relevant). 24 | Returns: 25 | Average precision 26 | """ 27 | r = np.asarray(r) 28 | out = [precision_at_k(r, k + 1) for k in range(cut) if r[k]] 29 | if not out: 30 | return 0. 31 | return np.sum(out)/float(min(cut, np.sum(r))) 32 | 33 | 34 | def mean_average_precision(rs): 35 | """Score is mean average precision 36 | Relevance is binary (nonzero is relevant). 37 | Returns: 38 | Mean average precision 39 | """ 40 | return np.mean([average_precision(r) for r in rs]) 41 | 42 | 43 | def dcg_at_k(r, k, method=1): 44 | """Score is discounted cumulative gain (dcg) 45 | Relevance is positive real values. Can use binary 46 | as the previous methods. 47 | Returns: 48 | Discounted cumulative gain 49 | """ 50 | r = np.asfarray(r)[:k] 51 | if r.size: 52 | if method == 0: 53 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 54 | elif method == 1: 55 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 56 | else: 57 | raise ValueError('method must be 0 or 1.') 58 | return 0. 59 | 60 | 61 | def ndcg_at_k(r, k, method=1): 62 | """Score is normalized discounted cumulative gain (ndcg) 63 | Relevance is positive real values. Can use binary 64 | as the previous methods. 65 | Returns: 66 | Normalized discounted cumulative gain 67 | """ 68 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) 69 | if not dcg_max: 70 | return 0. 71 | return dcg_at_k(r, k, method) / dcg_max 72 | 73 | 74 | def recall_at_k(r, k, all_pos_num): 75 | r = np.asfarray(r)[:k] 76 | if all_pos_num == 0: 77 | return 0 78 | else: 79 | return np.sum(r) / all_pos_num 80 | 81 | 82 | def hit_at_k(r, k): 83 | r = np.array(r)[:k] 84 | if np.sum(r) > 0: 85 | return 1. 86 | else: 87 | return 0. 88 | 89 | def F1(pre, rec): 90 | if pre + rec > 0: 91 | return (2.0 * pre * rec) / (pre + rec) 92 | else: 93 | return 0. 94 | 95 | def auc(ground_truth, prediction): 96 | try: 97 | res = roc_auc_score(y_true=ground_truth, y_score=prediction) 98 | except Exception: 99 | res = 0. 100 | return res -------------------------------------------------------------------------------- /LATTICE/codes/utility/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser(description="") 5 | 6 | parser.add_argument('--data_path', nargs='?', default='', 7 | help='Input data path.') 8 | parser.add_argument('--seed', type=int, default=123, 9 | help='Random seed') 10 | parser.add_argument('--dataset', nargs='?', default='cloth', 11 | help='Choose a dataset from {grocery, cloth, sport, netflix, sports, baby, clothing}') 12 | parser.add_argument('--verbose', type=int, default=5, 13 | help='Interval of evaluation.') 14 | parser.add_argument('--epoch', type=int, default=200, 15 | help='Number of epoch.') 16 | parser.add_argument('--batch_size', type=int, default=1024, 17 | help='Batch size.') 18 | parser.add_argument('--regs', nargs='?', default='[1e-5,1e-5,1e-2]', 19 | help='Regularizations.') 20 | parser.add_argument('--lr', type=float, default=0.0005, 21 | help='Learning rate.') 22 | parser.add_argument('--model_name', nargs='?', default='lattice', 23 | help='Specify the model name.') 24 | 25 | parser.add_argument('--embed_size', type=int, default=64, 26 | help='Embedding size.') 27 | parser.add_argument('--feat_embed_dim', type=int, default=64, 28 | help='') 29 | parser.add_argument('--weight_size', nargs='?', default='[64,64]', 30 | help='Output sizes of every layer') 31 | parser.add_argument('--core', type=int, default=5, 32 | help='5-core for warm-start; 0-core for cold start') 33 | parser.add_argument('--topk', type=int, default=10, 34 | help='K value of k-NN sparsification') 35 | parser.add_argument('--lambda_coeff', type=float, default=0.9, 36 | help='Lambda value of skip connection') 37 | parser.add_argument('--cf_model', nargs='?', default='lightgcn', 38 | help='Downstream Collaborative Filtering model {mf, ngcf, lightgcn}') 39 | parser.add_argument('--n_layers', type=int, default=1, 40 | help='Number of item graph conv layers') 41 | parser.add_argument('--mess_dropout', nargs='?', default='[0.1, 0.1]', 42 | help='Keep probability w.r.t. message dropout (i.e., 1-dropout_ratio) for each deep layer. 1: no dropout.') 43 | 44 | parser.add_argument('--early_stopping_patience', type=int, default=10, 45 | help='') 46 | parser.add_argument('--gpu_id', type=int, default=0, 47 | help='GPU id') 48 | parser.add_argument('--Ks', nargs='?', default='[10, 20, 50]', 49 | help='K value of ndcg/recall @ k') 50 | parser.add_argument('--test_flag', nargs='?', default='part', 51 | help='Specify the test type from {part, full}, indicating whether the reference is done in mini-batch') 52 | 53 | 54 | return parser.parse_args() 55 | -------------------------------------------------------------------------------- /PromptMM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/PromptMM.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PromptMM: Multi-Modal Knowledge Distillation for Recommendation with Prompt-Tuning 2 | 3 | PyTorch implementation for WWW 2023 paper [PromptMM: Multi-Modal Knowledge Distillation for Recommendation with Prompt-Tuning](https://arxiv.org/html/2402.17188v1). 4 | 5 | [Wei Wei](#), [Jiabin Tang](https://tjb-tech.github.io/), [Yangqin Jiang](#), [Lianghao Xia](https://akaxlh.github.io/) and [Chao Huang](https://sites.google.com/view/chaoh/home)*. 6 | (*Correspondence) 7 | 8 |

9 | 10 |

11 | 12 | 13 |

Dependencies

14 | 15 | * Python >= 3.9.13 16 | * [Pytorch](https://pytorch.org/) >= 1.13.0+cu116 17 | * [dgl-cuda11.6](https://www.dgl.ai/) >= 0.9.1post1 18 | 19 | 20 | 21 | 22 |

Usage

23 | 24 | Start training and inference as: 25 | 26 | ``` 27 | python ./main.py --dataset {DATASET} 28 | ``` 29 | Supported datasets: `Amazon-Electronics`, `Netflix`, `Tiktok` 30 | 31 | 32 |

Datasets

33 | 34 | ``` 35 | ├─ MMSSL/ 36 | ├── data/ 37 | ├── tiktok/ 38 | ... 39 | ``` 40 | | Dataset | | Netflix | | | Tiktok | | | | Electronics | | 41 | |:-----------:|:-:|:--------:|:---:|:-:|:--------:|:---:|:---:|:-:|:-----------:|:----:| 42 | | Modality | | V | T | | V | A | T | | V | T | 43 | | Feat. Dim. | | 512 | 768 | | 128 | 128 | 768 | | 4096 | 1024 | 44 | | User | | 43,739 | | | 14,343 | | | | 41,691 | | 45 | | Item | | 17,239 | | | 8,690 | | | | 21,479 | | 46 | | Interaction | | 609,341 | | | 276,637 | | | | 359,165 | | 47 | | Sparsity | | 99.919\% | | | 99.778\% | | | | 99.960\% | | 48 | 49 | 50 | - `2024.2.27 new multi-modal datastes uploaded`: 📢📢 🌹🌹 We provide new multi-modal datasets `Netflix` and `MovieLens` (i.e., CF training data, multi-modal data including `item text` and `posters`) of new multi-modal work [LLMRec](https://github.com/HKUDS/LLMRec) on Google Drive. 🌹We hope to contribute to our community and facilitate your research~ 51 | 52 | - `2023.2.27 update(all datasets uploaded)`: We provide the processed data at [Google Drive](https://drive.google.com/drive/folders/17vnX8S6a_68xzML1tAM5m9YsQyKZ1UKb?usp=share_link). 53 | 54 | 🚀🚀 The provided dataset is compatible with multi-modal recommender models such as [MMSSL](https://github.com/HKUDS/MMSSL), [LATTICE](https://github.com/CRIPAC-DIG/LATTICE), and [MICRO](https://github.com/CRIPAC-DIG/MICRO) and requires no additional data preprocessing, including (1) basic user-item interactions and (2) multi-modal features. 55 | 56 | ``` 57 | # part of data preprocessing 58 | # #----json2mat-------------------------------------------------------------------------------------------------- 59 | import json 60 | from scipy.sparse import csr_matrix 61 | import pickle 62 | import numpy as np 63 | n_user, n_item = 39387, 23033 64 | f = open('/home/weiw/Code/MM/MMSSL/data/clothing/train.json', 'r') 65 | train = json.load(f) 66 | row, col = [], [] 67 | for index, value in enumerate(train.keys()): 68 | for i in range(len(train[value])): 69 | row.append(int(value)) 70 | col.append(train[value][i]) 71 | data = np.ones(len(row)) 72 | train_mat = csr_matrix((data, (row, col)), shape=(n_user, n_item)) 73 | pickle.dump(train_mat, open('./train_mat', 'wb')) 74 | # # ----json2mat-------------------------------------------------------------------------------------------------- 75 | 76 | 77 | # ----mat2json-------------------------------------------------------------------------------------------------- 78 | # train_mat = pickle.load(open('./train_mat', 'rb')) 79 | test_mat = pickle.load(open('./test_mat', 'rb')) 80 | # val_mat = pickle.load(open('./val_mat', 'rb')) 81 | 82 | # total_mat = train_mat + test_mat + val_mat 83 | total_mat =test_mat 84 | 85 | # total_mat = pickle.load(open('./new_mat','rb')) 86 | # total_mat = pickle.load(open('./new_mat','rb')) 87 | total_array = total_mat.toarray() 88 | total_dict = {} 89 | 90 | for i in range(total_array.shape[0]): 91 | total_dict[str(i)] = [index for index, value in enumerate(total_array[i]) if value!=0] 92 | 93 | new_total_dict = {} 94 | 95 | for i in range(len(total_dict)): 96 | # if len(total_dict[str(i)])>1: 97 | new_total_dict[str(i)]=total_dict[str(i)] 98 | 99 | # train_dict, test_dict = {}, {} 100 | 101 | # for i in range(len(new_total_dict)): 102 | # train_dict[str(i)] = total_dict[str(i)][:-1] 103 | # test_dict[str(i)] = [total_dict[str(i)][-1]] 104 | 105 | # train_json_str = json.dumps(train_dict) 106 | test_json_str = json.dumps(new_total_dict) 107 | 108 | # with open('./new_train.json', 'w') as json_file: 109 | # # with open('./new_train_json', 'w') as json_file: 110 | # json_file.write(train_json_str) 111 | with open('./test.json', 'w') as test_file: 112 | # with open('./new_test_json', 'w') as test_file: 113 | test_file.write(test_json_str) 114 | # ----mat2json-------------------------------------------------------------------------------------------------- 115 | ``` 116 | 117 | 118 |

119 | 120 |

121 | 122 | 123 | ## Acknowledgement 124 | 125 | ## Acknowledgement 126 | 127 | The structure of this code is largely based on [LATTICE](https://github.com/CRIPAC-DIG/LATTICE), [MICRO](https://github.com/CRIPAC-DIG/MICRO). Thank them for their work. 128 | 129 | -------------------------------------------------------------------------------- /codes/Models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from time import time 4 | import pickle 5 | import pickle 6 | import scipy.sparse as sp 7 | from scipy.sparse import csr_matrix 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | 14 | from sklearn.decomposition import PCA, FastICA 15 | from sklearn import manifold 16 | from sklearn.manifold import TSNE 17 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 18 | 19 | # from utility.parser import parse_args 20 | from utility.norm import build_sim, build_knn_normalized_graph 21 | # args = parse_args() 22 | from utility.parser import args 23 | 24 | 25 | 26 | class Teacher_Model(nn.Module): 27 | def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats, text_feats): 28 | 29 | super().__init__() 30 | self.n_users = n_users 31 | self.n_items = n_items 32 | self.embedding_dim = embedding_dim 33 | self.weight_size = weight_size 34 | self.n_ui_layers = len(self.weight_size) 35 | self.weight_size = [self.embedding_dim] + self.weight_size 36 | 37 | self.image_trans = nn.Linear(image_feats.shape[1], args.embed_size) 38 | self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size) 39 | nn.init.xavier_uniform_(self.image_trans.weight) 40 | nn.init.xavier_uniform_(self.text_trans.weight) 41 | self.encoder = nn.ModuleDict() 42 | self.encoder['image_encoder'] = self.image_trans # ^-^ 43 | self.encoder['text_encoder'] = self.text_trans # ^-^ 44 | 45 | self.user_id_embedding = nn.Embedding(n_users, self.embedding_dim) 46 | self.item_id_embedding = nn.Embedding(n_items, self.embedding_dim) 47 | 48 | nn.init.xavier_uniform_(self.user_id_embedding.weight) 49 | nn.init.xavier_uniform_(self.item_id_embedding.weight) 50 | self.image_feats = torch.tensor(image_feats).float().cuda() 51 | self.text_feats = torch.tensor(text_feats).float().cuda() 52 | self.image_embedding = nn.Embedding.from_pretrained(torch.Tensor(image_feats), freeze=False) 53 | self.text_embedding = nn.Embedding.from_pretrained(torch.Tensor(text_feats), freeze=False) 54 | 55 | self.softmax = nn.Softmax(dim=-1) 56 | self.act = nn.Sigmoid() 57 | self.sigmoid = nn.Sigmoid() 58 | self.dropout = nn.Dropout(p=args.drop_rate) 59 | self.batch_norm = nn.BatchNorm1d(args.embed_size) 60 | 61 | def mm(self, x, y): 62 | if args.sparse: 63 | return torch.sparse.mm(x, y) 64 | else: 65 | return torch.mm(x, y) 66 | def sim(self, z1, z2): 67 | z1 = F.normalize(z1) 68 | z2 = F.normalize(z2) 69 | return torch.mm(z1, z2.t()) 70 | 71 | def batched_contrastive_loss(self, z1, z2, batch_size=4096): 72 | device = z1.device 73 | num_nodes = z1.size(0) 74 | num_batches = (num_nodes - 1) // batch_size + 1 75 | f = lambda x: torch.exp(x / self.tau) 76 | indices = torch.arange(0, num_nodes).to(device) 77 | losses = [] 78 | 79 | for i in range(num_batches): 80 | mask = indices[i * batch_size:(i + 1) * batch_size] 81 | refl_sim = f(self.sim(z1[mask], z1)) 82 | between_sim = f(self.sim(z1[mask], z2)) 83 | 84 | losses.append(-torch.log( 85 | between_sim[:, i * batch_size:(i + 1) * batch_size].diag() 86 | / (refl_sim.sum(1) + between_sim.sum(1) 87 | - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag()))) 88 | 89 | loss_vec = torch.cat(losses) 90 | return loss_vec.mean() 91 | 92 | def csr_norm(self, csr_mat, mean_flag=False): 93 | rowsum = np.array(csr_mat.sum(1)) 94 | rowsum = np.power(rowsum+1e-8, -0.5).flatten() 95 | rowsum[np.isinf(rowsum)] = 0. 96 | rowsum_diag = sp.diags(rowsum) 97 | 98 | colsum = np.array(csr_mat.sum(0)) 99 | colsum = np.power(colsum+1e-8, -0.5).flatten() 100 | colsum[np.isinf(colsum)] = 0. 101 | colsum_diag = sp.diags(colsum) 102 | 103 | if mean_flag == False: 104 | return rowsum_diag*csr_mat*colsum_diag 105 | else: 106 | return rowsum_diag*csr_mat 107 | 108 | def matrix_to_tensor(self, cur_matrix): 109 | if type(cur_matrix) != sp.coo_matrix: 110 | cur_matrix = cur_matrix.tocoo() # 111 | indices = torch.from_numpy(np.vstack((cur_matrix.row, cur_matrix.col)).astype(np.int64)) # 112 | values = torch.from_numpy(cur_matrix.data) # 113 | shape = torch.Size(cur_matrix.shape) 114 | 115 | return torch.sparse.FloatTensor(indices, values, shape).to(torch.float32).cuda() # 116 | 117 | def para_dict_to_tenser(self, para_dict): 118 | """ 119 | :param para_dict: nn.ParameterDict() 120 | :return: tensor 121 | """ 122 | tensors = [] 123 | 124 | for beh in para_dict.keys(): 125 | tensors.append(para_dict[beh]) 126 | tensors = torch.stack(tensors, dim=0) 127 | 128 | return tensors 129 | 130 | 131 | def multi_head_self_attention(self, trans_w, embedding_t_1, embedding_t): 132 | 133 | q = self.para_dict_to_tenser(embedding_t) 134 | v = k = self.para_dict_to_tenser(embedding_t_1) 135 | beh, N, d_h = q.shape[0], q.shape[1], args.embed_size/args.head_num 136 | 137 | Q = torch.matmul(q, trans_w['w_q']) 138 | K = torch.matmul(k, trans_w['w_k']) 139 | V = v 140 | 141 | Q = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3) 142 | K = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3) 143 | 144 | Q = torch.unsqueeze(Q, 2) 145 | K = torch.unsqueeze(K, 1) 146 | V = torch.unsqueeze(V, 1) 147 | 148 | att = torch.mul(Q, K) / torch.sqrt(torch.tensor(d_h)) 149 | att = torch.sum(att, dim=-1) 150 | att = torch.unsqueeze(att, dim=-1) 151 | att = F.softmax(att, dim=2) 152 | 153 | Z = torch.mul(att, V) 154 | Z = torch.sum(Z, dim=2) 155 | 156 | Z_list = [value for value in Z] 157 | Z = torch.cat(Z_list, -1) 158 | Z = torch.matmul(Z, self.weight_dict['w_self_attention_cat']) 159 | 160 | args.model_cat_rate*F.normalize(Z, p=2, dim=2) 161 | return Z, att.detach() 162 | 163 | # def prompt_tuning(self, soft_token_u, soft_token_i): 164 | # # self.user_id_embedding = nn.Embedding.from_pretrained(soft_token_u, freeze=False) 165 | # # self.item_id_embedding = nn.Embedding.from_pretrained(soft_token_u, freeze=False) 166 | # self.prompt_user = soft_token_u 167 | # self.prompt_item = soft_token_i 168 | 169 | def forward(self, ui_graph, iu_graph, prompt_module=None): 170 | 171 | # def forward(self, ui_graph, iu_graph): 172 | 173 | prompt_user, prompt_item = prompt_module() # [n*32] 174 | # ----feature prompt---- 175 | # feat_prompt_user = torch.mean( torch.stack((torch.mm(prompt_user, torch.mm(prompt_user.T, self.image_feats)), torch.mm(prompt_user, torch.mm(prompt_user.T, self.text_feats)))), dim=0 ) 176 | # feat_prompt_user = torch.mm(prompt_user, torch.mm(prompt_user.T, self.text_feats)) 177 | feat_prompt_item_image = torch.mm(prompt_item, torch.mm(prompt_item.T, self.image_feats)) 178 | feat_prompt_item_text = torch.mm(prompt_item, torch.mm(prompt_item.T, self.text_feats)) 179 | # feat_prompt_image_item = torch.mm(prompt_item, torch.mm(prompt_item.T, self.image_feats)) 180 | # feat_prompt_text_item = torch.mm(prompt_item, torch.mm(prompt_item.T, self.text_feats)) 181 | # ----feature prompt---- 182 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + feat_prompt_item_image )) 183 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + feat_prompt_item_text )) 184 | 185 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + F.normalize(feat_prompt_item_image, p=2, dim=1) )) 186 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + F.normalize(feat_prompt_item_text, p=2, dim=1) )) 187 | 188 | image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + args.feat_soft_token_rate*F.normalize(feat_prompt_item_image, p=2, dim=1) )) 189 | text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + args.feat_soft_token_rate*F.normalize(feat_prompt_item_text, p=2, dim=1) )) 190 | # args.feat_soft_token_rate*F.normalize(feat_prompt_item_image, p=2, dim=1) 191 | # args.feat_soft_token_rate*F.normalize(feat_prompt_item_text, p=2, dim=1) 192 | 193 | 194 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats)) 195 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats)) 196 | 197 | for i in range(args.layers): 198 | image_user_feats = self.mm(ui_graph, image_feats) 199 | image_item_feats = self.mm(iu_graph, image_user_feats) 200 | # image_user_id = self.mm(image_ui_graph, self.item_id_embedding.weight) 201 | # image_item_id = self.mm(image_iu_graph, self.user_id_embedding.weight) 202 | 203 | text_user_feats = self.mm(ui_graph, text_feats) 204 | text_item_feats = self.mm(iu_graph, text_user_feats) 205 | 206 | # text_user_id = self.mm(text_ui_graph, self.item_id_embedding.weight) 207 | # text_item_id = self.mm(text_iu_graph, self.user_id_embedding.weight) 208 | 209 | # self.embedding_dict['user']['image'] = image_user_id 210 | # self.embedding_dict['user']['text'] = text_user_id 211 | # self.embedding_dict['item']['image'] = image_item_id 212 | # self.embedding_dict['item']['text'] = text_item_id 213 | # user_z, att_u = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['user'], self.embedding_dict['user']) 214 | # item_z, att_i = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['item'], self.embedding_dict['item']) 215 | # user_emb = user_z.mean(0) 216 | # item_emb = item_z.mean(0) 217 | u_g_embeddings = self.user_id_embedding.weight + args.soft_token_rate*F.normalize(prompt_user, p=2, dim=1) 218 | i_g_embeddings = self.item_id_embedding.weight + args.soft_token_rate*F.normalize(prompt_item, p=2, dim=1) 219 | user_emb_list = [u_g_embeddings] 220 | item_emb_list = [i_g_embeddings] 221 | for i in range(self.n_ui_layers): 222 | if i == (self.n_ui_layers-1): 223 | u_g_embeddings = self.softmax( torch.mm(ui_graph, i_g_embeddings) ) 224 | i_g_embeddings = self.softmax( torch.mm(iu_graph, u_g_embeddings) ) 225 | 226 | else: 227 | u_g_embeddings = torch.mm(ui_graph, i_g_embeddings) 228 | i_g_embeddings = torch.mm(iu_graph, u_g_embeddings) 229 | 230 | user_emb_list.append(u_g_embeddings) 231 | item_emb_list.append(i_g_embeddings) 232 | 233 | u_g_embeddings = torch.mean(torch.stack(user_emb_list), dim=0) 234 | i_g_embeddings = torch.mean(torch.stack(item_emb_list), dim=0) 235 | 236 | 237 | u_g_embeddings = u_g_embeddings + args.model_cat_rate*F.normalize(image_user_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_user_feats, p=2, dim=1) 238 | i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(image_item_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_item_feats, p=2, dim=1) 239 | 240 | return u_g_embeddings, i_g_embeddings, image_item_feats, text_item_feats, image_user_feats, text_user_feats, u_g_embeddings, i_g_embeddings , prompt_user, prompt_item 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | class PromptLearner(nn.Module): 249 | def __init__(self, image_feats=None, text_feats=None, ui_graph=None): 250 | super().__init__() 251 | self.ui_graph = ui_graph 252 | 253 | 254 | if args.hard_token_type=='pca': 255 | try: 256 | t1 = time() 257 | hard_token_image = pickle.load(open(args.data_path + args.dataset + '/hard_token_image_pca','rb')) 258 | hard_token_text = pickle.load(open(args.data_path + args.dataset + '/hard_token_text_pca','rb')) 259 | print('already load hard token', time() - t1) 260 | except Exception: 261 | hard_token_image = PCA(n_components=args.embed_size).fit_transform(image_feats) 262 | hard_token_text = PCA(n_components=args.embed_size).fit_transform(text_feats) 263 | pickle.dump(hard_token_image, open(args.data_path + args.dataset + '/hard_token_image_pca','wb')) 264 | pickle.dump(hard_token_text, open(args.data_path + args.dataset + '/hard_token_text_pca','wb')) 265 | elif args.hard_token_type=='ica': 266 | try: 267 | t1 = time() 268 | hard_token_image = pickle.load(open(args.data_path + args.dataset + '/hard_token_image_ica','rb')) 269 | hard_token_text = pickle.load(open(args.data_path + args.dataset + '/hard_token_text_ica','rb')) 270 | print('already load hard token', time() - t1) 271 | except Exception: 272 | hard_token_image = FastICA(n_components=args.embed_size, random_state=12).fit_transform(image_feats) 273 | hard_token_text = FastICA(n_components=args.embed_size, random_state=12).fit_transform(text_feats) 274 | pickle.dump(hard_token_image, open(args.data_path + args.dataset + '/hard_token_image_ica','wb')) 275 | pickle.dump(hard_token_text, open(args.data_path + args.dataset + '/hard_token_text_ica','wb')) 276 | elif args.hard_token_type=='isomap': 277 | hard_token_image = manifold.Isomap(n_neighbors=5, n_components=args.embed_size, n_jobs=-1).fit_transform(image_feats) 278 | hard_token_text = manifold.Isomap(n_neighbors=5, n_components=args.embed_size, n_jobs=-1).fit_transform(text_feats) 279 | # elif args.hard_token_type=='tsne': 280 | # hard_token_image = TSNE(n_components=args.embed_size, n_iter=300).fit_transform(image_feats) 281 | # hard_token_text = TSNE(n_components=args.embed_size, n_iter=300).fit_transform(text_feats) 282 | # elif args.hard_token_type=='lda': 283 | # hard_token_image = LinearDiscriminantAnalysis(n_components=args.embed_size).fit_transform(image_feats) 284 | # hard_token_text = LinearDiscriminantAnalysis(n_components=args.embed_size).fit_transform(text_feats) 285 | 286 | # self.item_hard_token = nn.Embedding.from_pretrained(torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0), freeze=False).cuda().weight 287 | # self.user_hard_token = nn.Embedding.from_pretrained(torch.mm(ui_graph, self.item_hard_token), freeze=False).cuda().weight 288 | 289 | self.item_hard_token = torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0).cuda() 290 | self.user_hard_token = torch.mm(ui_graph, self.item_hard_token).cuda() 291 | 292 | self.trans_user = nn.Linear(args.embed_size, args.embed_size).cuda() 293 | self.trans_item = nn.Linear(args.embed_size, args.embed_size).cuda() 294 | # nn.init.xavier_uniform_(self.gnn_trans_user.weight) 295 | # nn.init.xavier_uniform_(self.gnn_trans_item.weight) 296 | # self.gnn_trans_user = self.gnn_trans_user.cuda() 297 | # self.gnn_trans_item = self.gnn_trans_item.cuda() 298 | # self.item_hard_token = torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0).cuda() 299 | 300 | 301 | def forward(self): 302 | # self.user_hard_token = self.gnn_trans_user(torch.mm(self.ui_graph, self.item_hard_token)) 303 | # self.item_hard_token = self.gnn_trans_item(self.item_hard_token) 304 | # return self.user_hard_token , self.item_hard_token 305 | return F.dropout(self.trans_user(self.user_hard_token), args.prompt_dropout) , F.dropout(self.trans_item(self.item_hard_token), args.prompt_dropout) 306 | 307 | 308 | 309 | 310 | class Student_LightGCN(nn.Module): 311 | def __init__(self, n_users, n_items, embedding_dim, gnn_layer, dropout_list, image_feats=None, text_feats=None): 312 | super().__init__() 313 | self.n_users = n_users 314 | self.n_items = n_items 315 | self.embedding_dim = embedding_dim 316 | self.n_ui_layers = gnn_layer 317 | 318 | self.user_id_embedding = nn.Embedding(n_users, embedding_dim) 319 | self.item_id_embedding = nn.Embedding(n_items, embedding_dim) 320 | nn.init.xavier_uniform_(self.user_id_embedding.weight) 321 | nn.init.xavier_uniform_(self.item_id_embedding.weight) 322 | 323 | # self.feat_trans = nn.Linear(args.embed_size, args.student_embed_size) 324 | # # self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size) 325 | # nn.init.xavier_uniform_(self.feat_trans.weight) 326 | # # nn.init.xavier_uniform_(self.text_trans.weight) 327 | 328 | def init_user_item_embed(self, pre_u_embed, pre_i_embed): 329 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False) 330 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False) 331 | 332 | self.user_id_embedding_pre = nn.Embedding.from_pretrained(pre_u_embed, freeze=False) 333 | self.item_id_embedding_pre = nn.Embedding.from_pretrained(pre_i_embed, freeze=False) 334 | 335 | def get_embedding(self): 336 | return self.user_id_embedding, self.item_id_embedding 337 | 338 | def forward(self, adj): 339 | 340 | # # teacher_feat_dict = { 'item_image':t_i_image_embed.deteach(),'item_text':t_i_text_embed.deteach(),'user_image':t_u_image_embed.deteach(),'user_text':t_u_text_embed.deteach() } 341 | # tmp_feat_dict = {} 342 | # for index,value in enumerate(teacher_feat_dict.keys()): 343 | # tmp_feat_dict[value] = self.feat_trans(teacher_feat_dict[value]) 344 | # u_g_embeddings = self.user_id_embedding.weight + args.model_cat_rate*F.normalize(tmp_feat_dict['user_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(tmp_feat_dict['user_text'], p=2, dim=1) 345 | # i_g_embeddings = self.item_id_embedding.weight + args.model_cat_rate*F.normalize(tmp_feat_dict['item_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(tmp_feat_dict['item_text'], p=2, dim=1) 346 | # ego_embeddings = torch.cat((u_g_embeddings, i_g_embeddings), dim=0) 347 | 348 | # self.user_id_embedding_pre = nn.Embedding.from_pretrained(pre_u_embed, freeze=False) 349 | # self.item_id_embedding_pre = nn.Embedding.from_pretrained(pre_i_embed, freeze=False) 350 | 351 | ego_embeddings = torch.cat((self.user_id_embedding.weight+self.user_id_embedding_pre.weight, self.item_id_embedding.weight+self.item_id_embedding_pre.weight), dim=0) 352 | # ego_embeddings = torch.cat((self.user_id_embedding.weight, self.item_id_embedding.weight), dim=0) 353 | all_embeddings = [ego_embeddings] 354 | for i in range(self.n_ui_layers): 355 | side_embeddings = torch.sparse.mm(adj, ego_embeddings) 356 | ego_embeddings = side_embeddings 357 | all_embeddings += [ego_embeddings] 358 | all_embeddings = torch.stack(all_embeddings, dim=1) 359 | all_embeddings = all_embeddings.mean(dim=1, keepdim=False) 360 | u_g_embeddings, i_g_embeddings = torch.split(all_embeddings, [self.n_users, self.n_items], dim=0) 361 | # u_g_embeddings += teacher_feat_dict['user_image'] + teacher_feat_dict['user_text'] 362 | # i_g_embeddings += teacher_feat_dict['item_image'] + teacher_feat_dict['item_text'] 363 | # u_g_embeddings = u_g_embeddings + args.model_cat_rate*F.normalize(teacher_feat_dict['user_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(teacher_feat_dict['user_text'], p=2, dim=1) 364 | # i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(teacher_feat_dict['item_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(teacher_feat_dict['item_text'], p=2, dim=1) 365 | 366 | return u_g_embeddings, i_g_embeddings 367 | # return self.user_id_embedding.weight, self.item_id_embedding.weight 368 | 369 | 370 | 371 | class Student_GCN(nn.Module): 372 | def __init__(self, n_users, n_items, embedding_dim, gnn_layer=2, drop_out=0., image_feats=None, text_feats=None): 373 | super(Student_GCN, self).__init__() 374 | self.embedding_dim = embedding_dim 375 | 376 | # self.layers = nn.Sequential(GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=True), 377 | # GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=False), 378 | # ) 379 | # self.layer_list = nn.ModuleList() 380 | # for i in range(args.student_n_layers): 381 | # self.layer_list.append(GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=False)) 382 | 383 | self.trans_user = nn.Linear(args.embed_size, args.embed_size).cuda() 384 | self.trans_item = nn.Linear(args.embed_size, args.embed_size).cuda() 385 | 386 | 387 | def forward(self, user_x, item_x, ui_graph, iu_graph): 388 | # # x, support = inputs 389 | # # user_x, item_x = self.layers((user_x, item_x, ui_graph, iu_graph)) 390 | # for i in range(args.student_n_layers): 391 | # user_x, item_x = self.layer_list[i](user_x, item_x, ui_graph, iu_graph) 392 | # return user_x, item_x 393 | 394 | return self.trans_user(user_x), self.trans_item(item_x) 395 | # self.user_id_embedding = nn.Embedding.from_pretrained(user_x, freeze=True) 396 | # self.item_id_embedding = nn.Embedding.from_pretrained(item_x, freeze=True) 397 | # return self.user_id_embedding.weight, self.item_id_embedding.weight 398 | 399 | def l2_loss(self): 400 | layer = self.layers.children() 401 | layer = next(iter(layer)) 402 | loss = None 403 | 404 | for p in layer.parameters(): 405 | if loss is None: 406 | loss = p.pow(2).sum() 407 | else: 408 | loss += p.pow(2).sum() 409 | 410 | return loss 411 | 412 | class GraphConvolution(nn.Module): 413 | def __init__(self, input_dim, output_dim, dropout=0., is_sparse_inputs=False, bias=False, activation = F.relu,featureless=False): 414 | super(GraphConvolution, self).__init__() 415 | self.dropout = dropout 416 | self.bias = bias 417 | self.activation = activation 418 | self.is_sparse_inputs = is_sparse_inputs 419 | self.featureless = featureless 420 | # self.num_features_nonzero = num_features_nonzero 421 | # self.user_weight = nn.Parameter(torch.randn(input_dim, output_dim)) 422 | # self.item_weight = nn.Parameter(torch.randn(input_dim, output_dim)) 423 | self.user_weight = nn.Parameter(torch.empty(input_dim, output_dim)) 424 | self.item_weight = nn.Parameter(torch.empty(input_dim, output_dim)) 425 | nn.init.xavier_uniform_(self.user_weight) 426 | nn.init.xavier_uniform_(self.item_weight) 427 | self.bias = None 428 | if bias: 429 | self.bias = nn.Parameter(torch.zeros(output_dim)) 430 | 431 | 432 | def forward(self, user_x, item_x, ui_graph, iu_graph): 433 | # print('inputs:', inputs) 434 | # x, support = inputs 435 | # if self.training and self.is_sparse_inputs: 436 | # x = sparse_dropout(x, self.dropout, self.num_features_nonzero) 437 | # elif self.training: 438 | user_x = F.dropout(user_x, self.dropout) 439 | item_x = F.dropout(item_x, self.dropout) 440 | # convolve 441 | if not self.featureless: # if it has features x 442 | if self.is_sparse_inputs: 443 | xw = torch.sparse.mm(user_x, self.user_weight) 444 | xw = torch.sparse.mm(item_x, self.item_weight) 445 | else: 446 | xw_user = torch.mm(user_x, self.user_weight) 447 | xw_item = torch.mm(item_x, self.item_weight) 448 | else: 449 | xw = self.weight 450 | out_user = torch.sparse.mm(ui_graph, xw_item) 451 | out_item = torch.sparse.mm(iu_graph, xw_user) 452 | 453 | if self.bias is not None: 454 | out += self.bias 455 | return self.activation(out_user), self.activation(out_item) 456 | 457 | 458 | def sparse_dropout(x, rate, noise_shape): 459 | """ 460 | :param x: 461 | :param rate: 462 | :param noise_shape: int scalar 463 | :return: 464 | """ 465 | random_tensor = 1 - rate 466 | random_tensor += torch.rand(noise_shape).to(x.device) 467 | dropout_mask = torch.floor(random_tensor).byte() 468 | i = x._indices() # [2, 49216] 469 | v = x._values() # [49216] 470 | # [2, 4926] => [49216, 2] => [remained node, 2] => [2, remained node] 471 | i = i[:, dropout_mask] 472 | v = v[dropout_mask] 473 | out = torch.sparse.FloatTensor(i, v, x.shape).to(x.device) 474 | out = out * (1./ (1-rate)) 475 | return out 476 | 477 | 478 | def dot(x, y, sparse=False): 479 | if sparse: 480 | res = torch.sparse.mm(x, y) 481 | else: 482 | res = torch.mm(x, y) 483 | return res 484 | 485 | 486 | 487 | 488 | 489 | class BLMLP(nn.Module): 490 | def __init__(self): 491 | super(BLMLP, self).__init__() 492 | self.W = nn.Parameter(nn.init.xavier_uniform_(torch.empty(args.student_embed_size, args.student_embed_size))) 493 | self.act = nn.LeakyReLU(negative_slope=0.5) 494 | 495 | def forward(self, embeds): 496 | pass 497 | 498 | def featureExtract(self, embeds): 499 | return self.act(embeds @ self.W) + embeds 500 | 501 | def pairPred(self, embeds1, embeds2): 502 | return (self.featureExtract(embeds1) * self.featureExtract(embeds2)).sum(dim=-1) 503 | 504 | def crossPred(self, embeds1, embeds2): 505 | return self.featureExtract(embeds1) @ self.featureExtract(embeds2).T 506 | 507 | 508 | 509 | class Student_MLP(nn.Module): 510 | def __init__(self): 511 | super(Student_MLP, self).__init__() 512 | # self.n_users = n_users 513 | # self.n_items = n_items 514 | # self.embedding_dim = embedding_dim 515 | 516 | # self.uEmbeds = nn.Parameter(init(torch.empty(args.user, args.latdim))) 517 | # self.iEmbeds = nn.Parameter(init(torch.empty(args.item, args.latdim))) 518 | 519 | self.user_trans = nn.Linear(args.embed_size, args.embed_size) 520 | self.item_trans = nn.Linear(args.embed_size, args.embed_size) 521 | nn.init.xavier_uniform_(self.user_trans.weight) 522 | nn.init.xavier_uniform_(self.item_trans.weight) 523 | 524 | self.MLP = BLMLP() 525 | # self.overallTime = datetime.timedelta(0) 526 | 527 | 528 | def get_embedding(self): 529 | return self.user_id_embedding, self.item_id_embedding 530 | 531 | 532 | def forward(self, pre_user, pre_item, ): 533 | # pre_user, pre_item = self.user_id_embedding.weight, self.item_id_embedding.weight 534 | user_embed = self.user_trans(pre_user) 535 | item_embed = self.user_trans(pre_item) 536 | 537 | return user_embed, item_embed 538 | # return pre_user, pre_item 539 | 540 | def init_user_item_embed(self, pre_u_embed, pre_i_embed): 541 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False) 542 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False) 543 | 544 | def pointPosPredictwEmbeds(self, uEmbeds, iEmbeds, ancs, poss): 545 | ancEmbeds = uEmbeds[ancs] 546 | posEmbeds = iEmbeds[poss] 547 | nume = self.MLP.pairPred(ancEmbeds, posEmbeds) 548 | return nume 549 | 550 | def pointNegPredictwEmbeds(self, embeds1, embeds2, nodes1, temp=1.0): 551 | pckEmbeds1 = embeds1[nodes1] 552 | preds = self.MLP.crossPred(pckEmbeds1, embeds2) 553 | return torch.exp(preds / temp).sum(-1) 554 | 555 | def pairPredictwEmbeds(self, uEmbeds, iEmbeds, ancs, poss, negs): 556 | ancEmbeds = uEmbeds[ancs] 557 | posEmbeds = iEmbeds[poss] 558 | negEmbeds = iEmbeds[negs] 559 | posPreds = self.MLP.pairPred(ancEmbeds, posEmbeds) 560 | negPreds = self.MLP.pairPred(ancEmbeds, negEmbeds) 561 | return posPreds - negPreds 562 | 563 | def predAll(self, pckUEmbeds, iEmbeds): 564 | return self.MLP.crossPred(pckUEmbeds, iEmbeds) 565 | 566 | def testPred(self, usr, trnMask): 567 | uEmbeds, iEmbeds = self.forward() 568 | allPreds = self.predAll(uEmbeds[usr], iEmbeds) * (1 - trnMask) - trnMask * 1e8 569 | return allPreds 570 | 571 | -------------------------------------------------------------------------------- /codes/Models_empower.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from time import time 4 | import pickle 5 | import pickle 6 | import scipy.sparse as sp 7 | from scipy.sparse import csr_matrix 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | 14 | from sklearn.decomposition import PCA, FastICA 15 | from sklearn import manifold 16 | from sklearn.manifold import TSNE 17 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 18 | 19 | # from utility.parser import parse_args 20 | from utility.norm import build_sim, build_knn_normalized_graph 21 | # args = parse_args() 22 | from utility.parser import args 23 | 24 | 25 | 26 | class Teacher_Model(nn.Module): 27 | def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats, text_feats): 28 | 29 | super().__init__() 30 | self.n_users = n_users 31 | self.n_items = n_items 32 | self.embedding_dim = embedding_dim 33 | self.weight_size = weight_size 34 | self.n_ui_layers = len(self.weight_size) 35 | self.weight_size = [self.embedding_dim] + self.weight_size 36 | 37 | self.image_trans = nn.Linear(image_feats.shape[1], args.embed_size) 38 | self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size) 39 | nn.init.xavier_uniform_(self.image_trans.weight) 40 | nn.init.xavier_uniform_(self.text_trans.weight) 41 | self.encoder = nn.ModuleDict() 42 | self.encoder['image_encoder'] = self.image_trans # ^-^ 43 | self.encoder['text_encoder'] = self.text_trans # ^-^ 44 | 45 | self.user_id_embedding = nn.Embedding(n_users, self.embedding_dim) 46 | self.item_id_embedding = nn.Embedding(n_items, self.embedding_dim) 47 | 48 | nn.init.xavier_uniform_(self.user_id_embedding.weight) 49 | nn.init.xavier_uniform_(self.item_id_embedding.weight) 50 | self.image_feats = torch.tensor(image_feats).float().cuda() 51 | self.text_feats = torch.tensor(text_feats).float().cuda() 52 | self.image_embedding = nn.Embedding.from_pretrained(torch.Tensor(image_feats), freeze=False) 53 | self.text_embedding = nn.Embedding.from_pretrained(torch.Tensor(text_feats), freeze=False) 54 | 55 | self.softmax = nn.Softmax(dim=-1) 56 | self.act = nn.Sigmoid() 57 | self.sigmoid = nn.Sigmoid() 58 | self.dropout = nn.Dropout(p=args.drop_rate) 59 | self.batch_norm = nn.BatchNorm1d(args.embed_size) 60 | 61 | def mm(self, x, y): 62 | if args.sparse: 63 | return torch.sparse.mm(x, y) 64 | else: 65 | return torch.mm(x, y) 66 | def sim(self, z1, z2): 67 | z1 = F.normalize(z1) 68 | z2 = F.normalize(z2) 69 | return torch.mm(z1, z2.t()) 70 | 71 | def batched_contrastive_loss(self, z1, z2, batch_size=4096): 72 | device = z1.device 73 | num_nodes = z1.size(0) 74 | num_batches = (num_nodes - 1) // batch_size + 1 75 | f = lambda x: torch.exp(x / self.tau) 76 | indices = torch.arange(0, num_nodes).to(device) 77 | losses = [] 78 | 79 | for i in range(num_batches): 80 | mask = indices[i * batch_size:(i + 1) * batch_size] 81 | refl_sim = f(self.sim(z1[mask], z1)) 82 | between_sim = f(self.sim(z1[mask], z2)) 83 | 84 | losses.append(-torch.log( 85 | between_sim[:, i * batch_size:(i + 1) * batch_size].diag() 86 | / (refl_sim.sum(1) + between_sim.sum(1) 87 | - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag()))) 88 | 89 | loss_vec = torch.cat(losses) 90 | return loss_vec.mean() 91 | 92 | def csr_norm(self, csr_mat, mean_flag=False): 93 | rowsum = np.array(csr_mat.sum(1)) 94 | rowsum = np.power(rowsum+1e-8, -0.5).flatten() 95 | rowsum[np.isinf(rowsum)] = 0. 96 | rowsum_diag = sp.diags(rowsum) 97 | 98 | colsum = np.array(csr_mat.sum(0)) 99 | colsum = np.power(colsum+1e-8, -0.5).flatten() 100 | colsum[np.isinf(colsum)] = 0. 101 | colsum_diag = sp.diags(colsum) 102 | 103 | if mean_flag == False: 104 | return rowsum_diag*csr_mat*colsum_diag 105 | else: 106 | return rowsum_diag*csr_mat 107 | 108 | def matrix_to_tensor(self, cur_matrix): 109 | if type(cur_matrix) != sp.coo_matrix: 110 | cur_matrix = cur_matrix.tocoo() # 111 | indices = torch.from_numpy(np.vstack((cur_matrix.row, cur_matrix.col)).astype(np.int64)) # 112 | values = torch.from_numpy(cur_matrix.data) # 113 | shape = torch.Size(cur_matrix.shape) 114 | 115 | return torch.sparse.FloatTensor(indices, values, shape).to(torch.float32).cuda() # 116 | 117 | def para_dict_to_tenser(self, para_dict): 118 | """ 119 | :param para_dict: nn.ParameterDict() 120 | :return: tensor 121 | """ 122 | tensors = [] 123 | 124 | for beh in para_dict.keys(): 125 | tensors.append(para_dict[beh]) 126 | tensors = torch.stack(tensors, dim=0) 127 | 128 | return tensors 129 | 130 | 131 | def multi_head_self_attention(self, trans_w, embedding_t_1, embedding_t): 132 | 133 | q = self.para_dict_to_tenser(embedding_t) 134 | v = k = self.para_dict_to_tenser(embedding_t_1) 135 | beh, N, d_h = q.shape[0], q.shape[1], args.embed_size/args.head_num 136 | 137 | Q = torch.matmul(q, trans_w['w_q']) 138 | K = torch.matmul(k, trans_w['w_k']) 139 | V = v 140 | 141 | Q = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3) 142 | K = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3) 143 | 144 | Q = torch.unsqueeze(Q, 2) 145 | K = torch.unsqueeze(K, 1) 146 | V = torch.unsqueeze(V, 1) 147 | 148 | att = torch.mul(Q, K) / torch.sqrt(torch.tensor(d_h)) 149 | att = torch.sum(att, dim=-1) 150 | att = torch.unsqueeze(att, dim=-1) 151 | att = F.softmax(att, dim=2) 152 | 153 | Z = torch.mul(att, V) 154 | Z = torch.sum(Z, dim=2) 155 | 156 | Z_list = [value for value in Z] 157 | Z = torch.cat(Z_list, -1) 158 | Z = torch.matmul(Z, self.weight_dict['w_self_attention_cat']) 159 | 160 | args.model_cat_rate*F.normalize(Z, p=2, dim=2) 161 | return Z, att.detach() 162 | 163 | # def prompt_tuning(self, soft_token_u, soft_token_i): 164 | # # self.user_id_embedding = nn.Embedding.from_pretrained(soft_token_u, freeze=False) 165 | # # self.item_id_embedding = nn.Embedding.from_pretrained(soft_token_u, freeze=False) 166 | # self.prompt_user = soft_token_u 167 | # self.prompt_item = soft_token_i 168 | 169 | def forward(self, ui_graph, iu_graph, prompt_module=None): 170 | 171 | # def forward(self, ui_graph, iu_graph): 172 | 173 | prompt_user, prompt_item = prompt_module() # [n*32] 174 | # ----feature prompt---- 175 | # feat_prompt_user = torch.mean( torch.stack((torch.mm(prompt_user, torch.mm(prompt_user.T, self.image_feats)), torch.mm(prompt_user, torch.mm(prompt_user.T, self.text_feats)))), dim=0 ) 176 | # feat_prompt_user = torch.mm(prompt_user, torch.mm(prompt_user.T, self.text_feats)) 177 | feat_prompt_item_image = torch.mm(prompt_item, torch.mm(prompt_item.T, self.image_feats)) 178 | feat_prompt_item_text = torch.mm(prompt_item, torch.mm(prompt_item.T, self.text_feats)) 179 | # feat_prompt_image_item = torch.mm(prompt_item, torch.mm(prompt_item.T, self.image_feats)) 180 | # feat_prompt_text_item = torch.mm(prompt_item, torch.mm(prompt_item.T, self.text_feats)) 181 | # ----feature prompt---- 182 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + feat_prompt_item_image )) 183 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + feat_prompt_item_text )) 184 | 185 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + F.normalize(feat_prompt_item_image, p=2, dim=1) )) 186 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + F.normalize(feat_prompt_item_text, p=2, dim=1) )) 187 | 188 | image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + args.feat_soft_token_rate*F.normalize(feat_prompt_item_image, p=2, dim=1) )) 189 | text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + args.feat_soft_token_rate*F.normalize(feat_prompt_item_text, p=2, dim=1) )) 190 | # args.feat_soft_token_rate*F.normalize(feat_prompt_item_image, p=2, dim=1) 191 | # args.feat_soft_token_rate*F.normalize(feat_prompt_item_text, p=2, dim=1) 192 | 193 | 194 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats)) 195 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats)) 196 | 197 | for i in range(args.layers): 198 | image_user_feats = self.mm(ui_graph, image_feats) 199 | image_item_feats = self.mm(iu_graph, image_user_feats) 200 | # image_user_id = self.mm(image_ui_graph, self.item_id_embedding.weight) 201 | # image_item_id = self.mm(image_iu_graph, self.user_id_embedding.weight) 202 | 203 | text_user_feats = self.mm(ui_graph, text_feats) 204 | text_item_feats = self.mm(iu_graph, text_user_feats) 205 | 206 | # text_user_id = self.mm(text_ui_graph, self.item_id_embedding.weight) 207 | # text_item_id = self.mm(text_iu_graph, self.user_id_embedding.weight) 208 | 209 | # self.embedding_dict['user']['image'] = image_user_id 210 | # self.embedding_dict['user']['text'] = text_user_id 211 | # self.embedding_dict['item']['image'] = image_item_id 212 | # self.embedding_dict['item']['text'] = text_item_id 213 | # user_z, att_u = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['user'], self.embedding_dict['user']) 214 | # item_z, att_i = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['item'], self.embedding_dict['item']) 215 | # user_emb = user_z.mean(0) 216 | # item_emb = item_z.mean(0) 217 | u_g_embeddings = self.user_id_embedding.weight + args.soft_token_rate*F.normalize(prompt_user, p=2, dim=1) 218 | i_g_embeddings = self.item_id_embedding.weight + args.soft_token_rate*F.normalize(prompt_item, p=2, dim=1) 219 | user_emb_list = [u_g_embeddings] 220 | item_emb_list = [i_g_embeddings] 221 | for i in range(self.n_ui_layers): 222 | if i == (self.n_ui_layers-1): 223 | u_g_embeddings = self.softmax( torch.mm(ui_graph, i_g_embeddings) ) 224 | i_g_embeddings = self.softmax( torch.mm(iu_graph, u_g_embeddings) ) 225 | 226 | else: 227 | u_g_embeddings = torch.mm(ui_graph, i_g_embeddings) 228 | i_g_embeddings = torch.mm(iu_graph, u_g_embeddings) 229 | 230 | user_emb_list.append(u_g_embeddings) 231 | item_emb_list.append(i_g_embeddings) 232 | 233 | u_g_embeddings = torch.mean(torch.stack(user_emb_list), dim=0) 234 | i_g_embeddings = torch.mean(torch.stack(item_emb_list), dim=0) 235 | 236 | 237 | u_g_embeddings = u_g_embeddings + args.model_cat_rate*F.normalize(image_user_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_user_feats, p=2, dim=1) 238 | i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(image_item_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_item_feats, p=2, dim=1) 239 | 240 | return u_g_embeddings, i_g_embeddings, image_item_feats, text_item_feats, image_user_feats, text_user_feats, u_g_embeddings, i_g_embeddings , prompt_user, prompt_item 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | class PromptLearner(nn.Module): 249 | def __init__(self, image_feats=None, text_feats=None, ui_graph=None): 250 | super().__init__() 251 | self.ui_graph = ui_graph 252 | 253 | 254 | if args.hard_token_type=='pca': 255 | try: 256 | t1 = time() 257 | hard_token_image = pickle.load(open(args.data_path + args.dataset + '/hard_token_image_pca','rb')) 258 | hard_token_text = pickle.load(open(args.data_path + args.dataset + '/hard_token_text_pca','rb')) 259 | print('already load hard token', time() - t1) 260 | except Exception: 261 | hard_token_image = PCA(n_components=args.embed_size).fit_transform(image_feats) 262 | hard_token_text = PCA(n_components=args.embed_size).fit_transform(text_feats) 263 | pickle.dump(hard_token_image, open(args.data_path + args.dataset + '/hard_token_image_pca','wb')) 264 | pickle.dump(hard_token_text, open(args.data_path + args.dataset + '/hard_token_text_pca','wb')) 265 | elif args.hard_token_type=='ica': 266 | try: 267 | t1 = time() 268 | hard_token_image = pickle.load(open(args.data_path + args.dataset + '/hard_token_image_ica','rb')) 269 | hard_token_text = pickle.load(open(args.data_path + args.dataset + '/hard_token_text_ica','rb')) 270 | print('already load hard token', time() - t1) 271 | except Exception: 272 | hard_token_image = FastICA(n_components=args.embed_size, random_state=12).fit_transform(image_feats) 273 | hard_token_text = FastICA(n_components=args.embed_size, random_state=12).fit_transform(text_feats) 274 | pickle.dump(hard_token_image, open(args.data_path + args.dataset + '/hard_token_image_ica','wb')) 275 | pickle.dump(hard_token_text, open(args.data_path + args.dataset + '/hard_token_text_ica','wb')) 276 | elif args.hard_token_type=='isomap': 277 | hard_token_image = manifold.Isomap(n_neighbors=5, n_components=args.embed_size, n_jobs=-1).fit_transform(image_feats) 278 | hard_token_text = manifold.Isomap(n_neighbors=5, n_components=args.embed_size, n_jobs=-1).fit_transform(text_feats) 279 | # elif args.hard_token_type=='tsne': 280 | # hard_token_image = TSNE(n_components=args.embed_size, n_iter=300).fit_transform(image_feats) 281 | # hard_token_text = TSNE(n_components=args.embed_size, n_iter=300).fit_transform(text_feats) 282 | # elif args.hard_token_type=='lda': 283 | # hard_token_image = LinearDiscriminantAnalysis(n_components=args.embed_size).fit_transform(image_feats) 284 | # hard_token_text = LinearDiscriminantAnalysis(n_components=args.embed_size).fit_transform(text_feats) 285 | 286 | # self.item_hard_token = nn.Embedding.from_pretrained(torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0), freeze=False).cuda().weight 287 | # self.user_hard_token = nn.Embedding.from_pretrained(torch.mm(ui_graph, self.item_hard_token), freeze=False).cuda().weight 288 | 289 | self.item_hard_token = torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0).cuda() 290 | self.user_hard_token = torch.mm(ui_graph, self.item_hard_token).cuda() 291 | 292 | self.trans_user = nn.Linear(args.embed_size, args.embed_size).cuda() 293 | self.trans_item = nn.Linear(args.embed_size, args.embed_size).cuda() 294 | # nn.init.xavier_uniform_(self.gnn_trans_user.weight) 295 | # nn.init.xavier_uniform_(self.gnn_trans_item.weight) 296 | # self.gnn_trans_user = self.gnn_trans_user.cuda() 297 | # self.gnn_trans_item = self.gnn_trans_item.cuda() 298 | # self.item_hard_token = torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0).cuda() 299 | 300 | 301 | def forward(self): 302 | # self.user_hard_token = self.gnn_trans_user(torch.mm(self.ui_graph, self.item_hard_token)) 303 | # self.item_hard_token = self.gnn_trans_item(self.item_hard_token) 304 | # return self.user_hard_token , self.item_hard_token 305 | return F.dropout(self.trans_user(self.user_hard_token), args.prompt_dropout) , F.dropout(self.trans_item(self.item_hard_token), args.prompt_dropout) 306 | 307 | 308 | 309 | 310 | class Student_LightGCN(nn.Module): 311 | def __init__(self, n_users, n_items, embedding_dim, gnn_layer, dropout_list, image_feats=None, text_feats=None): 312 | super().__init__() 313 | self.n_users = n_users 314 | self.n_items = n_items 315 | self.embedding_dim = embedding_dim 316 | self.n_ui_layers = gnn_layer 317 | 318 | self.user_id_embedding = nn.Embedding(n_users, embedding_dim) 319 | self.item_id_embedding = nn.Embedding(n_items, embedding_dim) 320 | nn.init.xavier_uniform_(self.user_id_embedding.weight) 321 | nn.init.xavier_uniform_(self.item_id_embedding.weight) 322 | 323 | # self.feat_trans = nn.Linear(args.embed_size, args.student_embed_size) 324 | # # self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size) 325 | # nn.init.xavier_uniform_(self.feat_trans.weight) 326 | # # nn.init.xavier_uniform_(self.text_trans.weight) 327 | 328 | def init_user_item_embed(self, pre_u_embed, pre_i_embed): 329 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False) 330 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False) 331 | 332 | self.user_id_embedding_pre = nn.Embedding.from_pretrained(pre_u_embed, freeze=False) 333 | self.item_id_embedding_pre = nn.Embedding.from_pretrained(pre_i_embed, freeze=False) 334 | 335 | def get_embedding(self): 336 | return self.user_id_embedding, self.item_id_embedding 337 | 338 | def forward(self, adj, image_item_embeds, text_item_embeds, image_user_embeds, text_user_embeds): 339 | 340 | # # teacher_feat_dict = { 'item_image':t_i_image_embed.deteach(),'item_text':t_i_text_embed.deteach(),'user_image':t_u_image_embed.deteach(),'user_text':t_u_text_embed.deteach() } 341 | # tmp_feat_dict = {} 342 | # for index,value in enumerate(teacher_feat_dict.keys()): 343 | # tmp_feat_dict[value] = self.feat_trans(teacher_feat_dict[value]) 344 | # u_g_embeddings = self.user_id_embedding.weight + args.model_cat_rate*F.normalize(tmp_feat_dict['user_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(tmp_feat_dict['user_text'], p=2, dim=1) 345 | # i_g_embeddings = self.item_id_embedding.weight + args.model_cat_rate*F.normalize(tmp_feat_dict['item_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(tmp_feat_dict['item_text'], p=2, dim=1) 346 | # ego_embeddings = torch.cat((u_g_embeddings, i_g_embeddings), dim=0) 347 | 348 | # self.user_id_embedding_pre = nn.Embedding.from_pretrained(pre_u_embed, freeze=False) 349 | # self.item_id_embedding_pre = nn.Embedding.from_pretrained(pre_i_embed, freeze=False) 350 | 351 | ego_embeddings = torch.cat((self.user_id_embedding.weight+self.user_id_embedding_pre.weight, self.item_id_embedding.weight+self.item_id_embedding_pre.weight), dim=0) 352 | # ego_embeddings = torch.cat((self.user_id_embedding.weight, self.item_id_embedding.weight), dim=0) 353 | all_embeddings = [ego_embeddings] 354 | for i in range(self.n_ui_layers): 355 | side_embeddings = torch.sparse.mm(adj, ego_embeddings) 356 | ego_embeddings = side_embeddings 357 | all_embeddings += [ego_embeddings] 358 | all_embeddings = torch.stack(all_embeddings, dim=1) 359 | all_embeddings = all_embeddings.mean(dim=1, keepdim=False) 360 | u_g_embeddings, i_g_embeddings = torch.split(all_embeddings, [self.n_users, self.n_items], dim=0) 361 | # u_g_embeddings += teacher_feat_dict['user_image'] + teacher_feat_dict['user_text'] 362 | # i_g_embeddings += teacher_feat_dict['item_image'] + teacher_feat_dict['item_text'] 363 | u_g_embeddings = u_g_embeddings + args.model_cat_rate*F.normalize(image_user_embeds, p=2, dim=1) + args.model_cat_rate*F.normalize(text_user_embeds, p=2, dim=1) 364 | i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(image_item_embeds, p=2, dim=1) + args.model_cat_rate*F.normalize(text_item_embeds, p=2, dim=1) 365 | 366 | return u_g_embeddings, i_g_embeddings 367 | # return self.user_id_embedding.weight, self.item_id_embedding.weight 368 | 369 | 370 | 371 | class Student_GCN(nn.Module): 372 | def __init__(self, n_users, n_items, embedding_dim, gnn_layer=2, drop_out=0., image_feats=None, text_feats=None): 373 | super(Student_GCN, self).__init__() 374 | self.embedding_dim = embedding_dim 375 | 376 | # self.layers = nn.Sequential(GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=True), 377 | # GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=False), 378 | # ) 379 | # self.layer_list = nn.ModuleList() 380 | # for i in range(args.student_n_layers): 381 | # self.layer_list.append(GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=False)) 382 | 383 | self.trans_user = nn.Linear(args.embed_size, args.embed_size).cuda() 384 | self.trans_item = nn.Linear(args.embed_size, args.embed_size).cuda() 385 | 386 | 387 | def forward(self, user_x, item_x, ui_graph, iu_graph): 388 | # # x, support = inputs 389 | # # user_x, item_x = self.layers((user_x, item_x, ui_graph, iu_graph)) 390 | # for i in range(args.student_n_layers): 391 | # user_x, item_x = self.layer_list[i](user_x, item_x, ui_graph, iu_graph) 392 | # return user_x, item_x 393 | 394 | return self.trans_user(user_x), self.trans_item(item_x) 395 | # self.user_id_embedding = nn.Embedding.from_pretrained(user_x, freeze=True) 396 | # self.item_id_embedding = nn.Embedding.from_pretrained(item_x, freeze=True) 397 | # return self.user_id_embedding.weight, self.item_id_embedding.weight 398 | 399 | def l2_loss(self): 400 | layer = self.layers.children() 401 | layer = next(iter(layer)) 402 | loss = None 403 | 404 | for p in layer.parameters(): 405 | if loss is None: 406 | loss = p.pow(2).sum() 407 | else: 408 | loss += p.pow(2).sum() 409 | 410 | return loss 411 | 412 | class GraphConvolution(nn.Module): 413 | def __init__(self, input_dim, output_dim, dropout=0., is_sparse_inputs=False, bias=False, activation = F.relu,featureless=False): 414 | super(GraphConvolution, self).__init__() 415 | self.dropout = dropout 416 | self.bias = bias 417 | self.activation = activation 418 | self.is_sparse_inputs = is_sparse_inputs 419 | self.featureless = featureless 420 | # self.num_features_nonzero = num_features_nonzero 421 | # self.user_weight = nn.Parameter(torch.randn(input_dim, output_dim)) 422 | # self.item_weight = nn.Parameter(torch.randn(input_dim, output_dim)) 423 | self.user_weight = nn.Parameter(torch.empty(input_dim, output_dim)) 424 | self.item_weight = nn.Parameter(torch.empty(input_dim, output_dim)) 425 | nn.init.xavier_uniform_(self.user_weight) 426 | nn.init.xavier_uniform_(self.item_weight) 427 | self.bias = None 428 | if bias: 429 | self.bias = nn.Parameter(torch.zeros(output_dim)) 430 | 431 | 432 | def forward(self, user_x, item_x, ui_graph, iu_graph): 433 | # print('inputs:', inputs) 434 | # x, support = inputs 435 | # if self.training and self.is_sparse_inputs: 436 | # x = sparse_dropout(x, self.dropout, self.num_features_nonzero) 437 | # elif self.training: 438 | user_x = F.dropout(user_x, self.dropout) 439 | item_x = F.dropout(item_x, self.dropout) 440 | # convolve 441 | if not self.featureless: # if it has features x 442 | if self.is_sparse_inputs: 443 | xw = torch.sparse.mm(user_x, self.user_weight) 444 | xw = torch.sparse.mm(item_x, self.item_weight) 445 | else: 446 | xw_user = torch.mm(user_x, self.user_weight) 447 | xw_item = torch.mm(item_x, self.item_weight) 448 | else: 449 | xw = self.weight 450 | out_user = torch.sparse.mm(ui_graph, xw_item) 451 | out_item = torch.sparse.mm(iu_graph, xw_user) 452 | 453 | if self.bias is not None: 454 | out += self.bias 455 | return self.activation(out_user), self.activation(out_item) 456 | 457 | 458 | def sparse_dropout(x, rate, noise_shape): 459 | """ 460 | :param x: 461 | :param rate: 462 | :param noise_shape: int scalar 463 | :return: 464 | """ 465 | random_tensor = 1 - rate 466 | random_tensor += torch.rand(noise_shape).to(x.device) 467 | dropout_mask = torch.floor(random_tensor).byte() 468 | i = x._indices() # [2, 49216] 469 | v = x._values() # [49216] 470 | # [2, 4926] => [49216, 2] => [remained node, 2] => [2, remained node] 471 | i = i[:, dropout_mask] 472 | v = v[dropout_mask] 473 | out = torch.sparse.FloatTensor(i, v, x.shape).to(x.device) 474 | out = out * (1./ (1-rate)) 475 | return out 476 | 477 | 478 | def dot(x, y, sparse=False): 479 | if sparse: 480 | res = torch.sparse.mm(x, y) 481 | else: 482 | res = torch.mm(x, y) 483 | return res 484 | 485 | 486 | 487 | 488 | 489 | class BLMLP(nn.Module): 490 | def __init__(self): 491 | super(BLMLP, self).__init__() 492 | self.W = nn.Parameter(nn.init.xavier_uniform_(torch.empty(args.student_embed_size, args.student_embed_size))) 493 | self.act = nn.LeakyReLU(negative_slope=0.5) 494 | 495 | def forward(self, embeds): 496 | pass 497 | 498 | def featureExtract(self, embeds): 499 | return self.act(embeds @ self.W) + embeds 500 | 501 | def pairPred(self, embeds1, embeds2): 502 | return (self.featureExtract(embeds1) * self.featureExtract(embeds2)).sum(dim=-1) 503 | 504 | def crossPred(self, embeds1, embeds2): 505 | return self.featureExtract(embeds1) @ self.featureExtract(embeds2).T 506 | 507 | 508 | 509 | class Student_MLP(nn.Module): 510 | def __init__(self): 511 | super(Student_MLP, self).__init__() 512 | # self.n_users = n_users 513 | # self.n_items = n_items 514 | # self.embedding_dim = embedding_dim 515 | 516 | # self.uEmbeds = nn.Parameter(init(torch.empty(args.user, args.latdim))) 517 | # self.iEmbeds = nn.Parameter(init(torch.empty(args.item, args.latdim))) 518 | 519 | self.user_trans = nn.Linear(args.embed_size, args.embed_size) 520 | self.item_trans = nn.Linear(args.embed_size, args.embed_size) 521 | nn.init.xavier_uniform_(self.user_trans.weight) 522 | nn.init.xavier_uniform_(self.item_trans.weight) 523 | 524 | self.MLP = BLMLP() 525 | # self.overallTime = datetime.timedelta(0) 526 | 527 | 528 | def get_embedding(self): 529 | return self.user_id_embedding, self.item_id_embedding 530 | 531 | 532 | def forward(self, pre_user, pre_item, ): 533 | # pre_user, pre_item = self.user_id_embedding.weight, self.item_id_embedding.weight 534 | user_embed = self.user_trans(pre_user) 535 | item_embed = self.user_trans(pre_item) 536 | 537 | return user_embed, item_embed 538 | # return pre_user, pre_item 539 | 540 | def init_user_item_embed(self, pre_u_embed, pre_i_embed): 541 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False) 542 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False) 543 | 544 | def pointPosPredictwEmbeds(self, uEmbeds, iEmbeds, ancs, poss): 545 | ancEmbeds = uEmbeds[ancs] 546 | posEmbeds = iEmbeds[poss] 547 | nume = self.MLP.pairPred(ancEmbeds, posEmbeds) 548 | return nume 549 | 550 | def pointNegPredictwEmbeds(self, embeds1, embeds2, nodes1, temp=1.0): 551 | pckEmbeds1 = embeds1[nodes1] 552 | preds = self.MLP.crossPred(pckEmbeds1, embeds2) 553 | return torch.exp(preds / temp).sum(-1) 554 | 555 | def pairPredictwEmbeds(self, uEmbeds, iEmbeds, ancs, poss, negs): 556 | ancEmbeds = uEmbeds[ancs] 557 | posEmbeds = iEmbeds[poss] 558 | negEmbeds = iEmbeds[negs] 559 | posPreds = self.MLP.pairPred(ancEmbeds, posEmbeds) 560 | negPreds = self.MLP.pairPred(ancEmbeds, negEmbeds) 561 | return posPreds - negPreds 562 | 563 | def predAll(self, pckUEmbeds, iEmbeds): 564 | return self.MLP.crossPred(pckUEmbeds, iEmbeds) 565 | 566 | def testPred(self, usr, trnMask): 567 | uEmbeds, iEmbeds = self.forward() 568 | allPreds = self.predAll(uEmbeds[usr], iEmbeds) * (1 - trnMask) - trnMask * 1e8 569 | return allPreds 570 | 571 | -------------------------------------------------------------------------------- /codes/Models_mmlight.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from time import time 4 | import pickle 5 | import pickle 6 | import scipy.sparse as sp 7 | from scipy.sparse import csr_matrix 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | 14 | from sklearn.decomposition import PCA, FastICA 15 | from sklearn import manifold 16 | from sklearn.manifold import TSNE 17 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 18 | 19 | # from utility.parser import parse_args 20 | from utility.norm import build_sim, build_knn_normalized_graph 21 | # args = parse_args() 22 | from utility.parser import args 23 | 24 | 25 | 26 | class Teacher_Model(nn.Module): 27 | def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats, text_feats): 28 | 29 | super().__init__() 30 | self.n_users = n_users 31 | self.n_items = n_items 32 | self.embedding_dim = embedding_dim 33 | self.weight_size = weight_size 34 | self.n_ui_layers = len(self.weight_size) 35 | self.weight_size = [self.embedding_dim] + self.weight_size 36 | 37 | self.image_trans = nn.Linear(image_feats.shape[1], args.embed_size) 38 | self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size) 39 | nn.init.xavier_uniform_(self.image_trans.weight) 40 | nn.init.xavier_uniform_(self.text_trans.weight) 41 | self.encoder = nn.ModuleDict() 42 | self.encoder['image_encoder'] = self.image_trans # ^-^ 43 | self.encoder['text_encoder'] = self.text_trans # ^-^ 44 | 45 | self.user_id_embedding = nn.Embedding(n_users, self.embedding_dim) 46 | self.item_id_embedding = nn.Embedding(n_items, self.embedding_dim) 47 | 48 | nn.init.xavier_uniform_(self.user_id_embedding.weight) 49 | nn.init.xavier_uniform_(self.item_id_embedding.weight) 50 | self.image_feats = torch.tensor(image_feats).float().cuda() 51 | self.text_feats = torch.tensor(text_feats).float().cuda() 52 | self.image_embedding = nn.Embedding.from_pretrained(torch.Tensor(image_feats), freeze=False) 53 | self.text_embedding = nn.Embedding.from_pretrained(torch.Tensor(text_feats), freeze=False) 54 | 55 | self.softmax = nn.Softmax(dim=-1) 56 | self.act = nn.Sigmoid() 57 | self.sigmoid = nn.Sigmoid() 58 | self.dropout = nn.Dropout(p=args.drop_rate) 59 | self.batch_norm = nn.BatchNorm1d(args.embed_size) 60 | 61 | def mm(self, x, y): 62 | if args.sparse: 63 | return torch.sparse.mm(x, y) 64 | else: 65 | return torch.mm(x, y) 66 | def sim(self, z1, z2): 67 | z1 = F.normalize(z1) 68 | z2 = F.normalize(z2) 69 | return torch.mm(z1, z2.t()) 70 | 71 | def batched_contrastive_loss(self, z1, z2, batch_size=4096): 72 | device = z1.device 73 | num_nodes = z1.size(0) 74 | num_batches = (num_nodes - 1) // batch_size + 1 75 | f = lambda x: torch.exp(x / self.tau) 76 | indices = torch.arange(0, num_nodes).to(device) 77 | losses = [] 78 | 79 | for i in range(num_batches): 80 | mask = indices[i * batch_size:(i + 1) * batch_size] 81 | refl_sim = f(self.sim(z1[mask], z1)) 82 | between_sim = f(self.sim(z1[mask], z2)) 83 | 84 | losses.append(-torch.log( 85 | between_sim[:, i * batch_size:(i + 1) * batch_size].diag() 86 | / (refl_sim.sum(1) + between_sim.sum(1) 87 | - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag()))) 88 | 89 | loss_vec = torch.cat(losses) 90 | return loss_vec.mean() 91 | 92 | def csr_norm(self, csr_mat, mean_flag=False): 93 | rowsum = np.array(csr_mat.sum(1)) 94 | rowsum = np.power(rowsum+1e-8, -0.5).flatten() 95 | rowsum[np.isinf(rowsum)] = 0. 96 | rowsum_diag = sp.diags(rowsum) 97 | 98 | colsum = np.array(csr_mat.sum(0)) 99 | colsum = np.power(colsum+1e-8, -0.5).flatten() 100 | colsum[np.isinf(colsum)] = 0. 101 | colsum_diag = sp.diags(colsum) 102 | 103 | if mean_flag == False: 104 | return rowsum_diag*csr_mat*colsum_diag 105 | else: 106 | return rowsum_diag*csr_mat 107 | 108 | def matrix_to_tensor(self, cur_matrix): 109 | if type(cur_matrix) != sp.coo_matrix: 110 | cur_matrix = cur_matrix.tocoo() # 111 | indices = torch.from_numpy(np.vstack((cur_matrix.row, cur_matrix.col)).astype(np.int64)) # 112 | values = torch.from_numpy(cur_matrix.data) # 113 | shape = torch.Size(cur_matrix.shape) 114 | 115 | return torch.sparse.FloatTensor(indices, values, shape).to(torch.float32).cuda() # 116 | 117 | def para_dict_to_tenser(self, para_dict): 118 | """ 119 | :param para_dict: nn.ParameterDict() 120 | :return: tensor 121 | """ 122 | tensors = [] 123 | 124 | for beh in para_dict.keys(): 125 | tensors.append(para_dict[beh]) 126 | tensors = torch.stack(tensors, dim=0) 127 | 128 | return tensors 129 | 130 | 131 | def multi_head_self_attention(self, trans_w, embedding_t_1, embedding_t): 132 | 133 | q = self.para_dict_to_tenser(embedding_t) 134 | v = k = self.para_dict_to_tenser(embedding_t_1) 135 | beh, N, d_h = q.shape[0], q.shape[1], args.embed_size/args.head_num 136 | 137 | Q = torch.matmul(q, trans_w['w_q']) 138 | K = torch.matmul(k, trans_w['w_k']) 139 | V = v 140 | 141 | Q = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3) 142 | K = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3) 143 | 144 | Q = torch.unsqueeze(Q, 2) 145 | K = torch.unsqueeze(K, 1) 146 | V = torch.unsqueeze(V, 1) 147 | 148 | att = torch.mul(Q, K) / torch.sqrt(torch.tensor(d_h)) 149 | att = torch.sum(att, dim=-1) 150 | att = torch.unsqueeze(att, dim=-1) 151 | att = F.softmax(att, dim=2) 152 | 153 | Z = torch.mul(att, V) 154 | Z = torch.sum(Z, dim=2) 155 | 156 | Z_list = [value for value in Z] 157 | Z = torch.cat(Z_list, -1) 158 | Z = torch.matmul(Z, self.weight_dict['w_self_attention_cat']) 159 | 160 | args.model_cat_rate*F.normalize(Z, p=2, dim=2) 161 | return Z, att.detach() 162 | 163 | # def prompt_tuning(self, soft_token_u, soft_token_i): 164 | # # self.user_id_embedding = nn.Embedding.from_pretrained(soft_token_u, freeze=False) 165 | # # self.item_id_embedding = nn.Embedding.from_pretrained(soft_token_u, freeze=False) 166 | # self.prompt_user = soft_token_u 167 | # self.prompt_item = soft_token_i 168 | 169 | def forward(self, ui_graph, iu_graph, prompt_module=None): 170 | 171 | # def forward(self, ui_graph, iu_graph): 172 | 173 | prompt_user, prompt_item = prompt_module() # [n*32] 174 | # ----feature prompt---- 175 | # feat_prompt_user = torch.mean( torch.stack((torch.mm(prompt_user, torch.mm(prompt_user.T, self.image_feats)), torch.mm(prompt_user, torch.mm(prompt_user.T, self.text_feats)))), dim=0 ) 176 | # feat_prompt_user = torch.mm(prompt_user, torch.mm(prompt_user.T, self.text_feats)) 177 | feat_prompt_item_image = torch.mm(prompt_item, torch.mm(prompt_item.T, self.image_feats)) 178 | feat_prompt_item_text = torch.mm(prompt_item, torch.mm(prompt_item.T, self.text_feats)) 179 | # feat_prompt_image_item = torch.mm(prompt_item, torch.mm(prompt_item.T, self.image_feats)) 180 | # feat_prompt_text_item = torch.mm(prompt_item, torch.mm(prompt_item.T, self.text_feats)) 181 | # ----feature prompt---- 182 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + feat_prompt_item_image )) 183 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + feat_prompt_item_text )) 184 | 185 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + F.normalize(feat_prompt_item_image, p=2, dim=1) )) 186 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + F.normalize(feat_prompt_item_text, p=2, dim=1) )) 187 | 188 | image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + args.feat_soft_token_rate*F.normalize(feat_prompt_item_image, p=2, dim=1) )) 189 | text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + args.feat_soft_token_rate*F.normalize(feat_prompt_item_text, p=2, dim=1) )) 190 | # args.feat_soft_token_rate*F.normalize(feat_prompt_item_image, p=2, dim=1) 191 | # args.feat_soft_token_rate*F.normalize(feat_prompt_item_text, p=2, dim=1) 192 | 193 | 194 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats)) 195 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats)) 196 | 197 | for i in range(args.layers): 198 | image_user_feats = self.mm(ui_graph, image_feats) 199 | image_item_feats = self.mm(iu_graph, image_user_feats) 200 | # image_user_id = self.mm(image_ui_graph, self.item_id_embedding.weight) 201 | # image_item_id = self.mm(image_iu_graph, self.user_id_embedding.weight) 202 | 203 | text_user_feats = self.mm(ui_graph, text_feats) 204 | text_item_feats = self.mm(iu_graph, text_user_feats) 205 | 206 | # text_user_id = self.mm(text_ui_graph, self.item_id_embedding.weight) 207 | # text_item_id = self.mm(text_iu_graph, self.user_id_embedding.weight) 208 | 209 | # self.embedding_dict['user']['image'] = image_user_id 210 | # self.embedding_dict['user']['text'] = text_user_id 211 | # self.embedding_dict['item']['image'] = image_item_id 212 | # self.embedding_dict['item']['text'] = text_item_id 213 | # user_z, att_u = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['user'], self.embedding_dict['user']) 214 | # item_z, att_i = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['item'], self.embedding_dict['item']) 215 | # user_emb = user_z.mean(0) 216 | # item_emb = item_z.mean(0) 217 | u_g_embeddings = self.user_id_embedding.weight + args.soft_token_rate*F.normalize(prompt_user, p=2, dim=1) 218 | i_g_embeddings = self.item_id_embedding.weight + args.soft_token_rate*F.normalize(prompt_item, p=2, dim=1) 219 | user_emb_list = [u_g_embeddings] 220 | item_emb_list = [i_g_embeddings] 221 | for i in range(self.n_ui_layers): 222 | if i == (self.n_ui_layers-1): 223 | u_g_embeddings = self.softmax( torch.mm(ui_graph, i_g_embeddings) ) 224 | i_g_embeddings = self.softmax( torch.mm(iu_graph, u_g_embeddings) ) 225 | 226 | else: 227 | u_g_embeddings = torch.mm(ui_graph, i_g_embeddings) 228 | i_g_embeddings = torch.mm(iu_graph, u_g_embeddings) 229 | 230 | user_emb_list.append(u_g_embeddings) 231 | item_emb_list.append(i_g_embeddings) 232 | 233 | u_g_embeddings = torch.mean(torch.stack(user_emb_list), dim=0) 234 | i_g_embeddings = torch.mean(torch.stack(item_emb_list), dim=0) 235 | 236 | 237 | u_g_embeddings = u_g_embeddings + args.model_cat_rate*F.normalize(image_user_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_user_feats, p=2, dim=1) 238 | i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(image_item_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_item_feats, p=2, dim=1) 239 | 240 | return u_g_embeddings, i_g_embeddings, image_item_feats, text_item_feats, image_user_feats, text_user_feats, u_g_embeddings, i_g_embeddings , prompt_user, prompt_item 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | class PromptLearner(nn.Module): 249 | def __init__(self, image_feats=None, text_feats=None, ui_graph=None): 250 | super().__init__() 251 | self.ui_graph = ui_graph 252 | 253 | 254 | if args.hard_token_type=='pca': 255 | try: 256 | t1 = time() 257 | hard_token_image = pickle.load(open(args.data_path + args.dataset + '/hard_token_image_pca','rb')) 258 | hard_token_text = pickle.load(open(args.data_path + args.dataset + '/hard_token_text_pca','rb')) 259 | print('already load hard token', time() - t1) 260 | except Exception: 261 | hard_token_image = PCA(n_components=args.embed_size).fit_transform(image_feats) 262 | hard_token_text = PCA(n_components=args.embed_size).fit_transform(text_feats) 263 | pickle.dump(hard_token_image, open(args.data_path + args.dataset + '/hard_token_image_pca','wb')) 264 | pickle.dump(hard_token_text, open(args.data_path + args.dataset + '/hard_token_text_pca','wb')) 265 | elif args.hard_token_type=='ica': 266 | try: 267 | t1 = time() 268 | hard_token_image = pickle.load(open(args.data_path + args.dataset + '/hard_token_image_ica','rb')) 269 | hard_token_text = pickle.load(open(args.data_path + args.dataset + '/hard_token_text_ica','rb')) 270 | print('already load hard token', time() - t1) 271 | except Exception: 272 | hard_token_image = FastICA(n_components=args.embed_size, random_state=12).fit_transform(image_feats) 273 | hard_token_text = FastICA(n_components=args.embed_size, random_state=12).fit_transform(text_feats) 274 | pickle.dump(hard_token_image, open(args.data_path + args.dataset + '/hard_token_image_ica','wb')) 275 | pickle.dump(hard_token_text, open(args.data_path + args.dataset + '/hard_token_text_ica','wb')) 276 | elif args.hard_token_type=='isomap': 277 | hard_token_image = manifold.Isomap(n_neighbors=5, n_components=args.embed_size, n_jobs=-1).fit_transform(image_feats) 278 | hard_token_text = manifold.Isomap(n_neighbors=5, n_components=args.embed_size, n_jobs=-1).fit_transform(text_feats) 279 | # elif args.hard_token_type=='tsne': 280 | # hard_token_image = TSNE(n_components=args.embed_size, n_iter=300).fit_transform(image_feats) 281 | # hard_token_text = TSNE(n_components=args.embed_size, n_iter=300).fit_transform(text_feats) 282 | # elif args.hard_token_type=='lda': 283 | # hard_token_image = LinearDiscriminantAnalysis(n_components=args.embed_size).fit_transform(image_feats) 284 | # hard_token_text = LinearDiscriminantAnalysis(n_components=args.embed_size).fit_transform(text_feats) 285 | 286 | # self.item_hard_token = nn.Embedding.from_pretrained(torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0), freeze=False).cuda().weight 287 | # self.user_hard_token = nn.Embedding.from_pretrained(torch.mm(ui_graph, self.item_hard_token), freeze=False).cuda().weight 288 | 289 | self.item_hard_token = torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0).cuda() 290 | self.user_hard_token = torch.mm(ui_graph, self.item_hard_token).cuda() 291 | 292 | self.trans_user = nn.Linear(args.embed_size, args.embed_size).cuda() 293 | self.trans_item = nn.Linear(args.embed_size, args.embed_size).cuda() 294 | # nn.init.xavier_uniform_(self.gnn_trans_user.weight) 295 | # nn.init.xavier_uniform_(self.gnn_trans_item.weight) 296 | # self.gnn_trans_user = self.gnn_trans_user.cuda() 297 | # self.gnn_trans_item = self.gnn_trans_item.cuda() 298 | # self.item_hard_token = torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0).cuda() 299 | 300 | 301 | def forward(self): 302 | # self.user_hard_token = self.gnn_trans_user(torch.mm(self.ui_graph, self.item_hard_token)) 303 | # self.item_hard_token = self.gnn_trans_item(self.item_hard_token) 304 | # return self.user_hard_token , self.item_hard_token 305 | return F.dropout(self.trans_user(self.user_hard_token), args.prompt_dropout) , F.dropout(self.trans_item(self.item_hard_token), args.prompt_dropout) 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | class Student_MMLight(nn.Module): 324 | def __init__(self, n_users, n_items, embedding_dim, n_layers, dropout, image_feats, text_feats, ui_graph, iu_graph): 325 | 326 | super().__init__() 327 | self.n_users = n_users 328 | self.n_items = n_items 329 | self.ui_graph = ui_graph 330 | self.iu_graph = iu_graph 331 | self.embedding_dim = embedding_dim 332 | # self.weight_size = weight_size 333 | self.n_ui_layers = n_layers 334 | # self.weight_size = [self.embedding_dim] + self.weight_size 335 | 336 | # self.image_trans = nn.Linear(image_feats.shape[1], args.embed_size) 337 | # self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size) 338 | self.image_trans = nn.Linear(args.embed_size, args.embed_size) 339 | self.text_trans = nn.Linear(args.embed_size, args.embed_size) 340 | nn.init.xavier_uniform_(self.image_trans.weight) 341 | nn.init.xavier_uniform_(self.text_trans.weight) 342 | 343 | self.user_id_embedding = nn.Embedding(n_users, self.embedding_dim) 344 | self.item_id_embedding = nn.Embedding(n_items, self.embedding_dim) 345 | 346 | nn.init.xavier_uniform_(self.user_id_embedding.weight) 347 | nn.init.xavier_uniform_(self.item_id_embedding.weight) 348 | self.image_feats = torch.tensor(image_feats).float().cuda() 349 | self.text_feats = torch.tensor(text_feats).float().cuda() 350 | 351 | self.softmax = nn.Softmax(dim=-1) 352 | self.act = nn.Sigmoid() 353 | self.sigmoid = nn.Sigmoid() 354 | self.dropout = nn.Dropout(args.drop_rate) 355 | self.batch_norm = nn.BatchNorm1d(args.embed_size) 356 | self.tau = 0.5 357 | 358 | 359 | def mm(self, x, y): 360 | if args.sparse: 361 | return torch.sparse.mm(x, y) 362 | else: 363 | return torch.mm(x, y) 364 | def sim(self, z1, z2): 365 | z1 = F.normalize(z1) 366 | z2 = F.normalize(z2) 367 | return torch.mm(z1, z2.t()) 368 | 369 | 370 | def init_user_item_embed(self, pre_u_embed, pre_i_embed): 371 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False) 372 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False) 373 | 374 | def forward(self, image_feats, text_feats, image_user_embeds, text_user_embeds): 375 | image_feats = image_item_feats = self.dropout(self.image_trans(image_feats)) 376 | text_feats = text_item_feats = self.dropout(self.text_trans(text_feats)) 377 | 378 | # image_user_feats = self.mm(self.ui_graph, image_feats) 379 | # image_item_feats = self.mm(self.iu_graph, image_user_feats) 380 | # text_user_feats = self.mm(self.ui_graph, text_feats) 381 | # text_item_feats = self.mm(self.iu_graph, text_user_feats) 382 | 383 | u_g_embeddings = self.user_id_embedding.weight 384 | i_g_embeddings = self.item_id_embedding.weight 385 | 386 | user_emb_list = [u_g_embeddings] 387 | item_emb_list = [i_g_embeddings] 388 | for i in range(self.n_ui_layers): 389 | if i == (self.n_ui_layers-1): 390 | u_g_embeddings = self.softmax( torch.mm(self.ui_graph, i_g_embeddings) ) 391 | i_g_embeddings = self.softmax( torch.mm(self.iu_graph, u_g_embeddings) ) 392 | else: 393 | u_g_embeddings = torch.mm(self.ui_graph, i_g_embeddings) 394 | i_g_embeddings = torch.mm(self.iu_graph, u_g_embeddings) 395 | 396 | user_emb_list.append(u_g_embeddings) 397 | item_emb_list.append(i_g_embeddings) 398 | 399 | u_g_embeddings = torch.mean(torch.stack(user_emb_list), dim=0) 400 | i_g_embeddings = torch.mean(torch.stack(item_emb_list), dim=0) 401 | 402 | 403 | u_g_embeddings = u_g_embeddings #+ args.model_cat_rate*F.normalize(image_user_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_user_feats, p=2, dim=1) 404 | i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(image_item_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_item_feats, p=2, dim=1) 405 | 406 | # return u_g_embeddings, i_g_embeddings, image_item_feats, text_item_feats, image_user_feats, text_user_feats, u_g_embeddings, i_g_embeddings 407 | return u_g_embeddings, i_g_embeddings 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | class Student_LightGCN(nn.Module): 431 | def __init__(self, n_users, n_items, embedding_dim, gnn_layer, dropout_list, image_feats=None, text_feats=None): 432 | super().__init__() 433 | self.n_users = n_users 434 | self.n_items = n_items 435 | self.embedding_dim = embedding_dim 436 | self.n_ui_layers = gnn_layer 437 | 438 | self.user_id_embedding = nn.Embedding(n_users, embedding_dim) 439 | self.item_id_embedding = nn.Embedding(n_items, embedding_dim) 440 | nn.init.xavier_uniform_(self.user_id_embedding.weight) 441 | nn.init.xavier_uniform_(self.item_id_embedding.weight) 442 | 443 | # self.feat_trans = nn.Linear(args.embed_size, args.student_embed_size) 444 | # # self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size) 445 | # nn.init.xavier_uniform_(self.feat_trans.weight) 446 | # # nn.init.xavier_uniform_(self.text_trans.weight) 447 | 448 | def init_user_item_embed(self, pre_u_embed, pre_i_embed): 449 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False) 450 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False) 451 | 452 | self.user_id_embedding_pre = nn.Embedding.from_pretrained(pre_u_embed, freeze=False) 453 | self.item_id_embedding_pre = nn.Embedding.from_pretrained(pre_i_embed, freeze=False) 454 | 455 | def get_embedding(self): 456 | return self.user_id_embedding, self.item_id_embedding 457 | 458 | def forward(self, adj, image_item_embeds, text_item_embeds, image_user_embeds, text_user_embeds): 459 | 460 | # # teacher_feat_dict = { 'item_image':t_i_image_embed.deteach(),'item_text':t_i_text_embed.deteach(),'user_image':t_u_image_embed.deteach(),'user_text':t_u_text_embed.deteach() } 461 | # tmp_feat_dict = {} 462 | # for index,value in enumerate(teacher_feat_dict.keys()): 463 | # tmp_feat_dict[value] = self.feat_trans(teacher_feat_dict[value]) 464 | # u_g_embeddings = self.user_id_embedding.weight + args.model_cat_rate*F.normalize(tmp_feat_dict['user_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(tmp_feat_dict['user_text'], p=2, dim=1) 465 | # i_g_embeddings = self.item_id_embedding.weight + args.model_cat_rate*F.normalize(tmp_feat_dict['item_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(tmp_feat_dict['item_text'], p=2, dim=1) 466 | # ego_embeddings = torch.cat((u_g_embeddings, i_g_embeddings), dim=0) 467 | 468 | # self.user_id_embedding_pre = nn.Embedding.from_pretrained(pre_u_embed, freeze=False) 469 | # self.item_id_embedding_pre = nn.Embedding.from_pretrained(pre_i_embed, freeze=False) 470 | 471 | ego_embeddings = torch.cat((self.user_id_embedding.weight+self.user_id_embedding_pre.weight, self.item_id_embedding.weight+self.item_id_embedding_pre.weight), dim=0) 472 | # ego_embeddings = torch.cat((self.user_id_embedding.weight, self.item_id_embedding.weight), dim=0) 473 | all_embeddings = [ego_embeddings] 474 | for i in range(self.n_ui_layers): 475 | side_embeddings = torch.sparse.mm(adj, ego_embeddings) 476 | ego_embeddings = side_embeddings 477 | all_embeddings += [ego_embeddings] 478 | all_embeddings = torch.stack(all_embeddings, dim=1) 479 | all_embeddings = all_embeddings.mean(dim=1, keepdim=False) 480 | u_g_embeddings, i_g_embeddings = torch.split(all_embeddings, [self.n_users, self.n_items], dim=0) 481 | # u_g_embeddings += teacher_feat_dict['user_image'] + teacher_feat_dict['user_text'] 482 | # i_g_embeddings += teacher_feat_dict['item_image'] + teacher_feat_dict['item_text'] 483 | u_g_embeddings = u_g_embeddings + args.model_cat_rate*F.normalize(image_user_embeds, p=2, dim=1) + args.model_cat_rate*F.normalize(text_user_embeds, p=2, dim=1) 484 | i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(image_item_embeds, p=2, dim=1) + args.model_cat_rate*F.normalize(text_item_embeds, p=2, dim=1) 485 | 486 | return u_g_embeddings, i_g_embeddings 487 | # return self.user_id_embedding.weight, self.item_id_embedding.weight 488 | 489 | 490 | 491 | class Student_GCN(nn.Module): 492 | def __init__(self, n_users, n_items, embedding_dim, gnn_layer=2, drop_out=0., image_feats=None, text_feats=None): 493 | super(Student_GCN, self).__init__() 494 | self.embedding_dim = embedding_dim 495 | 496 | # self.layers = nn.Sequential(GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=True), 497 | # GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=False), 498 | # ) 499 | # self.layer_list = nn.ModuleList() 500 | # for i in range(args.student_n_layers): 501 | # self.layer_list.append(GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=False)) 502 | 503 | self.trans_user = nn.Linear(args.embed_size, args.embed_size).cuda() 504 | self.trans_item = nn.Linear(args.embed_size, args.embed_size).cuda() 505 | 506 | 507 | def forward(self, user_x, item_x, ui_graph, iu_graph): 508 | # # x, support = inputs 509 | # # user_x, item_x = self.layers((user_x, item_x, ui_graph, iu_graph)) 510 | # for i in range(args.student_n_layers): 511 | # user_x, item_x = self.layer_list[i](user_x, item_x, ui_graph, iu_graph) 512 | # return user_x, item_x 513 | 514 | return self.trans_user(user_x), self.trans_item(item_x) 515 | # self.user_id_embedding = nn.Embedding.from_pretrained(user_x, freeze=True) 516 | # self.item_id_embedding = nn.Embedding.from_pretrained(item_x, freeze=True) 517 | # return self.user_id_embedding.weight, self.item_id_embedding.weight 518 | 519 | def l2_loss(self): 520 | layer = self.layers.children() 521 | layer = next(iter(layer)) 522 | loss = None 523 | 524 | for p in layer.parameters(): 525 | if loss is None: 526 | loss = p.pow(2).sum() 527 | else: 528 | loss += p.pow(2).sum() 529 | 530 | return loss 531 | 532 | class GraphConvolution(nn.Module): 533 | def __init__(self, input_dim, output_dim, dropout=0., is_sparse_inputs=False, bias=False, activation = F.relu,featureless=False): 534 | super(GraphConvolution, self).__init__() 535 | self.dropout = dropout 536 | self.bias = bias 537 | self.activation = activation 538 | self.is_sparse_inputs = is_sparse_inputs 539 | self.featureless = featureless 540 | # self.num_features_nonzero = num_features_nonzero 541 | # self.user_weight = nn.Parameter(torch.randn(input_dim, output_dim)) 542 | # self.item_weight = nn.Parameter(torch.randn(input_dim, output_dim)) 543 | self.user_weight = nn.Parameter(torch.empty(input_dim, output_dim)) 544 | self.item_weight = nn.Parameter(torch.empty(input_dim, output_dim)) 545 | nn.init.xavier_uniform_(self.user_weight) 546 | nn.init.xavier_uniform_(self.item_weight) 547 | self.bias = None 548 | if bias: 549 | self.bias = nn.Parameter(torch.zeros(output_dim)) 550 | 551 | 552 | def forward(self, user_x, item_x, ui_graph, iu_graph): 553 | # print('inputs:', inputs) 554 | # x, support = inputs 555 | # if self.training and self.is_sparse_inputs: 556 | # x = sparse_dropout(x, self.dropout, self.num_features_nonzero) 557 | # elif self.training: 558 | user_x = F.dropout(user_x, self.dropout) 559 | item_x = F.dropout(item_x, self.dropout) 560 | # convolve 561 | if not self.featureless: # if it has features x 562 | if self.is_sparse_inputs: 563 | xw = torch.sparse.mm(user_x, self.user_weight) 564 | xw = torch.sparse.mm(item_x, self.item_weight) 565 | else: 566 | xw_user = torch.mm(user_x, self.user_weight) 567 | xw_item = torch.mm(item_x, self.item_weight) 568 | else: 569 | xw = self.weight 570 | out_user = torch.sparse.mm(ui_graph, xw_item) 571 | out_item = torch.sparse.mm(iu_graph, xw_user) 572 | 573 | if self.bias is not None: 574 | out += self.bias 575 | return self.activation(out_user), self.activation(out_item) 576 | 577 | 578 | def sparse_dropout(x, rate, noise_shape): 579 | """ 580 | :param x: 581 | :param rate: 582 | :param noise_shape: int scalar 583 | :return: 584 | """ 585 | random_tensor = 1 - rate 586 | random_tensor += torch.rand(noise_shape).to(x.device) 587 | dropout_mask = torch.floor(random_tensor).byte() 588 | i = x._indices() # [2, 49216] 589 | v = x._values() # [49216] 590 | # [2, 4926] => [49216, 2] => [remained node, 2] => [2, remained node] 591 | i = i[:, dropout_mask] 592 | v = v[dropout_mask] 593 | out = torch.sparse.FloatTensor(i, v, x.shape).to(x.device) 594 | out = out * (1./ (1-rate)) 595 | return out 596 | 597 | 598 | def dot(x, y, sparse=False): 599 | if sparse: 600 | res = torch.sparse.mm(x, y) 601 | else: 602 | res = torch.mm(x, y) 603 | return res 604 | 605 | 606 | 607 | 608 | 609 | class BLMLP(nn.Module): 610 | def __init__(self): 611 | super(BLMLP, self).__init__() 612 | self.W = nn.Parameter(nn.init.xavier_uniform_(torch.empty(args.student_embed_size, args.student_embed_size))) 613 | self.act = nn.LeakyReLU(negative_slope=0.5) 614 | 615 | def forward(self, embeds): 616 | pass 617 | 618 | def featureExtract(self, embeds): 619 | return self.act(embeds @ self.W) + embeds 620 | 621 | def pairPred(self, embeds1, embeds2): 622 | return (self.featureExtract(embeds1) * self.featureExtract(embeds2)).sum(dim=-1) 623 | 624 | def crossPred(self, embeds1, embeds2): 625 | return self.featureExtract(embeds1) @ self.featureExtract(embeds2).T 626 | 627 | 628 | 629 | class Student_MLP(nn.Module): 630 | def __init__(self): 631 | super(Student_MLP, self).__init__() 632 | # self.n_users = n_users 633 | # self.n_items = n_items 634 | # self.embedding_dim = embedding_dim 635 | 636 | # self.uEmbeds = nn.Parameter(init(torch.empty(args.user, args.latdim))) 637 | # self.iEmbeds = nn.Parameter(init(torch.empty(args.item, args.latdim))) 638 | 639 | self.user_trans = nn.Linear(args.embed_size, args.embed_size) 640 | self.item_trans = nn.Linear(args.embed_size, args.embed_size) 641 | nn.init.xavier_uniform_(self.user_trans.weight) 642 | nn.init.xavier_uniform_(self.item_trans.weight) 643 | 644 | self.MLP = BLMLP() 645 | # self.overallTime = datetime.timedelta(0) 646 | 647 | 648 | def get_embedding(self): 649 | return self.user_id_embedding, self.item_id_embedding 650 | 651 | 652 | def forward(self, pre_user, pre_item, ): 653 | # pre_user, pre_item = self.user_id_embedding.weight, self.item_id_embedding.weight 654 | user_embed = self.user_trans(pre_user) 655 | item_embed = self.user_trans(pre_item) 656 | 657 | return user_embed, item_embed 658 | # return pre_user, pre_item 659 | 660 | def init_user_item_embed(self, pre_u_embed, pre_i_embed): 661 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False) 662 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False) 663 | 664 | def pointPosPredictwEmbeds(self, uEmbeds, iEmbeds, ancs, poss): 665 | ancEmbeds = uEmbeds[ancs] 666 | posEmbeds = iEmbeds[poss] 667 | nume = self.MLP.pairPred(ancEmbeds, posEmbeds) 668 | return nume 669 | 670 | def pointNegPredictwEmbeds(self, embeds1, embeds2, nodes1, temp=1.0): 671 | pckEmbeds1 = embeds1[nodes1] 672 | preds = self.MLP.crossPred(pckEmbeds1, embeds2) 673 | return torch.exp(preds / temp).sum(-1) 674 | 675 | def pairPredictwEmbeds(self, uEmbeds, iEmbeds, ancs, poss, negs): 676 | ancEmbeds = uEmbeds[ancs] 677 | posEmbeds = iEmbeds[poss] 678 | negEmbeds = iEmbeds[negs] 679 | posPreds = self.MLP.pairPred(ancEmbeds, posEmbeds) 680 | negPreds = self.MLP.pairPred(ancEmbeds, negEmbeds) 681 | return posPreds - negPreds 682 | 683 | def predAll(self, pckUEmbeds, iEmbeds): 684 | return self.MLP.crossPred(pckUEmbeds, iEmbeds) 685 | 686 | def testPred(self, usr, trnMask): 687 | uEmbeds, iEmbeds = self.forward() 688 | allPreds = self.predAll(uEmbeds[usr], iEmbeds) * (1 - trnMask) - trnMask * 1e8 689 | return allPreds 690 | 691 | -------------------------------------------------------------------------------- /codes/__pycache__/MMD.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/MMD.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models.cpython-39.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_0938.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_0938.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_0938_modality.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_0938_modality.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_0938_sub_gene.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_0938_sub_gene.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec3.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec_SLMRec.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec_SLMRec.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_0485.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_0485.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_0826.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_0826.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_333333333333.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_333333333333.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_3_5.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_3_5.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_0962_beforeafterTo0962.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_0962_beforeafterTo0962.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_DtwoMLP.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_DtwoMLP.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_beforeafter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_beforeafter.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_copy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_copy.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_delete.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_delete.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_based_noFeatTrans_ablation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_based_noFeatTrans_ablation.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_dropGNN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_dropGNN.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_redoADCL.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_redoADCL.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_HL_train.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_HL_train.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_IRGANdeal.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_IRGANdeal.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_MUIT.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_MUIT.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_binaryLoss_modelPara.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_binaryLoss_modelPara.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_oversmoothing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_oversmoothing.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_woFEAT.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_woFEAT.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_woIIGRAPH.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_woIIGRAPH.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_model_feature.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_model_feature.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_samepos.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_samepos.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_G_first.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_G_first.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_G_first_auto.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_G_first_auto.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_G_first_auto_geneBPR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_G_first_auto_geneBPR.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models3.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models3.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models3.cpython-39.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28_distribution.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28_distribution.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28_distribution_noFeatTrans.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28_distribution_noFeatTrans.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models_AD.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models_AD.cpython-38.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models_empower.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models_empower.cpython-39.pyc -------------------------------------------------------------------------------- /codes/__pycache__/Models_mmlight.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models_mmlight.cpython-39.pyc -------------------------------------------------------------------------------- /codes/__pycache__/generator_discriminator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/generator_discriminator.cpython-38.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/batch_test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/batch_test.cpython-38.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/batch_test.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/batch_test.cpython-39.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/load_data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/load_data.cpython-38.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/load_data.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/load_data.cpython-39.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/logging.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/logging.cpython-38.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/logging.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/logging.cpython-39.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/metrics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/metrics.cpython-39.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/norm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/norm.cpython-38.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/norm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/norm.cpython-39.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/parser.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/parser.cpython-38.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/parser.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/parser.cpython-39.pyc -------------------------------------------------------------------------------- /codes/utility/batch_test.py: -------------------------------------------------------------------------------- 1 | import utility.metrics as metrics 2 | # from utility.parser import parse_args 3 | from utility.load_data import Data 4 | import multiprocessing 5 | import heapq 6 | import torch 7 | import pickle 8 | import numpy as np 9 | from time import time 10 | 11 | cores = multiprocessing.cpu_count() // 5 12 | 13 | # args = parse_args() 14 | from utility.parser import args 15 | 16 | # Ks = eval(args.Ks) 17 | 18 | data_generator = Data(path=args.data_path + args.dataset, batch_size=args.batch_size) 19 | USR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items 20 | N_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test 21 | BATCH_SIZE = args.batch_size 22 | 23 | def ranklist_by_heapq(user_pos_test, test_items, rating, Ks): 24 | item_score = {} 25 | for i in test_items: 26 | item_score[i] = rating[i] 27 | 28 | K_max = max(Ks) 29 | K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get) 30 | 31 | r = [] 32 | for i in K_max_item_score: 33 | if i in user_pos_test: 34 | r.append(1) 35 | else: 36 | r.append(0) 37 | auc = 0. 38 | return r, auc 39 | 40 | def get_auc(item_score, user_pos_test): 41 | item_score = sorted(item_score.items(), key=lambda kv: kv[1]) 42 | item_score.reverse() 43 | item_sort = [x[0] for x in item_score] 44 | posterior = [x[1] for x in item_score] 45 | 46 | r = [] 47 | for i in item_sort: 48 | if i in user_pos_test: 49 | r.append(1) 50 | else: 51 | r.append(0) 52 | auc = metrics.auc(ground_truth=r, prediction=posterior) 53 | return auc 54 | 55 | def ranklist_by_sorted(user_pos_test, test_items, rating, Ks): 56 | item_score = {} 57 | for i in test_items: 58 | item_score[i] = rating[i] 59 | 60 | K_max = max(Ks) 61 | K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get) 62 | 63 | r = [] 64 | for i in K_max_item_score: 65 | if i in user_pos_test: 66 | r.append(1) 67 | else: 68 | r.append(0) 69 | auc = get_auc(item_score, user_pos_test) 70 | return r, auc 71 | 72 | def get_performance(user_pos_test, r, auc, Ks): 73 | precision, recall, ndcg, hit_ratio = [], [], [], [] 74 | 75 | for K in Ks: 76 | precision.append(metrics.precision_at_k(r, K)) 77 | recall.append(metrics.recall_at_k(r, K, len(user_pos_test))) 78 | ndcg.append(metrics.ndcg_at_k(r, K)) 79 | hit_ratio.append(metrics.hit_at_k(r, K)) 80 | 81 | return {'recall': np.array(recall), 'precision': np.array(precision), 82 | 'ndcg': np.array(ndcg), 'hit_ratio': np.array(hit_ratio), 'auc': auc} 83 | 84 | 85 | def test_one_user(x): 86 | # user u's ratings for user u 87 | is_val = x[-1] 88 | rating = x[0] 89 | #uid 90 | u = x[1] 91 | #user u's items in the training set 92 | try: 93 | training_items = data_generator.train_items[u] 94 | except Exception: 95 | training_items = [] 96 | #user u's items in the test set 97 | if is_val: 98 | user_pos_test = data_generator.val_set[u] 99 | else: 100 | user_pos_test = data_generator.test_set[u] 101 | 102 | all_items = set(range(ITEM_NUM)) 103 | 104 | test_items = list(all_items - set(training_items)) 105 | 106 | if args.test_flag == 'part': 107 | r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, eval(args.Ks)) 108 | else: 109 | r, auc = ranklist_by_sorted(user_pos_test, test_items, rating, eval(args.Ks)) 110 | 111 | return get_performance(user_pos_test, r, auc, eval(args.Ks)) 112 | 113 | 114 | def test_torch(ua_embeddings, ia_embeddings, users_to_test, is_val, drop_flag=False, batch_test_flag=False): 115 | result = {'precision': np.zeros(len(eval(args.Ks))), 'recall': np.zeros(len(eval(args.Ks))), 'ndcg': np.zeros(len(eval(args.Ks))), 116 | 'hit_ratio': np.zeros(len(eval(args.Ks))), 'auc': 0.} 117 | pool = multiprocessing.Pool(cores) 118 | 119 | u_batch_size = BATCH_SIZE * 2 120 | i_batch_size = BATCH_SIZE 121 | 122 | test_users = users_to_test 123 | n_test_users = len(test_users) 124 | n_user_batchs = n_test_users // u_batch_size + 1 125 | count = 0 126 | 127 | for u_batch_id in range(n_user_batchs): 128 | start = u_batch_id * u_batch_size 129 | end = (u_batch_id + 1) * u_batch_size 130 | user_batch = test_users[start: end] 131 | if batch_test_flag: 132 | n_item_batchs = ITEM_NUM // i_batch_size + 1 133 | rate_batch = np.zeros(shape=(len(user_batch), ITEM_NUM)) 134 | 135 | i_count = 0 136 | for i_batch_id in range(n_item_batchs): 137 | i_start = i_batch_id * i_batch_size 138 | i_end = min((i_batch_id + 1) * i_batch_size, ITEM_NUM) 139 | 140 | item_batch = range(i_start, i_end) 141 | u_g_embeddings = ua_embeddings[user_batch] 142 | i_g_embeddings = ia_embeddings[item_batch] 143 | i_rate_batch = torch.matmul(u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1)) 144 | 145 | rate_batch[:, i_start: i_end] = i_rate_batch 146 | i_count += i_rate_batch.shape[1] 147 | 148 | assert i_count == ITEM_NUM 149 | 150 | else: 151 | item_batch = range(ITEM_NUM) 152 | u_g_embeddings = ua_embeddings[user_batch] 153 | i_g_embeddings = ia_embeddings[item_batch] 154 | rate_batch = torch.matmul(u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1)) 155 | 156 | rate_batch = rate_batch.detach().cpu().numpy() 157 | user_batch_rating_uid = zip(rate_batch, user_batch, [is_val] * len(user_batch)) 158 | 159 | batch_result = pool.map(test_one_user, user_batch_rating_uid) 160 | count += len(batch_result) 161 | 162 | for re in batch_result: 163 | result['precision'] += re['precision'] / n_test_users 164 | result['recall'] += re['recall'] / n_test_users 165 | result['ndcg'] += re['ndcg'] / n_test_users 166 | result['hit_ratio'] += re['hit_ratio'] / n_test_users 167 | result['auc'] += re['auc'] / n_test_users 168 | 169 | assert count == n_test_users 170 | pool.close() 171 | return result 172 | -------------------------------------------------------------------------------- /codes/utility/load_data.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import random as rd 4 | import scipy.sparse as sp 5 | from time import time 6 | import json 7 | # from utility.parser import parse_args 8 | # args = parse_args() 9 | from utility.parser import args 10 | 11 | 12 | class Data(object): 13 | def __init__(self, path, batch_size): 14 | self.path = path #+ '/%d-core' % args.core 15 | self.batch_size = batch_size 16 | 17 | train_file = path + '/train.json'#+ '/%d-core/train.json' % (args.core) 18 | val_file = path + '/val.json' #+ '/%d-core/val.json' % (args.core) 19 | test_file = path + '/test.json' #+ '/%d-core/test.json' % (args.core) 20 | 21 | #get number of users and items 22 | self.n_users, self.n_items = 0, 0 23 | self.n_train, self.n_test = 0, 0 24 | self.neg_pools = {} 25 | 26 | self.exist_users = [] 27 | 28 | train = json.load(open(train_file)) 29 | test = json.load(open(test_file)) 30 | val = json.load(open(val_file)) 31 | for uid, items in train.items(): 32 | if len(items) == 0: 33 | continue 34 | uid = int(uid) 35 | self.exist_users.append(uid) 36 | self.n_items = max(self.n_items, max(items)) 37 | self.n_users = max(self.n_users, uid) 38 | self.n_train += len(items) 39 | 40 | for uid, items in test.items(): 41 | uid = int(uid) 42 | try: 43 | self.n_items = max(self.n_items, max(items)) 44 | self.n_test += len(items) 45 | except: 46 | continue 47 | 48 | for uid, items in val.items(): 49 | uid = int(uid) 50 | try: 51 | self.n_items = max(self.n_items, max(items)) 52 | self.n_val += len(items) 53 | except: 54 | continue 55 | 56 | self.n_items += 1 57 | self.n_users += 1 58 | 59 | self.print_statistics() 60 | 61 | self.R = sp.dok_matrix((self.n_users, self.n_items), dtype=np.float32) 62 | self.R_Item_Interacts = sp.dok_matrix((self.n_items, self.n_items), dtype=np.float32) 63 | 64 | self.train_items, self.test_set, self.val_set = {}, {}, {} 65 | for uid, train_items in train.items(): 66 | if len(train_items) == 0: 67 | continue 68 | uid = int(uid) 69 | for idx, i in enumerate(train_items): 70 | self.R[uid, i] = 1. 71 | 72 | self.train_items[uid] = train_items 73 | 74 | for uid, test_items in test.items(): 75 | uid = int(uid) 76 | if len(test_items) == 0: 77 | continue 78 | try: 79 | self.test_set[uid] = test_items 80 | except: 81 | continue 82 | 83 | for uid, val_items in val.items(): 84 | uid = int(uid) 85 | if len(val_items) == 0: 86 | continue 87 | try: 88 | self.val_set[uid] = val_items 89 | except: 90 | continue 91 | 92 | def get_adj_mat(self): 93 | try: 94 | t1 = time() 95 | adj_mat = sp.load_npz(self.path + '/s_adj_mat.npz') 96 | norm_adj_mat = sp.load_npz(self.path + '/s_norm_adj_mat.npz') 97 | mean_adj_mat = sp.load_npz(self.path + '/s_mean_adj_mat.npz') 98 | print('already load adj matrix', adj_mat.shape, time() - t1) 99 | 100 | except Exception: 101 | adj_mat, norm_adj_mat, mean_adj_mat = self.create_adj_mat() 102 | sp.save_npz(self.path + '/s_adj_mat.npz', adj_mat) 103 | sp.save_npz(self.path + '/s_norm_adj_mat.npz', norm_adj_mat) 104 | sp.save_npz(self.path + '/s_mean_adj_mat.npz', mean_adj_mat) 105 | return adj_mat, norm_adj_mat, mean_adj_mat 106 | 107 | def create_adj_mat(self): 108 | t1 = time() 109 | adj_mat = sp.dok_matrix((self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32) 110 | adj_mat = adj_mat.tolil() 111 | R = self.R.tolil() 112 | 113 | adj_mat[:self.n_users, self.n_users:] = R 114 | adj_mat[self.n_users:, :self.n_users] = R.T 115 | adj_mat = adj_mat.todok() 116 | print('already create adjacency matrix', adj_mat.shape, time() - t1) 117 | 118 | t2 = time() 119 | 120 | def normalized_adj_single(adj): 121 | rowsum = np.array(adj.sum(1)) 122 | 123 | d_inv = np.power(rowsum, -1).flatten() 124 | d_inv[np.isinf(d_inv)] = 0. 125 | d_mat_inv = sp.diags(d_inv) 126 | 127 | norm_adj = d_mat_inv.dot(adj) 128 | # norm_adj = adj.dot(d_mat_inv) 129 | print('generate single-normalized adjacency matrix.') 130 | return norm_adj.tocoo() 131 | 132 | def get_D_inv(adj): 133 | rowsum = np.array(adj.sum(1)) 134 | 135 | d_inv = np.power(rowsum, -1).flatten() 136 | d_inv[np.isinf(d_inv)] = 0. 137 | d_mat_inv = sp.diags(d_inv) 138 | return d_mat_inv 139 | 140 | def check_adj_if_equal(adj): 141 | dense_A = np.array(adj.todense()) 142 | degree = np.sum(dense_A, axis=1, keepdims=False) 143 | 144 | temp = np.dot(np.diag(np.power(degree, -1)), dense_A) 145 | print('check normalized adjacency matrix whether equal to this laplacian matrix.') 146 | return temp 147 | 148 | norm_adj_mat = normalized_adj_single(adj_mat + sp.eye(adj_mat.shape[0])) 149 | mean_adj_mat = normalized_adj_single(adj_mat) 150 | 151 | print('already normalize adjacency matrix', time() - t2) 152 | return adj_mat.tocsr(), norm_adj_mat.tocsr(), mean_adj_mat.tocsr() 153 | 154 | 155 | def sample(self): 156 | if self.batch_size <= self.n_users: 157 | users = rd.sample(self.exist_users, self.batch_size) 158 | else: 159 | users = [rd.choice(self.exist_users) for _ in range(self.batch_size)] 160 | # users = self.exist_users[:] 161 | 162 | def sample_pos_items_for_u(u, num): 163 | pos_items = self.train_items[u] 164 | n_pos_items = len(pos_items) 165 | pos_batch = [] 166 | while True: 167 | if len(pos_batch) == num: break 168 | pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0] 169 | pos_i_id = pos_items[pos_id] 170 | 171 | if pos_i_id not in pos_batch: 172 | pos_batch.append(pos_i_id) 173 | return pos_batch 174 | 175 | def sample_neg_items_for_u(u, num): 176 | neg_items = [] 177 | while True: 178 | if len(neg_items) == num: break 179 | neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0] 180 | if neg_id not in self.train_items[u] and neg_id not in neg_items: 181 | neg_items.append(neg_id) 182 | return neg_items 183 | 184 | def sample_neg_items_for_u_from_pools(u, num): 185 | neg_items = list(set(self.neg_pools[u]) - set(self.train_items[u])) 186 | return rd.sample(neg_items, num) 187 | 188 | pos_items, neg_items = [], [] 189 | for u in users: 190 | pos_items += sample_pos_items_for_u(u, 1) 191 | neg_items += sample_neg_items_for_u(u, 1) 192 | # neg_items += sample_neg_items_for_u(u, 3) 193 | return users, pos_items, neg_items 194 | 195 | 196 | def print_statistics(self): 197 | print('n_users=%d, n_items=%d' % (self.n_users, self.n_items)) 198 | print('n_interactions=%d' % (self.n_train + self.n_test)) 199 | print('n_train=%d, n_test=%d, sparsity=%.5f' % (self.n_train, self.n_test, (self.n_train + self.n_test)/(self.n_users * self.n_items))) 200 | 201 | -------------------------------------------------------------------------------- /codes/utility/logging.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | class Logger(): 5 | def __init__(self, filename, is_debug, path='/home/weiw/Code/MM/KDMM/logs/'): 6 | self.filename = filename 7 | self.path = path 8 | self.log_ = not is_debug 9 | def logging(self, s): 10 | s = str(s) 11 | print(datetime.now().strftime('%Y-%m-%d %H:%M: '), s) 12 | if self.log_: 13 | with open(os.path.join(os.path.join(self.path, self.filename)), 'a+') as f_log: 14 | f_log.write(str(datetime.now().strftime('%Y-%m-%d %H:%M: ')) + s + '\n') 15 | -------------------------------------------------------------------------------- /codes/utility/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import roc_auc_score 3 | 4 | def recall(rank, ground_truth, N): 5 | return len(set(rank[:N]) & set(ground_truth)) / float(len(set(ground_truth))) 6 | 7 | 8 | def precision_at_k(r, k): 9 | """Score is precision @ k 10 | Relevance is binary (nonzero is relevant). 11 | Returns: 12 | Precision @ k 13 | Raises: 14 | ValueError: len(r) must be >= k 15 | """ 16 | assert k >= 1 17 | r = np.asarray(r)[:k] 18 | return np.mean(r) 19 | 20 | 21 | def average_precision(r,cut): 22 | """Score is average precision (area under PR curve) 23 | Relevance is binary (nonzero is relevant). 24 | Returns: 25 | Average precision 26 | """ 27 | r = np.asarray(r) 28 | out = [precision_at_k(r, k + 1) for k in range(cut) if r[k]] 29 | if not out: 30 | return 0. 31 | return np.sum(out)/float(min(cut, np.sum(r))) 32 | 33 | 34 | def mean_average_precision(rs): 35 | """Score is mean average precision 36 | Relevance is binary (nonzero is relevant). 37 | Returns: 38 | Mean average precision 39 | """ 40 | return np.mean([average_precision(r) for r in rs]) 41 | 42 | 43 | def dcg_at_k(r, k, method=1): 44 | """Score is discounted cumulative gain (dcg) 45 | Relevance is positive real values. Can use binary 46 | as the previous methods. 47 | Returns: 48 | Discounted cumulative gain 49 | """ 50 | r = np.asfarray(r)[:k] 51 | if r.size: 52 | if method == 0: 53 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 54 | elif method == 1: 55 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 56 | else: 57 | raise ValueError('method must be 0 or 1.') 58 | return 0. 59 | 60 | 61 | def ndcg_at_k(r, k, method=1): 62 | """Score is normalized discounted cumulative gain (ndcg) 63 | Relevance is positive real values. Can use binary 64 | as the previous methods. 65 | Returns: 66 | Normalized discounted cumulative gain 67 | """ 68 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) 69 | if not dcg_max: 70 | return 0. 71 | return dcg_at_k(r, k, method) / dcg_max 72 | 73 | 74 | def recall_at_k(r, k, all_pos_num): 75 | r = np.asfarray(r)[:k] 76 | if all_pos_num == 0: 77 | return 0 78 | else: 79 | return np.sum(r) / all_pos_num 80 | 81 | 82 | def hit_at_k(r, k): 83 | r = np.array(r)[:k] 84 | if np.sum(r) > 0: 85 | return 1. 86 | else: 87 | return 0. 88 | 89 | def F1(pre, rec): 90 | if pre + rec > 0: 91 | return (2.0 * pre * rec) / (pre + rec) 92 | else: 93 | return 0. 94 | 95 | def auc(ground_truth, prediction): 96 | try: 97 | res = roc_auc_score(y_true=ground_truth, y_score=prediction) 98 | except Exception: 99 | res = 0. 100 | return res -------------------------------------------------------------------------------- /codes/utility/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.sparse import csr_matrix 4 | 5 | def build_sim(context): 6 | context_norm = context.div(torch.norm(context, p=2, dim=-1, keepdim=True)) 7 | sim = torch.sparse.mm(context_norm, context_norm.transpose(1, 0)) 8 | # a, b = context_norm.shape 9 | # b, c = context_norm.transpose(1, 0).shape 10 | # ab = context_norm.unsqueeze(-1) #.repeat(1,1,c) 11 | # bc = context_norm.transpose(1, 0).unsqueeze(0) #.repeat(a, 1,1) 12 | # sim = torch.mul(ab, bc).sum(dim=1, keepdim=False) 13 | 14 | return sim 15 | 16 | # def build_knn_normalized_graph(adj, topk, is_sparse, norm_type): 17 | # device = adj.device 18 | # knn_val, knn_ind = torch.topk(adj, topk, dim=-1) 19 | # if is_sparse: 20 | # tuple_list = [[row, int(col)] for row in range(len(knn_ind)) for col in knn_ind[row]] 21 | # row = [i[0] for i in tuple_list] 22 | # col = [i[1] for i in tuple_list] 23 | # i = torch.LongTensor([row, col]).to(device) 24 | # v = knn_val.flatten() 25 | # edge_index, edge_weight = get_sparse_laplacian(i, v, normalization=norm_type, num_nodes=adj.shape[0]) 26 | # return torch.sparse_coo_tensor(edge_index, edge_weight, adj.shape) 27 | # else: 28 | # weighted_adjacency_matrix = (torch.zeros_like(adj)).scatter_(-1, knn_ind, knn_val) 29 | # return get_dense_laplacian(weighted_adjacency_matrix, normalization=norm_type) 30 | 31 | def build_knn_normalized_graph(adj, topk, is_sparse, norm_type): 32 | device = adj.device 33 | knn_val, knn_ind = torch.topk(adj, topk, dim=-1) #[7050, 10] [7050, 10] 34 | n_item = knn_val.shape[0] 35 | n_data = knn_val.shape[0]*knn_val.shape[1] 36 | data = np.ones(n_data) 37 | if is_sparse: 38 | tuple_list = [[row, int(col)] for row in range(len(knn_ind)) for col in knn_ind[row]] #[70500] 39 | # data = np.array(knn_val.flatten().cpu()) #args.topk_rate* 40 | row = [i[0] for i in tuple_list] #[70500] 41 | col = [i[1] for i in tuple_list] #[70500] 42 | # #----------------------------------------------------------------------------------------------------- 43 | # i = torch.LongTensor([row, col]).to(device) 44 | # v = knn_val.flatten() 45 | # edge_index, edge_weight = get_sparse_laplacian(i, v, normalization=norm_type, num_nodes=adj.shape[0]) 46 | # #----------------------------------------------------------------------------------------------------- 47 | ii_graph = csr_matrix((data, (row, col)) ,shape=(n_item, n_item)) 48 | # return torch.sparse_coo_tensor(edge_index, edge_weight, adj.shape) 49 | return ii_graph 50 | else: 51 | weighted_adjacency_matrix = (torch.zeros_like(adj)).scatter_(-1, knn_ind, knn_val) 52 | return get_dense_laplacian(weighted_adjacency_matrix, normalization=norm_type) 53 | 54 | 55 | def get_sparse_laplacian(edge_index, edge_weight, num_nodes, normalization='none'): #[2, 70500], [70500] 56 | from torch_scatter import scatter_add 57 | row, col = edge_index[0], edge_index[1] #[70500] [70500] 58 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) #[7050] 59 | 60 | if normalization == 'sym': 61 | deg_inv_sqrt = deg.pow_(-0.5) 62 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) 63 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 64 | elif normalization == 'rw': 65 | deg_inv = 1.0 / deg 66 | deg_inv.masked_fill_(deg_inv == float('inf'), 0) 67 | edge_weight = deg_inv[row] * edge_weight 68 | return edge_index, edge_weight 69 | 70 | 71 | def get_dense_laplacian(adj, normalization='none'): 72 | if normalization == 'sym': 73 | rowsum = torch.sum(adj, -1) 74 | d_inv_sqrt = torch.pow(rowsum, -0.5) 75 | d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0. 76 | d_mat_inv_sqrt = torch.diagflat(d_inv_sqrt) 77 | L_norm = torch.mm(torch.mm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt) 78 | elif normalization == 'rw': 79 | rowsum = torch.sum(adj, -1) 80 | d_inv = torch.pow(rowsum, -1) 81 | d_inv[torch.isinf(d_inv)] = 0. 82 | d_mat_inv = torch.diagflat(d_inv) 83 | L_norm = torch.mm(d_mat_inv, adj) 84 | elif normalization == 'none': 85 | L_norm = adj 86 | return L_norm 87 | -------------------------------------------------------------------------------- /decouple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/decouple.png --------------------------------------------------------------------------------