├── .gitignore ├── LICENSE ├── README.md ├── data ├── beauty │ ├── get_item_embedding.ipynb │ └── get_user_embedding.ipynb ├── convert_inter.ipynb ├── data_process.py ├── fashion │ ├── get_item_embedding.ipynb │ └── get_user_embedding.ipynb ├── pca.ipynb ├── retrieval_users.ipynb └── yelp │ ├── get_item_embedding.ipynb │ └── get_user_embedding.ipynb ├── environment.yml ├── experiments ├── beauty.bash ├── fashion.bash └── yelp.bash ├── generators ├── bert_generator.py ├── data.py └── generator.py ├── main.py ├── models ├── BaseModel.py ├── Bert4Rec.py ├── DualLLMSRS.py ├── GRU4Rec.py ├── LLMESR.py ├── SASRec.py └── utils.py ├── requirements.txt ├── trainers ├── sequence_trainer.py └── trainer.py └── utils ├── earlystop.py ├── logger.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | log/ 3 | saved/ 4 | *.txt 5 | *.pkl 6 | !requirements.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Large Language Models Enhanced Sequential Recommendation for Long-tail User and Item 2 | 3 | This is the implementation of the paper "Large Language Models Enhanced Sequential Recommendation for Long-tail User and Item". 4 | 5 | ## Configure the environment 6 | 7 | To ease the configuration of the environment, I list versions of my hardware and software equipments: 8 | 9 | - Hardware: 10 | - GPU: Tesla V100 32GB 11 | - Cuda: 10.2 12 | - Driver version: 440.95.01 13 | - CPU: Intel Xeon Gold 6133 14 | - Software: 15 | - Python: 3.9.5 16 | - Pytorch: 1.12.0+cu102 17 | 18 | You can conda install the `environment.yml` or pip install the `requirements.txt` to configure the environment. 19 | 20 | ## Preprocess the dataset 21 | 22 | You can preprocess the dataset and get the LLMs embedding according to the following steps: 23 | 24 | 1. The raw dataset downloaded from website should be put into `/data//raw/`. The Yelp dataset can be obtained from [https://www.yelp.com/dataset](https://www.yelp.com/dataset). The fashion and beauty datasets can be obtained from [https://cseweb.ucsd.edu/~jmcauley/datasets.html\#amazon_reviews](https://cseweb.ucsd.edu/~jmcauley/datasets.html\#amazon_reviews). 25 | 2. Conduct the preprocessing code `data/data_process.py` to filter cold-start users and items. After the procedure, you will get the id file `/data//hdanled/id_map.json` and the interaction file `/data//handled/inter_seq.txt`. 26 | 3. Convert the interaction file to the format used in this repo by running `data/convert_inter.ipynb`. 27 | 4. To get the LLMs embedding for each dataset, please run the jupyter notebooks `/data//get_item_embedding.ipynb` and `/data//get_user_embedding.ipynb`. After the running, you will get the LLMs item embedding file `/data//handled/itm_emb_np.pkl` and LLMs user embedding file `/data//handled/usr_emb_np.pkl`. 28 | 5. For dual-view modeling module, we need to run the jupyter notebook `data/pca.ipynb` to get the dimension-reduced LLMs item embedding for initialization, i.e., `/data//handled/pca64_itm_emb_np.pkl`. 29 | 6. For retrieval augmented self-distillation, we need to run the jupyter notebook `data/retrieval_users.ipynb` to get the similar user set for each user. The output file in this step is `sim_user_100.pkl` 30 | 31 | In conclusion, the prerequisite files to run the code are as follows: `inter.txt`, `itm_emb_np.pkl`, `usr_emb_np.pkl`, `pca64_itm_emb_np.pkl` and `sim_user_100.pkl`. 32 | 33 | ⭐️ To ease the reproducibility of our paper, we also upload all preprocessed files to this [link](https://drive.google.com/file/d/1MpBUjCDLiFIEODTnopSCzDAnS8RzO9aV/view?usp=sharing). 34 | 35 | ## Run and test 36 | 37 | 1. You can reproduce all LLM-ESR experiments by running the bash as follows: 38 | 39 | ``` 40 | bash experiments/yelp.bash 41 | bash experiments/fashion.bash 42 | bash experiments/beauty.bash 43 | ``` 44 | 45 | 2. The log and results will be saved in the folder `log/`. The checkpoint will be saved in the folder `saved/`. 46 | 47 | ## Citation 48 | 49 | If the code and the paper are useful for you, it is appreciable to cite our paper: 50 | 51 | ``` 52 | @article{liu2024large, 53 | title={Large Language Models Enhanced Sequential Recommendation for Long-tail User and Item}, 54 | author={Liu, Qidong and Wu, Xian and Zhao, Xiangyu and Wang, Yejing and Zhang, Zijian and Tian, Feng and Zheng, Yefeng}, 55 | journal={arXiv preprint arXiv:2405.20646}, 56 | year={2024} 57 | } 58 | ``` 59 | 60 | ## Thanks 61 | 62 | The code refers to the repo [SASRec](https://github.com/kang205/SASRec) and [RLMRec](https://github.com/HKUDS/RLMRec). 63 | -------------------------------------------------------------------------------- /data/beauty/get_item_embedding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import pickle\n", 11 | "import jsonlines\n", 12 | "import pandas as pd\n", 13 | "import numpy as np\n", 14 | "import json\n", 15 | "import copy\n", 16 | "from tqdm import tqdm" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "prompt_template = \"The beauty item has following attributes: \\n name is ; brand is <BRAND>; price is <PRICE>. \\n\"\n", 26 | "feat_template = \"The item has following features: <CATEGORIES>. \\n\"\n", 27 | "desc_template = \"The item has following descriptions: <DESCRIPTION>. \\n\"" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "data = json.load(open(\"./handled/item2attributes.json\", \"r\"))" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "all_feats = []\n", 46 | "\n", 47 | "for user, user_attris in data.items():\n", 48 | " for feat_name in user_attris.keys():\n", 49 | " if feat_name not in all_feats:\n", 50 | " all_feats.append(feat_name)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "def get_attri(item_str, attri, item_info):\n", 60 | "\n", 61 | " if attri not in item_info.keys():\n", 62 | " new_str = item_str.replace(f\"<{attri.upper()}>\", \"unknown\")\n", 63 | " else:\n", 64 | " new_str = item_str.replace(f\"<{attri.upper()}>\", str(item_info[attri]))\n", 65 | "\n", 66 | " return new_str" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "def get_feat(item_str, feat, item_info):\n", 76 | "\n", 77 | " if feat not in item_info.keys():\n", 78 | " return \"\"\n", 79 | " \n", 80 | " assert isinstance(item_info[feat], list)\n", 81 | " feat_str = \"\"\n", 82 | " for meta_feat in item_info[feat][0]:\n", 83 | " feat_str = feat_str + meta_feat + \"; \"\n", 84 | " new_str = item_str.replace(f\"<{feat.upper()}>\", feat_str)\n", 85 | "\n", 86 | " if len(new_str) > 2048: # avoid exceed the input length limitation\n", 87 | " return new_str[:2048]\n", 88 | "\n", 89 | " return new_str" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "item_data = {}\n", 99 | "for key, value in tqdm(data.items()):\n", 100 | " item_str = copy.deepcopy(prompt_template)\n", 101 | " item_str = get_attri(item_str, \"title\", value)\n", 102 | " item_str = get_attri(item_str, \"brand\", value)\n", 103 | " item_str = get_attri(item_str, \"date\", value)\n", 104 | " item_str = get_attri(item_str, \"price\", value)\n", 105 | "\n", 106 | " feat_str = copy.deepcopy(feat_template)\n", 107 | " feat_str = get_feat(feat_str, \"categories\", value)\n", 108 | " desc_str = copy.deepcopy(desc_template)\n", 109 | " desc_str = get_attri(desc_str, \"description\", value)\n", 110 | " \n", 111 | " item_data[key] = item_str + feat_str + desc_str" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "item_data[\"1304351475\"]" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "json.dump(item_data, open(\"./handled/item_str.json\", \"w\"))" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "item_data = json.load(open(\"./handled/item_str.json\", \"r\"))" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "import jsonlines\n", 148 | "\n", 149 | "def save_data(data_path, data):\n", 150 | " '''write all_data list to a new jsonl'''\n", 151 | " with jsonlines.open(\"./handled/\"+ data_path, \"w\") as w:\n", 152 | " for meta_data in data:\n", 153 | " w.write(meta_data)\n", 154 | "\n", 155 | "id_map = json.load(open(\"./handled/id_map.json\", \"r\"))[\"item2id\"]\n", 156 | "json_data = []\n", 157 | "for key, value in item_data.items():\n", 158 | " json_data.append({\"input\": value, \"target\": \"\", \"item\": key, \"item_id\": id_map[key]})\n", 159 | "\n", 160 | "save_data(\"item_str.jsonline\", json_data)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "import requests\n", 170 | "import json\n", 171 | "\n", 172 | "url = \"\"\n", 173 | "\n", 174 | "payload = json.dumps({\n", 175 | " \"model\": \"text-embedding-ada-002\",\n", 176 | " \"input\": \"The food was delicious and the waiter...\"\n", 177 | "})\n", 178 | "headers = {\n", 179 | " 'Authorization': '',\n", 180 | " 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n", 181 | " 'Content-Type': 'application/json'\n", 182 | "}\n", 183 | "\n", 184 | "response = requests.request(\"POST\", url, headers=headers, data=payload)\n", 185 | "\n", 186 | "print(response.text)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "def get_response(prompt):\n", 196 | " url = \"\"\n", 197 | "\n", 198 | " payload = json.dumps({\n", 199 | " \"model\": \"text-embedding-ada-002\",\n", 200 | " \"input\": prompt\n", 201 | " })\n", 202 | " headers = {\n", 203 | " 'Authorization': '',\n", 204 | " 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n", 205 | " 'Content-Type': 'application/json'\n", 206 | " }\n", 207 | "\n", 208 | " response = requests.request(\"POST\", url, headers=headers, data=payload)\n", 209 | " re_json = json.loads(response.text)\n", 210 | "\n", 211 | " return re_json[\"data\"][0][\"embedding\"]" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "item_emb = {}" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "value_list = []\n", 230 | "\n", 231 | "for key, value in tqdm(item_data.items()):\n", 232 | " if len(value) > 4096:\n", 233 | " value_list.append(key)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "if os.path.exists(\"./handled/item_emb.pkl\"): # check whether some item emb exist in cache\n", 243 | " item_emb = pickle.load(open(\"./handled/item_emb.pkl\", \"rb\"))" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "count = 1\n", 253 | "while 1: # avoid broken due to internet connection\n", 254 | " if len(item_emb) == len(item_data):\n", 255 | " break\n", 256 | " try:\n", 257 | " for key, value in tqdm(item_data.items()):\n", 258 | " if key not in item_emb.keys():\n", 259 | " if len(value) > 4096:\n", 260 | " value = value[:4095]\n", 261 | " item_emb[key] = get_response(value)\n", 262 | " count += 1\n", 263 | " except:\n", 264 | " pickle.dump(item_emb, open(\"./handled/item_emb.pkl\", \"wb\"))\n", 265 | " " 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "len(item_emb)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [ 283 | "id_map = json.load(open(\"./handled/id_map.json\", \"r\"))[\"id2item\"]\n", 284 | "emb_list = []\n", 285 | "for id in range(1, len(item_emb)+1):\n", 286 | " meta_emb = item_emb[id_map[str(id)]]\n", 287 | " emb_list.append(meta_emb)\n", 288 | "\n", 289 | "emb_list = np.array(emb_list)\n", 290 | "pickle.dump(emb_list, open(\"./handled/itm_emb_np.pkl\", \"wb\"))" 291 | ] 292 | } 293 | ], 294 | "metadata": { 295 | "kernelspec": { 296 | "display_name": "llm", 297 | "language": "python", 298 | "name": "python3" 299 | }, 300 | "language_info": { 301 | "codemirror_mode": { 302 | "name": "ipython", 303 | "version": 3 304 | }, 305 | "file_extension": ".py", 306 | "mimetype": "text/x-python", 307 | "name": "python", 308 | "nbconvert_exporter": "python", 309 | "pygments_lexer": "ipython3", 310 | "version": "3.9.5" 311 | }, 312 | "orig_nbformat": 4 313 | }, 314 | "nbformat": 4, 315 | "nbformat_minor": 2 316 | } 317 | -------------------------------------------------------------------------------- /data/beauty/get_user_embedding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import pickle\n", 11 | "import jsonlines\n", 12 | "import pandas as pd\n", 13 | "import numpy as np\n", 14 | "import json\n", 15 | "import copy\n", 16 | "from tqdm import tqdm\n", 17 | "from collections import defaultdict\n", 18 | "import requests" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "id_map = json.load(open(\"./handled/id_map.json\"))\n", 28 | "item_dict = json.load(open(\"./handled/item2attributes.json\", \"r\"))" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "def load_dataset():\n", 38 | " '''Load train, validation, test dataset'''\n", 39 | "\n", 40 | " usernum = 0\n", 41 | " itemnum = 0\n", 42 | " User = defaultdict(list) # default value is a blank list\n", 43 | " user_train = {}\n", 44 | " user_valid = {}\n", 45 | " user_test = {}\n", 46 | " # assume user/item index starting from 1\n", 47 | " f = open('./handled/inter.txt', 'r')\n", 48 | " for line in f: # use a dict to save all seqeuces of each user\n", 49 | " u, i = line.rstrip().split(' ')\n", 50 | " u = int(u)\n", 51 | " i = int(i)\n", 52 | " usernum = max(u, usernum)\n", 53 | " itemnum = max(i, itemnum)\n", 54 | " User[u].append(i)\n", 55 | "\n", 56 | " for user in tqdm(User):\n", 57 | " nfeedback = len(User[user])\n", 58 | " #nfeedback = len(User[user])\n", 59 | " if nfeedback < 3:\n", 60 | " user_train[user] = User[user]\n", 61 | " user_valid[user] = []\n", 62 | " user_test[user] = []\n", 63 | " else:\n", 64 | " user_train[user] = User[user][:-2]\n", 65 | " user_valid[user] = []\n", 66 | " user_valid[user].append(User[user][-2])\n", 67 | " user_test[user] = []\n", 68 | " user_test[user].append(User[user][-1])\n", 69 | " \n", 70 | " return user_train\n", 71 | " " 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "inter = load_dataset()" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "prompt_template = \"The user has visited following fashions: \\n<HISTORY> \\nplease conclude the user's perference.\"" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "def get_user_prompt(history):\n", 99 | "\n", 100 | " user_str = copy.deepcopy(prompt_template)\n", 101 | " hist_str = \"\"\n", 102 | " for item in history:\n", 103 | " try: # some item does not have title\n", 104 | " item_str = item_dict[id_map[\"id2item\"][str(item)]][\"title\"]\n", 105 | " hist_str = hist_str + item_str + \", \"\n", 106 | " except:\n", 107 | " continue\n", 108 | "\n", 109 | " # limit the prompt length\n", 110 | " if len(hist_str) > 8000:\n", 111 | " hist_str = hist_str[-8000:]\n", 112 | "\n", 113 | " user_str = user_str.replace(\"<HISTORY>\", hist_str)\n", 114 | "\n", 115 | " return user_str" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "user_data = {}\n", 125 | "\n", 126 | "for user, history in tqdm(inter.items()):\n", 127 | " user_data[user] = get_user_prompt(history)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "# json.dump(user_data, open(\"./handled/user_str.json\", \"w\"))" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "user_data = json.load(open(\"./handled/user_str.json\", \"r\"))" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "url = \"\"\n", 155 | "\n", 156 | "payload = json.dumps({\n", 157 | " \"model\": \"text-embedding-ada-002\",\n", 158 | " \"input\": \"The food was delicious and the waiter...\"\n", 159 | "})\n", 160 | "headers = {\n", 161 | " 'Authorization': '',\n", 162 | " 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n", 163 | " 'Content-Type': 'application/json'\n", 164 | "}\n", 165 | "\n", 166 | "response = requests.request(\"POST\", url, headers=headers, data=payload)\n", 167 | "\n", 168 | "print(response.text)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "def get_response(prompt):\n", 178 | " url = \"\"\n", 179 | "\n", 180 | " payload = json.dumps({\n", 181 | " \"model\": \"text-embedding-ada-002\",\n", 182 | " \"input\": prompt\n", 183 | " })\n", 184 | " headers = {\n", 185 | " 'Authorization': '',\n", 186 | " 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n", 187 | " 'Content-Type': 'application/json'\n", 188 | " }\n", 189 | "\n", 190 | " response = requests.request(\"POST\", url, headers=headers, data=payload)\n", 191 | " re_json = json.loads(response.text)\n", 192 | "\n", 193 | " return re_json[\"data\"][0][\"embedding\"]" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "user_emb = {}" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "if os.path.exists(\"./handled/user_emb.pkl\"): # check whether some item emb exist in cache\n", 212 | " user_emb = pickle.load(open(\"./handled/user_emb.pkl\", \"rb\"))" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "len(user_emb)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "count = 1\n", 231 | "while 1: # avoid broken due to internet connection\n", 232 | " if len(user_emb) == len(user_data):\n", 233 | " break\n", 234 | " try:\n", 235 | " for key, value in tqdm(user_data.items()):\n", 236 | " if key not in user_emb.keys():\n", 237 | " if len(value) > 4096:\n", 238 | " value = value[:4095]\n", 239 | " user_emb[key] = get_response(value)\n", 240 | " count += 1\n", 241 | " except:\n", 242 | " pickle.dump(user_emb, open(\"./handled/user_emb.pkl\", \"wb\"))" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "len(user_emb)" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "emb_list = []\n", 261 | "for key, value in tqdm(user_emb.items()):\n", 262 | " emb_list.append(value)\n", 263 | "\n", 264 | "emb_list = np.array(emb_list)\n", 265 | "pickle.dump(emb_list, open(\"./handled/usr_emb_np.pkl\", \"wb\"))" 266 | ] 267 | } 268 | ], 269 | "metadata": { 270 | "kernelspec": { 271 | "display_name": "llm", 272 | "language": "python", 273 | "name": "python3" 274 | }, 275 | "language_info": { 276 | "codemirror_mode": { 277 | "name": "ipython", 278 | "version": 3 279 | }, 280 | "file_extension": ".py", 281 | "mimetype": "text/x-python", 282 | "name": "python", 283 | "nbconvert_exporter": "python", 284 | "pygments_lexer": "ipython3", 285 | "version": "3.9.5" 286 | }, 287 | "orig_nbformat": 4 288 | }, 289 | "nbformat": 4, 290 | "nbformat_minor": 2 291 | } 292 | -------------------------------------------------------------------------------- /data/convert_inter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import pickle\n", 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "from tqdm import tqdm" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "dataset = \"beauty\"" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "data = {}\n", 32 | "\n", 33 | "with open(f\"./{dataset}/handled/inter_seq.txt\", 'r') as f:\n", 34 | " for line in tqdm(f):\n", 35 | " line_data = line.rstrip().split(' ')\n", 36 | " user_id = line_data[0]\n", 37 | " line_data.pop(0) # delete user_id\n", 38 | " data[user_id] = line_data" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "with open(f\"./{dataset}/handled/inter.txt\", 'w') as f:\n", 48 | " for user, item_list in tqdm(data.items()):\n", 49 | " for item in item_list:\n", 50 | " u = int(user)\n", 51 | " i = int(item)\n", 52 | " f.write('%s %s\\n' % (u, i))" 53 | ] 54 | } 55 | ], 56 | "metadata": { 57 | "kernelspec": { 58 | "display_name": "llm", 59 | "language": "python", 60 | "name": "python3" 61 | }, 62 | "language_info": { 63 | "codemirror_mode": { 64 | "name": "ipython", 65 | "version": 3 66 | }, 67 | "file_extension": ".py", 68 | "mimetype": "text/x-python", 69 | "name": "python", 70 | "nbconvert_exporter": "python", 71 | "pygments_lexer": "ipython3", 72 | "version": "3.9.5" 73 | }, 74 | "orig_nbformat": 4 75 | }, 76 | "nbformat": 4, 77 | "nbformat_minor": 2 78 | } 79 | -------------------------------------------------------------------------------- /data/data_process.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import random 3 | import numpy as np 4 | import pandas as pd 5 | import json 6 | import pickle 7 | import gzip 8 | import tqdm 9 | import os 10 | from tqdm import tqdm 11 | 12 | true=True 13 | false=False 14 | def parse(path): # for Amazon 15 | g = gzip.open(path, 'rb') 16 | inter_list = [] 17 | for l in tqdm(g): 18 | inter_list.append(json.loads(l.decode())) 19 | # inter_list.append(eval(l)) 20 | 21 | return inter_list 22 | 23 | 24 | def parse_meta(path): # for Amazon 25 | g = gzip.open(path, 'rb') 26 | inter_list = [] 27 | for l in tqdm(g): 28 | inter_list.append(eval(l)) 29 | 30 | return inter_list 31 | 32 | 33 | # return (user item timestamp) sort in get_interaction 34 | def Amazon(dataset_name, rating_score): 35 | ''' 36 | reviewerID - ID of the reviewer, e.g. A2SUAM1J3GNN3B 37 | asin - ID of the product, e.g. 0000013714 38 | reviewerName - name of the reviewer 39 | helpful - helpfulness rating of the review, e.g. 2/3 40 | --"helpful": [2, 3], 41 | reviewText - text of the review 42 | --"reviewText": "I bought this for my husband who plays the piano. ..." 43 | overall - rating of the product 44 | --"overall": 5.0, 45 | summary - summary of the review 46 | --"summary": "Heavenly Highway Hymns", 47 | unixReviewTime - time of the review (unix time) 48 | --"unixReviewTime": 1252800000, 49 | reviewTime - time of the review (raw) 50 | --"reviewTime": "09 13, 2009" 51 | ''' 52 | datas = [] 53 | # older Amazon 54 | data_flie = './data/' + str(dataset_name) + '/raw/' + str(dataset_name) + '.json.gz' 55 | # latest Amazon 56 | # data_flie = '/home/hui_wang/data/new_Amazon/' + dataset_name + '.json.gz' 57 | for inter in parse(data_flie): 58 | if float(inter['overall']) <= rating_score: # 小于一定分数去掉 59 | continue 60 | user = inter['reviewerID'] 61 | item = inter['asin'] 62 | time = inter['unixReviewTime'] 63 | datas.append((user, item, int(time))) 64 | return datas 65 | 66 | 67 | def New_Amazon(dataset_name, rating_score): 68 | ''' 69 | reviewerID - ID of the reviewer, e.g. A2SUAM1J3GNN3B 70 | asin - ID of the product, e.g. 0000013714 71 | reviewerName - name of the reviewer 72 | helpful - helpfulness rating of the review, e.g. 2/3 73 | --"helpful": [2, 3], 74 | reviewText - text of the review 75 | --"reviewText": "I bought this for my husband who plays the piano. ..." 76 | overall - rating of the product 77 | --"overall": 5.0, 78 | summary - summary of the review 79 | --"summary": "Heavenly Highway Hymns", 80 | unixReviewTime - time of the review (unix time) 81 | --"unixReviewTime": 1252800000, 82 | reviewTime - time of the review (raw) 83 | --"reviewTime": "09 13, 2009" 84 | ''' 85 | datas = [] 86 | # older Amazon 87 | data_flie = './data/' + str(dataset_name) + '/raw/' + str(dataset_name) + '.json.gz' 88 | # latest Amazon 89 | # data_flie = '/home/hui_wang/data/new_Amazon/' + dataset_name + '.json.gz' 90 | for inter in parse(data_flie): 91 | if float(inter['overall']) <= rating_score: # 小于一定分数去掉 92 | continue 93 | user = inter['reviewerID'] 94 | item = inter['asin'] 95 | time = inter['unixReviewTime'] 96 | datas.append((user, item, int(time))) 97 | return datas 98 | 99 | 100 | def Amazon_meta(dataset_name, data_maps): 101 | ''' 102 | asin - ID of the product, e.g. 0000031852 103 | --"asin": "0000031852", 104 | title - name of the product 105 | --"title": "Girls Ballet Tutu Zebra Hot Pink", 106 | description 107 | price - price in US dollars (at time of crawl) 108 | --"price": 3.17, 109 | imUrl - url of the product image (str) 110 | --"imUrl": "http://ecx.images-amazon.com/images/I/51fAmVkTbyL._SY300_.jpg", 111 | related - related products (also bought, also viewed, bought together, buy after viewing) 112 | --"related":{ 113 | "also_bought": ["B00JHONN1S"], 114 | "also_viewed": ["B002BZX8Z6"], 115 | "bought_together": ["B002BZX8Z6"] 116 | }, 117 | salesRank - sales rank information 118 | --"salesRank": {"Toys & Games": 211836} 119 | brand - brand name 120 | --"brand": "Coxlures", 121 | categories - list of categories the product belongs to 122 | --"categories": [["Sports & Outdoors", "Other Sports", "Dance"]] 123 | ''' 124 | datas = {} 125 | meta_flie = './data/' + str(dataset_name) + '/raw/meta_' + str(dataset_name) + '.json.gz' 126 | item_asins = list(data_maps['item2id'].keys()) 127 | for info in tqdm(parse_meta(meta_flie)): 128 | if info['asin'] not in item_asins: 129 | continue 130 | datas[info['asin']] = info 131 | return datas 132 | 133 | def Yelp(date_min, date_max, rating_score): # take out inters in [date_min, date_max] and the score < rating_score 134 | datas = [] 135 | data_flie = './data/yelp/raw/yelp_academic_dataset_review.json' 136 | lines = open(data_flie).readlines() 137 | for line in tqdm(lines): 138 | review = json.loads(line.strip()) 139 | user = review['user_id'] 140 | item = review['business_id'] 141 | rating = review['stars'] 142 | # 2004-10-12 10:13:32 2019-12-13 15:51:19 143 | date = review['date'] 144 | # 剔除一些例子 145 | if date < date_min or date > date_max or float(rating) <= rating_score: 146 | continue 147 | time = date.replace('-','').replace(':','').replace(' ','') 148 | datas.append((user, item, int(time))) 149 | return datas 150 | 151 | 152 | def Yelp_meta(datamaps): 153 | meta_infos = {} 154 | meta_file = './data/yelp/raw/yelp_academic_dataset_business.json' 155 | item_ids = list(datamaps['item2id'].keys()) 156 | lines = open(meta_file).readlines() 157 | for line in tqdm(lines): 158 | info = json.loads(line) 159 | if info['business_id'] not in item_ids: 160 | continue 161 | meta_infos[info['business_id']] = info 162 | return meta_infos 163 | 164 | 165 | def add_comma(num): 166 | # 1000000 -> 1,000,000 167 | str_num = str(num) 168 | res_num = '' 169 | for i in range(len(str_num)): 170 | res_num += str_num[i] 171 | if (len(str_num)-i-1) % 3 == 0: 172 | res_num += ',' 173 | return res_num[:-1] 174 | 175 | # categories 和 brand is all attribute 176 | def get_attribute_Amazon(meta_infos, datamaps, attribute_core): 177 | 178 | attributes = defaultdict(int) 179 | # for iid, info in tqdm.tqdm(meta_infos.items()): 180 | # for cates in info['categories']: 181 | # for cate in cates[1:]: # 把主类删除 没有用 182 | # attributes[cate] +=1 183 | # try: 184 | # attributes[info['brand']] += 1 185 | # except: 186 | # pass 187 | 188 | # print(f'before delete, attribute num:{len(attributes)}') 189 | # new_meta = {} 190 | # for iid, info in tqdm.tqdm(meta_infos.items()): 191 | # new_meta[iid] = [] 192 | 193 | # try: 194 | # if attributes[info['brand']] >= attribute_core: 195 | # new_meta[iid].append(info['brand']) 196 | # except: 197 | # pass 198 | # for cates in info['categories']: 199 | # for cate in cates[1:]: 200 | # if attributes[cate] >= attribute_core: 201 | # new_meta[iid].append(cate) 202 | # 做映射 203 | attribute2id = {} 204 | id2attribute = {} 205 | attributeid2num = defaultdict(int) 206 | attribute_id = 1 207 | items2attributes = {} 208 | attribute_lens = [] 209 | 210 | for iid, attributes in meta_infos.items(): 211 | item_id = datamaps['item2id'][iid] 212 | items2attributes[item_id] = [] 213 | for attribute in attributes: 214 | if attribute not in attribute2id: 215 | attribute2id[attribute] = attribute_id 216 | id2attribute[attribute_id] = attribute 217 | attribute_id += 1 218 | attributeid2num[attribute2id[attribute]] += 1 219 | items2attributes[item_id].append(attribute2id[attribute]) 220 | attribute_lens.append(len(items2attributes[item_id])) 221 | print(f'before delete, attribute num:{len(attribute2id)}') 222 | print(f'attributes len, Min:{np.min(attribute_lens)}, Max:{np.max(attribute_lens)}, Avg.:{np.mean(attribute_lens):.4f}') 223 | # 更新datamap 224 | datamaps['attribute2id'] = attribute2id 225 | datamaps['id2attribute'] = id2attribute 226 | datamaps['attributeid2num'] = attributeid2num 227 | return len(attribute2id), np.mean(attribute_lens), datamaps, items2attributes 228 | 229 | 230 | def get_attribute_Yelp(meta_infos, datamaps, attribute_core): 231 | attributes = defaultdict(int) 232 | for iid, info in tqdm(meta_infos.items()): 233 | try: 234 | cates = [cate.strip() for cate in info['categories'].split(',')] 235 | for cate in cates: 236 | attributes[cate] +=1 237 | except: 238 | pass 239 | print(f'before delete, attribute num:{len(attributes)}') 240 | new_meta = {} 241 | for iid, info in tqdm(meta_infos.items()): 242 | new_meta[iid] = [] 243 | try: 244 | cates = [cate.strip() for cate in info['categories'].split(',') ] 245 | for cate in cates: 246 | if attributes[cate] >= attribute_core: 247 | new_meta[iid].append(cate) 248 | except: 249 | pass 250 | # 做映射 251 | attribute2id = {} 252 | id2attribute = {} 253 | attribute_id = 1 254 | items2attributes = {} 255 | attribute_lens = [] 256 | # load id map 257 | for iid, attributes in new_meta.items(): 258 | item_id = datamaps['item2id'][iid] 259 | items2attributes[item_id] = [] 260 | for attribute in attributes: 261 | if attribute not in attribute2id: 262 | attribute2id[attribute] = attribute_id 263 | id2attribute[attribute_id] = attribute 264 | attribute_id += 1 265 | items2attributes[item_id].append(attribute2id[attribute]) 266 | attribute_lens.append(len(items2attributes[item_id])) 267 | print(f'after delete, attribute num:{len(attribute2id)}') 268 | print(f'attributes len, Min:{np.min(attribute_lens)}, Max:{np.max(attribute_lens)}, Avg.:{np.mean(attribute_lens):.4f}') 269 | # 更新datamap 270 | datamaps['attribute2id'] = attribute2id 271 | datamaps['id2attribute'] = id2attribute 272 | return len(attribute2id), np.mean(attribute_lens), datamaps, items2attributes 273 | 274 | def get_interaction(datas): # sort the interactions based on timestamp 275 | user_seq = {} 276 | for data in datas: 277 | user, item, time = data 278 | if user in user_seq: 279 | user_seq[user].append((item, time)) 280 | else: 281 | user_seq[user] = [] 282 | user_seq[user].append((item, time)) 283 | 284 | for user, item_time in user_seq.items(): 285 | item_time.sort(key=lambda x: x[1]) # 对各个数据集得单独排序 286 | items = [] 287 | for t in item_time: 288 | items.append(t[0]) 289 | user_seq[user] = items 290 | return user_seq 291 | 292 | # K-core user_core item_core 293 | def check_Kcore(user_items, user_core, item_core): 294 | user_count = defaultdict(int) 295 | item_count = defaultdict(int) 296 | for user, items in user_items.items(): 297 | for item in items: 298 | user_count[user] += 1 299 | item_count[item] += 1 300 | 301 | for user, num in user_count.items(): 302 | if num < user_core: 303 | return user_count, item_count, False 304 | for item, num in item_count.items(): 305 | if num < item_core: 306 | return user_count, item_count, False 307 | return user_count, item_count, True # 已经保证Kcore 308 | 309 | # 循环过滤 K-core 310 | def filter_Kcore(user_items, user_core, item_core): # user 接所有items 311 | user_count, item_count, isKcore = check_Kcore(user_items, user_core, item_core) 312 | while not isKcore: 313 | for user, num in user_count.items(): 314 | if user_count[user] < user_core: # 直接把user 删除 315 | user_items.pop(user) 316 | else: 317 | for item in user_items[user]: 318 | if item_count[item] < item_core: 319 | user_items[user].remove(item) 320 | user_count, item_count, isKcore = check_Kcore(user_items, user_core, item_core) 321 | return user_items 322 | 323 | 324 | def filter_common(user_items, user_t, item_t): 325 | 326 | user_count = defaultdict(int) 327 | item_count = defaultdict(int) 328 | for user, item, _ in user_items: 329 | user_count[user] += 1 330 | item_count[item] += 1 331 | 332 | User = {} 333 | for user, item, timestamp in user_items: 334 | if user_count[user] < user_t or item_count[item] < item_t: 335 | continue 336 | if user not in User.keys(): 337 | User[user] = [] 338 | User[user].append((item, timestamp)) 339 | 340 | new_User = {} 341 | for userid in User.keys(): 342 | User[userid].sort(key=lambda x: x[1]) 343 | new_hist = [i for i, t in User[userid]] 344 | new_User[userid] = new_hist 345 | 346 | return new_User 347 | 348 | 349 | 350 | def id_map(user_items): # user_items dict 351 | 352 | user2id = {} # raw 2 uid 353 | item2id = {} # raw 2 iid 354 | id2user = {} # uid 2 raw 355 | id2item = {} # iid 2 raw 356 | user_id = 1 357 | item_id = 1 358 | final_data = {} 359 | for user, items in user_items.items(): 360 | if user not in user2id: 361 | user2id[user] = str(user_id) 362 | id2user[str(user_id)] = user 363 | user_id += 1 364 | iids = [] # item id lists 365 | for item in items: 366 | if item not in item2id: 367 | item2id[item] = str(item_id) 368 | id2item[str(item_id)] = item 369 | item_id += 1 370 | iids.append(item2id[item]) 371 | uid = user2id[user] 372 | final_data[uid] = iids 373 | data_maps = { 374 | 'user2id': user2id, 375 | 'item2id': item2id, 376 | 'id2user': id2user, 377 | 'id2item': id2item 378 | } 379 | return final_data, user_id-1, item_id-1, data_maps 380 | 381 | 382 | def get_counts(user_items): 383 | 384 | user_count = {} 385 | item_count = {} 386 | 387 | for user, items in user_items.items(): 388 | user_count[user] = len(items) 389 | for item in items: 390 | if item not in item_count.keys(): 391 | item_count[item] = 1 392 | else: 393 | item_count[item] += 1 394 | 395 | return user_count, item_count 396 | 397 | 398 | def filter_minmum(user_items, min_len=3): 399 | 400 | new_user_items = {} 401 | for user, items in user_items.items(): 402 | if len(items) >= min_len: 403 | new_user_items[user] = items 404 | 405 | return new_user_items 406 | 407 | 408 | 409 | def main(data_name, data_type='Amazon', user_core=3, item_core=3): 410 | assert data_type in {'Amazon', 'Yelp', 'New_Amazon'} 411 | np.random.seed(12345) 412 | rating_score = 0.0 # rating score smaller than this score would be deleted 413 | # user 5-core item 5-core 414 | attribute_core = 0 415 | 416 | if data_type == 'Yelp': 417 | date_max = '2019-12-31 00:00:00' 418 | date_min = '2000-01-01 00:00:00' 419 | datas = Yelp(date_min, date_max, rating_score) 420 | elif data_type == "New_Amazon": 421 | datas = New_Amazon(data_name, rating_score=rating_score) 422 | else: 423 | datas = Amazon(data_name, rating_score=rating_score) 424 | 425 | # datas = datas[:int(len(datas)*0.1)] # for electronics and game 426 | if data_type != "New_Amazon": 427 | user_items = get_interaction(datas) 428 | print(f'{data_name} Raw data has been processed! Lower than {rating_score} are deleted!') 429 | # raw_id user: [item1, item2, item3...] 430 | user_items = filter_common(datas, user_t=user_core, item_t=item_core) 431 | # user_items = filter_Kcore(user_items, user_core=user_core, item_core=item_core) 432 | print(f'User {user_core}-core complete! Item {item_core}-core complete!') 433 | 434 | user_items, user_num, item_num, data_maps = id_map(user_items) # new_num_id 435 | user_items = filter_minmum(user_items, min_len=3) 436 | # user_count, item_count, _ = check_Kcore(user_items, user_core=user_core, item_core=item_core) 437 | user_count, item_count = get_counts(user_items) 438 | user_count_list = list(user_count.values()) 439 | user_avg, user_min, user_max = np.mean(user_count_list), np.min(user_count_list), np.max(user_count_list) 440 | item_count_list = list(item_count.values()) 441 | item_avg, item_min, item_max = np.mean(item_count_list), np.min(item_count_list), np.max(item_count_list) 442 | interact_num = np.sum([x for x in user_count_list]) 443 | sparsity = (1 - interact_num / (user_num * item_num)) * 100 444 | show_info = f'Total User: {user_num}, Avg User: {user_avg:.4f}, Min Len: {user_min}, Max Len: {user_max}\n' + \ 445 | f'Total Item: {item_num}, Avg Item: {item_avg:.4f}, Min Inter: {item_min}, Max Inter: {item_max}\n' + \ 446 | f'Iteraction Num: {interact_num}, Sparsity: {sparsity:.2f}%' 447 | print(show_info) 448 | 449 | 450 | print('Begin extracting meta infos...') 451 | 452 | if data_type == 'Amazon': 453 | meta_infos = Amazon_meta(data_name, data_maps) 454 | attribute_num, avg_attribute, datamaps, item2attributes = get_attribute_Amazon(meta_infos, data_maps, attribute_core) 455 | elif data_type == "New_Amazon": 456 | meta_infos = Amazon_meta(data_name, data_maps) 457 | attribute_num, avg_attribute, datamaps, item2attributes = get_attribute_Amazon(meta_infos, data_maps, attribute_core) 458 | else: 459 | meta_infos = Yelp_meta(data_maps) 460 | attribute_num, avg_attribute, datamaps, item2attributes = get_attribute_Yelp(meta_infos, data_maps, attribute_core) 461 | 462 | print(f'{data_name} & {add_comma(user_num)}& {add_comma(item_num)} & {user_avg:.1f}' 463 | f'& {item_avg:.1f}& {add_comma(interact_num)}& {sparsity:.2f}\%&{add_comma(attribute_num)}&' 464 | f'{avg_attribute:.1f} \\') 465 | 466 | # -------------- Save Data --------------- 467 | handled_path = 'data/' + data_name + '/handled/' 468 | if not os.path.exists(handled_path): 469 | os.makedirs(handled_path) 470 | 471 | data_file = handled_path + 'inter_seq.txt' 472 | item2attributes_file = handled_path + 'item2attributes.json' 473 | id_file = handled_path + "id_map.json" 474 | 475 | with open(data_file, 'w') as out: 476 | for user, items in user_items.items(): 477 | out.write(user + ' ' + ' '.join(items) + '\n') 478 | json_str = json.dumps(meta_infos) 479 | with open(item2attributes_file, 'w') as out: 480 | out.write(json_str) 481 | with open(id_file, "w") as f: 482 | json.dump(data_maps, f) 483 | 484 | 485 | 486 | def LastFM(): 487 | user_core = 5 488 | item_core = 5 489 | datas = [] 490 | data_file = '/path/lastfm/2k/user_attributegedartists-timestamps.dat' 491 | lines = open(data_file).readlines() 492 | for line in tqdm.tqdm(lines[1:]): 493 | user, item, attribute, timestamp = line.strip().split('\t') 494 | datas.append((user, item, int(timestamp))) 495 | 496 | # 有重复item 497 | user_seq = {} 498 | user_seq_notime = {} 499 | for data in datas: 500 | user, item, time = data 501 | if user in user_seq: 502 | if item not in user_seq_notime[user]: 503 | user_seq[user].append((item, time)) 504 | user_seq_notime[user].append(item) 505 | else: 506 | continue 507 | else: 508 | user_seq[user] = [] 509 | user_seq_notime[user] = [] 510 | 511 | user_seq[user].append((item, time)) 512 | user_seq_notime[user].append(item) 513 | 514 | for user, item_time in user_seq.items(): 515 | item_time.sort(key=lambda x: x[1]) # 对各个数据集得单独排序 516 | items = [] 517 | for t in item_time: 518 | items.append(t[0]) 519 | user_seq[user] = items 520 | 521 | user_items = filter_Kcore(user_seq, user_core=user_core, item_core=item_core) 522 | print(f'User {user_core}-core complete! Item {item_core}-core complete!') 523 | 524 | user_items, user_num, item_num, data_maps = id_map(user_items) # new_num_id 525 | user_count, item_count, _ = check_Kcore(user_items, user_core=user_core, item_core=item_core) 526 | user_count_list = list(user_count.values()) 527 | user_avg, user_min, user_max = np.mean(user_count_list), np.min(user_count_list), np.max(user_count_list) 528 | item_count_list = list(item_count.values()) 529 | item_avg, item_min, item_max = np.mean(item_count_list), np.min(item_count_list), np.max(item_count_list) 530 | interact_num = np.sum([x for x in user_count_list]) 531 | sparsity = (1 - interact_num / (user_num * item_num)) * 100 532 | show_info = f'Total User: {user_num}, Avg User: {user_avg:.4f}, Min Len: {user_min}, Max Len: {user_max}\n' + \ 533 | f'Total Item: {item_num}, Avg Item: {item_avg:.4f}, Min Inter: {item_min}, Max Inter: {item_max}\n' + \ 534 | f'Iteraction Num: {interact_num}, Sparsity: {sparsity:.2f}%' 535 | print(show_info) 536 | 537 | attribute_file = './data_path/artist2attributes.json' 538 | 539 | meta_item2attribute = json.loads(open(attribute_file).readline()) 540 | 541 | # 做映射 542 | attribute2id = {} 543 | id2attribute = {} 544 | attribute_id = 1 545 | item2attributes = {} 546 | attribute_lens = [] 547 | # load id map 548 | for iid, attributes in meta_item2attribute.items(): 549 | if iid in list(data_maps['item2id'].keys()): 550 | item_id = data_maps['item2id'][iid] 551 | item2attributes[item_id] = [] 552 | for attribute in attributes: 553 | if attribute not in attribute2id: 554 | attribute2id[attribute] = attribute_id 555 | id2attribute[attribute_id] = attribute 556 | attribute_id += 1 557 | item2attributes[item_id].append(attribute2id[attribute]) 558 | attribute_lens.append(len(item2attributes[item_id])) 559 | print(f'after delete, attribute num:{len(attribute2id)}') 560 | print(f'attributes len, Min:{np.min(attribute_lens)}, Max:{np.max(attribute_lens)}, Avg.:{np.mean(attribute_lens):.4f}') 561 | # 更新datamap 562 | data_maps['attribute2id'] = attribute2id 563 | data_maps['id2attribute'] = id2attribute 564 | 565 | data_name = 'LastFM' 566 | print(f'{data_name} & {add_comma(user_num)}& {add_comma(item_num)} & {user_avg:.1f}' 567 | f'& {item_avg:.1f}& {add_comma(interact_num)}& {sparsity:.2f}\%&{add_comma(len(attribute2id))}&' 568 | f'{np.mean(attribute_lens):.1f} \\') 569 | 570 | # -------------- Save Data --------------- 571 | # one user one line 572 | data_file = 'data/' + data_name + '.txt' 573 | item2attributes_file = 'data/' + data_name + '_item2attributes.json' 574 | 575 | with open(data_file, 'w') as out: 576 | for user, items in user_items.items(): 577 | out.write(user + ' ' + ' '.join(items) + '\n') 578 | 579 | json_str = json.dumps(item2attributes) 580 | with open(item2attributes_file, 'w') as out: 581 | out.write(json_str) 582 | 583 | amazon_datas = ['Beauty', 'Sports_and_Outdoors', 'Toys_and_Games'] 584 | 585 | 586 | if __name__ == "__main__": 587 | 588 | main('yelp', data_type='Yelp', user_core=3, item_core=3) 589 | main("fashion", data_type="Amazon") 590 | main("beauty", data_type="Amazon", user_core=3, item_core=3) 591 | -------------------------------------------------------------------------------- /data/fashion/get_item_embedding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import pickle\n", 11 | "import jsonlines\n", 12 | "import pandas as pd\n", 13 | "import numpy as np\n", 14 | "import json\n", 15 | "import copy\n", 16 | "from tqdm import tqdm" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "prompt_template = \"The fashion item has following attributes: \\n name is <TITLE>; brand is <BRAND>; score is <DATE>; price is <PRICE>. \\n\"\n", 26 | "feat_template = \"The item has following features: <FEATURE>. \\n\"\n", 27 | "desc_template = \"The item has following descriptions: <DESCRIPTION>. \\n\"" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "data = json.load(open(\"./handled/item2attributes.json\", \"r\"))" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "def get_attri(item_str, attri, item_info):\n", 46 | "\n", 47 | " if attri not in item_info.keys() or len(item_info[attri]) > 100:\n", 48 | " new_str = item_str.replace(f\"<{attri.upper()}>\", \"unknown\")\n", 49 | " else:\n", 50 | " new_str = item_str.replace(f\"<{attri.upper()}>\", item_info[attri])\n", 51 | "\n", 52 | " return new_str" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "def get_feat(item_str, feat, item_info):\n", 62 | "\n", 63 | " if feat not in item_info.keys():\n", 64 | " return \"\"\n", 65 | " \n", 66 | " assert isinstance(item_info[feat], list)\n", 67 | " feat_str = \"\"\n", 68 | " for meta_feat in item_info[feat]:\n", 69 | " feat_str = feat_str + meta_feat + \"; \"\n", 70 | " new_str = item_str.replace(f\"<{feat.upper()}>\", feat_str)\n", 71 | "\n", 72 | " if len(new_str) > 2048: # avoid exceed the input length limitation\n", 73 | " return new_str[:2048]\n", 74 | "\n", 75 | " return new_str" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "item_data = {}\n", 85 | "for key, value in tqdm(data.items()):\n", 86 | " item_str = copy.deepcopy(prompt_template)\n", 87 | " item_str = get_attri(item_str, \"title\", value)\n", 88 | " item_str = get_attri(item_str, \"brand\", value)\n", 89 | " item_str = get_attri(item_str, \"date\", value)\n", 90 | " item_str = get_attri(item_str, \"price\", value)\n", 91 | "\n", 92 | " feat_str = copy.deepcopy(feat_template)\n", 93 | " feat_str = get_feat(feat_str, \"feature\", value)\n", 94 | " desc_str = copy.deepcopy(desc_template)\n", 95 | " desc_str = get_feat(desc_str, \"description\", value)\n", 96 | " \n", 97 | " item_data[key] = item_str + feat_str + desc_str" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "item_data[\"B0002Z1JNK\"]" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "json.dump(item_data, open(\"./handled/item_str.json\", \"w\"))" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "item_data = json.load(open(\"./handled/item_str.json\", \"r\"))" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "import jsonlines\n", 134 | "\n", 135 | "def save_data(data_path, data):\n", 136 | " '''write all_data list to a new jsonl'''\n", 137 | " with jsonlines.open(\"./handled/\"+ data_path, \"w\") as w:\n", 138 | " for meta_data in data:\n", 139 | " w.write(meta_data)\n", 140 | "\n", 141 | "id_map = json.load(open(\"./handled/id_map.json\", \"r\"))[\"item2id\"]\n", 142 | "json_data = []\n", 143 | "for key, value in item_data.items():\n", 144 | " json_data.append({\"input\": value, \"target\": \"\", \"item\": key, \"item_id\": id_map[key]})\n", 145 | "\n", 146 | "save_data(\"item_str.jsonline\", json_data)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "import requests\n", 156 | "import json\n", 157 | "\n", 158 | "url = \"\"\n", 159 | "\n", 160 | "payload = json.dumps({\n", 161 | " \"model\": \"text-embedding-ada-002\",\n", 162 | " \"input\": \"The food was delicious and the waiter...\"\n", 163 | "})\n", 164 | "headers = {\n", 165 | " 'Authorization': '',\n", 166 | " 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n", 167 | " 'Content-Type': 'application/json'\n", 168 | "}\n", 169 | "\n", 170 | "response = requests.request(\"POST\", url, headers=headers, data=payload)\n", 171 | "\n", 172 | "print(response.text)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "def get_response(prompt):\n", 182 | " url = \"\"\n", 183 | "\n", 184 | " payload = json.dumps({\n", 185 | " \"model\": \"text-embedding-ada-002\",\n", 186 | " \"input\": prompt\n", 187 | " })\n", 188 | " headers = {\n", 189 | " 'Authorization': '',\n", 190 | " 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n", 191 | " 'Content-Type': 'application/json'\n", 192 | " }\n", 193 | "\n", 194 | " response = requests.request(\"POST\", url, headers=headers, data=payload)\n", 195 | " re_json = json.loads(response.text)\n", 196 | "\n", 197 | " return re_json[\"data\"][0][\"embedding\"]" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "item_emb = {}" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "value_list = []\n", 216 | "\n", 217 | "for key, value in tqdm(item_data.items()):\n", 218 | " if len(value) > 4096:\n", 219 | " value_list.append(key)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "value_list" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "while 1: # avoid broken due to internet connection\n", 238 | " try:\n", 239 | " for key, value in tqdm(item_data.items()):\n", 240 | " if key not in item_emb.keys():\n", 241 | " item_emb[key] = get_response(value)\n", 242 | " except:\n", 243 | " continue\n", 244 | " if len(item_emb) == len(item_data):\n", 245 | " break" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "len(item_emb)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "id_map = json.load(open(\"./handled/id_map.json\", \"r\"))[\"id2item\"]\n", 264 | "emb_list = []\n", 265 | "for id in range(1, len(item_emb)+1):\n", 266 | " meta_emb = item_emb[id_map[str(id)]]\n", 267 | " emb_list.append(meta_emb)\n", 268 | "\n", 269 | "emb_list = np.array(emb_list)\n", 270 | "pickle.dump(emb_list, open(\"./handled/itm_emb_np.pkl\", \"wb\"))" 271 | ] 272 | } 273 | ], 274 | "metadata": { 275 | "kernelspec": { 276 | "display_name": "llm", 277 | "language": "python", 278 | "name": "python3" 279 | }, 280 | "language_info": { 281 | "codemirror_mode": { 282 | "name": "ipython", 283 | "version": 3 284 | }, 285 | "file_extension": ".py", 286 | "mimetype": "text/x-python", 287 | "name": "python", 288 | "nbconvert_exporter": "python", 289 | "pygments_lexer": "ipython3", 290 | "version": "3.9.5" 291 | }, 292 | "orig_nbformat": 4 293 | }, 294 | "nbformat": 4, 295 | "nbformat_minor": 2 296 | } 297 | -------------------------------------------------------------------------------- /data/fashion/get_user_embedding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import pickle\n", 11 | "import jsonlines\n", 12 | "import pandas as pd\n", 13 | "import numpy as np\n", 14 | "import json\n", 15 | "import copy\n", 16 | "from tqdm import tqdm\n", 17 | "from collections import defaultdict\n", 18 | "import requests" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "id_map = json.load(open(\"./handled/id_map.json\"))\n", 28 | "item_dict = json.load(open(\"./handled/item2attributes.json\", \"r\"))" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "def load_dataset():\n", 38 | " '''Load train, validation, test dataset'''\n", 39 | "\n", 40 | " usernum = 0\n", 41 | " itemnum = 0\n", 42 | " User = defaultdict(list) # default value is a blank list\n", 43 | " user_train = {}\n", 44 | " user_valid = {}\n", 45 | " user_test = {}\n", 46 | " # assume user/item index starting from 1\n", 47 | " f = open('./handled/inter.txt', 'r')\n", 48 | " for line in f: # use a dict to save all seqeuces of each user\n", 49 | " u, i = line.rstrip().split(' ')\n", 50 | " u = int(u)\n", 51 | " i = int(i)\n", 52 | " usernum = max(u, usernum)\n", 53 | " itemnum = max(i, itemnum)\n", 54 | " User[u].append(i)\n", 55 | "\n", 56 | " for user in tqdm(User):\n", 57 | " nfeedback = len(User[user])\n", 58 | " #nfeedback = len(User[user])\n", 59 | " if nfeedback < 3:\n", 60 | " user_train[user] = User[user]\n", 61 | " user_valid[user] = []\n", 62 | " user_test[user] = []\n", 63 | " else:\n", 64 | " user_train[user] = User[user][:-2]\n", 65 | " user_valid[user] = []\n", 66 | " user_valid[user].append(User[user][-2])\n", 67 | " user_test[user] = []\n", 68 | " user_test[user].append(User[user][-1])\n", 69 | " \n", 70 | " return user_train\n", 71 | " " 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "inter = load_dataset()" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "prompt_template = \"The user has visited following fashions: \\n<HISTORY> \\nplease conclude the user's perference.\"" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "def get_user_prompt(history):\n", 99 | "\n", 100 | " user_str = copy.deepcopy(prompt_template)\n", 101 | " hist_str = \"\"\n", 102 | " for item in history:\n", 103 | " try: # some item does not have title\n", 104 | " item_str = item_dict[id_map[\"id2item\"][str(item)]][\"title\"]\n", 105 | " hist_str = hist_str + item_str + \", \"\n", 106 | " except:\n", 107 | " continue\n", 108 | "\n", 109 | " # limit the prompt length\n", 110 | " if len(hist_str) > 8000:\n", 111 | " hist_str = hist_str[-8000:]\n", 112 | "\n", 113 | " user_str = user_str.replace(\"<HISTORY>\", hist_str)\n", 114 | "\n", 115 | " return user_str" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "user_data = {}\n", 125 | "\n", 126 | "for user, history in tqdm(inter.items()):\n", 127 | " user_data[user] = get_user_prompt(history)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "json.dump(user_data, open(\"./handled/user_str.json\", \"w\"))" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "user_data = json.load(open(\"./handled/user_str.json\", \"r\"))" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "url = \"\"\n", 155 | "\n", 156 | "payload = json.dumps({\n", 157 | " \"model\": \"text-embedding-ada-002\",\n", 158 | " \"input\": \"The food was delicious and the waiter...\"\n", 159 | "})\n", 160 | "headers = {\n", 161 | " 'Authorization': '',\n", 162 | " 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n", 163 | " 'Content-Type': 'application/json'\n", 164 | "}\n", 165 | "\n", 166 | "response = requests.request(\"POST\", url, headers=headers, data=payload)\n", 167 | "\n", 168 | "print(response.text)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "def get_response(prompt):\n", 178 | " url = \"\"\n", 179 | "\n", 180 | " payload = json.dumps({\n", 181 | " \"model\": \"text-embedding-ada-002\",\n", 182 | " \"input\": prompt\n", 183 | " })\n", 184 | " headers = {\n", 185 | " 'Authorization': '',\n", 186 | " 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n", 187 | " 'Content-Type': 'application/json'\n", 188 | " }\n", 189 | "\n", 190 | " response = requests.request(\"POST\", url, headers=headers, data=payload)\n", 191 | " re_json = json.loads(response.text)\n", 192 | "\n", 193 | " return re_json[\"data\"][0][\"embedding\"]" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "user_emb = {}" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "while 1: # avoid broken due to internet connection\n", 212 | " try:\n", 213 | " for key, value in tqdm(user_data.items()):\n", 214 | " if key not in user_emb.keys():\n", 215 | " user_emb[key] = get_response(value)\n", 216 | " except:\n", 217 | " continue\n", 218 | " if len(user_emb) == len(user_data):\n", 219 | " break" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "len(user_emb)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "emb_list = []\n", 238 | "for key, value in tqdm(user_emb.items()):\n", 239 | " emb_list.append(value)\n", 240 | "\n", 241 | "emb_list = np.array(emb_list)\n", 242 | "pickle.dump(emb_list, open(\"./handled/usr_emb_np.pkl\", \"wb\"))" 243 | ] 244 | } 245 | ], 246 | "metadata": { 247 | "kernelspec": { 248 | "display_name": "llm", 249 | "language": "python", 250 | "name": "python3" 251 | }, 252 | "language_info": { 253 | "codemirror_mode": { 254 | "name": "ipython", 255 | "version": 3 256 | }, 257 | "file_extension": ".py", 258 | "mimetype": "text/x-python", 259 | "name": "python", 260 | "nbconvert_exporter": "python", 261 | "pygments_lexer": "ipython3", 262 | "version": "3.9.5" 263 | }, 264 | "orig_nbformat": 4 265 | }, 266 | "nbformat": 4, 267 | "nbformat_minor": 2 268 | } 269 | -------------------------------------------------------------------------------- /data/pca.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 8, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pickle\n", 10 | "import os\n", 11 | "from sklearn.decomposition import PCA" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 9, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "dataset = \"beauty\"" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 10, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "llm_item_emb = pickle.load(open(os.path.join(dataset+\"/handled/\", \"itm_emb_np.pkl\"), \"rb\"))" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "pca = PCA(n_components=64)\n", 39 | "pca_item_emb = pca.fit_transform(llm_item_emb)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "with open(os.path.join(dataset+\"/handled/\", \"pca64_itm_emb_np.pkl\"), \"wb\") as f:\n", 49 | " pickle.dump(pca_item_emb, f)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "pca = PCA(n_components=128)\n", 59 | "pca_item_emb = pca.fit_transform(llm_item_emb)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "with open(os.path.join(dataset+\"/handled/\", \"pca_itm_emb_np.pkl\"), \"wb\") as f:\n", 69 | " pickle.dump(pca_item_emb, f)" 70 | ] 71 | } 72 | ], 73 | "metadata": { 74 | "kernelspec": { 75 | "display_name": "llm", 76 | "language": "python", 77 | "name": "python3" 78 | }, 79 | "language_info": { 80 | "codemirror_mode": { 81 | "name": "ipython", 82 | "version": 3 83 | }, 84 | "file_extension": ".py", 85 | "mimetype": "text/x-python", 86 | "name": "python", 87 | "nbconvert_exporter": "python", 88 | "pygments_lexer": "ipython3", 89 | "version": "3.9.5" 90 | }, 91 | "orig_nbformat": 4 92 | }, 93 | "nbformat": 4, 94 | "nbformat_minor": 2 95 | } 96 | -------------------------------------------------------------------------------- /data/retrieval_users.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import pickle\n", 11 | "import numpy as np\n", 12 | "np.random.seed(42)\n", 13 | "import pandas as pd\n", 14 | "from collections import defaultdict\n", 15 | "from sklearn.metrics.pairwise import cosine_similarity\n", 16 | "import matplotlib.pyplot as plt" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "dataset = \"yelp_new\"\n", 26 | "sim_metric = \"cos\"\n", 27 | "topk = 100" 28 | ] 29 | }, 30 | { 31 | "attachments": {}, 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "### Get the topk similar user" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# load the llm user embedding\n", 45 | "user_emb = pickle.load(open(os.path.join(dataset+\"/handled/\", \"usr_emb_np.pkl\"), \"rb\"))" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# calculate the similarity score between users based on llm user embedding\n", 55 | "if sim_metric == \"sin\":\n", 56 | " score_matrix = np.dot(user_emb, user_emb.T)\n", 57 | "elif sim_metric == \"cos\":\n", 58 | " score_matrix = cosine_similarity(user_emb, user_emb)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "plt.hist(score_matrix[0], bins=10)\n", 68 | "plt.show()" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "rank_matrix = np.argsort(-score_matrix, axis=-1) # user id starts from 0" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "final_rank_matrix = rank_matrix[:, 1:]\n", 87 | "final_rank_matrix = final_rank_matrix[:, :topk]" 88 | ] 89 | }, 90 | { 91 | "attachments": {}, 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "### Get the sequence length of each user" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "User = defaultdict(list)\n", 105 | "seq_len = []\n", 106 | "usernum, itemnum = 0, 0\n", 107 | "f = open('./%s/handled/%s.txt' % (dataset, \"inter\"), 'r')\n", 108 | "for line in f: # use a dict to save all seqeuces of each user\n", 109 | " u, i = line.rstrip().split(' ')\n", 110 | " u = int(u)\n", 111 | " i = int(i)\n", 112 | " usernum = max(u, usernum)\n", 113 | " itemnum = max(i, itemnum)\n", 114 | " User[u].append(i)\n", 115 | "\n", 116 | "for user, seq in User.items():\n", 117 | " seq_len.append(len(seq))" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "sim_user_len = []\n", 127 | "for sim_user_list in final_rank_matrix:\n", 128 | " avg_len = 0\n", 129 | " for sim_user in sim_user_list:\n", 130 | " avg_len += seq_len[sim_user] / topk\n", 131 | " sim_user_len.append(avg_len)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "np.mean(sim_user_len), np.mean(seq_len)" 141 | ] 142 | }, 143 | { 144 | "attachments": {}, 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "### Select the similar user" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "sim_users = []\n", 158 | "for sim_user_list in final_rank_matrix:\n", 159 | " sim_users.append(np.random.choice(sim_user_list, 1)[0])" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "final_rank_matrix.shape" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "## Save llm embedding based similar users\n", 178 | "pickle.dump(final_rank_matrix, open(os.path.join(dataset+\"/handled/\", \"sim_user_100.pkl\"), \"wb\"))" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [] 187 | } 188 | ], 189 | "metadata": { 190 | "kernelspec": { 191 | "display_name": "llm", 192 | "language": "python", 193 | "name": "python3" 194 | }, 195 | "language_info": { 196 | "codemirror_mode": { 197 | "name": "ipython", 198 | "version": 3 199 | }, 200 | "file_extension": ".py", 201 | "mimetype": "text/x-python", 202 | "name": "python", 203 | "nbconvert_exporter": "python", 204 | "pygments_lexer": "ipython3", 205 | "version": "3.9.5" 206 | }, 207 | "orig_nbformat": 4 208 | }, 209 | "nbformat": 4, 210 | "nbformat_minor": 2 211 | } 212 | -------------------------------------------------------------------------------- /data/yelp/get_item_embedding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import pickle\n", 11 | "import jsonlines\n", 12 | "import pandas as pd\n", 13 | "import numpy as np\n", 14 | "import json\n", 15 | "import copy\n", 16 | "from tqdm import tqdm" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "prompt_template = \"The point of interest has following attributes: \\n name is <NAME>; category is <CATEGORY>; type is <TYPE>; open status is <OPEN>; review count is <COUNT>; city is <CITY>; average score is <STARS>.\"" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "data = json.load(open(\"./handled/item2attributes.json\", \"r\"))" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "data[\"PzOqRohWw7F7YEPBz6AubA\"]" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "item_data = {}\n", 53 | "for key, value in tqdm(data.items()):\n", 54 | " item_str = copy.deepcopy(prompt_template)\n", 55 | " item_str = item_str.replace(\"<NAME>\", value[\"name\"])\n", 56 | " cate_str = \"\"\n", 57 | " for cate in value[\"categories\"]:\n", 58 | " cate_str += (cate + \" \")\n", 59 | " item_str = item_str.replace(\"<CATEGORY>\", cate_str)\n", 60 | " item_str = item_str.replace(\"<TYPE>\", value[\"type\"])\n", 61 | " item_str = item_str.replace(\"<OPEN>\", str(value[\"open\"]))\n", 62 | " item_str = item_str.replace(\"<COUNT>\", str(value[\"review_count\"]))\n", 63 | " item_str = item_str.replace(\"<CITY>\", value[\"city\"])\n", 64 | " item_str = item_str.replace(\"<STARS>\", str(value[\"stars\"]))\n", 65 | " \n", 66 | " item_data[key] = item_str" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "item_data['PzOqRohWw7F7YEPBz6AubA']" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "json.dump(item_data, open(\"./handled/item_str.json\", \"w\"))" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "item_data = json.load(open(\"./handled/item_str.json\", \"r\"))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "import jsonlines\n", 103 | "\n", 104 | "def save_data(data_path, data):\n", 105 | " '''write all_data list to a new jsonl'''\n", 106 | " with jsonlines.open(\"./handled/\"+ data_path, \"w\") as w:\n", 107 | " for meta_data in data:\n", 108 | " w.write(meta_data)\n", 109 | "\n", 110 | "id_map = json.load(open(\"./handled/id_map.json\", \"r\"))[\"item2id\"]\n", 111 | "json_data = []\n", 112 | "for key, value in item_data.items():\n", 113 | " json_data.append({\"input\": value, \"target\": \"\", \"item\": key, \"item_id\": id_map[key]})\n", 114 | "\n", 115 | "save_data(\"item_str.jsonline\", json_data)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "import requests\n", 125 | "import json\n", 126 | "\n", 127 | "url = \"\"\n", 128 | "\n", 129 | "payload = json.dumps({\n", 130 | " \"model\": \"text-embedding-ada-002\",\n", 131 | " \"input\": \"The food was delicious and the waiter...\"\n", 132 | "})\n", 133 | "headers = {\n", 134 | " 'Authorization': '',\n", 135 | " 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n", 136 | " 'Content-Type': 'application/json'\n", 137 | "}\n", 138 | "\n", 139 | "response = requests.request(\"POST\", url, headers=headers, data=payload)\n", 140 | "\n", 141 | "print(response.text)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "re_json = json.loads(response.text)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "len(re_json[\"data\"][0][\"embedding\"])" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "def get_response(prompt):\n", 169 | " url = \"\"\n", 170 | "\n", 171 | " payload = json.dumps({\n", 172 | " \"model\": \"text-embedding-ada-002\",\n", 173 | " \"input\": prompt\n", 174 | " })\n", 175 | " headers = {\n", 176 | " 'Authorization': '',\n", 177 | " 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n", 178 | " 'Content-Type': 'application/json'\n", 179 | " }\n", 180 | "\n", 181 | " response = requests.request(\"POST\", url, headers=headers, data=payload)\n", 182 | " re_json = json.loads(response.text)\n", 183 | "\n", 184 | " return re_json[\"data\"][0][\"embedding\"]" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "item_emb = {}" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "while 1: # avoid broken due to internet connection\n", 203 | " try:\n", 204 | " for key, value in tqdm(item_data.items()):\n", 205 | " if key not in item_emb.keys():\n", 206 | " item_emb[key] = get_response(value)\n", 207 | " except:\n", 208 | " continue\n", 209 | " if len(item_emb) == len(item_data):\n", 210 | " break" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "len(item_emb)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "id_map = json.load(open(\"./handled/id_map.json\", \"r\"))[\"id2item\"]\n", 229 | "emb_list = []\n", 230 | "for id in range(1, len(item_emb)+1):\n", 231 | " meta_emb = item_emb[id_map[str(id)]]\n", 232 | " emb_list.append(meta_emb)\n", 233 | "\n", 234 | "emb_list = np.array(emb_list)\n", 235 | "pickle.dump(emb_list, open(\"./handled/itm_emb_np.pkl\", \"wb\"))" 236 | ] 237 | } 238 | ], 239 | "metadata": { 240 | "kernelspec": { 241 | "display_name": "llm", 242 | "language": "python", 243 | "name": "python3" 244 | }, 245 | "language_info": { 246 | "codemirror_mode": { 247 | "name": "ipython", 248 | "version": 3 249 | }, 250 | "file_extension": ".py", 251 | "mimetype": "text/x-python", 252 | "name": "python", 253 | "nbconvert_exporter": "python", 254 | "pygments_lexer": "ipython3", 255 | "version": "3.9.5" 256 | }, 257 | "orig_nbformat": 4 258 | }, 259 | "nbformat": 4, 260 | "nbformat_minor": 2 261 | } 262 | -------------------------------------------------------------------------------- /data/yelp/get_user_embedding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import pickle\n", 11 | "import jsonlines\n", 12 | "import pandas as pd\n", 13 | "import numpy as np\n", 14 | "import json\n", 15 | "import copy\n", 16 | "from tqdm import tqdm\n", 17 | "from collections import defaultdict\n", 18 | "import requests" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "id_map = json.load(open(\"./handled/id_map.json\"))\n", 28 | "item_dict = json.load(open(\"./handled/item2attributes.json\", \"r\"))" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "def load_dataset():\n", 38 | " '''Load train, validation, test dataset'''\n", 39 | "\n", 40 | " usernum = 0\n", 41 | " itemnum = 0\n", 42 | " User = defaultdict(list) # default value is a blank list\n", 43 | " user_train = {}\n", 44 | " user_valid = {}\n", 45 | " user_test = {}\n", 46 | " # assume user/item index starting from 1\n", 47 | " f = open('./handled/inter.txt', 'r')\n", 48 | " for line in f: # use a dict to save all seqeuces of each user\n", 49 | " u, i = line.rstrip().split(' ')\n", 50 | " u = int(u)\n", 51 | " i = int(i)\n", 52 | " usernum = max(u, usernum)\n", 53 | " itemnum = max(i, itemnum)\n", 54 | " User[u].append(i)\n", 55 | "\n", 56 | " for user in tqdm(User):\n", 57 | " nfeedback = len(User[user])\n", 58 | " #nfeedback = len(User[user])\n", 59 | " if nfeedback < 3:\n", 60 | " user_train[user] = User[user]\n", 61 | " user_valid[user] = []\n", 62 | " user_test[user] = []\n", 63 | " else:\n", 64 | " user_train[user] = User[user][:-2]\n", 65 | " user_valid[user] = []\n", 66 | " user_valid[user].append(User[user][-2])\n", 67 | " user_test[user] = []\n", 68 | " user_test[user].append(User[user][-1])\n", 69 | " \n", 70 | " return user_train\n", 71 | " " 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "inter = load_dataset()" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "prompt_template = \"The user has visited following point of interests: \\n<HISTORY> \\nplease conclude the user's perference.\"" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "def get_user_prompt(history):\n", 99 | "\n", 100 | " user_str = copy.deepcopy(prompt_template)\n", 101 | " hist_str = \"\"\n", 102 | " for item in history:\n", 103 | " item_str = item_dict[id_map[\"id2item\"][str(item)]][\"name\"]\n", 104 | " hist_str = hist_str + item_str + \", \"\n", 105 | "\n", 106 | " # limit the prompt length\n", 107 | " if len(hist_str) > 8000:\n", 108 | " hist_str = hist_str[-8000:]\n", 109 | "\n", 110 | " user_str = user_str.replace(\"<HISTORY>\", hist_str)\n", 111 | "\n", 112 | " return user_str" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "user_data = {}\n", 122 | "\n", 123 | "for user, history in tqdm(inter.items()):\n", 124 | " user_data[user] = get_user_prompt(history)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "json.dump(user_data, open(\"./handled/user_str.json\", \"w\"))" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "user_data = json.load(open(\"./handled/user_str.json\", \"r\"))" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "url = \"\"\n", 152 | "\n", 153 | "payload = json.dumps({\n", 154 | " \"model\": \"text-embedding-ada-002\",\n", 155 | " \"input\": \"The food was delicious and the waiter...\"\n", 156 | "})\n", 157 | "headers = {\n", 158 | " 'Authorization': '',\n", 159 | " 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n", 160 | " 'Content-Type': 'application/json'\n", 161 | "}\n", 162 | "\n", 163 | "response = requests.request(\"POST\", url, headers=headers, data=payload)\n", 164 | "\n", 165 | "print(response.text)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "def get_response(prompt):\n", 175 | " url = \"\"\n", 176 | "\n", 177 | " payload = json.dumps({\n", 178 | " \"model\": \"text-embedding-ada-002\",\n", 179 | " \"input\": prompt\n", 180 | " })\n", 181 | " headers = {\n", 182 | " 'Authorization': '',\n", 183 | " 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n", 184 | " 'Content-Type': 'application/json'\n", 185 | " }\n", 186 | "\n", 187 | " response = requests.request(\"POST\", url, headers=headers, data=payload)\n", 188 | " re_json = json.loads(response.text)\n", 189 | "\n", 190 | " return re_json[\"data\"][0][\"embedding\"]" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "user_emb = {}" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "while 1: # avoid broken due to internet connection\n", 209 | " try:\n", 210 | " for key, value in tqdm(user_data.items()):\n", 211 | " if key not in user_emb.keys():\n", 212 | " user_emb[key] = get_response(value)\n", 213 | " except:\n", 214 | " continue\n", 215 | " if len(user_emb) == len(user_data):\n", 216 | " break" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "len(user_emb)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "emb_list = []\n", 235 | "for key, value in tqdm(user_emb.items()):\n", 236 | " emb_list.append(value)\n", 237 | "\n", 238 | "emb_list = np.array(emb_list)\n", 239 | "pickle.dump(emb_list, open(\"./handled/usr_emb_np.pkl\", \"wb\"))" 240 | ] 241 | } 242 | ], 243 | "metadata": { 244 | "kernelspec": { 245 | "display_name": "llm", 246 | "language": "python", 247 | "name": "python3" 248 | }, 249 | "language_info": { 250 | "codemirror_mode": { 251 | "name": "ipython", 252 | "version": 3 253 | }, 254 | "file_extension": ".py", 255 | "mimetype": "text/x-python", 256 | "name": "python", 257 | "nbconvert_exporter": "python", 258 | "pygments_lexer": "ipython3", 259 | "version": "3.9.5" 260 | }, 261 | "orig_nbformat": 4 262 | }, 263 | "nbformat": 4, 264 | "nbformat_minor": 2 265 | } 266 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: llm 2 | channels: 3 | - moussi 4 | - psi4 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - asttokens=2.2.1=pyhd8ed1ab_0 11 | - backcall=0.2.0=pyh9f0ad1d_0 12 | - backports=1.0=pyhd8ed1ab_3 13 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 14 | - binutils_impl_linux-64=2.31.1=h6176602_1 15 | - binutils_linux-64=2.31.1=h6176602_9 16 | - ca-certificates=2023.5.7=hbcca054_0 17 | - cachetools=5.3.0=pyhd8ed1ab_0 18 | - cloog=0.18.0=0 19 | - decorator=5.1.1=pyhd8ed1ab_0 20 | - entrypoints=0.4=pyhd8ed1ab_0 21 | - executing=1.2.0=pyhd8ed1ab_0 22 | - gcc-5=5.2.0=1 23 | - gcc_impl_linux-64=7.3.0=hd420e75_5 24 | - gcc_linux-64=7.3.0=h553295d_9 25 | - gmp=6.2.1=h2531618_2 26 | - ipykernel=5.5.5=py39hef51801_0 27 | - ipython=8.14.0=pyh41d4057_0 28 | - ipython_genutils=0.2.0=py_1 29 | - isl=0.12.2=0 30 | - jedi=0.18.2=pyhd8ed1ab_0 31 | - jupyter_client=7.1.2=pyhd8ed1ab_0 32 | - jupyter_core=5.3.1=py39hf3d152e_0 33 | - ld_impl_linux-64=2.33.1=h53a641e_7 34 | - libffi=3.3=he6710b0_2 35 | - libgcc-ng=7.3.0=hdf63c60_0 36 | - libgomp=11.2.0=h1234567_1 37 | - libsodium=1.0.18=h516909a_0 38 | - libstdcxx-ng=11.2.0=h1234567_1 39 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 40 | - mpc=1.1.0=h10f8cd9_1 41 | - mpfr=4.0.2=hb69a4c5_1 42 | - ncurses=6.2=he6710b0_1 43 | - nest-asyncio=1.5.6=pyhd8ed1ab_0 44 | - nvidia-ml-py=11.525.112=pyhd8ed1ab_0 45 | - nvitop=1.1.2=py39hf3d152e_0 46 | - openssl=1.1.1k=h27cfd23_0 47 | - parso=0.8.3=pyhd8ed1ab_0 48 | - pexpect=4.8.0=pyh1a96a4e_2 49 | - pickleshare=0.7.5=py_1003 50 | - pip=23.1.2=py39h06a4308_0 51 | - platformdirs=3.5.3=pyhd8ed1ab_0 52 | - prompt-toolkit=3.0.38=pyha770c72_0 53 | - prompt_toolkit=3.0.38=hd8ed1ab_0 54 | - ptyprocess=0.7.0=pyhd3deb0d_0 55 | - pure_eval=0.2.2=pyhd8ed1ab_0 56 | - pygments=2.15.1=pyhd8ed1ab_0 57 | - python=3.9.5=hdb3f193_3 58 | - python-dateutil=2.8.2=pyhd8ed1ab_0 59 | - python_abi=3.9=2_cp39 60 | - pyzmq=20.0.0=py39h2531618_1 61 | - readline=8.1=h27cfd23_0 62 | - setuptools=67.8.0=py39h06a4308_0 63 | - six=1.16.0=pyh6c4a22f_0 64 | - sqlite=3.35.4=hdfb4753_0 65 | - stack_data=0.6.2=pyhd8ed1ab_0 66 | - termcolor=2.3.0=pyhd8ed1ab_0 67 | - tk=8.6.10=hbc83047_0 68 | - tornado=6.1=py39h27cfd23_0 69 | - traitlets=5.9.0=pyhd8ed1ab_0 70 | - typing-extensions=4.6.3=hd8ed1ab_0 71 | - typing_extensions=4.6.3=pyha770c72_0 72 | - wcwidth=0.2.6=pyhd8ed1ab_0 73 | - wheel=0.38.4=py39h06a4308_0 74 | - xz=5.2.5=h7b6447c_0 75 | - zeromq=4.3.4=h2531618_0 76 | - zlib=1.2.11=h7b6447c_3 77 | - pip: 78 | - absl-py==2.1.0 79 | - accelerate==0.18.0 80 | - aiohttp==3.8.4 81 | - aiosignal==1.3.1 82 | - args==0.1.0 83 | - async-timeout==4.0.2 84 | - attrs==23.1.0 85 | - bayesian-optimization==1.4.3 86 | - certifi==2023.5.7 87 | - charset-normalizer==3.1.0 88 | - click==8.1.3 89 | - clint==0.5.1 90 | - cma==3.3.0 91 | - colorama==0.4.6 92 | - conda-pack==0.7.1 93 | - contourpy==1.1.0 94 | - coverage==7.4.4 95 | - cpm-kernels==1.0.11 96 | - cycler==0.11.0 97 | - datasets==2.12.0 98 | - deepspeed==0.9.4 99 | - dill==0.3.6 100 | - dnc==1.1.0 101 | - docopt==0.6.2 102 | - docstring-parser==0.15 103 | - einops==0.7.0 104 | - et-xmlfile==1.1.0 105 | - filelock==3.12.0 106 | - flann==1.6.13 107 | - fonttools==4.42.0 108 | - frozenlist==1.3.3 109 | - fsspec==2023.5.0 110 | - future==0.18.3 111 | - grpcio==1.62.0 112 | - hjson==3.1.0 113 | - huggingface-hub==0.20.1 114 | - icetk==0.0.7 115 | - idna==3.4 116 | - importlib-metadata==7.0.1 117 | - importlib-resources==6.0.1 118 | - jieba==0.42.1 119 | - joblib==1.2.0 120 | - jsonlines==3.1.0 121 | - kiwisolver==1.4.4 122 | - mamba==0.11.3 123 | - markdown==3.5.2 124 | - markdown-it-py==3.0.0 125 | - markupsafe==2.1.5 126 | - matplotlib==3.7.2 127 | - mdurl==0.1.2 128 | - multidict==6.0.4 129 | - multiprocess==0.70.14 130 | - nevergrad==0.13.0 131 | - ninja==1.11.1 132 | - nltk==3.8.1 133 | - numpy==1.24.3 134 | - openai==0.28.0 135 | - openpyxl==3.1.2 136 | - openyxl==0.1 137 | - packaging==23.1 138 | - pandas==2.0.1 139 | - pillow==9.5.0 140 | - pipreqs==0.4.13 141 | - protobuf==4.25.3 142 | - psutil==5.9.5 143 | - py-cpuinfo==9.0.0 144 | - pyarrow==12.0.0 145 | - pydantic==1.10.9 146 | - pyparsing==3.0.9 147 | - pytz==2023.3 148 | - pyyaml==6.0 149 | - rdkit==2023.9.1 150 | - regex==2023.6.3 151 | - requests==2.31.0 152 | - responses==0.18.0 153 | - rich==13.7.0 154 | - rouge-chinese==1.0.3 155 | - safetensors==0.3.1 156 | - scikit-learn==1.3.0 157 | - scipy==1.11.1 158 | - sentencepiece==0.1.99 159 | - setproctitle==1.3.3 160 | - shtab==1.6.5 161 | - sklearn==0.0.post5 162 | - tensorboard==2.16.2 163 | - tensorboard-data-server==0.7.2 164 | - threadpoolctl==3.1.0 165 | - tiktoken==0.6.0 166 | - tokenizers==0.13.3 167 | - torch==1.12.0+cu102 168 | - torchaudio==0.12.0+cu102 169 | - torchvision==0.13.0+cu102 170 | - tqdm==4.65.0 171 | - transformers==4.28.1 172 | - trl==0.7.6 173 | - tyro==0.6.1 174 | - tzdata==2023.3 175 | - urllib3==2.0.3 176 | - werkzeug==3.0.1 177 | - xxhash==3.2.0 178 | - yarg==0.1.9 179 | - yarl==1.9.2 180 | - zipp==3.16.2 181 | prefix: /data/anaconda3/envs/llm 182 | -------------------------------------------------------------------------------- /experiments/beauty.bash: -------------------------------------------------------------------------------- 1 | ## LLM-ESR -- SASRec, Bert4Rec, GRU4Rec 2 | gpu_id=0 3 | dataset="beauty" 4 | seed_list=(42 43 44) 5 | ts_user=9 6 | ts_item=4 7 | 8 | model_name="llmesr_sasrec" 9 | for seed in ${seed_list[@]} 10 | do 11 | python main.py --dataset ${dataset} \ 12 | --model_name ${model_name} \ 13 | --hidden_size 64 \ 14 | --train_batch_size 128 \ 15 | --max_len 200 \ 16 | --gpu_id ${gpu_id} \ 17 | --num_workers 8 \ 18 | --num_train_epochs 200 \ 19 | --seed ${seed} \ 20 | --check_path "" \ 21 | --patience 20 \ 22 | --ts_user ${ts_user} \ 23 | --ts_item ${ts_item} \ 24 | --freeze \ 25 | --log \ 26 | --user_sim_func kd \ 27 | --alpha 0.1 \ 28 | --use_cross_att 29 | done 30 | 31 | 32 | model_name="llmesr_bert4rec" 33 | mask_prob=0.6 34 | for seed in ${seed_list[@]} 35 | do 36 | python main.py --dataset ${dataset} \ 37 | --model_name ${model_name} \ 38 | --hidden_size 64 \ 39 | --train_batch_size 128 \ 40 | --max_len 200 \ 41 | --gpu_id ${gpu_id} \ 42 | --num_workers 8 \ 43 | --mask_prob ${mask_prob} \ 44 | --num_train_epochs 200 \ 45 | --seed ${seed} \ 46 | --check_path "" \ 47 | --patience 20 \ 48 | --ts_user ${ts_user} \ 49 | --ts_item ${ts_item} \ 50 | --freeze \ 51 | --log \ 52 | --user_sim_func kd \ 53 | --alpha 0.1 \ 54 | --use_cross_att 55 | done 56 | 57 | 58 | model_name="llmesr_gru4rec" 59 | for seed in ${seed_list[@]} 60 | do 61 | python main.py --dataset ${dataset} \ 62 | --model_name ${model_name} \ 63 | --hidden_size 64 \ 64 | --train_batch_size 128 \ 65 | --max_len 200 \ 66 | --gpu_id ${gpu_id} \ 67 | --num_workers 8 \ 68 | --num_train_epochs 200 \ 69 | --seed ${seed} \ 70 | --check_path "" \ 71 | --patience 20 \ 72 | --ts_user ${ts_user} \ 73 | --ts_item ${ts_item} \ 74 | --freeze \ 75 | --log \ 76 | --user_sim_func kd \ 77 | --alpha 0.1 \ 78 | --use_cross_att 79 | done -------------------------------------------------------------------------------- /experiments/fashion.bash: -------------------------------------------------------------------------------- 1 | ## LLM-ESR -- SASRec, Bert4Rec, GRU4Rec 2 | gpu_id=0 3 | dataset="fashion" 4 | seed_list=(42 43 44) 5 | ts_user=3 6 | ts_item=4 7 | 8 | model_name="llmesr_sasrec" 9 | for seed in ${seed_list[@]} 10 | do 11 | python main.py --dataset ${dataset} \ 12 | --model_name ${model_name} \ 13 | --hidden_size 64 \ 14 | --train_batch_size 128 \ 15 | --max_len 200 \ 16 | --gpu_id ${gpu_id} \ 17 | --num_workers 8 \ 18 | --num_train_epochs 200 \ 19 | --seed ${seed} \ 20 | --check_path "" \ 21 | --patience 20 \ 22 | --ts_user ${ts_user} \ 23 | --ts_item ${ts_item} \ 24 | --freeze \ 25 | --log \ 26 | --user_sim_func kd \ 27 | --alpha 0.1 \ 28 | --use_cross_att 29 | done 30 | 31 | 32 | model_name="llmesr_bert4rec" 33 | mask_prob=0.6 34 | for seed in ${seed_list[@]} 35 | do 36 | python main.py --dataset ${dataset} \ 37 | --model_name ${model_name} \ 38 | --hidden_size 64 \ 39 | --train_batch_size 128 \ 40 | --max_len 200 \ 41 | --gpu_id ${gpu_id} \ 42 | --num_workers 8 \ 43 | --mask_prob ${mask_prob} \ 44 | --num_train_epochs 200 \ 45 | --seed ${seed} \ 46 | --check_path "" \ 47 | --patience 20 \ 48 | --ts_user ${ts_user} \ 49 | --ts_item ${ts_item} \ 50 | --freeze \ 51 | --log \ 52 | --user_sim_func kd \ 53 | --alpha 0.1 \ 54 | --use_cross_att 55 | done 56 | 57 | 58 | model_name="llmesr_gru4rec" 59 | for seed in ${seed_list[@]} 60 | do 61 | python main.py --dataset ${dataset} \ 62 | --model_name ${model_name} \ 63 | --hidden_size 64 \ 64 | --train_batch_size 128 \ 65 | --max_len 200 \ 66 | --gpu_id ${gpu_id} \ 67 | --num_workers 8 \ 68 | --num_train_epochs 200 \ 69 | --seed ${seed} \ 70 | --check_path "" \ 71 | --patience 20 \ 72 | --ts_user ${ts_user} \ 73 | --ts_item ${ts_item} \ 74 | --freeze \ 75 | --log \ 76 | --user_sim_func kd \ 77 | --alpha 0.1 \ 78 | --use_cross_att 79 | done -------------------------------------------------------------------------------- /experiments/yelp.bash: -------------------------------------------------------------------------------- 1 | ## LLM-ESR -- SASRec, Bert4Rec, GRU4Rec 2 | gpu_id=0 3 | dataset="yelp" 4 | seed_list=(42 43 44) 5 | 6 | model_name="llmesr_sasrec" 7 | for seed in ${seed_list[@]} 8 | do 9 | python main.py --dataset ${dataset} \ 10 | --model_name ${model_name} \ 11 | --hidden_size 64 \ 12 | --train_batch_size 128 \ 13 | --max_len 200 \ 14 | --gpu_id ${gpu_id} \ 15 | --num_workers 8 \ 16 | --num_train_epochs 200 \ 17 | --seed ${seed} \ 18 | --check_path "" \ 19 | --patience 20 \ 20 | --ts_user 12 \ 21 | --ts_item 13 \ 22 | --freeze \ 23 | --log \ 24 | --user_sim_func kd \ 25 | --alpha 0.1 \ 26 | --use_cross_att 27 | done 28 | 29 | 30 | model_name="llmesr_bert4rec" 31 | mask_prob=0.6 32 | for seed in ${seed_list[@]} 33 | do 34 | python main.py --dataset ${dataset} \ 35 | --model_name ${model_name} \ 36 | --hidden_size 64 \ 37 | --train_batch_size 128 \ 38 | --max_len 200 \ 39 | --gpu_id ${gpu_id} \ 40 | --num_workers 8 \ 41 | --mask_prob ${mask_prob} \ 42 | --num_train_epochs 200 \ 43 | --seed ${seed} \ 44 | --check_path "" \ 45 | --patience 20 \ 46 | --ts_user 12 \ 47 | --ts_item 13 \ 48 | --freeze \ 49 | --log \ 50 | --user_sim_func kd \ 51 | --alpha 0.1 \ 52 | --use_cross_att 53 | done 54 | 55 | 56 | model_name="llmesr_gru4rec" 57 | for seed in ${seed_list[@]} 58 | do 59 | python main.py --dataset ${dataset} \ 60 | --model_name ${model_name} \ 61 | --hidden_size 64 \ 62 | --train_batch_size 128 \ 63 | --max_len 200 \ 64 | --gpu_id ${gpu_id} \ 65 | --num_workers 8 \ 66 | --num_train_epochs 200 \ 67 | --seed ${seed} \ 68 | --check_path "" \ 69 | --patience 20 \ 70 | --ts_user 12 \ 71 | --ts_item 13 \ 72 | --freeze \ 73 | --log \ 74 | --user_sim_func kd \ 75 | --alpha 0.1 \ 76 | --use_cross_att 77 | done -------------------------------------------------------------------------------- /generators/bert_generator.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | from generators.generator import Generator 3 | from generators.data import BertRecTrainDatasetAllUser 4 | from torch.utils.data import DataLoader, RandomSampler 5 | from utils.utils import unzip_data 6 | 7 | 8 | 9 | class BertGeneratorAllUser(Generator): 10 | 11 | def __init__(self, args, logger, device): 12 | 13 | super().__init__(args, logger, device) 14 | 15 | 16 | def make_trainloader(self): 17 | 18 | train_dataset = unzip_data(self.train, aug=self.args.aug) 19 | self.train_dataset = BertRecTrainDatasetAllUser(self.args, train_dataset, self.item_num, self.args.max_len) 20 | 21 | train_dataloader = DataLoader(self.train_dataset, 22 | sampler=RandomSampler(self.train_dataset), 23 | batch_size=self.bs, 24 | num_workers=self.num_workers) 25 | 26 | return train_dataloader 27 | -------------------------------------------------------------------------------- /generators/data.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import os 3 | import copy 4 | import random 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | from utils.utils import random_neq 8 | import pickle 9 | 10 | 11 | class SeqDataset(Dataset): 12 | '''The train dataset for Sequential recommendation''' 13 | 14 | def __init__(self, data, item_num, max_len, neg_num=1): 15 | 16 | super().__init__() 17 | self.data = data 18 | self.item_num = item_num 19 | self.max_len = max_len 20 | self.neg_num = neg_num 21 | self.var_name = ["seq", "pos", "neg", "positions"] 22 | 23 | 24 | def __len__(self): 25 | 26 | return len(self.data) 27 | 28 | def __getitem__(self, index): 29 | 30 | inter = self.data[index] 31 | non_neg = copy.deepcopy(inter) 32 | pos = inter[-1] 33 | neg = [] 34 | for _ in range(self.neg_num): 35 | per_neg = random_neq(1, self.item_num+1, non_neg) 36 | neg.append(per_neg) 37 | non_neg.append(per_neg) 38 | neg = np.array(neg) 39 | #neg = random_neq(1, self.item_num+1, inter) 40 | 41 | seq = np.zeros([self.max_len], dtype=np.int32) 42 | idx = self.max_len - 1 43 | for i in reversed(inter[:-1]): 44 | seq[idx] = i 45 | idx -= 1 46 | if idx == -1: 47 | break 48 | 49 | if len(inter) > self.max_len: 50 | mask_len = 0 51 | positions = list(range(1, self.max_len+1)) 52 | else: 53 | mask_len = self.max_len - (len(inter) - 1) 54 | positions = list(range(1, len(inter)-1+1)) 55 | 56 | positions= positions[-self.max_len:] 57 | positions = [0] * mask_len + positions 58 | positions = np.array(positions) 59 | 60 | return seq, pos, neg, positions 61 | 62 | 63 | 64 | class SeqDatasetAllUser(SeqDataset): 65 | '''The train dataset for Sequential recommendation''' 66 | 67 | def __init__(self, args, data, item_num, max_len, neg_num=1): 68 | 69 | super().__init__(data, item_num, max_len, neg_num) 70 | self.sim_user_num = args.sim_user_num 71 | self.sim_users = pickle.load(open(os.path.join("./data/"+args.dataset+"/handled/", "sim_user_100.pkl"), "rb")) 72 | self.var_name = ["seq", "pos", "neg", "positions", "user_id", "sim_seq", "sim_positions"] 73 | 74 | 75 | def __len__(self): 76 | 77 | return len(self.data) 78 | 79 | def __getitem__(self, index): 80 | 81 | inter = self.data[index] 82 | non_neg = copy.deepcopy(inter) 83 | pos = inter[-1] 84 | neg = [] 85 | for _ in range(self.neg_num): 86 | per_neg = random_neq(1, self.item_num+1, non_neg) 87 | neg.append(per_neg) 88 | non_neg.append(per_neg) 89 | neg = np.array(neg) 90 | #neg = random_neq(1, self.item_num+1, inter) 91 | 92 | seq = np.zeros([self.max_len], dtype=np.int32) 93 | idx = self.max_len - 1 94 | for i in reversed(inter[:-1]): 95 | seq[idx] = i 96 | idx -= 1 97 | if idx == -1: 98 | break 99 | 100 | if len(inter) > self.max_len: 101 | mask_len = 0 102 | positions = list(range(1, self.max_len+1)) 103 | else: 104 | mask_len = self.max_len - (len(inter) - 1) 105 | positions = list(range(1, len(inter)-1+1)) 106 | 107 | positions= positions[-self.max_len:] 108 | positions = [0] * mask_len + positions 109 | positions = np.array(positions) 110 | 111 | ### get the sequence of similar user 112 | sim_users = self.sim_users[index][:self.sim_user_num] 113 | sim_seq, sim_positions = [], [] 114 | for sim_user in sim_users: 115 | meta_seq, meta_positions = self._get_user_seq(sim_user) 116 | sim_seq.append(meta_seq) 117 | sim_positions.append(meta_positions) 118 | 119 | sim_seq = np.array(sim_seq) 120 | sim_positions = np.array(sim_positions) 121 | 122 | return seq, pos, neg, positions, index, sim_seq, sim_positions 123 | 124 | 125 | def _get_user_seq(self, user): 126 | 127 | ### get the sequence of required user 128 | inter = self.data[user] 129 | seq = np.zeros([self.max_len], dtype=np.int32) 130 | idx = self.max_len - 1 131 | for i in reversed(inter[:-1]): 132 | seq[idx] = i 133 | idx -= 1 134 | if idx == -1: 135 | break 136 | 137 | if len(inter) > self.max_len: 138 | mask_len = 0 139 | positions = list(range(1, self.max_len+1)) 140 | else: 141 | mask_len = self.max_len - (len(inter) - 1) 142 | positions = list(range(1, len(inter)-1+1)) 143 | 144 | positions = positions[-self.max_len:] 145 | positions = [0] * mask_len + positions 146 | positions = np.array(positions) 147 | 148 | return seq, positions 149 | 150 | 151 | 152 | class Seq2SeqDataset(Dataset): 153 | '''The train dataset for Sequential recommendation with seq-to-seq loss''' 154 | 155 | def __init__(self, args, data, item_num, max_len, neg_num=1): 156 | 157 | super().__init__() 158 | self.data = data 159 | self.item_num = item_num 160 | self.max_len = max_len 161 | self.neg_num = neg_num 162 | self.aug_seq = args.aug_seq 163 | self.aug_seq_len = args.aug_seq_len 164 | self.var_name = ["seq", "pos", "neg", "positions"] 165 | 166 | 167 | def __len__(self): 168 | 169 | return len(self.data) 170 | 171 | def __getitem__(self, index): 172 | 173 | inter = self.data[index] 174 | non_neg = copy.deepcopy(inter) 175 | 176 | seq = np.zeros([self.max_len], dtype=np.int32) 177 | pos = np.zeros([self.max_len], dtype=np.int32) 178 | neg = np.zeros([self.max_len], dtype=np.int32) 179 | nxt = inter[-1] 180 | idx = self.max_len - 1 181 | for i in reversed(inter[:-1]): 182 | seq[idx] = i 183 | pos[idx] = nxt 184 | neg[idx] = random_neq(1, self.item_num+1, non_neg) 185 | nxt = i 186 | idx -= 1 187 | if idx == -1: 188 | break 189 | 190 | if self.aug_seq: 191 | seq_len = len(inter) 192 | pos[:- (seq_len - self.aug_seq_len) + 1] = 0 193 | neg[:- (seq_len - self.aug_seq_len) + 1] = 0 194 | 195 | if len(inter) > self.max_len: 196 | mask_len = 0 197 | positions = list(range(1, self.max_len+1)) 198 | else: 199 | mask_len = self.max_len - (len(inter) - 1) 200 | positions = list(range(1, len(inter)-1+1)) 201 | 202 | positions= positions[-self.max_len:] 203 | positions = [0] * mask_len + positions 204 | positions = np.array(positions) 205 | 206 | return seq, pos, neg, positions 207 | 208 | 209 | 210 | class Seq2SeqDatasetAllUser(Seq2SeqDataset): 211 | 212 | def __init__(self, args, data, item_num, max_len, neg_num=1): 213 | 214 | super().__init__(args, data, item_num, max_len, neg_num) 215 | self.sim_user_num = args.sim_user_num 216 | self.sim_users = pickle.load(open(os.path.join("./data/"+args.dataset+"/handled/", "sim_user_100.pkl"), "rb")) 217 | self.var_name = ["seq", "pos", "neg", "positions", "user_id", "sim_seq", "sim_positions"] 218 | 219 | 220 | def __getitem__(self, index): 221 | 222 | inter = self.data[index] 223 | non_neg = copy.deepcopy(inter) 224 | 225 | seq = np.zeros([self.max_len], dtype=np.int32) 226 | pos = np.zeros([self.max_len], dtype=np.int32) 227 | neg = np.zeros([self.max_len], dtype=np.int32) 228 | nxt = inter[-1] 229 | idx = self.max_len - 1 230 | for i in reversed(inter[:-1]): 231 | seq[idx] = i 232 | pos[idx] = nxt 233 | neg[idx] = random_neq(1, self.item_num+1, non_neg) 234 | nxt = i 235 | idx -= 1 236 | if idx == -1: 237 | break 238 | 239 | if self.aug_seq: 240 | seq_len = len(inter) 241 | pos[:- (seq_len - self.aug_seq_len) + 1] = 0 242 | neg[:- (seq_len - self.aug_seq_len) + 1] = 0 243 | 244 | if len(inter) > self.max_len: 245 | mask_len = 0 246 | positions = list(range(1, self.max_len+1)) 247 | else: 248 | mask_len = self.max_len - (len(inter) - 1) 249 | positions = list(range(1, len(inter)-1+1)) 250 | 251 | positions = positions[-self.max_len:] 252 | positions = [0] * mask_len + positions 253 | positions = np.array(positions) 254 | 255 | ### get the sequence of similar user 256 | sim_users = self.sim_users[index][:self.sim_user_num] 257 | sim_seq, sim_positions = [], [] 258 | for sim_user in sim_users: 259 | meta_seq, meta_positions = self._get_user_seq(sim_user) 260 | sim_seq.append(meta_seq) 261 | sim_positions.append(meta_positions) 262 | 263 | sim_seq = np.array(sim_seq) 264 | sim_positions = np.array(sim_positions) 265 | 266 | return seq, pos, neg, positions, index, sim_seq, sim_positions 267 | 268 | 269 | def _get_user_seq(self, user): 270 | 271 | ### get the sequence of required user 272 | inter = self.data[user] 273 | seq = np.zeros([self.max_len], dtype=np.int32) 274 | idx = self.max_len - 1 275 | for i in reversed(inter[:-1]): 276 | seq[idx] = i 277 | idx -= 1 278 | if idx == -1: 279 | break 280 | 281 | if len(inter) > self.max_len: 282 | mask_len = 0 283 | positions = list(range(1, self.max_len+1)) 284 | else: 285 | mask_len = self.max_len - (len(inter) - 1) 286 | positions = list(range(1, len(inter)-1+1)) 287 | 288 | positions = positions[-self.max_len:] 289 | positions = [0] * mask_len + positions 290 | positions = np.array(positions) 291 | 292 | return seq, positions 293 | 294 | 295 | 296 | class BertRecTrainDatasetAllUser(Dataset): 297 | '''The train dataset for Bert4Rec''' 298 | 299 | def __init__(self, args, data, item_num, max_len, neg_num=1): 300 | 301 | super().__init__() 302 | self.data = data 303 | self.item_num = item_num 304 | self.max_len = max_len 305 | self.neg_num = neg_num 306 | self.mask_prob = args.mask_prob 307 | self.sim_user_num = args.sim_user_num 308 | self.mask_token = item_num + 1 309 | self.sim_users = pickle.load(open(os.path.join("./data/"+args.dataset+"/handled/", "sim_user_100.pkl"), "rb")) 310 | self.var_name = ["seq", "pos", "neg", "positions", "user_id", "sim_seq", "sim_positions"] 311 | 312 | 313 | def __len__(self): 314 | 315 | return 2 * len(self.data) 316 | 317 | def __getitem__(self, index): 318 | 319 | tokens = [] 320 | labels, neg_labels = [], [] 321 | 322 | if index >= len(self.data): 323 | seq = self.data[index - len(self.data)] 324 | for s in seq: 325 | tokens.append(s) 326 | labels.append(0) 327 | neg_labels.append(0) 328 | labels[-1] = tokens[-1] 329 | neg_labels[-1] = random_neq(1, self.item_num+1, seq) 330 | tokens[-1] = self.mask_token 331 | 332 | else: 333 | seq = self.data[index] 334 | 335 | for s in seq: 336 | prob = random.random() 337 | if prob < self.mask_prob: 338 | prob /= self.mask_prob 339 | 340 | if prob < 0.8: 341 | tokens.append(self.mask_token) 342 | elif prob < 0.9: 343 | tokens.append(random.randint(1, self.item_num)) 344 | else: 345 | tokens.append(s) 346 | 347 | labels.append(s) 348 | neg = random_neq(1, self.item_num+1, seq) 349 | neg_labels.append(neg) 350 | 351 | else: 352 | tokens.append(s) 353 | labels.append(0) 354 | neg_labels.append(0) 355 | 356 | tokens = tokens[-self.max_len:] 357 | labels = labels[-self.max_len:] 358 | neg_labels = neg_labels[-self.max_len:] 359 | pos = list(range(1, len(tokens)+1)) 360 | pos= pos[-self.max_len:] 361 | 362 | mask_len = self.max_len - len(tokens) 363 | 364 | tokens = [0] * mask_len + tokens 365 | labels = [0] * mask_len + labels 366 | neg_labels = [0] * mask_len + neg_labels 367 | pos = [0] * mask_len + pos 368 | 369 | if index >= len(self.data): 370 | user_id = index - len(self.data) 371 | else: 372 | user_id = index 373 | 374 | ### get the sequence of similar user 375 | sim_users = self.sim_users[user_id][:self.sim_user_num] 376 | sim_seq, sim_positions = [], [] 377 | for sim_user in sim_users: 378 | meta_seq, meta_positions = self._get_user_seq(sim_user) 379 | sim_seq.append(meta_seq) 380 | sim_positions.append(meta_positions) 381 | 382 | sim_seq = np.array(sim_seq) 383 | sim_positions = np.array(sim_positions) 384 | 385 | return np.array(tokens), np.array(labels), np.array(neg_labels), np.array(pos), user_id, sim_seq, sim_positions 386 | 387 | 388 | def _get_user_seq(self, user): 389 | 390 | ### get the sequence of required user 391 | inter = self.data[user] 392 | seq = np.zeros([self.max_len], dtype=np.int32) 393 | idx = self.max_len - 1 394 | for i in reversed(inter[:-1]): 395 | seq[idx] = i 396 | idx -= 1 397 | if idx == -1: 398 | break 399 | 400 | if len(inter) > self.max_len: 401 | mask_len = 0 402 | positions = list(range(1, self.max_len+1)) 403 | else: 404 | mask_len = self.max_len - (len(inter) - 1) 405 | positions = list(range(1, len(inter)-1+1)) 406 | 407 | positions = positions[-self.max_len:] 408 | positions = [0] * mask_len + positions 409 | positions = np.array(positions) 410 | 411 | return seq, positions 412 | 413 | -------------------------------------------------------------------------------- /generators/generator.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import os 3 | import time 4 | import pickle 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 10 | from generators.data import SeqDataset, SeqDatasetAllUser, Seq2SeqDatasetAllUser 11 | from utils.utils import unzip_data, concat_data 12 | 13 | 14 | class Generator(object): 15 | 16 | def __init__(self, args, logger, device): 17 | 18 | self.args = args 19 | self.aug_file = args.aug_file 20 | self.inter_file = args.inter_file 21 | self.dataset = args.dataset 22 | self.num_workers = args.num_workers 23 | self.bs = args.train_batch_size 24 | self.logger = logger 25 | self.device = device 26 | self.aug_seq = args.aug_seq 27 | 28 | self.logger.info("Loading dataset ... ") 29 | start = time.time() 30 | self._load_dataset() 31 | end = time.time() 32 | self.logger.info("Dataset is loaded: consume %.3f s" % (end - start)) 33 | 34 | 35 | def _load_dataset(self): 36 | '''Load train, validation, test dataset''' 37 | 38 | usernum = 0 39 | itemnum = 0 40 | User = defaultdict(list) # default value is a blank list 41 | user_train = {} 42 | user_valid = {} 43 | user_test = {} 44 | # assume user/item index starting from 1 45 | f = open('./data/%s/handled/%s.txt' % (self.dataset, self.inter_file), 'r') 46 | for line in f: # use a dict to save all seqeuces of each user 47 | u, i = line.rstrip().split(' ') 48 | u = int(u) 49 | i = int(i) 50 | usernum = max(u, usernum) 51 | itemnum = max(i, itemnum) 52 | User[u].append(i) 53 | 54 | self.user_num = usernum 55 | self.item_num = itemnum 56 | 57 | for user in tqdm(User): 58 | nfeedback = len(User[user]) - self.args.aug_seq_len 59 | #nfeedback = len(User[user]) 60 | if nfeedback < 3: 61 | user_train[user] = User[user] 62 | user_valid[user] = [] 63 | user_test[user] = [] 64 | else: 65 | user_train[user] = User[user][:-2] 66 | user_valid[user] = [] 67 | user_valid[user].append(User[user][-2]) 68 | user_test[user] = [] 69 | user_test[user].append(User[user][-1]) 70 | 71 | self.train = user_train 72 | self.valid = user_valid 73 | self.test = user_test 74 | 75 | 76 | 77 | def make_trainloader(self): 78 | 79 | train_dataset = unzip_data(self.train, aug=self.args.aug, aug_num=self.args.aug_seq_len) 80 | self.train_dataset = SeqDataset(train_dataset, self.item_num, self.args.max_len, self.args.train_neg) 81 | 82 | train_dataloader = DataLoader(self.train_dataset, 83 | sampler=RandomSampler(self.train_dataset), 84 | batch_size=self.bs, 85 | num_workers=self.num_workers) 86 | 87 | 88 | return train_dataloader 89 | 90 | 91 | def make_evalloader(self, test=False): 92 | 93 | if test: 94 | eval_dataset = concat_data([self.train, self.valid, self.test]) 95 | 96 | else: 97 | eval_dataset = concat_data([self.train, self.valid]) 98 | 99 | self.eval_dataset = SeqDataset(eval_dataset, self.item_num, self.args.max_len, self.args.test_neg) 100 | eval_dataloader = DataLoader(self.eval_dataset, 101 | sampler=SequentialSampler(self.eval_dataset), 102 | batch_size=100, 103 | num_workers=self.num_workers) 104 | 105 | return eval_dataloader 106 | 107 | 108 | def get_user_item_num(self): 109 | 110 | return self.user_num, self.item_num 111 | 112 | 113 | def get_item_pop(self): 114 | """get item popularity according to item index. return a np-array""" 115 | all_data = concat_data([self.train, self.valid, self.test]) 116 | pop = np.zeros(self.item_num+1) # item index starts from 0 117 | 118 | for items in all_data: 119 | pop[items] += 1 120 | 121 | return pop 122 | 123 | 124 | def get_user_len(self): 125 | """get sequence length according to user index. return a np-array""" 126 | all_data = concat_data([self.train, self.valid]) 127 | lens = [] 128 | 129 | for user in all_data: 130 | lens.append(len(user)) 131 | 132 | return np.array(lens) 133 | 134 | 135 | 136 | class GeneratorAllUser(Generator): 137 | 138 | def __init__(self, args, logger, device): 139 | 140 | super().__init__(args, logger, device) 141 | 142 | 143 | def make_trainloader(self): 144 | 145 | train_dataset = unzip_data(self.train, aug=self.args.aug, aug_num=self.args.aug_seq_len) 146 | self.train_dataset = SeqDatasetAllUser(self.args, train_dataset, self.item_num, self.args.max_len, self.args.train_neg) 147 | 148 | train_dataloader = DataLoader(self.train_dataset, 149 | sampler=RandomSampler(self.train_dataset), 150 | batch_size=self.bs, 151 | num_workers=self.num_workers) 152 | 153 | return train_dataloader 154 | 155 | 156 | 157 | class Seq2SeqGeneratorAllUser(Generator): 158 | 159 | def __init__(self, args, logger, device): 160 | 161 | super().__init__(args, logger, device) 162 | 163 | 164 | def make_trainloader(self): 165 | 166 | train_dataset = unzip_data(self.train, aug=self.args.aug, aug_num=self.args.aug_seq_len) 167 | self.train_dataset = Seq2SeqDatasetAllUser(self.args, train_dataset, self.item_num, self.args.max_len, self.args.train_neg) 168 | 169 | train_dataloader = DataLoader(self.train_dataset, 170 | sampler=RandomSampler(self.train_dataset), 171 | batch_size=self.bs, 172 | num_workers=self.num_workers) 173 | 174 | return train_dataloader 175 | 176 | 177 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import os 3 | import argparse 4 | import torch 5 | 6 | from generators.generator import Seq2SeqGeneratorAllUser 7 | from generators.generator import GeneratorAllUser 8 | from generators.bert_generator import BertGeneratorAllUser 9 | from trainers.sequence_trainer import SeqTrainer 10 | from utils.utils import set_seed 11 | from utils.logger import Logger 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | # Required parameters 17 | parser.add_argument("--model_name", 18 | default='llmesr_sasrec', 19 | choices=[ 20 | "llmesr_sasrec", "llmesr_bert4rec", "llmesr_gru4rec", 21 | ], 22 | type=str, 23 | required=False, 24 | help="model name") 25 | parser.add_argument("--dataset", 26 | default="yelp", 27 | choices=["yelp", "fashion", "beauty",], # preprocess by myself 28 | help="Choose the dataset") 29 | parser.add_argument("--inter_file", 30 | default="inter", 31 | type=str, 32 | help="the name of interaction file") 33 | parser.add_argument("--demo", 34 | default=False, 35 | action='store_true', 36 | help='whether run demo') 37 | parser.add_argument("--pretrain_dir", 38 | type=str, 39 | default="sasrec_seq", 40 | help="the path that pretrained model saved in") 41 | parser.add_argument("--output_dir", 42 | default='./saved/', 43 | type=str, 44 | required=False, 45 | help="The output directory where the model checkpoints will be written.") 46 | parser.add_argument("--check_path", 47 | default='', 48 | type=str, 49 | help="the save path of checkpoints for different running") 50 | parser.add_argument("--do_test", 51 | default=False, 52 | action="store_true", 53 | help="whehther run the test on the well-trained model") 54 | parser.add_argument("--do_emb", 55 | default=False, 56 | action="store_true", 57 | help="save the user embedding derived from the SRS model") 58 | parser.add_argument("--do_group", 59 | default=False, 60 | action="store_true", 61 | help="conduct the group test") 62 | parser.add_argument("--keepon", 63 | default=False, 64 | action="store_true", 65 | help="whether keep on training based on a trained model") 66 | parser.add_argument("--keepon_path", 67 | type=str, 68 | default="normal", 69 | help="the path of trained model for keep on training") 70 | parser.add_argument("--clip_path", 71 | type=str, 72 | default="", 73 | help="the path to save the CLIP-pretrained embedding and adapter") 74 | parser.add_argument("--ts_user", 75 | type=int, 76 | default=10, 77 | help="the threshold to split the short and long seq") 78 | parser.add_argument("--ts_item", 79 | type=int, 80 | default=20, 81 | help="the threshold to split the long-tail and popular items") 82 | 83 | # Model parameters 84 | parser.add_argument("--hidden_size", 85 | default=64, 86 | type=int, 87 | help="the hidden size of embedding") 88 | parser.add_argument("--trm_num", 89 | default=2, 90 | type=int, 91 | help="the number of transformer layer") 92 | parser.add_argument("--num_heads", 93 | default=1, 94 | type=int, 95 | help="the number of heads in Trm layer") 96 | parser.add_argument("--num_layers", 97 | default=1, 98 | type=int, 99 | help="the number of GRU layers") 100 | parser.add_argument("--cl_scale", 101 | type=float, 102 | default=0.1, 103 | help="the scale for contastive loss") 104 | parser.add_argument("--mask_crop_ratio", 105 | type=float, 106 | default=0.3, 107 | help="the mask/crop ratio for CL4SRec") 108 | parser.add_argument("--tau", 109 | default=1, 110 | type=float, 111 | help="the temperature for contrastive loss") 112 | parser.add_argument("--sse_ratio", 113 | default=0.4, 114 | type=float, 115 | help="the sse ratio for SSE-PT model") 116 | parser.add_argument("--dropout_rate", 117 | default=0.5, 118 | type=float, 119 | help="the dropout rate") 120 | parser.add_argument("--max_len", 121 | default=200, 122 | type=int, 123 | help="the max length of input sequence") 124 | parser.add_argument("--mask_prob", 125 | type=float, 126 | default=0.4, 127 | help="the mask probability for training Bert model") 128 | parser.add_argument("--aug", 129 | default=False, 130 | action="store_true", 131 | help="whether augment the sequence data") 132 | parser.add_argument("--aug_seq", 133 | default=False, 134 | action="store_true", 135 | help="whether use the augmented data") 136 | parser.add_argument("--aug_seq_len", 137 | default=0, 138 | type=int, 139 | help="the augmented length for each sequence") 140 | parser.add_argument("--aug_file", 141 | default="inter", 142 | type=str, 143 | help="the augmentation file name") 144 | parser.add_argument("--train_neg", 145 | default=1, 146 | type=int, 147 | help="the number of negative samples for training") 148 | parser.add_argument("--test_neg", 149 | default=100, 150 | type=int, 151 | help="the number of negative samples for test") 152 | parser.add_argument("--suffix_num", 153 | default=5, 154 | type=int, 155 | help="the suffix number for augmented sequence") 156 | parser.add_argument("--prompt_num", 157 | default=2, 158 | type=int, 159 | help="the number of prompts") 160 | parser.add_argument("--freeze", 161 | default=False, 162 | action="store_true", 163 | help="whether freeze the pretrained architecture when finetuning") 164 | parser.add_argument("--pg", 165 | default="length", 166 | choices=['length', 'attention'], 167 | type=str, 168 | help="choose the prompt generator") 169 | parser.add_argument("--use_cross_att", 170 | default=False, 171 | action="store_true", 172 | help="whether add a cross-attention to interact the dual-view") 173 | parser.add_argument("--alpha", 174 | default=0.1, 175 | type=float, 176 | help="the weight of auxiliary loss") 177 | parser.add_argument("--user_sim_func", 178 | default="kd", 179 | type=str, 180 | help="the type of user similarity function to derive the loss") 181 | parser.add_argument("--item_reg", 182 | default=False, 183 | action="store_true", 184 | help="whether regularize the item embedding by CL") 185 | parser.add_argument("--beta", 186 | default=0.1, 187 | type=float, 188 | help="the weight of regulation loss") 189 | parser.add_argument("--sim_user_num", 190 | default=10, 191 | type=int, 192 | help="the number of similar users for enhancement") 193 | parser.add_argument("--split_backbone", 194 | default=False, 195 | action="store_true", 196 | help="whether use a split backbone") 197 | parser.add_argument("--co_view", 198 | default=False, 199 | action="store_true", 200 | help="only use the collaborative view") 201 | parser.add_argument("--se_view", 202 | default=False, 203 | action="store_true", 204 | help="only use the semantic view") 205 | 206 | 207 | # Other parameters 208 | parser.add_argument("--train_batch_size", 209 | default=512, 210 | type=int, 211 | help="Total batch size for training.") 212 | parser.add_argument("--lr", 213 | default=0.001, 214 | type=float, 215 | help="The initial learning rate for Adam.") 216 | parser.add_argument("--l2", 217 | default=0, 218 | type=float, 219 | help='The L2 regularization') 220 | parser.add_argument("--num_train_epochs", 221 | default=100, 222 | type=float, 223 | help="Total number of training epochs to perform.") 224 | parser.add_argument("--lr_dc_step", 225 | default=1000, 226 | type=int, 227 | help='every n step, decrease the lr') 228 | parser.add_argument("--lr_dc", 229 | default=0, 230 | type=float, 231 | help='how many learning rate to decrease') 232 | parser.add_argument("--patience", 233 | type=int, 234 | default=20, 235 | help='How many steps to tolerate the performance decrease while training') 236 | parser.add_argument("--watch_metric", 237 | type=str, 238 | default='NDCG@10', 239 | help="which metric is used to select model.") 240 | parser.add_argument('--seed', 241 | type=int, 242 | default=42, 243 | help="random seed for different data split") 244 | parser.add_argument("--no_cuda", 245 | action='store_true', 246 | help="Whether not to use CUDA when available") 247 | parser.add_argument('--gpu_id', 248 | default=0, 249 | type=int, 250 | help='The device id.') 251 | parser.add_argument('--num_workers', 252 | default=0, 253 | type=int, 254 | help='The number of workers in dataloader') 255 | parser.add_argument("--log", 256 | default=False, 257 | action="store_true", 258 | help="whether create a new log file") 259 | 260 | torch.autograd.set_detect_anomaly(True) 261 | 262 | args = parser.parse_args() 263 | set_seed(args.seed) # fix the random seed 264 | args.output_dir = os.path.join(args.output_dir, args.dataset) 265 | args.pretrain_dir = os.path.join(args.output_dir, args.pretrain_dir) 266 | args.output_dir = os.path.join(args.output_dir, args.model_name) 267 | args.keepon_path = os.path.join(args.output_dir, args.keepon_path) 268 | args.output_dir = os.path.join(args.output_dir, args.check_path) # if check_path is none, then without check_path 269 | 270 | 271 | def main(): 272 | 273 | log_manager = Logger(args) # initialize the log manager 274 | logger, writer = log_manager.get_logger() # get the logger 275 | args.now_str = log_manager.get_now_str() 276 | 277 | device = torch.device("cuda:"+str(args.gpu_id) if torch.cuda.is_available() 278 | and not args.no_cuda else "cpu") 279 | 280 | 281 | os.makedirs(args.output_dir, exist_ok=True) 282 | 283 | # generator is used to manage dataset 284 | if args.model_name in ['llmesr_gru4rec']: 285 | generator = GeneratorAllUser(args, logger, device) 286 | elif args.model_name in ["llmesr_bert4rec"]: 287 | generator = BertGeneratorAllUser(args, logger, device) 288 | elif args.model_name in ["llmesr_sasrec"]: 289 | generator = Seq2SeqGeneratorAllUser(args, logger, device) 290 | else: 291 | raise ValueError 292 | 293 | trainer = SeqTrainer(args, logger, writer, device, generator) 294 | 295 | if args.do_test: 296 | trainer.test() 297 | elif args.do_emb: 298 | trainer.save_user_emb() 299 | elif args.do_group: 300 | trainer.test_group() 301 | else: 302 | trainer.train() 303 | 304 | log_manager.end_log() # delete the logger threads 305 | 306 | 307 | 308 | if __name__ == "__main__": 309 | 310 | main() 311 | 312 | 313 | 314 | -------------------------------------------------------------------------------- /models/BaseModel.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import torch.nn as nn 3 | 4 | 5 | 6 | class BaseSeqModel(nn.Module): 7 | 8 | def __init__(self, user_num, item_num, device, args) -> None: 9 | 10 | super().__init__() 11 | 12 | self.user_num = user_num 13 | self.item_num = item_num 14 | self.dev = device 15 | self.freeze_modules = [] 16 | self.filter_init_modules = [] # all modules should be initialized 17 | 18 | 19 | def _freeze(self): 20 | 21 | for name, param in self.named_parameters(): 22 | try: 23 | flag = False 24 | for fm in self.freeze_modules: 25 | if fm in name: # if the param in freeze_modules, freeze it 26 | flag = True 27 | if flag: 28 | param.requires_grad = False 29 | except: 30 | pass 31 | 32 | 33 | def _init_weights(self): 34 | 35 | for name, param in self.named_parameters(): 36 | try: 37 | flag = True # denote initialize this param 38 | for fm in self.filter_init_modules: 39 | if fm in name: # if the param in filter_modules, do not initialize 40 | flag = False 41 | if flag: 42 | nn.init.xavier_normal_(param.data) 43 | except: 44 | pass 45 | 46 | 47 | def _get_embedding(self, log_seqs): 48 | 49 | raise NotImplementedError("The function for sequence embedding is missed") 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /models/Bert4Rec.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import os 3 | import pickle 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from models.utils import * 9 | from models.BaseModel import BaseSeqModel 10 | 11 | 12 | 13 | class BertBackbone(nn.Module): 14 | 15 | def __init__(self, device, args) -> None: 16 | 17 | super().__init__() 18 | 19 | self.dev = device 20 | self.attention_layernorms = torch.nn.ModuleList() # to be Q for self-attention 21 | self.attention_layers = torch.nn.ModuleList() 22 | self.forward_layernorms = torch.nn.ModuleList() 23 | self.forward_layers = torch.nn.ModuleList() 24 | 25 | self.last_layernorm = torch.nn.LayerNorm(args.hidden_size, eps=1e-8) 26 | 27 | for _ in range(args.trm_num): 28 | new_attn_layernorm = torch.nn.LayerNorm(args.hidden_size, eps=1e-8) 29 | self.attention_layernorms.append(new_attn_layernorm) 30 | 31 | new_attn_layer = torch.nn.MultiheadAttention(args.hidden_size, 32 | args.num_heads, 33 | args.dropout_rate) 34 | self.attention_layers.append(new_attn_layer) 35 | 36 | new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_size, eps=1e-8) 37 | self.forward_layernorms.append(new_fwd_layernorm) 38 | 39 | new_fwd_layer = PointWiseFeedForward(args.hidden_size, args.dropout_rate) 40 | self.forward_layers.append(new_fwd_layer) 41 | 42 | 43 | def forward(self, seqs, log_seqs): 44 | 45 | timeline_mask = (log_seqs == 0) 46 | seqs *= ~timeline_mask.unsqueeze(-1) # broadcast in last dim 47 | 48 | for i in range(len(self.attention_layers)): 49 | seqs = torch.transpose(seqs, 0, 1) 50 | Q = self.attention_layernorms[i](seqs) 51 | mha_outputs, _ = self.attention_layers[i](Q, seqs, seqs,) 52 | seqs = Q + mha_outputs 53 | seqs = torch.transpose(seqs, 0, 1) 54 | 55 | seqs = self.forward_layernorms[i](seqs) 56 | seqs = self.forward_layers[i](seqs) 57 | seqs *= ~timeline_mask.unsqueeze(-1) 58 | 59 | log_feats = self.last_layernorm(seqs) # (U, T, C) -> (U, -1, C) 60 | 61 | return log_feats 62 | 63 | 64 | 65 | class Bert4Rec(BaseSeqModel): 66 | 67 | def __init__(self, user_num, item_num, device, args): 68 | 69 | super(Bert4Rec, self).__init__(user_num, item_num, device, args) 70 | 71 | self.mask_token = item_num + 1 72 | self.item_emb = torch.nn.Embedding(self.item_num+2, args.hidden_size, padding_idx=0) 73 | self.pos_emb = torch.nn.Embedding(args.max_len+100, args.hidden_size) # TO IMPROVE 74 | self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate) 75 | 76 | self.backbone = BertBackbone(self.dev, args) 77 | 78 | self.loss_func = torch.nn.BCEWithLogitsLoss() 79 | self._init_weights() 80 | 81 | 82 | def _get_embedding(self, log_seqs): 83 | 84 | item_seq_emb = self.item_emb(log_seqs) 85 | 86 | return item_seq_emb 87 | 88 | 89 | def log2feats(self, log_seqs, positions): 90 | 91 | seqs = self._get_embedding(log_seqs) 92 | seqs *= self.item_emb.embedding_dim ** 0.5 # QKV/sqrt(D) 93 | seqs += self.pos_emb(positions.long()) 94 | seqs = self.emb_dropout(seqs) 95 | 96 | log_feats = self.backbone(seqs, log_seqs) 97 | 98 | return log_feats 99 | 100 | 101 | def forward(self, 102 | seq, 103 | pos, 104 | neg, 105 | positions, 106 | **kwargs): # for training 107 | 108 | log_feats = self.log2feats(seq, positions) # (bs, max_len, hidden_size) 109 | mask_index = torch.where(pos>0) 110 | log_feats = log_feats[mask_index] # (bs, mask_num, hidden_size) 111 | 112 | pos_embs = self._get_embedding(pos) # (bs, mask_num, hidden_size) 113 | neg_embs = self._get_embedding(neg) # (bs, mask_num, hidden_size) 114 | pos_embs = pos_embs[mask_index] 115 | neg_embs = neg_embs[mask_index] 116 | 117 | pos_logits = torch.mul(log_feats, pos_embs).sum(dim=-1) # (bs, mask_num) 118 | neg_logits = torch.mul(log_feats, neg_embs).sum(dim=-1) # (bs, mask_num) 119 | 120 | pos_labels, neg_labels = torch.ones(pos_logits.shape, device=self.dev), torch.zeros(neg_logits.shape, device=self.dev) 121 | pos_loss, neg_loss = self.loss_func(pos_logits, pos_labels), self.loss_func(neg_logits, neg_labels) 122 | loss = pos_loss + neg_loss 123 | 124 | return loss # loss 125 | 126 | 127 | def predict(self, 128 | seq, 129 | item_indices, 130 | positions, 131 | **kwargs): # for inference 132 | 133 | log_seqs = torch.cat([seq, self.mask_token * torch.ones(seq.shape[0], 1, device=self.dev)], dim=1) 134 | pred_position = positions[:, -1] + 1 135 | positions = torch.cat([positions, pred_position.unsqueeze(1)], dim=1) 136 | log_feats = self.log2feats(log_seqs[:, 1:].long(), positions[:, 1:].long()) # user_ids hasn't been used yet 137 | 138 | final_feat = log_feats[:, -1, :] # only use last QKV classifier, a waste 139 | item_embs = self._get_embedding(item_indices) # (U, I, C) 140 | logits = item_embs.matmul(final_feat.unsqueeze(-1)).squeeze(-1) 141 | 142 | return logits # preds # (U, I) 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /models/DualLLMSRS.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import os 3 | import pickle 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from models.GRU4Rec import GRU4Rec 8 | from models.SASRec import SASRec_seq 9 | from models.Bert4Rec import Bert4Rec 10 | from models.utils import Multi_CrossAttention 11 | 12 | 13 | 14 | class DualLLMGRU4Rec(GRU4Rec): 15 | 16 | def __init__(self, user_num, item_num, device, args): 17 | 18 | super().__init__(user_num, item_num, device, args) 19 | 20 | self.mask_token = item_num + 1 21 | self.num_heads = args.num_heads 22 | self.use_cross_att = args.use_cross_att 23 | 24 | # load llm embedding as item embedding 25 | llm_item_emb = pickle.load(open(os.path.join("data/"+args.dataset+"/handled/", "itm_emb_np.pkl"), "rb")) 26 | llm_item_emb = np.insert(llm_item_emb, 0, values=np.zeros((1, llm_item_emb.shape[1])), axis=0) 27 | llm_item_emb = np.concatenate([llm_item_emb, np.zeros((1, llm_item_emb.shape[1]))], axis=0) 28 | self.llm_item_emb = nn.Embedding.from_pretrained(torch.Tensor(llm_item_emb)) 29 | self.llm_item_emb.weight.requires_grad = True # the grad is false in default 30 | self.adapter = nn.Sequential( 31 | nn.Linear(llm_item_emb.shape[1], int(llm_item_emb.shape[1] / 2)), 32 | nn.Linear(int(llm_item_emb.shape[1] / 2), args.hidden_size) 33 | ) 34 | 35 | id_item_emb = pickle.load(open(os.path.join("data/"+args.dataset+"/handled/", "pca64_itm_emb_np.pkl"), "rb")) 36 | id_item_emb = np.insert(id_item_emb, 0, values=np.zeros((1, id_item_emb.shape[1])), axis=0) 37 | id_item_emb = np.concatenate([id_item_emb, np.zeros((1, id_item_emb.shape[1]))], axis=0) 38 | self.id_item_emb = nn.Embedding.from_pretrained(torch.Tensor(id_item_emb)) 39 | self.id_item_emb.weight.requires_grad = True # the grad is false in default 40 | # self.id_item_emb = torch.nn.Embedding(self.item_num+2, args.hidden_size, padding_idx=0) 41 | 42 | self.pos_emb = torch.nn.Embedding(args.max_len+100, args.hidden_size) # TO IMPROVE 43 | self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate) 44 | 45 | if self.use_cross_att: 46 | self.llm2id = Multi_CrossAttention(args.hidden_size, args.hidden_size, 2) 47 | self.id2llm = Multi_CrossAttention(args.hidden_size, args.hidden_size, 2) 48 | 49 | if args.freeze: # freeze the llm embedding 50 | self.freeze_modules = ["llm_item_emb"] 51 | self._freeze() 52 | 53 | self.filter_init_modules = ["llm_item_emb", "id_item_emb"] 54 | self._init_weights() 55 | 56 | 57 | def _get_embedding(self, log_seqs): 58 | 59 | id_seq_emb = self.id_item_emb(log_seqs) 60 | llm_seq_emb = self.llm_item_emb(log_seqs) 61 | llm_seq_emb = self.adapter(llm_seq_emb) 62 | 63 | item_seq_emb = torch.cat([id_seq_emb, llm_seq_emb], dim=-1) 64 | 65 | return item_seq_emb 66 | 67 | 68 | def log2feats(self, log_seqs): 69 | 70 | id_seqs = self.id_item_emb(log_seqs) 71 | llm_seqs = self.llm_item_emb(log_seqs) 72 | llm_seqs = self.adapter(llm_seqs) 73 | 74 | if self.use_cross_att: 75 | cross_id_seqs = self.llm2id(llm_seqs, id_seqs, log_seqs) 76 | cross_llm_seqs = self.id2llm(id_seqs, llm_seqs, log_seqs) 77 | else: 78 | cross_id_seqs = id_seqs 79 | cross_llm_seqs = llm_seqs 80 | 81 | id_log_feats = self.backbone(cross_id_seqs, log_seqs) 82 | llm_log_feats = self.backbone(cross_llm_seqs, log_seqs) 83 | 84 | log_feats = torch.cat([id_log_feats, llm_log_feats], dim=-1) 85 | 86 | return log_feats 87 | 88 | 89 | 90 | class DualLLMSASRec(SASRec_seq): 91 | 92 | def __init__(self, user_num, item_num, device, args): 93 | 94 | super().__init__(user_num, item_num, device, args) 95 | 96 | # self.user_num = user_num 97 | # self.item_num = item_num 98 | # self.dev = device 99 | self.mask_token = item_num + 1 100 | self.num_heads = args.num_heads 101 | self.use_cross_att = args.use_cross_att 102 | 103 | # load llm embedding as item embedding 104 | # llm_item_emb = pickle.load(open(os.path.join("data/"+args.dataset, "pca_itm_emb_np.pkl"), "rb")) 105 | llm_item_emb = pickle.load(open(os.path.join("data/"+args.dataset+"/handled/", "itm_emb_np.pkl"), "rb")) 106 | llm_item_emb = np.insert(llm_item_emb, 0, values=np.zeros((1, llm_item_emb.shape[1])), axis=0) 107 | llm_item_emb = np.concatenate([llm_item_emb, np.zeros((1, llm_item_emb.shape[1]))], axis=0) 108 | self.llm_item_emb = nn.Embedding.from_pretrained(torch.Tensor(llm_item_emb)) 109 | self.llm_item_emb.weight.requires_grad = True # the grad is false in default 110 | # self.adapter = nn.Linear(llm_item_emb.shape[1], args.hidden_size) 111 | self.adapter = nn.Sequential( 112 | nn.Linear(llm_item_emb.shape[1], int(llm_item_emb.shape[1] / 2)), 113 | nn.Linear(int(llm_item_emb.shape[1] / 2), args.hidden_size) 114 | ) 115 | 116 | id_item_emb = pickle.load(open(os.path.join("data/"+args.dataset+"/handled/", "pca64_itm_emb_np.pkl"), "rb")) 117 | id_item_emb = np.insert(id_item_emb, 0, values=np.zeros((1, id_item_emb.shape[1])), axis=0) 118 | id_item_emb = np.concatenate([id_item_emb, np.zeros((1, id_item_emb.shape[1]))], axis=0) 119 | self.id_item_emb = nn.Embedding.from_pretrained(torch.Tensor(id_item_emb)) 120 | self.id_item_emb.weight.requires_grad = True # the grad is false in default 121 | # self.id_item_emb = torch.nn.Embedding(self.item_num+2, args.hidden_size, padding_idx=0) 122 | 123 | self.pos_emb = torch.nn.Embedding(args.max_len+100, args.hidden_size) # TO IMPROVE 124 | self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate) 125 | 126 | if self.use_cross_att: 127 | self.llm2id = Multi_CrossAttention(args.hidden_size, args.hidden_size, 2) 128 | self.id2llm = Multi_CrossAttention(args.hidden_size, args.hidden_size, 2) 129 | 130 | if args.freeze: # freeze the llm embedding 131 | self.freeze_modules = ["llm_item_emb"] 132 | self._freeze() 133 | 134 | self.filter_init_modules = ["llm_item_emb", "id_item_emb"] 135 | self._init_weights() 136 | 137 | 138 | def _get_embedding(self, log_seqs): 139 | 140 | id_seq_emb = self.id_item_emb(log_seqs) 141 | llm_seq_emb = self.llm_item_emb(log_seqs) 142 | llm_seq_emb = self.adapter(llm_seq_emb) 143 | 144 | item_seq_emb = torch.cat([id_seq_emb, llm_seq_emb], dim=-1) 145 | 146 | return item_seq_emb 147 | 148 | 149 | def log2feats(self, log_seqs, positions): 150 | 151 | id_seqs = self.id_item_emb(log_seqs) 152 | id_seqs *= self.id_item_emb.embedding_dim ** 0.5 # QKV/sqrt(D) 153 | id_seqs += self.pos_emb(positions.long()) 154 | id_seqs = self.emb_dropout(id_seqs) 155 | 156 | llm_seqs = self.llm_item_emb(log_seqs) 157 | llm_seqs = self.adapter(llm_seqs) 158 | llm_seqs *= self.id_item_emb.embedding_dim ** 0.5 # QKV/sqrt(D) 159 | llm_seqs += self.pos_emb(positions.long()) 160 | llm_seqs = self.emb_dropout(llm_seqs) 161 | 162 | if self.use_cross_att: 163 | cross_id_seqs = self.llm2id(llm_seqs, id_seqs, log_seqs) 164 | cross_llm_seqs = self.id2llm(id_seqs, llm_seqs, log_seqs) 165 | cross_id_seqs = 1 * cross_id_seqs + 0 * id_seqs 166 | cross_llm_seqs = 1 * cross_llm_seqs + 0 * llm_seqs 167 | else: 168 | cross_id_seqs = id_seqs 169 | cross_llm_seqs = llm_seqs 170 | 171 | id_log_feats = self.backbone(cross_id_seqs, log_seqs) 172 | llm_log_feats = self.backbone(cross_llm_seqs, log_seqs) 173 | 174 | log_feats = torch.cat([id_log_feats, llm_log_feats], dim=-1) 175 | 176 | return log_feats 177 | 178 | 179 | 180 | class DualLLMBert4Rec(Bert4Rec): 181 | 182 | def __init__(self, user_num, item_num, device, args): 183 | 184 | super().__init__(user_num, item_num, device, args) 185 | 186 | # self.user_num = user_num 187 | # self.item_num = item_num 188 | # self.dev = device 189 | self.mask_token = item_num + 1 190 | self.num_heads = args.num_heads 191 | self.use_cross_att = args.use_cross_att 192 | 193 | # load llm embedding as item embedding 194 | # llm_item_emb = pickle.load(open(os.path.join("data/"+args.dataset, "pca_itm_emb_np.pkl"), "rb")) 195 | llm_item_emb = pickle.load(open(os.path.join("data/"+args.dataset+"/handled/", "itm_emb_np.pkl"), "rb")) 196 | llm_item_emb = np.insert(llm_item_emb, 0, values=np.zeros((1, llm_item_emb.shape[1])), axis=0) 197 | llm_item_emb = np.concatenate([llm_item_emb, np.zeros((1, llm_item_emb.shape[1]))], axis=0) 198 | self.llm_item_emb = nn.Embedding.from_pretrained(torch.Tensor(llm_item_emb)) 199 | self.llm_item_emb.weight.requires_grad = True # the grad is false in default 200 | # self.adapter = nn.Linear(llm_item_emb.shape[1], args.hidden_size) 201 | self.adapter = nn.Sequential( 202 | nn.Linear(llm_item_emb.shape[1], int(llm_item_emb.shape[1] / 2)), 203 | nn.Linear(int(llm_item_emb.shape[1] / 2), args.hidden_size) 204 | ) 205 | 206 | id_item_emb = pickle.load(open(os.path.join("data/"+args.dataset+"/handled/", "pca64_itm_emb_np.pkl"), "rb")) 207 | id_item_emb = np.insert(id_item_emb, 0, values=np.zeros((1, id_item_emb.shape[1])), axis=0) 208 | id_item_emb = np.concatenate([id_item_emb, np.zeros((1, id_item_emb.shape[1]))], axis=0) 209 | self.id_item_emb = nn.Embedding.from_pretrained(torch.Tensor(id_item_emb)) 210 | self.id_item_emb.weight.requires_grad = True # the grad is false in default 211 | # self.id_item_emb = torch.nn.Embedding(self.item_num+2, args.hidden_size, padding_idx=0) 212 | 213 | self.pos_emb = torch.nn.Embedding(args.max_len+100, args.hidden_size) # TO IMPROVE 214 | self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate) 215 | 216 | if self.use_cross_att: 217 | self.llm2id = Multi_CrossAttention(args.hidden_size, args.hidden_size, 2) 218 | self.id2llm = Multi_CrossAttention(args.hidden_size, args.hidden_size, 2) 219 | 220 | if args.freeze: # freeze the llm embedding 221 | self.freeze_modules = ["llm_item_emb"] 222 | self._freeze() 223 | 224 | self.filter_init_modules = ["llm_item_emb", "id_item_emb"] 225 | self._init_weights() 226 | 227 | 228 | def _get_embedding(self, log_seqs): 229 | 230 | id_seq_emb = self.id_item_emb(log_seqs) 231 | llm_seq_emb = self.llm_item_emb(log_seqs) 232 | llm_seq_emb = self.adapter(llm_seq_emb) 233 | 234 | item_seq_emb = torch.cat([id_seq_emb, llm_seq_emb], dim=-1) 235 | 236 | return item_seq_emb 237 | 238 | 239 | def log2feats(self, log_seqs, positions): 240 | 241 | id_seqs = self.id_item_emb(log_seqs) 242 | id_seqs *= self.id_item_emb.embedding_dim ** 0.5 # QKV/sqrt(D) 243 | id_seqs += self.pos_emb(positions.long()) 244 | id_seqs = self.emb_dropout(id_seqs) 245 | 246 | llm_seqs = self.llm_item_emb(log_seqs) 247 | llm_seqs = self.adapter(llm_seqs) 248 | llm_seqs *= self.id_item_emb.embedding_dim ** 0.5 # QKV/sqrt(D) 249 | llm_seqs += self.pos_emb(positions.long()) 250 | llm_seqs = self.emb_dropout(llm_seqs) 251 | 252 | if self.use_cross_att: 253 | cross_id_seqs = self.llm2id(llm_seqs, id_seqs, log_seqs) 254 | cross_llm_seqs = self.id2llm(id_seqs, llm_seqs, log_seqs) 255 | else: 256 | cross_id_seqs = id_seqs 257 | cross_llm_seqs = llm_seqs 258 | 259 | id_log_feats = self.backbone(cross_id_seqs, log_seqs) 260 | llm_log_feats = self.backbone(cross_llm_seqs, log_seqs) 261 | 262 | log_feats = torch.cat([id_log_feats, llm_log_feats], dim=-1) 263 | 264 | return log_feats 265 | -------------------------------------------------------------------------------- /models/GRU4Rec.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import torch 3 | import torch.nn as nn 4 | from models.BaseModel import BaseSeqModel 5 | 6 | 7 | 8 | class GRU4RecBackbone(nn.Module): 9 | 10 | def __init__(self, device, args) -> None: 11 | 12 | super().__init__() 13 | 14 | self.dev = device 15 | self.gru = nn.GRU( 16 | input_size=args.hidden_size, 17 | hidden_size=args.hidden_size, 18 | num_layers=args.num_layers, 19 | bias=False, 20 | batch_first=True 21 | ) 22 | 23 | 24 | def forward(self, seqs, log_seqs): 25 | 26 | log_feats, _ = self.gru(seqs) 27 | 28 | return log_feats 29 | 30 | 31 | 32 | class GRU4Rec(BaseSeqModel): 33 | 34 | def __init__(self, user_num, item_num, device, args) -> None: 35 | 36 | super(GRU4Rec, self).__init__(user_num, item_num, device, args) 37 | 38 | self.dev = device 39 | self.item_emb = nn.Embedding(self.item_num+2, args.hidden_size, padding_idx=0) 40 | self.backbone = GRU4RecBackbone(device, args) 41 | 42 | self.loss_func = nn.BCEWithLogitsLoss() 43 | 44 | self._init_weights() 45 | 46 | 47 | def _get_embedding(self, log_seqs): 48 | 49 | item_seq_emb = self.item_emb(log_seqs) 50 | 51 | return item_seq_emb 52 | 53 | 54 | def log2feats(self, log_seqs): 55 | 56 | seqs = self.item_emb(log_seqs) 57 | mask = (log_seqs > 0).unsqueeze(-1) 58 | # seqs *= mask # the padding input is 0 59 | log_feats = self.backbone(seqs, log_seqs) 60 | 61 | return log_feats 62 | 63 | 64 | def forward(self, 65 | seq, 66 | pos, 67 | neg, 68 | positions, 69 | **kwargs): 70 | # inputs: (bs, max_seq_len, hidden_size), mask: (bs, max_seq_len) 71 | log_feats = self.log2feats(seq) 72 | log_feats = log_feats[:, -1, :].unsqueeze(1) 73 | 74 | pos_embs = self._get_embedding(pos.unsqueeze(1)) # (bs, 1, hidden_size) 75 | neg_embs = self._get_embedding(neg) # (bs, neg_num, hidden_size) 76 | 77 | pos_logits = torch.mul(log_feats, pos_embs).sum(dim=-1) # (bs, 1) 78 | neg_logits = torch.mul(log_feats, neg_embs).sum(dim=-1) # (bs, neg_num) 79 | pos_labels, neg_labels = torch.ones(pos_logits.shape, device=self.dev), torch.zeros(neg_logits.shape, device=self.dev) 80 | 81 | indices = (pos != 0) # do not calculate the padding units 82 | pos_loss, neg_loss = self.loss_func(pos_logits[indices], pos_labels[indices]), self.loss_func(neg_logits[indices], neg_labels[indices]) 83 | loss = pos_loss + neg_loss 84 | 85 | return loss 86 | 87 | 88 | def predict(self, 89 | seq, 90 | item_indices, 91 | positions, 92 | **kwargs): # for inference 93 | '''Used to predict the score of item_indices given log_seqs''' 94 | log_feats = self.log2feats(seq) 95 | final_feat = log_feats[:, -1, :] 96 | item_embs = self._get_embedding(item_indices) # (U, I, C) 97 | logits = item_embs.matmul(final_feat.unsqueeze(-1)).squeeze(-1) 98 | 99 | return logits # preds # (U, I) 100 | 101 | 102 | 103 | class GRU4Rec_seq(GRU4Rec): 104 | 105 | def __init__(self, user_num, item_num, device, args) -> None: 106 | 107 | super().__init__(user_num, item_num, device, args) 108 | 109 | 110 | def forward(self, 111 | seq, 112 | pos, 113 | neg, 114 | positions, 115 | **kwargs): 116 | '''apply the seq-to-seq loss''' 117 | log_feats = self.log2feats(seq) 118 | pos_embs = self._get_embedding(pos) # (bs, max_seq_len, hidden_size) 119 | neg_embs = self._get_embedding(neg) # (bs, max_seq_len, hidden_size) 120 | 121 | pos_logits = (log_feats * pos_embs).sum(dim=-1) 122 | neg_logits = (log_feats * neg_embs).sum(dim=-1) 123 | 124 | pos_labels, neg_labels = torch.ones(pos_logits.shape, device=self.dev), torch.zeros(neg_logits.shape, device=self.dev) 125 | indices = (pos != 0) # do not calculate the padding units 126 | pos_loss, neg_loss = self.loss_func(pos_logits[indices], pos_labels[indices]), self.loss_func(neg_logits[indices], neg_labels[indices]) 127 | loss = pos_loss + neg_loss 128 | 129 | return loss 130 | 131 | -------------------------------------------------------------------------------- /models/LLMESR.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import torch 3 | import torch.nn as nn 4 | from models.DualLLMSRS import DualLLMSASRec, DualLLMGRU4Rec, DualLLMBert4Rec 5 | from models.utils import Contrastive_Loss2 6 | 7 | 8 | 9 | class LLMESR_SASRec(DualLLMSASRec): 10 | 11 | def __init__(self, user_num, item_num, device, args): 12 | 13 | super().__init__(user_num, item_num, device, args) 14 | self.alpha = args.alpha 15 | self.user_sim_func = args.user_sim_func 16 | self.item_reg = args.item_reg 17 | 18 | if self.user_sim_func == "cl": 19 | self.align = Contrastive_Loss2() 20 | elif self.user_sim_func == "kd": 21 | self.align = nn.MSELoss() 22 | else: 23 | raise ValueError 24 | 25 | self.projector1 = nn.Linear(2*args.hidden_size, 2*args.hidden_size) 26 | self.projector2 = nn.Linear(2*args.hidden_size, 2*args.hidden_size) 27 | 28 | if self.item_reg: 29 | self.beta = args.beta 30 | self.reg = Contrastive_Loss2() 31 | 32 | self._init_weights() 33 | 34 | 35 | def forward(self, 36 | seq, 37 | pos, 38 | neg, 39 | positions, 40 | **kwargs): 41 | 42 | loss = super().forward(seq, pos, neg, positions, **kwargs) # get the original loss 43 | 44 | log_feats = self.log2feats(seq, positions)[:, -1, :] 45 | sim_seq, sim_positions = kwargs["sim_seq"].view(-1, seq.shape[1]), kwargs["sim_positions"].view(-1, seq.shape[1]) 46 | sim_num = kwargs["sim_seq"].shape[1] 47 | sim_log_feats = self.log2feats(sim_seq, sim_positions)[:, -1, :] # (bs*sim_num, hidden_size) 48 | sim_log_feats = sim_log_feats.detach().view(seq.shape[0], sim_num, -1) # (bs, sim_num, hidden_size) 49 | sim_log_feats = torch.mean(sim_log_feats, dim=1) 50 | 51 | if self.user_sim_func == "cl": 52 | # align_loss = self.align(self.projector1(log_feats), self.projector2(sim_log_feats)) 53 | align_loss = self.align(log_feats, sim_log_feats) 54 | elif self.user_sim_func == "kd": 55 | align_loss = self.align(log_feats, sim_log_feats) 56 | 57 | if self.item_reg: 58 | unfold_item_id = torch.masked_select(seq, seq>0) 59 | llm_item_emb = self.adapter(self.llm_item_emb(unfold_item_id)) 60 | id_item_emb = self.id_item_emb(unfold_item_id) 61 | reg_loss = self.reg(llm_item_emb, id_item_emb) 62 | loss += self.beta * reg_loss 63 | 64 | loss += self.alpha * align_loss 65 | 66 | return loss 67 | 68 | 69 | 70 | class LLMESR_GRU4Rec(DualLLMGRU4Rec): 71 | 72 | def __init__(self, user_num, item_num, device, args): 73 | 74 | super().__init__(user_num, item_num, device, args) 75 | self.alpha = args.alpha 76 | self.user_sim_func = args.user_sim_func 77 | self.item_reg = args.item_reg 78 | 79 | if self.user_sim_func == "cl": 80 | self.align = Contrastive_Loss2() 81 | elif self.user_sim_func == "kd": 82 | self.align = nn.MSELoss() 83 | else: 84 | raise ValueError 85 | 86 | self.projector1 = nn.Linear(2*args.hidden_size, 2*args.hidden_size) 87 | self.projector2 = nn.Linear(2*args.hidden_size, 2*args.hidden_size) 88 | 89 | if self.item_reg: 90 | self.beta = args.beta 91 | self.reg = Contrastive_Loss2() 92 | 93 | self._init_weights() 94 | 95 | 96 | def forward(self, 97 | seq, 98 | pos, 99 | neg, 100 | positions, 101 | **kwargs): 102 | 103 | loss = super().forward(seq, pos, neg, positions, **kwargs) # get the original loss 104 | 105 | log_feats = self.log2feats(seq)[:, -1, :] 106 | sim_seq, sim_positions = kwargs["sim_seq"].view(-1, seq.shape[1]), kwargs["sim_positions"].view(-1, seq.shape[1]) 107 | sim_num = kwargs["sim_seq"].shape[1] 108 | sim_log_feats = self.log2feats(sim_seq)[:, -1, :] # (bs*sim_num, hidden_size) 109 | sim_log_feats = sim_log_feats.detach().view(seq.shape[0], sim_num, -1) # (bs, sim_num, hidden_size) 110 | sim_log_feats = torch.mean(sim_log_feats, dim=1) 111 | 112 | if self.user_sim_func == "cl": 113 | # align_loss = self.align(self.projector1(log_feats), self.projector2(sim_log_feats)) 114 | align_loss = self.align(log_feats, sim_log_feats) 115 | elif self.user_sim_func == "kd": 116 | align_loss = self.align(log_feats, sim_log_feats) 117 | 118 | if self.item_reg: 119 | unfold_item_id = torch.masked_select(seq, seq>0) 120 | llm_item_emb = self.adapter(self.llm_item_emb(unfold_item_id)) 121 | id_item_emb = self.id_item_emb(unfold_item_id) 122 | reg_loss = self.reg(llm_item_emb, id_item_emb) 123 | loss += self.beta * reg_loss 124 | 125 | loss += self.alpha * align_loss 126 | 127 | return loss 128 | 129 | 130 | 131 | class LLMESR_Bert4Rec(DualLLMBert4Rec): 132 | 133 | def __init__(self, user_num, item_num, device, args): 134 | 135 | super().__init__(user_num, item_num, device, args) 136 | self.alpha = args.alpha 137 | self.user_sim_func = args.user_sim_func 138 | self.item_reg = args.item_reg 139 | 140 | if self.user_sim_func == "cl": 141 | self.align = Contrastive_Loss2() 142 | elif self.user_sim_func == "kd": 143 | self.align = nn.MSELoss() 144 | else: 145 | raise ValueError 146 | 147 | self.projector1 = nn.Linear(2*args.hidden_size, 2*args.hidden_size) 148 | self.projector2 = nn.Linear(2*args.hidden_size, 2*args.hidden_size) 149 | 150 | if self.item_reg: 151 | self.reg = Contrastive_Loss2() 152 | 153 | self._init_weights() 154 | 155 | 156 | def forward(self, 157 | seq, 158 | pos, 159 | neg, 160 | positions, 161 | **kwargs): 162 | 163 | loss = super().forward(seq, pos, neg, positions, **kwargs) # get the original loss 164 | 165 | log_feats = self.log2feats(seq, positions)[:, -1, :] 166 | sim_seq, sim_positions = kwargs["sim_seq"].view(-1, seq.shape[1]), kwargs["sim_positions"].view(-1, seq.shape[1]) 167 | sim_num = kwargs["sim_seq"].shape[1] 168 | sim_log_feats = self.log2feats(sim_seq, sim_positions)[:, -1, :] 169 | sim_log_feats = sim_log_feats.detach().view(seq.shape[0], sim_num, -1) # (bs, sim_num, hidden_size) 170 | sim_log_feats = torch.mean(sim_log_feats, dim=1) 171 | 172 | if self.user_sim_func == "cl": 173 | # align_loss = self.align(self.projector1(log_feats), self.projector2(sim_log_feats)) 174 | align_loss = self.align(log_feats, sim_log_feats) 175 | elif self.user_sim_func == "kd": 176 | align_loss = self.align(log_feats, sim_log_feats) 177 | 178 | loss += self.alpha * align_loss 179 | 180 | return loss 181 | 182 | 183 | 184 | -------------------------------------------------------------------------------- /models/SASRec.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import numpy as np 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import pickle 7 | from models.BaseModel import BaseSeqModel 8 | from models.utils import PointWiseFeedForward 9 | 10 | 11 | 12 | class SASRecBackbone(nn.Module): 13 | 14 | def __init__(self, device, args) -> None: 15 | 16 | super().__init__() 17 | 18 | self.dev = device 19 | self.attention_layernorms = torch.nn.ModuleList() # to be Q for self-attention 20 | self.attention_layers = torch.nn.ModuleList() 21 | self.forward_layernorms = torch.nn.ModuleList() 22 | self.forward_layers = torch.nn.ModuleList() 23 | 24 | self.last_layernorm = torch.nn.LayerNorm(args.hidden_size, eps=1e-8) 25 | 26 | for _ in range(args.trm_num): 27 | new_attn_layernorm = torch.nn.LayerNorm(args.hidden_size, eps=1e-8) 28 | self.attention_layernorms.append(new_attn_layernorm) 29 | 30 | new_attn_layer = torch.nn.MultiheadAttention(args.hidden_size, 31 | args.num_heads, 32 | args.dropout_rate) 33 | self.attention_layers.append(new_attn_layer) 34 | 35 | new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_size, eps=1e-8) 36 | self.forward_layernorms.append(new_fwd_layernorm) 37 | 38 | new_fwd_layer = PointWiseFeedForward(args.hidden_size, args.dropout_rate) 39 | self.forward_layers.append(new_fwd_layer) 40 | 41 | 42 | def forward(self, seqs, log_seqs): 43 | 44 | #timeline_mask = torch.BoolTensor(log_seqs == 0).to(self.dev) 45 | timeline_mask = (log_seqs == 0) 46 | seqs *= ~timeline_mask.unsqueeze(-1) # broadcast in last dim 47 | 48 | tl = seqs.shape[1] # time dim len for enforce causality 49 | attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.dev)) 50 | 51 | for i in range(len(self.attention_layers)): 52 | seqs = torch.transpose(seqs, 0, 1) 53 | Q = self.attention_layernorms[i](seqs) 54 | mha_outputs, _ = self.attention_layers[i](Q, seqs, seqs,) 55 | # attn_mask=attention_mask) 56 | seqs = Q + mha_outputs 57 | seqs = torch.transpose(seqs, 0, 1) 58 | 59 | seqs = self.forward_layernorms[i](seqs) 60 | seqs = self.forward_layers[i](seqs) 61 | seqs *= ~timeline_mask.unsqueeze(-1) 62 | 63 | log_feats = self.last_layernorm(seqs) # (U, T, C) -> (U, -1, C) 64 | 65 | return log_feats 66 | 67 | 68 | 69 | class SASRec(BaseSeqModel): 70 | 71 | def __init__(self, user_num, item_num, device, args): 72 | 73 | super(SASRec, self).__init__(user_num, item_num, device, args) 74 | 75 | # self.user_num = user_num 76 | # self.item_num = item_num 77 | # self.dev = device 78 | 79 | self.item_emb = torch.nn.Embedding(self.item_num+2, args.hidden_size, padding_idx=0) 80 | self.pos_emb = torch.nn.Embedding(args.max_len+100, args.hidden_size) # TO IMPROVE 81 | self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate) 82 | 83 | self.backbone = SASRecBackbone(device, args) 84 | 85 | self.loss_func = torch.nn.BCEWithLogitsLoss() 86 | 87 | # self.filter_init_modules = [] 88 | self._init_weights() 89 | 90 | 91 | def _get_embedding(self, log_seqs): 92 | 93 | item_seq_emb = self.item_emb(log_seqs) 94 | 95 | return item_seq_emb 96 | 97 | 98 | def log2feats(self, log_seqs, positions): 99 | '''Get the representation of given sequence''' 100 | seqs = self._get_embedding(log_seqs) 101 | seqs *= self.item_emb.embedding_dim ** 0.5 102 | seqs += self.pos_emb(positions.long()) 103 | seqs = self.emb_dropout(seqs) 104 | 105 | log_feats = self.backbone(seqs, log_seqs) 106 | 107 | return log_feats 108 | 109 | 110 | def forward(self, 111 | seq, 112 | pos, 113 | neg, 114 | positions): # for training 115 | '''Used to calculate pos and neg logits for loss''' 116 | log_feats = self.log2feats(seq, positions) # (bs, max_len, hidden_size) 117 | log_feats = log_feats[:, -1, :].unsqueeze(1) # (bs, hidden_size) 118 | 119 | pos_embs = self._get_embedding(pos.unsqueeze(1)) # (bs, 1, hidden_size) 120 | neg_embs = self._get_embedding(neg) # (bs, neg_num, hidden_size) 121 | 122 | pos_logits = torch.mul(log_feats, pos_embs).sum(dim=-1) # (bs, 1) 123 | neg_logits = torch.mul(log_feats, neg_embs).sum(dim=-1) # (bs, neg_num) 124 | 125 | pos_labels, neg_labels = torch.ones(pos_logits.shape, device=self.dev), torch.zeros(neg_logits.shape, device=self.dev) 126 | indices = (pos != 0) # do not calculate the padding units 127 | pos_loss, neg_loss = self.loss_func(pos_logits[indices], pos_labels[indices]), self.loss_func(neg_logits[indices], neg_labels[indices]) 128 | loss = pos_loss + neg_loss 129 | 130 | return loss # loss 131 | 132 | 133 | def predict(self, 134 | seq, 135 | item_indices, 136 | positions, 137 | **kwargs): # for inference 138 | '''Used to predict the score of item_indices given log_seqs''' 139 | log_feats = self.log2feats(seq, positions) # user_ids hasn't been used yet 140 | final_feat = log_feats[:, -1, :] # only use last QKV classifier, a waste 141 | item_embs = self._get_embedding(item_indices) # (U, I, C) 142 | logits = item_embs.matmul(final_feat.unsqueeze(-1)).squeeze(-1) 143 | 144 | return logits # preds # (U, I) 145 | 146 | 147 | def get_user_emb(self, 148 | seq, 149 | positions, 150 | **kwargs): 151 | log_feats = self.log2feats(seq, positions) # user_ids hasn't been used yet 152 | final_feat = log_feats[:, -1, :] # only use last QKV classifier, a waste 153 | 154 | return final_feat 155 | 156 | 157 | 158 | class SASRec_seq(SASRec): 159 | 160 | def __init__(self, user_num, item_num, device, args): 161 | 162 | super().__init__(user_num, item_num, device, args) 163 | 164 | 165 | def forward(self, 166 | seq, 167 | pos, 168 | neg, 169 | positions, 170 | **kwargs): 171 | '''apply the seq-to-seq loss''' 172 | log_feats = self.log2feats(seq, positions) 173 | pos_embs = self._get_embedding(pos) # (bs, max_seq_len, hidden_size) 174 | neg_embs = self._get_embedding(neg) # (bs, max_seq_len, hidden_size) 175 | 176 | pos_logits = (log_feats * pos_embs).sum(dim=-1) 177 | neg_logits = (log_feats * neg_embs).sum(dim=-1) 178 | 179 | pos_labels, neg_labels = torch.ones(pos_logits.shape, device=self.dev), torch.zeros(neg_logits.shape, device=self.dev) 180 | indices = (pos != 0) # do not calculate the padding units 181 | pos_loss, neg_loss = self.loss_func(pos_logits[indices], pos_labels[indices]), self.loss_func(neg_logits[indices], neg_labels[indices]) 182 | loss = pos_loss + neg_loss 183 | 184 | return loss 185 | 186 | 187 | 188 | 189 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from math import sqrt 7 | 8 | 9 | class PointWiseFeedForward(torch.nn.Module): 10 | def __init__(self, hidden_units, dropout_rate): 11 | 12 | super(PointWiseFeedForward, self).__init__() 13 | 14 | self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1) 15 | self.dropout1 = torch.nn.Dropout(p=dropout_rate) 16 | self.relu = torch.nn.ReLU() 17 | self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1) 18 | self.dropout2 = torch.nn.Dropout(p=dropout_rate) 19 | 20 | def forward(self, inputs): 21 | outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2)))))) 22 | outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length) 23 | outputs += inputs 24 | return outputs 25 | 26 | 27 | 28 | class Contrastive_Loss2(nn.Module): 29 | 30 | def __init__(self, tau=1) -> None: 31 | super().__init__() 32 | 33 | self.temperature = tau 34 | 35 | 36 | def forward(self, X, Y): 37 | 38 | logits = (X @ Y.T) / self.temperature 39 | X_similarity = Y @ Y.T 40 | Y_similarity = X @ X.T 41 | targets = F.softmax( 42 | (X_similarity + Y_similarity) / 2 * self.temperature, dim=-1 43 | ) 44 | X_loss = self.cross_entropy(logits, targets, reduction='none') 45 | Y_loss = self.cross_entropy(logits.T, targets.T, reduction='none') 46 | loss = (Y_loss + X_loss) / 2.0 # shape: (batch_size) 47 | return loss.mean() 48 | 49 | 50 | def cross_entropy(self, preds, targets, reduction='none'): 51 | 52 | log_softmax = nn.LogSoftmax(dim=-1) 53 | loss = (-targets * log_softmax(preds)).sum(1) 54 | if reduction == "none": 55 | return loss 56 | elif reduction == "mean": 57 | return loss.mean() 58 | 59 | 60 | 61 | class CalculateAttention(nn.Module): 62 | 63 | def __init__(self): 64 | super().__init__() 65 | 66 | 67 | def forward(self, Q, K, V, mask): 68 | 69 | attention = torch.matmul(Q,torch.transpose(K, -1, -2)) 70 | # use mask 71 | attention = attention.masked_fill_(mask, -1e9) 72 | attention = torch.softmax(attention / sqrt(Q.size(-1)), dim=-1) 73 | attention = torch.matmul(attention,V) 74 | return attention 75 | 76 | 77 | 78 | class Multi_CrossAttention(nn.Module): 79 | """ 80 | forward时,第一个参数用于计算query,第二个参数用于计算key和value 81 | """ 82 | def __init__(self,hidden_size,all_head_size,head_num): 83 | super().__init__() 84 | self.hidden_size = hidden_size # 输入维度 85 | self.all_head_size = all_head_size # 输出维度 86 | self.num_heads = head_num # 注意头的数量 87 | self.h_size = all_head_size // head_num 88 | 89 | assert all_head_size % head_num == 0 90 | 91 | # W_Q,W_K,W_V (hidden_size,all_head_size) 92 | self.linear_q = nn.Linear(hidden_size, all_head_size, bias=False) 93 | self.linear_k = nn.Linear(hidden_size, all_head_size, bias=False) 94 | self.linear_v = nn.Linear(hidden_size, all_head_size, bias=False) 95 | self.linear_output = nn.Linear(all_head_size, hidden_size) 96 | 97 | # normalization 98 | self.norm = sqrt(all_head_size) 99 | 100 | 101 | def print(self): 102 | print(self.hidden_size,self.all_head_size) 103 | print(self.linear_k,self.linear_q,self.linear_v) 104 | 105 | 106 | def forward(self,x,y,log_seqs): 107 | """ 108 | cross-attention: x,y是两个模型的隐藏层,将x作为q的输入,y作为k和v的输入 109 | """ 110 | 111 | batch_size = x.size(0) 112 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) 113 | 114 | # q_s: [batch_size, num_heads, seq_length, h_size] 115 | q_s = self.linear_q(x).view(batch_size, -1, self.num_heads, self.h_size).transpose(1,2) 116 | 117 | # k_s: [batch_size, num_heads, seq_length, h_size] 118 | k_s = self.linear_k(y).view(batch_size, -1, self.num_heads, self.h_size).transpose(1,2) 119 | 120 | # v_s: [batch_size, num_heads, seq_length, h_size] 121 | v_s = self.linear_v(y).view(batch_size, -1, self.num_heads, self.h_size).transpose(1,2) 122 | 123 | # attention_mask = attention_mask.eq(0) 124 | attention_mask = (log_seqs == 0).unsqueeze(1).repeat(1, log_seqs.size(1), 1).unsqueeze(1) 125 | 126 | attention = CalculateAttention()(q_s,k_s,v_s,attention_mask) 127 | # attention : [batch_size , seq_length , num_heads * h_size] 128 | attention = attention.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.h_size) 129 | 130 | # output : [batch_size , seq_length , hidden_size] 131 | output = self.linear_output(attention) 132 | 133 | return output 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.3 2 | pandas==2.0.1 3 | torch==1.12.0+cu102 4 | tqdm==4.65.0 5 | -------------------------------------------------------------------------------- /trainers/sequence_trainer.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import os 3 | import time 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | from trainers.trainer import Trainer 8 | from utils.utils import metric_report, metric_len_report, record_csv, metric_pop_report 9 | from utils.utils import metric_len_5group, metric_pop_5group 10 | 11 | 12 | class SeqTrainer(Trainer): 13 | 14 | def __init__(self, args, logger, writer, device, generator): 15 | 16 | super().__init__(args, logger, writer, device, generator) 17 | 18 | 19 | def _train_one_epoch(self, epoch): 20 | 21 | tr_loss = 0 22 | nb_tr_examples, nb_tr_steps = 0, 0 23 | train_time = [] 24 | 25 | self.model.train() 26 | prog_iter = tqdm(self.train_loader, leave=False, desc='Training') 27 | 28 | for batch in prog_iter: 29 | 30 | batch = tuple(t.to(self.device) for t in batch) 31 | 32 | train_start = time.time() 33 | inputs = self._prepare_train_inputs(batch) 34 | loss = self.model(**inputs) 35 | loss.backward() 36 | 37 | tr_loss += loss.item() 38 | nb_tr_examples += 1 39 | nb_tr_steps += 1 40 | 41 | # Display loss 42 | prog_iter.set_postfix(loss='%.4f' % (tr_loss / nb_tr_steps)) 43 | 44 | self.optimizer.step() 45 | self.optimizer.zero_grad() 46 | 47 | train_end = time.time() 48 | train_time.append(train_end-train_start) 49 | 50 | self.writer.add_scalar('train/loss', tr_loss / nb_tr_steps, epoch) 51 | 52 | 53 | 54 | def eval(self, epoch=0, test=False): 55 | 56 | print('') 57 | if test: 58 | self.logger.info("\n----------------------------------------------------------------") 59 | self.logger.info("********** Running test **********") 60 | desc = 'Testing' 61 | model_state_dict = torch.load(os.path.join(self.args.output_dir, 'pytorch_model.bin')) 62 | self.model.load_state_dict(model_state_dict['state_dict']) 63 | self.model.to(self.device) 64 | test_loader = self.test_loader 65 | 66 | else: 67 | self.logger.info("\n----------------------------------") 68 | self.logger.info("********** Epoch: %d eval **********" % epoch) 69 | desc = 'Evaluating' 70 | test_loader = self.valid_loader 71 | 72 | self.model.eval() 73 | pred_rank = torch.empty(0).to(self.device) 74 | seq_len = torch.empty(0).to(self.device) 75 | target_items = torch.empty(0).to(self.device) 76 | 77 | for batch in tqdm(test_loader, desc=desc): 78 | 79 | batch = tuple(t.to(self.device) for t in batch) 80 | inputs = self._prepare_eval_inputs(batch) 81 | seq_len = torch.cat([seq_len, torch.sum(inputs["seq"]>0, dim=1)]) 82 | target_items = torch.cat([target_items, inputs["pos"]]) 83 | 84 | with torch.no_grad(): 85 | 86 | inputs["item_indices"] = torch.cat([inputs["pos"].unsqueeze(1), inputs["neg"]], dim=1) 87 | pred_logits = -self.model.predict(**inputs) 88 | 89 | per_pred_rank = torch.argsort(torch.argsort(pred_logits))[:, 0] 90 | pred_rank = torch.cat([pred_rank, per_pred_rank]) 91 | 92 | self.logger.info('') 93 | res_dict = metric_report(pred_rank.detach().cpu().numpy()) 94 | res_len_dict = metric_len_report(pred_rank.detach().cpu().numpy(), seq_len.detach().cpu().numpy(), aug_len=self.args.aug_seq_len, args=self.args) 95 | res_pop_dict = metric_pop_report(pred_rank.detach().cpu().numpy(), self.item_pop, target_items.detach().cpu().numpy(), args=self.args) 96 | 97 | self.logger.info("Overall Performance:") 98 | for k, v in res_dict.items(): 99 | if not test: 100 | self.writer.add_scalar('Test/{}'.format(k), v, epoch) 101 | self.logger.info('\t %s: %.5f' % (k, v)) 102 | 103 | if test: 104 | self.logger.info("User Group Performance:") 105 | for k, v in res_len_dict.items(): 106 | if not test: 107 | self.writer.add_scalar('Test/{}'.format(k), v, epoch) 108 | self.logger.info('\t %s: %.5f' % (k, v)) 109 | self.logger.info("Item Group Performance:") 110 | for k, v in res_pop_dict.items(): 111 | if not test: 112 | self.writer.add_scalar('Test/{}'.format(k), v, epoch) 113 | self.logger.info('\t %s: %.5f' % (k, v)) 114 | 115 | res_dict = {**res_dict, **res_len_dict, **res_pop_dict} 116 | 117 | if test: 118 | record_csv(self.args, res_dict) 119 | 120 | return res_dict 121 | 122 | 123 | 124 | def save_user_emb(self): 125 | 126 | model_state_dict = torch.load(os.path.join(self.args.output_dir, 'pytorch_model.bin')) 127 | try: 128 | self.model.load_state_dict(model_state_dict['state_dict']) 129 | except: 130 | self.model.load_state_dict(model_state_dict) 131 | self.model.to(self.device) 132 | test_loader = self.test_loader 133 | 134 | self.model.eval() 135 | user_emb = torch.empty(0).to(self.device) 136 | desc = 'Running' 137 | 138 | for batch in tqdm(test_loader, desc=desc): 139 | 140 | batch = tuple(t.to(self.device) for t in batch) 141 | inputs = self._prepare_eval_inputs(batch) 142 | 143 | with torch.no_grad(): 144 | 145 | per_user_emb = self.model.get_user_emb(**inputs) 146 | user_emb = torch.cat([user_emb, per_user_emb], dim=0) 147 | 148 | user_emb = user_emb.detach().cpu().numpy() 149 | import pickle 150 | pickle.dump(user_emb, open("./usr_emb_sasrec.pkl", "wb")) 151 | 152 | 153 | 154 | def test_group(self): 155 | 156 | print('') 157 | self.logger.info("\n----------------------------------------------------------------") 158 | self.logger.info("********** Running Group test **********") 159 | desc = 'Testing' 160 | model_state_dict = torch.load(os.path.join(self.args.output_dir, 'pytorch_model.bin')) 161 | self.model.load_state_dict(model_state_dict['state_dict']) 162 | self.model.to(self.device) 163 | test_loader = self.test_loader 164 | 165 | self.model.eval() 166 | pred_rank = torch.empty(0).to(self.device) 167 | seq_len = torch.empty(0).to(self.device) 168 | target_items = torch.empty(0).to(self.device) 169 | 170 | for batch in tqdm(test_loader, desc=desc): 171 | 172 | batch = tuple(t.to(self.device) for t in batch) 173 | inputs = self._prepare_eval_inputs(batch) 174 | seq_len = torch.cat([seq_len, torch.sum(inputs["seq"]>0, dim=1)]) 175 | target_items = torch.cat([target_items, inputs["pos"]]) 176 | 177 | with torch.no_grad(): 178 | 179 | inputs["item_indices"] = torch.cat([inputs["pos"].unsqueeze(1), inputs["neg"]], dim=1) 180 | pred_logits = -self.model.predict(**inputs) 181 | 182 | per_pred_rank = torch.argsort(torch.argsort(pred_logits))[:, 0] 183 | pred_rank = torch.cat([pred_rank, per_pred_rank]) 184 | 185 | self.logger.info('') 186 | res_dict = metric_report(pred_rank.detach().cpu().numpy()) 187 | # res_len_dict = metric_len_report(pred_rank.detach().cpu().numpy(), seq_len.detach().cpu().numpy(), aug_len=self.args.aug_seq_len, args=self.args) 188 | # res_pop_dict = metric_pop_report(pred_rank.detach().cpu().numpy(), self.item_pop, target_items.detach().cpu().numpy(), args=self.args) 189 | hr_len, ndcg_len, count_len = metric_len_5group(pred_rank.detach().cpu().numpy(), seq_len.detach().cpu().numpy(), [5, 10, 15, 20]) 190 | hr_pop, ndcg_pop, count_pop = metric_pop_5group(pred_rank.detach().cpu().numpy(), self.item_pop, target_items.detach().cpu().numpy(), [10, 30, 60, 100]) 191 | 192 | self.logger.info("Overall Performance:") 193 | for k, v in res_dict.items(): 194 | self.logger.info('\t %s: %.5f' % (k, v)) 195 | 196 | self.logger.info("User Group Performance:") 197 | for i, (hr, ndcg) in enumerate(zip(hr_len, ndcg_len)): 198 | self.logger.info('The %d Group: HR %.4f, NDCG %.4f' % (i, hr, ndcg)) 199 | self.logger.info("Item Group Performance:") 200 | for i, (hr, ndcg) in enumerate(zip(hr_pop, ndcg_pop)): 201 | self.logger.info('The %d Group: HR %.4f, NDCG %.4f' % (i, hr, ndcg)) 202 | 203 | 204 | return res_dict 205 | 206 | 207 | 208 | class CL4SRecTrainer(SeqTrainer): 209 | 210 | def __init__(self, args, logger, writer, device, generator): 211 | 212 | super().__init__(args, logger, writer, device, generator) 213 | 214 | 215 | def _train_one_epoch(self, epoch): 216 | 217 | tr_loss = 0 218 | nb_tr_examples, nb_tr_steps = 0, 0 219 | train_time = [] 220 | 221 | self.model.train() 222 | prog_iter = tqdm(self.train_loader, leave=False, desc='Training') 223 | 224 | for batch in prog_iter: 225 | 226 | batch = tuple(t.to(self.device) for t in batch) 227 | 228 | train_start = time.time() 229 | seq, pos, neg, positions, aug1, aug2 = batch 230 | seq, pos, neg, positions, aug1, aug2 = seq.long(), pos.long(), neg.long(), positions.long(), aug1.long(), aug2.long() 231 | aug = (aug1, aug2) 232 | loss = self.model(seq, pos, neg, positions, aug) 233 | loss.backward() 234 | 235 | tr_loss += loss.item() 236 | nb_tr_examples += 1 237 | nb_tr_steps += 1 238 | 239 | # Display loss 240 | prog_iter.set_postfix(loss='%.4f' % (tr_loss / nb_tr_steps)) 241 | 242 | self.optimizer.step() 243 | self.optimizer.zero_grad() 244 | 245 | train_end = time.time() 246 | train_time.append(train_end-train_start) 247 | 248 | self.writer.add_scalar('train/loss', tr_loss / nb_tr_steps, epoch) 249 | 250 | 251 | 252 | class SSEPTTrainer(Trainer): 253 | 254 | def __init__(self, args, logger, writer, device, generator): 255 | 256 | super().__init__(args, logger, writer, device, generator) 257 | 258 | 259 | def _train_one_epoch(self, epoch): 260 | 261 | tr_loss = 0 262 | nb_tr_examples, nb_tr_steps = 0, 0 263 | train_time = [] 264 | 265 | self.model.train() 266 | prog_iter = tqdm(self.train_loader, leave=False, desc='Training') 267 | 268 | for batch in prog_iter: 269 | 270 | batch = tuple(t.to(self.device) for t in batch) 271 | 272 | train_start = time.time() 273 | seq_user, pos_user, neg_user, seq, pos, neg, positions = batch 274 | seq, pos, neg, positions = seq.long(), pos.long(), neg.long(), positions.long() 275 | seq_user, pos_user, neg_user = seq_user.long(), pos_user.long(), neg_user.long() 276 | loss = self.model(seq_user, pos_user, neg_user, seq, pos, neg, positions) 277 | loss.backward() 278 | 279 | tr_loss += loss.item() 280 | nb_tr_examples += 1 281 | nb_tr_steps += 1 282 | 283 | # Display loss 284 | prog_iter.set_postfix(loss='%.4f' % (tr_loss / nb_tr_steps)) 285 | 286 | self.optimizer.step() 287 | self.optimizer.zero_grad() 288 | 289 | train_end = time.time() 290 | train_time.append(train_end-train_start) 291 | 292 | self.writer.add_scalar('train/loss', tr_loss / nb_tr_steps, epoch) 293 | 294 | 295 | 296 | def eval(self, epoch=0, test=False): 297 | 298 | print('') 299 | if test: 300 | self.logger.info("\n----------------------------------------------------------------") 301 | self.logger.info("********** Running test **********") 302 | desc = 'Testing' 303 | model_state_dict = torch.load(os.path.join(self.args.output_dir, 'pytorch_model.bin')) 304 | try: 305 | self.model.load_state_dict(model_state_dict['state_dict']) 306 | except: 307 | self.model.load_state_dict(model_state_dict) 308 | self.model.to(self.device) 309 | test_loader = self.test_loader 310 | 311 | else: 312 | self.logger.info("\n----------------------------------") 313 | self.logger.info("********** Epoch: %d eval **********" % epoch) 314 | desc = 'Evaluating' 315 | test_loader = self.valid_loader 316 | 317 | self.model.eval() 318 | pred_rank = torch.empty(0).to(self.device) 319 | seq_len = torch.empty(0).to(self.device) 320 | 321 | for batch in tqdm(test_loader, desc=desc): 322 | 323 | batch = tuple(t.to(self.device) for t in batch) 324 | seq_user, pos_user, neg_user, seq, pos, neg, positions = batch 325 | seq, pos, neg, positions = seq.long(), pos.long(), neg.long(), positions.long() 326 | seq_user, pos_user, neg_user = seq_user.long(), pos_user.long(), neg_user.long() 327 | seq_len = torch.cat([seq_len, torch.sum(seq>0, dim=1)]) 328 | 329 | with torch.no_grad(): 330 | 331 | pred_logits = -self.model.predict(seq_user, seq, torch.cat([pos_user.unsqueeze(1), neg_user], dim=1), torch.cat([pos.unsqueeze(1), neg], dim=1), positions) 332 | 333 | per_pred_rank = torch.argsort(torch.argsort(pred_logits))[:, 0] 334 | pred_rank = torch.cat([pred_rank, per_pred_rank]) 335 | 336 | self.logger.info('') 337 | res_dict = metric_report(pred_rank.detach().cpu().numpy()) 338 | res_len_dict = metric_len_report(pred_rank.detach().cpu().numpy(), seq_len.detach().cpu().numpy(), aug_len=self.args.aug_seq_len) 339 | 340 | for k, v in res_dict.items(): 341 | if not test: 342 | self.writer.add_scalar('Test/{}'.format(k), v, epoch) 343 | self.logger.info('%s: %.5f' % (k, v)) 344 | for k, v in res_len_dict.items(): 345 | if not test: 346 | self.writer.add_scalar('Test/{}'.format(k), v, epoch) 347 | self.logger.info('%s: %.5f' % (k, v)) 348 | 349 | res_dict = {**res_dict, **res_len_dict} 350 | 351 | if test: 352 | record_csv(self.args, res_dict) 353 | 354 | return res_dict 355 | -------------------------------------------------------------------------------- /trainers/trainer.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import os 3 | import torch 4 | from tqdm import trange 5 | from utils.earlystop import EarlyStoppingNew 6 | from utils.utils import get_n_params 7 | from models.LLMESR import LLMESR_SASRec, LLMESR_Bert4Rec, LLMESR_GRU4Rec 8 | 9 | 10 | 11 | class Trainer(object): 12 | 13 | def __init__(self, args, logger, writer, device, generator): 14 | 15 | self.args = args 16 | self.logger = logger 17 | self.writer = writer 18 | self.device = device 19 | self.user_num, self.item_num = generator.get_user_item_num() 20 | self.start_epoch = 0 # define the start epoch for keepon training 21 | 22 | self.logger.info('Loading Model: ' + args.model_name) 23 | self._create_model() 24 | logger.info('# of model parameters: ' + str(get_n_params(self.model))) 25 | 26 | self._set_optimizer() 27 | self._set_scheduler() 28 | self._set_stopper() 29 | 30 | if args.keepon: 31 | self._load_pretrained_model() 32 | 33 | self.loss_func = torch.nn.BCEWithLogitsLoss() 34 | 35 | self.train_loader = generator.make_trainloader() 36 | self.valid_loader = generator.make_evalloader() 37 | self.test_loader = generator.make_evalloader(test=True) 38 | self.generator = generator 39 | 40 | # get item pop and user len 41 | self.item_pop = generator.get_item_pop() 42 | self.user_len = generator.get_user_len() 43 | 44 | #self.watch_metric = 'NDCG@10' # use which metric to select model 45 | self.watch_metric = args.watch_metric 46 | 47 | 48 | def _create_model(self): 49 | '''create your model''' 50 | if self.args.model_name == "llmesr_sasrec": 51 | self.model = LLMESR_SASRec(self.user_num, self.item_num, self.device, self.args) 52 | elif self.args.model_name == "llmesr_gru4rec": 53 | self.model = LLMESR_GRU4Rec(self.user_num, self.item_num, self.device, self.args) 54 | elif self.args.model_name == "llmesr_bert4rec": 55 | self.model = LLMESR_Bert4Rec(self.user_num, self.item_num, self.device, self.args) 56 | else: 57 | raise ValueError 58 | 59 | self.model.to(self.device) 60 | 61 | 62 | def _load_pretrained_model(self): 63 | 64 | self.logger.info("Loading the trained model for keep on training ... ") 65 | checkpoint_path = os.path.join(self.args.keepon_path, 'pytorch_model.bin') 66 | 67 | model_dict = self.model.state_dict() 68 | checkpoint = torch.load(checkpoint_path, map_location=self.device) 69 | pretrained_dict = checkpoint['state_dict'] 70 | 71 | # filter out required parameters 72 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 73 | model_dict.update(new_dict) 74 | # Print: how many parameters are loaded from the checkpoint 75 | self.logger.info('Total loaded parameters: {}, update: {}'.format(len(pretrained_dict), len(new_dict))) 76 | self.model.load_state_dict(model_dict) # load model parameters 77 | self.optimizer.load_state_dict(checkpoint['optimizer']) # load optimizer 78 | self.scheduler.load_state_dict(checkpoint['scheduler']) # load scheduler 79 | self.start_epoch = checkpoint['epoch'] # load epoch 80 | 81 | 82 | def _set_optimizer(self): 83 | 84 | self.optimizer = torch.optim.Adam(self.model.parameters(), 85 | lr=self.args.lr, 86 | weight_decay=self.args.l2, 87 | ) 88 | 89 | 90 | def _set_scheduler(self): 91 | 92 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 93 | step_size=self.args.lr_dc_step, 94 | gamma=self.args.lr_dc) 95 | 96 | 97 | def _set_stopper(self): 98 | 99 | self.stopper = EarlyStoppingNew(patience=self.args.patience, 100 | verbose=False, 101 | path=self.args.output_dir, 102 | trace_func=self.logger) 103 | 104 | 105 | def _train_one_epoch(self, epoch): 106 | 107 | return NotImplementedError 108 | 109 | 110 | def _prepare_train_inputs(self, data): 111 | """Prepare the inputs as a dict for training""" 112 | assert len(self.generator.train_dataset.var_name) == len(data) 113 | inputs = {} 114 | for i, var_name in enumerate(self.generator.train_dataset.var_name): 115 | inputs[var_name] = data[i] 116 | 117 | return inputs 118 | 119 | 120 | def _prepare_eval_inputs(self, data): 121 | """Prepare the inputs as a dict for evaluation""" 122 | inputs = {} 123 | assert len(self.generator.eval_dataset.var_name) == len(data) 124 | for i, var_name in enumerate(self.generator.eval_dataset.var_name): 125 | inputs[var_name] = data[i] 126 | 127 | return inputs 128 | 129 | 130 | def eval(self, epoch=0, test=False): 131 | 132 | return NotImplementedError 133 | 134 | 135 | def train(self): 136 | 137 | model_to_save = self.model.module if hasattr(self.model, 'module') else self.model # Only save the model it-self 138 | self.logger.info("\n----------------------------------------------------------------") 139 | self.logger.info("********** Running training **********") 140 | self.logger.info(" Batch size = %d", self.args.train_batch_size) 141 | res_list = [] 142 | train_time = [] 143 | 144 | for epoch in trange(self.start_epoch, self.start_epoch + int(self.args.num_train_epochs), desc="Epoch"): 145 | 146 | t = self._train_one_epoch(epoch) 147 | 148 | train_time.append(t) 149 | 150 | # evluate on validation per 20 epochs 151 | if (epoch % 1) == 0: 152 | 153 | metric_dict = self.eval(epoch=epoch) 154 | res_list.append(metric_dict) 155 | #self.scheduler.step() 156 | self.stopper(metric_dict[self.watch_metric], epoch, model_to_save, self.optimizer, self.scheduler) 157 | 158 | if self.stopper.early_stop: 159 | 160 | break 161 | 162 | best_epoch = self.stopper.best_epoch 163 | best_res = res_list[best_epoch - self.start_epoch] 164 | self.logger.info('') 165 | self.logger.info('The best epoch is %d' % best_epoch) 166 | self.logger.info('The best results are NDCG@10: %.5f, HR@10: %.5f' % 167 | (best_res['NDCG@10'], best_res['HR@10'])) 168 | 169 | res = self.eval(test=True) 170 | 171 | return res, best_epoch 172 | 173 | 174 | 175 | def test(self): 176 | """Do test directly. Set the output dir as the path that save the checkpoint""" 177 | res = self.eval(test=True) 178 | 179 | return res, -1 180 | 181 | 182 | 183 | def get_model(self): 184 | 185 | return self.model 186 | 187 | 188 | def get_model_param_num(self): 189 | 190 | total_num = sum(p.numel() for p in self.model.parameters()) 191 | trainable_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad) 192 | freeze_num = total_num - trainable_num 193 | 194 | return freeze_num, trainable_num 195 | 196 | 197 | -------------------------------------------------------------------------------- /utils/earlystop.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code based on https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py 3 | The origin code conduct early stopping according to validation loss, 4 | I alter it to early stop according to validation performance. 5 | ''' 6 | 7 | 8 | import numpy as np 9 | import torch 10 | import os 11 | 12 | 13 | class EarlyStopping(): 14 | """Early stops the training if validation performance doesn't improve after a given patience.""" 15 | def __init__(self, patience=7, verbose=False, delta=0, path='./checkpoint/', trace_func=print, model='checkpoint'): 16 | """ 17 | Args: 18 | patience (int): How long to wait after last time validation loss improved. 19 | Default: 7 20 | verbose (bool): If True, prints a message for each validation loss improvement. 21 | Default: False 22 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 23 | Default: 0 24 | path (str): Path for the checkpoint to be saved to. 25 | Default: 'checkpoint.pt' 26 | trace_func (function): trace print function. 27 | Default: print 28 | """ 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | 32 | self.patience = patience 33 | self.verbose = verbose 34 | self.counter = 0 35 | self.best_score = None # record the best score 36 | self.best_epoch = 0 # record the best epoch 37 | self.early_stop = False 38 | self.val_loss_min = np.Inf 39 | self.delta = delta 40 | self.path = os.path.join(path, "pytorch_model.bin") 41 | self.trace_func = trace_func 42 | 43 | 44 | def __call__(self, indicator, epoch, model): 45 | 46 | score = indicator 47 | 48 | if self.best_score is None: # for the first epoch 49 | self.best_score = score 50 | self.best_epoch = epoch 51 | self.save_checkpoint(score, model) 52 | elif score <= self.best_score + self.delta: 53 | self.counter += 1 54 | #self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 55 | if self.counter >= self.patience: 56 | self.early_stop = True 57 | else: 58 | self.best_score = score 59 | self.best_epoch = epoch 60 | self.save_checkpoint(score, model) 61 | self.counter = 0 62 | 63 | def save_checkpoint(self, val_loss, model): 64 | '''Saves model when validation loss decrease.''' 65 | if self.verbose: 66 | self.trace_func(f'The best score is ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 67 | torch.save(model.state_dict(), self.path) 68 | #self.val_loss_min = val_loss 69 | 70 | 71 | 72 | class EarlyStoppingNew(): 73 | """Early stops the training if validation performance doesn't improve after a given patience.""" 74 | def __init__(self, patience=7, verbose=False, delta=0, path='./checkpoint/', trace_func=print, model='checkpoint'): 75 | """ 76 | Args: 77 | patience (int): How long to wait after last time validation loss improved. 78 | Default: 7 79 | verbose (bool): If True, prints a message for each validation loss improvement. 80 | Default: False 81 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 82 | Default: 0 83 | path (str): Path for the checkpoint to be saved to. 84 | Default: 'checkpoint.pt' 85 | trace_func (function): trace print function. 86 | Default: print 87 | """ 88 | if not os.path.exists(path): 89 | os.makedirs(path) 90 | 91 | self.patience = patience 92 | self.verbose = verbose 93 | self.counter = 0 94 | self.best_score = None # record the best score 95 | self.best_epoch = 0 # record the best epoch 96 | self.early_stop = False 97 | self.val_loss_min = np.Inf 98 | self.delta = delta 99 | self.path = os.path.join(path, "pytorch_model.bin") 100 | self.trace_func = trace_func 101 | 102 | 103 | def __call__(self, indicator, epoch, model, optimizer=None, scheduler=None): 104 | 105 | score = indicator 106 | 107 | if self.best_score is None: # for the first epoch 108 | self.best_score = score 109 | self.best_epoch = epoch 110 | self.save_checkpoint(score, model, optimizer, scheduler, epoch) 111 | elif score <= self.best_score + self.delta: 112 | self.counter += 1 113 | #self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 114 | if self.counter >= self.patience: 115 | self.early_stop = True 116 | else: 117 | self.best_score = score 118 | self.best_epoch = epoch 119 | self.save_checkpoint(score, model, optimizer, scheduler, epoch) 120 | self.counter = 0 121 | 122 | def save_checkpoint(self, val_loss, model, optimizer, scheduler, epoch): 123 | '''Saves model when validation loss decrease.''' 124 | if self.verbose: 125 | self.trace_func(f'The best score is ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 126 | torch.save({'epoch': epoch, 127 | 'state_dict': model.state_dict(), 128 | 'optimizer': optimizer.state_dict(), 129 | 'scheduler': scheduler.state_dict()} 130 | , self.path) 131 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import logging 3 | from torch.utils.tensorboard import SummaryWriter 4 | import os 5 | import time 6 | 7 | 8 | 9 | class Logger(object): 10 | '''base logger''' 11 | 12 | def __init__(self, args): 13 | 14 | self.args = args 15 | self._create_logger() 16 | 17 | 18 | def _create_logger(self): 19 | ''' 20 | Initialize the logging module. Concretely, initialize the 21 | tensorboard and logging 22 | ''' 23 | # If demo, use default log 24 | if self.args.demo: 25 | self.args.log = False 26 | 27 | # judge whether the folder exits 28 | main_path = r'./log/' + str(self.args.dataset) + '/' + str(self.args.model_name) + '/' 29 | if not os.path.exists(main_path): 30 | os.makedirs(main_path) 31 | 32 | # get the current time string 33 | now_str = time.strftime("%m%d%H%M%S", time.localtime()) 34 | 35 | # Initialize tensorboard. Set the save folder. 36 | if self.args.log: 37 | os.makedirs(main_path + now_str + '/') 38 | folder_name = main_path + now_str + '/tensorboard/' 39 | file_path = now_str + '/bs' + str(self.args.train_batch_size) + '_lr' + str(self.args.lr) + '.txt' 40 | else: 41 | folder_name = main_path + '/default/' 42 | file_path = 'default/log.txt' 43 | self.writer = SummaryWriter(folder_name) 44 | 45 | # Initialize logging. Create console and file handler 46 | self.logger = logging.getLogger(self.args.model_name) 47 | self.logger.setLevel(logging.DEBUG) # must set 48 | 49 | # create file handler 50 | log_path = main_path + file_path 51 | self.fh = logging.FileHandler(log_path, mode='w', encoding='utf-8') 52 | self.fh.setLevel(logging.DEBUG) 53 | fm = logging.Formatter("%(asctime)s-%(message)s") 54 | self.fh.setFormatter(fm) 55 | self.logger.addHandler(self.fh) 56 | 57 | # record the hyper parameters in the text 58 | self.logger.info('The parameters are as below:') 59 | for kv in self.args._get_kwargs(): 60 | self.logger.info('%s: %s' % (kv[0], str(kv[1]))) 61 | #self.logger.info('\nStart Training:') 62 | 63 | #create console handler 64 | self.ch = logging.StreamHandler() 65 | self.ch.setLevel(logging.DEBUG) 66 | self.logger.addHandler(self.ch) 67 | 68 | self.now_str = now_str 69 | 70 | 71 | def end_log(self): 72 | 73 | self.logger.removeHandler(self.fh) 74 | self.logger.removeHandler(self.ch) 75 | 76 | 77 | def log_metrics(self, epoch, metrics, metric_values): 78 | '''Write results of experiments according to your code''' 79 | self.logger.info('epoch: %d' % epoch) 80 | 81 | if self.logger: 82 | log_str = "Overall Results: " 83 | for m in metrics: 84 | log_str = log_str + "\t" + m.upper() + "@" + str(self.args.topk) + ": %.4f" 85 | 86 | self.logger.info(log_str % tuple(metric_values)) 87 | 88 | if self.writer: 89 | 90 | for m, mv in zip(metrics, metric_values): 91 | 92 | self.writer.add_scalar(m.upper()+'@'+str(self.args.topk), mv, epoch) 93 | 94 | 95 | def get_logger(self): 96 | 97 | try: 98 | return self.logger, self.writer 99 | except: 100 | raise ValueError("Please check your logger creater") 101 | 102 | 103 | def get_now_str(self): 104 | 105 | try: 106 | return self.now_str 107 | except: 108 | raise ValueError("An error occurs in logger") 109 | 110 | 111 | 112 | class AugLogger(Logger): 113 | '''create your own logger''' 114 | 115 | def __init__(self, args): 116 | 117 | super(AugLogger, self).__init__(args) 118 | 119 | 120 | def _create_logger(self): 121 | ''' 122 | Initialize the logging module for sequence augmentation. 123 | ''' 124 | # judge whether the folder exits 125 | main_path = r'./log/' + str(self.args.dataset) + '/augmemt/' 126 | if not os.path.exists(main_path): 127 | os.makedirs(main_path) 128 | 129 | # get the current time string 130 | now_str = time.strftime("%m%d%H%M%S", time.localtime()) 131 | 132 | # Initialize tensorboard. Set the save folder. 133 | if self.args.log: 134 | os.makedirs(main_path + now_str + '/') 135 | file_path = now_str + '/log.txt' 136 | else: 137 | file_path = 'default_log.txt' 138 | 139 | # Initialize logging. Create console and file handler 140 | self.logger = logging.getLogger(self.args.model_name) 141 | self.logger.setLevel(logging.DEBUG) # must set 142 | 143 | # create file handler 144 | log_path = main_path + file_path 145 | self.fh = logging.FileHandler(log_path, mode='w', encoding='utf-8') 146 | self.fh.setLevel(logging.DEBUG) 147 | fm = logging.Formatter("%(asctime)s-%(message)s") 148 | self.fh.setFormatter(fm) 149 | self.logger.addHandler(self.fh) 150 | 151 | # record the hyper parameters in the text 152 | self.logger.info('The parameters are as below:') 153 | for kv in self.args._get_kwargs(): 154 | self.logger.info('%s: %s' % (kv[0], str(kv[1]))) 155 | #self.logger.info('\nStart Training:') 156 | 157 | #create console handler 158 | self.ch = logging.StreamHandler() 159 | self.ch.setLevel(logging.DEBUG) 160 | self.logger.addHandler(self.ch) 161 | 162 | self.now_str = now_str 163 | 164 | 165 | def get_logger(self): 166 | 167 | return self.logger 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # here put the import lib 2 | import os 3 | import random 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from tqdm import tqdm 8 | 9 | 10 | def set_seed(seed): 11 | '''Fix all of random seed for reproducible training''' 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | torch.backends.cudnn.deterministic = True # only add when conv in your model 18 | 19 | 20 | def get_n_params(model): 21 | '''Get the number of parameters of model''' 22 | pp = 0 23 | for p in list(model.parameters()): 24 | nn = 1 25 | for s in list(p.size()): 26 | nn = nn*s 27 | pp += nn 28 | return pp 29 | 30 | 31 | def get_n_params_(parameter_list): 32 | '''Get the number of parameters of model''' 33 | pp = 0 34 | for p in list(parameter_list): 35 | nn = 1 36 | for s in list(p.size()): 37 | nn = nn*s 38 | pp += nn 39 | return pp 40 | 41 | 42 | def unzip_data(data, aug=True, aug_num=0): 43 | 44 | res = [] 45 | 46 | if aug: 47 | for user in tqdm(data): 48 | 49 | user_seq = data[user] 50 | seq_len = len(user_seq) 51 | 52 | for i in range(aug_num+2, seq_len+1): 53 | 54 | res.append(user_seq[:i]) 55 | else: 56 | for user in tqdm(data): 57 | 58 | user_seq = data[user] 59 | res.append(user_seq) 60 | 61 | return res 62 | 63 | 64 | def unzip_data_with_user(data, aug=True, aug_num=0): 65 | 66 | res = [] 67 | users = [] 68 | user_id = 1 69 | 70 | if aug: 71 | for user in tqdm(data): 72 | 73 | user_seq = data[user] 74 | seq_len = len(user_seq) 75 | 76 | for i in range(aug_num+2, seq_len+1): 77 | 78 | res.append(user_seq[:i]) 79 | users.append(user_id) 80 | 81 | user_id += 1 82 | 83 | else: 84 | for user in tqdm(data): 85 | 86 | user_seq = data[user] 87 | res.append(user_seq) 88 | users.append(user_id) 89 | user_id += 1 90 | 91 | return res, users 92 | 93 | 94 | def concat_data(data_list): 95 | 96 | res = [] 97 | 98 | if len(data_list) == 2: 99 | 100 | train = data_list[0] 101 | valid = data_list[1] 102 | 103 | for user in train: 104 | 105 | res.append(train[user]+valid[user]) 106 | 107 | elif len(data_list) == 3: 108 | 109 | train = data_list[0] 110 | valid = data_list[1] 111 | test = data_list[2] 112 | 113 | for user in train: 114 | 115 | res.append(train[user]+valid[user]+test[user]) 116 | 117 | else: 118 | 119 | raise ValueError 120 | 121 | return res 122 | 123 | 124 | def concat_aug_data(data_list): 125 | 126 | res = [] 127 | 128 | train = data_list[0] 129 | valid = data_list[1] 130 | 131 | for user in train: 132 | 133 | if len(valid[user]) == 0: 134 | res.append([train[user][0]]) 135 | 136 | else: 137 | res.append(train[user]+valid[user]) 138 | 139 | return res 140 | 141 | 142 | def concat_data_with_user(data_list): 143 | 144 | res = [] 145 | users = [] 146 | user_id = 1 147 | 148 | if len(data_list) == 2: 149 | 150 | train = data_list[0] 151 | valid = data_list[1] 152 | 153 | for user in train: 154 | 155 | res.append(train[user]+valid[user]) 156 | users.append(user_id) 157 | user_id += 1 158 | 159 | elif len(data_list) == 3: 160 | 161 | train = data_list[0] 162 | valid = data_list[1] 163 | test = data_list[2] 164 | 165 | for user in train: 166 | 167 | res.append(train[user]+valid[user]+test[user]) 168 | users.append(user_id) 169 | user_id += 1 170 | 171 | else: 172 | 173 | raise ValueError 174 | 175 | return res, users 176 | 177 | 178 | def filter_data(data, thershold=5): 179 | '''Filter out the sequence shorter than threshold''' 180 | res = [] 181 | 182 | for user in data: 183 | 184 | if len(user) > thershold: 185 | res.append(user) 186 | else: 187 | continue 188 | 189 | return res 190 | 191 | 192 | 193 | def random_neq(l, r, s=[]): # 在l-r之间随机采样一个数,这个数不能在列表s中 194 | 195 | t = np.random.randint(l, r) 196 | while t in s: 197 | t = np.random.randint(l, r) 198 | return t 199 | 200 | 201 | 202 | def metric_report(data_rank, topk=10): 203 | 204 | NDCG, HT = 0, 0 205 | 206 | for rank in data_rank: 207 | 208 | if rank < topk: 209 | NDCG += 1 / np.log2(rank + 2) 210 | HT += 1 211 | 212 | return {'NDCG@10': NDCG / len(data_rank), 213 | 'HR@10': HT / len(data_rank)} 214 | 215 | 216 | 217 | # def metric_len_report(data_rank, data_len, topk=10, aug_len=0, args=None): 218 | 219 | # if args is not None: 220 | # ts_short = args.ts_short 221 | # ts_long = args.ts_long 222 | # else: 223 | # ts_short = 10 224 | # ts_long = 20 225 | 226 | # NDCG_s, HT_s = 0, 0 227 | # NDCG_m, HT_m = 0, 0 228 | # NDCG_l, HT_l = 0, 0 229 | # count_s = len(data_len[data_len<ts_short+aug_len]) 230 | # count_l = len(data_len[data_len>=ts_long+aug_len]) 231 | # count_m = len(data_len) - count_s - count_l 232 | 233 | # for i, rank in enumerate(data_rank): 234 | 235 | # if rank < topk: 236 | 237 | # if data_len[i] < ts_short+aug_len: 238 | # NDCG_s += 1 / np.log2(rank + 2) 239 | # HT_s += 1 240 | # elif data_len[i] < ts_long+aug_len: 241 | # NDCG_m += 1 / np.log2(rank + 2) 242 | # HT_m += 1 243 | # else: 244 | # NDCG_l += 1 / np.log2(rank + 2) 245 | # HT_l += 1 246 | 247 | # return {'Short NDCG@10': NDCG_s / count_s if count_s!=0 else 0, # avoid division of 0 248 | # 'Short HR@10': HT_s / count_s if count_s!=0 else 0, 249 | # 'Medium NDCG@10': NDCG_m / count_m if count_m!=0 else 0, 250 | # 'Medium HR@10': HT_m / count_m if count_m!=0 else 0, 251 | # 'Long NDCG@10': NDCG_l / count_l if count_l!=0 else 0, 252 | # 'Long HR@10': HT_l / count_l if count_l!=0 else 0,} 253 | 254 | 255 | def metric_len_report(data_rank, data_len, topk=10, aug_len=0, args=None): 256 | 257 | if args is not None: 258 | ts_user = args.ts_user 259 | else: 260 | ts_user = 10 261 | 262 | NDCG_s, HT_s = 0, 0 263 | NDCG_l, HT_l = 0, 0 264 | count_s = len(data_len[data_len<ts_user+aug_len]) 265 | count_l = len(data_len[data_len>=ts_user+aug_len]) 266 | 267 | for i, rank in enumerate(data_rank): 268 | 269 | if rank < topk: 270 | 271 | if data_len[i] < ts_user+aug_len: 272 | NDCG_s += 1 / np.log2(rank + 2) 273 | HT_s += 1 274 | else: 275 | NDCG_l += 1 / np.log2(rank + 2) 276 | HT_l += 1 277 | 278 | return {'Short NDCG@10': NDCG_s / count_s if count_s!=0 else 0, # avoid division of 0 279 | 'Short HR@10': HT_s / count_s if count_s!=0 else 0, 280 | 'Long NDCG@10': NDCG_l / count_l if count_l!=0 else 0, 281 | 'Long HR@10': HT_l / count_l if count_l!=0 else 0,} 282 | 283 | 284 | def metric_pop_report(data_rank, pop_dict, target_items, topk=10, aug_pop=0, args=None): 285 | """ 286 | Report the metrics according to target item's popularity 287 | item_pop: the array of the target item's popularity 288 | """ 289 | if args is not None: 290 | ts_tail = args.ts_item 291 | else: 292 | ts_tail = 20 293 | 294 | NDCG_s, HT_s = 0, 0 295 | NDCG_l, HT_l = 0, 0 296 | item_pop = pop_dict[target_items.astype("int64")] 297 | count_s = len(item_pop[item_pop<ts_tail+aug_pop]) 298 | count_l = len(item_pop[item_pop>=ts_tail+aug_pop]) 299 | 300 | for i, rank in enumerate(data_rank): 301 | 302 | if i == 0: # skip the padding index 303 | continue 304 | 305 | if rank < topk: 306 | 307 | if item_pop[i] < ts_tail+aug_pop: 308 | NDCG_s += 1 / np.log2(rank + 2) 309 | HT_s += 1 310 | else: 311 | NDCG_l += 1 / np.log2(rank + 2) 312 | HT_l += 1 313 | 314 | return {'Tail NDCG@10': NDCG_s / count_s if count_s!=0 else 0, 315 | 'Tail HR@10': HT_s / count_s if count_s!=0 else 0, 316 | 'Popular NDCG@10': NDCG_l / count_l if count_l!=0 else 0, 317 | 'Popular HR@10': HT_l / count_l if count_l!=0 else 0,} 318 | 319 | 320 | 321 | def metric_len_5group(pred_rank, 322 | seq_len, 323 | thresholds=[5, 10, 15, 20], 324 | topk=10): 325 | 326 | NDCG = np.zeros(5) 327 | HR = np.zeros(5) 328 | for i, rank in enumerate(pred_rank): 329 | 330 | target_len = seq_len[i] 331 | if rank < topk: 332 | 333 | if target_len < thresholds[0]: 334 | NDCG[0] += 1 / np.log2(rank + 2) 335 | HR[0] += 1 336 | 337 | elif target_len < thresholds[1]: 338 | NDCG[1] += 1 / np.log2(rank + 2) 339 | HR[1] += 1 340 | 341 | elif target_len < thresholds[2]: 342 | NDCG[2] += 1 / np.log2(rank + 2) 343 | HR[2] += 1 344 | 345 | elif target_len < thresholds[3]: 346 | NDCG[3] += 1 / np.log2(rank + 2) 347 | HR[3] += 1 348 | 349 | else: 350 | NDCG[4] += 1 / np.log2(rank + 2) 351 | HR[4] += 1 352 | 353 | count = np.zeros(5) 354 | count[0] = len(seq_len[seq_len>=0]) - len(seq_len[seq_len>=thresholds[0]]) 355 | count[1] = len(seq_len[seq_len>=thresholds[0]]) - len(seq_len[seq_len>=thresholds[1]]) 356 | count[2] = len(seq_len[seq_len>=thresholds[1]]) - len(seq_len[seq_len>=thresholds[2]]) 357 | count[3] = len(seq_len[seq_len>=thresholds[2]]) - len(seq_len[seq_len>=thresholds[3]]) 358 | count[4] = len(seq_len[seq_len>=thresholds[3]]) 359 | 360 | for j in range(5): 361 | NDCG[j] = NDCG[j] / count[j] 362 | HR[j] = HR[j] / count[j] 363 | 364 | return HR, NDCG, count 365 | 366 | 367 | 368 | def metric_pop_5group(pred_rank, 369 | pop_dict, 370 | target_items, 371 | thresholds=[10, 30, 60, 100], 372 | topk=10): 373 | 374 | NDCG = np.zeros(5) 375 | HR = np.zeros(5) 376 | for i, rank in enumerate(pred_rank): 377 | 378 | target_pop = pop_dict[int(target_items[i])] 379 | if rank < topk: 380 | 381 | if target_pop < thresholds[0]: 382 | NDCG[0] += 1 / np.log2(rank + 2) 383 | HR[0] += 1 384 | 385 | elif target_pop < thresholds[1]: 386 | NDCG[1] += 1 / np.log2(rank + 2) 387 | HR[1] += 1 388 | 389 | elif target_pop < thresholds[2]: 390 | NDCG[2] += 1 / np.log2(rank + 2) 391 | HR[2] += 1 392 | 393 | elif target_pop < thresholds[3]: 394 | NDCG[3] += 1 / np.log2(rank + 2) 395 | HR[3] += 1 396 | 397 | else: 398 | NDCG[4] += 1 / np.log2(rank + 2) 399 | HR[4] += 1 400 | 401 | count = np.zeros(5) 402 | pop = pop_dict[target_items.astype("int64")] 403 | count[0] = len(pop[pop>=0]) - len(pop[pop>=thresholds[0]]) 404 | count[1] = len(pop[pop>=thresholds[0]]) - len(pop[pop>=thresholds[1]]) 405 | count[2] = len(pop[pop>=thresholds[1]]) - len(pop[pop>=thresholds[2]]) 406 | count[3] = len(pop[pop>=thresholds[2]]) - len(pop[pop>=thresholds[3]]) 407 | count[4] = len(pop[pop>=thresholds[3]]) 408 | 409 | for j in range(5): 410 | NDCG[j] = NDCG[j] / count[j] 411 | HR[j] = HR[j] / count[j] 412 | 413 | return HR, NDCG, count 414 | 415 | 416 | 417 | def seq_acc(true, pred): 418 | 419 | true_num = np.sum((true==pred)) 420 | total_num = true.shape[0] * true.shape[1] 421 | 422 | return {'acc': true_num / total_num} 423 | 424 | 425 | def load_pretrained_model(pretrain_dir, model, logger, device): 426 | 427 | logger.info("Loading pretrained model ... ") 428 | checkpoint_path = os.path.join(pretrain_dir, 'pytorch_model.bin') 429 | 430 | model_dict = model.state_dict() 431 | 432 | # To be compatible with the new and old version of model saver 433 | try: 434 | pretrained_dict = torch.load(checkpoint_path, map_location=device)['state_dict'] 435 | except: 436 | pretrained_dict = torch.load(checkpoint_path, map_location=device) 437 | 438 | # filter out required parameters 439 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 440 | model_dict.update(new_dict) 441 | # 打印出来,更新了多少的参数 442 | logger.info('Total loaded parameters: {}, update: {}'.format(len(pretrained_dict), len(new_dict))) 443 | model.load_state_dict(model_dict) 444 | 445 | return model 446 | 447 | 448 | def record_csv(args, res_dict, path='log'): 449 | 450 | path = os.path.join(path, args.dataset) 451 | 452 | if not os.path.exists(path): 453 | os.makedirs(path) 454 | 455 | record_file = args.model_name + '.csv' 456 | csv_path = os.path.join(path, record_file) 457 | model_name = args.aug_file + '-' + args.now_str 458 | columns = list(res_dict.keys()) 459 | columns.insert(0, "model_name") 460 | res_dict["model_name"] = model_name 461 | # columns = ["model_name", "HR@10", "NDCG@10", "Short HR@10", "Short NDCG@10", "Medium HR@10", "Medium NDCG@10", "Long HR@10", "Long NDCG@10",] 462 | new_res_dict = {key: [value] for key, value in res_dict.items()} 463 | 464 | if not os.path.exists(csv_path): 465 | 466 | df = pd.DataFrame(new_res_dict) 467 | df = df[columns] # reindex the columns 468 | df.to_csv(csv_path, index=False) 469 | 470 | else: 471 | 472 | df = pd.read_csv(csv_path) 473 | add_df = pd.DataFrame(new_res_dict) 474 | df = pd.concat([df, add_df]) 475 | df.to_csv(csv_path, index=False) 476 | 477 | 478 | 479 | def record_group(args, res_dict, path='log'): 480 | 481 | path = os.path.join(path, args.dataset) 482 | 483 | if not os.path.exists(path): 484 | os.makedirs(path) 485 | 486 | record_file = args.model_name + '.csv' 487 | csv_path = os.path.join(path, record_file) 488 | model_name = args.aug_file + '-' + args.now_str 489 | columns = list(res_dict.keys()) 490 | columns.insert(0, "model_name") 491 | res_dict["model_name"] = model_name 492 | # columns = ["model_name", "HR@10", "NDCG@10", "Short HR@10", "Short NDCG@10", "Medium HR@10", "Medium NDCG@10", "Long HR@10", "Long NDCG@10",] 493 | new_res_dict = {key: [value] for key, value in res_dict.items()} 494 | 495 | if not os.path.exists(csv_path): 496 | 497 | df = pd.DataFrame(new_res_dict) 498 | df = df[columns] # reindex the columns 499 | df.to_csv(csv_path, index=False) 500 | 501 | else: 502 | 503 | df = pd.read_csv(csv_path) 504 | add_df = pd.DataFrame(new_res_dict) 505 | df = pd.concat([df, add_df]) 506 | df.to_csv(csv_path, index=False) 507 | --------------------------------------------------------------------------------