├── .gitignore ├── LICENSE ├── README.md ├── SASRecModules_ori.py ├── data ├── data_interface.py ├── lastfm_data.py ├── movielens_data.py └── steam_data.py ├── main.py ├── model ├── mlp_projector.py └── model_interface.py ├── optims.py ├── prompt ├── artist.txt ├── game.txt └── movie.txt ├── rec_model ├── lastfm.pt ├── movielens.pt └── steam.pt ├── recommender └── A_SASRec_final_bce_llm.py ├── requirements.txt ├── test_lastfm.sh ├── test_movielens.sh ├── test_steam.sh ├── train_lastfm.sh ├── train_movielens.sh └── train_steam.sh /.gitignore: -------------------------------------------------------------------------------- 1 | log/ 2 | checkpoints/ 3 | *.pyc 4 | output/ 5 | *.df 6 | __pycache__ 7 | ref -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Jiayi Liao 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLaRA 2 | 3 | - *2024.7*: We have resolved several bugs within our code. Below are the most recent results of LLaRA. 4 | 5 | | | movielens || steam || lastfm || 6 | |----------------|------------|------|----------|------|----------|------| 7 | | | ValidRatio | HitRatio@1 | ValidRatio | HitRatio@1 | ValidRatio | HitRatio@1 | 8 | | LLaRA(GRU4Rec) | 0.9684 | 0.4000 | 0.9840 | 0.4916 | 0.9672 | 0.4918 | 9 | | LLaRA(Caser) | 0.9684 | 0.4211 | 0.9519 | 0.4621 | 0.9754 | 0.4836 | 10 | | LLaRA(SASRec) | 0.9789 | 0.4526 | 0.9958 | 0.5051 | 0.9754 | 0.5246 | 11 | - *2024.5*: We have updated the Steam dataset to a new version, in which we've addressed an issue that led to the repetition of certain data in the last interacted item of sequence. 12 | - 🔥 *2024.3*: Our paper is accepted by SIGIR'24! Thank all Collaborators! 🎉🎉 13 | - 🔥 *2024.3*: Our [datasets](https://huggingface.co/datasets/joyliao7777/LLaRA) and [checkpoints](https://huggingface.co/joyliao7777/LLaRA) are released on the huggingface. 14 | 15 | ##### Preparation 16 | 17 | 1. Prepare the environment: 18 | 19 | ```sh 20 | git clone https://github.com/ljy0ustc/LLaRA.git 21 | cd LLaRA 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | 2. Prepare the pre-trained huggingface model of LLaMA2-7B (https://huggingface.co/meta-llama/Llama-2-7b-hf). 26 | 27 | 3. Download the data and checkpoints. 28 | 29 | 4. Prepare the data and checkpoints: 30 | 31 | Put the data to the dir path `data/ref/` and the checkpoints to the dir path `checkpoints/`. 32 | 33 | ##### Train LLaRA 34 | 35 | Train LLaRA with a single A100 GPU on MovieLens dataset: 36 | 37 | ```sh 38 | sh train_movielens.sh 39 | ``` 40 | 41 | Train LLaRA with a single A100 GPU on Steam dataset: 42 | 43 | ```sh 44 | sh train_steam.sh 45 | ``` 46 | 47 | Train LLaRA with a single A100 GPU on LastFM dataset: 48 | 49 | ```sh 50 | sh train_lastfm.sh 51 | ``` 52 | 53 | Note that: set the `llm_path` argument with your own directory path of the Llama2 model. 54 | 55 | ##### Evaluate LLaRA 56 | 57 | Test LLaRA with a single A100 GPU on MovieLens dataset: 58 | 59 | ```sh 60 | sh test_movielens.sh 61 | ``` 62 | 63 | Test LLaRA with a single A100 GPU on Steam dataset: 64 | 65 | ```sh 66 | sh test_steam.sh 67 | ``` 68 | 69 | Test LLaRA with a single A100 GPU on LastFM dataset: 70 | 71 | ```sh 72 | sh test_lastfm.sh 73 | ``` -------------------------------------------------------------------------------- /SASRecModules_ori.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | class PositionwiseFeedForward(nn.Module): 8 | def __init__(self, d_in, d_hid, dropout=0.1): 9 | super().__init__() 10 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) 11 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) 12 | self.layer_norm = nn.LayerNorm(d_in) 13 | self.dropout = nn.Dropout(dropout) 14 | 15 | def forward(self, x): 16 | residual = x 17 | output = x.transpose(1, 2) 18 | output = self.w_2(F.relu(self.w_1(output))) 19 | output = output.transpose(1, 2) 20 | output = self.dropout(output) 21 | output = self.layer_norm(output + residual) 22 | return output 23 | 24 | 25 | 26 | class MultiHeadAttention(nn.Module): 27 | def __init__(self, hidden_size, num_units, num_heads, dropout_rate): 28 | super().__init__() 29 | self.hidden_size = hidden_size 30 | self.num_heads = num_heads 31 | assert hidden_size % num_heads == 0 32 | 33 | self.linear_q = nn.Linear(hidden_size, num_units) 34 | self.linear_k = nn.Linear(hidden_size, num_units) 35 | self.linear_v = nn.Linear(hidden_size, num_units) 36 | self.dropout = nn.Dropout(dropout_rate) 37 | self.softmax = nn.Softmax(dim=-1) 38 | 39 | 40 | def forward(self, queries, keys): 41 | """ 42 | :param queries: A 3d tensor with shape of [N, T_q, C_q] 43 | :param keys: A 3d tensor with shape of [N, T_k, C_k] 44 | 45 | :return: A 3d tensor with shape of (N, T_q, C) 46 | 47 | """ 48 | Q = self.linear_q(queries) # (N, T_q, C) 49 | K = self.linear_k(keys) # (N, T_k, C) 50 | V = self.linear_v(keys) # (N, T_k, C) 51 | 52 | # Split and Concat 53 | split_size = self.hidden_size // self.num_heads 54 | Q_ = torch.cat(torch.split(Q, split_size, dim=2), dim=0) # (h*N, T_q, C/h) 55 | K_ = torch.cat(torch.split(K, split_size, dim=2), dim=0) # (h*N, T_k, C/h) 56 | V_ = torch.cat(torch.split(V, split_size, dim=2), dim=0) # (h*N, T_k, C/h) 57 | 58 | # Multiplication 59 | matmul_output = torch.bmm(Q_, K_.transpose(1, 2)) / self.hidden_size ** 0.5 # (h*N, T_q, T_k) 60 | 61 | # Key Masking 62 | key_mask = torch.sign(torch.abs(keys.sum(dim=-1))).repeat(self.num_heads, 1) # (h*N, T_k) 63 | key_mask_reshaped = key_mask.unsqueeze(1).repeat(1, queries.shape[1], 1) # (h*N, T_q, T_k) 64 | key_paddings = torch.ones_like(matmul_output) * (-2 ** 32 + 1) 65 | matmul_output_m1 = torch.where(torch.eq(key_mask_reshaped, 0), key_paddings, matmul_output) # (h*N, T_q, T_k) 66 | 67 | # Causality - Future Blinding 68 | diag_vals = torch.ones_like(matmul_output[0, :, :]) # (T_q, T_k) 69 | tril = torch.tril(diag_vals) # (T_q, T_k) 70 | causality_mask = tril.unsqueeze(0).repeat(matmul_output.shape[0], 1, 1) # (h*N, T_q, T_k) 71 | causality_paddings = torch.ones_like(causality_mask) * (-2 ** 32 + 1) 72 | matmul_output_m2 = torch.where(torch.eq(causality_mask, 0), causality_paddings, matmul_output_m1) # (h*N, T_q, T_k) 73 | 74 | # Activation 75 | matmul_output_sm = self.softmax(matmul_output_m2) # (h*N, T_q, T_k) 76 | 77 | # Query Masking 78 | query_mask = torch.sign(torch.abs(queries.sum(dim=-1))).repeat(self.num_heads, 1) # (h*N, T_q) 79 | query_mask = query_mask.unsqueeze(-1).repeat(1, 1, keys.shape[1]) # (h*N, T_q, T_k) 80 | matmul_output_qm = matmul_output_sm * query_mask 81 | 82 | # Dropout 83 | matmul_output_dropout = self.dropout(matmul_output_qm) 84 | 85 | # Weighted Sum 86 | output_ws = torch.bmm(matmul_output_dropout, V_) # ( h*N, T_q, C/h) 87 | 88 | # Restore Shape 89 | output = torch.cat(torch.split(output_ws, output_ws.shape[0] // self.num_heads, dim=0), dim=2) # (N, T_q, C) 90 | 91 | # Residual Connection 92 | output_res = output + queries 93 | 94 | return output_res 95 | -------------------------------------------------------------------------------- /data/data_interface.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import importlib 3 | import pickle as pkl 4 | import pytorch_lightning as pl 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data.sampler import WeightedRandomSampler 7 | 8 | import random 9 | import torch 10 | import argparse 11 | from transformers import LlamaForCausalLM, LlamaTokenizer 12 | import os 13 | 14 | 15 | 16 | class TrainCollater: 17 | def __init__(self, 18 | prompt_list=None, 19 | llm_tokenizer=None, 20 | train=False, 21 | terminator="\n", 22 | max_step=1): 23 | self.prompt_list = prompt_list 24 | self.llm_tokenizer = llm_tokenizer 25 | self.train=train 26 | self.terminator = terminator 27 | self.max_step = max_step 28 | self.cur_step = 1 29 | 30 | def __call__(self, batch): 31 | if isinstance(self.prompt_list,list): 32 | instruction = random.choice(self.prompt_list) 33 | inputs_text = instruction if isinstance(instruction, list) else [instruction] * len(batch) 34 | else: 35 | instruction = sample["instruction_input"] if "instruction_input" in sample else None 36 | inputs_text = instruction if isinstance(instruction, list) else [instruction] * len(batch) 37 | 38 | thresh_hold = self.cur_step/self.max_step 39 | p = random.random() 40 | if p < thresh_hold or not self.train: 41 | for i, sample in enumerate(batch): 42 | input_text=inputs_text[i] 43 | if '[HistoryHere]' in input_text: 44 | insert_prompt=", ".join([seq_title+' [HistoryEmb]' for seq_title in sample['seq_name']]) 45 | input_text=input_text.replace('[HistoryHere]',insert_prompt) 46 | if '[CansHere]' in input_text: 47 | insert_prompt=", ".join([can_title+' [CansEmb]' for can_title in sample['cans_name']]) 48 | input_text=input_text.replace('[CansHere]',insert_prompt) 49 | inputs_text[i]=input_text 50 | flag = False 51 | else: 52 | for i, sample in enumerate(batch): 53 | input_text=inputs_text[i] 54 | if '[HistoryHere]' in input_text: 55 | insert_prompt=", ".join([seq_title+' [PH]' for seq_title in sample['seq_name']]) 56 | input_text=input_text.replace('[HistoryHere]',insert_prompt) 57 | if '[CansHere]' in input_text: 58 | insert_prompt=", ".join([can_title+' [PH]' for can_title in sample['cans_name']]) 59 | input_text=input_text.replace('[CansHere]',insert_prompt) 60 | inputs_text[i]=input_text 61 | flag = True 62 | self.cur_step += 1 63 | 64 | targets_text = [sample['correct_answer'] for sample in batch] 65 | 66 | if self.train: 67 | targets_text=[target_text+self.terminator for target_text in targets_text] 68 | inputs_pair = [[p, t] for p, t in zip(inputs_text, targets_text)] 69 | batch_tokens = self.llm_tokenizer( 70 | inputs_pair, 71 | return_tensors="pt", 72 | padding="longest", 73 | truncation=False, 74 | add_special_tokens=True, 75 | return_attention_mask=True, 76 | return_token_type_ids=True) 77 | new_batch={"tokens":batch_tokens, 78 | "seq":torch.stack([torch.tensor(sample['seq']) for sample in batch], dim=0), 79 | "cans":torch.stack([torch.tensor(sample['cans']) for sample in batch], dim=0), 80 | "len_seq":torch.stack([torch.tensor(sample['len_seq']) for sample in batch], dim=0), 81 | "len_cans":torch.stack([torch.tensor(sample['len_cans']) for sample in batch], dim=0), 82 | "item_id": torch.stack([torch.tensor(sample['item_id']) for sample in batch], dim=0), 83 | "flag":flag, 84 | } 85 | else: 86 | batch_tokens = self.llm_tokenizer( 87 | inputs_text, 88 | return_tensors="pt", 89 | padding="longest", 90 | truncation=False, 91 | add_special_tokens=True, 92 | return_attention_mask=True) 93 | cans_name=[sample['cans_name'] for sample in batch] 94 | new_batch={"tokens":batch_tokens, 95 | "seq":torch.stack([torch.tensor(sample['seq']) for sample in batch], dim=0), 96 | "cans":torch.stack([torch.tensor(sample['cans']) for sample in batch], dim=0), 97 | "len_seq":torch.stack([torch.tensor(sample['len_seq']) for sample in batch], dim=0), 98 | "len_cans":torch.stack([torch.tensor(sample['len_cans']) for sample in batch], dim=0), 99 | "item_id": torch.stack([torch.tensor(sample['item_id']) for sample in batch], dim=0), 100 | "correct_answer": targets_text, 101 | "cans_name": cans_name, 102 | } 103 | return new_batch 104 | 105 | class DInterface(pl.LightningDataModule): 106 | 107 | def __init__(self, 108 | llm_tokenizer=None, 109 | num_workers=8, 110 | dataset='', 111 | **kwargs): 112 | super().__init__() 113 | self.num_workers = num_workers 114 | self.llm_tokenizer=llm_tokenizer 115 | self.dataset = dataset 116 | self.kwargs = kwargs 117 | self.batch_size = kwargs['batch_size'] 118 | self.max_epochs = kwargs['max_epochs'] 119 | self.load_data_module() 120 | self.load_prompt(kwargs['prompt_path']) 121 | 122 | self.trainset = self.instancialize(stage='train') 123 | self.valset = self.instancialize(stage='val') 124 | self.testset = self.instancialize(stage='test') 125 | self.max_steps = self.max_epochs*(len(self.trainset)//self.batch_size)//self.num_workers 126 | 127 | def train_dataloader(self): 128 | return DataLoader(self.trainset, 129 | batch_size=self.batch_size, 130 | num_workers=self.num_workers, 131 | shuffle=True, 132 | drop_last=True, 133 | collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=True, max_step=self.max_steps)) 134 | 135 | def val_dataloader(self): 136 | return DataLoader(self.valset, 137 | batch_size=self.batch_size, 138 | num_workers=self.num_workers, 139 | shuffle=False, 140 | collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=False)) 141 | 142 | def test_dataloader(self): 143 | return DataLoader(self.testset, 144 | batch_size=self.batch_size, 145 | num_workers=self.num_workers, 146 | shuffle=False, 147 | collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=False)) 148 | 149 | def load_data_module(self): 150 | name = self.dataset 151 | camel_name = ''.join([i.capitalize() for i in name.split('_')]) 152 | try: 153 | self.data_module = getattr(importlib.import_module( 154 | '.'+name, package=__package__), camel_name) 155 | except: 156 | raise ValueError( 157 | f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}') 158 | 159 | def instancialize(self, **other_args): 160 | """ Instancialize a model using the corresponding parameters 161 | from self.hparams dictionary. You can also input any args 162 | to overwrite the corresponding value in self.kwargs. 163 | """ 164 | class_args = inspect.getargspec(self.data_module.__init__).args[1:] 165 | inkeys = self.kwargs.keys() 166 | args1 = {} 167 | for arg in class_args: 168 | if arg in inkeys: 169 | args1[arg] = self.kwargs[arg] 170 | args1.update(other_args) 171 | return self.data_module(**args1) 172 | 173 | def load_prompt(self,prompt_path): 174 | if os.path.isfile(prompt_path): 175 | with open(prompt_path, 'r') as f: 176 | raw_prompts = f.read().splitlines() 177 | self.prompt_list = [p.strip() for p in raw_prompts] 178 | print('Load {} training prompts'.format(len(self.prompt_list))) 179 | print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) 180 | else: 181 | self.prompt_list = [] -------------------------------------------------------------------------------- /data/lastfm_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as op 3 | import numpy as np 4 | import pickle as pkl 5 | import torch.utils.data as data 6 | 7 | import pandas as pd 8 | import random 9 | 10 | class LastfmData(data.Dataset): 11 | def __init__(self, data_dir=r'data/ref/lastfm', 12 | stage=None, 13 | cans_num=10, 14 | sep=", ", 15 | no_augment=True): 16 | self.__dict__.update(locals()) 17 | self.aug = (stage=='train') and not no_augment 18 | self.padding_item_id=4606 19 | self.check_files() 20 | 21 | def __len__(self): 22 | return len(self.session_data['seq']) 23 | 24 | def __getitem__(self, i): 25 | temp = self.session_data.iloc[i] 26 | candidates = self.negative_sampling(temp['seq_unpad'],temp['next']) 27 | cans_name=[self.item_id2name[can] for can in candidates] 28 | sample = { 29 | 'seq': temp['seq'], 30 | 'seq_name': temp['seq_title'], 31 | 'len_seq': temp['len_seq'], 32 | 'seq_str': self.sep.join(temp['seq_title']), 33 | 'cans': candidates, 34 | 'cans_name': cans_name, 35 | 'cans_str': self.sep.join(cans_name), 36 | 'len_cans': self.cans_num, 37 | 'item_id': temp['next'], 38 | 'item_name': temp['next_item_name'], 39 | 'correct_answer': temp['next_item_name'] 40 | } 41 | return sample 42 | 43 | def negative_sampling(self,seq_unpad,next_item): 44 | canset=[i for i in list(self.item_id2name.keys()) if i not in seq_unpad and i!=next_item] 45 | candidates=random.sample(canset, self.cans_num-1)+[next_item] 46 | random.shuffle(candidates) 47 | return candidates 48 | 49 | def check_files(self): 50 | self.item_id2name=self.get_music_id2name() 51 | if self.stage=='train': 52 | filename="train_data.df" 53 | elif self.stage=='val': 54 | filename="Val_data.df" 55 | elif self.stage=='test': 56 | filename="Test_data.df" 57 | data_path=op.join(self.data_dir, filename) 58 | self.session_data = self.session_data4frame(data_path, self.item_id2name) 59 | 60 | 61 | def get_music_id2name(self): 62 | music_id2name = dict() 63 | item_path=op.join(self.data_dir, 'id2name.txt') 64 | with open(item_path, 'r') as f: 65 | for l in f.readlines(): 66 | ll = l.strip('\n').split('::') 67 | music_id2name[int(ll[0])] = ll[1].strip() 68 | return music_id2name 69 | 70 | def session_data4frame(self, datapath, music_id2name): 71 | train_data = pd.read_pickle(datapath) 72 | train_data = train_data[train_data['len_seq'] >= 3] 73 | def remove_padding(xx): 74 | x = xx[:] 75 | for i in range(10): 76 | try: 77 | x.remove(self.padding_item_id) 78 | except: 79 | break 80 | return x 81 | train_data['seq_unpad'] = train_data['seq'].apply(remove_padding) 82 | def seq_to_title(x): 83 | return [music_id2name[x_i] for x_i in x] 84 | train_data['seq_title'] = train_data['seq_unpad'].apply(seq_to_title) 85 | def next_item_title(x): 86 | return music_id2name[x] 87 | train_data['next_item_name'] = train_data['next'].apply(next_item_title) 88 | return train_data -------------------------------------------------------------------------------- /data/movielens_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as op 3 | import numpy as np 4 | import pickle as pkl 5 | import torch.utils.data as data 6 | 7 | import pandas as pd 8 | import random 9 | 10 | class MovielensData(data.Dataset): 11 | def __init__(self, data_dir=r'data/ref/movielens', 12 | stage=None, 13 | cans_num=10, 14 | sep=", ", 15 | no_augment=True): 16 | self.__dict__.update(locals()) 17 | self.aug = (stage=='train') and not no_augment 18 | self.padding_item_id=1682 19 | self.padding_rating=0 20 | self.check_files() 21 | 22 | def __len__(self): 23 | return len(self.session_data['seq']) 24 | 25 | def __getitem__(self, i): 26 | temp = self.session_data.iloc[i] 27 | candidates = self.negative_sampling(temp['seq_unpad'],temp['next']) 28 | cans_name=[self.item_id2name[can] for can in candidates] 29 | sample = { 30 | 'seq': temp['seq'], 31 | 'seq_name': temp['seq_title'], 32 | 'len_seq': temp['len_seq'], 33 | 'seq_str': self.sep.join(temp['seq_title']), 34 | 'cans': candidates, 35 | 'cans_name': cans_name, 36 | 'cans_str': self.sep.join(cans_name), 37 | 'len_cans': self.cans_num, 38 | 'item_id': temp['next'], 39 | 'item_name': temp['next_item_name'], 40 | 'correct_answer': temp['next_item_name'] 41 | } 42 | return sample 43 | 44 | def negative_sampling(self,seq_unpad,next_item): 45 | canset=[i for i in list(self.item_id2name.keys()) if i not in seq_unpad and i!=next_item] 46 | candidates=random.sample(canset, self.cans_num-1)+[next_item] 47 | random.shuffle(candidates) 48 | return candidates 49 | 50 | def check_files(self): 51 | self.item_id2name=self.get_movie_id2name() 52 | if self.stage=='train': 53 | filename="train_data.df" 54 | elif self.stage=='val': 55 | filename="Val_data.df" 56 | elif self.stage=='test': 57 | filename="Test_data.df" 58 | data_path=op.join(self.data_dir, filename) 59 | self.session_data = self.session_data4frame(data_path, self.item_id2name) 60 | 61 | def get_mv_title(self,s): 62 | sub_list=[", The", ", A", ", An"] 63 | for sub_s in sub_list: 64 | if sub_s in s: 65 | return sub_s[2:]+" "+s.replace(sub_s,"") 66 | return s 67 | 68 | def get_movie_id2name(self): 69 | movie_id2name = dict() 70 | item_path=op.join(self.data_dir, 'u.item') 71 | with open(item_path, 'r', encoding = "ISO-8859-1") as f: 72 | for l in f.readlines(): 73 | ll = l.strip('\n').split('|') 74 | movie_id2name[int(ll[0]) - 1] = self.get_mv_title(ll[1][:-7]) 75 | return movie_id2name 76 | 77 | def session_data4frame(self, datapath, movie_id2name): 78 | train_data = pd.read_pickle(datapath) 79 | train_data = train_data[train_data['len_seq'] >= 3] 80 | def remove_padding(xx): 81 | x = xx[:] 82 | for i in range(10): 83 | try: 84 | x.remove((self.padding_item_id,self.padding_rating)) 85 | except: 86 | break 87 | return x 88 | train_data['seq_unpad'] = train_data['seq'].apply(remove_padding) 89 | def seq_to_title(x): 90 | return [movie_id2name[x_i[0]] for x_i in x] 91 | train_data['seq_title'] = train_data['seq_unpad'].apply(seq_to_title) 92 | def next_item_title(x): 93 | return movie_id2name[x[0]] 94 | train_data['next_item_name'] = train_data['next'].apply(next_item_title) 95 | def get_id_from_tumple(x): 96 | return x[0] 97 | def get_id_from_list(x): 98 | return [i[0] for i in x] 99 | train_data['next'] = train_data['next'].apply(get_id_from_tumple) 100 | train_data['seq'] = train_data['seq'].apply(get_id_from_list) 101 | train_data['seq_unpad']=train_data['seq_unpad'].apply(get_id_from_list) 102 | return train_data -------------------------------------------------------------------------------- /data/steam_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as op 3 | import numpy as np 4 | import pickle as pkl 5 | import torch.utils.data as data 6 | 7 | import pandas as pd 8 | import random 9 | 10 | class SteamData(data.Dataset): 11 | def __init__(self, data_dir=r'data/ref/steam', 12 | stage=None, 13 | cans_num=10, 14 | sep=", ", 15 | no_augment=True): 16 | self.__dict__.update(locals()) 17 | self.aug = (stage=='train') and not no_augment 18 | self.padding_item_id=3581 19 | self.check_files() 20 | 21 | def __len__(self): 22 | return len(self.session_data['seq']) 23 | 24 | def __getitem__(self, i): 25 | temp = self.session_data.iloc[i] 26 | candidates = self.negative_sampling(temp['seq_unpad'],temp['next']) 27 | cans_name=[self.item_id2name[can] for can in candidates] 28 | sample = { 29 | 'seq': temp['seq'], 30 | 'seq_name': temp['seq_title'], 31 | 'len_seq': temp['len_seq'], 32 | 'seq_str': self.sep.join(temp['seq_title']), 33 | 'cans': candidates, 34 | 'cans_name': cans_name, 35 | 'cans_str': self.sep.join(cans_name), 36 | 'len_cans': self.cans_num, 37 | 'item_id': temp['next'], 38 | 'item_name': temp['next_item_name'], 39 | 'correct_answer': temp['next_item_name'] 40 | } 41 | return sample 42 | 43 | def negative_sampling(self,seq_unpad,next_item): 44 | canset=[i for i in list(self.item_id2name.keys()) if i not in seq_unpad and i!=next_item] 45 | candidates=random.sample(canset, self.cans_num-1)+[next_item] 46 | random.shuffle(candidates) 47 | return candidates 48 | 49 | def check_files(self): 50 | self.item_id2name=self.get_game_id2name() 51 | if self.stage=='train': 52 | filename="train_data.df" 53 | elif self.stage=='val': 54 | filename="Val_data.df" 55 | elif self.stage=='test': 56 | filename="Test_data.df" 57 | data_path=op.join(self.data_dir, filename) 58 | self.session_data = self.session_data4frame(data_path, self.item_id2name) 59 | 60 | 61 | def get_game_id2name(self): 62 | game_id2name = dict() 63 | item_path=op.join(self.data_dir, 'id2name.txt') 64 | with open(item_path, 'r') as f: 65 | for l in f.readlines(): 66 | ll = l.strip('\n').split('::') 67 | game_id2name[int(ll[0])] = ll[1].strip() 68 | return game_id2name 69 | 70 | def session_data4frame(self, datapath, game_id2name): 71 | train_data = pd.read_pickle(datapath) 72 | train_data = train_data[train_data['len_seq'] >= 3] 73 | def remove_padding(xx): 74 | x = xx[:] 75 | for i in range(10): 76 | try: 77 | x.remove(self.padding_item_id) 78 | except: 79 | break 80 | return x 81 | train_data['seq_unpad'] = train_data['seq'].apply(remove_padding) 82 | def seq_to_title(x): 83 | return [game_id2name[x_i] for x_i in x] 84 | train_data['seq_title'] = train_data['seq_unpad'].apply(seq_to_title) 85 | def next_item_title(x): 86 | return game_id2name[x] 87 | train_data['next_item_name'] = train_data['next'].apply(next_item_title) 88 | return train_data -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytorch_lightning as pl 3 | from argparse import ArgumentParser 4 | from pytorch_lightning import Trainer 5 | import pytorch_lightning.callbacks as plc 6 | from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger 7 | 8 | from model.model_interface import MInterface 9 | from data.data_interface import DInterface 10 | from recommender.A_SASRec_final_bce_llm import SASRec, Caser, GRU 11 | from SASRecModules_ori import * 12 | from transformers import LlamaForCausalLM, LlamaTokenizer 13 | 14 | def load_callbacks(args): 15 | callbacks = [] 16 | callbacks.append(plc.EarlyStopping( 17 | monitor='metric', 18 | mode='max', 19 | patience=10, 20 | min_delta=0.001 21 | )) 22 | 23 | callbacks.append(plc.ModelCheckpoint( 24 | monitor='metric', 25 | dirpath=args.ckpt_dir, 26 | filename='{epoch:02d}-{metric:.3f}', 27 | save_top_k=-1, 28 | mode='max', 29 | save_last=True, 30 | #train_time_interval=args.val_check_interval 31 | every_n_epochs=1 32 | )) 33 | 34 | if args.lr_scheduler: 35 | callbacks.append(plc.LearningRateMonitor( 36 | logging_interval='step')) 37 | return callbacks 38 | 39 | def main(args): 40 | pl.seed_everything(args.seed) 41 | model = MInterface(**vars(args)) 42 | if args.ckpt_path: 43 | ckpt = torch.load(args.ckpt_path, map_location='cpu') 44 | model.load_state_dict(ckpt['state_dict'], strict=False) 45 | print("load checkpoints from {}".format(args.ckpt_path)) 46 | 47 | data_module = DInterface(llm_tokenizer=model.llama_tokenizer,**vars(args)) 48 | 49 | args.max_steps=len(data_module.trainset) * args.max_epochs // (args.accumulate_grad_batches * args.batch_size) 50 | 51 | logger = TensorBoardLogger(save_dir='./log/', name=args.log_dir) 52 | args.callbacks = load_callbacks(args) 53 | args.logger = logger 54 | if not os.path.exists(args.ckpt_dir): 55 | os.makedirs(args.ckpt_dir) 56 | 57 | trainer = Trainer.from_argparse_args(args) 58 | 59 | if args.auto_lr_find: 60 | lr_finder=trainer.tuner.lr_find(model=model, datamodule=data_module, min_lr=1e-10, max_lr=1e-3, num_training=100) 61 | fig=lr_finder.plot(suggest=True) 62 | fig_path="lr_finder.png" 63 | fig.savefig(fig_path) 64 | print("Saving to {}".format(fig_path)) 65 | model.hparams.lr=lr_finder.suggestion() 66 | 67 | if args.mode == 'train': 68 | trainer.fit(model=model, datamodule=data_module) 69 | else: 70 | trainer.test(model=model, datamodule=data_module) 71 | 72 | 73 | if __name__ == '__main__': 74 | torch.multiprocessing.set_start_method('spawn') 75 | parser = ArgumentParser() 76 | 77 | parser.add_argument('--accelerator', default='gpu', type=str) 78 | parser.add_argument('--devices', default=-1, type=int) 79 | parser.add_argument('--precision', default='bf16', type=str) 80 | parser.add_argument('--amp_backend', default="native", type=str) 81 | 82 | parser.add_argument('--batch_size', default=8, type=int) 83 | parser.add_argument('--num_workers', default=8, type=int) 84 | parser.add_argument('--seed', default=1234, type=int) 85 | parser.add_argument('--lr', default=1e-3, type=float) 86 | parser.add_argument('--accumulate_grad_batches', default=8, type=int) 87 | parser.add_argument('--check_val_every_n_epoch', default=1, type=int) 88 | 89 | parser.add_argument('--lr_scheduler', default='cosine', choices=['cosine'], type=str) 90 | parser.add_argument('--lr_decay_min_lr', default=1e-9, type=float) 91 | parser.add_argument('--lr_warmup_start_lr', default=1e-7, type=float) 92 | 93 | parser.add_argument('--load_best', action='store_true') 94 | parser.add_argument('--load_dir', default=None, type=str) 95 | parser.add_argument('--load_ver', default=None, type=str) 96 | parser.add_argument('--load_v_num', default=None, type=int) 97 | 98 | parser.add_argument('--dataset', default='movielens_data', type=str) 99 | parser.add_argument('--data_dir', default='data/ref/movielens1m', type=str) 100 | parser.add_argument('--model_name', default='mlp_projector', type=str) 101 | parser.add_argument('--loss', default='lm', type=str) 102 | parser.add_argument('--weight_decay', default=1e-5, type=float) 103 | parser.add_argument('--no_augment', action='store_true') 104 | parser.add_argument('--ckpt_dir', default='./checkpoints/', type=str) 105 | parser.add_argument('--log_dir', default='movielens_logs', type=str) 106 | 107 | parser.add_argument('--rec_size', default=64, type=int) 108 | parser.add_argument('--padding_item_id', default=1682, type=int) 109 | parser.add_argument('--llm_path', type=str) 110 | parser.add_argument('--rec_model_path', default='./rec_model/SASRec_ml1m.pt', type=str) 111 | parser.add_argument('--prompt_path', default='./prompt/movie/', type=str) 112 | parser.add_argument('--output_dir', default='./output/', type=str) 113 | parser.add_argument('--ckpt_path', type=str) 114 | parser.add_argument('--rec_embed', default="SASRec", choices=['SASRec', 'Caser','GRU'], type=str) 115 | 116 | parser.add_argument('--aug_prob', default=0.5, type=float) 117 | parser.add_argument('--mode', default='train', choices=['train', 'test'], type=str) 118 | parser.add_argument('--auto_lr_find', default=False, action='store_true') 119 | parser.add_argument('--metric', default='hr', choices=['hr'], type=str) 120 | parser.add_argument('--max_epochs', default=10, type=int) 121 | parser.add_argument('--save', default='part', choices=['part', 'all'], type=str) 122 | parser.add_argument('--cans_num', default=10, type=int) 123 | 124 | # Finetuning 125 | parser.add_argument('--llm_tuning', default='lora', choices=['lora', 'freeze','freeze_lora'], type=str) 126 | parser.add_argument('--peft_dir', default=None, type=str) 127 | parser.add_argument('--peft_config', default=None, type=str) 128 | parser.add_argument('--lora_r', default=8, type=float) 129 | parser.add_argument('--lora_alpha', default=32, type=float) 130 | parser.add_argument('--lora_dropout', default=0.1, type=float) 131 | 132 | args = parser.parse_args() 133 | 134 | if 'movielens' in args.data_dir: 135 | args.padding_item_id = 1682 136 | elif 'steam' in args.data_dir: 137 | args.padding_item_id = 3581 138 | elif 'lastfm' in args.data_dir: 139 | args.padding_item_id = 4606 140 | 141 | main(args) 142 | -------------------------------------------------------------------------------- /model/mlp_projector.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class MlpProjector(nn.Module): 4 | def __init__(self, rec_size=64, llm_size=4096): 5 | super().__init__() 6 | self.mlp_proj = nn.Sequential( 7 | nn.Linear(rec_size, llm_size), 8 | nn.GELU(), 9 | nn.Linear(llm_size, llm_size) 10 | ) 11 | 12 | def forward(self, x): 13 | x = self.mlp_proj(x) 14 | return x -------------------------------------------------------------------------------- /model/model_interface.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import torch 3 | import importlib 4 | from torch import nn 5 | from torch.nn import functional as F 6 | import torch.optim.lr_scheduler as lrs 7 | 8 | import pytorch_lightning as pl 9 | 10 | from transformers import LlamaForCausalLM, LlamaTokenizer 11 | import random 12 | from pandas.core.frame import DataFrame 13 | import os.path as op 14 | import os 15 | from optims import LinearWarmupCosineLRScheduler 16 | import numpy as np 17 | from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel 18 | 19 | class MInterface(pl.LightningModule): 20 | def __init__(self, 21 | **kargs): 22 | super().__init__() 23 | self.save_hyperparameters() 24 | self.load_llm(self.hparams.llm_path) 25 | self.load_rec_model(self.hparams.rec_model_path) 26 | self.load_projector() 27 | 28 | def forward(self, batch): 29 | targets = batch["tokens"].input_ids.masked_fill( 30 | batch["tokens"].input_ids == self.llama_tokenizer.pad_token_id, -100 31 | ) # [batch_size, max_len] 32 | targets = targets.masked_fill((batch["tokens"].token_type_ids == 0)[:,1:], -100) 33 | input_embeds = self.wrap_emb(batch) 34 | outputs = self.llama_model( 35 | inputs_embeds=input_embeds, 36 | attention_mask=batch["tokens"].attention_mask, 37 | return_dict=True, 38 | labels=targets, 39 | use_cache=False 40 | ) 41 | return outputs 42 | 43 | def generate(self, batch,temperature=0.8,do_sample=False,num_beams=1,max_gen_length=64,min_gen_length=1,repetition_penalty=1.0,length_penalty=1.0, num_return_sequences=1): 44 | input_embeds = self.wrap_emb(batch) 45 | generate_ids = self.llama_model.generate( 46 | inputs_embeds=input_embeds, 47 | attention_mask=batch["tokens"].attention_mask, 48 | temperature=temperature, 49 | do_sample=do_sample, 50 | num_beams=num_beams, 51 | max_new_tokens=max_gen_length, 52 | min_new_tokens=min_gen_length, 53 | pad_token_id=self.llama_tokenizer.pad_token_id, 54 | repetition_penalty=repetition_penalty, 55 | length_penalty=length_penalty, 56 | num_return_sequences=num_return_sequences 57 | ) 58 | output_text=self.llama_tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) 59 | outputs=[text.strip() for text in output_text] 60 | return outputs 61 | 62 | def training_step(self, batch, batch_idx): 63 | if self.scheduler: 64 | self.scheduler.step(self.trainer.global_step, self.current_epoch, self.trainer.max_steps) 65 | if batch["flag"]: 66 | for name, param in self.projector.named_parameters(): 67 | param.requires_grad = False 68 | else: 69 | for name, param in self.projector.named_parameters(): 70 | param.requires_grad = True 71 | out = self(batch) 72 | loss = self.configure_loss(out) 73 | self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True) 74 | self.log('lr', self.scheduler.optimizer.param_groups[0]['lr'], on_step=True, on_epoch=True, prog_bar=True) 75 | self.log('global_step_num', self.trainer.global_step, on_step=True, on_epoch=True, prog_bar=True) 76 | return loss 77 | 78 | def on_validation_epoch_start(self): 79 | self.val_content={ 80 | "generate":[], 81 | "real":[], 82 | "cans":[], 83 | } 84 | 85 | @torch.no_grad() 86 | def validation_step(self, batch, batch_idx): 87 | generate_output = self.generate(batch) 88 | output=[] 89 | for i,generate in enumerate(generate_output): 90 | real=batch['correct_answer'][i] 91 | cans=batch['cans_name'][i] 92 | generate=generate.strip().split("\n")[0] 93 | output.append((generate,real,cans)) 94 | return output 95 | 96 | def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): 97 | for generate,real,cans in outputs: 98 | self.val_content["generate"].append(generate) 99 | self.val_content["real"].append(real) 100 | self.val_content["cans"].append(cans) 101 | 102 | def on_validation_epoch_end(self): 103 | df=DataFrame(self.val_content) 104 | if not os.path.exists(self.hparams.output_dir): 105 | os.makedirs(self.hparams.output_dir) 106 | df.to_csv(op.join(self.hparams.output_dir, 'valid.csv')) 107 | prediction_valid_ratio,hr=self.calculate_hr1(self.val_content) 108 | metric=hr*prediction_valid_ratio 109 | self.log('val_prediction_valid', prediction_valid_ratio, on_step=False, on_epoch=True, prog_bar=True) 110 | self.log('val_hr', hr, on_step=False, on_epoch=True, prog_bar=True) 111 | self.log('metric', metric, on_step=False, on_epoch=True, prog_bar=True) 112 | 113 | def on_test_epoch_start(self): 114 | self.test_content={ 115 | "generate":[], 116 | "real":[], 117 | "cans":[], 118 | } 119 | 120 | @torch.no_grad() 121 | def test_step(self, batch, batch_idx): 122 | generate_output = self.generate(batch) 123 | output=[] 124 | for i,generate in enumerate(generate_output): 125 | real=batch['correct_answer'][i] 126 | cans=batch['cans_name'][i] 127 | generate=generate.strip().split("\n")[0] 128 | output.append((generate,real,cans)) 129 | return output 130 | 131 | def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): 132 | for generate,real,cans in outputs: 133 | self.test_content["generate"].append(generate) 134 | self.test_content["real"].append(real) 135 | self.test_content["cans"].append(cans) 136 | 137 | def on_test_epoch_end(self): 138 | df=DataFrame(self.test_content) 139 | if not os.path.exists(self.hparams.output_dir): 140 | os.makedirs(self.hparams.output_dir) 141 | df.to_csv(op.join(self.hparams.output_dir, 'test.csv')) 142 | prediction_valid_ratio,hr=self.calculate_hr1(self.test_content) 143 | metric=hr*prediction_valid_ratio 144 | self.log('test_prediction_valid', prediction_valid_ratio, on_step=False, on_epoch=True, prog_bar=True) 145 | self.log('test_hr', hr, on_step=False, on_epoch=True, prog_bar=True) 146 | self.log('metric', metric, on_step=False, on_epoch=True, prog_bar=True) 147 | 148 | def configure_optimizers(self): 149 | if hasattr(self.hparams, 'weight_decay'): 150 | weight_decay = self.hparams.weight_decay 151 | else: 152 | weight_decay = 0 153 | optimizer = torch.optim.Adam([ 154 | {'params': self.projector.parameters(), 'lr': self.hparams.lr, 'weight_decay':weight_decay}, 155 | {'params': self.llama_model.parameters(), 'lr': self.hparams.lr} 156 | ]) 157 | 158 | if self.hparams.lr_scheduler is None: 159 | return optimizer 160 | else: 161 | max_step = self.trainer.max_steps 162 | warmup_steps = max_step // 20 163 | print(f'max_step: {max_step}') 164 | print(f'warmup_steps: {warmup_steps}') 165 | if self.hparams.lr_scheduler == 'cosine': 166 | self.scheduler = LinearWarmupCosineLRScheduler(optimizer, 167 | max_step=max_step, 168 | min_lr=self.hparams.lr_decay_min_lr, 169 | init_lr=self.hparams.lr, 170 | warmup_steps=warmup_steps, 171 | warmup_start_lr=self.hparams.lr_warmup_start_lr) 172 | else: 173 | self.scheduler = None 174 | raise ValueError('Invalid lr_scheduler type!') 175 | return optimizer 176 | 177 | def configure_loss(self, out, labels=None): 178 | loss = self.hparams.loss.lower() 179 | if loss == 'lm': 180 | return out.loss 181 | else: 182 | raise ValueError("Invalid Loss Type!") 183 | 184 | def on_save_checkpoint(self, checkpoint): 185 | if self.hparams.save == 'part': 186 | checkpoint.pop('optimizer_states') 187 | to_be_removed = [] 188 | for key, value in checkpoint['state_dict'].items(): 189 | try: 190 | if not self.get_parameter(key).requires_grad: 191 | to_be_removed.append(key) 192 | except AttributeError: 193 | to_be_removed.append(key) 194 | for key in to_be_removed: 195 | checkpoint['state_dict'].pop(key) 196 | elif self.hparams.save == 'all': 197 | pass 198 | 199 | def load_llm(self, llm_path): 200 | print('Loading LLAMA') 201 | self.llama_tokenizer = LlamaTokenizer.from_pretrained(llm_path, use_fast=False) 202 | self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token 203 | self.llama_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 204 | self.llama_tokenizer.padding_side = "right" 205 | self.llama_tokenizer.add_special_tokens({'additional_special_tokens': ['[PH]','[HistoryEmb]','[CansEmb]','[ItemEmb]']}) 206 | self.llama_model = LlamaForCausalLM.from_pretrained(llm_path, torch_dtype=torch.bfloat16) 207 | self.llama_model.resize_token_embeddings(len(self.llama_tokenizer)) 208 | if self.hparams.llm_tuning == 'lora': 209 | if self.hparams.peft_dir: 210 | self.llama_model = PeftModel.from_pretrained(self.llm_model, self.hparams.peft_dir, is_trainable=True) 211 | else: 212 | if self.hparams.peft_config: 213 | peft_config = LoraConfig(**LoraConfig.from_json_file(self.hparams.peft_config)) 214 | else: 215 | peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, 216 | inference_mode=False, 217 | r=self.hparams.lora_r, 218 | lora_alpha=self.hparams.lora_alpha, 219 | lora_dropout=self.hparams.lora_dropout, 220 | target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']) 221 | self.peft_config = peft_config 222 | self.llama_model = get_peft_model(self.llama_model, peft_config) 223 | self.llama_model.print_trainable_parameters() 224 | elif self.hparams.llm_tuning == 'freeze': 225 | for name, param in self.llama_model.named_parameters(): 226 | param.requires_grad = False 227 | elif self.hparams.llm_tuning == 'freeze_lora': 228 | if self.hparams.peft_dir: 229 | self.llama_model = PeftModel.from_pretrained(self.llm_model, self.hparams.peft_dir, is_trainable=True) 230 | else: 231 | if self.hparams.peft_config: 232 | peft_config = LoraConfig(**LoraConfig.from_json_file(self.hparams.peft_config)) 233 | else: 234 | peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, 235 | inference_mode=False, 236 | r=self.hparams.lora_r, 237 | lora_alpha=self.hparams.lora_alpha, 238 | lora_dropout=self.hparams.lora_dropout, 239 | target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']) 240 | self.peft_config = peft_config 241 | self.llama_model = get_peft_model(self.llama_model, peft_config) 242 | for name, param in self.llama_model.named_parameters(): 243 | param.requires_grad = False 244 | self.llama_model.print_trainable_parameters() 245 | else: 246 | raise NotImplementedError() 247 | 248 | print('Loading LLAMA Done') 249 | 250 | def load_projector(self): 251 | name = self.hparams.model_name 252 | camel_name = ''.join([i.capitalize() for i in name.split('_')]) 253 | try: 254 | Model = getattr(importlib.import_module( 255 | '.'+name, package=__package__), camel_name) 256 | except: 257 | raise ValueError( 258 | f'Invalid Module File Name or Invalid Class Name {name}.{camel_name}!') 259 | self.projector = self.instancialize(Model, rec_size=self.hparams.rec_size, llm_size=self.llama_model.config.hidden_size) 260 | 261 | def instancialize(self, Model, **other_args): 262 | class_args = inspect.getargspec(Model.__init__).args[1:] 263 | inkeys = self.hparams.keys() 264 | args1 = {} 265 | for arg in class_args: 266 | if arg in inkeys: 267 | args1[arg] = getattr(self.hparams, arg) 268 | args1.update(other_args) 269 | return Model(**args1) 270 | 271 | def load_rec_model(self, rec_model_path): 272 | print('Loading Rec Model') 273 | self.rec_model = torch.load(rec_model_path, map_location="cpu") 274 | self.rec_model.eval() 275 | for name, param in self.rec_model.named_parameters(): 276 | param.requires_grad = False 277 | print('Loding Rec model Done') 278 | 279 | def encode_items(self, seq): 280 | if self.hparams.rec_embed=="SASRec": 281 | item_rec_embs=self.rec_model.cacu_x(seq) 282 | elif self.hparams.rec_embed in ['Caser','GRU']: 283 | item_rec_embs=self.rec_model.item_embeddings(seq) 284 | item_txt_embs=self.projector(item_rec_embs) 285 | return item_txt_embs 286 | 287 | def embed_tokens(self, token_ids): 288 | embeds = self.llama_model.base_model.embed_tokens(token_ids) 289 | return embeds 290 | 291 | def wrap_emb(self, batch): 292 | input_embeds = self.llama_model.get_input_embeddings()(batch["tokens"].input_ids) 293 | 294 | his_token_id=self.llama_tokenizer("[HistoryEmb]", return_tensors="pt",add_special_tokens=False).input_ids.item() 295 | cans_token_id=self.llama_tokenizer("[CansEmb]", return_tensors="pt",add_special_tokens=False).input_ids.item() 296 | item_token_id=self.llama_tokenizer("[ItemEmb]", return_tensors="pt",add_special_tokens=False).input_ids.item() 297 | his_item_embeds= self.encode_items(batch["seq"]) 298 | cans_item_embeds= self.encode_items(batch["cans"]) 299 | item_embeds=self.encode_items(batch["item_id"]) 300 | 301 | for i in range(len(batch["len_seq"])): 302 | if (batch["tokens"].input_ids[i]==his_token_id).nonzero().shape[0]>0: 303 | idx_tensor=(batch["tokens"].input_ids[i]==his_token_id).nonzero().view(-1) 304 | for idx, item_emb in zip(idx_tensor,his_item_embeds[i,:batch["len_seq"][i].item()]): 305 | input_embeds[i,idx]=item_emb 306 | if (batch["tokens"].input_ids[i]==cans_token_id).nonzero().shape[0]>0: 307 | idx_tensor=(batch["tokens"].input_ids[i]==cans_token_id).nonzero().view(-1) 308 | for idx, item_emb in zip(idx_tensor,cans_item_embeds[i,:batch["len_cans"][i].item()]): 309 | input_embeds[i,idx]=item_emb 310 | if (batch["tokens"].input_ids[i]==item_token_id).nonzero().shape[0]>0: 311 | idx=(batch["tokens"].input_ids[i]==item_token_id).nonzero().item() 312 | input_embeds[i,idx]=item_embeds[i] 313 | return input_embeds 314 | 315 | def calculate_hr1(self,eval_content): 316 | correct_num=0 317 | valid_num=0 318 | total_num=0 319 | for i,generate in enumerate(eval_content["generate"]): 320 | real=eval_content["real"][i] 321 | cans=eval_content["cans"][i] 322 | total_num+=1 323 | generate=generate.strip().lower().strip() 324 | real=real.strip().lower().strip() 325 | cans=[item.strip().lower().strip() for item in cans] 326 | gen_cans_list=[] 327 | for cans_item in cans: 328 | if cans_item in generate: 329 | gen_cans_list.append(cans_item) 330 | if len(gen_cans_list)==1: 331 | valid_num+=1 332 | if real == gen_cans_list[0]: 333 | correct_num+=1 334 | valid_ratio=valid_num/total_num 335 | if valid_num>0: 336 | hr1=correct_num/valid_num 337 | else: 338 | hr1=0 339 | return valid_ratio,hr1 -------------------------------------------------------------------------------- /optims.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | class LinearWarmupCosineLRScheduler: 4 | def __init__( 5 | self, 6 | optimizer, 7 | min_lr, 8 | init_lr, 9 | warmup_steps=0, 10 | warmup_start_lr=-1, 11 | **kwargs 12 | ): 13 | self.optimizer = optimizer 14 | 15 | self.min_lr = min_lr 16 | 17 | self.init_lr = init_lr 18 | self.warmup_steps = warmup_steps 19 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 20 | 21 | def step(self, cur_step, cur_epoch, max_step): 22 | # assuming the warmup iters less than one epoch 23 | if cur_epoch==0 and cur_step < self.warmup_steps: 24 | warmup_lr_schedule( 25 | step=cur_step, 26 | optimizer=self.optimizer, 27 | max_step=self.warmup_steps, 28 | init_lr=self.warmup_start_lr, 29 | max_lr=self.init_lr, 30 | ) 31 | else: 32 | cosine_lr_schedule( 33 | step=cur_step, 34 | optimizer=self.optimizer, 35 | max_step=max_step, 36 | init_lr=self.init_lr, 37 | min_lr=self.min_lr, 38 | ) 39 | 40 | def state_dict(self): 41 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 42 | 43 | def load_state_dict(self, state_dict): 44 | self.__dict__.update(state_dict) 45 | 46 | def cosine_lr_schedule(optimizer, step, max_step, init_lr, min_lr): 47 | """Decay the learning rate""" 48 | lr = (init_lr - min_lr) * 0.5 * ( 49 | 1.0 + math.cos(math.pi * step / max_step) 50 | ) + min_lr 51 | for param_group in optimizer.param_groups: 52 | param_group["lr"] = lr 53 | 54 | 55 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 56 | """Warmup the learning rate""" 57 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) 58 | for param_group in optimizer.param_groups: 59 | param_group["lr"] = lr -------------------------------------------------------------------------------- /prompt/artist.txt: -------------------------------------------------------------------------------- 1 | This user has listened to [HistoryHere] in the previous. Please predict the next music artist this user will listen to. The artist name candidates are [CansHere]. Choose only one artist from the candidates. The answer is 2 | This user has listened to [HistoryHere] in the previous. Given the following 10 music artist names, [CansHere], recommend one artist for this user to listen to next. The artist name you recommend is 3 | The listening history of this user is [HistoryHere]. Recommend the next music artist for this user to listen to from the following artist name set, [CansHere]. The recommendation should contain one artist's name only. The recommendation is -------------------------------------------------------------------------------- /prompt/game.txt: -------------------------------------------------------------------------------- 1 | This user has played [HistoryHere] in the previous. Please predict the next game this user will play. The game title candidates are [CansHere]. Choose only one game from the candidates. The answer is 2 | This user has played [HistoryHere] in the previous. Given the following 10 game titles, [CansHere], recommend one game for this user to play next. The game title you recommend is 3 | The visit history of this user is [HistoryHere]. Recommend a next game for this user to play from the following game title set, [CansHere]. The recommendation should contain one game title only. The recommendation is -------------------------------------------------------------------------------- /prompt/movie.txt: -------------------------------------------------------------------------------- 1 | This user has watched [HistoryHere] in the previous. Please predict the next movie this user will watch. Choose the answer from the following 10 movie titles: [CansHere]. Answer: 2 | This user has watched [HistoryHere] in the previous. Given the following 10 movie titles: [CansHere], recommend one movie for this user to watch next. The movie title you recommend is: 3 | The visit history of this user is: [HistoryHere]. Recommend a next movie for this user to watch from the following movie title set: [CansHere]. The recommendation should contain one movie title only. Recommendation: -------------------------------------------------------------------------------- /rec_model/lastfm.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljy0ustc/LLaRA/9733f732399e95e236f4f2e1d1139c1f0437090e/rec_model/lastfm.pt -------------------------------------------------------------------------------- /rec_model/movielens.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljy0ustc/LLaRA/9733f732399e95e236f4f2e1d1139c1f0437090e/rec_model/movielens.pt -------------------------------------------------------------------------------- /rec_model/steam.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljy0ustc/LLaRA/9733f732399e95e236f4f2e1d1139c1f0437090e/rec_model/steam.pt -------------------------------------------------------------------------------- /recommender/A_SASRec_final_bce_llm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import argparse 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import os 8 | import logging 9 | import time as Time 10 | from collections import Counter 11 | from SASRecModules_ori import * 12 | 13 | def extract_axis_1(data, indices): 14 | res = [] 15 | for i in range(data.shape[0]): 16 | res.append(data[i, indices[i], :]) 17 | res = torch.stack(res, dim=0).unsqueeze(1) 18 | return res 19 | 20 | class GRU(nn.Module): 21 | def __init__(self, hidden_size, item_num, state_size, gru_layers=1): 22 | super(GRU, self).__init__() 23 | self.hidden_size = hidden_size 24 | self.item_num = item_num 25 | self.state_size = state_size 26 | self.item_embeddings = nn.Embedding( 27 | num_embeddings=item_num + 1, 28 | embedding_dim=self.hidden_size, 29 | ) 30 | nn.init.normal_(self.item_embeddings.weight, 0, 0.01) 31 | self.gru = nn.GRU( 32 | input_size=self.hidden_size, 33 | hidden_size=self.hidden_size, 34 | num_layers=gru_layers, 35 | batch_first=True 36 | ) 37 | self.s_fc = nn.Linear(self.hidden_size, self.item_num) 38 | 39 | def forward(self, states, len_states): 40 | # Supervised Head 41 | emb = self.item_embeddings(states) 42 | emb_packed = torch.nn.utils.rnn.pack_padded_sequence(emb, len_states, batch_first=True, enforce_sorted=False) 43 | emb_packed, hidden = self.gru(emb_packed) 44 | hidden = hidden.view(-1, hidden.shape[2]) 45 | supervised_output = self.s_fc(hidden) 46 | return supervised_output 47 | 48 | def forward_eval(self, states, len_states): 49 | # Supervised Head 50 | emb = self.item_embeddings(states) 51 | emb_packed = torch.nn.utils.rnn.pack_padded_sequence(emb, len_states, batch_first=True, enforce_sorted=False) 52 | emb_packed, hidden = self.gru(emb_packed) 53 | hidden = hidden.view(-1, hidden.shape[2]) 54 | supervised_output = self.s_fc(hidden) 55 | 56 | return supervised_output 57 | 58 | 59 | class Caser(nn.Module): 60 | def __init__(self, hidden_size, item_num, state_size, num_filters, filter_sizes, 61 | dropout_rate): 62 | super(Caser, self).__init__() 63 | self.hidden_size = hidden_size 64 | self.item_num = int(item_num) 65 | self.state_size = state_size 66 | self.filter_sizes = eval(filter_sizes) 67 | self.num_filters = num_filters 68 | self.dropout_rate = dropout_rate 69 | self.item_embeddings = nn.Embedding( 70 | num_embeddings=item_num + 1, 71 | embedding_dim=self.hidden_size, 72 | ) 73 | 74 | # init embedding 75 | nn.init.normal_(self.item_embeddings.weight, 0, 0.01) 76 | 77 | # Horizontal Convolutional Layers 78 | self.horizontal_cnn = nn.ModuleList( 79 | [nn.Conv2d(1, self.num_filters, (i, self.hidden_size)) for i in self.filter_sizes]) 80 | # Initialize weights and biases 81 | for cnn in self.horizontal_cnn: 82 | nn.init.xavier_normal_(cnn.weight) 83 | nn.init.constant_(cnn.bias, 0.1) 84 | 85 | # Vertical Convolutional Layer 86 | self.vertical_cnn = nn.Conv2d(1, 1, (self.state_size, 1)) 87 | nn.init.xavier_normal_(self.vertical_cnn.weight) 88 | nn.init.constant_(self.vertical_cnn.bias, 0.1) 89 | 90 | # Fully Connected Layer 91 | self.num_filters_total = self.num_filters * len(self.filter_sizes) 92 | final_dim = self.hidden_size + self.num_filters_total 93 | self.s_fc = nn.Linear(final_dim, item_num) 94 | 95 | # dropout 96 | self.dropout = nn.Dropout(self.dropout_rate) 97 | 98 | def forward(self, states, len_states): 99 | input_emb = self.item_embeddings(states) 100 | mask = torch.ne(states, self.item_num).float().unsqueeze(-1) 101 | input_emb *= mask 102 | input_emb = input_emb.unsqueeze(1) 103 | pooled_outputs = [] 104 | for cnn in self.horizontal_cnn: 105 | h_out = nn.functional.relu(cnn(input_emb)) 106 | h_out = h_out.squeeze() 107 | p_out = nn.functional.max_pool1d(h_out, h_out.shape[2]) 108 | pooled_outputs.append(p_out) 109 | 110 | h_pool = torch.cat(pooled_outputs, 1) 111 | h_pool_flat = h_pool.view(-1, self.num_filters_total) 112 | 113 | v_out = nn.functional.relu(self.vertical_cnn(input_emb)) 114 | v_flat = v_out.view(-1, self.hidden_size) 115 | 116 | out = torch.cat([h_pool_flat, v_flat], 1) 117 | out = self.dropout(out) 118 | supervised_output = self.s_fc(out) 119 | 120 | return supervised_output 121 | 122 | def forward_eval(self, states, len_states): 123 | input_emb = self.item_embeddings(states) 124 | mask = torch.ne(states, self.item_num).float().unsqueeze(-1) 125 | input_emb *= mask 126 | input_emb = input_emb.unsqueeze(1) 127 | pooled_outputs = [] 128 | for cnn in self.horizontal_cnn: 129 | h_out = nn.functional.relu(cnn(input_emb)) 130 | h_out = h_out.squeeze() 131 | p_out = nn.functional.max_pool1d(h_out, h_out.shape[2]) 132 | pooled_outputs.append(p_out) 133 | 134 | h_pool = torch.cat(pooled_outputs, 1) 135 | h_pool_flat = h_pool.view(-1, self.num_filters_total) 136 | 137 | v_out = nn.functional.relu(self.vertical_cnn(input_emb)) 138 | v_flat = v_out.view(-1, self.hidden_size) 139 | 140 | out = torch.cat([h_pool_flat, v_flat], 1) 141 | out = self.dropout(out) 142 | supervised_output = self.s_fc(out) 143 | 144 | return supervised_output 145 | 146 | 147 | class SASRec(nn.Module): 148 | def __init__(self, hidden_size, item_num, state_size, dropout, device, num_heads=1): 149 | super().__init__() 150 | self.state_size = state_size 151 | self.hidden_size = hidden_size 152 | self.item_num = int(item_num) 153 | self.dropout = nn.Dropout(dropout) 154 | self.device = device 155 | self.item_embeddings = nn.Embedding( 156 | num_embeddings=item_num + 1, 157 | embedding_dim=hidden_size, 158 | ) 159 | nn.init.normal_(self.item_embeddings.weight, 0, 1) 160 | self.positional_embeddings = nn.Embedding( 161 | num_embeddings=state_size, 162 | embedding_dim=hidden_size 163 | ) 164 | self.emb_dropout = nn.Dropout(dropout) 165 | self.ln_1 = nn.LayerNorm(hidden_size) 166 | self.ln_2 = nn.LayerNorm(hidden_size) 167 | self.ln_3 = nn.LayerNorm(hidden_size) 168 | self.mh_attn = MultiHeadAttention(hidden_size, hidden_size, num_heads, dropout) 169 | self.feed_forward = PositionwiseFeedForward(hidden_size, hidden_size, dropout) 170 | self.s_fc = nn.Linear(hidden_size, item_num) 171 | 172 | def forward(self, states, len_states): 173 | inputs_emb = self.item_embeddings(states) 174 | inputs_emb += self.positional_embeddings(torch.arange(self.state_size).to(self.device)) 175 | seq = self.emb_dropout(inputs_emb) 176 | mask = torch.ne(states, self.item_num).float().unsqueeze(-1).to(self.device) 177 | seq *= mask 178 | seq_normalized = self.ln_1(seq) 179 | mh_attn_out = self.mh_attn(seq_normalized, seq) 180 | ff_out = self.feed_forward(self.ln_2(mh_attn_out)) 181 | ff_out *= mask 182 | ff_out = self.ln_3(ff_out) 183 | state_hidden = extract_axis_1(ff_out, len_states - 1) 184 | supervised_output = self.s_fc(state_hidden).squeeze() 185 | return supervised_output 186 | 187 | def forward_eval(self, states, len_states): 188 | inputs_emb = self.item_embeddings(states) 189 | inputs_emb += self.positional_embeddings(torch.arange(self.state_size).to(self.device)) 190 | seq = self.emb_dropout(inputs_emb) 191 | mask = torch.ne(states, self.item_num).float().unsqueeze(-1).to(self.device) 192 | seq *= mask 193 | seq_normalized = self.ln_1(seq) 194 | mh_attn_out = self.mh_attn(seq_normalized, seq) 195 | ff_out = self.feed_forward(self.ln_2(mh_attn_out)) 196 | ff_out *= mask 197 | ff_out = self.ln_3(ff_out) 198 | state_hidden = extract_axis_1(ff_out, len_states - 1) 199 | supervised_output = self.s_fc(state_hidden).squeeze() 200 | return supervised_output 201 | 202 | def cacul_h(self, states, len_states): 203 | inputs_emb = self.item_embeddings(states) 204 | inputs_emb += self.positional_embeddings(torch.arange(self.state_size).to(self.device)) 205 | seq = self.emb_dropout(inputs_emb) 206 | mask = torch.ne(states, self.item_num).float().unsqueeze(-1).to(self.device) 207 | seq *= mask 208 | seq_normalized = self.ln_1(seq) 209 | mh_attn_out = self.mh_attn(seq_normalized, seq) 210 | ff_out = self.feed_forward(self.ln_2(mh_attn_out)) 211 | ff_out *= mask 212 | ff_out = self.ln_3(ff_out) 213 | state_hidden = extract_axis_1(ff_out, len_states - 1) 214 | 215 | return state_hidden 216 | 217 | def cacu_x(self, x): 218 | x = self.item_embeddings(x) 219 | 220 | return x -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.31.0 2 | aiohttp==3.9.5 3 | aiosignal==1.3.1 4 | async-timeout==4.0.3 5 | attrs==23.2.0 6 | bitsandbytes==0.43.1 7 | certifi==2024.6.2 8 | charset-normalizer==3.3.2 9 | cmake==3.29.5 10 | filelock==3.14.0 11 | frozenlist==1.4.1 12 | fsspec==2024.6.0 13 | huggingface-hub==0.23.3 14 | idna==3.7 15 | Jinja2==3.1.4 16 | lightning-utilities==0.11.2 17 | lit==18.1.7 18 | MarkupSafe==2.1.5 19 | mpmath==1.3.0 20 | multidict==6.0.5 21 | networkx==3.1 22 | numpy==1.24.4 23 | nvidia-cublas-cu11==11.10.3.66 24 | nvidia-cuda-cupti-cu11==11.7.101 25 | nvidia-cuda-nvrtc-cu11==11.7.99 26 | nvidia-cuda-runtime-cu11==11.7.99 27 | nvidia-cudnn-cu11==8.5.0.96 28 | nvidia-cufft-cu11==10.9.0.58 29 | nvidia-curand-cu11==10.2.10.91 30 | nvidia-cusolver-cu11==11.4.0.1 31 | nvidia-cusparse-cu11==11.7.4.91 32 | nvidia-nccl-cu11==2.14.3 33 | nvidia-nvtx-cu11==11.7.91 34 | packaging==24.1 35 | pandas==2.0.3 36 | peft==0.11.1 37 | protobuf==5.27.1 38 | psutil==5.9.8 39 | python-dateutil==2.9.0.post0 40 | pytorch-lightning==1.8.6 41 | pytz==2024.1 42 | PyYAML==6.0.1 43 | regex==2024.5.15 44 | requests==2.32.3 45 | safetensors==0.4.3 46 | sentencepiece==0.2.0 47 | six==1.16.0 48 | sympy==1.12.1 49 | tensorboardX==2.6.2.2 50 | tokenizers==0.13.3 51 | torch==2.0.0 52 | torchmetrics==1.4.0.post0 53 | tqdm==4.66.4 54 | transformers==4.28.0 55 | triton==2.0.0 56 | typing_extensions==4.12.2 57 | tzdata==2024.1 58 | urllib3==2.2.1 59 | yarl==1.9.4 60 | -------------------------------------------------------------------------------- /test_lastfm.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --mode test \ 3 | --batch_size 8 \ 4 | --accumulate_grad_batches 16 \ 5 | --dataset lastfm_data \ 6 | --data_dir data/ref/lastfm \ 7 | --cans_num 20 \ 8 | --prompt_path ./prompt/artist.txt \ 9 | --rec_embed SASRec \ 10 | --llm_tuning lora \ 11 | --llm_path xxx \ 12 | --rec_model_path ./rec_model/lastfm.pt \ 13 | --ckpt_path ./checkpoints/lastfm.ckpt \ 14 | --output_dir ./output/lastfm/ \ 15 | --log_dir lastfm_logs \ 16 | --lr_warmup_start_lr 7e-6 \ 17 | --lr 7e-4 \ 18 | --lr_decay_min_lr 7e-6 \ 19 | --max_epochs 5 -------------------------------------------------------------------------------- /test_movielens.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --mode test \ 3 | --batch_size 8 \ 4 | --accumulate_grad_batches 16 \ 5 | --dataset movielens_data \ 6 | --data_dir data/ref/movielens \ 7 | --cans_num 20 \ 8 | --prompt_path ./prompt/movie.txt \ 9 | --rec_embed SASRec \ 10 | --llm_tuning lora \ 11 | --llm_path xxx \ 12 | --rec_model_path ./rec_model/movielens.pt \ 13 | --ckpt_path ./checkpoints/movielens.ckpt \ 14 | --output_dir ./output/movielens/ \ 15 | --log_dir movielens_logs \ 16 | --lr_warmup_start_lr 8e-6 \ 17 | --lr 8e-4 \ 18 | --lr_decay_min_lr 8e-6 \ 19 | --max_epochs 5 -------------------------------------------------------------------------------- /test_steam.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --mode test \ 3 | --batch_size 4 \ 4 | --accumulate_grad_batches 32 \ 5 | --dataset steam_data \ 6 | --data_dir data/ref/steam \ 7 | --cans_num 20 \ 8 | --prompt_path ./prompt/game.txt \ 9 | --rec_embed SASRec \ 10 | --llm_tuning lora \ 11 | --llm_path xxx \ 12 | --rec_model_path ./rec_model/steam.pt \ 13 | --ckpt_path ./checkpoints/steam.ckpt \ 14 | --output_dir ./output/steam/ \ 15 | --log_dir steam_logs \ 16 | --lr_warmup_start_lr 5e-6 \ 17 | --lr 5e-4 \ 18 | --lr_decay_min_lr 5e-6 \ 19 | --max_epochs 5 -------------------------------------------------------------------------------- /train_lastfm.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --mode train \ 3 | --batch_size 8 \ 4 | --accumulate_grad_batches 16 \ 5 | --dataset lastfm_data \ 6 | --data_dir data/ref/lastfm \ 7 | --cans_num 20 \ 8 | --prompt_path ./prompt/artist.txt \ 9 | --rec_embed SASRec \ 10 | --llm_tuning lora \ 11 | --llm_path xxx \ 12 | --rec_model_path ./rec_model/lastfm.pt \ 13 | --ckpt_dir ./checkpoints/lastfm/ \ 14 | --output_dir ./output/lastfm/ \ 15 | --log_dir lastfm_logs \ 16 | --lr_warmup_start_lr 7e-6 \ 17 | --lr 7e-4 \ 18 | --lr_decay_min_lr 7e-6 \ 19 | --max_epochs 5 -------------------------------------------------------------------------------- /train_movielens.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --mode train \ 3 | --batch_size 8 \ 4 | --accumulate_grad_batches 16 \ 5 | --dataset movielens_data \ 6 | --data_dir data/ref/movielens \ 7 | --cans_num 20 \ 8 | --prompt_path ./prompt/movie.txt \ 9 | --rec_embed SASRec \ 10 | --llm_tuning lora \ 11 | --llm_path xxx \ 12 | --rec_model_path ./rec_model/movielens.pt \ 13 | --ckpt_dir ./checkpoints/movielens/ \ 14 | --output_dir ./output/movielens/ \ 15 | --log_dir movielens_logs \ 16 | --lr_warmup_start_lr 8e-6 \ 17 | --lr 8e-4 \ 18 | --lr_decay_min_lr 8e-6 \ 19 | --max_epochs 5 -------------------------------------------------------------------------------- /train_steam.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --mode train \ 3 | --batch_size 4 \ 4 | --accumulate_grad_batches 32 \ 5 | --dataset steam_data \ 6 | --data_dir data/ref/steam \ 7 | --cans_num 20 \ 8 | --prompt_path ./prompt/game.txt \ 9 | --rec_embed SASRec \ 10 | --llm_tuning lora \ 11 | --llm_path xxx \ 12 | --rec_model_path ./rec_model/steam.pt \ 13 | --ckpt_dir ./checkpoints/steam/ \ 14 | --output_dir ./output/steam/ \ 15 | --log_dir steam_logs \ 16 | --lr_warmup_start_lr 5e-6 \ 17 | --lr 5e-4 \ 18 | --lr_decay_min_lr 5e-6 \ 19 | --max_epochs 5 --------------------------------------------------------------------------------