├── codes ├── utility │ ├── __pycache__ │ │ ├── parser.cpython-36.pyc │ │ ├── parser.cpython-38.pyc │ │ ├── load_data.cpython-36.pyc │ │ ├── load_data.cpython-38.pyc │ │ ├── metrics.cpython-36.pyc │ │ ├── metrics.cpython-38.pyc │ │ ├── batch_test.cpython-36.pyc │ │ ├── batch_test.cpython-38.pyc │ │ ├── batch_test.cpython-310-pytest-7.4.3.pyc │ │ ├── batch_test.cpython-36-pytest-6.2.5.pyc │ │ ├── batch_test.cpython-36-pytest-7.0.1.pyc │ │ └── index.html.tmp │ ├── metrics.py │ ├── parser.py │ ├── batch_test.py │ └── load_data.py ├── main.py ├── Preliminaries.ipynb ├── data │ └── build_data.py └── Models.py ├── requirements.txt ├── LICENSE └── README.md /codes/utility/__pycache__/parser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/parser.cpython-36.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/parser.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/parser.cpython-38.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/load_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/load_data.cpython-36.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/load_data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/load_data.cpython-38.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/batch_test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/batch_test.cpython-36.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/batch_test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/batch_test.cpython-38.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/batch_test.cpython-310-pytest-7.4.3.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/batch_test.cpython-310-pytest-7.4.3.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/batch_test.cpython-36-pytest-6.2.5.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/batch_test.cpython-36-pytest-6.2.5.pyc -------------------------------------------------------------------------------- /codes/utility/__pycache__/batch_test.cpython-36-pytest-7.0.1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eddie-liner/MONET/HEAD/codes/utility/__pycache__/batch_test.cpython-36-pytest-7.0.1.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gensim==3.8.3 2 | pytorch==1.10.2+cu113 3 | sentence_transformers==2.2.0 4 | pandas 5 | numpy 6 | tqdm 7 | torch-scatter 8 | torch-sparse 9 | torch-cluster 10 | torch-spline-conv 11 | torch-geometric 12 | -------------------------------------------------------------------------------- /codes/utility/__pycache__/index.html.tmp: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Directory listing for /MONET/codes/utility/__pycache__/ 6 | 7 | 8 |

Directory listing for /MONET/codes/utility/__pycache__/

