├── LATTICE
├── LICENSE
├── README.md
└── codes
│ ├── Models.py
│ ├── __pycache__
│ └── Models.cpython-38.pyc
│ ├── main.py
│ └── utility
│ ├── __pycache__
│ ├── batch_test.cpython-38.pyc
│ ├── load_data.cpython-38.pyc
│ ├── metrics.cpython-38.pyc
│ └── parser.cpython-38.pyc
│ ├── batch_test.py
│ ├── load_data.py
│ ├── metrics.py
│ └── parser.py
├── PromptMM.png
├── README.md
├── codes
├── Models.py
├── Models_empower.py
├── Models_mmlight.py
├── __pycache__
│ ├── MMD.cpython-38.pyc
│ ├── Models.cpython-38.pyc
│ ├── Models.cpython-39.pyc
│ ├── Models2.cpython-38.pyc
│ ├── Models2_0938.cpython-38.pyc
│ ├── Models2_0938_modality.cpython-38.pyc
│ ├── Models2_0938_sub_gene.cpython-38.pyc
│ ├── Models2_MF_VBPR_NGCF_LightGCN.cpython-38.pyc
│ ├── Models2_MF_VBPR_NGCF_LightGCN_MMGCN.cpython-38.pyc
│ ├── Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR.cpython-38.pyc
│ ├── Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec.cpython-38.pyc
│ ├── Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec3.cpython-38.pyc
│ ├── Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec_SLMRec.cpython-38.pyc
│ ├── Models2_sub_gene.cpython-38.pyc
│ ├── Models2_sub_gene_0485.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_0826.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_333333333333.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_3_5.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_0962_beforeafterTo0962.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_DtwoMLP.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_beforeafter.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_copy.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_delete.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_based_noFeatTrans_ablation.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_dropGNN.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_redoADCL.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_HL_train.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_IRGANdeal.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_MUIT.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_binaryLoss_modelPara.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_oversmoothing.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_woFEAT.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_woIIGRAPH.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_model_feature.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_samepos.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_G_first.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_G_first_auto.cpython-38.pyc
│ ├── Models2_sub_gene_co_weight_AD_0938_fake_click_G_first_auto_geneBPR.cpython-38.pyc
│ ├── Models3.cpython-38.pyc
│ ├── Models3.cpython-39.pyc
│ ├── Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28.cpython-38.pyc
│ ├── Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28_distribution.cpython-38.pyc
│ ├── Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28_distribution_noFeatTrans.cpython-38.pyc
│ ├── Models_AD.cpython-38.pyc
│ ├── Models_empower.cpython-39.pyc
│ ├── Models_mmlight.cpython-39.pyc
│ └── generator_discriminator.cpython-38.pyc
├── main.py
├── main_empower.py
├── main_mmlight.py
└── utility
│ ├── __pycache__
│ ├── batch_test.cpython-38.pyc
│ ├── batch_test.cpython-39.pyc
│ ├── load_data.cpython-38.pyc
│ ├── load_data.cpython-39.pyc
│ ├── logging.cpython-38.pyc
│ ├── logging.cpython-39.pyc
│ ├── metrics.cpython-38.pyc
│ ├── metrics.cpython-39.pyc
│ ├── norm.cpython-38.pyc
│ ├── norm.cpython-39.pyc
│ ├── parser.cpython-38.pyc
│ └── parser.cpython-39.pyc
│ ├── batch_test.py
│ ├── load_data.py
│ ├── logging.py
│ ├── metrics.py
│ ├── norm.py
│ └── parser.py
└── decouple.png
/LATTICE/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Big Data and Multi-modal Computing Group, CRIPAC
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LATTICE/README.md:
--------------------------------------------------------------------------------
1 | # LATTICE
2 |
3 | PyTorch implementation for ACM Multimedia 2021 paper: [Mining Latent Structures for Multimedia Recommendation](https://dl.acm.org/doi/10.1145/3474085.3475259)
4 |
5 |
6 |
7 | ## Dependencies
8 |
9 | - Python 3.6
10 | - torch==1.5.0
11 | - scikit-learn==0.24.2
12 |
13 |
14 |
15 | ## Dataset Preparation
16 |
17 | - Download **5-core reviews data**, **meta data**, and **image features** from [Amazon product dataset](http://jmcauley.ucsd.edu/data/amazon/links.html). Put data into the directory `data/meta-data/`.
18 |
19 | - Install [sentence-transformers](https://www.sbert.net/docs/installation.html) and download [pretrained models](https://www.sbert.net/docs/pretrained_models.html) to extract textual features. Unzip pretrained model into the directory `sentence-transformers/`:
20 |
21 | ```
22 | ├─ data/:
23 | ├── sports/
24 | ├── meta-data/
25 | ├── image_features_Sports_and_Outdoors.b
26 | ├── meta-Sports_and_Outdoors.json.gz
27 | ├── reviews_Sports_and_Outdoors_5.json.gz
28 | ├── sentence-transformers/
29 | ├── stsb-roberta-large
30 | ```
31 |
32 | - Run `python build_data.py` to preprocess data.
33 |
34 | - Run `python cold_start.py` to build cold-start data.
35 |
36 | - We provide processed data [Baidu Yun](https://pan.baidu.com/s/1SWe-XE23Nn0i4xSOXV_JyQ) (access code: m37q), [Google Drive](https://drive.google.com/drive/folders/1sFg9W2wCexWahjqtN6MVc4f4dMj5hyFp?usp=sharing).
37 |
38 | ## Usage
39 |
40 | Start training and inference as:
41 |
42 | ```
43 | cd codes
44 | python main.py --dataset {DATASET}
45 | ```
46 |
47 | For cold-start settings:
48 | ```
49 | python main.py --dataset {DATASET} --core 0 --verbose 1 --lr 1e-5
50 | ```
51 |
52 |
53 |
54 | ## Citation
55 |
56 | If you want to use our codes in your research, please cite:
57 |
58 | ```
59 | @inproceedings{LATTICE21,
60 | title = {Mining Latent Structures for Multimedia Recommendation},
61 | author = {Zhang, Jinghao and
62 | Zhu, Yanqiao and
63 | Liu, Qiang and
64 | Wu, Shu and
65 | Wang, Shuhui and
66 | Wang, Liang},
67 | booktitle = {Proceedings of the 29th ACM International Conference on Multimedia},
68 | pages = {3872–3880},
69 | year = {2021}
70 | }
71 | ```
72 |
73 | ## Acknowledgement
74 |
75 | The structure of this code is largely based on [LightGCN](https://github.com/gusye1234/LightGCN-PyTorch). Thank for their work.
76 |
77 |
--------------------------------------------------------------------------------
/LATTICE/codes/Models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from time import time
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.sparse as sparse
8 | import torch.nn.functional as F
9 |
10 | from utility.parser import parse_args
11 | args = parse_args()
12 |
13 | def build_knn_neighbourhood(adj, topk):
14 | knn_val, knn_ind = torch.topk(adj, topk, dim=-1)
15 | weighted_adjacency_matrix = (torch.zeros_like(adj)).scatter_(-1, knn_ind, knn_val)
16 | return weighted_adjacency_matrix
17 | def compute_normalized_laplacian(adj):
18 | rowsum = torch.sum(adj, -1)
19 | d_inv_sqrt = torch.pow(rowsum, -0.5)
20 | d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
21 | d_mat_inv_sqrt = torch.diagflat(d_inv_sqrt)
22 | L_norm = torch.mm(torch.mm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)
23 | return L_norm
24 | def build_sim(context):
25 | context_norm = context.div(torch.norm(context, p=2, dim=-1, keepdim=True))
26 | sim = torch.mm(context_norm, context_norm.transpose(1, 0))
27 | return sim
28 |
29 | class LATTICE(nn.Module):
30 | def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats, text_feats):
31 | super().__init__()
32 | self.n_users = n_users
33 | self.n_items = n_items
34 | self.embedding_dim = embedding_dim
35 | self.weight_size = weight_size
36 | self.n_ui_layers = len(self.weight_size)
37 | self.weight_size = [self.embedding_dim] + self.weight_size
38 | self.user_embedding = nn.Embedding(n_users, self.embedding_dim)
39 | self.item_id_embedding = nn.Embedding(n_items, self.embedding_dim)
40 | nn.init.xavier_uniform_(self.user_embedding.weight)
41 | nn.init.xavier_uniform_(self.item_id_embedding.weight)
42 |
43 | if args.cf_model == 'ngcf':
44 | self.GC_Linear_list = nn.ModuleList()
45 | self.Bi_Linear_list = nn.ModuleList()
46 | self.dropout_list = nn.ModuleList()
47 | for i in range(self.n_ui_layers):
48 | self.GC_Linear_list.append(nn.Linear(self.weight_size[i], self.weight_size[i+1]))
49 | self.Bi_Linear_list.append(nn.Linear(self.weight_size[i], self.weight_size[i+1]))
50 | self.dropout_list.append(nn.Dropout(dropout_list[i]))
51 |
52 |
53 | self.image_embedding = nn.Embedding.from_pretrained(torch.Tensor(image_feats), freeze=False)
54 | self.text_embedding = nn.Embedding.from_pretrained(torch.Tensor(text_feats), freeze=False)
55 |
56 |
57 | if os.path.exists(args.data_path + 'image_adj_%d.pt'%(args.topk)):
58 | image_adj = torch.load(args.data_path + 'image_adj_%d.pt'%(args.topk))
59 | else:
60 | image_adj = build_sim(self.image_embedding.weight.detach())
61 | image_adj = build_knn_neighbourhood(image_adj, topk=args.topk)
62 | image_adj = compute_normalized_laplacian(image_adj)
63 | torch.save(image_adj, args.data_path + 'image_adj_%d.pt'%(args.topk))
64 |
65 | if os.path.exists(args.data_path + 'text_adj_%d.pt'%(args.topk)):
66 | text_adj = torch.load(args.data_path + 'text_adj_%d.pt'%(args.topk))
67 | else:
68 | text_adj = build_sim(self.text_embedding.weight.detach())
69 | text_adj = build_knn_neighbourhood(text_adj, topk=args.topk)
70 | text_adj = compute_normalized_laplacian(text_adj)
71 | torch.save(text_adj, args.data_path + 'text_adj_%d.pt'%(args.topk))
72 |
73 | self.text_original_adj = text_adj.cuda()
74 | self.image_original_adj = image_adj.cuda()
75 |
76 | self.image_trs = nn.Linear(image_feats.shape[1], args.feat_embed_dim)
77 | self.text_trs = nn.Linear(text_feats.shape[1], args.feat_embed_dim)
78 |
79 |
80 | self.modal_weight = nn.Parameter(torch.Tensor([0.5, 0.5]))
81 | self.softmax = nn.Softmax(dim=0)
82 |
83 | def forward(self, adj, build_item_graph=False):
84 | image_feats = self.image_trs(self.image_embedding.weight)
85 | text_feats = self.text_trs(self.text_embedding.weight)
86 | if build_item_graph:
87 | weight = self.softmax(self.modal_weight)
88 | self.image_adj = build_sim(image_feats)
89 | self.image_adj = build_knn_neighbourhood(self.image_adj, topk=args.topk)
90 |
91 | self.text_adj = build_sim(text_feats)
92 | self.text_adj = build_knn_neighbourhood(self.text_adj, topk=args.topk)
93 |
94 |
95 | learned_adj = weight[0] * self.image_adj + weight[1] * self.text_adj
96 | learned_adj = compute_normalized_laplacian(learned_adj)
97 | original_adj = weight[0] * self.image_original_adj + weight[1] * self.text_original_adj
98 | self.item_adj = (1 - args.lambda_coeff) * learned_adj + args.lambda_coeff * original_adj
99 | else:
100 | self.item_adj = self.item_adj.detach()
101 |
102 | h = self.item_id_embedding.weight
103 | for i in range(args.n_layers):
104 | h = torch.mm(self.item_adj, h)
105 |
106 | if args.cf_model == 'ngcf':
107 | ego_embeddings = torch.cat((self.user_embedding.weight, self.item_id_embedding.weight), dim=0)
108 | all_embeddings = [ego_embeddings]
109 | for i in range(self.n_ui_layers):
110 | side_embeddings = torch.sparse.mm(adj, ego_embeddings)
111 | sum_embeddings = F.leaky_relu(self.GC_Linear_list[i](side_embeddings))
112 | bi_embeddings = torch.mul(ego_embeddings, side_embeddings)
113 | bi_embeddings = F.leaky_relu(self.Bi_Linear_list[i](bi_embeddings))
114 | ego_embeddings = sum_embeddings + bi_embeddings
115 | ego_embeddings = self.dropout_list[i](ego_embeddings)
116 |
117 | norm_embeddings = F.normalize(ego_embeddings, p=2, dim=1)
118 | all_embeddings += [norm_embeddings]
119 |
120 | all_embeddings = torch.stack(all_embeddings, dim=1)
121 | all_embeddings = all_embeddings.mean(dim=1, keepdim=False)
122 | u_g_embeddings, i_g_embeddings = torch.split(all_embeddings, [self.n_users, self.n_items], dim=0)
123 | i_g_embeddings = i_g_embeddings + F.normalize(h, p=2, dim=1)
124 | return u_g_embeddings, i_g_embeddings
125 | elif args.cf_model == 'lightgcn':
126 | ego_embeddings = torch.cat((self.user_embedding.weight, self.item_id_embedding.weight), dim=0)
127 | all_embeddings = [ego_embeddings]
128 | for i in range(self.n_ui_layers):
129 | side_embeddings = torch.sparse.mm(adj, ego_embeddings)
130 | ego_embeddings = side_embeddings
131 | all_embeddings += [ego_embeddings]
132 | all_embeddings = torch.stack(all_embeddings, dim=1)
133 | all_embeddings = all_embeddings.mean(dim=1, keepdim=False)
134 | u_g_embeddings, i_g_embeddings = torch.split(all_embeddings, [self.n_users, self.n_items], dim=0)
135 | i_g_embeddings = i_g_embeddings + F.normalize(h, p=2, dim=1)
136 | return u_g_embeddings, i_g_embeddings
137 | elif args.cf_model == 'mf':
138 | return self.user_embedding.weight, self.item_id_embedding.weight + F.normalize(h, p=2, dim=1)
139 |
140 |
--------------------------------------------------------------------------------
/LATTICE/codes/__pycache__/Models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/LATTICE/codes/__pycache__/Models.cpython-38.pyc
--------------------------------------------------------------------------------
/LATTICE/codes/main.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import math
3 | import os
4 | import random
5 | import sys
6 | from time import time
7 | from tqdm import tqdm
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torch.optim as optim
14 | import torch.sparse as sparse
15 |
16 | from utility.parser import parse_args
17 | from Models import LATTICE
18 | from utility.batch_test import *
19 |
20 | args = parse_args()
21 |
22 |
23 | class Trainer(object):
24 | def __init__(self, data_config):
25 | # argument settings
26 | self.n_users = data_config['n_users']
27 | self.n_items = data_config['n_items']
28 |
29 | self.model_name = args.model_name
30 | self.mess_dropout = eval(args.mess_dropout)
31 | self.lr = args.lr
32 | self.emb_dim = args.embed_size
33 | self.batch_size = args.batch_size
34 | self.weight_size = eval(args.weight_size)
35 | self.n_layers = len(self.weight_size)
36 | self.regs = eval(args.regs)
37 | self.decay = self.regs[0]
38 |
39 | self.norm_adj = data_config['norm_adj']
40 | self.norm_adj = self.sparse_mx_to_torch_sparse_tensor(self.norm_adj).float().cuda()
41 |
42 | image_feats = np.load(args.data_path + '{}/image_feat.npy'.format(args.dataset))
43 | text_feats = np.load(args.data_path + '{}/text_feat.npy'.format(args.dataset))
44 |
45 | self.model = LATTICE(self.n_users, self.n_items, self.emb_dim, self.weight_size, self.mess_dropout, image_feats, text_feats)
46 | self.model = self.model.cuda()
47 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
48 | self.lr_scheduler = self.set_lr_scheduler()
49 |
50 | def set_lr_scheduler(self):
51 | fac = lambda epoch: 0.96 ** (epoch / 50)
52 | scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=fac)
53 | return scheduler
54 |
55 | def test(self, users_to_test, is_val):
56 | self.model.eval()
57 | with torch.no_grad():
58 | ua_embeddings, ia_embeddings = self.model(self.norm_adj, build_item_graph=True)
59 | result = test_torch(ua_embeddings, ia_embeddings, users_to_test, is_val)
60 | return result
61 |
62 | def train(self):
63 | training_time_list = []
64 | loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], []
65 | stopping_step = 0
66 | should_stop = False
67 | cur_best_pre_0 = 0.
68 |
69 | n_batch = data_generator.n_train // args.batch_size + 1
70 | best_recall = 0
71 | for epoch in (range(args.epoch)):
72 | t1 = time()
73 | loss, mf_loss, emb_loss, reg_loss = 0., 0., 0., 0.
74 | n_batch = data_generator.n_train // args.batch_size + 1
75 | f_time, b_time, loss_time, opt_time, clip_time, emb_time = 0., 0., 0., 0., 0., 0.
76 | sample_time = 0.
77 | build_item_graph = True
78 | for idx in (range(n_batch)):
79 | self.model.train()
80 | self.optimizer.zero_grad()
81 | sample_t1 = time()
82 | users, pos_items, neg_items = data_generator.sample()
83 | sample_time += time() - sample_t1
84 | ua_embeddings, ia_embeddings = self.model(self.norm_adj, build_item_graph=build_item_graph)
85 | build_item_graph = False
86 | u_g_embeddings = ua_embeddings[users]
87 | pos_i_g_embeddings = ia_embeddings[pos_items]
88 | neg_i_g_embeddings = ia_embeddings[neg_items]
89 |
90 |
91 | batch_mf_loss, batch_emb_loss, batch_reg_loss = self.bpr_loss(u_g_embeddings, pos_i_g_embeddings,
92 | neg_i_g_embeddings)
93 |
94 | batch_loss = batch_mf_loss + batch_emb_loss + batch_reg_loss
95 |
96 | batch_loss.backward(retain_graph=True)
97 | self.optimizer.step()
98 |
99 | loss += float(batch_loss)
100 | mf_loss += float(batch_mf_loss)
101 | emb_loss += float(batch_emb_loss)
102 | reg_loss += float(batch_reg_loss)
103 |
104 |
105 | self.lr_scheduler.step()
106 |
107 | del ua_embeddings, ia_embeddings, u_g_embeddings, neg_i_g_embeddings, pos_i_g_embeddings
108 |
109 | if math.isnan(loss) == True:
110 | print('ERROR: loss is nan.')
111 | sys.exit()
112 |
113 | perf_str = 'Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f]' % (
114 | epoch, time() - t1, loss, mf_loss, emb_loss)
115 | training_time_list.append(time() - t1)
116 | print(perf_str)
117 |
118 | if epoch % args.verbose != 0:
119 | continue
120 |
121 |
122 | t2 = time()
123 | users_to_test = list(data_generator.test_set.keys())
124 | users_to_val = list(data_generator.val_set.keys())
125 | ret = self.test(users_to_val, is_val=True)
126 | training_time_list.append(t2 - t1)
127 |
128 | t3 = time()
129 |
130 | loss_loger.append(loss)
131 | rec_loger.append(ret['recall'])
132 | pre_loger.append(ret['precision'])
133 | ndcg_loger.append(ret['ndcg'])
134 | hit_loger.append(ret['hit_ratio'])
135 | if args.verbose > 0:
136 | perf_str = 'Epoch %d [%.1fs + %.1fs]: val==[%.5f=%.5f + %.5f + %.5f], recall=[%.5f, %.5f], ' \
137 | 'precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]' % \
138 | (epoch, t2 - t1, t3 - t2, loss, mf_loss, emb_loss, reg_loss, ret['recall'][0],
139 | ret['recall'][-1],
140 | ret['precision'][0], ret['precision'][-1], ret['hit_ratio'][0], ret['hit_ratio'][-1],
141 | ret['ndcg'][0], ret['ndcg'][-1])
142 | print(perf_str)
143 |
144 | if ret['recall'][1] > best_recall:
145 | best_recall = ret['recall'][1]
146 | test_ret = self.test(users_to_test, is_val=False)
147 | perf_str = 'Epoch %d [%.1fs + %.1fs]: test==[%.5f=%.5f + %.5f + %.5f], recall=[%.5f, %.5f], ' \
148 | 'precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]' % \
149 | (epoch, t2 - t1, t3 - t2, loss, mf_loss, emb_loss, reg_loss, test_ret['recall'][0],
150 | test_ret['recall'][-1],
151 | test_ret['precision'][0], test_ret['precision'][-1], test_ret['hit_ratio'][0], test_ret['hit_ratio'][-1],
152 | test_ret['ndcg'][0], test_ret['ndcg'][-1])
153 | print(perf_str)
154 | stopping_step = 0
155 | elif stopping_step < args.early_stopping_patience:
156 | stopping_step += 1
157 | print('#####Early stopping steps: %d #####' % stopping_step)
158 | else:
159 | print('#####Early stop! #####')
160 | break
161 |
162 | print(test_ret)
163 |
164 | def bpr_loss(self, users, pos_items, neg_items):
165 | pos_scores = torch.sum(torch.mul(users, pos_items), dim=1)
166 | neg_scores = torch.sum(torch.mul(users, neg_items), dim=1)
167 |
168 | regularizer = 1./2*(users**2).sum() + 1./2*(pos_items**2).sum() + 1./2*(neg_items**2).sum()
169 | regularizer = regularizer / self.batch_size
170 |
171 | maxi = F.logsigmoid(pos_scores - neg_scores)
172 | mf_loss = -torch.mean(maxi)
173 |
174 | emb_loss = self.decay * regularizer
175 | reg_loss = 0.0
176 | return mf_loss, emb_loss, reg_loss
177 |
178 | def sparse_mx_to_torch_sparse_tensor(self, sparse_mx):
179 | """Convert a scipy sparse matrix to a torch sparse tensor."""
180 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
181 | indices = torch.from_numpy(
182 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
183 | values = torch.from_numpy(sparse_mx.data)
184 | shape = torch.Size(sparse_mx.shape)
185 | return torch.sparse.FloatTensor(indices, values, shape)
186 |
187 | def set_seed(seed):
188 | np.random.seed(seed)
189 | random.seed(seed)
190 | torch.manual_seed(seed) # cpu
191 | torch.cuda.manual_seed_all(seed) # gpu
192 |
193 | if __name__ == '__main__':
194 | set_seed(args.seed)
195 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
196 |
197 | config = dict()
198 | config['n_users'] = data_generator.n_users
199 | config['n_items'] = data_generator.n_items
200 |
201 | plain_adj, norm_adj, mean_adj = data_generator.get_adj_mat()
202 | config['norm_adj'] = norm_adj
203 |
204 | trainer = Trainer(data_config=config)
205 | trainer.train()
206 |
207 |
--------------------------------------------------------------------------------
/LATTICE/codes/utility/__pycache__/batch_test.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/LATTICE/codes/utility/__pycache__/batch_test.cpython-38.pyc
--------------------------------------------------------------------------------
/LATTICE/codes/utility/__pycache__/load_data.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/LATTICE/codes/utility/__pycache__/load_data.cpython-38.pyc
--------------------------------------------------------------------------------
/LATTICE/codes/utility/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/LATTICE/codes/utility/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/LATTICE/codes/utility/__pycache__/parser.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/LATTICE/codes/utility/__pycache__/parser.cpython-38.pyc
--------------------------------------------------------------------------------
/LATTICE/codes/utility/batch_test.py:
--------------------------------------------------------------------------------
1 | import utility.metrics as metrics
2 | from utility.parser import parse_args
3 | from utility.load_data import Data
4 | import multiprocessing
5 | import heapq
6 | import torch
7 | import pickle
8 | import numpy as np
9 | from time import time
10 |
11 | cores = multiprocessing.cpu_count() // 5
12 |
13 | args = parse_args()
14 | Ks = eval(args.Ks)
15 |
16 | data_generator = Data(path=args.data_path + args.dataset, batch_size=args.batch_size)
17 | USR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items
18 | N_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test
19 | BATCH_SIZE = args.batch_size
20 |
21 | def ranklist_by_heapq(user_pos_test, test_items, rating, Ks):
22 | item_score = {}
23 | for i in test_items:
24 | item_score[i] = rating[i]
25 |
26 | K_max = max(Ks)
27 | K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)
28 |
29 | r = []
30 | for i in K_max_item_score:
31 | if i in user_pos_test:
32 | r.append(1)
33 | else:
34 | r.append(0)
35 | auc = 0.
36 | return r, auc
37 |
38 | def get_auc(item_score, user_pos_test):
39 | item_score = sorted(item_score.items(), key=lambda kv: kv[1])
40 | item_score.reverse()
41 | item_sort = [x[0] for x in item_score]
42 | posterior = [x[1] for x in item_score]
43 |
44 | r = []
45 | for i in item_sort:
46 | if i in user_pos_test:
47 | r.append(1)
48 | else:
49 | r.append(0)
50 | auc = metrics.auc(ground_truth=r, prediction=posterior)
51 | return auc
52 |
53 | def ranklist_by_sorted(user_pos_test, test_items, rating, Ks):
54 | item_score = {}
55 | for i in test_items:
56 | item_score[i] = rating[i]
57 |
58 | K_max = max(Ks)
59 | K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)
60 |
61 | r = []
62 | for i in K_max_item_score:
63 | if i in user_pos_test:
64 | r.append(1)
65 | else:
66 | r.append(0)
67 | auc = get_auc(item_score, user_pos_test)
68 | return r, auc
69 |
70 | def get_performance(user_pos_test, r, auc, Ks):
71 | precision, recall, ndcg, hit_ratio = [], [], [], []
72 |
73 | for K in Ks:
74 | precision.append(metrics.precision_at_k(r, K))
75 | recall.append(metrics.recall_at_k(r, K, len(user_pos_test)))
76 | ndcg.append(metrics.ndcg_at_k(r, K))
77 | hit_ratio.append(metrics.hit_at_k(r, K))
78 |
79 | return {'recall': np.array(recall), 'precision': np.array(precision),
80 | 'ndcg': np.array(ndcg), 'hit_ratio': np.array(hit_ratio), 'auc': auc}
81 |
82 |
83 | def test_one_user(x):
84 | # user u's ratings for user u
85 | is_val = x[-1]
86 | rating = x[0]
87 | #uid
88 | u = x[1]
89 | #user u's items in the training set
90 | try:
91 | training_items = data_generator.train_items[u]
92 | except Exception:
93 | training_items = []
94 | #user u's items in the test set
95 | if is_val:
96 | user_pos_test = data_generator.val_set[u]
97 | else:
98 | user_pos_test = data_generator.test_set[u]
99 |
100 | all_items = set(range(ITEM_NUM))
101 |
102 | test_items = list(all_items - set(training_items))
103 |
104 | if args.test_flag == 'part':
105 | r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, Ks)
106 | else:
107 | r, auc = ranklist_by_sorted(user_pos_test, test_items, rating, Ks)
108 |
109 | return get_performance(user_pos_test, r, auc, Ks)
110 |
111 |
112 | def test_torch(ua_embeddings, ia_embeddings, users_to_test, is_val, drop_flag=False, batch_test_flag=False):
113 | result = {'precision': np.zeros(len(Ks)), 'recall': np.zeros(len(Ks)), 'ndcg': np.zeros(len(Ks)),
114 | 'hit_ratio': np.zeros(len(Ks)), 'auc': 0.}
115 | pool = multiprocessing.Pool(cores)
116 |
117 | u_batch_size = BATCH_SIZE * 2
118 | i_batch_size = BATCH_SIZE
119 |
120 | test_users = users_to_test
121 | n_test_users = len(test_users)
122 | n_user_batchs = n_test_users // u_batch_size + 1
123 | count = 0
124 |
125 | for u_batch_id in range(n_user_batchs):
126 | start = u_batch_id * u_batch_size
127 | end = (u_batch_id + 1) * u_batch_size
128 | user_batch = test_users[start: end]
129 | if batch_test_flag:
130 | n_item_batchs = ITEM_NUM // i_batch_size + 1
131 | rate_batch = np.zeros(shape=(len(user_batch), ITEM_NUM))
132 |
133 | i_count = 0
134 | for i_batch_id in range(n_item_batchs):
135 | i_start = i_batch_id * i_batch_size
136 | i_end = min((i_batch_id + 1) * i_batch_size, ITEM_NUM)
137 |
138 | item_batch = range(i_start, i_end)
139 | u_g_embeddings = ua_embeddings[user_batch]
140 | i_g_embeddings = ia_embeddings[item_batch]
141 | i_rate_batch = torch.matmul(u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1))
142 |
143 | rate_batch[:, i_start: i_end] = i_rate_batch
144 | i_count += i_rate_batch.shape[1]
145 |
146 | assert i_count == ITEM_NUM
147 |
148 | else:
149 | item_batch = range(ITEM_NUM)
150 | u_g_embeddings = ua_embeddings[user_batch]
151 | i_g_embeddings = ia_embeddings[item_batch]
152 | rate_batch = torch.matmul(u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1))
153 |
154 | rate_batch = rate_batch.detach().cpu().numpy()
155 | user_batch_rating_uid = zip(rate_batch, user_batch, [is_val] * len(user_batch))
156 |
157 | batch_result = pool.map(test_one_user, user_batch_rating_uid)
158 | count += len(batch_result)
159 |
160 | for re in batch_result:
161 | result['precision'] += re['precision'] / n_test_users
162 | result['recall'] += re['recall'] / n_test_users
163 | result['ndcg'] += re['ndcg'] / n_test_users
164 | result['hit_ratio'] += re['hit_ratio'] / n_test_users
165 | result['auc'] += re['auc'] / n_test_users
166 |
167 | assert count == n_test_users
168 | pool.close()
169 | return result
170 |
--------------------------------------------------------------------------------
/LATTICE/codes/utility/load_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random as rd
3 | import scipy.sparse as sp
4 | from time import time
5 | import json
6 | from utility.parser import parse_args
7 | args = parse_args()
8 |
9 | class Data(object):
10 | def __init__(self, path, batch_size):
11 | # self.path = path + '/%d-core' % args.core
12 | # self.batch_size = batch_size
13 |
14 | # train_file = path + '/%d-core/train.json' % (args.core)
15 | # val_file = path + '/%d-core/val.json' % (args.core)
16 | # test_file = path + '/%d-core/test.json' % (args.core)
17 |
18 | self.path = path #+ '/%d-core' % args.core
19 | self.batch_size = batch_size
20 |
21 | train_file = path + '/train.json'#+ '/%d-core/train.json' % (args.core)
22 | val_file = path + '/val.json' #+ '/%d-core/val.json' % (args.core)
23 | test_file = path + '/test.json' #+ '/%d-core/test.json' % (args.core)
24 |
25 | #get number of users and items
26 | self.n_users, self.n_items = 0, 0
27 | self.n_train, self.n_test = 0, 0
28 | self.neg_pools = {}
29 |
30 | self.exist_users = []
31 |
32 | train = json.load(open(train_file))
33 | test = json.load(open(test_file))
34 | val = json.load(open(val_file))
35 | for uid, items in train.items():
36 | if len(items) == 0:
37 | continue
38 | uid = int(uid)
39 | self.exist_users.append(uid)
40 | self.n_items = max(self.n_items, max(items))
41 | self.n_users = max(self.n_users, uid)
42 | self.n_train += len(items)
43 |
44 | for uid, items in test.items():
45 | uid = int(uid)
46 | try:
47 | self.n_items = max(self.n_items, max(items))
48 | self.n_test += len(items)
49 | except:
50 | continue
51 |
52 | for uid, items in val.items():
53 | uid = int(uid)
54 | try:
55 | self.n_items = max(self.n_items, max(items))
56 | self.n_val += len(items)
57 | except:
58 | continue
59 |
60 | self.n_items += 1
61 | self.n_users += 1
62 |
63 | text_feats = np.load(args.data_path + '{}/text_feat.npy'.format(args.dataset))
64 | self.n_items = text_feats.shape[0]
65 |
66 | self.print_statistics()
67 |
68 | self.R = sp.dok_matrix((self.n_users, self.n_items), dtype=np.float32)
69 | self.R_Item_Interacts = sp.dok_matrix((self.n_items, self.n_items), dtype=np.float32)
70 |
71 | self.train_items, self.test_set, self.val_set = {}, {}, {}
72 | for uid, train_items in train.items():
73 | if len(train_items) == 0:
74 | continue
75 | uid = int(uid)
76 | for idx, i in enumerate(train_items):
77 | self.R[uid, i] = 1.
78 |
79 | self.train_items[uid] = train_items
80 |
81 | for uid, test_items in test.items():
82 | uid = int(uid)
83 | if len(test_items) == 0:
84 | continue
85 | try:
86 | self.test_set[uid] = test_items
87 | except:
88 | continue
89 |
90 | for uid, val_items in val.items():
91 | uid = int(uid)
92 | if len(val_items) == 0:
93 | continue
94 | try:
95 | self.val_set[uid] = val_items
96 | except:
97 | continue
98 |
99 | def get_adj_mat(self):
100 | try:
101 | t1 = time()
102 | adj_mat = sp.load_npz(self.path + '/s_adj_mat.npz')
103 | norm_adj_mat = sp.load_npz(self.path + '/s_norm_adj_mat.npz')
104 | mean_adj_mat = sp.load_npz(self.path + '/s_mean_adj_mat.npz')
105 | print('already load adj matrix', adj_mat.shape, time() - t1)
106 |
107 | except Exception:
108 | adj_mat, norm_adj_mat, mean_adj_mat = self.create_adj_mat()
109 | sp.save_npz(self.path + '/s_adj_mat.npz', adj_mat)
110 | sp.save_npz(self.path + '/s_norm_adj_mat.npz', norm_adj_mat)
111 | sp.save_npz(self.path + '/s_mean_adj_mat.npz', mean_adj_mat)
112 | return adj_mat, norm_adj_mat, mean_adj_mat
113 |
114 | def create_adj_mat(self):
115 | t1 = time()
116 | adj_mat = sp.dok_matrix((self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32)
117 | adj_mat = adj_mat.tolil()
118 | R = self.R.tolil()
119 |
120 | adj_mat[:self.n_users, self.n_users:] = R
121 | adj_mat[self.n_users:, :self.n_users] = R.T
122 | adj_mat = adj_mat.todok()
123 | print('already create adjacency matrix', adj_mat.shape, time() - t1)
124 |
125 | t2 = time()
126 |
127 | def normalized_adj_single(adj):
128 | rowsum = np.array(adj.sum(1))
129 |
130 | d_inv = np.power(rowsum, -1).flatten()
131 | d_inv[np.isinf(d_inv)] = 0.
132 | d_mat_inv = sp.diags(d_inv)
133 |
134 | norm_adj = d_mat_inv.dot(adj)
135 | # norm_adj = adj.dot(d_mat_inv)
136 | print('generate single-normalized adjacency matrix.')
137 | return norm_adj.tocoo()
138 |
139 | def get_D_inv(adj):
140 | rowsum = np.array(adj.sum(1))
141 |
142 | d_inv = np.power(rowsum, -1).flatten()
143 | d_inv[np.isinf(d_inv)] = 0.
144 | d_mat_inv = sp.diags(d_inv)
145 | return d_mat_inv
146 |
147 | def check_adj_if_equal(adj):
148 | dense_A = np.array(adj.todense())
149 | degree = np.sum(dense_A, axis=1, keepdims=False)
150 |
151 | temp = np.dot(np.diag(np.power(degree, -1)), dense_A)
152 | print('check normalized adjacency matrix whether equal to this laplacian matrix.')
153 | return temp
154 |
155 | norm_adj_mat = normalized_adj_single(adj_mat + sp.eye(adj_mat.shape[0]))
156 | mean_adj_mat = normalized_adj_single(adj_mat)
157 |
158 | print('already normalize adjacency matrix', time() - t2)
159 | return adj_mat.tocsr(), norm_adj_mat.tocsr(), mean_adj_mat.tocsr()
160 |
161 |
162 | def sample(self):
163 | if self.batch_size <= self.n_users:
164 | users = rd.sample(self.exist_users, self.batch_size)
165 | else:
166 | users = [rd.choice(self.exist_users) for _ in range(self.batch_size)]
167 | # users = self.exist_users[:]
168 |
169 | def sample_pos_items_for_u(u, num):
170 | pos_items = self.train_items[u]
171 | n_pos_items = len(pos_items)
172 | pos_batch = []
173 | while True:
174 | if len(pos_batch) == num: break
175 | pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0]
176 | pos_i_id = pos_items[pos_id]
177 |
178 | if pos_i_id not in pos_batch:
179 | pos_batch.append(pos_i_id)
180 | return pos_batch
181 |
182 | def sample_neg_items_for_u(u, num):
183 | neg_items = []
184 | while True:
185 | if len(neg_items) == num: break
186 | neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
187 | if neg_id not in self.train_items[u] and neg_id not in neg_items:
188 | neg_items.append(neg_id)
189 | return neg_items
190 |
191 | def sample_neg_items_for_u_from_pools(u, num):
192 | neg_items = list(set(self.neg_pools[u]) - set(self.train_items[u]))
193 | return rd.sample(neg_items, num)
194 |
195 | pos_items, neg_items = [], []
196 | for u in users:
197 | pos_items += sample_pos_items_for_u(u, 1)
198 | neg_items += sample_neg_items_for_u(u, 1)
199 | # neg_items += sample_neg_items_for_u(u, 3)
200 | return users, pos_items, neg_items
201 |
202 |
203 |
204 | def print_statistics(self):
205 | print('n_users=%d, n_items=%d' % (self.n_users, self.n_items))
206 | print('n_interactions=%d' % (self.n_train + self.n_test))
207 | print('n_train=%d, n_test=%d, sparsity=%.5f' % (self.n_train, self.n_test, (self.n_train + self.n_test)/(self.n_users * self.n_items)))
208 |
209 |
--------------------------------------------------------------------------------
/LATTICE/codes/utility/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import roc_auc_score
3 |
4 | def recall(rank, ground_truth, N):
5 | return len(set(rank[:N]) & set(ground_truth)) / float(len(set(ground_truth)))
6 |
7 |
8 | def precision_at_k(r, k):
9 | """Score is precision @ k
10 | Relevance is binary (nonzero is relevant).
11 | Returns:
12 | Precision @ k
13 | Raises:
14 | ValueError: len(r) must be >= k
15 | """
16 | assert k >= 1
17 | r = np.asarray(r)[:k]
18 | return np.mean(r)
19 |
20 |
21 | def average_precision(r,cut):
22 | """Score is average precision (area under PR curve)
23 | Relevance is binary (nonzero is relevant).
24 | Returns:
25 | Average precision
26 | """
27 | r = np.asarray(r)
28 | out = [precision_at_k(r, k + 1) for k in range(cut) if r[k]]
29 | if not out:
30 | return 0.
31 | return np.sum(out)/float(min(cut, np.sum(r)))
32 |
33 |
34 | def mean_average_precision(rs):
35 | """Score is mean average precision
36 | Relevance is binary (nonzero is relevant).
37 | Returns:
38 | Mean average precision
39 | """
40 | return np.mean([average_precision(r) for r in rs])
41 |
42 |
43 | def dcg_at_k(r, k, method=1):
44 | """Score is discounted cumulative gain (dcg)
45 | Relevance is positive real values. Can use binary
46 | as the previous methods.
47 | Returns:
48 | Discounted cumulative gain
49 | """
50 | r = np.asfarray(r)[:k]
51 | if r.size:
52 | if method == 0:
53 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
54 | elif method == 1:
55 | return np.sum(r / np.log2(np.arange(2, r.size + 2)))
56 | else:
57 | raise ValueError('method must be 0 or 1.')
58 | return 0.
59 |
60 |
61 | def ndcg_at_k(r, k, method=1):
62 | """Score is normalized discounted cumulative gain (ndcg)
63 | Relevance is positive real values. Can use binary
64 | as the previous methods.
65 | Returns:
66 | Normalized discounted cumulative gain
67 | """
68 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)
69 | if not dcg_max:
70 | return 0.
71 | return dcg_at_k(r, k, method) / dcg_max
72 |
73 |
74 | def recall_at_k(r, k, all_pos_num):
75 | r = np.asfarray(r)[:k]
76 | if all_pos_num == 0:
77 | return 0
78 | else:
79 | return np.sum(r) / all_pos_num
80 |
81 |
82 | def hit_at_k(r, k):
83 | r = np.array(r)[:k]
84 | if np.sum(r) > 0:
85 | return 1.
86 | else:
87 | return 0.
88 |
89 | def F1(pre, rec):
90 | if pre + rec > 0:
91 | return (2.0 * pre * rec) / (pre + rec)
92 | else:
93 | return 0.
94 |
95 | def auc(ground_truth, prediction):
96 | try:
97 | res = roc_auc_score(y_true=ground_truth, y_score=prediction)
98 | except Exception:
99 | res = 0.
100 | return res
--------------------------------------------------------------------------------
/LATTICE/codes/utility/parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_args():
4 | parser = argparse.ArgumentParser(description="")
5 |
6 | parser.add_argument('--data_path', nargs='?', default='',
7 | help='Input data path.')
8 | parser.add_argument('--seed', type=int, default=123,
9 | help='Random seed')
10 | parser.add_argument('--dataset', nargs='?', default='cloth',
11 | help='Choose a dataset from {grocery, cloth, sport, netflix, sports, baby, clothing}')
12 | parser.add_argument('--verbose', type=int, default=5,
13 | help='Interval of evaluation.')
14 | parser.add_argument('--epoch', type=int, default=200,
15 | help='Number of epoch.')
16 | parser.add_argument('--batch_size', type=int, default=1024,
17 | help='Batch size.')
18 | parser.add_argument('--regs', nargs='?', default='[1e-5,1e-5,1e-2]',
19 | help='Regularizations.')
20 | parser.add_argument('--lr', type=float, default=0.0005,
21 | help='Learning rate.')
22 | parser.add_argument('--model_name', nargs='?', default='lattice',
23 | help='Specify the model name.')
24 |
25 | parser.add_argument('--embed_size', type=int, default=64,
26 | help='Embedding size.')
27 | parser.add_argument('--feat_embed_dim', type=int, default=64,
28 | help='')
29 | parser.add_argument('--weight_size', nargs='?', default='[64,64]',
30 | help='Output sizes of every layer')
31 | parser.add_argument('--core', type=int, default=5,
32 | help='5-core for warm-start; 0-core for cold start')
33 | parser.add_argument('--topk', type=int, default=10,
34 | help='K value of k-NN sparsification')
35 | parser.add_argument('--lambda_coeff', type=float, default=0.9,
36 | help='Lambda value of skip connection')
37 | parser.add_argument('--cf_model', nargs='?', default='lightgcn',
38 | help='Downstream Collaborative Filtering model {mf, ngcf, lightgcn}')
39 | parser.add_argument('--n_layers', type=int, default=1,
40 | help='Number of item graph conv layers')
41 | parser.add_argument('--mess_dropout', nargs='?', default='[0.1, 0.1]',
42 | help='Keep probability w.r.t. message dropout (i.e., 1-dropout_ratio) for each deep layer. 1: no dropout.')
43 |
44 | parser.add_argument('--early_stopping_patience', type=int, default=10,
45 | help='')
46 | parser.add_argument('--gpu_id', type=int, default=0,
47 | help='GPU id')
48 | parser.add_argument('--Ks', nargs='?', default='[10, 20, 50]',
49 | help='K value of ndcg/recall @ k')
50 | parser.add_argument('--test_flag', nargs='?', default='part',
51 | help='Specify the test type from {part, full}, indicating whether the reference is done in mini-batch')
52 |
53 |
54 | return parser.parse_args()
55 |
--------------------------------------------------------------------------------
/PromptMM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/PromptMM.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PromptMM: Multi-Modal Knowledge Distillation for Recommendation with Prompt-Tuning
2 |
3 | PyTorch implementation for WWW 2023 paper [PromptMM: Multi-Modal Knowledge Distillation for Recommendation with Prompt-Tuning](https://arxiv.org/html/2402.17188v1).
4 |
5 | [Wei Wei](#), [Jiabin Tang](https://tjb-tech.github.io/), [Yangqin Jiang](#), [Lianghao Xia](https://akaxlh.github.io/) and [Chao Huang](https://sites.google.com/view/chaoh/home)*.
6 | (*Correspondence)
7 |
8 |
9 |
10 |
11 |
12 |
13 | Dependencies
14 |
15 | * Python >= 3.9.13
16 | * [Pytorch](https://pytorch.org/) >= 1.13.0+cu116
17 | * [dgl-cuda11.6](https://www.dgl.ai/) >= 0.9.1post1
18 |
19 |
20 |
21 |
22 | Usage
23 |
24 | Start training and inference as:
25 |
26 | ```
27 | python ./main.py --dataset {DATASET}
28 | ```
29 | Supported datasets: `Amazon-Electronics`, `Netflix`, `Tiktok`
30 |
31 |
32 | Datasets
33 |
34 | ```
35 | ├─ MMSSL/
36 | ├── data/
37 | ├── tiktok/
38 | ...
39 | ```
40 | | Dataset | | Netflix | | | Tiktok | | | | Electronics | |
41 | |:-----------:|:-:|:--------:|:---:|:-:|:--------:|:---:|:---:|:-:|:-----------:|:----:|
42 | | Modality | | V | T | | V | A | T | | V | T |
43 | | Feat. Dim. | | 512 | 768 | | 128 | 128 | 768 | | 4096 | 1024 |
44 | | User | | 43,739 | | | 14,343 | | | | 41,691 | |
45 | | Item | | 17,239 | | | 8,690 | | | | 21,479 | |
46 | | Interaction | | 609,341 | | | 276,637 | | | | 359,165 | |
47 | | Sparsity | | 99.919\% | | | 99.778\% | | | | 99.960\% | |
48 |
49 |
50 | - `2024.2.27 new multi-modal datastes uploaded`: 📢📢 🌹🌹 We provide new multi-modal datasets `Netflix` and `MovieLens` (i.e., CF training data, multi-modal data including `item text` and `posters`) of new multi-modal work [LLMRec](https://github.com/HKUDS/LLMRec) on Google Drive. 🌹We hope to contribute to our community and facilitate your research~
51 |
52 | - `2023.2.27 update(all datasets uploaded)`: We provide the processed data at [Google Drive](https://drive.google.com/drive/folders/17vnX8S6a_68xzML1tAM5m9YsQyKZ1UKb?usp=share_link).
53 |
54 | 🚀🚀 The provided dataset is compatible with multi-modal recommender models such as [MMSSL](https://github.com/HKUDS/MMSSL), [LATTICE](https://github.com/CRIPAC-DIG/LATTICE), and [MICRO](https://github.com/CRIPAC-DIG/MICRO) and requires no additional data preprocessing, including (1) basic user-item interactions and (2) multi-modal features.
55 |
56 | ```
57 | # part of data preprocessing
58 | # #----json2mat--------------------------------------------------------------------------------------------------
59 | import json
60 | from scipy.sparse import csr_matrix
61 | import pickle
62 | import numpy as np
63 | n_user, n_item = 39387, 23033
64 | f = open('/home/weiw/Code/MM/MMSSL/data/clothing/train.json', 'r')
65 | train = json.load(f)
66 | row, col = [], []
67 | for index, value in enumerate(train.keys()):
68 | for i in range(len(train[value])):
69 | row.append(int(value))
70 | col.append(train[value][i])
71 | data = np.ones(len(row))
72 | train_mat = csr_matrix((data, (row, col)), shape=(n_user, n_item))
73 | pickle.dump(train_mat, open('./train_mat', 'wb'))
74 | # # ----json2mat--------------------------------------------------------------------------------------------------
75 |
76 |
77 | # ----mat2json--------------------------------------------------------------------------------------------------
78 | # train_mat = pickle.load(open('./train_mat', 'rb'))
79 | test_mat = pickle.load(open('./test_mat', 'rb'))
80 | # val_mat = pickle.load(open('./val_mat', 'rb'))
81 |
82 | # total_mat = train_mat + test_mat + val_mat
83 | total_mat =test_mat
84 |
85 | # total_mat = pickle.load(open('./new_mat','rb'))
86 | # total_mat = pickle.load(open('./new_mat','rb'))
87 | total_array = total_mat.toarray()
88 | total_dict = {}
89 |
90 | for i in range(total_array.shape[0]):
91 | total_dict[str(i)] = [index for index, value in enumerate(total_array[i]) if value!=0]
92 |
93 | new_total_dict = {}
94 |
95 | for i in range(len(total_dict)):
96 | # if len(total_dict[str(i)])>1:
97 | new_total_dict[str(i)]=total_dict[str(i)]
98 |
99 | # train_dict, test_dict = {}, {}
100 |
101 | # for i in range(len(new_total_dict)):
102 | # train_dict[str(i)] = total_dict[str(i)][:-1]
103 | # test_dict[str(i)] = [total_dict[str(i)][-1]]
104 |
105 | # train_json_str = json.dumps(train_dict)
106 | test_json_str = json.dumps(new_total_dict)
107 |
108 | # with open('./new_train.json', 'w') as json_file:
109 | # # with open('./new_train_json', 'w') as json_file:
110 | # json_file.write(train_json_str)
111 | with open('./test.json', 'w') as test_file:
112 | # with open('./new_test_json', 'w') as test_file:
113 | test_file.write(test_json_str)
114 | # ----mat2json--------------------------------------------------------------------------------------------------
115 | ```
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 | ## Acknowledgement
124 |
125 | ## Acknowledgement
126 |
127 | The structure of this code is largely based on [LATTICE](https://github.com/CRIPAC-DIG/LATTICE), [MICRO](https://github.com/CRIPAC-DIG/MICRO). Thank them for their work.
128 |
129 |
--------------------------------------------------------------------------------
/codes/Models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from time import time
4 | import pickle
5 | import pickle
6 | import scipy.sparse as sp
7 | from scipy.sparse import csr_matrix
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn import init
13 |
14 | from sklearn.decomposition import PCA, FastICA
15 | from sklearn import manifold
16 | from sklearn.manifold import TSNE
17 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
18 |
19 | # from utility.parser import parse_args
20 | from utility.norm import build_sim, build_knn_normalized_graph
21 | # args = parse_args()
22 | from utility.parser import args
23 |
24 |
25 |
26 | class Teacher_Model(nn.Module):
27 | def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats, text_feats):
28 |
29 | super().__init__()
30 | self.n_users = n_users
31 | self.n_items = n_items
32 | self.embedding_dim = embedding_dim
33 | self.weight_size = weight_size
34 | self.n_ui_layers = len(self.weight_size)
35 | self.weight_size = [self.embedding_dim] + self.weight_size
36 |
37 | self.image_trans = nn.Linear(image_feats.shape[1], args.embed_size)
38 | self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size)
39 | nn.init.xavier_uniform_(self.image_trans.weight)
40 | nn.init.xavier_uniform_(self.text_trans.weight)
41 | self.encoder = nn.ModuleDict()
42 | self.encoder['image_encoder'] = self.image_trans # ^-^
43 | self.encoder['text_encoder'] = self.text_trans # ^-^
44 |
45 | self.user_id_embedding = nn.Embedding(n_users, self.embedding_dim)
46 | self.item_id_embedding = nn.Embedding(n_items, self.embedding_dim)
47 |
48 | nn.init.xavier_uniform_(self.user_id_embedding.weight)
49 | nn.init.xavier_uniform_(self.item_id_embedding.weight)
50 | self.image_feats = torch.tensor(image_feats).float().cuda()
51 | self.text_feats = torch.tensor(text_feats).float().cuda()
52 | self.image_embedding = nn.Embedding.from_pretrained(torch.Tensor(image_feats), freeze=False)
53 | self.text_embedding = nn.Embedding.from_pretrained(torch.Tensor(text_feats), freeze=False)
54 |
55 | self.softmax = nn.Softmax(dim=-1)
56 | self.act = nn.Sigmoid()
57 | self.sigmoid = nn.Sigmoid()
58 | self.dropout = nn.Dropout(p=args.drop_rate)
59 | self.batch_norm = nn.BatchNorm1d(args.embed_size)
60 |
61 | def mm(self, x, y):
62 | if args.sparse:
63 | return torch.sparse.mm(x, y)
64 | else:
65 | return torch.mm(x, y)
66 | def sim(self, z1, z2):
67 | z1 = F.normalize(z1)
68 | z2 = F.normalize(z2)
69 | return torch.mm(z1, z2.t())
70 |
71 | def batched_contrastive_loss(self, z1, z2, batch_size=4096):
72 | device = z1.device
73 | num_nodes = z1.size(0)
74 | num_batches = (num_nodes - 1) // batch_size + 1
75 | f = lambda x: torch.exp(x / self.tau)
76 | indices = torch.arange(0, num_nodes).to(device)
77 | losses = []
78 |
79 | for i in range(num_batches):
80 | mask = indices[i * batch_size:(i + 1) * batch_size]
81 | refl_sim = f(self.sim(z1[mask], z1))
82 | between_sim = f(self.sim(z1[mask], z2))
83 |
84 | losses.append(-torch.log(
85 | between_sim[:, i * batch_size:(i + 1) * batch_size].diag()
86 | / (refl_sim.sum(1) + between_sim.sum(1)
87 | - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag())))
88 |
89 | loss_vec = torch.cat(losses)
90 | return loss_vec.mean()
91 |
92 | def csr_norm(self, csr_mat, mean_flag=False):
93 | rowsum = np.array(csr_mat.sum(1))
94 | rowsum = np.power(rowsum+1e-8, -0.5).flatten()
95 | rowsum[np.isinf(rowsum)] = 0.
96 | rowsum_diag = sp.diags(rowsum)
97 |
98 | colsum = np.array(csr_mat.sum(0))
99 | colsum = np.power(colsum+1e-8, -0.5).flatten()
100 | colsum[np.isinf(colsum)] = 0.
101 | colsum_diag = sp.diags(colsum)
102 |
103 | if mean_flag == False:
104 | return rowsum_diag*csr_mat*colsum_diag
105 | else:
106 | return rowsum_diag*csr_mat
107 |
108 | def matrix_to_tensor(self, cur_matrix):
109 | if type(cur_matrix) != sp.coo_matrix:
110 | cur_matrix = cur_matrix.tocoo() #
111 | indices = torch.from_numpy(np.vstack((cur_matrix.row, cur_matrix.col)).astype(np.int64)) #
112 | values = torch.from_numpy(cur_matrix.data) #
113 | shape = torch.Size(cur_matrix.shape)
114 |
115 | return torch.sparse.FloatTensor(indices, values, shape).to(torch.float32).cuda() #
116 |
117 | def para_dict_to_tenser(self, para_dict):
118 | """
119 | :param para_dict: nn.ParameterDict()
120 | :return: tensor
121 | """
122 | tensors = []
123 |
124 | for beh in para_dict.keys():
125 | tensors.append(para_dict[beh])
126 | tensors = torch.stack(tensors, dim=0)
127 |
128 | return tensors
129 |
130 |
131 | def multi_head_self_attention(self, trans_w, embedding_t_1, embedding_t):
132 |
133 | q = self.para_dict_to_tenser(embedding_t)
134 | v = k = self.para_dict_to_tenser(embedding_t_1)
135 | beh, N, d_h = q.shape[0], q.shape[1], args.embed_size/args.head_num
136 |
137 | Q = torch.matmul(q, trans_w['w_q'])
138 | K = torch.matmul(k, trans_w['w_k'])
139 | V = v
140 |
141 | Q = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3)
142 | K = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3)
143 |
144 | Q = torch.unsqueeze(Q, 2)
145 | K = torch.unsqueeze(K, 1)
146 | V = torch.unsqueeze(V, 1)
147 |
148 | att = torch.mul(Q, K) / torch.sqrt(torch.tensor(d_h))
149 | att = torch.sum(att, dim=-1)
150 | att = torch.unsqueeze(att, dim=-1)
151 | att = F.softmax(att, dim=2)
152 |
153 | Z = torch.mul(att, V)
154 | Z = torch.sum(Z, dim=2)
155 |
156 | Z_list = [value for value in Z]
157 | Z = torch.cat(Z_list, -1)
158 | Z = torch.matmul(Z, self.weight_dict['w_self_attention_cat'])
159 |
160 | args.model_cat_rate*F.normalize(Z, p=2, dim=2)
161 | return Z, att.detach()
162 |
163 | # def prompt_tuning(self, soft_token_u, soft_token_i):
164 | # # self.user_id_embedding = nn.Embedding.from_pretrained(soft_token_u, freeze=False)
165 | # # self.item_id_embedding = nn.Embedding.from_pretrained(soft_token_u, freeze=False)
166 | # self.prompt_user = soft_token_u
167 | # self.prompt_item = soft_token_i
168 |
169 | def forward(self, ui_graph, iu_graph, prompt_module=None):
170 |
171 | # def forward(self, ui_graph, iu_graph):
172 |
173 | prompt_user, prompt_item = prompt_module() # [n*32]
174 | # ----feature prompt----
175 | # feat_prompt_user = torch.mean( torch.stack((torch.mm(prompt_user, torch.mm(prompt_user.T, self.image_feats)), torch.mm(prompt_user, torch.mm(prompt_user.T, self.text_feats)))), dim=0 )
176 | # feat_prompt_user = torch.mm(prompt_user, torch.mm(prompt_user.T, self.text_feats))
177 | feat_prompt_item_image = torch.mm(prompt_item, torch.mm(prompt_item.T, self.image_feats))
178 | feat_prompt_item_text = torch.mm(prompt_item, torch.mm(prompt_item.T, self.text_feats))
179 | # feat_prompt_image_item = torch.mm(prompt_item, torch.mm(prompt_item.T, self.image_feats))
180 | # feat_prompt_text_item = torch.mm(prompt_item, torch.mm(prompt_item.T, self.text_feats))
181 | # ----feature prompt----
182 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + feat_prompt_item_image ))
183 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + feat_prompt_item_text ))
184 |
185 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + F.normalize(feat_prompt_item_image, p=2, dim=1) ))
186 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + F.normalize(feat_prompt_item_text, p=2, dim=1) ))
187 |
188 | image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + args.feat_soft_token_rate*F.normalize(feat_prompt_item_image, p=2, dim=1) ))
189 | text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + args.feat_soft_token_rate*F.normalize(feat_prompt_item_text, p=2, dim=1) ))
190 | # args.feat_soft_token_rate*F.normalize(feat_prompt_item_image, p=2, dim=1)
191 | # args.feat_soft_token_rate*F.normalize(feat_prompt_item_text, p=2, dim=1)
192 |
193 |
194 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats))
195 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats))
196 |
197 | for i in range(args.layers):
198 | image_user_feats = self.mm(ui_graph, image_feats)
199 | image_item_feats = self.mm(iu_graph, image_user_feats)
200 | # image_user_id = self.mm(image_ui_graph, self.item_id_embedding.weight)
201 | # image_item_id = self.mm(image_iu_graph, self.user_id_embedding.weight)
202 |
203 | text_user_feats = self.mm(ui_graph, text_feats)
204 | text_item_feats = self.mm(iu_graph, text_user_feats)
205 |
206 | # text_user_id = self.mm(text_ui_graph, self.item_id_embedding.weight)
207 | # text_item_id = self.mm(text_iu_graph, self.user_id_embedding.weight)
208 |
209 | # self.embedding_dict['user']['image'] = image_user_id
210 | # self.embedding_dict['user']['text'] = text_user_id
211 | # self.embedding_dict['item']['image'] = image_item_id
212 | # self.embedding_dict['item']['text'] = text_item_id
213 | # user_z, att_u = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['user'], self.embedding_dict['user'])
214 | # item_z, att_i = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['item'], self.embedding_dict['item'])
215 | # user_emb = user_z.mean(0)
216 | # item_emb = item_z.mean(0)
217 | u_g_embeddings = self.user_id_embedding.weight + args.soft_token_rate*F.normalize(prompt_user, p=2, dim=1)
218 | i_g_embeddings = self.item_id_embedding.weight + args.soft_token_rate*F.normalize(prompt_item, p=2, dim=1)
219 | user_emb_list = [u_g_embeddings]
220 | item_emb_list = [i_g_embeddings]
221 | for i in range(self.n_ui_layers):
222 | if i == (self.n_ui_layers-1):
223 | u_g_embeddings = self.softmax( torch.mm(ui_graph, i_g_embeddings) )
224 | i_g_embeddings = self.softmax( torch.mm(iu_graph, u_g_embeddings) )
225 |
226 | else:
227 | u_g_embeddings = torch.mm(ui_graph, i_g_embeddings)
228 | i_g_embeddings = torch.mm(iu_graph, u_g_embeddings)
229 |
230 | user_emb_list.append(u_g_embeddings)
231 | item_emb_list.append(i_g_embeddings)
232 |
233 | u_g_embeddings = torch.mean(torch.stack(user_emb_list), dim=0)
234 | i_g_embeddings = torch.mean(torch.stack(item_emb_list), dim=0)
235 |
236 |
237 | u_g_embeddings = u_g_embeddings + args.model_cat_rate*F.normalize(image_user_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_user_feats, p=2, dim=1)
238 | i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(image_item_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_item_feats, p=2, dim=1)
239 |
240 | return u_g_embeddings, i_g_embeddings, image_item_feats, text_item_feats, image_user_feats, text_user_feats, u_g_embeddings, i_g_embeddings , prompt_user, prompt_item
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 | class PromptLearner(nn.Module):
249 | def __init__(self, image_feats=None, text_feats=None, ui_graph=None):
250 | super().__init__()
251 | self.ui_graph = ui_graph
252 |
253 |
254 | if args.hard_token_type=='pca':
255 | try:
256 | t1 = time()
257 | hard_token_image = pickle.load(open(args.data_path + args.dataset + '/hard_token_image_pca','rb'))
258 | hard_token_text = pickle.load(open(args.data_path + args.dataset + '/hard_token_text_pca','rb'))
259 | print('already load hard token', time() - t1)
260 | except Exception:
261 | hard_token_image = PCA(n_components=args.embed_size).fit_transform(image_feats)
262 | hard_token_text = PCA(n_components=args.embed_size).fit_transform(text_feats)
263 | pickle.dump(hard_token_image, open(args.data_path + args.dataset + '/hard_token_image_pca','wb'))
264 | pickle.dump(hard_token_text, open(args.data_path + args.dataset + '/hard_token_text_pca','wb'))
265 | elif args.hard_token_type=='ica':
266 | try:
267 | t1 = time()
268 | hard_token_image = pickle.load(open(args.data_path + args.dataset + '/hard_token_image_ica','rb'))
269 | hard_token_text = pickle.load(open(args.data_path + args.dataset + '/hard_token_text_ica','rb'))
270 | print('already load hard token', time() - t1)
271 | except Exception:
272 | hard_token_image = FastICA(n_components=args.embed_size, random_state=12).fit_transform(image_feats)
273 | hard_token_text = FastICA(n_components=args.embed_size, random_state=12).fit_transform(text_feats)
274 | pickle.dump(hard_token_image, open(args.data_path + args.dataset + '/hard_token_image_ica','wb'))
275 | pickle.dump(hard_token_text, open(args.data_path + args.dataset + '/hard_token_text_ica','wb'))
276 | elif args.hard_token_type=='isomap':
277 | hard_token_image = manifold.Isomap(n_neighbors=5, n_components=args.embed_size, n_jobs=-1).fit_transform(image_feats)
278 | hard_token_text = manifold.Isomap(n_neighbors=5, n_components=args.embed_size, n_jobs=-1).fit_transform(text_feats)
279 | # elif args.hard_token_type=='tsne':
280 | # hard_token_image = TSNE(n_components=args.embed_size, n_iter=300).fit_transform(image_feats)
281 | # hard_token_text = TSNE(n_components=args.embed_size, n_iter=300).fit_transform(text_feats)
282 | # elif args.hard_token_type=='lda':
283 | # hard_token_image = LinearDiscriminantAnalysis(n_components=args.embed_size).fit_transform(image_feats)
284 | # hard_token_text = LinearDiscriminantAnalysis(n_components=args.embed_size).fit_transform(text_feats)
285 |
286 | # self.item_hard_token = nn.Embedding.from_pretrained(torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0), freeze=False).cuda().weight
287 | # self.user_hard_token = nn.Embedding.from_pretrained(torch.mm(ui_graph, self.item_hard_token), freeze=False).cuda().weight
288 |
289 | self.item_hard_token = torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0).cuda()
290 | self.user_hard_token = torch.mm(ui_graph, self.item_hard_token).cuda()
291 |
292 | self.trans_user = nn.Linear(args.embed_size, args.embed_size).cuda()
293 | self.trans_item = nn.Linear(args.embed_size, args.embed_size).cuda()
294 | # nn.init.xavier_uniform_(self.gnn_trans_user.weight)
295 | # nn.init.xavier_uniform_(self.gnn_trans_item.weight)
296 | # self.gnn_trans_user = self.gnn_trans_user.cuda()
297 | # self.gnn_trans_item = self.gnn_trans_item.cuda()
298 | # self.item_hard_token = torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0).cuda()
299 |
300 |
301 | def forward(self):
302 | # self.user_hard_token = self.gnn_trans_user(torch.mm(self.ui_graph, self.item_hard_token))
303 | # self.item_hard_token = self.gnn_trans_item(self.item_hard_token)
304 | # return self.user_hard_token , self.item_hard_token
305 | return F.dropout(self.trans_user(self.user_hard_token), args.prompt_dropout) , F.dropout(self.trans_item(self.item_hard_token), args.prompt_dropout)
306 |
307 |
308 |
309 |
310 | class Student_LightGCN(nn.Module):
311 | def __init__(self, n_users, n_items, embedding_dim, gnn_layer, dropout_list, image_feats=None, text_feats=None):
312 | super().__init__()
313 | self.n_users = n_users
314 | self.n_items = n_items
315 | self.embedding_dim = embedding_dim
316 | self.n_ui_layers = gnn_layer
317 |
318 | self.user_id_embedding = nn.Embedding(n_users, embedding_dim)
319 | self.item_id_embedding = nn.Embedding(n_items, embedding_dim)
320 | nn.init.xavier_uniform_(self.user_id_embedding.weight)
321 | nn.init.xavier_uniform_(self.item_id_embedding.weight)
322 |
323 | # self.feat_trans = nn.Linear(args.embed_size, args.student_embed_size)
324 | # # self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size)
325 | # nn.init.xavier_uniform_(self.feat_trans.weight)
326 | # # nn.init.xavier_uniform_(self.text_trans.weight)
327 |
328 | def init_user_item_embed(self, pre_u_embed, pre_i_embed):
329 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False)
330 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False)
331 |
332 | self.user_id_embedding_pre = nn.Embedding.from_pretrained(pre_u_embed, freeze=False)
333 | self.item_id_embedding_pre = nn.Embedding.from_pretrained(pre_i_embed, freeze=False)
334 |
335 | def get_embedding(self):
336 | return self.user_id_embedding, self.item_id_embedding
337 |
338 | def forward(self, adj):
339 |
340 | # # teacher_feat_dict = { 'item_image':t_i_image_embed.deteach(),'item_text':t_i_text_embed.deteach(),'user_image':t_u_image_embed.deteach(),'user_text':t_u_text_embed.deteach() }
341 | # tmp_feat_dict = {}
342 | # for index,value in enumerate(teacher_feat_dict.keys()):
343 | # tmp_feat_dict[value] = self.feat_trans(teacher_feat_dict[value])
344 | # u_g_embeddings = self.user_id_embedding.weight + args.model_cat_rate*F.normalize(tmp_feat_dict['user_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(tmp_feat_dict['user_text'], p=2, dim=1)
345 | # i_g_embeddings = self.item_id_embedding.weight + args.model_cat_rate*F.normalize(tmp_feat_dict['item_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(tmp_feat_dict['item_text'], p=2, dim=1)
346 | # ego_embeddings = torch.cat((u_g_embeddings, i_g_embeddings), dim=0)
347 |
348 | # self.user_id_embedding_pre = nn.Embedding.from_pretrained(pre_u_embed, freeze=False)
349 | # self.item_id_embedding_pre = nn.Embedding.from_pretrained(pre_i_embed, freeze=False)
350 |
351 | ego_embeddings = torch.cat((self.user_id_embedding.weight+self.user_id_embedding_pre.weight, self.item_id_embedding.weight+self.item_id_embedding_pre.weight), dim=0)
352 | # ego_embeddings = torch.cat((self.user_id_embedding.weight, self.item_id_embedding.weight), dim=0)
353 | all_embeddings = [ego_embeddings]
354 | for i in range(self.n_ui_layers):
355 | side_embeddings = torch.sparse.mm(adj, ego_embeddings)
356 | ego_embeddings = side_embeddings
357 | all_embeddings += [ego_embeddings]
358 | all_embeddings = torch.stack(all_embeddings, dim=1)
359 | all_embeddings = all_embeddings.mean(dim=1, keepdim=False)
360 | u_g_embeddings, i_g_embeddings = torch.split(all_embeddings, [self.n_users, self.n_items], dim=0)
361 | # u_g_embeddings += teacher_feat_dict['user_image'] + teacher_feat_dict['user_text']
362 | # i_g_embeddings += teacher_feat_dict['item_image'] + teacher_feat_dict['item_text']
363 | # u_g_embeddings = u_g_embeddings + args.model_cat_rate*F.normalize(teacher_feat_dict['user_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(teacher_feat_dict['user_text'], p=2, dim=1)
364 | # i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(teacher_feat_dict['item_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(teacher_feat_dict['item_text'], p=2, dim=1)
365 |
366 | return u_g_embeddings, i_g_embeddings
367 | # return self.user_id_embedding.weight, self.item_id_embedding.weight
368 |
369 |
370 |
371 | class Student_GCN(nn.Module):
372 | def __init__(self, n_users, n_items, embedding_dim, gnn_layer=2, drop_out=0., image_feats=None, text_feats=None):
373 | super(Student_GCN, self).__init__()
374 | self.embedding_dim = embedding_dim
375 |
376 | # self.layers = nn.Sequential(GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=True),
377 | # GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=False),
378 | # )
379 | # self.layer_list = nn.ModuleList()
380 | # for i in range(args.student_n_layers):
381 | # self.layer_list.append(GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=False))
382 |
383 | self.trans_user = nn.Linear(args.embed_size, args.embed_size).cuda()
384 | self.trans_item = nn.Linear(args.embed_size, args.embed_size).cuda()
385 |
386 |
387 | def forward(self, user_x, item_x, ui_graph, iu_graph):
388 | # # x, support = inputs
389 | # # user_x, item_x = self.layers((user_x, item_x, ui_graph, iu_graph))
390 | # for i in range(args.student_n_layers):
391 | # user_x, item_x = self.layer_list[i](user_x, item_x, ui_graph, iu_graph)
392 | # return user_x, item_x
393 |
394 | return self.trans_user(user_x), self.trans_item(item_x)
395 | # self.user_id_embedding = nn.Embedding.from_pretrained(user_x, freeze=True)
396 | # self.item_id_embedding = nn.Embedding.from_pretrained(item_x, freeze=True)
397 | # return self.user_id_embedding.weight, self.item_id_embedding.weight
398 |
399 | def l2_loss(self):
400 | layer = self.layers.children()
401 | layer = next(iter(layer))
402 | loss = None
403 |
404 | for p in layer.parameters():
405 | if loss is None:
406 | loss = p.pow(2).sum()
407 | else:
408 | loss += p.pow(2).sum()
409 |
410 | return loss
411 |
412 | class GraphConvolution(nn.Module):
413 | def __init__(self, input_dim, output_dim, dropout=0., is_sparse_inputs=False, bias=False, activation = F.relu,featureless=False):
414 | super(GraphConvolution, self).__init__()
415 | self.dropout = dropout
416 | self.bias = bias
417 | self.activation = activation
418 | self.is_sparse_inputs = is_sparse_inputs
419 | self.featureless = featureless
420 | # self.num_features_nonzero = num_features_nonzero
421 | # self.user_weight = nn.Parameter(torch.randn(input_dim, output_dim))
422 | # self.item_weight = nn.Parameter(torch.randn(input_dim, output_dim))
423 | self.user_weight = nn.Parameter(torch.empty(input_dim, output_dim))
424 | self.item_weight = nn.Parameter(torch.empty(input_dim, output_dim))
425 | nn.init.xavier_uniform_(self.user_weight)
426 | nn.init.xavier_uniform_(self.item_weight)
427 | self.bias = None
428 | if bias:
429 | self.bias = nn.Parameter(torch.zeros(output_dim))
430 |
431 |
432 | def forward(self, user_x, item_x, ui_graph, iu_graph):
433 | # print('inputs:', inputs)
434 | # x, support = inputs
435 | # if self.training and self.is_sparse_inputs:
436 | # x = sparse_dropout(x, self.dropout, self.num_features_nonzero)
437 | # elif self.training:
438 | user_x = F.dropout(user_x, self.dropout)
439 | item_x = F.dropout(item_x, self.dropout)
440 | # convolve
441 | if not self.featureless: # if it has features x
442 | if self.is_sparse_inputs:
443 | xw = torch.sparse.mm(user_x, self.user_weight)
444 | xw = torch.sparse.mm(item_x, self.item_weight)
445 | else:
446 | xw_user = torch.mm(user_x, self.user_weight)
447 | xw_item = torch.mm(item_x, self.item_weight)
448 | else:
449 | xw = self.weight
450 | out_user = torch.sparse.mm(ui_graph, xw_item)
451 | out_item = torch.sparse.mm(iu_graph, xw_user)
452 |
453 | if self.bias is not None:
454 | out += self.bias
455 | return self.activation(out_user), self.activation(out_item)
456 |
457 |
458 | def sparse_dropout(x, rate, noise_shape):
459 | """
460 | :param x:
461 | :param rate:
462 | :param noise_shape: int scalar
463 | :return:
464 | """
465 | random_tensor = 1 - rate
466 | random_tensor += torch.rand(noise_shape).to(x.device)
467 | dropout_mask = torch.floor(random_tensor).byte()
468 | i = x._indices() # [2, 49216]
469 | v = x._values() # [49216]
470 | # [2, 4926] => [49216, 2] => [remained node, 2] => [2, remained node]
471 | i = i[:, dropout_mask]
472 | v = v[dropout_mask]
473 | out = torch.sparse.FloatTensor(i, v, x.shape).to(x.device)
474 | out = out * (1./ (1-rate))
475 | return out
476 |
477 |
478 | def dot(x, y, sparse=False):
479 | if sparse:
480 | res = torch.sparse.mm(x, y)
481 | else:
482 | res = torch.mm(x, y)
483 | return res
484 |
485 |
486 |
487 |
488 |
489 | class BLMLP(nn.Module):
490 | def __init__(self):
491 | super(BLMLP, self).__init__()
492 | self.W = nn.Parameter(nn.init.xavier_uniform_(torch.empty(args.student_embed_size, args.student_embed_size)))
493 | self.act = nn.LeakyReLU(negative_slope=0.5)
494 |
495 | def forward(self, embeds):
496 | pass
497 |
498 | def featureExtract(self, embeds):
499 | return self.act(embeds @ self.W) + embeds
500 |
501 | def pairPred(self, embeds1, embeds2):
502 | return (self.featureExtract(embeds1) * self.featureExtract(embeds2)).sum(dim=-1)
503 |
504 | def crossPred(self, embeds1, embeds2):
505 | return self.featureExtract(embeds1) @ self.featureExtract(embeds2).T
506 |
507 |
508 |
509 | class Student_MLP(nn.Module):
510 | def __init__(self):
511 | super(Student_MLP, self).__init__()
512 | # self.n_users = n_users
513 | # self.n_items = n_items
514 | # self.embedding_dim = embedding_dim
515 |
516 | # self.uEmbeds = nn.Parameter(init(torch.empty(args.user, args.latdim)))
517 | # self.iEmbeds = nn.Parameter(init(torch.empty(args.item, args.latdim)))
518 |
519 | self.user_trans = nn.Linear(args.embed_size, args.embed_size)
520 | self.item_trans = nn.Linear(args.embed_size, args.embed_size)
521 | nn.init.xavier_uniform_(self.user_trans.weight)
522 | nn.init.xavier_uniform_(self.item_trans.weight)
523 |
524 | self.MLP = BLMLP()
525 | # self.overallTime = datetime.timedelta(0)
526 |
527 |
528 | def get_embedding(self):
529 | return self.user_id_embedding, self.item_id_embedding
530 |
531 |
532 | def forward(self, pre_user, pre_item, ):
533 | # pre_user, pre_item = self.user_id_embedding.weight, self.item_id_embedding.weight
534 | user_embed = self.user_trans(pre_user)
535 | item_embed = self.user_trans(pre_item)
536 |
537 | return user_embed, item_embed
538 | # return pre_user, pre_item
539 |
540 | def init_user_item_embed(self, pre_u_embed, pre_i_embed):
541 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False)
542 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False)
543 |
544 | def pointPosPredictwEmbeds(self, uEmbeds, iEmbeds, ancs, poss):
545 | ancEmbeds = uEmbeds[ancs]
546 | posEmbeds = iEmbeds[poss]
547 | nume = self.MLP.pairPred(ancEmbeds, posEmbeds)
548 | return nume
549 |
550 | def pointNegPredictwEmbeds(self, embeds1, embeds2, nodes1, temp=1.0):
551 | pckEmbeds1 = embeds1[nodes1]
552 | preds = self.MLP.crossPred(pckEmbeds1, embeds2)
553 | return torch.exp(preds / temp).sum(-1)
554 |
555 | def pairPredictwEmbeds(self, uEmbeds, iEmbeds, ancs, poss, negs):
556 | ancEmbeds = uEmbeds[ancs]
557 | posEmbeds = iEmbeds[poss]
558 | negEmbeds = iEmbeds[negs]
559 | posPreds = self.MLP.pairPred(ancEmbeds, posEmbeds)
560 | negPreds = self.MLP.pairPred(ancEmbeds, negEmbeds)
561 | return posPreds - negPreds
562 |
563 | def predAll(self, pckUEmbeds, iEmbeds):
564 | return self.MLP.crossPred(pckUEmbeds, iEmbeds)
565 |
566 | def testPred(self, usr, trnMask):
567 | uEmbeds, iEmbeds = self.forward()
568 | allPreds = self.predAll(uEmbeds[usr], iEmbeds) * (1 - trnMask) - trnMask * 1e8
569 | return allPreds
570 |
571 |
--------------------------------------------------------------------------------
/codes/Models_empower.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from time import time
4 | import pickle
5 | import pickle
6 | import scipy.sparse as sp
7 | from scipy.sparse import csr_matrix
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn import init
13 |
14 | from sklearn.decomposition import PCA, FastICA
15 | from sklearn import manifold
16 | from sklearn.manifold import TSNE
17 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
18 |
19 | # from utility.parser import parse_args
20 | from utility.norm import build_sim, build_knn_normalized_graph
21 | # args = parse_args()
22 | from utility.parser import args
23 |
24 |
25 |
26 | class Teacher_Model(nn.Module):
27 | def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats, text_feats):
28 |
29 | super().__init__()
30 | self.n_users = n_users
31 | self.n_items = n_items
32 | self.embedding_dim = embedding_dim
33 | self.weight_size = weight_size
34 | self.n_ui_layers = len(self.weight_size)
35 | self.weight_size = [self.embedding_dim] + self.weight_size
36 |
37 | self.image_trans = nn.Linear(image_feats.shape[1], args.embed_size)
38 | self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size)
39 | nn.init.xavier_uniform_(self.image_trans.weight)
40 | nn.init.xavier_uniform_(self.text_trans.weight)
41 | self.encoder = nn.ModuleDict()
42 | self.encoder['image_encoder'] = self.image_trans # ^-^
43 | self.encoder['text_encoder'] = self.text_trans # ^-^
44 |
45 | self.user_id_embedding = nn.Embedding(n_users, self.embedding_dim)
46 | self.item_id_embedding = nn.Embedding(n_items, self.embedding_dim)
47 |
48 | nn.init.xavier_uniform_(self.user_id_embedding.weight)
49 | nn.init.xavier_uniform_(self.item_id_embedding.weight)
50 | self.image_feats = torch.tensor(image_feats).float().cuda()
51 | self.text_feats = torch.tensor(text_feats).float().cuda()
52 | self.image_embedding = nn.Embedding.from_pretrained(torch.Tensor(image_feats), freeze=False)
53 | self.text_embedding = nn.Embedding.from_pretrained(torch.Tensor(text_feats), freeze=False)
54 |
55 | self.softmax = nn.Softmax(dim=-1)
56 | self.act = nn.Sigmoid()
57 | self.sigmoid = nn.Sigmoid()
58 | self.dropout = nn.Dropout(p=args.drop_rate)
59 | self.batch_norm = nn.BatchNorm1d(args.embed_size)
60 |
61 | def mm(self, x, y):
62 | if args.sparse:
63 | return torch.sparse.mm(x, y)
64 | else:
65 | return torch.mm(x, y)
66 | def sim(self, z1, z2):
67 | z1 = F.normalize(z1)
68 | z2 = F.normalize(z2)
69 | return torch.mm(z1, z2.t())
70 |
71 | def batched_contrastive_loss(self, z1, z2, batch_size=4096):
72 | device = z1.device
73 | num_nodes = z1.size(0)
74 | num_batches = (num_nodes - 1) // batch_size + 1
75 | f = lambda x: torch.exp(x / self.tau)
76 | indices = torch.arange(0, num_nodes).to(device)
77 | losses = []
78 |
79 | for i in range(num_batches):
80 | mask = indices[i * batch_size:(i + 1) * batch_size]
81 | refl_sim = f(self.sim(z1[mask], z1))
82 | between_sim = f(self.sim(z1[mask], z2))
83 |
84 | losses.append(-torch.log(
85 | between_sim[:, i * batch_size:(i + 1) * batch_size].diag()
86 | / (refl_sim.sum(1) + between_sim.sum(1)
87 | - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag())))
88 |
89 | loss_vec = torch.cat(losses)
90 | return loss_vec.mean()
91 |
92 | def csr_norm(self, csr_mat, mean_flag=False):
93 | rowsum = np.array(csr_mat.sum(1))
94 | rowsum = np.power(rowsum+1e-8, -0.5).flatten()
95 | rowsum[np.isinf(rowsum)] = 0.
96 | rowsum_diag = sp.diags(rowsum)
97 |
98 | colsum = np.array(csr_mat.sum(0))
99 | colsum = np.power(colsum+1e-8, -0.5).flatten()
100 | colsum[np.isinf(colsum)] = 0.
101 | colsum_diag = sp.diags(colsum)
102 |
103 | if mean_flag == False:
104 | return rowsum_diag*csr_mat*colsum_diag
105 | else:
106 | return rowsum_diag*csr_mat
107 |
108 | def matrix_to_tensor(self, cur_matrix):
109 | if type(cur_matrix) != sp.coo_matrix:
110 | cur_matrix = cur_matrix.tocoo() #
111 | indices = torch.from_numpy(np.vstack((cur_matrix.row, cur_matrix.col)).astype(np.int64)) #
112 | values = torch.from_numpy(cur_matrix.data) #
113 | shape = torch.Size(cur_matrix.shape)
114 |
115 | return torch.sparse.FloatTensor(indices, values, shape).to(torch.float32).cuda() #
116 |
117 | def para_dict_to_tenser(self, para_dict):
118 | """
119 | :param para_dict: nn.ParameterDict()
120 | :return: tensor
121 | """
122 | tensors = []
123 |
124 | for beh in para_dict.keys():
125 | tensors.append(para_dict[beh])
126 | tensors = torch.stack(tensors, dim=0)
127 |
128 | return tensors
129 |
130 |
131 | def multi_head_self_attention(self, trans_w, embedding_t_1, embedding_t):
132 |
133 | q = self.para_dict_to_tenser(embedding_t)
134 | v = k = self.para_dict_to_tenser(embedding_t_1)
135 | beh, N, d_h = q.shape[0], q.shape[1], args.embed_size/args.head_num
136 |
137 | Q = torch.matmul(q, trans_w['w_q'])
138 | K = torch.matmul(k, trans_w['w_k'])
139 | V = v
140 |
141 | Q = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3)
142 | K = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3)
143 |
144 | Q = torch.unsqueeze(Q, 2)
145 | K = torch.unsqueeze(K, 1)
146 | V = torch.unsqueeze(V, 1)
147 |
148 | att = torch.mul(Q, K) / torch.sqrt(torch.tensor(d_h))
149 | att = torch.sum(att, dim=-1)
150 | att = torch.unsqueeze(att, dim=-1)
151 | att = F.softmax(att, dim=2)
152 |
153 | Z = torch.mul(att, V)
154 | Z = torch.sum(Z, dim=2)
155 |
156 | Z_list = [value for value in Z]
157 | Z = torch.cat(Z_list, -1)
158 | Z = torch.matmul(Z, self.weight_dict['w_self_attention_cat'])
159 |
160 | args.model_cat_rate*F.normalize(Z, p=2, dim=2)
161 | return Z, att.detach()
162 |
163 | # def prompt_tuning(self, soft_token_u, soft_token_i):
164 | # # self.user_id_embedding = nn.Embedding.from_pretrained(soft_token_u, freeze=False)
165 | # # self.item_id_embedding = nn.Embedding.from_pretrained(soft_token_u, freeze=False)
166 | # self.prompt_user = soft_token_u
167 | # self.prompt_item = soft_token_i
168 |
169 | def forward(self, ui_graph, iu_graph, prompt_module=None):
170 |
171 | # def forward(self, ui_graph, iu_graph):
172 |
173 | prompt_user, prompt_item = prompt_module() # [n*32]
174 | # ----feature prompt----
175 | # feat_prompt_user = torch.mean( torch.stack((torch.mm(prompt_user, torch.mm(prompt_user.T, self.image_feats)), torch.mm(prompt_user, torch.mm(prompt_user.T, self.text_feats)))), dim=0 )
176 | # feat_prompt_user = torch.mm(prompt_user, torch.mm(prompt_user.T, self.text_feats))
177 | feat_prompt_item_image = torch.mm(prompt_item, torch.mm(prompt_item.T, self.image_feats))
178 | feat_prompt_item_text = torch.mm(prompt_item, torch.mm(prompt_item.T, self.text_feats))
179 | # feat_prompt_image_item = torch.mm(prompt_item, torch.mm(prompt_item.T, self.image_feats))
180 | # feat_prompt_text_item = torch.mm(prompt_item, torch.mm(prompt_item.T, self.text_feats))
181 | # ----feature prompt----
182 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + feat_prompt_item_image ))
183 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + feat_prompt_item_text ))
184 |
185 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + F.normalize(feat_prompt_item_image, p=2, dim=1) ))
186 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + F.normalize(feat_prompt_item_text, p=2, dim=1) ))
187 |
188 | image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + args.feat_soft_token_rate*F.normalize(feat_prompt_item_image, p=2, dim=1) ))
189 | text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + args.feat_soft_token_rate*F.normalize(feat_prompt_item_text, p=2, dim=1) ))
190 | # args.feat_soft_token_rate*F.normalize(feat_prompt_item_image, p=2, dim=1)
191 | # args.feat_soft_token_rate*F.normalize(feat_prompt_item_text, p=2, dim=1)
192 |
193 |
194 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats))
195 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats))
196 |
197 | for i in range(args.layers):
198 | image_user_feats = self.mm(ui_graph, image_feats)
199 | image_item_feats = self.mm(iu_graph, image_user_feats)
200 | # image_user_id = self.mm(image_ui_graph, self.item_id_embedding.weight)
201 | # image_item_id = self.mm(image_iu_graph, self.user_id_embedding.weight)
202 |
203 | text_user_feats = self.mm(ui_graph, text_feats)
204 | text_item_feats = self.mm(iu_graph, text_user_feats)
205 |
206 | # text_user_id = self.mm(text_ui_graph, self.item_id_embedding.weight)
207 | # text_item_id = self.mm(text_iu_graph, self.user_id_embedding.weight)
208 |
209 | # self.embedding_dict['user']['image'] = image_user_id
210 | # self.embedding_dict['user']['text'] = text_user_id
211 | # self.embedding_dict['item']['image'] = image_item_id
212 | # self.embedding_dict['item']['text'] = text_item_id
213 | # user_z, att_u = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['user'], self.embedding_dict['user'])
214 | # item_z, att_i = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['item'], self.embedding_dict['item'])
215 | # user_emb = user_z.mean(0)
216 | # item_emb = item_z.mean(0)
217 | u_g_embeddings = self.user_id_embedding.weight + args.soft_token_rate*F.normalize(prompt_user, p=2, dim=1)
218 | i_g_embeddings = self.item_id_embedding.weight + args.soft_token_rate*F.normalize(prompt_item, p=2, dim=1)
219 | user_emb_list = [u_g_embeddings]
220 | item_emb_list = [i_g_embeddings]
221 | for i in range(self.n_ui_layers):
222 | if i == (self.n_ui_layers-1):
223 | u_g_embeddings = self.softmax( torch.mm(ui_graph, i_g_embeddings) )
224 | i_g_embeddings = self.softmax( torch.mm(iu_graph, u_g_embeddings) )
225 |
226 | else:
227 | u_g_embeddings = torch.mm(ui_graph, i_g_embeddings)
228 | i_g_embeddings = torch.mm(iu_graph, u_g_embeddings)
229 |
230 | user_emb_list.append(u_g_embeddings)
231 | item_emb_list.append(i_g_embeddings)
232 |
233 | u_g_embeddings = torch.mean(torch.stack(user_emb_list), dim=0)
234 | i_g_embeddings = torch.mean(torch.stack(item_emb_list), dim=0)
235 |
236 |
237 | u_g_embeddings = u_g_embeddings + args.model_cat_rate*F.normalize(image_user_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_user_feats, p=2, dim=1)
238 | i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(image_item_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_item_feats, p=2, dim=1)
239 |
240 | return u_g_embeddings, i_g_embeddings, image_item_feats, text_item_feats, image_user_feats, text_user_feats, u_g_embeddings, i_g_embeddings , prompt_user, prompt_item
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 | class PromptLearner(nn.Module):
249 | def __init__(self, image_feats=None, text_feats=None, ui_graph=None):
250 | super().__init__()
251 | self.ui_graph = ui_graph
252 |
253 |
254 | if args.hard_token_type=='pca':
255 | try:
256 | t1 = time()
257 | hard_token_image = pickle.load(open(args.data_path + args.dataset + '/hard_token_image_pca','rb'))
258 | hard_token_text = pickle.load(open(args.data_path + args.dataset + '/hard_token_text_pca','rb'))
259 | print('already load hard token', time() - t1)
260 | except Exception:
261 | hard_token_image = PCA(n_components=args.embed_size).fit_transform(image_feats)
262 | hard_token_text = PCA(n_components=args.embed_size).fit_transform(text_feats)
263 | pickle.dump(hard_token_image, open(args.data_path + args.dataset + '/hard_token_image_pca','wb'))
264 | pickle.dump(hard_token_text, open(args.data_path + args.dataset + '/hard_token_text_pca','wb'))
265 | elif args.hard_token_type=='ica':
266 | try:
267 | t1 = time()
268 | hard_token_image = pickle.load(open(args.data_path + args.dataset + '/hard_token_image_ica','rb'))
269 | hard_token_text = pickle.load(open(args.data_path + args.dataset + '/hard_token_text_ica','rb'))
270 | print('already load hard token', time() - t1)
271 | except Exception:
272 | hard_token_image = FastICA(n_components=args.embed_size, random_state=12).fit_transform(image_feats)
273 | hard_token_text = FastICA(n_components=args.embed_size, random_state=12).fit_transform(text_feats)
274 | pickle.dump(hard_token_image, open(args.data_path + args.dataset + '/hard_token_image_ica','wb'))
275 | pickle.dump(hard_token_text, open(args.data_path + args.dataset + '/hard_token_text_ica','wb'))
276 | elif args.hard_token_type=='isomap':
277 | hard_token_image = manifold.Isomap(n_neighbors=5, n_components=args.embed_size, n_jobs=-1).fit_transform(image_feats)
278 | hard_token_text = manifold.Isomap(n_neighbors=5, n_components=args.embed_size, n_jobs=-1).fit_transform(text_feats)
279 | # elif args.hard_token_type=='tsne':
280 | # hard_token_image = TSNE(n_components=args.embed_size, n_iter=300).fit_transform(image_feats)
281 | # hard_token_text = TSNE(n_components=args.embed_size, n_iter=300).fit_transform(text_feats)
282 | # elif args.hard_token_type=='lda':
283 | # hard_token_image = LinearDiscriminantAnalysis(n_components=args.embed_size).fit_transform(image_feats)
284 | # hard_token_text = LinearDiscriminantAnalysis(n_components=args.embed_size).fit_transform(text_feats)
285 |
286 | # self.item_hard_token = nn.Embedding.from_pretrained(torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0), freeze=False).cuda().weight
287 | # self.user_hard_token = nn.Embedding.from_pretrained(torch.mm(ui_graph, self.item_hard_token), freeze=False).cuda().weight
288 |
289 | self.item_hard_token = torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0).cuda()
290 | self.user_hard_token = torch.mm(ui_graph, self.item_hard_token).cuda()
291 |
292 | self.trans_user = nn.Linear(args.embed_size, args.embed_size).cuda()
293 | self.trans_item = nn.Linear(args.embed_size, args.embed_size).cuda()
294 | # nn.init.xavier_uniform_(self.gnn_trans_user.weight)
295 | # nn.init.xavier_uniform_(self.gnn_trans_item.weight)
296 | # self.gnn_trans_user = self.gnn_trans_user.cuda()
297 | # self.gnn_trans_item = self.gnn_trans_item.cuda()
298 | # self.item_hard_token = torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0).cuda()
299 |
300 |
301 | def forward(self):
302 | # self.user_hard_token = self.gnn_trans_user(torch.mm(self.ui_graph, self.item_hard_token))
303 | # self.item_hard_token = self.gnn_trans_item(self.item_hard_token)
304 | # return self.user_hard_token , self.item_hard_token
305 | return F.dropout(self.trans_user(self.user_hard_token), args.prompt_dropout) , F.dropout(self.trans_item(self.item_hard_token), args.prompt_dropout)
306 |
307 |
308 |
309 |
310 | class Student_LightGCN(nn.Module):
311 | def __init__(self, n_users, n_items, embedding_dim, gnn_layer, dropout_list, image_feats=None, text_feats=None):
312 | super().__init__()
313 | self.n_users = n_users
314 | self.n_items = n_items
315 | self.embedding_dim = embedding_dim
316 | self.n_ui_layers = gnn_layer
317 |
318 | self.user_id_embedding = nn.Embedding(n_users, embedding_dim)
319 | self.item_id_embedding = nn.Embedding(n_items, embedding_dim)
320 | nn.init.xavier_uniform_(self.user_id_embedding.weight)
321 | nn.init.xavier_uniform_(self.item_id_embedding.weight)
322 |
323 | # self.feat_trans = nn.Linear(args.embed_size, args.student_embed_size)
324 | # # self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size)
325 | # nn.init.xavier_uniform_(self.feat_trans.weight)
326 | # # nn.init.xavier_uniform_(self.text_trans.weight)
327 |
328 | def init_user_item_embed(self, pre_u_embed, pre_i_embed):
329 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False)
330 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False)
331 |
332 | self.user_id_embedding_pre = nn.Embedding.from_pretrained(pre_u_embed, freeze=False)
333 | self.item_id_embedding_pre = nn.Embedding.from_pretrained(pre_i_embed, freeze=False)
334 |
335 | def get_embedding(self):
336 | return self.user_id_embedding, self.item_id_embedding
337 |
338 | def forward(self, adj, image_item_embeds, text_item_embeds, image_user_embeds, text_user_embeds):
339 |
340 | # # teacher_feat_dict = { 'item_image':t_i_image_embed.deteach(),'item_text':t_i_text_embed.deteach(),'user_image':t_u_image_embed.deteach(),'user_text':t_u_text_embed.deteach() }
341 | # tmp_feat_dict = {}
342 | # for index,value in enumerate(teacher_feat_dict.keys()):
343 | # tmp_feat_dict[value] = self.feat_trans(teacher_feat_dict[value])
344 | # u_g_embeddings = self.user_id_embedding.weight + args.model_cat_rate*F.normalize(tmp_feat_dict['user_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(tmp_feat_dict['user_text'], p=2, dim=1)
345 | # i_g_embeddings = self.item_id_embedding.weight + args.model_cat_rate*F.normalize(tmp_feat_dict['item_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(tmp_feat_dict['item_text'], p=2, dim=1)
346 | # ego_embeddings = torch.cat((u_g_embeddings, i_g_embeddings), dim=0)
347 |
348 | # self.user_id_embedding_pre = nn.Embedding.from_pretrained(pre_u_embed, freeze=False)
349 | # self.item_id_embedding_pre = nn.Embedding.from_pretrained(pre_i_embed, freeze=False)
350 |
351 | ego_embeddings = torch.cat((self.user_id_embedding.weight+self.user_id_embedding_pre.weight, self.item_id_embedding.weight+self.item_id_embedding_pre.weight), dim=0)
352 | # ego_embeddings = torch.cat((self.user_id_embedding.weight, self.item_id_embedding.weight), dim=0)
353 | all_embeddings = [ego_embeddings]
354 | for i in range(self.n_ui_layers):
355 | side_embeddings = torch.sparse.mm(adj, ego_embeddings)
356 | ego_embeddings = side_embeddings
357 | all_embeddings += [ego_embeddings]
358 | all_embeddings = torch.stack(all_embeddings, dim=1)
359 | all_embeddings = all_embeddings.mean(dim=1, keepdim=False)
360 | u_g_embeddings, i_g_embeddings = torch.split(all_embeddings, [self.n_users, self.n_items], dim=0)
361 | # u_g_embeddings += teacher_feat_dict['user_image'] + teacher_feat_dict['user_text']
362 | # i_g_embeddings += teacher_feat_dict['item_image'] + teacher_feat_dict['item_text']
363 | u_g_embeddings = u_g_embeddings + args.model_cat_rate*F.normalize(image_user_embeds, p=2, dim=1) + args.model_cat_rate*F.normalize(text_user_embeds, p=2, dim=1)
364 | i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(image_item_embeds, p=2, dim=1) + args.model_cat_rate*F.normalize(text_item_embeds, p=2, dim=1)
365 |
366 | return u_g_embeddings, i_g_embeddings
367 | # return self.user_id_embedding.weight, self.item_id_embedding.weight
368 |
369 |
370 |
371 | class Student_GCN(nn.Module):
372 | def __init__(self, n_users, n_items, embedding_dim, gnn_layer=2, drop_out=0., image_feats=None, text_feats=None):
373 | super(Student_GCN, self).__init__()
374 | self.embedding_dim = embedding_dim
375 |
376 | # self.layers = nn.Sequential(GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=True),
377 | # GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=False),
378 | # )
379 | # self.layer_list = nn.ModuleList()
380 | # for i in range(args.student_n_layers):
381 | # self.layer_list.append(GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=False))
382 |
383 | self.trans_user = nn.Linear(args.embed_size, args.embed_size).cuda()
384 | self.trans_item = nn.Linear(args.embed_size, args.embed_size).cuda()
385 |
386 |
387 | def forward(self, user_x, item_x, ui_graph, iu_graph):
388 | # # x, support = inputs
389 | # # user_x, item_x = self.layers((user_x, item_x, ui_graph, iu_graph))
390 | # for i in range(args.student_n_layers):
391 | # user_x, item_x = self.layer_list[i](user_x, item_x, ui_graph, iu_graph)
392 | # return user_x, item_x
393 |
394 | return self.trans_user(user_x), self.trans_item(item_x)
395 | # self.user_id_embedding = nn.Embedding.from_pretrained(user_x, freeze=True)
396 | # self.item_id_embedding = nn.Embedding.from_pretrained(item_x, freeze=True)
397 | # return self.user_id_embedding.weight, self.item_id_embedding.weight
398 |
399 | def l2_loss(self):
400 | layer = self.layers.children()
401 | layer = next(iter(layer))
402 | loss = None
403 |
404 | for p in layer.parameters():
405 | if loss is None:
406 | loss = p.pow(2).sum()
407 | else:
408 | loss += p.pow(2).sum()
409 |
410 | return loss
411 |
412 | class GraphConvolution(nn.Module):
413 | def __init__(self, input_dim, output_dim, dropout=0., is_sparse_inputs=False, bias=False, activation = F.relu,featureless=False):
414 | super(GraphConvolution, self).__init__()
415 | self.dropout = dropout
416 | self.bias = bias
417 | self.activation = activation
418 | self.is_sparse_inputs = is_sparse_inputs
419 | self.featureless = featureless
420 | # self.num_features_nonzero = num_features_nonzero
421 | # self.user_weight = nn.Parameter(torch.randn(input_dim, output_dim))
422 | # self.item_weight = nn.Parameter(torch.randn(input_dim, output_dim))
423 | self.user_weight = nn.Parameter(torch.empty(input_dim, output_dim))
424 | self.item_weight = nn.Parameter(torch.empty(input_dim, output_dim))
425 | nn.init.xavier_uniform_(self.user_weight)
426 | nn.init.xavier_uniform_(self.item_weight)
427 | self.bias = None
428 | if bias:
429 | self.bias = nn.Parameter(torch.zeros(output_dim))
430 |
431 |
432 | def forward(self, user_x, item_x, ui_graph, iu_graph):
433 | # print('inputs:', inputs)
434 | # x, support = inputs
435 | # if self.training and self.is_sparse_inputs:
436 | # x = sparse_dropout(x, self.dropout, self.num_features_nonzero)
437 | # elif self.training:
438 | user_x = F.dropout(user_x, self.dropout)
439 | item_x = F.dropout(item_x, self.dropout)
440 | # convolve
441 | if not self.featureless: # if it has features x
442 | if self.is_sparse_inputs:
443 | xw = torch.sparse.mm(user_x, self.user_weight)
444 | xw = torch.sparse.mm(item_x, self.item_weight)
445 | else:
446 | xw_user = torch.mm(user_x, self.user_weight)
447 | xw_item = torch.mm(item_x, self.item_weight)
448 | else:
449 | xw = self.weight
450 | out_user = torch.sparse.mm(ui_graph, xw_item)
451 | out_item = torch.sparse.mm(iu_graph, xw_user)
452 |
453 | if self.bias is not None:
454 | out += self.bias
455 | return self.activation(out_user), self.activation(out_item)
456 |
457 |
458 | def sparse_dropout(x, rate, noise_shape):
459 | """
460 | :param x:
461 | :param rate:
462 | :param noise_shape: int scalar
463 | :return:
464 | """
465 | random_tensor = 1 - rate
466 | random_tensor += torch.rand(noise_shape).to(x.device)
467 | dropout_mask = torch.floor(random_tensor).byte()
468 | i = x._indices() # [2, 49216]
469 | v = x._values() # [49216]
470 | # [2, 4926] => [49216, 2] => [remained node, 2] => [2, remained node]
471 | i = i[:, dropout_mask]
472 | v = v[dropout_mask]
473 | out = torch.sparse.FloatTensor(i, v, x.shape).to(x.device)
474 | out = out * (1./ (1-rate))
475 | return out
476 |
477 |
478 | def dot(x, y, sparse=False):
479 | if sparse:
480 | res = torch.sparse.mm(x, y)
481 | else:
482 | res = torch.mm(x, y)
483 | return res
484 |
485 |
486 |
487 |
488 |
489 | class BLMLP(nn.Module):
490 | def __init__(self):
491 | super(BLMLP, self).__init__()
492 | self.W = nn.Parameter(nn.init.xavier_uniform_(torch.empty(args.student_embed_size, args.student_embed_size)))
493 | self.act = nn.LeakyReLU(negative_slope=0.5)
494 |
495 | def forward(self, embeds):
496 | pass
497 |
498 | def featureExtract(self, embeds):
499 | return self.act(embeds @ self.W) + embeds
500 |
501 | def pairPred(self, embeds1, embeds2):
502 | return (self.featureExtract(embeds1) * self.featureExtract(embeds2)).sum(dim=-1)
503 |
504 | def crossPred(self, embeds1, embeds2):
505 | return self.featureExtract(embeds1) @ self.featureExtract(embeds2).T
506 |
507 |
508 |
509 | class Student_MLP(nn.Module):
510 | def __init__(self):
511 | super(Student_MLP, self).__init__()
512 | # self.n_users = n_users
513 | # self.n_items = n_items
514 | # self.embedding_dim = embedding_dim
515 |
516 | # self.uEmbeds = nn.Parameter(init(torch.empty(args.user, args.latdim)))
517 | # self.iEmbeds = nn.Parameter(init(torch.empty(args.item, args.latdim)))
518 |
519 | self.user_trans = nn.Linear(args.embed_size, args.embed_size)
520 | self.item_trans = nn.Linear(args.embed_size, args.embed_size)
521 | nn.init.xavier_uniform_(self.user_trans.weight)
522 | nn.init.xavier_uniform_(self.item_trans.weight)
523 |
524 | self.MLP = BLMLP()
525 | # self.overallTime = datetime.timedelta(0)
526 |
527 |
528 | def get_embedding(self):
529 | return self.user_id_embedding, self.item_id_embedding
530 |
531 |
532 | def forward(self, pre_user, pre_item, ):
533 | # pre_user, pre_item = self.user_id_embedding.weight, self.item_id_embedding.weight
534 | user_embed = self.user_trans(pre_user)
535 | item_embed = self.user_trans(pre_item)
536 |
537 | return user_embed, item_embed
538 | # return pre_user, pre_item
539 |
540 | def init_user_item_embed(self, pre_u_embed, pre_i_embed):
541 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False)
542 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False)
543 |
544 | def pointPosPredictwEmbeds(self, uEmbeds, iEmbeds, ancs, poss):
545 | ancEmbeds = uEmbeds[ancs]
546 | posEmbeds = iEmbeds[poss]
547 | nume = self.MLP.pairPred(ancEmbeds, posEmbeds)
548 | return nume
549 |
550 | def pointNegPredictwEmbeds(self, embeds1, embeds2, nodes1, temp=1.0):
551 | pckEmbeds1 = embeds1[nodes1]
552 | preds = self.MLP.crossPred(pckEmbeds1, embeds2)
553 | return torch.exp(preds / temp).sum(-1)
554 |
555 | def pairPredictwEmbeds(self, uEmbeds, iEmbeds, ancs, poss, negs):
556 | ancEmbeds = uEmbeds[ancs]
557 | posEmbeds = iEmbeds[poss]
558 | negEmbeds = iEmbeds[negs]
559 | posPreds = self.MLP.pairPred(ancEmbeds, posEmbeds)
560 | negPreds = self.MLP.pairPred(ancEmbeds, negEmbeds)
561 | return posPreds - negPreds
562 |
563 | def predAll(self, pckUEmbeds, iEmbeds):
564 | return self.MLP.crossPred(pckUEmbeds, iEmbeds)
565 |
566 | def testPred(self, usr, trnMask):
567 | uEmbeds, iEmbeds = self.forward()
568 | allPreds = self.predAll(uEmbeds[usr], iEmbeds) * (1 - trnMask) - trnMask * 1e8
569 | return allPreds
570 |
571 |
--------------------------------------------------------------------------------
/codes/Models_mmlight.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from time import time
4 | import pickle
5 | import pickle
6 | import scipy.sparse as sp
7 | from scipy.sparse import csr_matrix
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn import init
13 |
14 | from sklearn.decomposition import PCA, FastICA
15 | from sklearn import manifold
16 | from sklearn.manifold import TSNE
17 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
18 |
19 | # from utility.parser import parse_args
20 | from utility.norm import build_sim, build_knn_normalized_graph
21 | # args = parse_args()
22 | from utility.parser import args
23 |
24 |
25 |
26 | class Teacher_Model(nn.Module):
27 | def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats, text_feats):
28 |
29 | super().__init__()
30 | self.n_users = n_users
31 | self.n_items = n_items
32 | self.embedding_dim = embedding_dim
33 | self.weight_size = weight_size
34 | self.n_ui_layers = len(self.weight_size)
35 | self.weight_size = [self.embedding_dim] + self.weight_size
36 |
37 | self.image_trans = nn.Linear(image_feats.shape[1], args.embed_size)
38 | self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size)
39 | nn.init.xavier_uniform_(self.image_trans.weight)
40 | nn.init.xavier_uniform_(self.text_trans.weight)
41 | self.encoder = nn.ModuleDict()
42 | self.encoder['image_encoder'] = self.image_trans # ^-^
43 | self.encoder['text_encoder'] = self.text_trans # ^-^
44 |
45 | self.user_id_embedding = nn.Embedding(n_users, self.embedding_dim)
46 | self.item_id_embedding = nn.Embedding(n_items, self.embedding_dim)
47 |
48 | nn.init.xavier_uniform_(self.user_id_embedding.weight)
49 | nn.init.xavier_uniform_(self.item_id_embedding.weight)
50 | self.image_feats = torch.tensor(image_feats).float().cuda()
51 | self.text_feats = torch.tensor(text_feats).float().cuda()
52 | self.image_embedding = nn.Embedding.from_pretrained(torch.Tensor(image_feats), freeze=False)
53 | self.text_embedding = nn.Embedding.from_pretrained(torch.Tensor(text_feats), freeze=False)
54 |
55 | self.softmax = nn.Softmax(dim=-1)
56 | self.act = nn.Sigmoid()
57 | self.sigmoid = nn.Sigmoid()
58 | self.dropout = nn.Dropout(p=args.drop_rate)
59 | self.batch_norm = nn.BatchNorm1d(args.embed_size)
60 |
61 | def mm(self, x, y):
62 | if args.sparse:
63 | return torch.sparse.mm(x, y)
64 | else:
65 | return torch.mm(x, y)
66 | def sim(self, z1, z2):
67 | z1 = F.normalize(z1)
68 | z2 = F.normalize(z2)
69 | return torch.mm(z1, z2.t())
70 |
71 | def batched_contrastive_loss(self, z1, z2, batch_size=4096):
72 | device = z1.device
73 | num_nodes = z1.size(0)
74 | num_batches = (num_nodes - 1) // batch_size + 1
75 | f = lambda x: torch.exp(x / self.tau)
76 | indices = torch.arange(0, num_nodes).to(device)
77 | losses = []
78 |
79 | for i in range(num_batches):
80 | mask = indices[i * batch_size:(i + 1) * batch_size]
81 | refl_sim = f(self.sim(z1[mask], z1))
82 | between_sim = f(self.sim(z1[mask], z2))
83 |
84 | losses.append(-torch.log(
85 | between_sim[:, i * batch_size:(i + 1) * batch_size].diag()
86 | / (refl_sim.sum(1) + between_sim.sum(1)
87 | - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag())))
88 |
89 | loss_vec = torch.cat(losses)
90 | return loss_vec.mean()
91 |
92 | def csr_norm(self, csr_mat, mean_flag=False):
93 | rowsum = np.array(csr_mat.sum(1))
94 | rowsum = np.power(rowsum+1e-8, -0.5).flatten()
95 | rowsum[np.isinf(rowsum)] = 0.
96 | rowsum_diag = sp.diags(rowsum)
97 |
98 | colsum = np.array(csr_mat.sum(0))
99 | colsum = np.power(colsum+1e-8, -0.5).flatten()
100 | colsum[np.isinf(colsum)] = 0.
101 | colsum_diag = sp.diags(colsum)
102 |
103 | if mean_flag == False:
104 | return rowsum_diag*csr_mat*colsum_diag
105 | else:
106 | return rowsum_diag*csr_mat
107 |
108 | def matrix_to_tensor(self, cur_matrix):
109 | if type(cur_matrix) != sp.coo_matrix:
110 | cur_matrix = cur_matrix.tocoo() #
111 | indices = torch.from_numpy(np.vstack((cur_matrix.row, cur_matrix.col)).astype(np.int64)) #
112 | values = torch.from_numpy(cur_matrix.data) #
113 | shape = torch.Size(cur_matrix.shape)
114 |
115 | return torch.sparse.FloatTensor(indices, values, shape).to(torch.float32).cuda() #
116 |
117 | def para_dict_to_tenser(self, para_dict):
118 | """
119 | :param para_dict: nn.ParameterDict()
120 | :return: tensor
121 | """
122 | tensors = []
123 |
124 | for beh in para_dict.keys():
125 | tensors.append(para_dict[beh])
126 | tensors = torch.stack(tensors, dim=0)
127 |
128 | return tensors
129 |
130 |
131 | def multi_head_self_attention(self, trans_w, embedding_t_1, embedding_t):
132 |
133 | q = self.para_dict_to_tenser(embedding_t)
134 | v = k = self.para_dict_to_tenser(embedding_t_1)
135 | beh, N, d_h = q.shape[0], q.shape[1], args.embed_size/args.head_num
136 |
137 | Q = torch.matmul(q, trans_w['w_q'])
138 | K = torch.matmul(k, trans_w['w_k'])
139 | V = v
140 |
141 | Q = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3)
142 | K = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3)
143 |
144 | Q = torch.unsqueeze(Q, 2)
145 | K = torch.unsqueeze(K, 1)
146 | V = torch.unsqueeze(V, 1)
147 |
148 | att = torch.mul(Q, K) / torch.sqrt(torch.tensor(d_h))
149 | att = torch.sum(att, dim=-1)
150 | att = torch.unsqueeze(att, dim=-1)
151 | att = F.softmax(att, dim=2)
152 |
153 | Z = torch.mul(att, V)
154 | Z = torch.sum(Z, dim=2)
155 |
156 | Z_list = [value for value in Z]
157 | Z = torch.cat(Z_list, -1)
158 | Z = torch.matmul(Z, self.weight_dict['w_self_attention_cat'])
159 |
160 | args.model_cat_rate*F.normalize(Z, p=2, dim=2)
161 | return Z, att.detach()
162 |
163 | # def prompt_tuning(self, soft_token_u, soft_token_i):
164 | # # self.user_id_embedding = nn.Embedding.from_pretrained(soft_token_u, freeze=False)
165 | # # self.item_id_embedding = nn.Embedding.from_pretrained(soft_token_u, freeze=False)
166 | # self.prompt_user = soft_token_u
167 | # self.prompt_item = soft_token_i
168 |
169 | def forward(self, ui_graph, iu_graph, prompt_module=None):
170 |
171 | # def forward(self, ui_graph, iu_graph):
172 |
173 | prompt_user, prompt_item = prompt_module() # [n*32]
174 | # ----feature prompt----
175 | # feat_prompt_user = torch.mean( torch.stack((torch.mm(prompt_user, torch.mm(prompt_user.T, self.image_feats)), torch.mm(prompt_user, torch.mm(prompt_user.T, self.text_feats)))), dim=0 )
176 | # feat_prompt_user = torch.mm(prompt_user, torch.mm(prompt_user.T, self.text_feats))
177 | feat_prompt_item_image = torch.mm(prompt_item, torch.mm(prompt_item.T, self.image_feats))
178 | feat_prompt_item_text = torch.mm(prompt_item, torch.mm(prompt_item.T, self.text_feats))
179 | # feat_prompt_image_item = torch.mm(prompt_item, torch.mm(prompt_item.T, self.image_feats))
180 | # feat_prompt_text_item = torch.mm(prompt_item, torch.mm(prompt_item.T, self.text_feats))
181 | # ----feature prompt----
182 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + feat_prompt_item_image ))
183 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + feat_prompt_item_text ))
184 |
185 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + F.normalize(feat_prompt_item_image, p=2, dim=1) ))
186 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + F.normalize(feat_prompt_item_text, p=2, dim=1) ))
187 |
188 | image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats + args.feat_soft_token_rate*F.normalize(feat_prompt_item_image, p=2, dim=1) ))
189 | text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats + args.feat_soft_token_rate*F.normalize(feat_prompt_item_text, p=2, dim=1) ))
190 | # args.feat_soft_token_rate*F.normalize(feat_prompt_item_image, p=2, dim=1)
191 | # args.feat_soft_token_rate*F.normalize(feat_prompt_item_text, p=2, dim=1)
192 |
193 |
194 | # image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats))
195 | # text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats))
196 |
197 | for i in range(args.layers):
198 | image_user_feats = self.mm(ui_graph, image_feats)
199 | image_item_feats = self.mm(iu_graph, image_user_feats)
200 | # image_user_id = self.mm(image_ui_graph, self.item_id_embedding.weight)
201 | # image_item_id = self.mm(image_iu_graph, self.user_id_embedding.weight)
202 |
203 | text_user_feats = self.mm(ui_graph, text_feats)
204 | text_item_feats = self.mm(iu_graph, text_user_feats)
205 |
206 | # text_user_id = self.mm(text_ui_graph, self.item_id_embedding.weight)
207 | # text_item_id = self.mm(text_iu_graph, self.user_id_embedding.weight)
208 |
209 | # self.embedding_dict['user']['image'] = image_user_id
210 | # self.embedding_dict['user']['text'] = text_user_id
211 | # self.embedding_dict['item']['image'] = image_item_id
212 | # self.embedding_dict['item']['text'] = text_item_id
213 | # user_z, att_u = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['user'], self.embedding_dict['user'])
214 | # item_z, att_i = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['item'], self.embedding_dict['item'])
215 | # user_emb = user_z.mean(0)
216 | # item_emb = item_z.mean(0)
217 | u_g_embeddings = self.user_id_embedding.weight + args.soft_token_rate*F.normalize(prompt_user, p=2, dim=1)
218 | i_g_embeddings = self.item_id_embedding.weight + args.soft_token_rate*F.normalize(prompt_item, p=2, dim=1)
219 | user_emb_list = [u_g_embeddings]
220 | item_emb_list = [i_g_embeddings]
221 | for i in range(self.n_ui_layers):
222 | if i == (self.n_ui_layers-1):
223 | u_g_embeddings = self.softmax( torch.mm(ui_graph, i_g_embeddings) )
224 | i_g_embeddings = self.softmax( torch.mm(iu_graph, u_g_embeddings) )
225 |
226 | else:
227 | u_g_embeddings = torch.mm(ui_graph, i_g_embeddings)
228 | i_g_embeddings = torch.mm(iu_graph, u_g_embeddings)
229 |
230 | user_emb_list.append(u_g_embeddings)
231 | item_emb_list.append(i_g_embeddings)
232 |
233 | u_g_embeddings = torch.mean(torch.stack(user_emb_list), dim=0)
234 | i_g_embeddings = torch.mean(torch.stack(item_emb_list), dim=0)
235 |
236 |
237 | u_g_embeddings = u_g_embeddings + args.model_cat_rate*F.normalize(image_user_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_user_feats, p=2, dim=1)
238 | i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(image_item_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_item_feats, p=2, dim=1)
239 |
240 | return u_g_embeddings, i_g_embeddings, image_item_feats, text_item_feats, image_user_feats, text_user_feats, u_g_embeddings, i_g_embeddings , prompt_user, prompt_item
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 | class PromptLearner(nn.Module):
249 | def __init__(self, image_feats=None, text_feats=None, ui_graph=None):
250 | super().__init__()
251 | self.ui_graph = ui_graph
252 |
253 |
254 | if args.hard_token_type=='pca':
255 | try:
256 | t1 = time()
257 | hard_token_image = pickle.load(open(args.data_path + args.dataset + '/hard_token_image_pca','rb'))
258 | hard_token_text = pickle.load(open(args.data_path + args.dataset + '/hard_token_text_pca','rb'))
259 | print('already load hard token', time() - t1)
260 | except Exception:
261 | hard_token_image = PCA(n_components=args.embed_size).fit_transform(image_feats)
262 | hard_token_text = PCA(n_components=args.embed_size).fit_transform(text_feats)
263 | pickle.dump(hard_token_image, open(args.data_path + args.dataset + '/hard_token_image_pca','wb'))
264 | pickle.dump(hard_token_text, open(args.data_path + args.dataset + '/hard_token_text_pca','wb'))
265 | elif args.hard_token_type=='ica':
266 | try:
267 | t1 = time()
268 | hard_token_image = pickle.load(open(args.data_path + args.dataset + '/hard_token_image_ica','rb'))
269 | hard_token_text = pickle.load(open(args.data_path + args.dataset + '/hard_token_text_ica','rb'))
270 | print('already load hard token', time() - t1)
271 | except Exception:
272 | hard_token_image = FastICA(n_components=args.embed_size, random_state=12).fit_transform(image_feats)
273 | hard_token_text = FastICA(n_components=args.embed_size, random_state=12).fit_transform(text_feats)
274 | pickle.dump(hard_token_image, open(args.data_path + args.dataset + '/hard_token_image_ica','wb'))
275 | pickle.dump(hard_token_text, open(args.data_path + args.dataset + '/hard_token_text_ica','wb'))
276 | elif args.hard_token_type=='isomap':
277 | hard_token_image = manifold.Isomap(n_neighbors=5, n_components=args.embed_size, n_jobs=-1).fit_transform(image_feats)
278 | hard_token_text = manifold.Isomap(n_neighbors=5, n_components=args.embed_size, n_jobs=-1).fit_transform(text_feats)
279 | # elif args.hard_token_type=='tsne':
280 | # hard_token_image = TSNE(n_components=args.embed_size, n_iter=300).fit_transform(image_feats)
281 | # hard_token_text = TSNE(n_components=args.embed_size, n_iter=300).fit_transform(text_feats)
282 | # elif args.hard_token_type=='lda':
283 | # hard_token_image = LinearDiscriminantAnalysis(n_components=args.embed_size).fit_transform(image_feats)
284 | # hard_token_text = LinearDiscriminantAnalysis(n_components=args.embed_size).fit_transform(text_feats)
285 |
286 | # self.item_hard_token = nn.Embedding.from_pretrained(torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0), freeze=False).cuda().weight
287 | # self.user_hard_token = nn.Embedding.from_pretrained(torch.mm(ui_graph, self.item_hard_token), freeze=False).cuda().weight
288 |
289 | self.item_hard_token = torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0).cuda()
290 | self.user_hard_token = torch.mm(ui_graph, self.item_hard_token).cuda()
291 |
292 | self.trans_user = nn.Linear(args.embed_size, args.embed_size).cuda()
293 | self.trans_item = nn.Linear(args.embed_size, args.embed_size).cuda()
294 | # nn.init.xavier_uniform_(self.gnn_trans_user.weight)
295 | # nn.init.xavier_uniform_(self.gnn_trans_item.weight)
296 | # self.gnn_trans_user = self.gnn_trans_user.cuda()
297 | # self.gnn_trans_item = self.gnn_trans_item.cuda()
298 | # self.item_hard_token = torch.mean((torch.stack((torch.tensor(hard_token_image).float(), torch.tensor(hard_token_text).float()))), dim=0).cuda()
299 |
300 |
301 | def forward(self):
302 | # self.user_hard_token = self.gnn_trans_user(torch.mm(self.ui_graph, self.item_hard_token))
303 | # self.item_hard_token = self.gnn_trans_item(self.item_hard_token)
304 | # return self.user_hard_token , self.item_hard_token
305 | return F.dropout(self.trans_user(self.user_hard_token), args.prompt_dropout) , F.dropout(self.trans_item(self.item_hard_token), args.prompt_dropout)
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 | class Student_MMLight(nn.Module):
324 | def __init__(self, n_users, n_items, embedding_dim, n_layers, dropout, image_feats, text_feats, ui_graph, iu_graph):
325 |
326 | super().__init__()
327 | self.n_users = n_users
328 | self.n_items = n_items
329 | self.ui_graph = ui_graph
330 | self.iu_graph = iu_graph
331 | self.embedding_dim = embedding_dim
332 | # self.weight_size = weight_size
333 | self.n_ui_layers = n_layers
334 | # self.weight_size = [self.embedding_dim] + self.weight_size
335 |
336 | # self.image_trans = nn.Linear(image_feats.shape[1], args.embed_size)
337 | # self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size)
338 | self.image_trans = nn.Linear(args.embed_size, args.embed_size)
339 | self.text_trans = nn.Linear(args.embed_size, args.embed_size)
340 | nn.init.xavier_uniform_(self.image_trans.weight)
341 | nn.init.xavier_uniform_(self.text_trans.weight)
342 |
343 | self.user_id_embedding = nn.Embedding(n_users, self.embedding_dim)
344 | self.item_id_embedding = nn.Embedding(n_items, self.embedding_dim)
345 |
346 | nn.init.xavier_uniform_(self.user_id_embedding.weight)
347 | nn.init.xavier_uniform_(self.item_id_embedding.weight)
348 | self.image_feats = torch.tensor(image_feats).float().cuda()
349 | self.text_feats = torch.tensor(text_feats).float().cuda()
350 |
351 | self.softmax = nn.Softmax(dim=-1)
352 | self.act = nn.Sigmoid()
353 | self.sigmoid = nn.Sigmoid()
354 | self.dropout = nn.Dropout(args.drop_rate)
355 | self.batch_norm = nn.BatchNorm1d(args.embed_size)
356 | self.tau = 0.5
357 |
358 |
359 | def mm(self, x, y):
360 | if args.sparse:
361 | return torch.sparse.mm(x, y)
362 | else:
363 | return torch.mm(x, y)
364 | def sim(self, z1, z2):
365 | z1 = F.normalize(z1)
366 | z2 = F.normalize(z2)
367 | return torch.mm(z1, z2.t())
368 |
369 |
370 | def init_user_item_embed(self, pre_u_embed, pre_i_embed):
371 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False)
372 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False)
373 |
374 | def forward(self, image_feats, text_feats, image_user_embeds, text_user_embeds):
375 | image_feats = image_item_feats = self.dropout(self.image_trans(image_feats))
376 | text_feats = text_item_feats = self.dropout(self.text_trans(text_feats))
377 |
378 | # image_user_feats = self.mm(self.ui_graph, image_feats)
379 | # image_item_feats = self.mm(self.iu_graph, image_user_feats)
380 | # text_user_feats = self.mm(self.ui_graph, text_feats)
381 | # text_item_feats = self.mm(self.iu_graph, text_user_feats)
382 |
383 | u_g_embeddings = self.user_id_embedding.weight
384 | i_g_embeddings = self.item_id_embedding.weight
385 |
386 | user_emb_list = [u_g_embeddings]
387 | item_emb_list = [i_g_embeddings]
388 | for i in range(self.n_ui_layers):
389 | if i == (self.n_ui_layers-1):
390 | u_g_embeddings = self.softmax( torch.mm(self.ui_graph, i_g_embeddings) )
391 | i_g_embeddings = self.softmax( torch.mm(self.iu_graph, u_g_embeddings) )
392 | else:
393 | u_g_embeddings = torch.mm(self.ui_graph, i_g_embeddings)
394 | i_g_embeddings = torch.mm(self.iu_graph, u_g_embeddings)
395 |
396 | user_emb_list.append(u_g_embeddings)
397 | item_emb_list.append(i_g_embeddings)
398 |
399 | u_g_embeddings = torch.mean(torch.stack(user_emb_list), dim=0)
400 | i_g_embeddings = torch.mean(torch.stack(item_emb_list), dim=0)
401 |
402 |
403 | u_g_embeddings = u_g_embeddings #+ args.model_cat_rate*F.normalize(image_user_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_user_feats, p=2, dim=1)
404 | i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(image_item_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_item_feats, p=2, dim=1)
405 |
406 | # return u_g_embeddings, i_g_embeddings, image_item_feats, text_item_feats, image_user_feats, text_user_feats, u_g_embeddings, i_g_embeddings
407 | return u_g_embeddings, i_g_embeddings
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 | class Student_LightGCN(nn.Module):
431 | def __init__(self, n_users, n_items, embedding_dim, gnn_layer, dropout_list, image_feats=None, text_feats=None):
432 | super().__init__()
433 | self.n_users = n_users
434 | self.n_items = n_items
435 | self.embedding_dim = embedding_dim
436 | self.n_ui_layers = gnn_layer
437 |
438 | self.user_id_embedding = nn.Embedding(n_users, embedding_dim)
439 | self.item_id_embedding = nn.Embedding(n_items, embedding_dim)
440 | nn.init.xavier_uniform_(self.user_id_embedding.weight)
441 | nn.init.xavier_uniform_(self.item_id_embedding.weight)
442 |
443 | # self.feat_trans = nn.Linear(args.embed_size, args.student_embed_size)
444 | # # self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size)
445 | # nn.init.xavier_uniform_(self.feat_trans.weight)
446 | # # nn.init.xavier_uniform_(self.text_trans.weight)
447 |
448 | def init_user_item_embed(self, pre_u_embed, pre_i_embed):
449 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False)
450 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False)
451 |
452 | self.user_id_embedding_pre = nn.Embedding.from_pretrained(pre_u_embed, freeze=False)
453 | self.item_id_embedding_pre = nn.Embedding.from_pretrained(pre_i_embed, freeze=False)
454 |
455 | def get_embedding(self):
456 | return self.user_id_embedding, self.item_id_embedding
457 |
458 | def forward(self, adj, image_item_embeds, text_item_embeds, image_user_embeds, text_user_embeds):
459 |
460 | # # teacher_feat_dict = { 'item_image':t_i_image_embed.deteach(),'item_text':t_i_text_embed.deteach(),'user_image':t_u_image_embed.deteach(),'user_text':t_u_text_embed.deteach() }
461 | # tmp_feat_dict = {}
462 | # for index,value in enumerate(teacher_feat_dict.keys()):
463 | # tmp_feat_dict[value] = self.feat_trans(teacher_feat_dict[value])
464 | # u_g_embeddings = self.user_id_embedding.weight + args.model_cat_rate*F.normalize(tmp_feat_dict['user_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(tmp_feat_dict['user_text'], p=2, dim=1)
465 | # i_g_embeddings = self.item_id_embedding.weight + args.model_cat_rate*F.normalize(tmp_feat_dict['item_image'], p=2, dim=1) + args.model_cat_rate*F.normalize(tmp_feat_dict['item_text'], p=2, dim=1)
466 | # ego_embeddings = torch.cat((u_g_embeddings, i_g_embeddings), dim=0)
467 |
468 | # self.user_id_embedding_pre = nn.Embedding.from_pretrained(pre_u_embed, freeze=False)
469 | # self.item_id_embedding_pre = nn.Embedding.from_pretrained(pre_i_embed, freeze=False)
470 |
471 | ego_embeddings = torch.cat((self.user_id_embedding.weight+self.user_id_embedding_pre.weight, self.item_id_embedding.weight+self.item_id_embedding_pre.weight), dim=0)
472 | # ego_embeddings = torch.cat((self.user_id_embedding.weight, self.item_id_embedding.weight), dim=0)
473 | all_embeddings = [ego_embeddings]
474 | for i in range(self.n_ui_layers):
475 | side_embeddings = torch.sparse.mm(adj, ego_embeddings)
476 | ego_embeddings = side_embeddings
477 | all_embeddings += [ego_embeddings]
478 | all_embeddings = torch.stack(all_embeddings, dim=1)
479 | all_embeddings = all_embeddings.mean(dim=1, keepdim=False)
480 | u_g_embeddings, i_g_embeddings = torch.split(all_embeddings, [self.n_users, self.n_items], dim=0)
481 | # u_g_embeddings += teacher_feat_dict['user_image'] + teacher_feat_dict['user_text']
482 | # i_g_embeddings += teacher_feat_dict['item_image'] + teacher_feat_dict['item_text']
483 | u_g_embeddings = u_g_embeddings + args.model_cat_rate*F.normalize(image_user_embeds, p=2, dim=1) + args.model_cat_rate*F.normalize(text_user_embeds, p=2, dim=1)
484 | i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(image_item_embeds, p=2, dim=1) + args.model_cat_rate*F.normalize(text_item_embeds, p=2, dim=1)
485 |
486 | return u_g_embeddings, i_g_embeddings
487 | # return self.user_id_embedding.weight, self.item_id_embedding.weight
488 |
489 |
490 |
491 | class Student_GCN(nn.Module):
492 | def __init__(self, n_users, n_items, embedding_dim, gnn_layer=2, drop_out=0., image_feats=None, text_feats=None):
493 | super(Student_GCN, self).__init__()
494 | self.embedding_dim = embedding_dim
495 |
496 | # self.layers = nn.Sequential(GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=True),
497 | # GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=False),
498 | # )
499 | # self.layer_list = nn.ModuleList()
500 | # for i in range(args.student_n_layers):
501 | # self.layer_list.append(GraphConvolution(self.embedding_dim, self.embedding_dim, activation=F.relu, dropout=args.student_drop_rate, is_sparse_inputs=False))
502 |
503 | self.trans_user = nn.Linear(args.embed_size, args.embed_size).cuda()
504 | self.trans_item = nn.Linear(args.embed_size, args.embed_size).cuda()
505 |
506 |
507 | def forward(self, user_x, item_x, ui_graph, iu_graph):
508 | # # x, support = inputs
509 | # # user_x, item_x = self.layers((user_x, item_x, ui_graph, iu_graph))
510 | # for i in range(args.student_n_layers):
511 | # user_x, item_x = self.layer_list[i](user_x, item_x, ui_graph, iu_graph)
512 | # return user_x, item_x
513 |
514 | return self.trans_user(user_x), self.trans_item(item_x)
515 | # self.user_id_embedding = nn.Embedding.from_pretrained(user_x, freeze=True)
516 | # self.item_id_embedding = nn.Embedding.from_pretrained(item_x, freeze=True)
517 | # return self.user_id_embedding.weight, self.item_id_embedding.weight
518 |
519 | def l2_loss(self):
520 | layer = self.layers.children()
521 | layer = next(iter(layer))
522 | loss = None
523 |
524 | for p in layer.parameters():
525 | if loss is None:
526 | loss = p.pow(2).sum()
527 | else:
528 | loss += p.pow(2).sum()
529 |
530 | return loss
531 |
532 | class GraphConvolution(nn.Module):
533 | def __init__(self, input_dim, output_dim, dropout=0., is_sparse_inputs=False, bias=False, activation = F.relu,featureless=False):
534 | super(GraphConvolution, self).__init__()
535 | self.dropout = dropout
536 | self.bias = bias
537 | self.activation = activation
538 | self.is_sparse_inputs = is_sparse_inputs
539 | self.featureless = featureless
540 | # self.num_features_nonzero = num_features_nonzero
541 | # self.user_weight = nn.Parameter(torch.randn(input_dim, output_dim))
542 | # self.item_weight = nn.Parameter(torch.randn(input_dim, output_dim))
543 | self.user_weight = nn.Parameter(torch.empty(input_dim, output_dim))
544 | self.item_weight = nn.Parameter(torch.empty(input_dim, output_dim))
545 | nn.init.xavier_uniform_(self.user_weight)
546 | nn.init.xavier_uniform_(self.item_weight)
547 | self.bias = None
548 | if bias:
549 | self.bias = nn.Parameter(torch.zeros(output_dim))
550 |
551 |
552 | def forward(self, user_x, item_x, ui_graph, iu_graph):
553 | # print('inputs:', inputs)
554 | # x, support = inputs
555 | # if self.training and self.is_sparse_inputs:
556 | # x = sparse_dropout(x, self.dropout, self.num_features_nonzero)
557 | # elif self.training:
558 | user_x = F.dropout(user_x, self.dropout)
559 | item_x = F.dropout(item_x, self.dropout)
560 | # convolve
561 | if not self.featureless: # if it has features x
562 | if self.is_sparse_inputs:
563 | xw = torch.sparse.mm(user_x, self.user_weight)
564 | xw = torch.sparse.mm(item_x, self.item_weight)
565 | else:
566 | xw_user = torch.mm(user_x, self.user_weight)
567 | xw_item = torch.mm(item_x, self.item_weight)
568 | else:
569 | xw = self.weight
570 | out_user = torch.sparse.mm(ui_graph, xw_item)
571 | out_item = torch.sparse.mm(iu_graph, xw_user)
572 |
573 | if self.bias is not None:
574 | out += self.bias
575 | return self.activation(out_user), self.activation(out_item)
576 |
577 |
578 | def sparse_dropout(x, rate, noise_shape):
579 | """
580 | :param x:
581 | :param rate:
582 | :param noise_shape: int scalar
583 | :return:
584 | """
585 | random_tensor = 1 - rate
586 | random_tensor += torch.rand(noise_shape).to(x.device)
587 | dropout_mask = torch.floor(random_tensor).byte()
588 | i = x._indices() # [2, 49216]
589 | v = x._values() # [49216]
590 | # [2, 4926] => [49216, 2] => [remained node, 2] => [2, remained node]
591 | i = i[:, dropout_mask]
592 | v = v[dropout_mask]
593 | out = torch.sparse.FloatTensor(i, v, x.shape).to(x.device)
594 | out = out * (1./ (1-rate))
595 | return out
596 |
597 |
598 | def dot(x, y, sparse=False):
599 | if sparse:
600 | res = torch.sparse.mm(x, y)
601 | else:
602 | res = torch.mm(x, y)
603 | return res
604 |
605 |
606 |
607 |
608 |
609 | class BLMLP(nn.Module):
610 | def __init__(self):
611 | super(BLMLP, self).__init__()
612 | self.W = nn.Parameter(nn.init.xavier_uniform_(torch.empty(args.student_embed_size, args.student_embed_size)))
613 | self.act = nn.LeakyReLU(negative_slope=0.5)
614 |
615 | def forward(self, embeds):
616 | pass
617 |
618 | def featureExtract(self, embeds):
619 | return self.act(embeds @ self.W) + embeds
620 |
621 | def pairPred(self, embeds1, embeds2):
622 | return (self.featureExtract(embeds1) * self.featureExtract(embeds2)).sum(dim=-1)
623 |
624 | def crossPred(self, embeds1, embeds2):
625 | return self.featureExtract(embeds1) @ self.featureExtract(embeds2).T
626 |
627 |
628 |
629 | class Student_MLP(nn.Module):
630 | def __init__(self):
631 | super(Student_MLP, self).__init__()
632 | # self.n_users = n_users
633 | # self.n_items = n_items
634 | # self.embedding_dim = embedding_dim
635 |
636 | # self.uEmbeds = nn.Parameter(init(torch.empty(args.user, args.latdim)))
637 | # self.iEmbeds = nn.Parameter(init(torch.empty(args.item, args.latdim)))
638 |
639 | self.user_trans = nn.Linear(args.embed_size, args.embed_size)
640 | self.item_trans = nn.Linear(args.embed_size, args.embed_size)
641 | nn.init.xavier_uniform_(self.user_trans.weight)
642 | nn.init.xavier_uniform_(self.item_trans.weight)
643 |
644 | self.MLP = BLMLP()
645 | # self.overallTime = datetime.timedelta(0)
646 |
647 |
648 | def get_embedding(self):
649 | return self.user_id_embedding, self.item_id_embedding
650 |
651 |
652 | def forward(self, pre_user, pre_item, ):
653 | # pre_user, pre_item = self.user_id_embedding.weight, self.item_id_embedding.weight
654 | user_embed = self.user_trans(pre_user)
655 | item_embed = self.user_trans(pre_item)
656 |
657 | return user_embed, item_embed
658 | # return pre_user, pre_item
659 |
660 | def init_user_item_embed(self, pre_u_embed, pre_i_embed):
661 | self.user_id_embedding = nn.Embedding.from_pretrained(pre_u_embed, freeze=False)
662 | self.item_id_embedding = nn.Embedding.from_pretrained(pre_i_embed, freeze=False)
663 |
664 | def pointPosPredictwEmbeds(self, uEmbeds, iEmbeds, ancs, poss):
665 | ancEmbeds = uEmbeds[ancs]
666 | posEmbeds = iEmbeds[poss]
667 | nume = self.MLP.pairPred(ancEmbeds, posEmbeds)
668 | return nume
669 |
670 | def pointNegPredictwEmbeds(self, embeds1, embeds2, nodes1, temp=1.0):
671 | pckEmbeds1 = embeds1[nodes1]
672 | preds = self.MLP.crossPred(pckEmbeds1, embeds2)
673 | return torch.exp(preds / temp).sum(-1)
674 |
675 | def pairPredictwEmbeds(self, uEmbeds, iEmbeds, ancs, poss, negs):
676 | ancEmbeds = uEmbeds[ancs]
677 | posEmbeds = iEmbeds[poss]
678 | negEmbeds = iEmbeds[negs]
679 | posPreds = self.MLP.pairPred(ancEmbeds, posEmbeds)
680 | negPreds = self.MLP.pairPred(ancEmbeds, negEmbeds)
681 | return posPreds - negPreds
682 |
683 | def predAll(self, pckUEmbeds, iEmbeds):
684 | return self.MLP.crossPred(pckUEmbeds, iEmbeds)
685 |
686 | def testPred(self, usr, trnMask):
687 | uEmbeds, iEmbeds = self.forward()
688 | allPreds = self.predAll(uEmbeds[usr], iEmbeds) * (1 - trnMask) - trnMask * 1e8
689 | return allPreds
690 |
691 |
--------------------------------------------------------------------------------
/codes/__pycache__/MMD.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/MMD.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_0938.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_0938.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_0938_modality.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_0938_modality.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_0938_sub_gene.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_0938_sub_gene.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec3.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec3.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec_SLMRec.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_MF_VBPR_NGCF_LightGCN_MMGCN_HAFR_CLCRec_SLMRec.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_0485.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_0485.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_0826.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_0826.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_333333333333.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_333333333333.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_3_5.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_3_5.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_0962_beforeafterTo0962.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_0962_beforeafterTo0962.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_DtwoMLP.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_DtwoMLP.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_beforeafter.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_beforeafter.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_copy.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_copy.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_delete.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_ablation_delete.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_based_noFeatTrans_ablation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_based_noFeatTrans_ablation.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_dropGNN.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_dropGNN.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_redoADCL.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_redoADCL.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_HL_train.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_HL_train.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_IRGANdeal.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_IRGANdeal.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_MUIT.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_MUIT.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_binaryLoss_modelPara.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_binaryLoss_modelPara.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_oversmoothing.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_oversmoothing.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_woFEAT.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_woFEAT.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_woIIGRAPH.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_woIIGRAPH.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_model_feature.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_model_feature.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_samepos.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_samepos.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_G_first.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_G_first.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_G_first_auto.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_G_first_auto.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_G_first_auto_geneBPR.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models2_sub_gene_co_weight_AD_0938_fake_click_G_first_auto_geneBPR.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models3.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models3.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models3.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models3.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28_distribution.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28_distribution.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28_distribution_noFeatTrans.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models3_sub_gene_co_weight_AD_0938_fake_click_D_first_auto_modality_CL_4_11_0945_0929_mem_5_28_distribution_noFeatTrans.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models_AD.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models_AD.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models_empower.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models_empower.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/Models_mmlight.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/Models_mmlight.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/generator_discriminator.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/__pycache__/generator_discriminator.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/batch_test.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/batch_test.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/batch_test.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/batch_test.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/load_data.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/load_data.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/load_data.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/load_data.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/logging.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/logging.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/logging.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/logging.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/metrics.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/metrics.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/norm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/norm.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/norm.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/norm.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/parser.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/parser.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/parser.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/codes/utility/__pycache__/parser.cpython-39.pyc
--------------------------------------------------------------------------------
/codes/utility/batch_test.py:
--------------------------------------------------------------------------------
1 | import utility.metrics as metrics
2 | # from utility.parser import parse_args
3 | from utility.load_data import Data
4 | import multiprocessing
5 | import heapq
6 | import torch
7 | import pickle
8 | import numpy as np
9 | from time import time
10 |
11 | cores = multiprocessing.cpu_count() // 5
12 |
13 | # args = parse_args()
14 | from utility.parser import args
15 |
16 | # Ks = eval(args.Ks)
17 |
18 | data_generator = Data(path=args.data_path + args.dataset, batch_size=args.batch_size)
19 | USR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items
20 | N_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test
21 | BATCH_SIZE = args.batch_size
22 |
23 | def ranklist_by_heapq(user_pos_test, test_items, rating, Ks):
24 | item_score = {}
25 | for i in test_items:
26 | item_score[i] = rating[i]
27 |
28 | K_max = max(Ks)
29 | K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)
30 |
31 | r = []
32 | for i in K_max_item_score:
33 | if i in user_pos_test:
34 | r.append(1)
35 | else:
36 | r.append(0)
37 | auc = 0.
38 | return r, auc
39 |
40 | def get_auc(item_score, user_pos_test):
41 | item_score = sorted(item_score.items(), key=lambda kv: kv[1])
42 | item_score.reverse()
43 | item_sort = [x[0] for x in item_score]
44 | posterior = [x[1] for x in item_score]
45 |
46 | r = []
47 | for i in item_sort:
48 | if i in user_pos_test:
49 | r.append(1)
50 | else:
51 | r.append(0)
52 | auc = metrics.auc(ground_truth=r, prediction=posterior)
53 | return auc
54 |
55 | def ranklist_by_sorted(user_pos_test, test_items, rating, Ks):
56 | item_score = {}
57 | for i in test_items:
58 | item_score[i] = rating[i]
59 |
60 | K_max = max(Ks)
61 | K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)
62 |
63 | r = []
64 | for i in K_max_item_score:
65 | if i in user_pos_test:
66 | r.append(1)
67 | else:
68 | r.append(0)
69 | auc = get_auc(item_score, user_pos_test)
70 | return r, auc
71 |
72 | def get_performance(user_pos_test, r, auc, Ks):
73 | precision, recall, ndcg, hit_ratio = [], [], [], []
74 |
75 | for K in Ks:
76 | precision.append(metrics.precision_at_k(r, K))
77 | recall.append(metrics.recall_at_k(r, K, len(user_pos_test)))
78 | ndcg.append(metrics.ndcg_at_k(r, K))
79 | hit_ratio.append(metrics.hit_at_k(r, K))
80 |
81 | return {'recall': np.array(recall), 'precision': np.array(precision),
82 | 'ndcg': np.array(ndcg), 'hit_ratio': np.array(hit_ratio), 'auc': auc}
83 |
84 |
85 | def test_one_user(x):
86 | # user u's ratings for user u
87 | is_val = x[-1]
88 | rating = x[0]
89 | #uid
90 | u = x[1]
91 | #user u's items in the training set
92 | try:
93 | training_items = data_generator.train_items[u]
94 | except Exception:
95 | training_items = []
96 | #user u's items in the test set
97 | if is_val:
98 | user_pos_test = data_generator.val_set[u]
99 | else:
100 | user_pos_test = data_generator.test_set[u]
101 |
102 | all_items = set(range(ITEM_NUM))
103 |
104 | test_items = list(all_items - set(training_items))
105 |
106 | if args.test_flag == 'part':
107 | r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, eval(args.Ks))
108 | else:
109 | r, auc = ranklist_by_sorted(user_pos_test, test_items, rating, eval(args.Ks))
110 |
111 | return get_performance(user_pos_test, r, auc, eval(args.Ks))
112 |
113 |
114 | def test_torch(ua_embeddings, ia_embeddings, users_to_test, is_val, drop_flag=False, batch_test_flag=False):
115 | result = {'precision': np.zeros(len(eval(args.Ks))), 'recall': np.zeros(len(eval(args.Ks))), 'ndcg': np.zeros(len(eval(args.Ks))),
116 | 'hit_ratio': np.zeros(len(eval(args.Ks))), 'auc': 0.}
117 | pool = multiprocessing.Pool(cores)
118 |
119 | u_batch_size = BATCH_SIZE * 2
120 | i_batch_size = BATCH_SIZE
121 |
122 | test_users = users_to_test
123 | n_test_users = len(test_users)
124 | n_user_batchs = n_test_users // u_batch_size + 1
125 | count = 0
126 |
127 | for u_batch_id in range(n_user_batchs):
128 | start = u_batch_id * u_batch_size
129 | end = (u_batch_id + 1) * u_batch_size
130 | user_batch = test_users[start: end]
131 | if batch_test_flag:
132 | n_item_batchs = ITEM_NUM // i_batch_size + 1
133 | rate_batch = np.zeros(shape=(len(user_batch), ITEM_NUM))
134 |
135 | i_count = 0
136 | for i_batch_id in range(n_item_batchs):
137 | i_start = i_batch_id * i_batch_size
138 | i_end = min((i_batch_id + 1) * i_batch_size, ITEM_NUM)
139 |
140 | item_batch = range(i_start, i_end)
141 | u_g_embeddings = ua_embeddings[user_batch]
142 | i_g_embeddings = ia_embeddings[item_batch]
143 | i_rate_batch = torch.matmul(u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1))
144 |
145 | rate_batch[:, i_start: i_end] = i_rate_batch
146 | i_count += i_rate_batch.shape[1]
147 |
148 | assert i_count == ITEM_NUM
149 |
150 | else:
151 | item_batch = range(ITEM_NUM)
152 | u_g_embeddings = ua_embeddings[user_batch]
153 | i_g_embeddings = ia_embeddings[item_batch]
154 | rate_batch = torch.matmul(u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1))
155 |
156 | rate_batch = rate_batch.detach().cpu().numpy()
157 | user_batch_rating_uid = zip(rate_batch, user_batch, [is_val] * len(user_batch))
158 |
159 | batch_result = pool.map(test_one_user, user_batch_rating_uid)
160 | count += len(batch_result)
161 |
162 | for re in batch_result:
163 | result['precision'] += re['precision'] / n_test_users
164 | result['recall'] += re['recall'] / n_test_users
165 | result['ndcg'] += re['ndcg'] / n_test_users
166 | result['hit_ratio'] += re['hit_ratio'] / n_test_users
167 | result['auc'] += re['auc'] / n_test_users
168 |
169 | assert count == n_test_users
170 | pool.close()
171 | return result
172 |
--------------------------------------------------------------------------------
/codes/utility/load_data.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import random as rd
4 | import scipy.sparse as sp
5 | from time import time
6 | import json
7 | # from utility.parser import parse_args
8 | # args = parse_args()
9 | from utility.parser import args
10 |
11 |
12 | class Data(object):
13 | def __init__(self, path, batch_size):
14 | self.path = path #+ '/%d-core' % args.core
15 | self.batch_size = batch_size
16 |
17 | train_file = path + '/train.json'#+ '/%d-core/train.json' % (args.core)
18 | val_file = path + '/val.json' #+ '/%d-core/val.json' % (args.core)
19 | test_file = path + '/test.json' #+ '/%d-core/test.json' % (args.core)
20 |
21 | #get number of users and items
22 | self.n_users, self.n_items = 0, 0
23 | self.n_train, self.n_test = 0, 0
24 | self.neg_pools = {}
25 |
26 | self.exist_users = []
27 |
28 | train = json.load(open(train_file))
29 | test = json.load(open(test_file))
30 | val = json.load(open(val_file))
31 | for uid, items in train.items():
32 | if len(items) == 0:
33 | continue
34 | uid = int(uid)
35 | self.exist_users.append(uid)
36 | self.n_items = max(self.n_items, max(items))
37 | self.n_users = max(self.n_users, uid)
38 | self.n_train += len(items)
39 |
40 | for uid, items in test.items():
41 | uid = int(uid)
42 | try:
43 | self.n_items = max(self.n_items, max(items))
44 | self.n_test += len(items)
45 | except:
46 | continue
47 |
48 | for uid, items in val.items():
49 | uid = int(uid)
50 | try:
51 | self.n_items = max(self.n_items, max(items))
52 | self.n_val += len(items)
53 | except:
54 | continue
55 |
56 | self.n_items += 1
57 | self.n_users += 1
58 |
59 | self.print_statistics()
60 |
61 | self.R = sp.dok_matrix((self.n_users, self.n_items), dtype=np.float32)
62 | self.R_Item_Interacts = sp.dok_matrix((self.n_items, self.n_items), dtype=np.float32)
63 |
64 | self.train_items, self.test_set, self.val_set = {}, {}, {}
65 | for uid, train_items in train.items():
66 | if len(train_items) == 0:
67 | continue
68 | uid = int(uid)
69 | for idx, i in enumerate(train_items):
70 | self.R[uid, i] = 1.
71 |
72 | self.train_items[uid] = train_items
73 |
74 | for uid, test_items in test.items():
75 | uid = int(uid)
76 | if len(test_items) == 0:
77 | continue
78 | try:
79 | self.test_set[uid] = test_items
80 | except:
81 | continue
82 |
83 | for uid, val_items in val.items():
84 | uid = int(uid)
85 | if len(val_items) == 0:
86 | continue
87 | try:
88 | self.val_set[uid] = val_items
89 | except:
90 | continue
91 |
92 | def get_adj_mat(self):
93 | try:
94 | t1 = time()
95 | adj_mat = sp.load_npz(self.path + '/s_adj_mat.npz')
96 | norm_adj_mat = sp.load_npz(self.path + '/s_norm_adj_mat.npz')
97 | mean_adj_mat = sp.load_npz(self.path + '/s_mean_adj_mat.npz')
98 | print('already load adj matrix', adj_mat.shape, time() - t1)
99 |
100 | except Exception:
101 | adj_mat, norm_adj_mat, mean_adj_mat = self.create_adj_mat()
102 | sp.save_npz(self.path + '/s_adj_mat.npz', adj_mat)
103 | sp.save_npz(self.path + '/s_norm_adj_mat.npz', norm_adj_mat)
104 | sp.save_npz(self.path + '/s_mean_adj_mat.npz', mean_adj_mat)
105 | return adj_mat, norm_adj_mat, mean_adj_mat
106 |
107 | def create_adj_mat(self):
108 | t1 = time()
109 | adj_mat = sp.dok_matrix((self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32)
110 | adj_mat = adj_mat.tolil()
111 | R = self.R.tolil()
112 |
113 | adj_mat[:self.n_users, self.n_users:] = R
114 | adj_mat[self.n_users:, :self.n_users] = R.T
115 | adj_mat = adj_mat.todok()
116 | print('already create adjacency matrix', adj_mat.shape, time() - t1)
117 |
118 | t2 = time()
119 |
120 | def normalized_adj_single(adj):
121 | rowsum = np.array(adj.sum(1))
122 |
123 | d_inv = np.power(rowsum, -1).flatten()
124 | d_inv[np.isinf(d_inv)] = 0.
125 | d_mat_inv = sp.diags(d_inv)
126 |
127 | norm_adj = d_mat_inv.dot(adj)
128 | # norm_adj = adj.dot(d_mat_inv)
129 | print('generate single-normalized adjacency matrix.')
130 | return norm_adj.tocoo()
131 |
132 | def get_D_inv(adj):
133 | rowsum = np.array(adj.sum(1))
134 |
135 | d_inv = np.power(rowsum, -1).flatten()
136 | d_inv[np.isinf(d_inv)] = 0.
137 | d_mat_inv = sp.diags(d_inv)
138 | return d_mat_inv
139 |
140 | def check_adj_if_equal(adj):
141 | dense_A = np.array(adj.todense())
142 | degree = np.sum(dense_A, axis=1, keepdims=False)
143 |
144 | temp = np.dot(np.diag(np.power(degree, -1)), dense_A)
145 | print('check normalized adjacency matrix whether equal to this laplacian matrix.')
146 | return temp
147 |
148 | norm_adj_mat = normalized_adj_single(adj_mat + sp.eye(adj_mat.shape[0]))
149 | mean_adj_mat = normalized_adj_single(adj_mat)
150 |
151 | print('already normalize adjacency matrix', time() - t2)
152 | return adj_mat.tocsr(), norm_adj_mat.tocsr(), mean_adj_mat.tocsr()
153 |
154 |
155 | def sample(self):
156 | if self.batch_size <= self.n_users:
157 | users = rd.sample(self.exist_users, self.batch_size)
158 | else:
159 | users = [rd.choice(self.exist_users) for _ in range(self.batch_size)]
160 | # users = self.exist_users[:]
161 |
162 | def sample_pos_items_for_u(u, num):
163 | pos_items = self.train_items[u]
164 | n_pos_items = len(pos_items)
165 | pos_batch = []
166 | while True:
167 | if len(pos_batch) == num: break
168 | pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0]
169 | pos_i_id = pos_items[pos_id]
170 |
171 | if pos_i_id not in pos_batch:
172 | pos_batch.append(pos_i_id)
173 | return pos_batch
174 |
175 | def sample_neg_items_for_u(u, num):
176 | neg_items = []
177 | while True:
178 | if len(neg_items) == num: break
179 | neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
180 | if neg_id not in self.train_items[u] and neg_id not in neg_items:
181 | neg_items.append(neg_id)
182 | return neg_items
183 |
184 | def sample_neg_items_for_u_from_pools(u, num):
185 | neg_items = list(set(self.neg_pools[u]) - set(self.train_items[u]))
186 | return rd.sample(neg_items, num)
187 |
188 | pos_items, neg_items = [], []
189 | for u in users:
190 | pos_items += sample_pos_items_for_u(u, 1)
191 | neg_items += sample_neg_items_for_u(u, 1)
192 | # neg_items += sample_neg_items_for_u(u, 3)
193 | return users, pos_items, neg_items
194 |
195 |
196 | def print_statistics(self):
197 | print('n_users=%d, n_items=%d' % (self.n_users, self.n_items))
198 | print('n_interactions=%d' % (self.n_train + self.n_test))
199 | print('n_train=%d, n_test=%d, sparsity=%.5f' % (self.n_train, self.n_test, (self.n_train + self.n_test)/(self.n_users * self.n_items)))
200 |
201 |
--------------------------------------------------------------------------------
/codes/utility/logging.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 |
4 | class Logger():
5 | def __init__(self, filename, is_debug, path='/home/weiw/Code/MM/KDMM/logs/'):
6 | self.filename = filename
7 | self.path = path
8 | self.log_ = not is_debug
9 | def logging(self, s):
10 | s = str(s)
11 | print(datetime.now().strftime('%Y-%m-%d %H:%M: '), s)
12 | if self.log_:
13 | with open(os.path.join(os.path.join(self.path, self.filename)), 'a+') as f_log:
14 | f_log.write(str(datetime.now().strftime('%Y-%m-%d %H:%M: ')) + s + '\n')
15 |
--------------------------------------------------------------------------------
/codes/utility/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import roc_auc_score
3 |
4 | def recall(rank, ground_truth, N):
5 | return len(set(rank[:N]) & set(ground_truth)) / float(len(set(ground_truth)))
6 |
7 |
8 | def precision_at_k(r, k):
9 | """Score is precision @ k
10 | Relevance is binary (nonzero is relevant).
11 | Returns:
12 | Precision @ k
13 | Raises:
14 | ValueError: len(r) must be >= k
15 | """
16 | assert k >= 1
17 | r = np.asarray(r)[:k]
18 | return np.mean(r)
19 |
20 |
21 | def average_precision(r,cut):
22 | """Score is average precision (area under PR curve)
23 | Relevance is binary (nonzero is relevant).
24 | Returns:
25 | Average precision
26 | """
27 | r = np.asarray(r)
28 | out = [precision_at_k(r, k + 1) for k in range(cut) if r[k]]
29 | if not out:
30 | return 0.
31 | return np.sum(out)/float(min(cut, np.sum(r)))
32 |
33 |
34 | def mean_average_precision(rs):
35 | """Score is mean average precision
36 | Relevance is binary (nonzero is relevant).
37 | Returns:
38 | Mean average precision
39 | """
40 | return np.mean([average_precision(r) for r in rs])
41 |
42 |
43 | def dcg_at_k(r, k, method=1):
44 | """Score is discounted cumulative gain (dcg)
45 | Relevance is positive real values. Can use binary
46 | as the previous methods.
47 | Returns:
48 | Discounted cumulative gain
49 | """
50 | r = np.asfarray(r)[:k]
51 | if r.size:
52 | if method == 0:
53 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
54 | elif method == 1:
55 | return np.sum(r / np.log2(np.arange(2, r.size + 2)))
56 | else:
57 | raise ValueError('method must be 0 or 1.')
58 | return 0.
59 |
60 |
61 | def ndcg_at_k(r, k, method=1):
62 | """Score is normalized discounted cumulative gain (ndcg)
63 | Relevance is positive real values. Can use binary
64 | as the previous methods.
65 | Returns:
66 | Normalized discounted cumulative gain
67 | """
68 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)
69 | if not dcg_max:
70 | return 0.
71 | return dcg_at_k(r, k, method) / dcg_max
72 |
73 |
74 | def recall_at_k(r, k, all_pos_num):
75 | r = np.asfarray(r)[:k]
76 | if all_pos_num == 0:
77 | return 0
78 | else:
79 | return np.sum(r) / all_pos_num
80 |
81 |
82 | def hit_at_k(r, k):
83 | r = np.array(r)[:k]
84 | if np.sum(r) > 0:
85 | return 1.
86 | else:
87 | return 0.
88 |
89 | def F1(pre, rec):
90 | if pre + rec > 0:
91 | return (2.0 * pre * rec) / (pre + rec)
92 | else:
93 | return 0.
94 |
95 | def auc(ground_truth, prediction):
96 | try:
97 | res = roc_auc_score(y_true=ground_truth, y_score=prediction)
98 | except Exception:
99 | res = 0.
100 | return res
--------------------------------------------------------------------------------
/codes/utility/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.sparse import csr_matrix
4 |
5 | def build_sim(context):
6 | context_norm = context.div(torch.norm(context, p=2, dim=-1, keepdim=True))
7 | sim = torch.sparse.mm(context_norm, context_norm.transpose(1, 0))
8 | # a, b = context_norm.shape
9 | # b, c = context_norm.transpose(1, 0).shape
10 | # ab = context_norm.unsqueeze(-1) #.repeat(1,1,c)
11 | # bc = context_norm.transpose(1, 0).unsqueeze(0) #.repeat(a, 1,1)
12 | # sim = torch.mul(ab, bc).sum(dim=1, keepdim=False)
13 |
14 | return sim
15 |
16 | # def build_knn_normalized_graph(adj, topk, is_sparse, norm_type):
17 | # device = adj.device
18 | # knn_val, knn_ind = torch.topk(adj, topk, dim=-1)
19 | # if is_sparse:
20 | # tuple_list = [[row, int(col)] for row in range(len(knn_ind)) for col in knn_ind[row]]
21 | # row = [i[0] for i in tuple_list]
22 | # col = [i[1] for i in tuple_list]
23 | # i = torch.LongTensor([row, col]).to(device)
24 | # v = knn_val.flatten()
25 | # edge_index, edge_weight = get_sparse_laplacian(i, v, normalization=norm_type, num_nodes=adj.shape[0])
26 | # return torch.sparse_coo_tensor(edge_index, edge_weight, adj.shape)
27 | # else:
28 | # weighted_adjacency_matrix = (torch.zeros_like(adj)).scatter_(-1, knn_ind, knn_val)
29 | # return get_dense_laplacian(weighted_adjacency_matrix, normalization=norm_type)
30 |
31 | def build_knn_normalized_graph(adj, topk, is_sparse, norm_type):
32 | device = adj.device
33 | knn_val, knn_ind = torch.topk(adj, topk, dim=-1) #[7050, 10] [7050, 10]
34 | n_item = knn_val.shape[0]
35 | n_data = knn_val.shape[0]*knn_val.shape[1]
36 | data = np.ones(n_data)
37 | if is_sparse:
38 | tuple_list = [[row, int(col)] for row in range(len(knn_ind)) for col in knn_ind[row]] #[70500]
39 | # data = np.array(knn_val.flatten().cpu()) #args.topk_rate*
40 | row = [i[0] for i in tuple_list] #[70500]
41 | col = [i[1] for i in tuple_list] #[70500]
42 | # #-----------------------------------------------------------------------------------------------------
43 | # i = torch.LongTensor([row, col]).to(device)
44 | # v = knn_val.flatten()
45 | # edge_index, edge_weight = get_sparse_laplacian(i, v, normalization=norm_type, num_nodes=adj.shape[0])
46 | # #-----------------------------------------------------------------------------------------------------
47 | ii_graph = csr_matrix((data, (row, col)) ,shape=(n_item, n_item))
48 | # return torch.sparse_coo_tensor(edge_index, edge_weight, adj.shape)
49 | return ii_graph
50 | else:
51 | weighted_adjacency_matrix = (torch.zeros_like(adj)).scatter_(-1, knn_ind, knn_val)
52 | return get_dense_laplacian(weighted_adjacency_matrix, normalization=norm_type)
53 |
54 |
55 | def get_sparse_laplacian(edge_index, edge_weight, num_nodes, normalization='none'): #[2, 70500], [70500]
56 | from torch_scatter import scatter_add
57 | row, col = edge_index[0], edge_index[1] #[70500] [70500]
58 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) #[7050]
59 |
60 | if normalization == 'sym':
61 | deg_inv_sqrt = deg.pow_(-0.5)
62 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
63 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
64 | elif normalization == 'rw':
65 | deg_inv = 1.0 / deg
66 | deg_inv.masked_fill_(deg_inv == float('inf'), 0)
67 | edge_weight = deg_inv[row] * edge_weight
68 | return edge_index, edge_weight
69 |
70 |
71 | def get_dense_laplacian(adj, normalization='none'):
72 | if normalization == 'sym':
73 | rowsum = torch.sum(adj, -1)
74 | d_inv_sqrt = torch.pow(rowsum, -0.5)
75 | d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
76 | d_mat_inv_sqrt = torch.diagflat(d_inv_sqrt)
77 | L_norm = torch.mm(torch.mm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)
78 | elif normalization == 'rw':
79 | rowsum = torch.sum(adj, -1)
80 | d_inv = torch.pow(rowsum, -1)
81 | d_inv[torch.isinf(d_inv)] = 0.
82 | d_mat_inv = torch.diagflat(d_inv)
83 | L_norm = torch.mm(d_mat_inv, adj)
84 | elif normalization == 'none':
85 | L_norm = adj
86 | return L_norm
87 |
--------------------------------------------------------------------------------
/decouple.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/PromptMM/70da1002a35d6f2c7712c16cd0b2ca24c8813008/decouple.png
--------------------------------------------------------------------------------