├── .figure ├── adressa-result.png ├── mind-leaderboard.png └── mind-result.png ├── .gitignore ├── LICENSE ├── README.md ├── preprocess ├── adressa_raw.py ├── news_process.py └── user_process.py ├── raw ├── README.md └── download.sh └── src ├── agg.py ├── data.py ├── main.py ├── metrics.py └── model.py /.figure/adressa-result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjw1029/Efficient-FedRec/839f967c1ed1c0cb0b1b4d670828437ffb712f29/.figure/adressa-result.png -------------------------------------------------------------------------------- /.figure/mind-leaderboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjw1029/Efficient-FedRec/839f967c1ed1c0cb0b1b4d670828437ffb712f29/.figure/mind-leaderboard.png -------------------------------------------------------------------------------- /.figure/mind-result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjw1029/Efficient-FedRec/839f967c1ed1c0cb0b1b4d670828437ffb712f29/.figure/mind-result.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | checkpoint/ 4 | wandb/ 5 | amlt/ 6 | .amltconfig 7 | config.yaml 8 | 9 | raw/*/ 10 | data/* 11 | output/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 yjw1029 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient-FedRec 2 | Python implementation for our paper "Efficient-FedRec: Efficient Federated Learning Frameworkfor Privacy-Preserving News Recommendation" in EMNLP 2021. 3 | 4 | ## Introduction 5 | Directly applying federated learning on news recommendation models will lead to high computation and communication cost on user side. 6 | In this work, we propose Efficient-FedRec, in which we dicompose the news recommendation model into a large news model maintained on server and a light-weight user model computed on the user side. 7 | Experiments on two public dataset show the effectiveness of our method. 8 | 9 | 10 | ## Environment 11 | Requirments 12 | ``` 13 | numpy 14 | torch==1.9.1 15 | transformers==4.12.5 16 | tqdm 17 | sklearn 18 | wandb 19 | ``` 20 | 21 | ## Getting Started 22 | * Download datasets 23 | ```bash 24 | cd raw 25 | chmod +x download.sh 26 | ./download.sh mind . 27 | ./download.sh adressa . 28 | ``` 29 | * Preprocess datasets 30 | ```bash 31 | cd preprocess 32 | # modify adressa to mind format 33 | python adressa_raw.py 34 | 35 | # preprocess mind dataset 36 | python news_process.py --data mind 37 | python user_process.py --data mind 38 | 39 | # preprocess adressa dataset 40 | python news_process.py --data adressa 41 | python user_process.py --data adressa 42 | ``` 43 | 44 | * Run experiments 45 | ```bash 46 | # You may need to configure your wandb account first 47 | cd src 48 | python main.py --data mind 49 | # get prediction result of the best checkpoint and submit on condalab 50 | python main.py --data mind --mode predict 51 | 52 | # train on adressa 53 | python main.py --data adressa --max_train_steps 500 --validation_step 10 --bert_type NbAiLab/nb-bert-base 54 | # test on adressa 55 | python main.py --data adressa --mode test --bert_type NbAiLab/nb-bert-base 56 | ``` 57 | 58 | 59 | ## Results 60 | 61 | ### MIND 62 | Wandb result on MIND dataset 63 | ![](./.figure/mind-result.png) 64 | Zip the prediction.txt file and upload to MIND competition. Test result is 65 | ![](./.figure/mind-leaderboard.png) 66 | 67 | 68 | ## Adressa 69 | Wandb result on Adressa dataset. 70 | ![](./.figure/adressa-result.png) 71 | Test result is 72 | ``` 73 | test auc: 0.7980, mrr: 0.4637, ndcg5: 0.4852, ndcg10: 0.5497 74 | ``` 75 | 76 | ## Citing 77 | If you want to cite Efficient-Fedrec in your papers (much appreciated!), you can cite it as follows: 78 | ``` 79 | @inproceedings{yi-etal-2021-efficient, 80 | title = "Efficient-{F}ed{R}ec: Efficient Federated Learning Framework for Privacy-Preserving News Recommendation", 81 | author = "Yi, Jingwei and 82 | Wu, Fangzhao and 83 | Wu, Chuhan and 84 | Liu, Ruixuan and 85 | Sun, Guangzhong and 86 | Xie, Xing", 87 | booktitle = "EMNLP", 88 | year = "2021", 89 | pages = "2814--2824" 90 | } 91 | ``` 92 | -------------------------------------------------------------------------------- /preprocess/adressa_raw.py: -------------------------------------------------------------------------------- 1 | # This script is used to construct training, validation and test dataset of adressa. 2 | # We follow existing works[1][2] to split the dataset. 3 | # [1] 4 | # [2] 5 | 6 | import json 7 | import pickle 8 | import argparse 9 | 10 | import numpy as np 11 | from tqdm import tqdm 12 | from pathlib import Path 13 | from collections import defaultdict 14 | from sklearn.model_selection import train_test_split 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | "--adressa_path", 21 | type=str, 22 | default="../raw/adressa/raw/one_week", 23 | help="path to downloaded raw adressa dataset", 24 | ) 25 | parser.add_argument( 26 | "--out_path", 27 | type=str, 28 | default="../raw/adressa/", 29 | help="path to save processed dataset, default in ../raw/adressa", 30 | ) 31 | parser.add_argument( 32 | "--neg_num", 33 | type=int, 34 | default=20, 35 | help="randomly sample neg_num negative impression for every positive behavior", 36 | ) 37 | 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | def process_news(adressa_path): 43 | news_title = {} 44 | 45 | for file in adressa_path.iterdir(): 46 | with open(file, "r") as f: 47 | for l in tqdm(f): 48 | event_dict = json.loads(l.strip("\n")) 49 | if "id" in event_dict and "title" in event_dict: 50 | if event_dict["id"] not in news_title: 51 | news_title[event_dict["id"]] = event_dict["title"] 52 | else: 53 | assert news_title[event_dict["id"]] == event_dict["title"] 54 | 55 | nid2index = {k: v for k, v in zip(news_title.keys(), range(1, len(news_title) + 1))} 56 | return news_title, nid2index 57 | 58 | 59 | def write_news_files(news_title, nid2index, out_path): 60 | # Output with MIND format 61 | news_lines = [] 62 | for nid in tqdm(news_title): 63 | nindex = nid2index[nid] 64 | title = news_title[nid] 65 | news_line = "\t".join([str(nindex), "", "", title, "", "", "", ""]) + "\n" 66 | news_lines.append(news_line) 67 | 68 | for stage in ["train", "valid", "test"]: 69 | file_path = out_path / stage 70 | file_path.mkdir(exist_ok=True, parents=True) 71 | with open(out_path / stage / "news.tsv", "w", encoding="utf-8") as f: 72 | f.writelines(news_lines) 73 | 74 | 75 | class UserInfo: 76 | def __init__(self, train_day=6, test_day=7): 77 | self.click_news = [] 78 | self.click_time = [] 79 | self.click_days = [] 80 | 81 | self.train_news = [] 82 | self.train_time = [] 83 | self.train_days = [] 84 | 85 | self.test_news = [] 86 | self.test_time = [] 87 | self.test_days = [] 88 | 89 | self.train_day = train_day 90 | self.test_day = test_day 91 | 92 | def update(self, nindex, time, day): 93 | if day == self.train_day: 94 | self.train_news.append(nindex) 95 | self.train_time.append(time) 96 | self.train_days.append(day) 97 | elif day == self.test_day: 98 | self.test_news.append(nindex) 99 | self.test_time.append(time) 100 | self.test_days.append(day) 101 | else: 102 | self.click_news.append(nindex) 103 | self.click_time.append(time) 104 | self.click_days.append(day) 105 | 106 | def sort_click(self): 107 | self.click_news = np.array(self.click_news, dtype="int32") 108 | self.click_time = np.array(self.click_time, dtype="int32") 109 | self.click_days = np.array(self.click_days, dtype="int32") 110 | 111 | self.train_news = np.array(self.train_news, dtype="int32") 112 | self.train_time = np.array(self.train_time, dtype="int32") 113 | self.train_days = np.array(self.train_days, dtype="int32") 114 | 115 | self.test_news = np.array(self.test_news, dtype="int32") 116 | self.test_time = np.array(self.test_time, dtype="int32") 117 | self.test_days = np.array(self.test_days, dtype="int32") 118 | 119 | order = np.argsort(self.train_time) 120 | self.train_time = self.train_time[order] 121 | self.train_days = self.train_days[order] 122 | self.train_news = self.train_news[order] 123 | 124 | order = np.argsort(self.test_time) 125 | self.test_time = self.test_time[order] 126 | self.test_days = self.test_days[order] 127 | self.test_news = self.test_news[order] 128 | 129 | order = np.argsort(self.click_time) 130 | self.click_time = self.click_time[order] 131 | self.click_days = self.click_days[order] 132 | self.click_news = self.click_news[order] 133 | 134 | 135 | def process_users(adressa_path): 136 | uid2index = {} 137 | user_info = defaultdict(UserInfo) 138 | 139 | for file in adressa_path.iterdir(): 140 | with open(file, "r") as f: 141 | for l in tqdm(f): 142 | event_dict = json.loads(l.strip("\n")) 143 | if "id" in event_dict and "title" in event_dict: 144 | nindex = nid2index[event_dict["id"]] 145 | uid = event_dict["userId"] 146 | 147 | if uid not in uid2index: 148 | uid2index[uid] = len(uid2index) 149 | 150 | uindex = uid2index[uid] 151 | click_time = int(event_dict["time"]) 152 | day = int(file.name[-1]) 153 | user_info[uindex].update(nindex, click_time, day) 154 | 155 | return uid2index, user_info 156 | 157 | 158 | def construct_behaviors(uindex, click_news, train_news, test_news, neg_num): 159 | p = np.ones(len(news_title) + 1, dtype="float32") 160 | p[click_news] = 0 161 | p[train_news] = 0 162 | p[test_news] = 0 163 | p[0] = 0 164 | p /= p.sum() 165 | 166 | train_his_news = [str(i) for i in click_news.tolist()] 167 | train_his_line = " ".join(train_his_news) 168 | 169 | for nindex in train_news: 170 | neg_cand = np.random.choice( 171 | len(news_title) + 1, size=neg_num, replace=False, p=p 172 | ).tolist() 173 | cand_news = " ".join( 174 | [f"{str(nindex)}-1"] + [f"{str(nindex)}-0" for nindex in neg_cand] 175 | ) 176 | 177 | train_behavior_line = f"null\t{uindex}\tnull\t{train_his_line}\t{cand_news}\n" 178 | train_lines.append(train_behavior_line) 179 | 180 | test_his_news = [str(i) for i in click_news.tolist() + train_news.tolist()] 181 | test_his_line = " ".join(test_his_news) 182 | for nindex in test_news: 183 | neg_cand = np.random.choice( 184 | len(news_title) + 1, size=neg_num, replace=False, p=p 185 | ).tolist() 186 | cand_news = " ".join( 187 | [f"{str(nindex)}-1"] + [f"{str(nindex)}-0" for nindex in neg_cand] 188 | ) 189 | 190 | test_behavior_line = f"null\t{uindex}\tnull\t{test_his_line}\t{cand_news}\n" 191 | test_lines.append(test_behavior_line) 192 | 193 | 194 | if __name__ == "__main__": 195 | args = parse_args() 196 | adressa_path = Path(args.adressa_path) 197 | out_path = Path(args.out_path) 198 | 199 | news_title, nid2index = process_news(adressa_path) 200 | write_news_files(news_title, nid2index, out_path) 201 | 202 | uid2index, user_info = process_users(adressa_path) 203 | for uid in tqdm(user_info): 204 | user_info[uid].sort_click() 205 | 206 | train_lines = [] 207 | test_lines = [] 208 | for uindex in tqdm(user_info): 209 | uinfo = user_info[uindex] 210 | click_news = uinfo.click_news 211 | train_news = uinfo.train_news 212 | test_news = uinfo.test_news 213 | construct_behaviors(uindex, click_news, train_news, test_news, args.neg_num) 214 | 215 | test_split_lines, valid_split_lines = train_test_split( 216 | test_lines, test_size=0.2, random_state=2021 217 | ) 218 | with open(out_path / "train" / "behaviors.tsv", "w", encoding="utf-8") as f: 219 | f.writelines(train_lines) 220 | 221 | with open(out_path / "valid" / "behaviors.tsv", "w", encoding="utf-8") as f: 222 | f.writelines(valid_split_lines) 223 | 224 | with open(out_path / "test" / "behaviors.tsv", "w", encoding="utf-8") as f: 225 | f.writelines(test_split_lines) 226 | -------------------------------------------------------------------------------- /preprocess/news_process.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | import numpy as np 5 | import os 6 | import pickle 7 | import argparse 8 | # config 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument( 13 | "--raw_path", 14 | type=str, 15 | default="../raw/", 16 | help="path to raw mind dataset or parsed ", 17 | ) 18 | parser.add_argument( 19 | "--out_path", 20 | type=str, 21 | default="../data/", 22 | help="path to save processed dataset, default in ../raw/mind/preprocess", 23 | ) 24 | parser.add_argument( 25 | "--data", 26 | type=str, 27 | default="mind", 28 | choices=["mind", "adressa"], 29 | help="decide which dataset for preprocess" 30 | ) 31 | parser.add_argument( 32 | "--npratio", 33 | type=int, 34 | default=4 35 | ) 36 | parser.add_argument( 37 | "--max_his_len", type=int, default=50 38 | ) 39 | parser.add_argument("--min_word_cnt", type=int, default=3) 40 | parser.add_argument("--max_title_len", type=int, default=30) 41 | 42 | args = parser.parse_args() 43 | return args 44 | 45 | 46 | if __name__ == "__main__": 47 | args = parse_args() 48 | raw_path = Path(args.raw_path) / args.data 49 | out_path = Path(args.out_path) / args.data 50 | 51 | if not raw_path.is_dir(): 52 | raise ValueError(f"{raw_path.name} does not exist.") 53 | 54 | out_path.mkdir(exist_ok=True, parents=True) 55 | 56 | if args.data == "mind": 57 | model_type = "bert-base-uncased" 58 | else: 59 | model_type = "NbAiLab/nb-bert-base" 60 | 61 | tokenizer = BertTokenizer.from_pretrained(model_type) 62 | 63 | 64 | # news preprocess 65 | nid2index = {"": 0} 66 | news_index = [[[0] * args.max_title_len, [0] * args.max_title_len]] 67 | 68 | for l in tqdm(open(raw_path / "train" / "news.tsv", "r", encoding='utf-8')): 69 | nid, vert, subvert, title, abst, url, ten, aen = l.strip("\n").split("\t") 70 | if nid in nid2index: 71 | continue 72 | tokens = tokenizer( 73 | title, 74 | max_length=args.max_title_len, 75 | truncation=True, 76 | padding="max_length", 77 | return_attention_mask=True, 78 | ) 79 | nid2index[nid] = len(nid2index) 80 | news_index.append([tokens.input_ids, tokens.attention_mask]) 81 | 82 | 83 | for l in tqdm(open(raw_path / "valid" / "news.tsv", "r", encoding='utf-8')): 84 | nid, vert, subvert, title, abst, url, ten, aen = l.strip("\n").split("\t") 85 | if nid in nid2index: 86 | continue 87 | tokens = tokenizer( 88 | title, 89 | max_length=args.max_title_len, 90 | truncation=True, 91 | padding="max_length", 92 | return_attention_mask=True, 93 | ) 94 | nid2index[nid] = len(nid2index) 95 | news_index.append([tokens.input_ids, tokens.attention_mask]) 96 | 97 | with open(out_path / "bert_nid2index.pkl", "wb") as f: 98 | pickle.dump(nid2index, f) 99 | 100 | news_index = np.array(news_index) 101 | np.save(out_path / "bert_news_index", news_index) 102 | 103 | if os.path.exists(raw_path / "test"): 104 | nid2index = {"": 0} 105 | news_index = [[[0] * args.max_title_len, [0] * args.max_title_len]] 106 | 107 | for l in tqdm(open(raw_path / "test" / "news.tsv", "r", encoding='utf-8')): 108 | nid, vert, subvert, title, abst, url, ten, aen = l.strip("\n").split("\t") 109 | if nid in nid2index: 110 | continue 111 | tokens = tokenizer( 112 | title, 113 | max_length=args.max_title_len, 114 | truncation=True, 115 | padding="max_length", 116 | return_attention_mask=True, 117 | ) 118 | nid2index[nid] = len(nid2index) 119 | news_index.append([tokens.input_ids, tokens.attention_mask]) 120 | 121 | with open(out_path / "bert_test_nid2index.pkl", "wb") as f: 122 | pickle.dump(nid2index, f) 123 | 124 | news_index = np.array(news_index) 125 | np.save(out_path / "bert_test_news_index", news_index) 126 | -------------------------------------------------------------------------------- /preprocess/user_process.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tqdm import tqdm 3 | from collections import defaultdict 4 | import os 5 | import pickle 6 | import argparse 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument( 11 | "--raw_path", 12 | type=str, 13 | default="../raw/", 14 | help="path to raw mind dataset or parsed ", 15 | ) 16 | parser.add_argument( 17 | "--out_path", 18 | type=str, 19 | default="../data/", 20 | help="path to save processed dataset, default in ../raw/mind/preprocess", 21 | ) 22 | parser.add_argument( 23 | "--data", 24 | type=str, 25 | default="mind", 26 | choices=["mind", "adressa"], 27 | help="decide which dataset for preprocess" 28 | ) 29 | parser.add_argument( 30 | "--npratio", 31 | type=int, 32 | default=4 33 | ) 34 | parser.add_argument( 35 | "--max_his_len", type=int, default=50 36 | ) 37 | parser.add_argument("--min_word_cnt", type=int, default=3) 38 | parser.add_argument("--max_title_len", type=int, default=30) 39 | 40 | args = parser.parse_args() 41 | return args 42 | 43 | 44 | if __name__ == "__main__": 45 | args = parse_args() 46 | 47 | raw_path = Path(args.raw_path) / args.data 48 | out_path = Path(args.out_path) / args.data 49 | 50 | user_imprs = defaultdict(list) 51 | 52 | # read user impressions 53 | for l in tqdm(open(raw_path / "train" / "behaviors.tsv", "r")): 54 | imp_id, uid, t, his, imprs = l.strip("\n").split("\t") 55 | his = his.split() 56 | imprs = [i.split("-") for i in imprs.split(" ")] 57 | neg_imp = [i[0] for i in imprs if i[1] == "0"] 58 | pos_imp = [i[0] for i in imprs if i[1] == "1"] 59 | user_imprs[uid].append([imp_id, his, pos_imp, neg_imp, 0, uid]) 60 | 61 | for l in tqdm(open(raw_path / "valid" / "behaviors.tsv", "r")): 62 | imp_id, uid, t, his, imprs = l.strip("\n").split("\t") 63 | his = his.split() 64 | imprs = [i.split("-") for i in imprs.split(" ")] 65 | neg_imp = [i[0] for i in imprs if i[1] == "0"] 66 | pos_imp = [i[0] for i in imprs if i[1] == "1"] 67 | user_imprs[uid].append([imp_id, his, pos_imp, neg_imp, 1, uid]) 68 | 69 | if os.path.exists(raw_path / "test"): 70 | if args.data == "adressa": 71 | for l in tqdm(open(raw_path / "test" / "behaviors.tsv", "r")): 72 | imp_id, uid, t, his, imprs = l.strip("\n").split("\t") 73 | his = his.split() 74 | imprs = [i.split("-") for i in imprs.split(" ")] 75 | neg_imp = [i[0] for i in imprs if i[1] == "0"] 76 | pos_imp = [i[0] for i in imprs if i[1] == "1"] 77 | user_imprs[uid].append([imp_id, his, pos_imp, neg_imp, 2, uid]) 78 | else: 79 | # MIND test dataset do not contains labels, need to test on condalab 80 | for l in tqdm(open(raw_path / "test" / "behaviors.tsv", "r")): 81 | imp_id, uid, t, his, imprs = l.strip("\n").split("\t") 82 | his = his.split() 83 | imprs = imprs.split(" ") 84 | user_imprs[uid].append([imp_id, his, imprs, [], 2, uid]) 85 | 86 | 87 | train_samples = [] 88 | valid_samples = [] 89 | test_samples = [] 90 | user_indices = defaultdict(list) 91 | 92 | index = 0 93 | for uid in tqdm(user_imprs): 94 | for impr in user_imprs[uid]: 95 | imp_id, his, poss, negs, is_valid, uid = impr 96 | his = his[-args.max_his_len:] 97 | if is_valid == 0: 98 | for pos in poss: 99 | train_samples.append([imp_id, pos, negs, his, uid]) 100 | user_indices[uid].append(index) 101 | index += 1 102 | elif is_valid == 1: 103 | valid_samples.append([imp_id, poss, negs, his, uid]) 104 | else: 105 | test_samples.append([imp_id, poss, negs, his, uid]) 106 | 107 | print(len(train_samples), len(valid_samples), len(test_samples)) 108 | 109 | with open(out_path / "train_sam_uid.pkl", "wb") as f: 110 | pickle.dump(train_samples, f) 111 | 112 | with open(out_path / "valid_sam_uid.pkl", "wb") as f: 113 | pickle.dump(valid_samples, f) 114 | 115 | with open(out_path / "test_sam_uid.pkl", "wb") as f: 116 | pickle.dump(test_samples, f) 117 | 118 | with open(out_path / "user_indices.pkl", "wb") as f: 119 | pickle.dump(user_indices, f) 120 | 121 | train_user_samples = 0 122 | 123 | for uid in tqdm(user_indices): 124 | train_user_samples += len(user_indices[uid]) 125 | 126 | print(train_user_samples / len(user_indices)) 127 | -------------------------------------------------------------------------------- /raw/README.md: -------------------------------------------------------------------------------- 1 | # Raw Dataset 2 | 3 | This directory is used to save the raw dataset of [MIND-small](https://msnews.github.io/)[1] and [Adressa-1week](http://reclab.idi.ntnu.no/dataset/)[2]. 4 | 5 | 6 | ## Download 7 | Since there is no released test dataset for MIND-small, we use the test dataset of MIND-large for test. 8 | 9 | ```bash 10 | # download mind-small 11 | ./download.sh mind to/your/path/ 12 | 13 | # download adressa-1week 14 | ./download.sh adressa to/your/path/ 15 | ``` 16 | -------------------------------------------------------------------------------- /raw/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET=$1 4 | DIR=$2 5 | 6 | if [ $DATASET == "mind" ] 7 | then 8 | mkdir $DIR/mind 9 | mkdir $DIR/mind/train 10 | mkdir $DIR/mind/valid 11 | wget https://mind201910small.blob.core.windows.net/release/MINDsmall_train.zip -P $DIR/mind/train 12 | unzip $DIR/mind/train/MINDsmall_train.zip -d $DIR/mind/train/ 13 | rm $DIR/mind/train/MINDsmall_train.zip 14 | wget https://mind201910small.blob.core.windows.net/release/MINDsmall_dev.zip -P $DIR/mind/valid 15 | unzip $DIR/mind/valid/MINDsmall_dev.zip -d $DIR/mind/valid/ 16 | rm $DIR/mind/valid/MINDsmall_dev.zip 17 | wget https://mind201910small.blob.core.windows.net/release/MINDlarge_test.zip -P $DIR/mind/test 18 | unzip $DIR/mind/test/MINDlarge_test.zip -d $DIR/mind/test/ 19 | rm $DIR/mind/test/MINDlarge_test.zip 20 | elif [ $DATASET == "adressa" ] 21 | then 22 | mkdir $DIR/adressa 23 | mkdir $DIR/adressa/raw 24 | wget --no-check-certificate http://reclab.idi.ntnu.no/dataset/one_week.tar.gz -P $DIR/adressa/raw 25 | tar zvfx $DIR/adressa/raw/one_week.tar.gz -C $DIR/adressa/raw 26 | rm $DIR/adressa/raw/one_week.tar.gz 27 | fi 28 | -------------------------------------------------------------------------------- /src/agg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from model import TextEncoder, UserEncoder 5 | import torch.optim as optim 6 | 7 | from data import NewsPartDataset 8 | 9 | class NewsUpdatorDataset(Dataset): 10 | def __init__(self, news_index, news_ids, news_grads): 11 | self.news_index = news_index 12 | self.news_grads = news_grads 13 | self.news_ids = news_ids 14 | 15 | def __len__(self): 16 | return len(self.news_ids) 17 | 18 | def __getitem__(self, idx): 19 | nid = self.news_ids[idx] 20 | return self.news_index[nid], self.news_grads[idx] 21 | 22 | 23 | class Aggregator: 24 | def __init__(self, args, news_dataset, news_index, device): 25 | self.device = device 26 | 27 | self.text_encoder = TextEncoder(bert_type=args.bert_type).to(device) 28 | self.user_encoder = UserEncoder().to(device) 29 | 30 | self.news_optimizer = optim.Adam(self.text_encoder.parameters(), lr=args.news_lr) 31 | self.user_optimizer = optim.Adam(self.user_encoder.parameters(), lr=args.user_lr) 32 | 33 | for param in self.text_encoder.bert.parameters(): 34 | param.requires_grad = False 35 | 36 | for index, layer in enumerate(self.text_encoder.bert.encoder.layer): 37 | if index in args.trainable_layers: 38 | for param in layer.parameters(): 39 | param.requires_grad = True 40 | 41 | if -1 in args.trainable_layers: 42 | for param in self.text_encoder.bert.embeddings.parameters(): 43 | param.requires_grad = True 44 | 45 | self.news_dataset = news_dataset 46 | self.news_index = news_index 47 | 48 | self.time = 0 49 | self.cnt = 0 50 | 51 | self._init_grad_param() 52 | 53 | def _init_grad_param(self): 54 | self.news_grads = {} 55 | self.user_optimizer.zero_grad() 56 | self.news_optimizer.zero_grad() 57 | 58 | def gen_news_vecs(self, nids): 59 | self.text_encoder.eval() 60 | news_ds = NewsPartDataset(self.news_index, nids) 61 | news_dl = DataLoader(news_ds, batch_size=2048, shuffle=False, num_workers=0) 62 | news_vecs = np.zeros((len(self.news_index), 400), dtype='float32') 63 | with torch.no_grad(): 64 | for nids, news in news_dl: 65 | news = news.to(self.device) 66 | news_vec = self.text_encoder(news).detach().cpu().numpy() 67 | news_vecs[nids.numpy()] = news_vec 68 | if np.isnan(news_vecs).any(): 69 | raise ValueError("news_vecs contains nan") 70 | self.news_vecs = news_vecs 71 | return news_vecs 72 | 73 | def get_news_vecs(self, idx): 74 | return self.news_vecs[idx] 75 | 76 | def update(self): 77 | self.update_user_grad() 78 | self.update_news_grad() 79 | self._init_grad_param() 80 | self.cnt += 1 81 | 82 | def average_update_time(self): 83 | return self.time / self.cnt 84 | 85 | def update_news_grad(self): 86 | self.text_encoder.train() 87 | self.news_optimizer.zero_grad() 88 | 89 | news_ids, news_grads = [], [] 90 | for nid in self.news_grads: 91 | news_ids.append(nid) 92 | news_grads.append(self.news_grads[nid]) 93 | 94 | news_up_ds = NewsUpdatorDataset(self.news_index, news_ids, news_grads) 95 | news_up_dl = DataLoader(news_up_ds, batch_size=128, shuffle=False, num_workers=0) 96 | for news_index, news_grad in news_up_dl: 97 | news_index = news_index.to(self.device) 98 | news_grad = news_grad.to(self.device) 99 | news_vecs = self.text_encoder(news_index) 100 | news_vecs.backward(news_grad) 101 | 102 | self.news_optimizer.step() 103 | self.news_optimizer.zero_grad() 104 | 105 | def update_user_grad(self): 106 | self.user_optimizer.step() 107 | self.user_optimizer.zero_grad() 108 | 109 | def check_news_vec_same(self, nids, news_vecs): 110 | assert (self.get_news_vecs(nids) == news_vecs).all(), "News vecs are not the same" 111 | 112 | def collect(self, news_grad, user_grad): 113 | # update user model params 114 | for name, param in self.user_encoder.named_parameters(): 115 | if param.grad is None: 116 | param.grad = user_grad[name] 117 | else: 118 | param.grad += user_grad[name] 119 | 120 | # update news model params 121 | for nid in news_grad: 122 | if nid in self.news_grads: 123 | self.news_grads[nid] += news_grad[nid] 124 | else: 125 | self.news_grads[nid] = news_grad[nid] -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from torch.utils.data import Dataset, DataLoader 4 | 5 | 6 | def newsample(nnn, ratio): 7 | if ratio > len(nnn): 8 | return nnn + [""] * (ratio - len(nnn)) 9 | else: 10 | return random.sample(nnn, ratio) 11 | 12 | class TrainDataset(Dataset): 13 | def __init__(self, args, samples, users, user_indices, nid2index, agg, news_index): 14 | self.news_index = news_index 15 | self.nid2index = nid2index 16 | self.agg = agg 17 | self.samples = [] 18 | self.args = args 19 | 20 | for user in users: 21 | self.samples.extend([samples[i] for i in user_indices[user]]) 22 | 23 | def __len__(self): 24 | return len(self.samples) 25 | 26 | def __getitem__(self, idx): 27 | # pos, neg, his, neg_his 28 | _, pos, neg, his, _ = self.samples[idx] 29 | neg = newsample(neg, self.args.npratio) 30 | candidate_news = np.array([self.nid2index[n] for n in [pos] + neg]) 31 | candidate_news_vecs = self.agg.get_news_vecs(candidate_news) 32 | his = np.array([self.nid2index[n] for n in his] + [0] * (self.args.max_his_len - len(his))) 33 | his_vecs = self.agg.get_news_vecs(his) 34 | label = np.array(0) 35 | 36 | return candidate_news, candidate_news_vecs, his, his_vecs, label 37 | 38 | 39 | class NewsDataset(Dataset): 40 | def __init__(self, news_index): 41 | self.news_index = news_index 42 | 43 | def __len__(self): 44 | return len(self.news_index) 45 | 46 | def __getitem__(self, idx): 47 | return self.news_index[idx] 48 | 49 | 50 | class NewsPartDataset(Dataset): 51 | def __init__(self, news_index, nids): 52 | self.news_index = news_index 53 | self.nids = nids 54 | 55 | def __len__(self): 56 | return len(self.nids) 57 | 58 | def __getitem__(self, idx): 59 | nid = self.nids[idx] 60 | return nid, self.news_index[nid] 61 | 62 | 63 | class UserDataset(Dataset): 64 | def __init__(self, 65 | args, 66 | samples, 67 | news_vecs, 68 | nid2index): 69 | self.samples = samples 70 | self.args = args 71 | self.news_vecs = news_vecs 72 | self.nid2index = nid2index 73 | 74 | def __len__(self): 75 | return len(self.samples) 76 | 77 | def __getitem__(self, idx): 78 | _, poss, negs, his, _ = self.samples[idx] 79 | his = [self.nid2index[n] for n in his] + [0] * (self.args.max_his_len - len(his)) 80 | his = self.news_vecs[his] 81 | return his 82 | 83 | 84 | class NewsUpdatorDataset(Dataset): 85 | def __init__(self, news_index, news_ids, news_grads): 86 | self.news_index = news_index 87 | self.news_grads = news_grads 88 | self.news_ids = news_ids 89 | 90 | def __len__(self): 91 | return len(self.news_ids) 92 | 93 | def __getitem__(self, idx): 94 | nid = self.news_ids[idx] 95 | return self.news_index[nid], self.news_grads[idx] -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from tqdm import tqdm 5 | import random 6 | import wandb 7 | import numpy as np 8 | import os 9 | import pickle 10 | 11 | import torch 12 | from torch.utils.data import DataLoader 13 | import torch.optim as optim 14 | 15 | from agg import Aggregator 16 | from model import Model, TextEncoder, UserEncoder 17 | from data import TrainDataset, NewsDataset, UserDataset 18 | from metrics import evaluation_split 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--wandb_entity", type=str) 24 | parser.add_argument( 25 | "--mode", type=str, default="train", choices=["train", "test", "predict"] 26 | ) 27 | parser.add_argument( 28 | "--data_path", 29 | type=str, 30 | default=os.getenv("AMLT_DATA_DIR", "../data"), 31 | help="path to downloaded raw adressa dataset", 32 | ) 33 | parser.add_argument( 34 | "--out_path", 35 | type=str, 36 | default=os.getenv("AMLT_OUTPUT_DIR", "../output"), 37 | help="path to downloaded raw adressa dataset", 38 | ) 39 | parser.add_argument( 40 | "--data", 41 | type=str, 42 | default="mind", 43 | choices=["mind", "adressa"], 44 | help="decide which dataset for preprocess", 45 | ) 46 | parser.add_argument("--bert_type", type=str, default="bert-base-uncased") 47 | parser.add_argument( 48 | "--trainable_layers", type=int, nargs="+", default=[6, 7, 8, 9, 10, 11] 49 | ) 50 | parser.add_argument("--user_lr", type=float, default=0.00005) 51 | parser.add_argument("--news_lr", type=float, default=0.00005) 52 | parser.add_argument("--user_num", type=int, default=50) 53 | parser.add_argument("--max_his_len", type=float, default=50) 54 | parser.add_argument( 55 | "--npratio", 56 | type=int, 57 | default=20, 58 | help="randomly sample neg_num negative impression for every positive behavior", 59 | ) 60 | parser.add_argument("--max_train_steps", type=int, default=2000) 61 | parser.add_argument("--validation_steps", type=int, default=100) 62 | parser.add_argument("--name", type=str, default="efficient-fedrec") 63 | 64 | args = parser.parse_args() 65 | return args 66 | 67 | 68 | def process_news_grad(candidate_info, his_info): 69 | news_grad = {} 70 | candidate_news, candidate_vecs, candidate_grad = candidate_info 71 | his, his_vecs, his_grad = his_info 72 | 73 | candidate_news, candaidate_grad = ( 74 | candidate_news.reshape(-1,), 75 | candidate_grad.reshape(-1, 400), 76 | ) 77 | his, his_grad = his.reshape(-1,), his_grad.reshape(-1, 400) 78 | 79 | for nid, grad in zip(his, his_grad): 80 | if nid in news_grad: 81 | news_grad[nid] += grad 82 | else: 83 | news_grad[nid] = grad 84 | 85 | for nid, grad in zip(candidate_news, candaidate_grad): 86 | if nid in news_grad: 87 | news_grad[nid] += grad 88 | else: 89 | news_grad[nid] = grad 90 | return news_grad 91 | 92 | 93 | def process_user_grad(model_param, sample_num, user_sample): 94 | user_grad = {} 95 | for name, param in model_param: 96 | user_grad[name] = param.grad * (sample_num / user_sample) 97 | return user_grad 98 | 99 | 100 | def collect_users_nids(train_sam, users, user_indices, nid2index): 101 | user_nids = [0] 102 | user_sample = 0 103 | for user in users: 104 | sids = user_indices[user] 105 | user_sample += len(sids) 106 | for idx in sids: 107 | _, pos, neg, his, _ = train_sam[idx] 108 | user_nids.extend([nid2index[i] for i in list(set([pos] + neg + his))]) 109 | return list(set(user_nids)), user_sample 110 | 111 | 112 | def train_on_step( 113 | agg, model, args, user_indices, user_num, train_sam, nid2index, news_index, device 114 | ): 115 | # sample users 116 | users = random.sample(user_indices.keys(), user_num) 117 | nids, user_sample = collect_users_nids(train_sam, users, user_indices, nid2index) 118 | 119 | agg.gen_news_vecs(nids) 120 | train_ds = TrainDataset( 121 | args, train_sam, users, user_indices, nid2index, agg, news_index 122 | ) 123 | train_dl = DataLoader(train_ds, batch_size=16384, shuffle=True, num_workers=0) 124 | model.train() 125 | loss = 0 126 | 127 | for cnt, batch_sample in enumerate(train_dl): 128 | model.user_encoder.load_state_dict(agg.user_encoder.state_dict()) 129 | optimizer = optim.SGD(model.parameters(), lr=args.user_lr) 130 | 131 | candidate_news, candidate_news_vecs, his, his_vecs, label = batch_sample 132 | candidate_news_vecs = candidate_news_vecs.to(device) 133 | his_vecs = his_vecs.to(device) 134 | sample_num = his_vecs.shape[0] 135 | 136 | label = label.to(device) 137 | 138 | # compute gradients for user model and news representations 139 | candidate_news_vecs.requires_grad = True 140 | his_vecs.requires_grad = True 141 | bz_loss, y_hat = model(candidate_news_vecs, his_vecs, label) 142 | loss += bz_loss.detach().cpu().numpy() 143 | 144 | optimizer.zero_grad() 145 | bz_loss.backward() 146 | 147 | candaidate_grad = candidate_news_vecs.grad.detach().cpu() * ( 148 | sample_num / user_sample 149 | ) 150 | candidate_vecs = candidate_news_vecs.detach().cpu().numpy() 151 | candidate_news = candidate_news.numpy() 152 | 153 | his_grad = his_vecs.grad.detach().cpu() * (sample_num / user_sample) 154 | his_vecs = his_vecs.detach().cpu().numpy() 155 | his = his.numpy() 156 | 157 | news_grad = process_news_grad( 158 | [candidate_news, candidate_vecs, candaidate_grad], [his, his_vecs, his_grad] 159 | ) 160 | user_grad = process_user_grad( 161 | model.user_encoder.named_parameters(), sample_num, user_sample 162 | ) 163 | agg.collect(news_grad, user_grad) 164 | 165 | loss = loss / (cnt + 1) 166 | agg.update() 167 | return loss 168 | 169 | 170 | def validate(args, agg, valid_sam, nid2index, news_index, device): 171 | agg.gen_news_vecs(list(range(len(news_index)))) 172 | agg.user_encoder.eval() 173 | user_dataset = UserDataset(args, valid_sam, agg.news_vecs, nid2index) 174 | user_vecs = [] 175 | user_dl = DataLoader(user_dataset, batch_size=4096, shuffle=False, num_workers=0) 176 | with torch.no_grad(): 177 | for his in tqdm(user_dl): 178 | his = his.to(device) 179 | user_vec = agg.user_encoder(his).detach().cpu().numpy() 180 | user_vecs.append(user_vec) 181 | user_vecs = np.concatenate(user_vecs) 182 | 183 | val_scores = evaluation_split(agg.news_vecs, user_vecs, valid_sam, nid2index) 184 | val_auc, val_mrr, val_ndcg, val_ndcg10 = [ 185 | np.mean(i) for i in list(zip(*val_scores)) 186 | ] 187 | 188 | return val_auc, val_mrr, val_ndcg, val_ndcg10 189 | 190 | 191 | def test(args, data_path, out_model_path, out_path, device): 192 | with open(data_path / "test_sam_uid.pkl", "rb") as f: 193 | test_sam = pickle.load(f) 194 | 195 | with open(data_path / "bert_test_nid2index.pkl", "rb") as f: 196 | test_nid2index = pickle.load(f) 197 | 198 | test_news_index = np.load(data_path / "bert_test_news_index.npy", allow_pickle=True) 199 | 200 | text_encoder = TextEncoder(bert_type=args.bert_type).to(device) 201 | user_encoder = UserEncoder().to(device) 202 | ckpt = torch.load(out_model_path / f"{args.name}-{args.data}.pkl") 203 | text_encoder.load_state_dict(ckpt["text_encoder"]) 204 | user_encoder.load_state_dict(ckpt["user_encoder"]) 205 | 206 | test_news_dataset = NewsDataset(test_news_index) 207 | news_dl = DataLoader( 208 | test_news_dataset, batch_size=512, shuffle=False, num_workers=0 209 | ) 210 | news_vecs = [] 211 | text_encoder.eval() 212 | for news in tqdm(news_dl): 213 | news = news.to(device) 214 | news_vec = text_encoder(news).detach().cpu().numpy() 215 | news_vecs.append(news_vec) 216 | news_vecs = np.concatenate(news_vecs) 217 | 218 | user_dataset = UserDataset(args, test_sam, news_vecs, test_nid2index) 219 | user_vecs = [] 220 | user_dl = DataLoader(user_dataset, batch_size=4096, shuffle=False, num_workers=0) 221 | user_encoder.eval() 222 | for his in tqdm(user_dl): 223 | his = his.to(device) 224 | user_vec = user_encoder(his).detach().cpu().numpy() 225 | user_vecs.append(user_vec) 226 | user_vecs = np.concatenate(user_vecs) 227 | 228 | test_scores = evaluation_split(news_vecs, user_vecs, test_sam, test_nid2index) 229 | test_auc, test_mrr, test_ndcg, test_ndcg10 = [ 230 | np.mean(i) for i in list(zip(*test_scores)) 231 | ] 232 | 233 | with open(out_path / f"log.txt", "a") as f: 234 | f.write( 235 | f"test auc: {test_auc:.4f}, mrr: {test_mrr:.4f}, ndcg5: {test_ndcg:.4f}, ndcg10: {test_ndcg10:.4f}\n" 236 | ) 237 | 238 | 239 | def predict(args, data_path, out_model_path, out_path, device): 240 | with open(data_path / "test_sam_uid.pkl", "rb") as f: 241 | test_sam = pickle.load(f) 242 | 243 | with open(data_path / "bert_test_nid2index.pkl", "rb") as f: 244 | test_nid2index = pickle.load(f) 245 | 246 | test_news_index = np.load(data_path / "bert_test_news_index.npy", allow_pickle=True) 247 | 248 | text_encoder = TextEncoder(bert_type=args.bert_type).to(device) 249 | user_encoder = UserEncoder().to(device) 250 | ckpt = torch.load(out_model_path / f"{args.name}-{args.data}.pkl") 251 | text_encoder.load_state_dict(ckpt["text_encoder"]) 252 | user_encoder.load_state_dict(ckpt["user_encoder"]) 253 | 254 | test_news_dataset = NewsDataset(test_news_index) 255 | news_dl = DataLoader( 256 | test_news_dataset, batch_size=512, shuffle=False, num_workers=0 257 | ) 258 | news_vecs = [] 259 | text_encoder.eval() 260 | for news in tqdm(news_dl): 261 | news = news.to(device) 262 | news_vec = text_encoder(news).detach().cpu().numpy() 263 | news_vecs.append(news_vec) 264 | news_vecs = np.concatenate(news_vecs) 265 | 266 | user_dataset = UserDataset(args, test_sam, news_vecs, test_nid2index) 267 | user_vecs = [] 268 | user_dl = DataLoader(user_dataset, batch_size=4096, shuffle=False, num_workers=0) 269 | user_encoder.eval() 270 | for his in tqdm(user_dl): 271 | his = his.to(device) 272 | user_vec = user_encoder(his).detach().cpu().numpy() 273 | user_vecs.append(user_vec) 274 | user_vecs = np.concatenate(user_vecs) 275 | 276 | pred_lines = [] 277 | for i in tqdm(range(len(test_sam))): 278 | impr_id, poss, negs, _, _ = test_sam[i] 279 | user_vec = user_vecs[i] 280 | news_ids = [test_nid2index[i] for i in poss + negs] 281 | news_vec = news_vecs[news_ids] 282 | y_score = np.multiply(news_vec, user_vec) 283 | y_score = np.sum(y_score, axis=1) 284 | 285 | pred_rank = (np.argsort(np.argsort(y_score)[::-1]) + 1).tolist() 286 | pred_rank = '[' + ','.join([str(i) for i in pred_rank]) + ']' 287 | pred_lines.append((int(impr_id), ' '.join([impr_id, pred_rank])+ '\n')) 288 | 289 | pred_lines.sort(key=lambda x: x[0]) 290 | pred_lines = [x[1] for x in pred_lines] 291 | with open(out_path / 'prediction.txt', 'w') as f: 292 | f.writelines(pred_lines) 293 | 294 | 295 | if __name__ == "__main__": 296 | args = parse_args() 297 | 298 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 299 | device = torch.device("cuda:0") 300 | torch.cuda.set_device(device) 301 | 302 | if args.mode == "train": 303 | wandb.init( 304 | project=f"{args.name}-{args.data}", config=args, entity=args.wandb_entity 305 | ) 306 | 307 | data_path = Path(args.data_path) / args.data 308 | out_path = Path(args.out_path) / f"{args.name}-{args.data}" 309 | out_model_path = out_path / "model" 310 | 311 | out_model_path.mkdir(exist_ok=True, parents=True) 312 | 313 | # load preprocessed data 314 | with open(data_path / "bert_nid2index.pkl", "rb") as f: 315 | nid2index = pickle.load(f) 316 | 317 | news_index = np.load(data_path / "bert_news_index.npy", allow_pickle=True) 318 | 319 | with open(data_path / "train_sam_uid.pkl", "rb") as f: 320 | train_sam = pickle.load(f) 321 | 322 | with open(data_path / "valid_sam_uid.pkl", "rb") as f: 323 | valid_sam = pickle.load(f) 324 | 325 | with open(data_path / "user_indices.pkl", "rb") as f: 326 | user_indices = pickle.load(f) 327 | 328 | news_dataset = NewsDataset(news_index) 329 | 330 | agg = Aggregator(args, news_dataset, news_index, device) 331 | model = Model().to(device) 332 | best_auc = 0 333 | for step in range(args.max_train_steps): 334 | loss = train_on_step( 335 | agg, 336 | model, 337 | args, 338 | user_indices, 339 | args.user_num, 340 | train_sam, 341 | nid2index, 342 | news_index, 343 | device, 344 | ) 345 | 346 | wandb.log({"train loss": loss}, step=step + 1) 347 | 348 | if (step + 1) % args.validation_steps == 0: 349 | val_auc, val_mrr, val_ndcg, val_ndcg10 = validate( 350 | args, agg, valid_sam, nid2index, news_index, device 351 | ) 352 | 353 | wandb.log( 354 | { 355 | "valid auc": val_auc, 356 | "valid mrr": val_mrr, 357 | "valid ndcg@5": val_ndcg, 358 | "valid ndcg@10": val_ndcg10, 359 | }, 360 | step=step + 1, 361 | ) 362 | 363 | with open(out_path / f"log.txt", "a") as f: 364 | f.write( 365 | f"[{step}] round auc: {val_auc:.4f}, mrr: {val_mrr:.4f}, ndcg5: {val_ndcg:.4f}, ndcg10: {val_ndcg10:.4f}\n" 366 | ) 367 | 368 | if val_auc > best_auc: 369 | best_auc = val_auc 370 | wandb.run.summary["best_auc"] = best_auc 371 | torch.save( 372 | { 373 | "text_encoder": agg.text_encoder.state_dict(), 374 | "user_encoder": agg.user_encoder.state_dict(), 375 | }, 376 | out_model_path / f"{args.name}-{args.data}.pkl", 377 | ) 378 | 379 | with open(out_path / f"log.txt", "a") as f: 380 | f.write(f"[{step}] round save model\n") 381 | 382 | elif args.mode == "test": 383 | data_path = Path(args.data_path) / args.data 384 | out_path = Path(args.out_path) / f"{args.name}-{args.data}" 385 | out_model_path = out_path / "model" 386 | test(args, data_path, out_model_path, out_path, device) 387 | 388 | elif args.mode == "predict": 389 | data_path = Path(args.data_path) / args.data 390 | out_path = Path(args.out_path) / f"{args.name}-{args.data}" 391 | out_model_path = out_path / "model" 392 | predict(args, data_path, out_model_path, out_path, device) 393 | 394 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from sklearn.metrics import roc_auc_score 4 | 5 | def dcg_score(y_true, y_score, k=10): 6 | order = np.argsort(y_score)[::-1] 7 | y_true = np.take(y_true, order[:k]) 8 | gains = 2 ** y_true - 1 9 | discounts = np.log2(np.arange(len(y_true)) + 2) 10 | return np.sum(gains / discounts) 11 | 12 | 13 | def ndcg_score(y_true, y_score, k=10): 14 | best = dcg_score(y_true, y_true, k) 15 | actual = dcg_score(y_true, y_score, k) 16 | return actual / best 17 | 18 | 19 | def mrr_score(y_true, y_score): 20 | order = np.argsort(y_score)[::-1] 21 | y_true = np.take(y_true, order) 22 | rr_score = y_true / (np.arange(len(y_true)) + 1) 23 | return np.sum(rr_score) / np.sum(y_true) 24 | 25 | 26 | def compute_amn(y_true, y_score): 27 | auc = roc_auc_score(y_true,y_score) 28 | mrr = mrr_score(y_true,y_score) 29 | ndcg5 = ndcg_score(y_true,y_score,5) 30 | ndcg10 = ndcg_score(y_true,y_score,10) 31 | return auc, mrr, ndcg5, ndcg10 32 | 33 | def evaluation_split(news_vecs, user_vecs, samples, nid2index): 34 | all_rslt = [] 35 | for i in tqdm(range(len(samples))): 36 | _, poss, negs, _, _ = samples[i] 37 | user_vec = user_vecs[i] 38 | y_true = [1] * len(poss) + [0] * len(negs) 39 | news_ids = [nid2index[i] for i in poss + negs] 40 | news_vec = news_vecs[news_ids] 41 | y_score = np.multiply(news_vec, user_vec) 42 | y_score = np.sum(y_score, axis=1) 43 | try: 44 | all_rslt.append(compute_amn(y_true, y_score)) 45 | except Exception as e: 46 | print(e) 47 | return np.array(all_rslt) -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from transformers import BertModel 5 | 6 | import numpy as np 7 | 8 | class ScaledDotProductAttention(nn.Module): 9 | def __init__(self, d_k): 10 | super(ScaledDotProductAttention, self).__init__() 11 | self.d_k = d_k 12 | 13 | def forward(self, Q, K, V, attn_mask=None): 14 | scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k) 15 | scores = torch.exp(scores) 16 | if attn_mask is not None: 17 | scores = scores * attn_mask 18 | attn = scores / (torch.sum(scores, dim=-1, keepdim=True) + 1e-8) 19 | 20 | context = torch.matmul(attn, V) 21 | return context, attn 22 | 23 | class MultiHeadAttention(nn.Module): 24 | def __init__(self, d_model, n_heads, d_k, d_v): 25 | super(MultiHeadAttention, self).__init__() 26 | self.d_model = d_model # 300 27 | self.n_heads = n_heads # 20 28 | self.d_k = d_k # 20 29 | self.d_v = d_v # 20 30 | 31 | self.W_Q = nn.Linear(d_model, d_k * n_heads) # 300, 400 32 | self.W_K = nn.Linear(d_model, d_k * n_heads) # 300, 400 33 | self.W_V = nn.Linear(d_model, d_v * n_heads) # 300, 400 34 | 35 | self._initialize_weights() 36 | 37 | def _initialize_weights(self): 38 | for m in self.modules(): 39 | if isinstance(m, nn.Linear): 40 | nn.init.xavier_uniform_(m.weight, gain=1) 41 | 42 | def forward(self, Q, K, V, attn_mask=None): 43 | batch_size, seq_len, _ = Q.size() 44 | 45 | q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2) 46 | k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2) 47 | v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2) 48 | 49 | if attn_mask is not None: 50 | attn_mask = attn_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) 51 | attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) 52 | 53 | context, attn = ScaledDotProductAttention(self.d_k)(q_s, k_s, v_s, attn_mask) 54 | context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) 55 | return context 56 | 57 | 58 | class AdditiveAttention(nn.Module): 59 | def __init__(self, d_h, hidden_size=200): 60 | super(AdditiveAttention, self).__init__() 61 | self.att_fc1 = nn.Linear(d_h, hidden_size) 62 | self.att_fc2 = nn.Linear(hidden_size, 1) 63 | 64 | def forward(self, x, attn_mask=None): 65 | bz = x.shape[0] 66 | e = self.att_fc1(x) 67 | e = nn.Tanh()(e) 68 | alpha = self.att_fc2(e) 69 | 70 | alpha = torch.exp(alpha) 71 | if attn_mask is not None: 72 | alpha = alpha * attn_mask.unsqueeze(2) 73 | alpha = alpha / (torch.sum(alpha, dim=1, keepdim=True) + 1e-8) 74 | 75 | x = torch.bmm(x.permute(0, 2, 1), alpha) 76 | x = torch.reshape(x, (bz, -1)) # (bz, 400) 77 | return x 78 | 79 | class TextEncoder(nn.Module): 80 | def __init__(self, 81 | bert_type="bert-base-uncased", 82 | word_embedding_dim=400, 83 | dropout_rate=0.2, 84 | enable_gpu=True): 85 | super(TextEncoder, self).__init__() 86 | self.dropout_rate = 0.2 87 | self.bert = BertModel.from_pretrained(bert_type, 88 | hidden_dropout_prob=0, 89 | attention_probs_dropout_prob=0) 90 | self.additive_attention = AdditiveAttention(self.bert.config.hidden_size, 91 | self.bert.config.hidden_size//2) 92 | self.fc = nn.Linear(self.bert.config.hidden_size, word_embedding_dim) 93 | 94 | def forward(self, text): 95 | # text batch, 2, word 96 | tokens = text[:,0,:] 97 | atts = text[:,1,:] 98 | text_vector = self.bert(tokens, attention_mask=atts)[0] 99 | text_vector = self.additive_attention(text_vector) 100 | text_vector = self.fc(text_vector) 101 | return text_vector 102 | 103 | 104 | class UserEncoder(nn.Module): 105 | def __init__(self, 106 | news_embedding_dim=400, 107 | num_attention_heads=20, 108 | query_vector_dim=200 109 | ): 110 | super(UserEncoder, self).__init__() 111 | self.dropout_rate = 0.2 112 | self.multihead_attention = MultiHeadAttention(news_embedding_dim, 113 | num_attention_heads, 20, 20) 114 | self.additive_attention = AdditiveAttention(news_embedding_dim, 115 | query_vector_dim) 116 | 117 | def forward(self, clicked_news_vecs): 118 | clicked_news_vecs = F.dropout(clicked_news_vecs, p=self.dropout_rate, training=self.training) 119 | multi_clicked_vectors = self.multihead_attention( 120 | clicked_news_vecs, clicked_news_vecs, clicked_news_vecs 121 | ) 122 | pos_user_vector = self.additive_attention(multi_clicked_vectors) 123 | 124 | user_vector = pos_user_vector 125 | return user_vector 126 | 127 | class Model(nn.Module): 128 | def __init__(self): 129 | super(Model, self).__init__() 130 | self.user_encoder = UserEncoder() 131 | 132 | self.criterion = nn.CrossEntropyLoss() 133 | 134 | def forward(self, candidate_vecs, clicked_news_vecs, targets, compute_loss=True): 135 | user_vector = self.user_encoder(clicked_news_vecs) 136 | 137 | score = torch.bmm(candidate_vecs, user_vector.unsqueeze(-1)).squeeze(dim=-1) 138 | 139 | if compute_loss: 140 | loss = self.criterion(score, targets) 141 | return loss, score 142 | else: 143 | return score --------------------------------------------------------------------------------