├── codes
├── utility
│ ├── __pycache__
│ │ ├── parser.cpython-36.pyc
│ │ ├── parser.cpython-38.pyc
│ │ ├── load_data.cpython-36.pyc
│ │ ├── load_data.cpython-38.pyc
│ │ ├── metrics.cpython-36.pyc
│ │ ├── metrics.cpython-38.pyc
│ │ ├── batch_test.cpython-36.pyc
│ │ ├── batch_test.cpython-38.pyc
│ │ ├── batch_test.cpython-310-pytest-7.4.3.pyc
│ │ ├── batch_test.cpython-36-pytest-6.2.5.pyc
│ │ ├── batch_test.cpython-36-pytest-7.0.1.pyc
│ │ └── index.html.tmp
│ ├── metrics.py
│ ├── parser.py
│ ├── batch_test.py
│ └── load_data.py
├── main.py
├── Preliminaries.ipynb
├── data
│ └── build_data.py
└── Models.py
├── requirements.txt
├── LICENSE
└── README.md
/codes/utility/__pycache__/parser.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/parser.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/parser.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/parser.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/load_data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/load_data.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/load_data.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/load_data.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/metrics.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/metrics.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/batch_test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/batch_test.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/batch_test.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/batch_test.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/batch_test.cpython-310-pytest-7.4.3.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/batch_test.cpython-310-pytest-7.4.3.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/batch_test.cpython-36-pytest-6.2.5.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/batch_test.cpython-36-pytest-6.2.5.pyc
--------------------------------------------------------------------------------
/codes/utility/__pycache__/batch_test.cpython-36-pytest-7.0.1.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/batch_test.cpython-36-pytest-7.0.1.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gensim==3.8.3
2 | pytorch==1.10.2+cu113
3 | sentence_transformers==2.2.0
4 | pandas
5 | numpy
6 | tqdm
7 | torch-scatter
8 | torch-sparse
9 | torch-cluster
10 | torch-spline-conv
11 | torch-geometric
12 |
--------------------------------------------------------------------------------
/codes/utility/__pycache__/index.html.tmp:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Directory listing for /MONET/codes/utility/__pycache__/
6 |
7 |
8 | Directory listing for /MONET/codes/utility/__pycache__/
9 |
10 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Big Data and Multi-modal Computing Group, CRIPAC
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/codes/utility/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import roc_auc_score
3 |
4 |
5 | def recall(rank, ground_truth, N):
6 | return len(set(rank[:N]) & set(ground_truth)) / float(len(set(ground_truth)))
7 |
8 |
9 | def precision_at_k(r, k):
10 | """Score is precision @ k.
11 |
12 | Relevance is binary (nonzero is relevant).
13 | Returns:
14 | Precision @ k
15 | Raises:
16 | ValueError: len(r) must be >= k
17 | """
18 | assert k >= 1
19 | r = np.asarray(r)[:k]
20 | return np.mean(r)
21 |
22 |
23 | def average_precision(r, cut):
24 | """Score is average precision (area under PR curve).
25 |
26 | Relevance is binary (nonzero is relevant).
27 | Returns:
28 | Average precision
29 | """
30 | r = np.asarray(r)
31 | out = [precision_at_k(r, k + 1) for k in range(cut) if r[k]]
32 | if not out:
33 | return 0.0
34 | return np.sum(out) / float(min(cut, np.sum(r)))
35 |
36 |
37 | def mean_average_precision(rs):
38 | """Score is mean average precision.
39 |
40 | Relevance is binary (nonzero is relevant).
41 | Returns:
42 | Mean average precision
43 | """
44 | return np.mean([average_precision(r) for r in rs])
45 |
46 |
47 | def dcg_at_k(r, k, method=1):
48 | """Score is discounted cumulative gain (dcg).
49 |
50 | Relevance is positive real values. Can use binary
51 | as the previous methods.
52 | Returns:
53 | Discounted cumulative gain
54 | """
55 | r = np.asfarray(r)[:k]
56 | if r.size:
57 | if method == 0:
58 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
59 | elif method == 1:
60 | return np.sum(r / np.log2(np.arange(2, r.size + 2)))
61 | else:
62 | raise ValueError("method must be 0 or 1.")
63 | return 0.0
64 |
65 |
66 | def ndcg_at_k(r, k, method=1):
67 | """Score is normalized discounted cumulative gain (ndcg).
68 |
69 | Relevance is positive real values. Can use binary
70 | as the previous methods.
71 | Returns:
72 | Normalized discounted cumulative gain
73 | """
74 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)
75 | if not dcg_max:
76 | return 0.0
77 | return dcg_at_k(r, k, method) / dcg_max
78 |
79 |
80 | def recall_at_k(r, k, all_pos_num):
81 | r = np.asfarray(r)[:k]
82 | if all_pos_num == 0:
83 | return 0
84 | else:
85 | return np.sum(r) / all_pos_num
86 |
87 |
88 | def hit_at_k(r, k):
89 | r = np.array(r)[:k]
90 | if np.sum(r) > 0:
91 | return 1.0
92 | else:
93 | return 0.0
94 |
95 |
96 | def F1(pre, rec):
97 | if pre + rec > 0:
98 | return (2.0 * pre * rec) / (pre + rec)
99 | else:
100 | return 0.0
101 |
102 |
103 | def auc(ground_truth, prediction):
104 | try:
105 | res = roc_auc_score(y_true=ground_truth, y_score=prediction)
106 | except Exception:
107 | res = 0.0
108 | return res
109 |
--------------------------------------------------------------------------------
/codes/utility/parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def parse_args(flags=False):
5 | parser = argparse.ArgumentParser(description="")
6 |
7 | parser.add_argument(
8 | "--data_path", nargs="?", default="data/", help="Input data path."
9 | )
10 | parser.add_argument("--seed", type=int, default=123, help="Random seed")
11 | parser.add_argument(
12 | "--dataset",
13 | nargs="?",
14 | default="MenClothing",
15 | help="Choose a dataset from {Toys_and_Games, Beauty, MenClothing, WomenClothing}",
16 | )
17 | parser.add_argument(
18 | "--verbose", type=int, default=5, help="Interval of evaluation."
19 | )
20 | parser.add_argument("--epoch", type=int, default=1000, help="Number of epoch.")
21 | parser.add_argument("--batch_size", type=int, default=1024, help="Batch size.")
22 | parser.add_argument(
23 | "--regs", nargs="?", default="[1e-5,1e-5]", help="Regularizations."
24 | )
25 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate.")
26 | parser.add_argument("--embed_size", type=int, default=64, help="Embedding size.")
27 | parser.add_argument(
28 | "--feat_embed_dim", type=int, default=64, help="Feature embedding size."
29 | )
30 | parser.add_argument(
31 | "--alpha", type=float, default=1.0, help="Coefficient of self node features."
32 | )
33 | parser.add_argument(
34 | "--beta",
35 | type=float,
36 | default=0.3,
37 | help="Coefficient of fine-grained interest matching.",
38 | )
39 | parser.add_argument(
40 | "--core",
41 | type=int,
42 | default=5,
43 | help="5-core for warm-start; 0-core for cold start.",
44 | )
45 | parser.add_argument(
46 | "--n_layers", type=int, default=2, help="Number of graph conv layers."
47 | )
48 | parser.add_argument("--has_norm", default=True, action="store_false")
49 | parser.add_argument("--target_aware", default=True, action="store_false")
50 | parser.add_argument(
51 | "--agg",
52 | type=str,
53 | default="concat",
54 | help="Choose a dataset from {sum, weighted_sum, concat, fc}",
55 | )
56 | parser.add_argument("--cf", default=False, action="store_true")
57 | parser.add_argument(
58 | "--cf_gcn",
59 | type=str,
60 | default="LightGCN",
61 | help="Choose a dataset from {MeGCN, LightGCN}",
62 | )
63 | parser.add_argument("--lightgcn", default=False, action="store_true")
64 | parser.add_argument("--model_name", type=str)
65 | parser.add_argument("--early_stopping_patience", type=int, default=10, help="")
66 | parser.add_argument("--gpu_id", type=int, default=0, help="GPU id")
67 | parser.add_argument(
68 | "--Ks", nargs="?", default="[10, 20]", help="K value of ndcg/recall @ k"
69 | )
70 | parser.add_argument(
71 | "--test_flag",
72 | nargs="?",
73 | default="part",
74 | help="Specify the test type from {part, full}, indicating whether the reference is done in mini-batch",
75 | )
76 |
77 | if flags:
78 | attribute_dict = dict(vars(parser.parse_args()))
79 | print("*" * 32 + " Experiment setting " + "*" * 32)
80 | for k, v in attribute_dict.items():
81 | print(k + " : " + str(v))
82 | print("*" * 32 + " Experiment setting " + "*" * 32)
83 | return parser.parse_args()
84 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MONET: Modality-Embracing Graph Convolutional Network and Target-Aware Attention for Multimedia Recommendation
2 | This repository provides a reference implementation of *MONET* as described in the following paper:
3 | > MONET: Modality-Embracing Graph Convolutional Network and Target-Aware Attention for Multimedia Recommendation
4 | > Yungi Kim, Taeri Kim, Won-Yong Shin and Sang-Wook Kim
5 | > 17th ACM Int'l Conf. on Web Search and Data Mining (ACM WSDM 2024)
6 |
7 | ### Overview of MONET
8 | 
9 |
10 |
11 | ### Authors
12 | - Yungi Kim (gozj3319@hanyang.ac.kr)
13 | - Taeri Kim (taerik@hanyang.ac.kr)
14 | - Won-Yong Shin (wy.shin@yonsei.ac.kr)
15 | - Sang-Wook Kim (wook@hanyang.ac.kr)
16 |
17 | ### Requirements
18 | The code has been tested running under Python 3.6.13. The required packages are as follows:
19 | - ```gensim==3.8.3```
20 | - ```pytorch==1.10.2+cu113```
21 | - ```torch_geometric==2.0.3```
22 | - ```sentence_transformers==2.2.0```
23 | - ```pandas```
24 | - ```numpy```
25 | - ```tqdm```
26 | - ```torch-scatter```
27 | - ```torch-sparse```
28 | - ```torch-cluster```
29 | - ```torch-spline-conv```
30 | - ```torch-geometric```
31 |
32 | ### Dataset Preparation
33 | #### Dataset Download
34 | *Men Clothing and Women Clothing*: Download Amazon product dataset provided by [MAML](https://github.com/liufancs/MAML). Put data folder into the directory data/.
35 |
36 | *Beauty and Toys & Games*: Download 5-core reviews data, meta data, and image features from [Amazon product dataset](http://jmcauley.ucsd.edu/data/amazon/links.html). Put data into the directory data/{folder}/meta-data/.
37 |
38 | #### Dataset Preprocessing
39 | Run ```python build_data.py --name={Dataset}```
40 |
41 | ### Usage
42 | #### For simplicity, we provide usage for the Women Clothing dataset.
43 | ------------------------------------
44 | - For MONET in RQ1,
45 | ```
46 | python main.py --agg=concat --n_layers=2 --alpha=1.0 --beta=0.3 --dataset=WomenClothing --model_name=MONET_2_10_3
47 | ```
48 | ------------------------------------
49 | - For RQ2, refer the second cell in "Preliminaries.ipynb".
50 | ------------------------------------
51 | - For MONET_w/o_MeGCN and MONET_w/o_TA in RQ3,
52 | ```
53 | python main.py --agg=concat --n_layers=0 --alpha=1.0 --beta=0.3 --dataset=WomenClothing --model_name=MONET_wo_MeGCN
54 | python main.py --target_aware --agg=concat --n_layers=2 --alpha=1.0 --beta=0.3 --dataset=WomenClothing --model_name=MONET_wo_TA
55 | ```
56 | ------------------------------------
57 | - For RQ4 (hyperparameters $\alpha$, $\beta$ sensitivity),
58 | ```
59 | python main.py --agg=concat --n_layers=2 --alpha={value} --beta=0.3 --dataset=WomenClothing --model_name=MONET_2_{alpha}_3
60 | python main.py --agg=concat --n_layers=2 --alpha=1.0 --beta={value} --dataset=WomenClothing --model_name=MONET_2_10_{beta}
61 | ```
62 |
63 | ### Cite
64 | We encourage you to cite our paper if you have used the code in your work. You can use the following BibTex citation:
65 | ```
66 | @inproceedings{kim24wsdm,
67 | author = {Yungi Kim and Taeri Kim and Won{-}Yong Shin and Sang{-}Wook Kim},
68 | title = {MONET: Modality-Embracing Graph Convolutional Network and Target-Aware Attention for Multimedia Recommendation},
69 | booktitle = {ACM International Conference on Web Search and Data Mining (ACM WSDM 2024)},
70 | year = {2024}
71 | }
72 | ```
73 |
74 | ### Acknowledgement
75 | The structure of this code is largely based on [LATTICE](https://github.com/CRIPAC-DIG/LATTICE). Thank for their work.
76 |
--------------------------------------------------------------------------------
/codes/utility/batch_test.py:
--------------------------------------------------------------------------------
1 | import heapq
2 | import multiprocessing
3 | import pickle
4 | from time import time
5 |
6 | import numpy as np
7 | import torch
8 | import utility.metrics as metrics
9 | from tqdm import tqdm
10 | from utility.load_data import Data
11 | from utility.parser import parse_args
12 |
13 | cores = multiprocessing.cpu_count() // 5
14 |
15 | args = parse_args()
16 | Ks = eval(args.Ks)
17 |
18 | data_generator = Data(path=args.data_path + args.dataset, batch_size=args.batch_size)
19 | USR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items
20 | N_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test
21 | if args.target_aware:
22 | BATCH_SIZE = 16
23 | else:
24 | BATCH_SIZE = args.batch_size
25 |
26 |
27 | def ranklist_by_heapq(user_pos_test, test_items, rating, Ks):
28 | item_score = {}
29 | for i in test_items:
30 | item_score[i] = rating[i]
31 |
32 | K_max = max(Ks)
33 | K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)
34 |
35 | r = []
36 | for i in K_max_item_score:
37 | if i in user_pos_test:
38 | r.append(1)
39 | else:
40 | r.append(0)
41 | auc = 0.0
42 | return r, auc
43 |
44 |
45 | def get_auc(item_score, user_pos_test):
46 | item_score = sorted(item_score.items(), key=lambda kv: kv[1])
47 | item_score.reverse()
48 | item_sort = [x[0] for x in item_score]
49 | posterior = [x[1] for x in item_score]
50 |
51 | r = []
52 | for i in item_sort:
53 | if i in user_pos_test:
54 | r.append(1)
55 | else:
56 | r.append(0)
57 | auc = metrics.auc(ground_truth=r, prediction=posterior)
58 | return auc
59 |
60 |
61 | def ranklist_by_sorted(user_pos_test, test_items, rating, Ks):
62 | item_score = {}
63 | for i in test_items:
64 | item_score[i] = rating[i]
65 |
66 | K_max = max(Ks)
67 | K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)
68 |
69 | r = []
70 | for i in K_max_item_score:
71 | if i in user_pos_test:
72 | r.append(1)
73 | else:
74 | r.append(0)
75 | auc = get_auc(item_score, user_pos_test)
76 | return r, auc
77 |
78 |
79 | def get_performance(user_pos_test, r, auc, Ks):
80 | precision, recall, ndcg, hit_ratio = [], [], [], []
81 |
82 | for K in Ks:
83 | precision.append(metrics.precision_at_k(r, K))
84 | recall.append(metrics.recall_at_k(r, K, len(user_pos_test)))
85 | ndcg.append(metrics.ndcg_at_k(r, K))
86 | hit_ratio.append(metrics.hit_at_k(r, K))
87 |
88 | return {
89 | "recall": np.array(recall),
90 | "precision": np.array(precision),
91 | "ndcg": np.array(ndcg),
92 | "hit_ratio": np.array(hit_ratio),
93 | "auc": auc,
94 | }
95 |
96 |
97 | def test_one_user(x):
98 | # user u's ratings for user u
99 | is_val = x[-1]
100 | rating = x[0]
101 | # uid
102 | u = x[1]
103 | # user u's items in the training set
104 | try:
105 | training_items = data_generator.train_items[u]
106 | except Exception:
107 | training_items = []
108 | if is_val:
109 | user_pos_test = data_generator.val_set[u]
110 | else:
111 | user_pos_test = data_generator.test_set[u]
112 |
113 | all_items = set(range(ITEM_NUM))
114 |
115 | test_items = list(all_items - set(training_items))
116 |
117 | if args.test_flag == "part":
118 | r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, Ks)
119 | else:
120 | r, auc = ranklist_by_sorted(user_pos_test, test_items, rating, Ks)
121 |
122 | return get_performance(user_pos_test, r, auc, Ks)
123 |
124 |
125 | def test_torch(
126 | ua_embeddings, ia_embeddings, users_to_test, is_val, adj, beta, target_aware
127 | ):
128 | result = {
129 | "precision": np.zeros(len(Ks)),
130 | "recall": np.zeros(len(Ks)),
131 | "ndcg": np.zeros(len(Ks)),
132 | "hit_ratio": np.zeros(len(Ks)),
133 | "auc": 0.0,
134 | }
135 | pool = multiprocessing.Pool(cores)
136 |
137 | u_batch_size = BATCH_SIZE * 2
138 | i_batch_size = BATCH_SIZE
139 |
140 | test_users = users_to_test
141 | n_test_users = len(test_users)
142 | n_user_batchs = n_test_users // u_batch_size + 1
143 | count = 0
144 |
145 | item_item = torch.mm(ia_embeddings, ia_embeddings.T)
146 |
147 | for u_batch_id in tqdm(range(n_user_batchs), position=1, leave=False):
148 | start = u_batch_id * u_batch_size
149 | end = (u_batch_id + 1) * u_batch_size
150 | user_batch = test_users[start:end]
151 | if target_aware:
152 | n_item_batchs = ITEM_NUM // i_batch_size + 1
153 | rate_batch = np.zeros(shape=(len(user_batch), ITEM_NUM))
154 |
155 | i_count = 0
156 | for i_batch_id in range(n_item_batchs):
157 | i_start = i_batch_id * i_batch_size
158 | i_end = min((i_batch_id + 1) * i_batch_size, ITEM_NUM)
159 |
160 | item_batch = range(i_start, i_end)
161 | u_g_embeddings = ua_embeddings[user_batch] # (batch_size, dim)
162 | i_g_embeddings = ia_embeddings[item_batch] # (batch_size, dim)
163 |
164 | # target-aware
165 | item_query = item_item[item_batch, :] # (item_batch_size, n_items)
166 | item_target_user_alpha = torch.softmax(
167 | torch.multiply(
168 | item_query.unsqueeze(1), adj[user_batch, :].unsqueeze(0)
169 | ).masked_fill(
170 | adj[user_batch, :].repeat(len(item_batch), 1, 1) == 0, -1e9
171 | ),
172 | dim=2,
173 | ) # (item_batch_size, user_batch_size, n_items)
174 | item_target_user = torch.matmul(
175 | item_target_user_alpha, ia_embeddings
176 | ) # (item_batch_size, user_batch_size, dim)
177 |
178 | # target-aware
179 | i_rate_batch = (1 - beta) * torch.matmul(
180 | u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1)
181 | ) + beta * torch.sum(
182 | torch.mul(
183 | item_target_user.permute(1, 0, 2).contiguous(), i_g_embeddings
184 | ),
185 | dim=2,
186 | )
187 |
188 | rate_batch[:, i_start:i_end] = i_rate_batch.detach().cpu().numpy()
189 | i_count += i_rate_batch.shape[1]
190 |
191 | del (
192 | item_query,
193 | item_target_user_alpha,
194 | item_target_user,
195 | i_g_embeddings,
196 | u_g_embeddings,
197 | )
198 | torch.cuda.empty_cache()
199 |
200 | assert i_count == ITEM_NUM
201 |
202 | else:
203 | item_batch = range(ITEM_NUM)
204 | u_g_embeddings = ua_embeddings[user_batch]
205 | i_g_embeddings = ia_embeddings[item_batch]
206 |
207 | rate_batch = torch.matmul(
208 | u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1)
209 | )
210 | rate_batch = rate_batch.detach().cpu().numpy()
211 |
212 | user_batch_rating_uid = zip(rate_batch, user_batch, [is_val] * len(user_batch))
213 |
214 | batch_result = pool.map(test_one_user, user_batch_rating_uid)
215 | count += len(batch_result)
216 |
217 | for re in batch_result:
218 | result["precision"] += re["precision"] / n_test_users
219 | result["recall"] += re["recall"] / n_test_users
220 | result["ndcg"] += re["ndcg"] / n_test_users
221 | result["hit_ratio"] += re["hit_ratio"] / n_test_users
222 | result["auc"] += re["auc"] / n_test_users
223 |
224 | assert count == n_test_users
225 | pool.close()
226 | return result
227 |
--------------------------------------------------------------------------------
/codes/utility/load_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | # args = parse_args()
5 | import random as rd
6 |
7 | # from utility.parser import parse_args
8 | from collections import defaultdict
9 |
10 | import numpy as np
11 | import pandas as pd
12 | import scipy.sparse as sp
13 | from gensim.models.doc2vec import Doc2Vec
14 |
15 |
16 | class Data(object):
17 | def __init__(self, path, batch_size):
18 | self.path = path + "/5-core"
19 | self.batch_size = batch_size
20 |
21 | train_file = path + "/5-core/train.json"
22 | val_file = path + "/5-core/val.json"
23 | test_file = path + "/5-core/test.json"
24 |
25 | # get number of users and items
26 | self.n_users, self.n_items = 0, 0
27 | self.n_train, self.n_test, self.n_val = 0, 0, 0
28 | self.neg_pools = {}
29 |
30 | self.exist_users = []
31 |
32 | train = json.load(open(train_file))
33 | test = json.load(open(test_file))
34 | val = json.load(open(val_file))
35 | for uid, items in train.items():
36 | if len(items) == 0:
37 | continue
38 | uid = int(uid)
39 | self.exist_users.append(uid)
40 | self.n_items = max(self.n_items, max(items))
41 | self.n_users = max(self.n_users, uid)
42 | self.n_train += len(items)
43 |
44 | for uid, items in test.items():
45 | uid = int(uid)
46 | try:
47 | self.n_items = max(self.n_items, max(items))
48 | self.n_test += len(items)
49 | except Exception:
50 | continue
51 |
52 | for uid, items in val.items():
53 | uid = int(uid)
54 | try:
55 | self.n_items = max(self.n_items, max(items))
56 | self.n_val += len(items)
57 | except Exception:
58 | continue
59 |
60 | self.n_items += 1
61 | self.n_users += 1
62 |
63 | self.print_statistics()
64 |
65 | self.R = sp.dok_matrix((self.n_users, self.n_items), dtype=np.float32)
66 |
67 | self.train_items, self.test_set, self.val_set = {}, {}, {}
68 | for uid, train_items in train.items():
69 | if len(train_items) == 0:
70 | continue
71 | uid = int(uid)
72 | for _, i in enumerate(train_items):
73 | self.R[uid, i] = 1.0
74 |
75 | self.train_items[uid] = train_items
76 |
77 | for uid, test_items in test.items():
78 | uid = int(uid)
79 | if len(test_items) == 0:
80 | continue
81 | try:
82 | self.test_set[uid] = test_items
83 | except Exception:
84 | continue
85 |
86 | for uid, val_items in val.items():
87 | uid = int(uid)
88 | if len(val_items) == 0:
89 | continue
90 | try:
91 | self.val_set[uid] = val_items
92 | except Exception:
93 | continue
94 |
95 | def nonzero_idx(self):
96 | r, c = self.R.nonzero()
97 | idx = list(zip(r, c))
98 | return idx
99 |
100 | def sample(self):
101 | if self.batch_size <= self.n_users:
102 | users = rd.sample(self.exist_users, self.batch_size)
103 | else:
104 | users = [rd.choice(self.exist_users) for _ in range(self.batch_size)]
105 | # users = self.exist_users[:]
106 |
107 | def sample_pos_items_for_u(u, num):
108 | pos_items = self.train_items[u]
109 | n_pos_items = len(pos_items)
110 | pos_batch = []
111 | while True:
112 | if len(pos_batch) == num:
113 | break
114 | pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0]
115 | pos_i_id = pos_items[pos_id]
116 |
117 | if pos_i_id not in pos_batch:
118 | pos_batch.append(pos_i_id)
119 | return pos_batch
120 |
121 | def sample_neg_items_for_u(u, num):
122 | neg_items = []
123 | while True:
124 | if len(neg_items) == num:
125 | break
126 | neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
127 | if neg_id not in self.train_items[u] and neg_id not in neg_items:
128 | neg_items.append(neg_id)
129 | return neg_items
130 |
131 | pos_items, neg_items = [], []
132 | for u in users:
133 | pos_items += sample_pos_items_for_u(u, 1)
134 | neg_items += sample_neg_items_for_u(u, 1)
135 | return users, pos_items, neg_items
136 |
137 | def print_statistics(self):
138 | print("n_users=%d, n_items=%d" % (self.n_users, self.n_items))
139 | print("n_interactions=%d" % (self.n_train + self.n_val + self.n_test))
140 | print(
141 | "n_train=%d, n_val=%d, n_test=%d, sparsity=%.5f"
142 | % (
143 | self.n_train,
144 | self.n_val,
145 | self.n_test,
146 | (self.n_train + self.n_val + self.n_test)
147 | / (self.n_users * self.n_items),
148 | )
149 | )
150 |
151 |
152 | def dataset_merge_and_split(path):
153 | df = pd.read_csv(path + "/train.csv", index_col=None, usecols=None)
154 | # Construct matrix
155 | ui = defaultdict(list)
156 | for _, row in df.iterrows():
157 | user, item = int(row["userID"]), int(row["itemID"])
158 | ui[user].append(item)
159 |
160 | df = pd.read_csv(path + "/test.csv", index_col=None, usecols=None)
161 | for _, row in df.iterrows():
162 | user, item = int(row["userID"]), int(row["itemID"])
163 | ui[user].append(item)
164 |
165 | train_json = {}
166 | val_json = {}
167 | test_json = {}
168 | for u, items in ui.items():
169 | if len(items) < 10:
170 | testval = np.random.choice(len(items), 2, replace=False)
171 | else:
172 | testval = np.random.choice(len(items), int(len(items) * 0.2), replace=False)
173 |
174 | test = testval[: len(testval) // 2]
175 | val = testval[len(testval) // 2 :]
176 | train = [i for i in list(range(len(items))) if i not in testval]
177 | train_json[u] = [items[idx] for idx in train]
178 | val_json[u] = [items[idx] for idx in val.tolist()]
179 | test_json[u] = [items[idx] for idx in test.tolist()]
180 |
181 | with open(path + "/5-core/train.json", "w") as f:
182 | json.dump(train_json, f)
183 | with open(path + "/5-core/val.json", "w") as f:
184 | json.dump(val_json, f)
185 | with open(path + "/5-core/test.json", "w") as f:
186 | json.dump(test_json, f)
187 |
188 |
189 | def load_textual_image_features(data_path):
190 | asin_dict = json.load(open(os.path.join(data_path, "asin_sample.json"), "r"))
191 |
192 | # Prepare textual feture data.
193 | doc2vec_model = Doc2Vec.load(os.path.join(data_path, "doc2vecFile"))
194 | vis_vec = np.load(
195 | os.path.join(data_path, "image_feature.npy"), allow_pickle=True
196 | ).item()
197 | text_vec = {}
198 | for asin in asin_dict:
199 | text_vec[asin] = doc2vec_model.docvecs[asin]
200 |
201 | all_dict = {}
202 | num_items = 0
203 | filename = data_path + "/train.csv"
204 | df = pd.read_csv(filename, index_col=None, usecols=None)
205 | for _, row in df.iterrows():
206 | asin, i = row["asin"], int(row["itemID"])
207 | all_dict[i] = asin
208 | num_items = max(num_items, i)
209 | filename = data_path + "/test.csv"
210 | df = pd.read_csv(filename, index_col=None, usecols=None)
211 | for _, row in df.iterrows():
212 | asin, i = row["asin"], int(row["itemID"])
213 | all_dict[i] = asin
214 | num_items = max(num_items, i)
215 |
216 | t_features = []
217 | v_features = []
218 | for i in range(num_items + 1):
219 | t_features.append(text_vec[all_dict[i]])
220 | v_features.append(vis_vec[all_dict[i]])
221 |
222 | np.save(data_path + "/text_feat.npy", np.asarray(t_features, dtype=np.float32))
223 | np.save(data_path + "/image_feat.npy", np.asarray(v_features, dtype=np.float32))
224 |
--------------------------------------------------------------------------------
/codes/main.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import sys
4 | from time import time
5 |
6 | import numpy as np
7 | import torch
8 | import torch.optim as optim
9 | from Models import MONET
10 | from utility.batch_test import data_generator, test_torch
11 | from utility.parser import parse_args
12 |
13 |
14 | class Trainer(object):
15 | def __init__(self, data_config, args):
16 | # argument settings
17 | self.n_users = data_config["n_users"]
18 | self.n_items = data_config["n_items"]
19 |
20 | self.feat_embed_dim = args.feat_embed_dim
21 | self.lr = args.lr
22 | self.emb_dim = args.embed_size
23 | self.batch_size = args.batch_size
24 | self.n_layers = args.n_layers
25 | self.has_norm = args.has_norm
26 | self.regs = eval(args.regs)
27 | self.decay = self.regs[0]
28 | self.lamb = self.regs[1]
29 | self.alpha = args.alpha
30 | self.beta = args.beta
31 | self.dataset = args.dataset
32 | self.model_name = args.model_name
33 | self.agg = args.agg
34 | self.target_aware = args.target_aware
35 | self.cf = args.cf
36 | self.cf_gcn = args.cf_gcn
37 | self.lightgcn = args.lightgcn
38 |
39 | self.nonzero_idx = data_config["nonzero_idx"]
40 |
41 | self.image_feats = np.load("data/{}/image_feat.npy".format(self.dataset))
42 | self.text_feats = np.load("data/{}/text_feat.npy".format(self.dataset))
43 |
44 | self.model = MONET(
45 | self.n_users,
46 | self.n_items,
47 | self.feat_embed_dim,
48 | self.nonzero_idx,
49 | self.has_norm,
50 | self.image_feats,
51 | self.text_feats,
52 | self.n_layers,
53 | self.alpha,
54 | self.beta,
55 | self.agg,
56 | self.cf,
57 | self.cf_gcn,
58 | self.lightgcn,
59 | )
60 |
61 | self.model = self.model.cuda()
62 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
63 | self.lr_scheduler = self.set_lr_scheduler()
64 |
65 | def set_lr_scheduler(self):
66 | fac = lambda epoch: 0.96 ** (epoch / 50)
67 | scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=fac)
68 | return scheduler
69 |
70 | def test(self, users_to_test, is_val):
71 | self.model.eval()
72 | with torch.no_grad():
73 | ua_embeddings, ia_embeddings = self.model()
74 | result = test_torch(
75 | ua_embeddings,
76 | ia_embeddings,
77 | users_to_test,
78 | is_val,
79 | self.adj,
80 | self.beta,
81 | self.target_aware,
82 | )
83 | return result
84 |
85 | def train(self):
86 | nonzero_idx = torch.tensor(self.nonzero_idx).cuda().long().T
87 | self.adj = (
88 | torch.sparse.FloatTensor(
89 | nonzero_idx,
90 | torch.ones((nonzero_idx.size(1))).cuda(),
91 | (self.n_users, self.n_items),
92 | )
93 | .to_dense()
94 | .cuda()
95 | )
96 | stopping_step = 0
97 |
98 | n_batch = data_generator.n_train // args.batch_size + 1
99 | best_recall = 0
100 | for epoch in range(args.epoch):
101 | t1 = time()
102 | loss, mf_loss, emb_loss, reg_loss = 0.0, 0.0, 0.0, 0.0
103 | n_batch = data_generator.n_train // args.batch_size + 1
104 | for _ in range(n_batch):
105 | self.model.train()
106 | self.optimizer.zero_grad()
107 | user_emb, item_emb = self.model()
108 | users, pos_items, neg_items = data_generator.sample()
109 |
110 | batch_mf_loss, batch_emb_loss, batch_reg_loss = self.model.bpr_loss(
111 | user_emb, item_emb, users, pos_items, neg_items, self.target_aware
112 | )
113 |
114 | batch_emb_loss = self.decay * batch_emb_loss
115 | batch_loss = batch_mf_loss + batch_emb_loss + batch_reg_loss
116 |
117 | batch_loss.backward(retain_graph=True)
118 | self.optimizer.step()
119 |
120 | loss += float(batch_loss)
121 | mf_loss += float(batch_mf_loss)
122 | emb_loss += float(batch_emb_loss)
123 | reg_loss += float(batch_reg_loss)
124 |
125 | del user_emb, item_emb
126 | torch.cuda.empty_cache()
127 |
128 | self.lr_scheduler.step()
129 |
130 | if math.isnan(loss):
131 | print("ERROR: loss is nan.")
132 | sys.exit()
133 |
134 | perf_str = "Pre_Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f + %.5f]" % (
135 | epoch,
136 | time() - t1,
137 | loss,
138 | mf_loss,
139 | emb_loss,
140 | reg_loss,
141 | )
142 | print(perf_str)
143 |
144 | if epoch % args.verbose != 0:
145 | continue
146 |
147 | t2 = time()
148 | users_to_test = list(data_generator.test_set.keys())
149 | users_to_val = list(data_generator.val_set.keys())
150 | ret = self.test(users_to_val, is_val=True)
151 |
152 | t3 = time()
153 |
154 | if args.verbose > 0:
155 | perf_str = (
156 | "Pre_Epoch %d [%.1fs + %.1fs]: val==[%.5f=%.5f + %.5f + %.5f], recall=[%.5f, %.5f], "
157 | "precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]"
158 | % (
159 | epoch,
160 | t2 - t1,
161 | t3 - t2,
162 | loss,
163 | mf_loss,
164 | emb_loss,
165 | reg_loss,
166 | ret["recall"][0],
167 | ret["recall"][-1],
168 | ret["precision"][0],
169 | ret["precision"][-1],
170 | ret["hit_ratio"][0],
171 | ret["hit_ratio"][-1],
172 | ret["ndcg"][0],
173 | ret["ndcg"][-1],
174 | )
175 | )
176 | print(perf_str)
177 |
178 | if ret["recall"][1] > best_recall:
179 | best_recall = ret["recall"][1]
180 | stopping_step = 0
181 | torch.save(
182 | {self.model_name: self.model.state_dict()},
183 | "./models/" + self.dataset + "_" + self.model_name,
184 | )
185 | elif stopping_step < args.early_stopping_patience:
186 | stopping_step += 1
187 | print("#####Early stopping steps: %d #####" % stopping_step)
188 | else:
189 | print("#####Early stop! #####")
190 | break
191 |
192 | self.model = MONET(
193 | self.n_users,
194 | self.n_items,
195 | self.feat_embed_dim,
196 | self.nonzero_idx,
197 | self.has_norm,
198 | self.image_feats,
199 | self.text_feats,
200 | self.n_layers,
201 | self.alpha,
202 | self.beta,
203 | self.agg,
204 | self.cf,
205 | self.cf_gcn,
206 | self.lightgcn,
207 | )
208 |
209 | self.model.load_state_dict(
210 | torch.load(
211 | "./models/" + self.dataset + "_" + self.model_name,
212 | map_location=torch.device("cpu"),
213 | )[self.model_name]
214 | )
215 | self.model.cuda()
216 | test_ret = self.test(users_to_test, is_val=False)
217 | print("Final ", test_ret)
218 |
219 |
220 | def set_seed(seed):
221 | np.random.seed(seed)
222 | random.seed(seed)
223 | torch.manual_seed(seed) # cpu
224 | torch.cuda.manual_seed_all(seed) # gpu
225 |
226 |
227 | if __name__ == "__main__":
228 | args = parse_args(True)
229 | set_seed(args.seed)
230 |
231 | config = dict()
232 | config["n_users"] = data_generator.n_users
233 | config["n_items"] = data_generator.n_items
234 |
235 | nonzero_idx = data_generator.nonzero_idx()
236 | config["nonzero_idx"] = nonzero_idx
237 |
238 | trainer = Trainer(config, args)
239 | trainer.train()
240 |
--------------------------------------------------------------------------------
/codes/Preliminaries.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "n_users=19244, n_items=14596\n",
13 | "n_interactions=135326\n",
14 | "n_train=95629, n_val=20127, n_test=19570, sparsity=0.00048\n"
15 | ]
16 | },
17 | {
18 | "name": "stderr",
19 | "output_type": "stream",
20 | "text": [
21 | "100%|██████████| 19244/19244 [01:10<00:00, 273.56it/s]"
22 | ]
23 | },
24 | {
25 | "name": "stdout",
26 | "output_type": "stream",
27 | "text": [
28 | "text 0.10991018 0.086630285\n",
29 | "img 0.3009104 0.23256181\n"
30 | ]
31 | },
32 | {
33 | "name": "stderr",
34 | "output_type": "stream",
35 | "text": [
36 | "\n"
37 | ]
38 | }
39 | ],
40 | "source": [
41 | "# avg.sim (Figure 1)\n",
42 | "\n",
43 | "import json\n",
44 | "import os\n",
45 | "from utility.load_data import Data\n",
46 | "\n",
47 | "data_generator = Data(path='data/WomenClothing', batch_size=1024)\n",
48 | "\n",
49 | "from copy import deepcopy\n",
50 | "I_items = deepcopy(data_generator.train_items)\n",
51 | "\n",
52 | "for k in I_items.keys():\n",
53 | " I_items[k] = I_items[k] + data_generator.val_set[k] + data_generator.test_set[k]\n",
54 | "\n",
55 | "import numpy as np\n",
56 | "image_feats = np.load('data/WomenClothing/image_feat.npy')\n",
57 | "text_feats = np.load('data/WomenClothing/text_feat.npy')\n",
58 | "\n",
59 | "from collections import defaultdict\n",
60 | "from tqdm import tqdm\n",
61 | "\n",
62 | "img_cos = np.dot(image_feats, image_feats.T) / (np.linalg.norm(image_feats, axis=1)[:, np.newaxis] * np.linalg.norm(image_feats, axis=1)[:, np.newaxis].T)\n",
63 | "text_cos = np.dot(text_feats, text_feats.T) / (np.linalg.norm(text_feats, axis=1)[:, np.newaxis] * np.linalg.norm(text_feats, axis=1)[:, np.newaxis].T)\n",
64 | "\n",
65 | "seen_img = []\n",
66 | "seen_text = []\n",
67 | "unseen_img = []\n",
68 | "unseen_text = []\n",
69 | "for user, items in tqdm(I_items.items()):\n",
70 | " img = img_cos[items][:, items]\n",
71 | " text = text_cos[items][:, items]\n",
72 | "\n",
73 | " seen_img_result = []\n",
74 | " seen_text_result = []\n",
75 | " for i in range(len(items)):\n",
76 | " seen_img_result.append(np.concatenate([img[i, :i], img[i, i+1:]]))\n",
77 | " seen_text_result.append(np.concatenate([text[i, :i], text[i, i+1:]]))\n",
78 | " seen_img_result = np.array(seen_img_result) # .flatten()\n",
79 | " seen_text_result = np.array(seen_text_result) # .flatten()\n",
80 | "\n",
81 | " unseen_items = set(range(data_generator.n_items)) - set(items)\n",
82 | " unseen_items = list(unseen_items)\n",
83 | "\n",
84 | " unseen_img_result = img_cos[items][:, unseen_items].flatten()\n",
85 | " unseen_text_result = text_cos[items][:, unseen_items].flatten()\n",
86 | "\n",
87 | " seen_img.append(seen_img_result.mean())\n",
88 | " seen_text.append(seen_text_result.mean())\n",
89 | " unseen_img.append(unseen_img_result.mean())\n",
90 | " unseen_text.append(unseen_text_result.mean())\n",
91 | "\n",
92 | "print('text', np.mean(seen_text), np.mean(unseen_text))\n",
93 | "print('img', np.mean(seen_img), np.mean(unseen_img))"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 1,
99 | "metadata": {},
100 | "outputs": [
101 | {
102 | "name": "stderr",
103 | "output_type": "stream",
104 | "text": [
105 | "/home/ubuntu/anaconda3/envs/yg/lib/python3.6/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
106 | " from .autonotebook import tqdm as notebook_tqdm\n"
107 | ]
108 | },
109 | {
110 | "name": "stdout",
111 | "output_type": "stream",
112 | "text": [
113 | "n_users=19244, n_items=14596\n",
114 | "n_interactions=135326\n",
115 | "n_train=95629, n_val=20127, n_test=19570, sparsity=0.00048\n",
116 | "Loads image_emb: torch.Size([33840, 64]) and text_emb: torch.Size([33840, 64])\n",
117 | "0.14228745 0.11385205\n",
118 | "Loads image_emb: torch.Size([33840, 64]) and text_emb: torch.Size([33840, 64])\n",
119 | "0.3034921 0.10677319\n",
120 | "Loads image_emb: torch.Size([33840, 64]) and text_emb: torch.Size([33840, 64])\n",
121 | "0.3312145 0.1141393\n",
122 | "Loads image_emb: torch.Size([33840, 64]) and text_emb: torch.Size([33840, 64])\n",
123 | "0.17027126 0.110791825\n",
124 | "Loads image_emb: torch.Size([33840, 64]) and text_emb: torch.Size([33840, 64])\n",
125 | "0.2663411 0.11235878\n"
126 | ]
127 | }
128 | ],
129 | "source": [
130 | "# avg.diff\n",
131 | "from Models import *\n",
132 | "\n",
133 | "import os\n",
134 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n",
135 | "import torch\n",
136 | "import numpy as np\n",
137 | "from utility.load_data import Data\n",
138 | "data_generator = Data(path='data/WomenClothing', batch_size=1024)\n",
139 | "\n",
140 | "def sparse_mx_to_torch_sparse_tensor(sparse_mx):\n",
141 | " \"\"\"Convert a scipy sparse matrix to a torch sparse tensor.\"\"\"\n",
142 | " sparse_mx = sparse_mx.tocoo().astype(np.float32)\n",
143 | " indices = torch.from_numpy(\n",
144 | " np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))\n",
145 | " values = torch.from_numpy(sparse_mx.data)\n",
146 | " shape = torch.Size(sparse_mx.shape)\n",
147 | " return torch.sparse.FloatTensor(indices, values, shape)\n",
148 | "\n",
149 | "model_name = 'MONET_concat_20_03'\n",
150 | "seed_list = ['123', '0', '42', '1024', '2048']\n",
151 | "nonzero_idx = data_generator.nonzero_idx()\n",
152 | "\n",
153 | "import numpy as np\n",
154 | "image_feats = np.load('data/WomenClothing/image_feat.npy')\n",
155 | "text_feats = np.load('data/WomenClothing/text_feat.npy')\n",
156 | " \n",
157 | "for seed in seed_list: \n",
158 | " model = MONET(data_generator.n_users, data_generator.n_items, 64, nonzero_idx, True, image_feats, text_feats, 2, 1.0, 0.3, 'concat', 's', False) \n",
159 | " model.load_state_dict(torch.load('./models/' + 'WomenClothing' + '_' + model_name + '_' + seed, map_location=torch.device('cpu'))[model_name + '_' + seed])\n",
160 | " model.cuda()\n",
161 | " image_emb, text_emb = model(eval=True)\n",
162 | " print('Loads image_emb: {} and text_emb: {}'.format(image_emb.shape, text_emb.shape))\n",
163 | "\n",
164 | " # user_emb = torch.load('data/{}/{}_user_emb.pt'.format('clothing', 'lightgcn_layer3_original')).cuda()\n",
165 | " # item_emb = torch.load('data/{}/{}_item_emb.pt'.format('clothing', 'lightgcn_layer3_original')).cuda()\n",
166 | " # print('Loads user_emb: {} and item_emb: {}'.format(user_emb.weight.shape, item_emb.weight.shape))\n",
167 | "\n",
168 | " # image_emb = image_emb.mean(dim=1, keepdim=False)\n",
169 | " # text_emb = text_emb.mean(dim=1, keepdim=False)\n",
170 | "\n",
171 | " # image_emb = image_emb[:, -1, :]\n",
172 | " # text_emb = text_emb[:, -1, :]\n",
173 | "\n",
174 | "\n",
175 | " final_image_preference, final_image_emb = torch.split(image_emb, [data_generator.n_users, data_generator.n_items], dim=0)\n",
176 | " final_text_preference, final_text_emb = torch.split(text_emb, [data_generator.n_users, data_generator.n_items], dim=0)\n",
177 | "\n",
178 | " final_text_emb, final_image_emb = final_text_emb.cpu().detach().numpy(), final_image_emb.cpu().detach().numpy()\n",
179 | "\n",
180 | " final_image_cos = np.dot(final_image_emb, final_image_emb.T) / (np.linalg.norm(final_image_emb, axis=1)[:, np.newaxis] * np.linalg.norm(final_image_emb, axis=1)[:, np.newaxis].T)\n",
181 | " final_text_cos = np.dot(final_text_emb, final_text_emb.T) / (np.linalg.norm(final_text_emb, axis=1)[:, np.newaxis] * np.linalg.norm(final_text_emb, axis=1)[:, np.newaxis].T)\n",
182 | "\n",
183 | " img_cos = np.dot(image_feats, image_feats.T) / (np.linalg.norm(image_feats, axis=1)[:, np.newaxis] * np.linalg.norm(image_feats, axis=1)[:, np.newaxis].T)\n",
184 | " text_cos = np.dot(text_feats, text_feats.T) / (np.linalg.norm(text_feats, axis=1)[:, np.newaxis] * np.linalg.norm(text_feats, axis=1)[:, np.newaxis].T)\n",
185 | "\n",
186 | " img_diff = np.abs(img_cos - final_image_cos)\n",
187 | " text_diff = np.abs(text_cos - final_text_cos)\n",
188 | "\n",
189 | " img = []\n",
190 | " for i in range(data_generator.n_items):\n",
191 | " img.append(np.concatenate([img_diff[i, :i], img_diff[i, i+1:]]))\n",
192 | " img = np.array(img) # .flatten()\n",
193 | "\n",
194 | " txt = []\n",
195 | " for i in range(data_generator.n_items):\n",
196 | " txt.append(np.concatenate([text_diff[i, :i], text_diff[i, i+1:]]))\n",
197 | " txt = np.array(txt) # .flatten()\n",
198 | "\n",
199 | " print(img[~np.isnan(img)].mean(), txt[~np.isnan(txt)].mean())"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": null,
205 | "metadata": {},
206 | "outputs": [],
207 | "source": []
208 | }
209 | ],
210 | "metadata": {
211 | "interpreter": {
212 | "hash": "0aa7af790e1209bd084877485dad105a461ac2ebd38ac99cff72d3e7c0921c3c"
213 | },
214 | "kernelspec": {
215 | "display_name": "yg",
216 | "language": "python",
217 | "name": "yg"
218 | },
219 | "language_info": {
220 | "codemirror_mode": {
221 | "name": "ipython",
222 | "version": 3
223 | },
224 | "file_extension": ".py",
225 | "mimetype": "text/x-python",
226 | "name": "python",
227 | "nbconvert_exporter": "python",
228 | "pygments_lexer": "ipython3",
229 | "version": "3.6.13 |Anaconda, Inc.| (default, Jun 4 2021, 14:25:59) \n[GCC 7.5.0]"
230 | },
231 | "orig_nbformat": 4
232 | },
233 | "nbformat": 4,
234 | "nbformat_minor": 2
235 | }
236 |
--------------------------------------------------------------------------------
/codes/data/build_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import array
3 | import gzip
4 | import json
5 | import os
6 | from collections import defaultdict
7 |
8 | import numpy as np
9 | import pandas as pd
10 | from sentence_transformers import SentenceTransformer
11 |
12 |
13 | def dataset_merge_and_split(path, core):
14 | if not os.path.exists(folder + "%d-core" % core):
15 | os.makedirs(folder + "%d-core" % core)
16 |
17 | df = pd.read_csv(path + "/train.csv", index_col=None, usecols=None)
18 | # Construct matrix
19 | ui = defaultdict(list)
20 | for _, row in df.iterrows():
21 | user, item = int(row["userID"]), int(row["itemID"])
22 | ui[user].append(item)
23 |
24 | df = pd.read_csv(path + "/test.csv", index_col=None, usecols=None)
25 | for _, row in df.iterrows():
26 | user, item = int(row["userID"]), int(row["itemID"])
27 | ui[user].append(item)
28 |
29 | train_json = {}
30 | val_json = {}
31 | test_json = {}
32 | for u, items in ui.items():
33 | if len(items) < 10:
34 | testval = np.random.choice(len(items), 2, replace=False)
35 | else:
36 | testval = np.random.choice(len(items), int(len(items) * 0.2), replace=False)
37 |
38 | test = testval[: len(testval) // 2]
39 | val = testval[len(testval) // 2 :]
40 | train = [i for i in list(range(len(items))) if i not in testval]
41 | train_json[u] = [items[idx] for idx in train]
42 | val_json[u] = [items[idx] for idx in val.tolist()]
43 | test_json[u] = [items[idx] for idx in test.tolist()]
44 |
45 | with open(path + "/5-core/train.json", "w") as f:
46 | json.dump(train_json, f)
47 | with open(path + "/5-core/val.json", "w") as f:
48 | json.dump(val_json, f)
49 | with open(path + "/5-core/test.json", "w") as f:
50 | json.dump(test_json, f)
51 |
52 |
53 | def load_textual_image_features(data_path):
54 | import json
55 | import os
56 |
57 | from gensim.models.doc2vec import Doc2Vec
58 |
59 | asin_dict = json.load(open(os.path.join(data_path, "asin_sample.json"), "r"))
60 |
61 | # Prepare textual feture data.
62 | doc2vec_model = Doc2Vec.load(os.path.join(data_path, "doc2vecFile"))
63 | vis_vec = np.load(
64 | os.path.join(data_path, "image_feature.npy"), allow_pickle=True
65 | ).item()
66 | text_vec = {}
67 | for asin in asin_dict:
68 | text_vec[asin] = doc2vec_model.docvecs[asin]
69 |
70 | all_dict = {}
71 | num_items = 0
72 | filename = data_path + "/train.csv"
73 | df = pd.read_csv(filename, index_col=None, usecols=None)
74 | for _, row in df.iterrows():
75 | asin, i = row["asin"], int(row["itemID"])
76 | all_dict[i] = asin
77 | num_items = max(num_items, i)
78 | filename = data_path + "/test.csv"
79 | df = pd.read_csv(filename, index_col=None, usecols=None)
80 | for _, row in df.iterrows():
81 | asin, i = row["asin"], int(row["itemID"])
82 | all_dict[i] = asin
83 | num_items = max(num_items, i)
84 |
85 | t_features = []
86 | v_features = []
87 | for i in range(num_items + 1):
88 | t_features.append(text_vec[all_dict[i]])
89 | v_features.append(vis_vec[all_dict[i]])
90 |
91 | np.save(data_path + "/text_feat.npy", np.asarray(t_features, dtype=np.float32))
92 | np.save(data_path + "/image_feat.npy", np.asarray(v_features, dtype=np.float32))
93 |
94 |
95 | parser = argparse.ArgumentParser(description="")
96 |
97 | parser.add_argument(
98 | "--name",
99 | nargs="?",
100 | default="MenClothing",
101 | help="Choose a dataset folder from {MenClothing, WomenClothing, Beauty, Toys_and_Games}.",
102 | )
103 |
104 | np.random.seed(123)
105 |
106 | args = parser.parse_args()
107 | folder = args.name + "/"
108 | name = args.name
109 | core = 5
110 | if folder in ["MenClothing/", "WomenClothing/"]:
111 | dataset_merge_and_split(folder, core)
112 | load_textual_image_features(folder)
113 | else:
114 | bert_path = "sentence-transformers/stsb-roberta-large"
115 | bert_model = SentenceTransformer(bert_path)
116 |
117 | if not os.path.exists(folder + "%d-core" % core):
118 | os.makedirs(folder + "%d-core" % core)
119 |
120 | def parse(path):
121 | g = gzip.open(path, "r")
122 | for line in g:
123 | yield json.dumps(eval(line))
124 |
125 | print("----------parse metadata----------")
126 | if not os.path.exists(folder + "meta-data/meta.json"):
127 | with open(folder + "meta-data/meta.json", "w") as f:
128 | for line in parse(folder + "meta-data/" + "meta_%s.json.gz" % (name)):
129 | f.write(line + "\n")
130 |
131 | print("----------parse data----------")
132 | if not os.path.exists(folder + "meta-data/%d-core.json" % core):
133 | with open(folder + "meta-data/%d-core.json" % core, "w") as f:
134 | for line in parse(
135 | folder + "meta-data/" + "reviews_%s_%d.json.gz" % (name, core)
136 | ):
137 | f.write(line + "\n")
138 |
139 | print("----------load data----------")
140 | jsons = []
141 | for line in open(folder + "meta-data/%d-core.json" % core).readlines():
142 | jsons.append(json.loads(line))
143 |
144 | print("----------Build dict----------")
145 | items = set()
146 | users = set()
147 | for j in jsons:
148 | items.add(j["asin"])
149 | users.add(j["reviewerID"])
150 | print("n_items:", len(items), "n_users:", len(users))
151 |
152 | item2id = {}
153 | with open(folder + "%d-core/item_list.txt" % core, "w") as f:
154 | for i, item in enumerate(items):
155 | item2id[item] = i
156 | f.writelines(item + "\t" + str(i) + "\n")
157 |
158 | user2id = {}
159 | with open(folder + "%d-core/user_list.txt" % core, "w") as f:
160 | for i, user in enumerate(users):
161 | user2id[user] = i
162 | f.writelines(user + "\t" + str(i) + "\n")
163 |
164 | ui = defaultdict(list)
165 | review2id = {}
166 | review_text = {}
167 | ratings = {}
168 | with open(folder + "%d-core/review_list.txt" % core, "w") as f:
169 | for j in jsons:
170 | u_id = user2id[j["reviewerID"]]
171 | i_id = item2id[j["asin"]]
172 | ui[u_id].append(i_id) # ui[u_id].append(i_id)
173 | review_text[len(review2id)] = j["reviewText"].replace("\n", " ")
174 | ratings[len(review2id)] = int(j["overall"])
175 | f.writelines(str((u_id, i_id)) + "\t" + str(len(review2id)) + "\n")
176 | review2id[u_id, i_id] = len(review2id)
177 | with open(folder + "%d-core/user-item-dict.json" % core, "w") as f:
178 | f.write(json.dumps(ui))
179 | with open(folder + "%d-core/rating-dict.json" % core, "w") as f:
180 | f.write(json.dumps(ratings))
181 |
182 | review_texts = []
183 | with open(folder + "%d-core/review_text.txt" % core, "w") as f:
184 | for i, j in review2id:
185 | f.write(review_text[review2id[i, j]] + "\n")
186 | review_texts.append(review_text[review2id[i, j]] + "\n")
187 | review_embeddings = bert_model.encode(review_texts)
188 | assert review_embeddings.shape[0] == len(review2id)
189 | np.save(folder + "review_feat.npy", review_embeddings)
190 |
191 | print("----------Split Data----------")
192 | train_json = {}
193 | val_json = {}
194 | test_json = {}
195 | for u, items in ui.items():
196 | if len(items) < 10:
197 | testval = np.random.choice(len(items), 2, replace=False)
198 | else:
199 | testval = np.random.choice(len(items), int(len(items) * 0.2), replace=False)
200 |
201 | test = testval[: len(testval) // 2]
202 | val = testval[len(testval) // 2 :]
203 | train = [i for i in list(range(len(items))) if i not in testval]
204 | train_json[u] = [items[idx] for idx in train]
205 | val_json[u] = [items[idx] for idx in val.tolist()]
206 | test_json[u] = [items[idx] for idx in test.tolist()]
207 |
208 | with open(folder + "%d-core/train.json" % core, "w") as f:
209 | json.dump(train_json, f)
210 | with open(folder + "%d-core/val.json" % core, "w") as f:
211 | json.dump(val_json, f)
212 | with open(folder + "%d-core/test.json" % core, "w") as f:
213 | json.dump(test_json, f)
214 |
215 | jsons = []
216 | with open(folder + "meta-data/meta.json", "r") as f:
217 | for line in f.readlines():
218 | jsons.append(json.loads(line))
219 |
220 | print("----------Text Features----------")
221 | raw_text = {}
222 | for _json in jsons:
223 | if _json["asin"] in item2id:
224 | string = " "
225 | if "categories" in _json:
226 | for cates in _json["categories"]:
227 | for cate in cates:
228 | string += cate + " "
229 | if "title" in _json:
230 | string += _json["title"]
231 | if "brand" in _json:
232 | string += _json["title"]
233 | if "description" in _json:
234 | string += _json["description"]
235 | raw_text[item2id[_json["asin"]]] = string.replace("\n", " ")
236 | texts = []
237 | with open(folder + "%d-core/raw_text.txt" % core, "w") as f:
238 | for i in range(len(item2id)):
239 | f.write(raw_text[i] + "\n")
240 | texts.append(raw_text[i] + "\n")
241 | sentence_embeddings = bert_model.encode(texts)
242 | assert sentence_embeddings.shape[0] == len(item2id)
243 | np.save(folder + "text_feat.npy", sentence_embeddings)
244 |
245 | print("----------Image Features----------")
246 |
247 | def readImageFeatures(path):
248 | f = open(path, "rb")
249 | while True:
250 | asin = f.read(10).decode("UTF-8")
251 | if asin == "":
252 | break
253 | a = array.array("f")
254 | a.fromfile(f, 4096)
255 | yield asin, a.tolist()
256 |
257 | data = readImageFeatures(folder + "meta-data/" + "image_features_%s.b" % name)
258 | feats = {}
259 | avg = []
260 | for d in data:
261 | if d[0] in item2id:
262 | feats[int(item2id[d[0]])] = d[1]
263 | avg.append(d[1])
264 | avg = np.array(avg).mean(0).tolist()
265 |
266 | ret = []
267 | for i in range(len(item2id)):
268 | if i in feats:
269 | ret.append(feats[i])
270 | else:
271 | ret.append(avg)
272 |
273 | assert len(ret) == len(item2id)
274 | np.save(folder + "image_feat.npy", np.array(ret))
275 |
--------------------------------------------------------------------------------
/codes/Models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import Parameter
5 | from torch_geometric.nn.conv import MessagePassing
6 | from torch_geometric.utils.num_nodes import maybe_num_nodes
7 | from torch_scatter import scatter_add
8 |
9 |
10 | def normalize_laplacian(edge_index, edge_weight):
11 | num_nodes = maybe_num_nodes(edge_index)
12 | row, col = edge_index[0], edge_index[1]
13 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
14 |
15 | deg_inv_sqrt = deg.pow_(-0.5)
16 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float("inf"), 0)
17 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
18 | return edge_weight
19 |
20 |
21 | class Our_GCNs(MessagePassing):
22 | def __init__(self, in_channels, out_channels):
23 | super(Our_GCNs, self).__init__(aggr="add")
24 | self.in_channels = in_channels
25 | self.out_channels = out_channels
26 |
27 | def forward(self, x, edge_index, weight_vector, size=None):
28 | self.weight_vector = weight_vector
29 | return self.propagate(edge_index, size=size, x=x)
30 |
31 | def message(self, x_j):
32 | return x_j * self.weight_vector
33 |
34 | def update(self, aggr_out):
35 | return aggr_out
36 |
37 |
38 | from torch_geometric.nn.inits import uniform
39 |
40 |
41 | class Nonlinear_GCNs(MessagePassing):
42 | def __init__(self, in_channels, out_channels):
43 | super(Nonlinear_GCNs, self).__init__(aggr="add")
44 | self.in_channels = in_channels
45 | self.out_channels = out_channels
46 | self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
47 | self.reset_parameters()
48 |
49 | def reset_parameters(self):
50 | uniform(self.in_channels, self.weight)
51 |
52 | def forward(self, x, edge_index, weight_vector, size=None):
53 | x = torch.matmul(x, self.weight)
54 | self.weight_vector = weight_vector
55 | return self.propagate(edge_index, size=size, x=x)
56 |
57 | def message(self, x_j):
58 | return x_j * self.weight_vector
59 |
60 | def update(self, aggr_out):
61 | return aggr_out
62 |
63 |
64 | class MeGCN(nn.Module):
65 | def __init__(
66 | self,
67 | n_users,
68 | n_items,
69 | n_layers,
70 | has_norm,
71 | feat_embed_dim,
72 | nonzero_idx,
73 | image_feats,
74 | text_feats,
75 | alpha,
76 | agg,
77 | cf,
78 | cf_gcn,
79 | lightgcn,
80 | ):
81 | super(MeGCN, self).__init__()
82 | self.n_users = n_users
83 | self.n_items = n_items
84 | self.n_layers = n_layers
85 | self.has_norm = has_norm
86 | self.feat_embed_dim = feat_embed_dim
87 | self.nonzero_idx = torch.tensor(nonzero_idx).cuda().long().T
88 | self.alpha = alpha
89 | self.agg = agg
90 | self.cf = cf
91 | self.cf_gcn = cf_gcn
92 | self.lightgcn = lightgcn
93 |
94 | self.image_preference = nn.Embedding(self.n_users, self.feat_embed_dim)
95 | self.text_preference = nn.Embedding(self.n_users, self.feat_embed_dim)
96 | nn.init.xavier_uniform_(self.image_preference.weight)
97 | nn.init.xavier_uniform_(self.text_preference.weight)
98 |
99 | self.image_embedding = nn.Embedding.from_pretrained(
100 | torch.tensor(image_feats, dtype=torch.float), freeze=True
101 | ) # [# of items, 4096]
102 | self.text_embedding = nn.Embedding.from_pretrained(
103 | torch.tensor(text_feats, dtype=torch.float), freeze=True
104 | ) # [# of items, 1024]
105 |
106 | if self.cf:
107 | self.user_embedding = nn.Embedding(self.n_users, self.feat_embed_dim)
108 | self.item_embedding = nn.Embedding(self.n_items, self.feat_embed_dim)
109 | nn.init.xavier_uniform_(self.user_embedding.weight)
110 | nn.init.xavier_uniform_(self.item_embedding.weight)
111 |
112 | self.image_trs = nn.Linear(image_feats.shape[1], self.feat_embed_dim)
113 | self.text_trs = nn.Linear(text_feats.shape[1], self.feat_embed_dim)
114 |
115 | if not self.cf:
116 | if self.agg == "fc":
117 | self.transform = nn.Linear(self.feat_embed_dim * 2, self.feat_embed_dim)
118 | elif self.agg == "weighted_sum":
119 | self.modal_weight = nn.Parameter(torch.Tensor([0.5, 0.5]))
120 | self.softmax = nn.Softmax(dim=0)
121 | else:
122 | if self.agg == "fc":
123 | self.transform = nn.Linear(self.feat_embed_dim * 3, self.feat_embed_dim)
124 | elif self.agg == "weighted_sum":
125 | self.modal_weight = nn.Parameter(torch.Tensor([0.33, 0.33, 0.33]))
126 | self.softmax = nn.Softmax(dim=0)
127 |
128 | self.layers = nn.ModuleList(
129 | [
130 | Our_GCNs(self.feat_embed_dim, self.feat_embed_dim)
131 | for _ in range(self.n_layers)
132 | ]
133 | )
134 |
135 | def forward(self, edge_index, edge_weight, _eval=False):
136 | # transform
137 | image_emb = self.image_trs(
138 | self.image_embedding.weight
139 | ) # [# of items, feat_embed_dim]
140 | text_emb = self.text_trs(
141 | self.text_embedding.weight
142 | ) # [# of items, feat_embed_dim]
143 |
144 | if self.has_norm:
145 | image_emb = F.normalize(image_emb)
146 | text_emb = F.normalize(text_emb)
147 | image_preference = self.image_preference.weight
148 | text_preference = self.text_preference.weight
149 |
150 | # propagate
151 | ego_image_emb = torch.cat([image_preference, image_emb], dim=0)
152 | ego_text_emb = torch.cat([text_preference, text_emb], dim=0)
153 |
154 | if self.cf:
155 | user_emb = self.user_embedding.weight
156 | item_emb = self.item_embedding.weight
157 | ego_cf_emb = torch.cat([user_emb, item_emb], dim=0)
158 | if self.cf_gcn == "LightGCN":
159 | all_cf_emb = [ego_cf_emb]
160 |
161 | if self.lightgcn:
162 | all_image_emb = [ego_image_emb]
163 | all_text_emb = [ego_text_emb]
164 |
165 | for layer in self.layers:
166 | if not self.lightgcn:
167 | side_image_emb = layer(ego_image_emb, edge_index, edge_weight)
168 | side_text_emb = layer(ego_text_emb, edge_index, edge_weight)
169 |
170 | ego_image_emb = side_image_emb + self.alpha * ego_image_emb
171 | ego_text_emb = side_text_emb + self.alpha * ego_text_emb
172 | else:
173 | side_image_emb = layer(ego_image_emb, edge_index, edge_weight)
174 | side_text_emb = layer(ego_text_emb, edge_index, edge_weight)
175 | ego_image_emb = side_image_emb
176 | ego_text_emb = side_text_emb
177 | all_image_emb += [ego_image_emb]
178 | all_text_emb += [ego_text_emb]
179 | if self.cf:
180 | if self.cf_gcn == "MeGCN":
181 | side_cf_emb = layer(ego_cf_emb, edge_index, edge_weight)
182 | ego_cf_emb = side_cf_emb + self.alpha * ego_cf_emb
183 | elif self.cf_gcn == "LightGCN":
184 | side_cf_emb = layer(ego_cf_emb, edge_index, edge_weight)
185 | ego_cf_emb = side_cf_emb
186 | all_cf_emb += [ego_cf_emb]
187 |
188 | if not self.lightgcn:
189 | final_image_preference, final_image_emb = torch.split(
190 | ego_image_emb, [self.n_users, self.n_items], dim=0
191 | )
192 | final_text_preference, final_text_emb = torch.split(
193 | ego_text_emb, [self.n_users, self.n_items], dim=0
194 | )
195 | else:
196 | all_image_emb = torch.stack(all_image_emb, dim=1)
197 | all_image_emb = all_image_emb.mean(dim=1, keepdim=False)
198 | final_image_preference, final_image_emb = torch.split(
199 | all_image_emb, [self.n_users, self.n_items], dim=0
200 | )
201 |
202 | all_text_emb = torch.stack(all_text_emb, dim=1)
203 | all_text_emb = all_text_emb.mean(dim=1, keepdim=False)
204 | final_text_preference, final_text_emb = torch.split(
205 | all_text_emb, [self.n_users, self.n_items], dim=0
206 | )
207 |
208 | if self.cf:
209 | if self.cf_gcn == "MeGCN":
210 | final_cf_user_emb, final_cf_item_emb = torch.split(
211 | ego_cf_emb, [self.n_users, self.n_items], dim=0
212 | )
213 | elif self.cf_gcn == "LightGCN":
214 | all_cf_emb = torch.stack(all_cf_emb, dim=1)
215 | all_cf_emb = all_cf_emb.mean(dim=1, keepdim=False)
216 | final_cf_user_emb, final_cf_item_emb = torch.split(
217 | all_cf_emb, [self.n_users, self.n_items], dim=0
218 | )
219 |
220 | if _eval:
221 | return ego_image_emb, ego_text_emb
222 |
223 | if not self.cf:
224 | if self.agg == "concat":
225 | items = torch.cat(
226 | [final_image_emb, final_text_emb], dim=1
227 | ) # [# of items, feat_embed_dim * 2]
228 | user_preference = torch.cat(
229 | [final_image_preference, final_text_preference], dim=1
230 | ) # [# of users, feat_embed_dim * 2]
231 | elif self.agg == "sum":
232 | items = final_image_emb + final_text_emb # [# of items, feat_embed_dim]
233 | user_preference = (
234 | final_image_preference + final_text_preference
235 | ) # [# of users, feat_embed_dim]
236 | elif self.agg == "weighted_sum":
237 | weight = self.softmax(self.modal_weight)
238 | items = (
239 | weight[0] * final_image_emb + weight[1] * final_text_emb
240 | ) # [# of items, feat_embed_dim]
241 | user_preference = (
242 | weight[0] * final_image_preference
243 | + weight[1] * final_text_preference
244 | ) # [# of users, feat_embed_dim]
245 | elif self.agg == "fc":
246 | items = self.transform(
247 | torch.cat([final_image_emb, final_text_emb], dim=1)
248 | ) # [# of items, feat_embed_dim]
249 | user_preference = self.transform(
250 | torch.cat([final_image_preference, final_text_preference], dim=1)
251 | ) # [# of users, feat_embed_dim]
252 | else:
253 | if self.agg == "concat":
254 | items = torch.cat(
255 | [final_image_emb, final_text_emb, final_cf_item_emb], dim=1
256 | ) # [# of items, feat_embed_dim * 2]
257 | user_preference = torch.cat(
258 | [final_image_preference, final_text_preference, final_cf_user_emb],
259 | dim=1,
260 | ) # [# of users, feat_embed_dim * 2]
261 | elif self.agg == "sum":
262 | items = (
263 | final_image_emb + final_text_emb + final_cf_item_emb
264 | ) # [# of items, feat_embed_dim]
265 | user_preference = (
266 | final_image_preference + final_text_preference + final_cf_user_emb
267 | ) # [# of users, feat_embed_dim]
268 | elif self.agg == "weighted_sum":
269 | weight = self.softmax(self.modal_weight)
270 | items = (
271 | weight[0] * final_image_emb
272 | + weight[1] * final_text_emb
273 | + weight[2] * final_cf_item_emb
274 | ) # [# of items, feat_embed_dim]
275 | user_preference = (
276 | weight[0] * final_image_preference
277 | + weight[1] * final_text_preference
278 | + weight[2] * final_cf_user_emb
279 | ) # [# of users, feat_embed_dim]
280 | elif self.agg == "fc":
281 | items = self.transform(
282 | torch.cat(
283 | [final_image_emb, final_text_emb, final_cf_item_emb], dim=1
284 | )
285 | ) # [# of items, feat_embed_dim]
286 | user_preference = self.transform(
287 | torch.cat(
288 | [
289 | final_image_preference,
290 | final_text_preference,
291 | final_cf_user_emb,
292 | ],
293 | dim=1,
294 | )
295 | ) # [# of users, feat_embed_dim]
296 |
297 | return user_preference, items
298 |
299 |
300 | class MONET(nn.Module):
301 | def __init__(
302 | self,
303 | n_users,
304 | n_items,
305 | feat_embed_dim,
306 | nonzero_idx,
307 | has_norm,
308 | image_feats,
309 | text_feats,
310 | n_layers,
311 | alpha,
312 | beta,
313 | agg,
314 | cf,
315 | cf_gcn,
316 | lightgcn,
317 | ):
318 | super(MONET, self).__init__()
319 | self.n_users = n_users
320 | self.n_items = n_items
321 | self.feat_embed_dim = feat_embed_dim
322 | self.n_layers = n_layers
323 | self.nonzero_idx = nonzero_idx
324 | self.alpha = alpha
325 | self.beta = beta
326 | self.agg = agg
327 | self.image_feats = torch.tensor(image_feats, dtype=torch.float).cuda()
328 | self.text_feats = torch.tensor(text_feats, dtype=torch.float).cuda()
329 |
330 | self.megcn = MeGCN(
331 | self.n_users,
332 | self.n_items,
333 | self.n_layers,
334 | has_norm,
335 | self.feat_embed_dim,
336 | self.nonzero_idx,
337 | image_feats,
338 | text_feats,
339 | self.alpha,
340 | self.agg,
341 | cf,
342 | cf_gcn,
343 | lightgcn,
344 | )
345 |
346 | nonzero_idx = torch.tensor(self.nonzero_idx).cuda().long().T
347 | nonzero_idx[1] = nonzero_idx[1] + self.n_users
348 | self.edge_index = torch.cat(
349 | [nonzero_idx, torch.stack([nonzero_idx[1], nonzero_idx[0]], dim=0)], dim=1
350 | )
351 | self.edge_weight = torch.ones((self.edge_index.size(1))).cuda().view(-1, 1)
352 | self.edge_weight = normalize_laplacian(self.edge_index, self.edge_weight)
353 |
354 | nonzero_idx = torch.tensor(self.nonzero_idx).cuda().long().T
355 | self.adj = (
356 | torch.sparse.FloatTensor(
357 | nonzero_idx,
358 | torch.ones((nonzero_idx.size(1))).cuda(),
359 | (self.n_users, self.n_items),
360 | )
361 | .to_dense()
362 | .cuda()
363 | )
364 |
365 | def forward(self, _eval=False):
366 | if _eval:
367 | img, txt = self.megcn(self.edge_index, self.edge_weight, _eval=True)
368 | return img, txt
369 |
370 | user, items = self.megcn(self.edge_index, self.edge_weight, _eval=False)
371 |
372 | return user, items
373 |
374 | def bpr_loss(self, user_emb, item_emb, users, pos_items, neg_items, target_aware):
375 | current_user_emb = user_emb[users]
376 | pos_item_emb = item_emb[pos_items]
377 | neg_item_emb = item_emb[neg_items]
378 |
379 | if target_aware:
380 | # target-aware
381 | item_item = torch.mm(item_emb, item_emb.T)
382 | pos_item_query = item_item[pos_items, :] # (batch_size, n_items)
383 | neg_item_query = item_item[neg_items, :] # (batch_size, n_items)
384 | pos_target_user_alpha = torch.softmax(
385 | torch.multiply(pos_item_query, self.adj[users, :]).masked_fill(
386 | self.adj[users, :] == 0, -1e9
387 | ),
388 | dim=1,
389 | ) # (batch_size, n_items)
390 | neg_target_user_alpha = torch.softmax(
391 | torch.multiply(neg_item_query, self.adj[users, :]).masked_fill(
392 | self.adj[users, :] == 0, -1e9
393 | ),
394 | dim=1,
395 | ) # (batch_size, n_items)
396 | pos_target_user = torch.mm(
397 | pos_target_user_alpha, item_emb
398 | ) # (batch_size, dim)
399 | neg_target_user = torch.mm(
400 | neg_target_user_alpha, item_emb
401 | ) # (batch_size, dim)
402 |
403 | # predictor
404 | pos_scores = (1 - self.beta) * torch.sum(
405 | torch.mul(current_user_emb, pos_item_emb), dim=1
406 | ) + self.beta * torch.sum(torch.mul(pos_target_user, pos_item_emb), dim=1)
407 | neg_scores = (1 - self.beta) * torch.sum(
408 | torch.mul(current_user_emb, neg_item_emb), dim=1
409 | ) + self.beta * torch.sum(torch.mul(neg_target_user, neg_item_emb), dim=1)
410 | else:
411 | pos_scores = torch.sum(torch.mul(current_user_emb, pos_item_emb), dim=1)
412 | neg_scores = torch.sum(torch.mul(current_user_emb, neg_item_emb), dim=1)
413 |
414 | maxi = F.logsigmoid(pos_scores - neg_scores)
415 | mf_loss = -torch.mean(maxi)
416 |
417 | regularizer = (
418 | 1.0 / 2 * (pos_item_emb**2).sum()
419 | + 1.0 / 2 * (neg_item_emb**2).sum()
420 | + 1.0 / 2 * (current_user_emb**2).sum()
421 | )
422 | emb_loss = regularizer / pos_item_emb.size(0)
423 |
424 | reg_loss = 0.0
425 |
426 | return mf_loss, emb_loss, reg_loss
427 |
--------------------------------------------------------------------------------