9 |
10 | 20 |
21 | 22 | 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Big Data and Multi-modal Computing Group, CRIPAC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /codes/utility/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import roc_auc_score 3 | 4 | 5 | def recall(rank, ground_truth, N): 6 | return len(set(rank[:N]) & set(ground_truth)) / float(len(set(ground_truth))) 7 | 8 | 9 | def precision_at_k(r, k): 10 | """Score is precision @ k. 11 | 12 | Relevance is binary (nonzero is relevant). 13 | Returns: 14 | Precision @ k 15 | Raises: 16 | ValueError: len(r) must be >= k 17 | """ 18 | assert k >= 1 19 | r = np.asarray(r)[:k] 20 | return np.mean(r) 21 | 22 | 23 | def average_precision(r, cut): 24 | """Score is average precision (area under PR curve). 25 | 26 | Relevance is binary (nonzero is relevant). 27 | Returns: 28 | Average precision 29 | """ 30 | r = np.asarray(r) 31 | out = [precision_at_k(r, k + 1) for k in range(cut) if r[k]] 32 | if not out: 33 | return 0.0 34 | return np.sum(out) / float(min(cut, np.sum(r))) 35 | 36 | 37 | def mean_average_precision(rs): 38 | """Score is mean average precision. 39 | 40 | Relevance is binary (nonzero is relevant). 41 | Returns: 42 | Mean average precision 43 | """ 44 | return np.mean([average_precision(r) for r in rs]) 45 | 46 | 47 | def dcg_at_k(r, k, method=1): 48 | """Score is discounted cumulative gain (dcg). 49 | 50 | Relevance is positive real values. Can use binary 51 | as the previous methods. 52 | Returns: 53 | Discounted cumulative gain 54 | """ 55 | r = np.asfarray(r)[:k] 56 | if r.size: 57 | if method == 0: 58 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 59 | elif method == 1: 60 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 61 | else: 62 | raise ValueError("method must be 0 or 1.") 63 | return 0.0 64 | 65 | 66 | def ndcg_at_k(r, k, method=1): 67 | """Score is normalized discounted cumulative gain (ndcg). 68 | 69 | Relevance is positive real values. Can use binary 70 | as the previous methods. 71 | Returns: 72 | Normalized discounted cumulative gain 73 | """ 74 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) 75 | if not dcg_max: 76 | return 0.0 77 | return dcg_at_k(r, k, method) / dcg_max 78 | 79 | 80 | def recall_at_k(r, k, all_pos_num): 81 | r = np.asfarray(r)[:k] 82 | if all_pos_num == 0: 83 | return 0 84 | else: 85 | return np.sum(r) / all_pos_num 86 | 87 | 88 | def hit_at_k(r, k): 89 | r = np.array(r)[:k] 90 | if np.sum(r) > 0: 91 | return 1.0 92 | else: 93 | return 0.0 94 | 95 | 96 | def F1(pre, rec): 97 | if pre + rec > 0: 98 | return (2.0 * pre * rec) / (pre + rec) 99 | else: 100 | return 0.0 101 | 102 | 103 | def auc(ground_truth, prediction): 104 | try: 105 | res = roc_auc_score(y_true=ground_truth, y_score=prediction) 106 | except Exception: 107 | res = 0.0 108 | return res 109 | -------------------------------------------------------------------------------- /codes/utility/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(flags=False): 5 | parser = argparse.ArgumentParser(description="") 6 | 7 | parser.add_argument( 8 | "--data_path", nargs="?", default="data/", help="Input data path." 9 | ) 10 | parser.add_argument("--seed", type=int, default=123, help="Random seed") 11 | parser.add_argument( 12 | "--dataset", 13 | nargs="?", 14 | default="MenClothing", 15 | help="Choose a dataset from {Toys_and_Games, Beauty, MenClothing, WomenClothing}", 16 | ) 17 | parser.add_argument( 18 | "--verbose", type=int, default=5, help="Interval of evaluation." 19 | ) 20 | parser.add_argument("--epoch", type=int, default=1000, help="Number of epoch.") 21 | parser.add_argument("--batch_size", type=int, default=1024, help="Batch size.") 22 | parser.add_argument( 23 | "--regs", nargs="?", default="[1e-5,1e-5]", help="Regularizations." 24 | ) 25 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate.") 26 | parser.add_argument("--embed_size", type=int, default=64, help="Embedding size.") 27 | parser.add_argument( 28 | "--feat_embed_dim", type=int, default=64, help="Feature embedding size." 29 | ) 30 | parser.add_argument( 31 | "--alpha", type=float, default=1.0, help="Coefficient of self node features." 32 | ) 33 | parser.add_argument( 34 | "--beta", 35 | type=float, 36 | default=0.3, 37 | help="Coefficient of fine-grained interest matching.", 38 | ) 39 | parser.add_argument( 40 | "--core", 41 | type=int, 42 | default=5, 43 | help="5-core for warm-start; 0-core for cold start.", 44 | ) 45 | parser.add_argument( 46 | "--n_layers", type=int, default=2, help="Number of graph conv layers." 47 | ) 48 | parser.add_argument("--has_norm", default=True, action="store_false") 49 | parser.add_argument("--target_aware", default=True, action="store_false") 50 | parser.add_argument( 51 | "--agg", 52 | type=str, 53 | default="concat", 54 | help="Choose a dataset from {sum, weighted_sum, concat, fc}", 55 | ) 56 | parser.add_argument("--cf", default=False, action="store_true") 57 | parser.add_argument( 58 | "--cf_gcn", 59 | type=str, 60 | default="LightGCN", 61 | help="Choose a dataset from {MeGCN, LightGCN}", 62 | ) 63 | parser.add_argument("--lightgcn", default=False, action="store_true") 64 | parser.add_argument("--model_name", type=str) 65 | parser.add_argument("--early_stopping_patience", type=int, default=10, help="") 66 | parser.add_argument("--gpu_id", type=int, default=0, help="GPU id") 67 | parser.add_argument( 68 | "--Ks", nargs="?", default="[10, 20]", help="K value of ndcg/recall @ k" 69 | ) 70 | parser.add_argument( 71 | "--test_flag", 72 | nargs="?", 73 | default="part", 74 | help="Specify the test type from {part, full}, indicating whether the reference is done in mini-batch", 75 | ) 76 | 77 | if flags: 78 | attribute_dict = dict(vars(parser.parse_args())) 79 | print("*" * 32 + " Experiment setting " + "*" * 32) 80 | for k, v in attribute_dict.items(): 81 | print(k + " : " + str(v)) 82 | print("*" * 32 + " Experiment setting " + "*" * 32) 83 | return parser.parse_args() 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MONET: Modality-Embracing Graph Convolutional Network and Target-Aware Attention for Multimedia Recommendation 2 | This repository provides a reference implementation of *MONET* as described in the following paper: 3 | > MONET: Modality-Embracing Graph Convolutional Network and Target-Aware Attention for Multimedia Recommendation
4 | > Yungi Kim, Taeri Kim, Won-Yong Shin and Sang-Wook Kim
5 | > 17th ACM Int'l Conf. on Web Search and Data Mining (ACM WSDM 2024)
6 | 7 | ### Overview of MONET 8 | ![monet](https://github.com/Kimyungi/MONET/assets/28508383/6723ccd1-8a8e-4710-ba7b-6a7bee928301) 9 | 10 | 11 | ### Authors 12 | - Yungi Kim (gozj3319@hanyang.ac.kr) 13 | - Taeri Kim (taerik@hanyang.ac.kr) 14 | - Won-Yong Shin (wy.shin@yonsei.ac.kr) 15 | - Sang-Wook Kim (wook@hanyang.ac.kr) 16 | 17 | ### Requirements 18 | The code has been tested running under Python 3.6.13. The required packages are as follows: 19 | - ```gensim==3.8.3``` 20 | - ```pytorch==1.10.2+cu113``` 21 | - ```torch_geometric==2.0.3``` 22 | - ```sentence_transformers==2.2.0``` 23 | - ```pandas``` 24 | - ```numpy``` 25 | - ```tqdm``` 26 | - ```torch-scatter``` 27 | - ```torch-sparse``` 28 | - ```torch-cluster``` 29 | - ```torch-spline-conv``` 30 | - ```torch-geometric``` 31 | 32 | ### Dataset Preparation 33 | #### Dataset Download 34 | *Men Clothing and Women Clothing*: Download Amazon product dataset provided by [MAML](https://github.com/liufancs/MAML). Put data folder into the directory data/. 35 | 36 | *Beauty and Toys & Games*: Download 5-core reviews data, meta data, and image features from [Amazon product dataset](http://jmcauley.ucsd.edu/data/amazon/links.html). Put data into the directory data/{folder}/meta-data/. 37 | 38 | #### Dataset Preprocessing 39 | Run ```python build_data.py --name={Dataset}``` 40 | 41 | ### Usage 42 | #### For simplicity, we provide usage for the Women Clothing dataset. 43 | ------------------------------------ 44 | - For MONET in RQ1, 45 | ``` 46 | python main.py --agg=concat --n_layers=2 --alpha=1.0 --beta=0.3 --dataset=WomenClothing --model_name=MONET_2_10_3 47 | ``` 48 | ------------------------------------ 49 | - For RQ2, refer the second cell in "Preliminaries.ipynb". 50 | ------------------------------------ 51 | - For MONET_w/o_MeGCN and MONET_w/o_TA in RQ3, 52 | ``` 53 | python main.py --agg=concat --n_layers=0 --alpha=1.0 --beta=0.3 --dataset=WomenClothing --model_name=MONET_wo_MeGCN 54 | python main.py --target_aware --agg=concat --n_layers=2 --alpha=1.0 --beta=0.3 --dataset=WomenClothing --model_name=MONET_wo_TA 55 | ``` 56 | ------------------------------------ 57 | - For RQ4 (hyperparameters $\alpha$, $\beta$ sensitivity), 58 | ``` 59 | python main.py --agg=concat --n_layers=2 --alpha={value} --beta=0.3 --dataset=WomenClothing --model_name=MONET_2_{alpha}_3 60 | python main.py --agg=concat --n_layers=2 --alpha=1.0 --beta={value} --dataset=WomenClothing --model_name=MONET_2_10_{beta} 61 | ``` 62 | 63 | ### Cite 64 | We encourage you to cite our paper if you have used the code in your work. You can use the following BibTex citation: 65 | ``` 66 | @inproceedings{kim24wsdm, 67 | author = {Yungi Kim and Taeri Kim and Won{-}Yong Shin and Sang{-}Wook Kim}, 68 | title = {MONET: Modality-Embracing Graph Convolutional Network and Target-Aware Attention for Multimedia Recommendation}, 69 | booktitle = {ACM International Conference on Web Search and Data Mining (ACM WSDM 2024)}, 70 | year = {2024} 71 | } 72 | ``` 73 | 74 | ### Acknowledgement 75 | The structure of this code is largely based on [LATTICE](https://github.com/CRIPAC-DIG/LATTICE). Thank for their work. 76 | -------------------------------------------------------------------------------- /codes/utility/batch_test.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | import multiprocessing 3 | import pickle 4 | from time import time 5 | 6 | import numpy as np 7 | import torch 8 | import utility.metrics as metrics 9 | from tqdm import tqdm 10 | from utility.load_data import Data 11 | from utility.parser import parse_args 12 | 13 | cores = multiprocessing.cpu_count() // 5 14 | 15 | args = parse_args() 16 | Ks = eval(args.Ks) 17 | 18 | data_generator = Data(path=args.data_path + args.dataset, batch_size=args.batch_size) 19 | USR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items 20 | N_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test 21 | if args.target_aware: 22 | BATCH_SIZE = 16 23 | else: 24 | BATCH_SIZE = args.batch_size 25 | 26 | 27 | def ranklist_by_heapq(user_pos_test, test_items, rating, Ks): 28 | item_score = {} 29 | for i in test_items: 30 | item_score[i] = rating[i] 31 | 32 | K_max = max(Ks) 33 | K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get) 34 | 35 | r = [] 36 | for i in K_max_item_score: 37 | if i in user_pos_test: 38 | r.append(1) 39 | else: 40 | r.append(0) 41 | auc = 0.0 42 | return r, auc 43 | 44 | 45 | def get_auc(item_score, user_pos_test): 46 | item_score = sorted(item_score.items(), key=lambda kv: kv[1]) 47 | item_score.reverse() 48 | item_sort = [x[0] for x in item_score] 49 | posterior = [x[1] for x in item_score] 50 | 51 | r = [] 52 | for i in item_sort: 53 | if i in user_pos_test: 54 | r.append(1) 55 | else: 56 | r.append(0) 57 | auc = metrics.auc(ground_truth=r, prediction=posterior) 58 | return auc 59 | 60 | 61 | def ranklist_by_sorted(user_pos_test, test_items, rating, Ks): 62 | item_score = {} 63 | for i in test_items: 64 | item_score[i] = rating[i] 65 | 66 | K_max = max(Ks) 67 | K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get) 68 | 69 | r = [] 70 | for i in K_max_item_score: 71 | if i in user_pos_test: 72 | r.append(1) 73 | else: 74 | r.append(0) 75 | auc = get_auc(item_score, user_pos_test) 76 | return r, auc 77 | 78 | 79 | def get_performance(user_pos_test, r, auc, Ks): 80 | precision, recall, ndcg, hit_ratio = [], [], [], [] 81 | 82 | for K in Ks: 83 | precision.append(metrics.precision_at_k(r, K)) 84 | recall.append(metrics.recall_at_k(r, K, len(user_pos_test))) 85 | ndcg.append(metrics.ndcg_at_k(r, K)) 86 | hit_ratio.append(metrics.hit_at_k(r, K)) 87 | 88 | return { 89 | "recall": np.array(recall), 90 | "precision": np.array(precision), 91 | "ndcg": np.array(ndcg), 92 | "hit_ratio": np.array(hit_ratio), 93 | "auc": auc, 94 | } 95 | 96 | 97 | def test_one_user(x): 98 | # user u's ratings for user u 99 | is_val = x[-1] 100 | rating = x[0] 101 | # uid 102 | u = x[1] 103 | # user u's items in the training set 104 | try: 105 | training_items = data_generator.train_items[u] 106 | except Exception: 107 | training_items = [] 108 | if is_val: 109 | user_pos_test = data_generator.val_set[u] 110 | else: 111 | user_pos_test = data_generator.test_set[u] 112 | 113 | all_items = set(range(ITEM_NUM)) 114 | 115 | test_items = list(all_items - set(training_items)) 116 | 117 | if args.test_flag == "part": 118 | r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, Ks) 119 | else: 120 | r, auc = ranklist_by_sorted(user_pos_test, test_items, rating, Ks) 121 | 122 | return get_performance(user_pos_test, r, auc, Ks) 123 | 124 | 125 | def test_torch( 126 | ua_embeddings, ia_embeddings, users_to_test, is_val, adj, beta, target_aware 127 | ): 128 | result = { 129 | "precision": np.zeros(len(Ks)), 130 | "recall": np.zeros(len(Ks)), 131 | "ndcg": np.zeros(len(Ks)), 132 | "hit_ratio": np.zeros(len(Ks)), 133 | "auc": 0.0, 134 | } 135 | pool = multiprocessing.Pool(cores) 136 | 137 | u_batch_size = BATCH_SIZE * 2 138 | i_batch_size = BATCH_SIZE 139 | 140 | test_users = users_to_test 141 | n_test_users = len(test_users) 142 | n_user_batchs = n_test_users // u_batch_size + 1 143 | count = 0 144 | 145 | item_item = torch.mm(ia_embeddings, ia_embeddings.T) 146 | 147 | for u_batch_id in tqdm(range(n_user_batchs), position=1, leave=False): 148 | start = u_batch_id * u_batch_size 149 | end = (u_batch_id + 1) * u_batch_size 150 | user_batch = test_users[start:end] 151 | if target_aware: 152 | n_item_batchs = ITEM_NUM // i_batch_size + 1 153 | rate_batch = np.zeros(shape=(len(user_batch), ITEM_NUM)) 154 | 155 | i_count = 0 156 | for i_batch_id in range(n_item_batchs): 157 | i_start = i_batch_id * i_batch_size 158 | i_end = min((i_batch_id + 1) * i_batch_size, ITEM_NUM) 159 | 160 | item_batch = range(i_start, i_end) 161 | u_g_embeddings = ua_embeddings[user_batch] # (batch_size, dim) 162 | i_g_embeddings = ia_embeddings[item_batch] # (batch_size, dim) 163 | 164 | # target-aware 165 | item_query = item_item[item_batch, :] # (item_batch_size, n_items) 166 | item_target_user_alpha = torch.softmax( 167 | torch.multiply( 168 | item_query.unsqueeze(1), adj[user_batch, :].unsqueeze(0) 169 | ).masked_fill( 170 | adj[user_batch, :].repeat(len(item_batch), 1, 1) == 0, -1e9 171 | ), 172 | dim=2, 173 | ) # (item_batch_size, user_batch_size, n_items) 174 | item_target_user = torch.matmul( 175 | item_target_user_alpha, ia_embeddings 176 | ) # (item_batch_size, user_batch_size, dim) 177 | 178 | # target-aware 179 | i_rate_batch = (1 - beta) * torch.matmul( 180 | u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1) 181 | ) + beta * torch.sum( 182 | torch.mul( 183 | item_target_user.permute(1, 0, 2).contiguous(), i_g_embeddings 184 | ), 185 | dim=2, 186 | ) 187 | 188 | rate_batch[:, i_start:i_end] = i_rate_batch.detach().cpu().numpy() 189 | i_count += i_rate_batch.shape[1] 190 | 191 | del ( 192 | item_query, 193 | item_target_user_alpha, 194 | item_target_user, 195 | i_g_embeddings, 196 | u_g_embeddings, 197 | ) 198 | torch.cuda.empty_cache() 199 | 200 | assert i_count == ITEM_NUM 201 | 202 | else: 203 | item_batch = range(ITEM_NUM) 204 | u_g_embeddings = ua_embeddings[user_batch] 205 | i_g_embeddings = ia_embeddings[item_batch] 206 | 207 | rate_batch = torch.matmul( 208 | u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1) 209 | ) 210 | rate_batch = rate_batch.detach().cpu().numpy() 211 | 212 | user_batch_rating_uid = zip(rate_batch, user_batch, [is_val] * len(user_batch)) 213 | 214 | batch_result = pool.map(test_one_user, user_batch_rating_uid) 215 | count += len(batch_result) 216 | 217 | for re in batch_result: 218 | result["precision"] += re["precision"] / n_test_users 219 | result["recall"] += re["recall"] / n_test_users 220 | result["ndcg"] += re["ndcg"] / n_test_users 221 | result["hit_ratio"] += re["hit_ratio"] / n_test_users 222 | result["auc"] += re["auc"] / n_test_users 223 | 224 | assert count == n_test_users 225 | pool.close() 226 | return result 227 | -------------------------------------------------------------------------------- /codes/utility/load_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | # args = parse_args() 5 | import random as rd 6 | 7 | # from utility.parser import parse_args 8 | from collections import defaultdict 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import scipy.sparse as sp 13 | from gensim.models.doc2vec import Doc2Vec 14 | 15 | 16 | class Data(object): 17 | def __init__(self, path, batch_size): 18 | self.path = path + "/5-core" 19 | self.batch_size = batch_size 20 | 21 | train_file = path + "/5-core/train.json" 22 | val_file = path + "/5-core/val.json" 23 | test_file = path + "/5-core/test.json" 24 | 25 | # get number of users and items 26 | self.n_users, self.n_items = 0, 0 27 | self.n_train, self.n_test, self.n_val = 0, 0, 0 28 | self.neg_pools = {} 29 | 30 | self.exist_users = [] 31 | 32 | train = json.load(open(train_file)) 33 | test = json.load(open(test_file)) 34 | val = json.load(open(val_file)) 35 | for uid, items in train.items(): 36 | if len(items) == 0: 37 | continue 38 | uid = int(uid) 39 | self.exist_users.append(uid) 40 | self.n_items = max(self.n_items, max(items)) 41 | self.n_users = max(self.n_users, uid) 42 | self.n_train += len(items) 43 | 44 | for uid, items in test.items(): 45 | uid = int(uid) 46 | try: 47 | self.n_items = max(self.n_items, max(items)) 48 | self.n_test += len(items) 49 | except Exception: 50 | continue 51 | 52 | for uid, items in val.items(): 53 | uid = int(uid) 54 | try: 55 | self.n_items = max(self.n_items, max(items)) 56 | self.n_val += len(items) 57 | except Exception: 58 | continue 59 | 60 | self.n_items += 1 61 | self.n_users += 1 62 | 63 | self.print_statistics() 64 | 65 | self.R = sp.dok_matrix((self.n_users, self.n_items), dtype=np.float32) 66 | 67 | self.train_items, self.test_set, self.val_set = {}, {}, {} 68 | for uid, train_items in train.items(): 69 | if len(train_items) == 0: 70 | continue 71 | uid = int(uid) 72 | for _, i in enumerate(train_items): 73 | self.R[uid, i] = 1.0 74 | 75 | self.train_items[uid] = train_items 76 | 77 | for uid, test_items in test.items(): 78 | uid = int(uid) 79 | if len(test_items) == 0: 80 | continue 81 | try: 82 | self.test_set[uid] = test_items 83 | except Exception: 84 | continue 85 | 86 | for uid, val_items in val.items(): 87 | uid = int(uid) 88 | if len(val_items) == 0: 89 | continue 90 | try: 91 | self.val_set[uid] = val_items 92 | except Exception: 93 | continue 94 | 95 | def nonzero_idx(self): 96 | r, c = self.R.nonzero() 97 | idx = list(zip(r, c)) 98 | return idx 99 | 100 | def sample(self): 101 | if self.batch_size <= self.n_users: 102 | users = rd.sample(self.exist_users, self.batch_size) 103 | else: 104 | users = [rd.choice(self.exist_users) for _ in range(self.batch_size)] 105 | # users = self.exist_users[:] 106 | 107 | def sample_pos_items_for_u(u, num): 108 | pos_items = self.train_items[u] 109 | n_pos_items = len(pos_items) 110 | pos_batch = [] 111 | while True: 112 | if len(pos_batch) == num: 113 | break 114 | pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0] 115 | pos_i_id = pos_items[pos_id] 116 | 117 | if pos_i_id not in pos_batch: 118 | pos_batch.append(pos_i_id) 119 | return pos_batch 120 | 121 | def sample_neg_items_for_u(u, num): 122 | neg_items = [] 123 | while True: 124 | if len(neg_items) == num: 125 | break 126 | neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0] 127 | if neg_id not in self.train_items[u] and neg_id not in neg_items: 128 | neg_items.append(neg_id) 129 | return neg_items 130 | 131 | pos_items, neg_items = [], [] 132 | for u in users: 133 | pos_items += sample_pos_items_for_u(u, 1) 134 | neg_items += sample_neg_items_for_u(u, 1) 135 | return users, pos_items, neg_items 136 | 137 | def print_statistics(self): 138 | print("n_users=%d, n_items=%d" % (self.n_users, self.n_items)) 139 | print("n_interactions=%d" % (self.n_train + self.n_val + self.n_test)) 140 | print( 141 | "n_train=%d, n_val=%d, n_test=%d, sparsity=%.5f" 142 | % ( 143 | self.n_train, 144 | self.n_val, 145 | self.n_test, 146 | (self.n_train + self.n_val + self.n_test) 147 | / (self.n_users * self.n_items), 148 | ) 149 | ) 150 | 151 | 152 | def dataset_merge_and_split(path): 153 | df = pd.read_csv(path + "/train.csv", index_col=None, usecols=None) 154 | # Construct matrix 155 | ui = defaultdict(list) 156 | for _, row in df.iterrows(): 157 | user, item = int(row["userID"]), int(row["itemID"]) 158 | ui[user].append(item) 159 | 160 | df = pd.read_csv(path + "/test.csv", index_col=None, usecols=None) 161 | for _, row in df.iterrows(): 162 | user, item = int(row["userID"]), int(row["itemID"]) 163 | ui[user].append(item) 164 | 165 | train_json = {} 166 | val_json = {} 167 | test_json = {} 168 | for u, items in ui.items(): 169 | if len(items) < 10: 170 | testval = np.random.choice(len(items), 2, replace=False) 171 | else: 172 | testval = np.random.choice(len(items), int(len(items) * 0.2), replace=False) 173 | 174 | test = testval[: len(testval) // 2] 175 | val = testval[len(testval) // 2 :] 176 | train = [i for i in list(range(len(items))) if i not in testval] 177 | train_json[u] = [items[idx] for idx in train] 178 | val_json[u] = [items[idx] for idx in val.tolist()] 179 | test_json[u] = [items[idx] for idx in test.tolist()] 180 | 181 | with open(path + "/5-core/train.json", "w") as f: 182 | json.dump(train_json, f) 183 | with open(path + "/5-core/val.json", "w") as f: 184 | json.dump(val_json, f) 185 | with open(path + "/5-core/test.json", "w") as f: 186 | json.dump(test_json, f) 187 | 188 | 189 | def load_textual_image_features(data_path): 190 | asin_dict = json.load(open(os.path.join(data_path, "asin_sample.json"), "r")) 191 | 192 | # Prepare textual feture data. 193 | doc2vec_model = Doc2Vec.load(os.path.join(data_path, "doc2vecFile")) 194 | vis_vec = np.load( 195 | os.path.join(data_path, "image_feature.npy"), allow_pickle=True 196 | ).item() 197 | text_vec = {} 198 | for asin in asin_dict: 199 | text_vec[asin] = doc2vec_model.docvecs[asin] 200 | 201 | all_dict = {} 202 | num_items = 0 203 | filename = data_path + "/train.csv" 204 | df = pd.read_csv(filename, index_col=None, usecols=None) 205 | for _, row in df.iterrows(): 206 | asin, i = row["asin"], int(row["itemID"]) 207 | all_dict[i] = asin 208 | num_items = max(num_items, i) 209 | filename = data_path + "/test.csv" 210 | df = pd.read_csv(filename, index_col=None, usecols=None) 211 | for _, row in df.iterrows(): 212 | asin, i = row["asin"], int(row["itemID"]) 213 | all_dict[i] = asin 214 | num_items = max(num_items, i) 215 | 216 | t_features = [] 217 | v_features = [] 218 | for i in range(num_items + 1): 219 | t_features.append(text_vec[all_dict[i]]) 220 | v_features.append(vis_vec[all_dict[i]]) 221 | 222 | np.save(data_path + "/text_feat.npy", np.asarray(t_features, dtype=np.float32)) 223 | np.save(data_path + "/image_feat.npy", np.asarray(v_features, dtype=np.float32)) 224 | -------------------------------------------------------------------------------- /codes/main.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import sys 4 | from time import time 5 | 6 | import numpy as np 7 | import torch 8 | import torch.optim as optim 9 | from Models import MONET 10 | from utility.batch_test import data_generator, test_torch 11 | from utility.parser import parse_args 12 | 13 | 14 | class Trainer(object): 15 | def __init__(self, data_config, args): 16 | # argument settings 17 | self.n_users = data_config["n_users"] 18 | self.n_items = data_config["n_items"] 19 | 20 | self.feat_embed_dim = args.feat_embed_dim 21 | self.lr = args.lr 22 | self.emb_dim = args.embed_size 23 | self.batch_size = args.batch_size 24 | self.n_layers = args.n_layers 25 | self.has_norm = args.has_norm 26 | self.regs = eval(args.regs) 27 | self.decay = self.regs[0] 28 | self.lamb = self.regs[1] 29 | self.alpha = args.alpha 30 | self.beta = args.beta 31 | self.dataset = args.dataset 32 | self.model_name = args.model_name 33 | self.agg = args.agg 34 | self.target_aware = args.target_aware 35 | self.cf = args.cf 36 | self.cf_gcn = args.cf_gcn 37 | self.lightgcn = args.lightgcn 38 | 39 | self.nonzero_idx = data_config["nonzero_idx"] 40 | 41 | self.image_feats = np.load("data/{}/image_feat.npy".format(self.dataset)) 42 | self.text_feats = np.load("data/{}/text_feat.npy".format(self.dataset)) 43 | 44 | self.model = MONET( 45 | self.n_users, 46 | self.n_items, 47 | self.feat_embed_dim, 48 | self.nonzero_idx, 49 | self.has_norm, 50 | self.image_feats, 51 | self.text_feats, 52 | self.n_layers, 53 | self.alpha, 54 | self.beta, 55 | self.agg, 56 | self.cf, 57 | self.cf_gcn, 58 | self.lightgcn, 59 | ) 60 | 61 | self.model = self.model.cuda() 62 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) 63 | self.lr_scheduler = self.set_lr_scheduler() 64 | 65 | def set_lr_scheduler(self): 66 | fac = lambda epoch: 0.96 ** (epoch / 50) 67 | scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=fac) 68 | return scheduler 69 | 70 | def test(self, users_to_test, is_val): 71 | self.model.eval() 72 | with torch.no_grad(): 73 | ua_embeddings, ia_embeddings = self.model() 74 | result = test_torch( 75 | ua_embeddings, 76 | ia_embeddings, 77 | users_to_test, 78 | is_val, 79 | self.adj, 80 | self.beta, 81 | self.target_aware, 82 | ) 83 | return result 84 | 85 | def train(self): 86 | nonzero_idx = torch.tensor(self.nonzero_idx).cuda().long().T 87 | self.adj = ( 88 | torch.sparse.FloatTensor( 89 | nonzero_idx, 90 | torch.ones((nonzero_idx.size(1))).cuda(), 91 | (self.n_users, self.n_items), 92 | ) 93 | .to_dense() 94 | .cuda() 95 | ) 96 | stopping_step = 0 97 | 98 | n_batch = data_generator.n_train // args.batch_size + 1 99 | best_recall = 0 100 | for epoch in range(args.epoch): 101 | t1 = time() 102 | loss, mf_loss, emb_loss, reg_loss = 0.0, 0.0, 0.0, 0.0 103 | n_batch = data_generator.n_train // args.batch_size + 1 104 | for _ in range(n_batch): 105 | self.model.train() 106 | self.optimizer.zero_grad() 107 | user_emb, item_emb = self.model() 108 | users, pos_items, neg_items = data_generator.sample() 109 | 110 | batch_mf_loss, batch_emb_loss, batch_reg_loss = self.model.bpr_loss( 111 | user_emb, item_emb, users, pos_items, neg_items, self.target_aware 112 | ) 113 | 114 | batch_emb_loss = self.decay * batch_emb_loss 115 | batch_loss = batch_mf_loss + batch_emb_loss + batch_reg_loss 116 | 117 | batch_loss.backward(retain_graph=True) 118 | self.optimizer.step() 119 | 120 | loss += float(batch_loss) 121 | mf_loss += float(batch_mf_loss) 122 | emb_loss += float(batch_emb_loss) 123 | reg_loss += float(batch_reg_loss) 124 | 125 | del user_emb, item_emb 126 | torch.cuda.empty_cache() 127 | 128 | self.lr_scheduler.step() 129 | 130 | if math.isnan(loss): 131 | print("ERROR: loss is nan.") 132 | sys.exit() 133 | 134 | perf_str = "Pre_Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f + %.5f]" % ( 135 | epoch, 136 | time() - t1, 137 | loss, 138 | mf_loss, 139 | emb_loss, 140 | reg_loss, 141 | ) 142 | print(perf_str) 143 | 144 | if epoch % args.verbose != 0: 145 | continue 146 | 147 | t2 = time() 148 | users_to_test = list(data_generator.test_set.keys()) 149 | users_to_val = list(data_generator.val_set.keys()) 150 | ret = self.test(users_to_val, is_val=True) 151 | 152 | t3 = time() 153 | 154 | if args.verbose > 0: 155 | perf_str = ( 156 | "Pre_Epoch %d [%.1fs + %.1fs]: val==[%.5f=%.5f + %.5f + %.5f], recall=[%.5f, %.5f], " 157 | "precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]" 158 | % ( 159 | epoch, 160 | t2 - t1, 161 | t3 - t2, 162 | loss, 163 | mf_loss, 164 | emb_loss, 165 | reg_loss, 166 | ret["recall"][0], 167 | ret["recall"][-1], 168 | ret["precision"][0], 169 | ret["precision"][-1], 170 | ret["hit_ratio"][0], 171 | ret["hit_ratio"][-1], 172 | ret["ndcg"][0], 173 | ret["ndcg"][-1], 174 | ) 175 | ) 176 | print(perf_str) 177 | 178 | if ret["recall"][1] > best_recall: 179 | best_recall = ret["recall"][1] 180 | stopping_step = 0 181 | torch.save( 182 | {self.model_name: self.model.state_dict()}, 183 | "./models/" + self.dataset + "_" + self.model_name, 184 | ) 185 | elif stopping_step < args.early_stopping_patience: 186 | stopping_step += 1 187 | print("#####Early stopping steps: %d #####" % stopping_step) 188 | else: 189 | print("#####Early stop! #####") 190 | break 191 | 192 | self.model = MONET( 193 | self.n_users, 194 | self.n_items, 195 | self.feat_embed_dim, 196 | self.nonzero_idx, 197 | self.has_norm, 198 | self.image_feats, 199 | self.text_feats, 200 | self.n_layers, 201 | self.alpha, 202 | self.beta, 203 | self.agg, 204 | self.cf, 205 | self.cf_gcn, 206 | self.lightgcn, 207 | ) 208 | 209 | self.model.load_state_dict( 210 | torch.load( 211 | "./models/" + self.dataset + "_" + self.model_name, 212 | map_location=torch.device("cpu"), 213 | )[self.model_name] 214 | ) 215 | self.model.cuda() 216 | test_ret = self.test(users_to_test, is_val=False) 217 | print("Final ", test_ret) 218 | 219 | 220 | def set_seed(seed): 221 | np.random.seed(seed) 222 | random.seed(seed) 223 | torch.manual_seed(seed) # cpu 224 | torch.cuda.manual_seed_all(seed) # gpu 225 | 226 | 227 | if __name__ == "__main__": 228 | args = parse_args(True) 229 | set_seed(args.seed) 230 | 231 | config = dict() 232 | config["n_users"] = data_generator.n_users 233 | config["n_items"] = data_generator.n_items 234 | 235 | nonzero_idx = data_generator.nonzero_idx() 236 | config["nonzero_idx"] = nonzero_idx 237 | 238 | trainer = Trainer(config, args) 239 | trainer.train() 240 | -------------------------------------------------------------------------------- /codes/Preliminaries.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "n_users=19244, n_items=14596\n", 13 | "n_interactions=135326\n", 14 | "n_train=95629, n_val=20127, n_test=19570, sparsity=0.00048\n" 15 | ] 16 | }, 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "100%|██████████| 19244/19244 [01:10<00:00, 273.56it/s]" 22 | ] 23 | }, 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "text 0.10991018 0.086630285\n", 29 | "img 0.3009104 0.23256181\n" 30 | ] 31 | }, 32 | { 33 | "name": "stderr", 34 | "output_type": "stream", 35 | "text": [ 36 | "\n" 37 | ] 38 | } 39 | ], 40 | "source": [ 41 | "# avg.sim (Figure 1)\n", 42 | "\n", 43 | "import json\n", 44 | "import os\n", 45 | "from utility.load_data import Data\n", 46 | "\n", 47 | "data_generator = Data(path='data/WomenClothing', batch_size=1024)\n", 48 | "\n", 49 | "from copy import deepcopy\n", 50 | "I_items = deepcopy(data_generator.train_items)\n", 51 | "\n", 52 | "for k in I_items.keys():\n", 53 | " I_items[k] = I_items[k] + data_generator.val_set[k] + data_generator.test_set[k]\n", 54 | "\n", 55 | "import numpy as np\n", 56 | "image_feats = np.load('data/WomenClothing/image_feat.npy')\n", 57 | "text_feats = np.load('data/WomenClothing/text_feat.npy')\n", 58 | "\n", 59 | "from collections import defaultdict\n", 60 | "from tqdm import tqdm\n", 61 | "\n", 62 | "img_cos = np.dot(image_feats, image_feats.T) / (np.linalg.norm(image_feats, axis=1)[:, np.newaxis] * np.linalg.norm(image_feats, axis=1)[:, np.newaxis].T)\n", 63 | "text_cos = np.dot(text_feats, text_feats.T) / (np.linalg.norm(text_feats, axis=1)[:, np.newaxis] * np.linalg.norm(text_feats, axis=1)[:, np.newaxis].T)\n", 64 | "\n", 65 | "seen_img = []\n", 66 | "seen_text = []\n", 67 | "unseen_img = []\n", 68 | "unseen_text = []\n", 69 | "for user, items in tqdm(I_items.items()):\n", 70 | " img = img_cos[items][:, items]\n", 71 | " text = text_cos[items][:, items]\n", 72 | "\n", 73 | " seen_img_result = []\n", 74 | " seen_text_result = []\n", 75 | " for i in range(len(items)):\n", 76 | " seen_img_result.append(np.concatenate([img[i, :i], img[i, i+1:]]))\n", 77 | " seen_text_result.append(np.concatenate([text[i, :i], text[i, i+1:]]))\n", 78 | " seen_img_result = np.array(seen_img_result) # .flatten()\n", 79 | " seen_text_result = np.array(seen_text_result) # .flatten()\n", 80 | "\n", 81 | " unseen_items = set(range(data_generator.n_items)) - set(items)\n", 82 | " unseen_items = list(unseen_items)\n", 83 | "\n", 84 | " unseen_img_result = img_cos[items][:, unseen_items].flatten()\n", 85 | " unseen_text_result = text_cos[items][:, unseen_items].flatten()\n", 86 | "\n", 87 | " seen_img.append(seen_img_result.mean())\n", 88 | " seen_text.append(seen_text_result.mean())\n", 89 | " unseen_img.append(unseen_img_result.mean())\n", 90 | " unseen_text.append(unseen_text_result.mean())\n", 91 | "\n", 92 | "print('text', np.mean(seen_text), np.mean(unseen_text))\n", 93 | "print('img', np.mean(seen_img), np.mean(unseen_img))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 1, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stderr", 103 | "output_type": "stream", 104 | "text": [ 105 | "/home/ubuntu/anaconda3/envs/yg/lib/python3.6/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 106 | " from .autonotebook import tqdm as notebook_tqdm\n" 107 | ] 108 | }, 109 | { 110 | "name": "stdout", 111 | "output_type": "stream", 112 | "text": [ 113 | "n_users=19244, n_items=14596\n", 114 | "n_interactions=135326\n", 115 | "n_train=95629, n_val=20127, n_test=19570, sparsity=0.00048\n", 116 | "Loads image_emb: torch.Size([33840, 64]) and text_emb: torch.Size([33840, 64])\n", 117 | "0.14228745 0.11385205\n", 118 | "Loads image_emb: torch.Size([33840, 64]) and text_emb: torch.Size([33840, 64])\n", 119 | "0.3034921 0.10677319\n", 120 | "Loads image_emb: torch.Size([33840, 64]) and text_emb: torch.Size([33840, 64])\n", 121 | "0.3312145 0.1141393\n", 122 | "Loads image_emb: torch.Size([33840, 64]) and text_emb: torch.Size([33840, 64])\n", 123 | "0.17027126 0.110791825\n", 124 | "Loads image_emb: torch.Size([33840, 64]) and text_emb: torch.Size([33840, 64])\n", 125 | "0.2663411 0.11235878\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "# avg.diff\n", 131 | "from Models import *\n", 132 | "\n", 133 | "import os\n", 134 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n", 135 | "import torch\n", 136 | "import numpy as np\n", 137 | "from utility.load_data import Data\n", 138 | "data_generator = Data(path='data/WomenClothing', batch_size=1024)\n", 139 | "\n", 140 | "def sparse_mx_to_torch_sparse_tensor(sparse_mx):\n", 141 | " \"\"\"Convert a scipy sparse matrix to a torch sparse tensor.\"\"\"\n", 142 | " sparse_mx = sparse_mx.tocoo().astype(np.float32)\n", 143 | " indices = torch.from_numpy(\n", 144 | " np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))\n", 145 | " values = torch.from_numpy(sparse_mx.data)\n", 146 | " shape = torch.Size(sparse_mx.shape)\n", 147 | " return torch.sparse.FloatTensor(indices, values, shape)\n", 148 | "\n", 149 | "model_name = 'MONET_concat_20_03'\n", 150 | "seed_list = ['123', '0', '42', '1024', '2048']\n", 151 | "nonzero_idx = data_generator.nonzero_idx()\n", 152 | "\n", 153 | "import numpy as np\n", 154 | "image_feats = np.load('data/WomenClothing/image_feat.npy')\n", 155 | "text_feats = np.load('data/WomenClothing/text_feat.npy')\n", 156 | " \n", 157 | "for seed in seed_list: \n", 158 | " model = MONET(data_generator.n_users, data_generator.n_items, 64, nonzero_idx, True, image_feats, text_feats, 2, 1.0, 0.3, 'concat', 's', False) \n", 159 | " model.load_state_dict(torch.load('./models/' + 'WomenClothing' + '_' + model_name + '_' + seed, map_location=torch.device('cpu'))[model_name + '_' + seed])\n", 160 | " model.cuda()\n", 161 | " image_emb, text_emb = model(eval=True)\n", 162 | " print('Loads image_emb: {} and text_emb: {}'.format(image_emb.shape, text_emb.shape))\n", 163 | "\n", 164 | " # user_emb = torch.load('data/{}/{}_user_emb.pt'.format('clothing', 'lightgcn_layer3_original')).cuda()\n", 165 | " # item_emb = torch.load('data/{}/{}_item_emb.pt'.format('clothing', 'lightgcn_layer3_original')).cuda()\n", 166 | " # print('Loads user_emb: {} and item_emb: {}'.format(user_emb.weight.shape, item_emb.weight.shape))\n", 167 | "\n", 168 | " # image_emb = image_emb.mean(dim=1, keepdim=False)\n", 169 | " # text_emb = text_emb.mean(dim=1, keepdim=False)\n", 170 | "\n", 171 | " # image_emb = image_emb[:, -1, :]\n", 172 | " # text_emb = text_emb[:, -1, :]\n", 173 | "\n", 174 | "\n", 175 | " final_image_preference, final_image_emb = torch.split(image_emb, [data_generator.n_users, data_generator.n_items], dim=0)\n", 176 | " final_text_preference, final_text_emb = torch.split(text_emb, [data_generator.n_users, data_generator.n_items], dim=0)\n", 177 | "\n", 178 | " final_text_emb, final_image_emb = final_text_emb.cpu().detach().numpy(), final_image_emb.cpu().detach().numpy()\n", 179 | "\n", 180 | " final_image_cos = np.dot(final_image_emb, final_image_emb.T) / (np.linalg.norm(final_image_emb, axis=1)[:, np.newaxis] * np.linalg.norm(final_image_emb, axis=1)[:, np.newaxis].T)\n", 181 | " final_text_cos = np.dot(final_text_emb, final_text_emb.T) / (np.linalg.norm(final_text_emb, axis=1)[:, np.newaxis] * np.linalg.norm(final_text_emb, axis=1)[:, np.newaxis].T)\n", 182 | "\n", 183 | " img_cos = np.dot(image_feats, image_feats.T) / (np.linalg.norm(image_feats, axis=1)[:, np.newaxis] * np.linalg.norm(image_feats, axis=1)[:, np.newaxis].T)\n", 184 | " text_cos = np.dot(text_feats, text_feats.T) / (np.linalg.norm(text_feats, axis=1)[:, np.newaxis] * np.linalg.norm(text_feats, axis=1)[:, np.newaxis].T)\n", 185 | "\n", 186 | " img_diff = np.abs(img_cos - final_image_cos)\n", 187 | " text_diff = np.abs(text_cos - final_text_cos)\n", 188 | "\n", 189 | " img = []\n", 190 | " for i in range(data_generator.n_items):\n", 191 | " img.append(np.concatenate([img_diff[i, :i], img_diff[i, i+1:]]))\n", 192 | " img = np.array(img) # .flatten()\n", 193 | "\n", 194 | " txt = []\n", 195 | " for i in range(data_generator.n_items):\n", 196 | " txt.append(np.concatenate([text_diff[i, :i], text_diff[i, i+1:]]))\n", 197 | " txt = np.array(txt) # .flatten()\n", 198 | "\n", 199 | " print(img[~np.isnan(img)].mean(), txt[~np.isnan(txt)].mean())" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [] 208 | } 209 | ], 210 | "metadata": { 211 | "interpreter": { 212 | "hash": "0aa7af790e1209bd084877485dad105a461ac2ebd38ac99cff72d3e7c0921c3c" 213 | }, 214 | "kernelspec": { 215 | "display_name": "yg", 216 | "language": "python", 217 | "name": "yg" 218 | }, 219 | "language_info": { 220 | "codemirror_mode": { 221 | "name": "ipython", 222 | "version": 3 223 | }, 224 | "file_extension": ".py", 225 | "mimetype": "text/x-python", 226 | "name": "python", 227 | "nbconvert_exporter": "python", 228 | "pygments_lexer": "ipython3", 229 | "version": "3.6.13 |Anaconda, Inc.| (default, Jun 4 2021, 14:25:59) \n[GCC 7.5.0]" 230 | }, 231 | "orig_nbformat": 4 232 | }, 233 | "nbformat": 4, 234 | "nbformat_minor": 2 235 | } 236 | -------------------------------------------------------------------------------- /codes/data/build_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import array 3 | import gzip 4 | import json 5 | import os 6 | from collections import defaultdict 7 | 8 | import numpy as np 9 | import pandas as pd 10 | from sentence_transformers import SentenceTransformer 11 | 12 | 13 | def dataset_merge_and_split(path, core): 14 | if not os.path.exists(folder + "%d-core" % core): 15 | os.makedirs(folder + "%d-core" % core) 16 | 17 | df = pd.read_csv(path + "/train.csv", index_col=None, usecols=None) 18 | # Construct matrix 19 | ui = defaultdict(list) 20 | for _, row in df.iterrows(): 21 | user, item = int(row["userID"]), int(row["itemID"]) 22 | ui[user].append(item) 23 | 24 | df = pd.read_csv(path + "/test.csv", index_col=None, usecols=None) 25 | for _, row in df.iterrows(): 26 | user, item = int(row["userID"]), int(row["itemID"]) 27 | ui[user].append(item) 28 | 29 | train_json = {} 30 | val_json = {} 31 | test_json = {} 32 | for u, items in ui.items(): 33 | if len(items) < 10: 34 | testval = np.random.choice(len(items), 2, replace=False) 35 | else: 36 | testval = np.random.choice(len(items), int(len(items) * 0.2), replace=False) 37 | 38 | test = testval[: len(testval) // 2] 39 | val = testval[len(testval) // 2 :] 40 | train = [i for i in list(range(len(items))) if i not in testval] 41 | train_json[u] = [items[idx] for idx in train] 42 | val_json[u] = [items[idx] for idx in val.tolist()] 43 | test_json[u] = [items[idx] for idx in test.tolist()] 44 | 45 | with open(path + "/5-core/train.json", "w") as f: 46 | json.dump(train_json, f) 47 | with open(path + "/5-core/val.json", "w") as f: 48 | json.dump(val_json, f) 49 | with open(path + "/5-core/test.json", "w") as f: 50 | json.dump(test_json, f) 51 | 52 | 53 | def load_textual_image_features(data_path): 54 | import json 55 | import os 56 | 57 | from gensim.models.doc2vec import Doc2Vec 58 | 59 | asin_dict = json.load(open(os.path.join(data_path, "asin_sample.json"), "r")) 60 | 61 | # Prepare textual feture data. 62 | doc2vec_model = Doc2Vec.load(os.path.join(data_path, "doc2vecFile")) 63 | vis_vec = np.load( 64 | os.path.join(data_path, "image_feature.npy"), allow_pickle=True 65 | ).item() 66 | text_vec = {} 67 | for asin in asin_dict: 68 | text_vec[asin] = doc2vec_model.docvecs[asin] 69 | 70 | all_dict = {} 71 | num_items = 0 72 | filename = data_path + "/train.csv" 73 | df = pd.read_csv(filename, index_col=None, usecols=None) 74 | for _, row in df.iterrows(): 75 | asin, i = row["asin"], int(row["itemID"]) 76 | all_dict[i] = asin 77 | num_items = max(num_items, i) 78 | filename = data_path + "/test.csv" 79 | df = pd.read_csv(filename, index_col=None, usecols=None) 80 | for _, row in df.iterrows(): 81 | asin, i = row["asin"], int(row["itemID"]) 82 | all_dict[i] = asin 83 | num_items = max(num_items, i) 84 | 85 | t_features = [] 86 | v_features = [] 87 | for i in range(num_items + 1): 88 | t_features.append(text_vec[all_dict[i]]) 89 | v_features.append(vis_vec[all_dict[i]]) 90 | 91 | np.save(data_path + "/text_feat.npy", np.asarray(t_features, dtype=np.float32)) 92 | np.save(data_path + "/image_feat.npy", np.asarray(v_features, dtype=np.float32)) 93 | 94 | 95 | parser = argparse.ArgumentParser(description="") 96 | 97 | parser.add_argument( 98 | "--name", 99 | nargs="?", 100 | default="MenClothing", 101 | help="Choose a dataset folder from {MenClothing, WomenClothing, Beauty, Toys_and_Games}.", 102 | ) 103 | 104 | np.random.seed(123) 105 | 106 | args = parser.parse_args() 107 | folder = args.name + "/" 108 | name = args.name 109 | core = 5 110 | if folder in ["MenClothing/", "WomenClothing/"]: 111 | dataset_merge_and_split(folder, core) 112 | load_textual_image_features(folder) 113 | else: 114 | bert_path = "sentence-transformers/stsb-roberta-large" 115 | bert_model = SentenceTransformer(bert_path) 116 | 117 | if not os.path.exists(folder + "%d-core" % core): 118 | os.makedirs(folder + "%d-core" % core) 119 | 120 | def parse(path): 121 | g = gzip.open(path, "r") 122 | for line in g: 123 | yield json.dumps(eval(line)) 124 | 125 | print("----------parse metadata----------") 126 | if not os.path.exists(folder + "meta-data/meta.json"): 127 | with open(folder + "meta-data/meta.json", "w") as f: 128 | for line in parse(folder + "meta-data/" + "meta_%s.json.gz" % (name)): 129 | f.write(line + "\n") 130 | 131 | print("----------parse data----------") 132 | if not os.path.exists(folder + "meta-data/%d-core.json" % core): 133 | with open(folder + "meta-data/%d-core.json" % core, "w") as f: 134 | for line in parse( 135 | folder + "meta-data/" + "reviews_%s_%d.json.gz" % (name, core) 136 | ): 137 | f.write(line + "\n") 138 | 139 | print("----------load data----------") 140 | jsons = [] 141 | for line in open(folder + "meta-data/%d-core.json" % core).readlines(): 142 | jsons.append(json.loads(line)) 143 | 144 | print("----------Build dict----------") 145 | items = set() 146 | users = set() 147 | for j in jsons: 148 | items.add(j["asin"]) 149 | users.add(j["reviewerID"]) 150 | print("n_items:", len(items), "n_users:", len(users)) 151 | 152 | item2id = {} 153 | with open(folder + "%d-core/item_list.txt" % core, "w") as f: 154 | for i, item in enumerate(items): 155 | item2id[item] = i 156 | f.writelines(item + "\t" + str(i) + "\n") 157 | 158 | user2id = {} 159 | with open(folder + "%d-core/user_list.txt" % core, "w") as f: 160 | for i, user in enumerate(users): 161 | user2id[user] = i 162 | f.writelines(user + "\t" + str(i) + "\n") 163 | 164 | ui = defaultdict(list) 165 | review2id = {} 166 | review_text = {} 167 | ratings = {} 168 | with open(folder + "%d-core/review_list.txt" % core, "w") as f: 169 | for j in jsons: 170 | u_id = user2id[j["reviewerID"]] 171 | i_id = item2id[j["asin"]] 172 | ui[u_id].append(i_id) # ui[u_id].append(i_id) 173 | review_text[len(review2id)] = j["reviewText"].replace("\n", " ") 174 | ratings[len(review2id)] = int(j["overall"]) 175 | f.writelines(str((u_id, i_id)) + "\t" + str(len(review2id)) + "\n") 176 | review2id[u_id, i_id] = len(review2id) 177 | with open(folder + "%d-core/user-item-dict.json" % core, "w") as f: 178 | f.write(json.dumps(ui)) 179 | with open(folder + "%d-core/rating-dict.json" % core, "w") as f: 180 | f.write(json.dumps(ratings)) 181 | 182 | review_texts = [] 183 | with open(folder + "%d-core/review_text.txt" % core, "w") as f: 184 | for i, j in review2id: 185 | f.write(review_text[review2id[i, j]] + "\n") 186 | review_texts.append(review_text[review2id[i, j]] + "\n") 187 | review_embeddings = bert_model.encode(review_texts) 188 | assert review_embeddings.shape[0] == len(review2id) 189 | np.save(folder + "review_feat.npy", review_embeddings) 190 | 191 | print("----------Split Data----------") 192 | train_json = {} 193 | val_json = {} 194 | test_json = {} 195 | for u, items in ui.items(): 196 | if len(items) < 10: 197 | testval = np.random.choice(len(items), 2, replace=False) 198 | else: 199 | testval = np.random.choice(len(items), int(len(items) * 0.2), replace=False) 200 | 201 | test = testval[: len(testval) // 2] 202 | val = testval[len(testval) // 2 :] 203 | train = [i for i in list(range(len(items))) if i not in testval] 204 | train_json[u] = [items[idx] for idx in train] 205 | val_json[u] = [items[idx] for idx in val.tolist()] 206 | test_json[u] = [items[idx] for idx in test.tolist()] 207 | 208 | with open(folder + "%d-core/train.json" % core, "w") as f: 209 | json.dump(train_json, f) 210 | with open(folder + "%d-core/val.json" % core, "w") as f: 211 | json.dump(val_json, f) 212 | with open(folder + "%d-core/test.json" % core, "w") as f: 213 | json.dump(test_json, f) 214 | 215 | jsons = [] 216 | with open(folder + "meta-data/meta.json", "r") as f: 217 | for line in f.readlines(): 218 | jsons.append(json.loads(line)) 219 | 220 | print("----------Text Features----------") 221 | raw_text = {} 222 | for _json in jsons: 223 | if _json["asin"] in item2id: 224 | string = " " 225 | if "categories" in _json: 226 | for cates in _json["categories"]: 227 | for cate in cates: 228 | string += cate + " " 229 | if "title" in _json: 230 | string += _json["title"] 231 | if "brand" in _json: 232 | string += _json["title"] 233 | if "description" in _json: 234 | string += _json["description"] 235 | raw_text[item2id[_json["asin"]]] = string.replace("\n", " ") 236 | texts = [] 237 | with open(folder + "%d-core/raw_text.txt" % core, "w") as f: 238 | for i in range(len(item2id)): 239 | f.write(raw_text[i] + "\n") 240 | texts.append(raw_text[i] + "\n") 241 | sentence_embeddings = bert_model.encode(texts) 242 | assert sentence_embeddings.shape[0] == len(item2id) 243 | np.save(folder + "text_feat.npy", sentence_embeddings) 244 | 245 | print("----------Image Features----------") 246 | 247 | def readImageFeatures(path): 248 | f = open(path, "rb") 249 | while True: 250 | asin = f.read(10).decode("UTF-8") 251 | if asin == "": 252 | break 253 | a = array.array("f") 254 | a.fromfile(f, 4096) 255 | yield asin, a.tolist() 256 | 257 | data = readImageFeatures(folder + "meta-data/" + "image_features_%s.b" % name) 258 | feats = {} 259 | avg = [] 260 | for d in data: 261 | if d[0] in item2id: 262 | feats[int(item2id[d[0]])] = d[1] 263 | avg.append(d[1]) 264 | avg = np.array(avg).mean(0).tolist() 265 | 266 | ret = [] 267 | for i in range(len(item2id)): 268 | if i in feats: 269 | ret.append(feats[i]) 270 | else: 271 | ret.append(avg) 272 | 273 | assert len(ret) == len(item2id) 274 | np.save(folder + "image_feat.npy", np.array(ret)) 275 | -------------------------------------------------------------------------------- /codes/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | from torch_geometric.nn.conv import MessagePassing 6 | from torch_geometric.utils.num_nodes import maybe_num_nodes 7 | from torch_scatter import scatter_add 8 | 9 | 10 | def normalize_laplacian(edge_index, edge_weight): 11 | num_nodes = maybe_num_nodes(edge_index) 12 | row, col = edge_index[0], edge_index[1] 13 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 14 | 15 | deg_inv_sqrt = deg.pow_(-0.5) 16 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float("inf"), 0) 17 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 18 | return edge_weight 19 | 20 | 21 | class Our_GCNs(MessagePassing): 22 | def __init__(self, in_channels, out_channels): 23 | super(Our_GCNs, self).__init__(aggr="add") 24 | self.in_channels = in_channels 25 | self.out_channels = out_channels 26 | 27 | def forward(self, x, edge_index, weight_vector, size=None): 28 | self.weight_vector = weight_vector 29 | return self.propagate(edge_index, size=size, x=x) 30 | 31 | def message(self, x_j): 32 | return x_j * self.weight_vector 33 | 34 | def update(self, aggr_out): 35 | return aggr_out 36 | 37 | 38 | from torch_geometric.nn.inits import uniform 39 | 40 | 41 | class Nonlinear_GCNs(MessagePassing): 42 | def __init__(self, in_channels, out_channels): 43 | super(Nonlinear_GCNs, self).__init__(aggr="add") 44 | self.in_channels = in_channels 45 | self.out_channels = out_channels 46 | self.weight = Parameter(torch.Tensor(self.in_channels, out_channels)) 47 | self.reset_parameters() 48 | 49 | def reset_parameters(self): 50 | uniform(self.in_channels, self.weight) 51 | 52 | def forward(self, x, edge_index, weight_vector, size=None): 53 | x = torch.matmul(x, self.weight) 54 | self.weight_vector = weight_vector 55 | return self.propagate(edge_index, size=size, x=x) 56 | 57 | def message(self, x_j): 58 | return x_j * self.weight_vector 59 | 60 | def update(self, aggr_out): 61 | return aggr_out 62 | 63 | 64 | class MeGCN(nn.Module): 65 | def __init__( 66 | self, 67 | n_users, 68 | n_items, 69 | n_layers, 70 | has_norm, 71 | feat_embed_dim, 72 | nonzero_idx, 73 | image_feats, 74 | text_feats, 75 | alpha, 76 | agg, 77 | cf, 78 | cf_gcn, 79 | lightgcn, 80 | ): 81 | super(MeGCN, self).__init__() 82 | self.n_users = n_users 83 | self.n_items = n_items 84 | self.n_layers = n_layers 85 | self.has_norm = has_norm 86 | self.feat_embed_dim = feat_embed_dim 87 | self.nonzero_idx = torch.tensor(nonzero_idx).cuda().long().T 88 | self.alpha = alpha 89 | self.agg = agg 90 | self.cf = cf 91 | self.cf_gcn = cf_gcn 92 | self.lightgcn = lightgcn 93 | 94 | self.image_preference = nn.Embedding(self.n_users, self.feat_embed_dim) 95 | self.text_preference = nn.Embedding(self.n_users, self.feat_embed_dim) 96 | nn.init.xavier_uniform_(self.image_preference.weight) 97 | nn.init.xavier_uniform_(self.text_preference.weight) 98 | 99 | self.image_embedding = nn.Embedding.from_pretrained( 100 | torch.tensor(image_feats, dtype=torch.float), freeze=True 101 | ) # [# of items, 4096] 102 | self.text_embedding = nn.Embedding.from_pretrained( 103 | torch.tensor(text_feats, dtype=torch.float), freeze=True 104 | ) # [# of items, 1024] 105 | 106 | if self.cf: 107 | self.user_embedding = nn.Embedding(self.n_users, self.feat_embed_dim) 108 | self.item_embedding = nn.Embedding(self.n_items, self.feat_embed_dim) 109 | nn.init.xavier_uniform_(self.user_embedding.weight) 110 | nn.init.xavier_uniform_(self.item_embedding.weight) 111 | 112 | self.image_trs = nn.Linear(image_feats.shape[1], self.feat_embed_dim) 113 | self.text_trs = nn.Linear(text_feats.shape[1], self.feat_embed_dim) 114 | 115 | if not self.cf: 116 | if self.agg == "fc": 117 | self.transform = nn.Linear(self.feat_embed_dim * 2, self.feat_embed_dim) 118 | elif self.agg == "weighted_sum": 119 | self.modal_weight = nn.Parameter(torch.Tensor([0.5, 0.5])) 120 | self.softmax = nn.Softmax(dim=0) 121 | else: 122 | if self.agg == "fc": 123 | self.transform = nn.Linear(self.feat_embed_dim * 3, self.feat_embed_dim) 124 | elif self.agg == "weighted_sum": 125 | self.modal_weight = nn.Parameter(torch.Tensor([0.33, 0.33, 0.33])) 126 | self.softmax = nn.Softmax(dim=0) 127 | 128 | self.layers = nn.ModuleList( 129 | [ 130 | Our_GCNs(self.feat_embed_dim, self.feat_embed_dim) 131 | for _ in range(self.n_layers) 132 | ] 133 | ) 134 | 135 | def forward(self, edge_index, edge_weight, _eval=False): 136 | # transform 137 | image_emb = self.image_trs( 138 | self.image_embedding.weight 139 | ) # [# of items, feat_embed_dim] 140 | text_emb = self.text_trs( 141 | self.text_embedding.weight 142 | ) # [# of items, feat_embed_dim] 143 | 144 | if self.has_norm: 145 | image_emb = F.normalize(image_emb) 146 | text_emb = F.normalize(text_emb) 147 | image_preference = self.image_preference.weight 148 | text_preference = self.text_preference.weight 149 | 150 | # propagate 151 | ego_image_emb = torch.cat([image_preference, image_emb], dim=0) 152 | ego_text_emb = torch.cat([text_preference, text_emb], dim=0) 153 | 154 | if self.cf: 155 | user_emb = self.user_embedding.weight 156 | item_emb = self.item_embedding.weight 157 | ego_cf_emb = torch.cat([user_emb, item_emb], dim=0) 158 | if self.cf_gcn == "LightGCN": 159 | all_cf_emb = [ego_cf_emb] 160 | 161 | if self.lightgcn: 162 | all_image_emb = [ego_image_emb] 163 | all_text_emb = [ego_text_emb] 164 | 165 | for layer in self.layers: 166 | if not self.lightgcn: 167 | side_image_emb = layer(ego_image_emb, edge_index, edge_weight) 168 | side_text_emb = layer(ego_text_emb, edge_index, edge_weight) 169 | 170 | ego_image_emb = side_image_emb + self.alpha * ego_image_emb 171 | ego_text_emb = side_text_emb + self.alpha * ego_text_emb 172 | else: 173 | side_image_emb = layer(ego_image_emb, edge_index, edge_weight) 174 | side_text_emb = layer(ego_text_emb, edge_index, edge_weight) 175 | ego_image_emb = side_image_emb 176 | ego_text_emb = side_text_emb 177 | all_image_emb += [ego_image_emb] 178 | all_text_emb += [ego_text_emb] 179 | if self.cf: 180 | if self.cf_gcn == "MeGCN": 181 | side_cf_emb = layer(ego_cf_emb, edge_index, edge_weight) 182 | ego_cf_emb = side_cf_emb + self.alpha * ego_cf_emb 183 | elif self.cf_gcn == "LightGCN": 184 | side_cf_emb = layer(ego_cf_emb, edge_index, edge_weight) 185 | ego_cf_emb = side_cf_emb 186 | all_cf_emb += [ego_cf_emb] 187 | 188 | if not self.lightgcn: 189 | final_image_preference, final_image_emb = torch.split( 190 | ego_image_emb, [self.n_users, self.n_items], dim=0 191 | ) 192 | final_text_preference, final_text_emb = torch.split( 193 | ego_text_emb, [self.n_users, self.n_items], dim=0 194 | ) 195 | else: 196 | all_image_emb = torch.stack(all_image_emb, dim=1) 197 | all_image_emb = all_image_emb.mean(dim=1, keepdim=False) 198 | final_image_preference, final_image_emb = torch.split( 199 | all_image_emb, [self.n_users, self.n_items], dim=0 200 | ) 201 | 202 | all_text_emb = torch.stack(all_text_emb, dim=1) 203 | all_text_emb = all_text_emb.mean(dim=1, keepdim=False) 204 | final_text_preference, final_text_emb = torch.split( 205 | all_text_emb, [self.n_users, self.n_items], dim=0 206 | ) 207 | 208 | if self.cf: 209 | if self.cf_gcn == "MeGCN": 210 | final_cf_user_emb, final_cf_item_emb = torch.split( 211 | ego_cf_emb, [self.n_users, self.n_items], dim=0 212 | ) 213 | elif self.cf_gcn == "LightGCN": 214 | all_cf_emb = torch.stack(all_cf_emb, dim=1) 215 | all_cf_emb = all_cf_emb.mean(dim=1, keepdim=False) 216 | final_cf_user_emb, final_cf_item_emb = torch.split( 217 | all_cf_emb, [self.n_users, self.n_items], dim=0 218 | ) 219 | 220 | if _eval: 221 | return ego_image_emb, ego_text_emb 222 | 223 | if not self.cf: 224 | if self.agg == "concat": 225 | items = torch.cat( 226 | [final_image_emb, final_text_emb], dim=1 227 | ) # [# of items, feat_embed_dim * 2] 228 | user_preference = torch.cat( 229 | [final_image_preference, final_text_preference], dim=1 230 | ) # [# of users, feat_embed_dim * 2] 231 | elif self.agg == "sum": 232 | items = final_image_emb + final_text_emb # [# of items, feat_embed_dim] 233 | user_preference = ( 234 | final_image_preference + final_text_preference 235 | ) # [# of users, feat_embed_dim] 236 | elif self.agg == "weighted_sum": 237 | weight = self.softmax(self.modal_weight) 238 | items = ( 239 | weight[0] * final_image_emb + weight[1] * final_text_emb 240 | ) # [# of items, feat_embed_dim] 241 | user_preference = ( 242 | weight[0] * final_image_preference 243 | + weight[1] * final_text_preference 244 | ) # [# of users, feat_embed_dim] 245 | elif self.agg == "fc": 246 | items = self.transform( 247 | torch.cat([final_image_emb, final_text_emb], dim=1) 248 | ) # [# of items, feat_embed_dim] 249 | user_preference = self.transform( 250 | torch.cat([final_image_preference, final_text_preference], dim=1) 251 | ) # [# of users, feat_embed_dim] 252 | else: 253 | if self.agg == "concat": 254 | items = torch.cat( 255 | [final_image_emb, final_text_emb, final_cf_item_emb], dim=1 256 | ) # [# of items, feat_embed_dim * 2] 257 | user_preference = torch.cat( 258 | [final_image_preference, final_text_preference, final_cf_user_emb], 259 | dim=1, 260 | ) # [# of users, feat_embed_dim * 2] 261 | elif self.agg == "sum": 262 | items = ( 263 | final_image_emb + final_text_emb + final_cf_item_emb 264 | ) # [# of items, feat_embed_dim] 265 | user_preference = ( 266 | final_image_preference + final_text_preference + final_cf_user_emb 267 | ) # [# of users, feat_embed_dim] 268 | elif self.agg == "weighted_sum": 269 | weight = self.softmax(self.modal_weight) 270 | items = ( 271 | weight[0] * final_image_emb 272 | + weight[1] * final_text_emb 273 | + weight[2] * final_cf_item_emb 274 | ) # [# of items, feat_embed_dim] 275 | user_preference = ( 276 | weight[0] * final_image_preference 277 | + weight[1] * final_text_preference 278 | + weight[2] * final_cf_user_emb 279 | ) # [# of users, feat_embed_dim] 280 | elif self.agg == "fc": 281 | items = self.transform( 282 | torch.cat( 283 | [final_image_emb, final_text_emb, final_cf_item_emb], dim=1 284 | ) 285 | ) # [# of items, feat_embed_dim] 286 | user_preference = self.transform( 287 | torch.cat( 288 | [ 289 | final_image_preference, 290 | final_text_preference, 291 | final_cf_user_emb, 292 | ], 293 | dim=1, 294 | ) 295 | ) # [# of users, feat_embed_dim] 296 | 297 | return user_preference, items 298 | 299 | 300 | class MONET(nn.Module): 301 | def __init__( 302 | self, 303 | n_users, 304 | n_items, 305 | feat_embed_dim, 306 | nonzero_idx, 307 | has_norm, 308 | image_feats, 309 | text_feats, 310 | n_layers, 311 | alpha, 312 | beta, 313 | agg, 314 | cf, 315 | cf_gcn, 316 | lightgcn, 317 | ): 318 | super(MONET, self).__init__() 319 | self.n_users = n_users 320 | self.n_items = n_items 321 | self.feat_embed_dim = feat_embed_dim 322 | self.n_layers = n_layers 323 | self.nonzero_idx = nonzero_idx 324 | self.alpha = alpha 325 | self.beta = beta 326 | self.agg = agg 327 | self.image_feats = torch.tensor(image_feats, dtype=torch.float).cuda() 328 | self.text_feats = torch.tensor(text_feats, dtype=torch.float).cuda() 329 | 330 | self.megcn = MeGCN( 331 | self.n_users, 332 | self.n_items, 333 | self.n_layers, 334 | has_norm, 335 | self.feat_embed_dim, 336 | self.nonzero_idx, 337 | image_feats, 338 | text_feats, 339 | self.alpha, 340 | self.agg, 341 | cf, 342 | cf_gcn, 343 | lightgcn, 344 | ) 345 | 346 | nonzero_idx = torch.tensor(self.nonzero_idx).cuda().long().T 347 | nonzero_idx[1] = nonzero_idx[1] + self.n_users 348 | self.edge_index = torch.cat( 349 | [nonzero_idx, torch.stack([nonzero_idx[1], nonzero_idx[0]], dim=0)], dim=1 350 | ) 351 | self.edge_weight = torch.ones((self.edge_index.size(1))).cuda().view(-1, 1) 352 | self.edge_weight = normalize_laplacian(self.edge_index, self.edge_weight) 353 | 354 | nonzero_idx = torch.tensor(self.nonzero_idx).cuda().long().T 355 | self.adj = ( 356 | torch.sparse.FloatTensor( 357 | nonzero_idx, 358 | torch.ones((nonzero_idx.size(1))).cuda(), 359 | (self.n_users, self.n_items), 360 | ) 361 | .to_dense() 362 | .cuda() 363 | ) 364 | 365 | def forward(self, _eval=False): 366 | if _eval: 367 | img, txt = self.megcn(self.edge_index, self.edge_weight, _eval=True) 368 | return img, txt 369 | 370 | user, items = self.megcn(self.edge_index, self.edge_weight, _eval=False) 371 | 372 | return user, items 373 | 374 | def bpr_loss(self, user_emb, item_emb, users, pos_items, neg_items, target_aware): 375 | current_user_emb = user_emb[users] 376 | pos_item_emb = item_emb[pos_items] 377 | neg_item_emb = item_emb[neg_items] 378 | 379 | if target_aware: 380 | # target-aware 381 | item_item = torch.mm(item_emb, item_emb.T) 382 | pos_item_query = item_item[pos_items, :] # (batch_size, n_items) 383 | neg_item_query = item_item[neg_items, :] # (batch_size, n_items) 384 | pos_target_user_alpha = torch.softmax( 385 | torch.multiply(pos_item_query, self.adj[users, :]).masked_fill( 386 | self.adj[users, :] == 0, -1e9 387 | ), 388 | dim=1, 389 | ) # (batch_size, n_items) 390 | neg_target_user_alpha = torch.softmax( 391 | torch.multiply(neg_item_query, self.adj[users, :]).masked_fill( 392 | self.adj[users, :] == 0, -1e9 393 | ), 394 | dim=1, 395 | ) # (batch_size, n_items) 396 | pos_target_user = torch.mm( 397 | pos_target_user_alpha, item_emb 398 | ) # (batch_size, dim) 399 | neg_target_user = torch.mm( 400 | neg_target_user_alpha, item_emb 401 | ) # (batch_size, dim) 402 | 403 | # predictor 404 | pos_scores = (1 - self.beta) * torch.sum( 405 | torch.mul(current_user_emb, pos_item_emb), dim=1 406 | ) + self.beta * torch.sum(torch.mul(pos_target_user, pos_item_emb), dim=1) 407 | neg_scores = (1 - self.beta) * torch.sum( 408 | torch.mul(current_user_emb, neg_item_emb), dim=1 409 | ) + self.beta * torch.sum(torch.mul(neg_target_user, neg_item_emb), dim=1) 410 | else: 411 | pos_scores = torch.sum(torch.mul(current_user_emb, pos_item_emb), dim=1) 412 | neg_scores = torch.sum(torch.mul(current_user_emb, neg_item_emb), dim=1) 413 | 414 | maxi = F.logsigmoid(pos_scores - neg_scores) 415 | mf_loss = -torch.mean(maxi) 416 | 417 | regularizer = ( 418 | 1.0 / 2 * (pos_item_emb**2).sum() 419 | + 1.0 / 2 * (neg_item_emb**2).sum() 420 | + 1.0 / 2 * (current_user_emb**2).sum() 421 | ) 422 | emb_loss = regularizer / pos_item_emb.size(0) 423 | 424 | reg_loss = 0.0 425 | 426 | return mf_loss, emb_loss, reg_loss 427 | --------------------------------------------------------------------------------