├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── alpaca-lora-7B ├── adapter_config.json └── adapter_model.bin ├── data ├── book │ ├── test.json │ ├── train.json │ └── valid.json └── movie │ ├── test.json │ ├── train.json │ └── valid.json ├── evaluate.py ├── export_hf_checkpoint.py ├── export_state_dict_checkpoint.py ├── finetune.py ├── finetune_multi_rec.py ├── finetune_rec.py ├── preprocess_book.py ├── preprocess_movie.py ├── requirements.txt └── shell ├── evaluate.sh ├── instruct_7B.sh └── instruct_multi_7B.sh /.gitattributes: -------------------------------------------------------------------------------- 1 | *.bin filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAI990323/TALLRec/c1db29ce6501bce32cc8c3e343c4eb14155beeb1/.gitignore -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Our weights for the instruction tuning model is uploading [here](https://drive.google.com/file/d/1teUwLm4BOqhngfCKKXE1tiMhJPf_FvRJ/view?usp=sharing) 2 | 3 | **TALLRec: An Effective and Efficient Tuning Framework to Align Large Language Model with Recommendation** is available at https://arxiv.org/abs/2305.00447. 4 | 5 | **Wrongly delete the line in evaluate.py by mistake, now it has been updated** 6 | 7 | We introduce a novel framework (TALLRec) that enables the efficient and effective adaptation of LLMs to recommendation tasks. 8 | 9 | # Main results 10 | | | |movie | || book | | 11 | |------------------------------- | ----- | ----- | ----- | ----- | ----- | ----- | 12 | | Few-shot | 16 | 64 | 256 | 16 | 64 | 256 | 13 | | GRU | 49.07 | 49.87 | 52.89 | 48.95 | 49.64 | 49.86 | 14 | | Caser | 49.68 | 51.06 | 54.20 | 49.84 | 49.72 | 49.57 | 15 | | SASRec | 50.43 | 50.48 | 52.25 | 49.48 | 50.06 | 50.20 | 16 | | DROS | 50.76 | 51.54 | 54.07 | 49.28 | 49.13 | 49.13 | 17 | | GRU-BERT | 50.85 | 51.65 | 53.44 | 50.07 | 49.64 | 49.79 | 18 | | DROS-BERT | 50.21 | 51.71 | 53.94 | 50.07 | 48.98 | 50.20 | 19 | | TALLRec (ours) | **67.24** | **67.48** | **71.98** | **56.36** | **60.39** | **64.38** | 20 | 21 | Table 1. we shown the AUC results of the baseline models and our frameworks on movie and book scenarios. 22 | 23 | Train TALLRec base on LLaMA7B: 24 | ``` 25 | bash ./shell/instruct_7B.sh gpu_id random_seed 26 | ``` 27 | If you want to run it under your environment, you need to make changes to the sh file: 28 | - output_dir: Model save path,we will automatically add the seed and the sample to the end of the path for each experiments. 29 | - base_model: LLaMA parameter weight path in Hugginface format 30 | - train_data: Training data path such as "./data/movie/train.json" for movie dataset. 31 | - val_data: Validation data set path such as "./data/movie/valid.json" for movie dataset. 32 | - instruction_model: The LoRA weights after the instruction tuning, for example lora weight from alpaca-lora. 33 | 34 | After training, you need to evluate the test result on the best model evaluated by the validation set. 35 | ``` 36 | bash ./shell/evaluate.sh gpu_id output_dir 37 | ``` 38 | If you want to run it under your environment, you need to make changes to the sh file: 39 | - base_model: LLaMA parameter weight path in Hugginface format 40 | - test_data: Test data set path such as "./data/movie/test.json" for movie dataset. 41 | 42 | Note that we will automatically detect all the different seed and sample files in the output_dir directory, and then integrate these results into the output_dir.json file. 43 | 44 | Our project is developed based on the Alpaca_lora [repo](https://github.com/tloen/alpaca-lora), thanks for their contributions. 45 | 46 | For "Environment setting sharing for CUDA 12.0", please see [here](https://github.com/SAI990323/TALLRec/issues/46). 47 | -------------------------------------------------------------------------------- /alpaca-lora-7B/adapter_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name_or_path": "/data/zhangjz/alpaca-lora/weights/", 3 | "bias": "none", 4 | "enable_lora": null, 5 | "fan_in_fan_out": false, 6 | "inference_mode": true, 7 | "lora_alpha": 16, 8 | "lora_dropout": 0.05, 9 | "merge_weights": false, 10 | "modules_to_save": null, 11 | "peft_type": "LORA", 12 | "r": 8, 13 | "target_modules": [ 14 | "q_proj", 15 | "v_proj" 16 | ], 17 | "task_type": "CAUSAL_LM" 18 | } -------------------------------------------------------------------------------- /alpaca-lora-7B/adapter_model.bin: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5675d68c6a5123c4df97fe77d07e3c9708494d7b065a56db422e3c046eac3ee7 3 | size 16822989 4 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import fire 4 | import gradio as gr 5 | import torch 6 | torch.set_num_threads(1) 7 | import transformers 8 | import json 9 | import os 10 | os.environ['OPENBLAS_NUM_THREADS'] = '1' 11 | os.environ['OMP_NUM_THREADS'] = '1' 12 | from peft import PeftModel 13 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer 14 | from sklearn.metrics import roc_auc_score 15 | if torch.cuda.is_available(): 16 | device = "cuda" 17 | else: 18 | device = "cpu" 19 | 20 | try: 21 | if torch.backends.mps.is_available(): 22 | device = "mps" 23 | except: # noqa: E722 24 | pass 25 | 26 | 27 | def main( 28 | load_8bit: bool = False, 29 | base_model: str = "", 30 | lora_weights: str = "tloen/alpaca-lora-7b", 31 | test_data_path: str = "data/test.json", 32 | result_json_data: str = "temp.json", 33 | batch_size: int = 32, 34 | share_gradio: bool = False, 35 | ): 36 | assert ( 37 | base_model 38 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" 39 | 40 | model_type = lora_weights.split('/')[-1] 41 | model_name = '_'.join(model_type.split('_')[:2]) 42 | 43 | if model_type.find('book') > -1: 44 | train_sce = 'book' 45 | else: 46 | train_sce = 'movie' 47 | 48 | if test_data_path.find('book') > -1: 49 | test_sce = 'book' 50 | else: 51 | test_sce = 'movie' 52 | 53 | temp_list = model_type.split('_') 54 | seed = temp_list[-2] 55 | sample = temp_list[-1] 56 | 57 | if os.path.exists(result_json_data): 58 | f = open(result_json_data, 'r') 59 | data = json.load(f) 60 | f.close() 61 | else: 62 | data = dict() 63 | 64 | if not data.__contains__(train_sce): 65 | data[train_sce] = {} 66 | if not data[train_sce].__contains__(test_sce): 67 | data[train_sce][test_sce] = {} 68 | if not data[train_sce][test_sce].__contains__(model_name): 69 | data[train_sce][test_sce][model_name] = {} 70 | if not data[train_sce][test_sce][model_name].__contains__(seed): 71 | data[train_sce][test_sce][model_name][seed] = {} 72 | if data[train_sce][test_sce][model_name][seed].__contains__(sample): 73 | exit(0) 74 | # data[train_sce][test_sce][model_name][seed][sample] = 75 | 76 | 77 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 78 | if device == "cuda": 79 | model = LlamaForCausalLM.from_pretrained( 80 | base_model, 81 | load_in_8bit=load_8bit, 82 | torch_dtype=torch.float16, 83 | device_map="auto", 84 | ) 85 | model = PeftModel.from_pretrained( 86 | model, 87 | lora_weights, 88 | torch_dtype=torch.float16, 89 | device_map={'': 0} 90 | ) 91 | elif device == "mps": 92 | model = LlamaForCausalLM.from_pretrained( 93 | base_model, 94 | device_map={"": device}, 95 | torch_dtype=torch.float16, 96 | ) 97 | model = PeftModel.from_pretrained( 98 | model, 99 | lora_weights, 100 | device_map={"": device}, 101 | torch_dtype=torch.float16, 102 | ) 103 | else: 104 | model = LlamaForCausalLM.from_pretrained( 105 | base_model, device_map={"": device}, low_cpu_mem_usage=True 106 | ) 107 | model = PeftModel.from_pretrained( 108 | model, 109 | lora_weights, 110 | device_map={"": device}, 111 | ) 112 | 113 | 114 | tokenizer.padding_side = "left" 115 | # unwind broken decapoda-research config 116 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk 117 | model.config.bos_token_id = 1 118 | model.config.eos_token_id = 2 119 | 120 | if not load_8bit: 121 | model.half() # seems to fix bugs for some users. 122 | 123 | model.eval() 124 | if torch.__version__ >= "2" and sys.platform != "win32": 125 | model = torch.compile(model) 126 | 127 | def evaluate( 128 | instructions, 129 | inputs=None, 130 | temperature=0, 131 | top_p=1.0, 132 | top_k=40, 133 | num_beams=1, 134 | max_new_tokens=128, 135 | batch_size=1, 136 | **kwargs, 137 | ): 138 | prompt = [generate_prompt(instruction, input) for instruction, input in zip(instructions, inputs)] 139 | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device) 140 | generation_config = GenerationConfig( 141 | temperature=temperature, 142 | top_p=top_p, 143 | top_k=top_k, 144 | num_beams=num_beams, 145 | **kwargs, 146 | ) 147 | with torch.no_grad(): 148 | generation_output = model.generate( 149 | **inputs, 150 | generation_config=generation_config, 151 | return_dict_in_generate=True, 152 | output_scores=True, 153 | max_new_tokens=max_new_tokens, 154 | # batch_size=batch_size, 155 | ) 156 | s = generation_output.sequences 157 | scores = generation_output.scores[0].softmax(dim=-1) 158 | logits = torch.tensor(scores[:,[8241, 3782]], dtype=torch.float32).softmax(dim=-1) 159 | input_ids = inputs["input_ids"].to(device) 160 | L = input_ids.shape[1] 161 | s = generation_output.sequences 162 | output = tokenizer.batch_decode(s, skip_special_tokens=True) 163 | output = [_.split('Response:\n')[-1] for _ in output] 164 | 165 | return output, logits.tolist() 166 | 167 | # testing code for readme 168 | logit_list = [] 169 | gold_list= [] 170 | outputs = [] 171 | logits = [] 172 | from tqdm import tqdm 173 | gold = [] 174 | pred = [] 175 | 176 | with open(test_data_path, 'r') as f: 177 | test_data = json.load(f) 178 | instructions = [_['instruction'] for _ in test_data] 179 | inputs = [_['input'] for _ in test_data] 180 | gold = [int(_['output'] == 'Yes.') for _ in test_data] 181 | def batch(list, batch_size=32): 182 | chunk_size = (len(list) - 1) // batch_size + 1 183 | for i in range(chunk_size): 184 | yield list[batch_size * i: batch_size * (i + 1)] 185 | for i, batch in tqdm(enumerate(zip(batch(instructions), batch(inputs)))): 186 | instructions, inputs = batch 187 | output, logit = evaluate(instructions, inputs) 188 | outputs = outputs + output 189 | logits = logits + logit 190 | for i, test in tqdm(enumerate(test_data)): 191 | test_data[i]['predict'] = outputs[i] 192 | test_data[i]['logits'] = logits[i] 193 | pred.append(logits[i][0]) 194 | 195 | from sklearn.metrics import roc_auc_score 196 | 197 | data[train_sce][test_sce][model_name][seed][sample] = roc_auc_score(gold, pred) 198 | f = open(result_json_data, 'w') 199 | json.dump(data, f, indent=4) 200 | f.close() 201 | 202 | def generate_prompt(instruction, input=None): 203 | if input: 204 | return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501 205 | 206 | ### Instruction: 207 | {instruction} 208 | 209 | ### Input: 210 | {input} 211 | 212 | ### Response: 213 | """ 214 | else: 215 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501 216 | 217 | ### Instruction: 218 | {instruction} 219 | 220 | ### Response: 221 | """ 222 | 223 | 224 | if __name__ == "__main__": 225 | fire.Fire(main) 226 | -------------------------------------------------------------------------------- /export_hf_checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import transformers 5 | from peft import PeftModel 6 | from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402 7 | 8 | BASE_MODEL = os.environ.get("BASE_MODEL", None) 9 | assert ( 10 | BASE_MODEL 11 | ), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501 12 | 13 | tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL) 14 | 15 | base_model = LlamaForCausalLM.from_pretrained( 16 | BASE_MODEL, 17 | load_in_8bit=False, 18 | torch_dtype=torch.float16, 19 | device_map={"": "cpu"}, 20 | ) 21 | 22 | first_weight = base_model.model.layers[0].self_attn.q_proj.weight 23 | first_weight_old = first_weight.clone() 24 | 25 | lora_model = PeftModel.from_pretrained( 26 | base_model, 27 | "tloen/alpaca-lora-7b", 28 | device_map={"": "cpu"}, 29 | torch_dtype=torch.float16, 30 | ) 31 | 32 | lora_weight = lora_model.base_model.model.model.layers[ 33 | 0 34 | ].self_attn.q_proj.weight 35 | 36 | assert torch.allclose(first_weight_old, first_weight) 37 | 38 | # merge weights 39 | for layer in lora_model.base_model.model.model.layers: 40 | layer.self_attn.q_proj.merge_weights = True 41 | layer.self_attn.v_proj.merge_weights = True 42 | 43 | lora_model.train(False) 44 | 45 | # did we do anything? 46 | assert not torch.allclose(first_weight_old, first_weight) 47 | 48 | lora_model_sd = lora_model.state_dict() 49 | deloreanized_sd = { 50 | k.replace("base_model.model.", ""): v 51 | for k, v in lora_model_sd.items() 52 | if "lora" not in k 53 | } 54 | 55 | LlamaForCausalLM.save_pretrained( 56 | base_model, "./hf_ckpt", state_dict=deloreanized_sd, max_shard_size="400MB" 57 | ) 58 | -------------------------------------------------------------------------------- /export_state_dict_checkpoint.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | import transformers 6 | from peft import PeftModel 7 | from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: E402 8 | 9 | BASE_MODEL = os.environ.get("BASE_MODEL", None) 10 | assert ( 11 | BASE_MODEL 12 | ), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501 13 | 14 | tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL) 15 | 16 | base_model = LlamaForCausalLM.from_pretrained( 17 | BASE_MODEL, 18 | load_in_8bit=False, 19 | torch_dtype=torch.float16, 20 | device_map={"": "cpu"}, 21 | ) 22 | 23 | lora_model = PeftModel.from_pretrained( 24 | base_model, 25 | "tloen/alpaca-lora-7b", 26 | device_map={"": "cpu"}, 27 | torch_dtype=torch.float16, 28 | ) 29 | 30 | # merge weights 31 | for layer in lora_model.base_model.model.model.layers: 32 | layer.self_attn.q_proj.merge_weights = True 33 | layer.self_attn.v_proj.merge_weights = True 34 | 35 | lora_model.train(False) 36 | 37 | lora_model_sd = lora_model.state_dict() 38 | 39 | params = { 40 | "dim": 4096, 41 | "multiple_of": 256, 42 | "n_heads": 32, 43 | "n_layers": 32, 44 | "norm_eps": 1e-06, 45 | "vocab_size": -1, 46 | } 47 | n_layers = params["n_layers"] 48 | n_heads = params["n_heads"] 49 | dim = params["dim"] 50 | dims_per_head = dim // n_heads 51 | base = 10000.0 52 | inv_freq = 1.0 / ( 53 | base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head) 54 | ) 55 | 56 | 57 | def permute(w): 58 | return ( 59 | w.view(n_heads, dim // n_heads // 2, 2, dim) 60 | .transpose(1, 2) 61 | .reshape(dim, dim) 62 | ) 63 | 64 | 65 | def unpermute(w): 66 | return ( 67 | w.view(n_heads, 2, dim // n_heads // 2, dim) 68 | .transpose(1, 2) 69 | .reshape(dim, dim) 70 | ) 71 | 72 | 73 | def translate_state_dict_key(k): # noqa: C901 74 | k = k.replace("base_model.model.", "") 75 | if k == "model.embed_tokens.weight": 76 | return "tok_embeddings.weight" 77 | elif k == "model.norm.weight": 78 | return "norm.weight" 79 | elif k == "lm_head.weight": 80 | return "output.weight" 81 | elif k.startswith("model.layers."): 82 | layer = k.split(".")[2] 83 | if k.endswith(".self_attn.q_proj.weight"): 84 | return f"layers.{layer}.attention.wq.weight" 85 | elif k.endswith(".self_attn.k_proj.weight"): 86 | return f"layers.{layer}.attention.wk.weight" 87 | elif k.endswith(".self_attn.v_proj.weight"): 88 | return f"layers.{layer}.attention.wv.weight" 89 | elif k.endswith(".self_attn.o_proj.weight"): 90 | return f"layers.{layer}.attention.wo.weight" 91 | elif k.endswith(".mlp.gate_proj.weight"): 92 | return f"layers.{layer}.feed_forward.w1.weight" 93 | elif k.endswith(".mlp.down_proj.weight"): 94 | return f"layers.{layer}.feed_forward.w2.weight" 95 | elif k.endswith(".mlp.up_proj.weight"): 96 | return f"layers.{layer}.feed_forward.w3.weight" 97 | elif k.endswith(".input_layernorm.weight"): 98 | return f"layers.{layer}.attention_norm.weight" 99 | elif k.endswith(".post_attention_layernorm.weight"): 100 | return f"layers.{layer}.ffn_norm.weight" 101 | elif k.endswith("rotary_emb.inv_freq") or "lora" in k: 102 | return None 103 | else: 104 | print(layer, k) 105 | raise NotImplementedError 106 | else: 107 | print(k) 108 | raise NotImplementedError 109 | 110 | 111 | new_state_dict = {} 112 | for k, v in lora_model_sd.items(): 113 | new_k = translate_state_dict_key(k) 114 | if new_k is not None: 115 | if "wq" in new_k or "wk" in new_k: 116 | new_state_dict[new_k] = unpermute(v) 117 | else: 118 | new_state_dict[new_k] = v 119 | 120 | os.makedirs("./ckpt", exist_ok=True) 121 | 122 | torch.save(new_state_dict, "./ckpt/consolidated.00.pth") 123 | 124 | with open("./ckpt/params.json", "w") as f: 125 | json.dump(params, f) 126 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['LD_LIBRARY_PATH'] = '/data/baokq/miniconda3/envs/alpaca_lora/lib/' 3 | import sys 4 | from typing import List 5 | 6 | import fire 7 | import torch 8 | import transformers 9 | from datasets import load_dataset 10 | 11 | 12 | from peft import ( # noqa: E402 13 | LoraConfig, 14 | get_peft_model, 15 | get_peft_model_state_dict, 16 | prepare_model_for_int8_training, 17 | set_peft_model_state_dict, 18 | ) 19 | from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402 20 | 21 | def train( 22 | # model/data params 23 | base_model: str = "", # the only required argument 24 | data_path: str = "yahma/alpaca-cleaned", 25 | output_dir: str = "./lora-alpaca", 26 | # training hyperparams 27 | batch_size: int = 128, 28 | micro_batch_size: int = 4, 29 | num_epochs: int = 3, 30 | learning_rate: float = 3e-4, 31 | cutoff_len: int = 256, 32 | val_set_size: int = 2000, 33 | # lora hyperparams 34 | lora_r: int = 8, 35 | lora_alpha: int = 16, 36 | lora_dropout: float = 0.05, 37 | lora_target_modules: List[str] = [ 38 | "q_proj", 39 | "v_proj", 40 | ], 41 | # llm hyperparams 42 | train_on_inputs: bool = True, # if False, masks out inputs in loss 43 | group_by_length: bool = False, # faster, but produces an odd training loss curve 44 | # wandb params 45 | wandb_project: str = "", 46 | wandb_run_name: str = "", 47 | wandb_watch: str = "", # options: false | gradients | all 48 | wandb_log_model: str = "", # options: false | true 49 | resume_from_checkpoint: str = None, # either training checkpoint or final adapter 50 | ): 51 | print( 52 | f"Training Alpaca-LoRA model with params:\n" 53 | f"base_model: {base_model}\n" 54 | f"data_path: {data_path}\n" 55 | f"output_dir: {output_dir}\n" 56 | f"batch_size: {batch_size}\n" 57 | f"micro_batch_size: {micro_batch_size}\n" 58 | f"num_epochs: {num_epochs}\n" 59 | f"learning_rate: {learning_rate}\n" 60 | f"cutoff_len: {cutoff_len}\n" 61 | f"val_set_size: {val_set_size}\n" 62 | f"lora_r: {lora_r}\n" 63 | f"lora_alpha: {lora_alpha}\n" 64 | f"lora_dropout: {lora_dropout}\n" 65 | f"lora_target_modules: {lora_target_modules}\n" 66 | f"train_on_inputs: {train_on_inputs}\n" 67 | f"group_by_length: {group_by_length}\n" 68 | f"wandb_project: {wandb_project}\n" 69 | f"wandb_run_name: {wandb_run_name}\n" 70 | f"wandb_watch: {wandb_watch}\n" 71 | f"wandb_log_model: {wandb_log_model}\n" 72 | f"resume_from_checkpoint: {resume_from_checkpoint}\n" 73 | ) 74 | assert ( 75 | base_model 76 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" 77 | gradient_accumulation_steps = batch_size // micro_batch_size 78 | # print(f"gradient_accumulation_steps: {gradient_accumulation_steps}") 79 | 80 | device_map = "auto" 81 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 82 | ddp = world_size != 1 83 | if ddp: 84 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 85 | gradient_accumulation_steps = gradient_accumulation_steps // world_size 86 | 87 | # Check if parameter passed or if set within environ 88 | use_wandb = len(wandb_project) > 0 or ( 89 | "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 90 | ) 91 | # Only overwrite environ if wandb param passed 92 | if len(wandb_project) > 0: 93 | os.environ["WANDB_PROJECT"] = wandb_project 94 | if len(wandb_watch) > 0: 95 | os.environ["WANDB_WATCH"] = wandb_watch 96 | if len(wandb_log_model) > 0: 97 | os.environ["WANDB_LOG_MODEL"] = wandb_log_model 98 | 99 | model = LlamaForCausalLM.from_pretrained( 100 | base_model, 101 | load_in_8bit=True, 102 | torch_dtype=torch.float16, 103 | device_map=device_map, 104 | ) 105 | 106 | os.environ["WANDB_DISABLED"] = "true" 107 | 108 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 109 | 110 | tokenizer.pad_token_id = ( 111 | 0 # unk. we want this to be different from the eos token 112 | ) 113 | tokenizer.padding_side = "left" # Allow batched inference 114 | 115 | def tokenize(prompt, add_eos_token=True): 116 | # there's probably a way to do this with the tokenizer settings 117 | # but again, gotta move fast 118 | result = tokenizer( 119 | prompt, 120 | truncation=True, 121 | max_length=cutoff_len, 122 | padding=False, 123 | return_tensors=None, 124 | ) 125 | if ( 126 | result["input_ids"][-1] != tokenizer.eos_token_id 127 | and len(result["input_ids"]) < cutoff_len 128 | and add_eos_token 129 | ): 130 | result["input_ids"].append(tokenizer.eos_token_id) 131 | result["attention_mask"].append(1) 132 | 133 | result["labels"] = result["input_ids"].copy() 134 | 135 | return result 136 | 137 | def generate_and_tokenize_prompt(data_point): 138 | full_prompt = generate_prompt(data_point) 139 | tokenized_full_prompt = tokenize(full_prompt) 140 | if not train_on_inputs: 141 | user_prompt = generate_prompt({**data_point, "output": ""}) 142 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) 143 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 144 | 145 | tokenized_full_prompt["labels"] = [ 146 | -100 147 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 148 | user_prompt_len: 149 | ] # could be sped up, probably 150 | return tokenized_full_prompt 151 | 152 | model = prepare_model_for_int8_training(model) 153 | 154 | config = LoraConfig( 155 | r=lora_r, 156 | lora_alpha=lora_alpha, 157 | target_modules=lora_target_modules, 158 | lora_dropout=lora_dropout, 159 | bias="none", 160 | task_type="CAUSAL_LM", 161 | ) 162 | model = get_peft_model(model, config) 163 | 164 | if data_path.endswith(".json"): # todo: support jsonl 165 | data = load_dataset("json", data_files=data_path) 166 | else: 167 | data = load_dataset(data_path) 168 | 169 | if resume_from_checkpoint: 170 | # Check the available weights and load them 171 | checkpoint_name = os.path.join( 172 | resume_from_checkpoint, "pytorch_model.bin" 173 | ) # Full checkpoint 174 | if not os.path.exists(checkpoint_name): 175 | checkpoint_name = os.path.join( 176 | resume_from_checkpoint, "adapter_model.bin" 177 | ) # only LoRA model - LoRA config above has to fit 178 | resume_from_checkpoint = ( 179 | False # So the trainer won't try loading its state 180 | ) 181 | # The two files above have a different name depending on how they were saved, but are actually the same. 182 | if os.path.exists(checkpoint_name): 183 | print(f"Restarting from {checkpoint_name}") 184 | adapters_weights = torch.load(checkpoint_name) 185 | model = set_peft_model_state_dict(model, adapters_weights) 186 | else: 187 | print(f"Checkpoint {checkpoint_name} not found") 188 | 189 | model.print_trainable_parameters() # Be more transparent about the % of trainable params. 190 | 191 | if val_set_size > 0: 192 | train_val = data["train"].train_test_split( 193 | test_size=val_set_size, shuffle=True, seed=42 194 | ) 195 | train_data = ( 196 | train_val["train"].shuffle().map(generate_and_tokenize_prompt) 197 | ) 198 | val_data = ( 199 | train_val["test"].shuffle().map(generate_and_tokenize_prompt) 200 | ) 201 | else: 202 | train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) 203 | val_data = None 204 | 205 | if not ddp and torch.cuda.device_count() > 1: 206 | model.is_parallelizable = True 207 | model.model_parallel = True 208 | 209 | trainer = transformers.Trainer( 210 | model=model, 211 | train_dataset=train_data, 212 | eval_dataset=val_data, 213 | args=transformers.TrainingArguments( 214 | per_device_train_batch_size=micro_batch_size, 215 | gradient_accumulation_steps=gradient_accumulation_steps, 216 | warmup_steps=100, 217 | num_train_epochs=num_epochs, 218 | learning_rate=learning_rate, 219 | fp16=True, 220 | logging_steps=10, 221 | optim="adamw_torch", 222 | evaluation_strategy="steps" if val_set_size > 0 else "no", 223 | save_strategy="steps", 224 | eval_steps=200 if val_set_size > 0 else None, 225 | save_steps=200, 226 | output_dir=output_dir, 227 | save_total_limit=3, 228 | load_best_model_at_end=True if val_set_size > 0 else False, 229 | ddp_find_unused_parameters=False if ddp else None, 230 | group_by_length=group_by_length, 231 | report_to="wandb" if use_wandb else None, 232 | run_name=wandb_run_name if use_wandb else None, 233 | ), 234 | data_collator=transformers.DataCollatorForSeq2Seq( 235 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 236 | ), 237 | ) 238 | model.config.use_cache = False 239 | 240 | old_state_dict = model.state_dict 241 | model.state_dict = ( 242 | lambda self, *_, **__: get_peft_model_state_dict( 243 | self, old_state_dict() 244 | ) 245 | ).__get__(model, type(model)) 246 | 247 | if torch.__version__ >= "2" and sys.platform != "win32": 248 | model = torch.compile(model) 249 | 250 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 251 | 252 | model.save_pretrained(output_dir) 253 | 254 | print( 255 | "\n If there's a warning about missing keys above, please disregard :)" 256 | ) 257 | 258 | 259 | def generate_prompt(data_point): 260 | # sorry about the formatting disaster gotta move fast 261 | if data_point["input"]: 262 | return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501 263 | 264 | ### Instruction: 265 | {data_point["instruction"]} 266 | 267 | ### Input: 268 | {data_point["input"]} 269 | 270 | ### Response: 271 | {data_point["output"]}""" 272 | else: 273 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. 274 | 275 | ### Instruction: 276 | {data_point["instruction"]} 277 | 278 | ### Response: 279 | {data_point["output"]}""" 280 | 281 | 282 | if __name__ == "__main__": 283 | fire.Fire(train) 284 | -------------------------------------------------------------------------------- /finetune_multi_rec.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['LD_LIBRARY_PATH'] = '/data/baokq/miniconda3/envs/alpaca_lora/lib/' 3 | import sys 4 | from typing import List 5 | 6 | import numpy as np 7 | import fire 8 | import torch 9 | import transformers 10 | from datasets import load_dataset, concatenate_datasets 11 | from transformers import EarlyStoppingCallback 12 | 13 | """ 14 | Unused imports: 15 | import torch.nn as nn 16 | import bitsandbytes as bnb 17 | """ 18 | 19 | from peft import ( # noqa: E402 20 | LoraConfig, 21 | get_peft_model, 22 | get_peft_model_state_dict, 23 | prepare_model_for_int8_training, 24 | set_peft_model_state_dict, 25 | ) 26 | from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402 27 | from sklearn.metrics import roc_auc_score 28 | 29 | 30 | def train( 31 | # model/data params 32 | base_model: str = "", # the only required argument 33 | train_data_path: str = "", 34 | train_data_path2: str = "", 35 | val_data_path: str = "", 36 | val_data_path2: str = "", 37 | output_dir: str = "./lora-alpaca", 38 | sample: int = -1, 39 | seed: int = 0, 40 | # training hyperparams 41 | batch_size: int = 128, 42 | micro_batch_size: int = 4, 43 | num_epochs: int = 3, 44 | learning_rate: float = 3e-4, 45 | cutoff_len: int = 256, 46 | # lora hyperparams 47 | lora_r: int = 8, 48 | lora_alpha: int = 16, 49 | lora_dropout: float = 0.05, 50 | lora_target_modules: List[str] = [ 51 | "q_proj", 52 | "v_proj", 53 | ], 54 | # llm hyperparams 55 | train_on_inputs: bool = True, # if False, masks out inputs in loss 56 | group_by_length: bool = False, # faster, but produces an odd training loss curve 57 | # wandb params 58 | wandb_project: str = "", 59 | wandb_run_name: str = "", 60 | wandb_watch: str = "", # options: false | gradients | all 61 | wandb_log_model: str = "", # options: false | true 62 | resume_from_checkpoint: str = None, # either training checkpoint or final adapter 63 | 64 | ): 65 | print( 66 | f"Training Alpaca-LoRA model with params:\n" 67 | f"base_model: {base_model}\n" 68 | f"train_data_path: {train_data_path}\n" 69 | f"val_data_path: {val_data_path}\n" 70 | f"sample: {sample}\n" 71 | f"seed: {seed}\n" 72 | f"output_dir: {output_dir}\n" 73 | f"batch_size: {batch_size}\n" 74 | f"micro_batch_size: {micro_batch_size}\n" 75 | f"num_epochs: {num_epochs}\n" 76 | f"learning_rate: {learning_rate}\n" 77 | f"cutoff_len: {cutoff_len}\n" 78 | f"lora_r: {lora_r}\n" 79 | f"lora_alpha: {lora_alpha}\n" 80 | f"lora_dropout: {lora_dropout}\n" 81 | f"lora_target_modules: {lora_target_modules}\n" 82 | f"train_on_inputs: {train_on_inputs}\n" 83 | f"group_by_length: {group_by_length}\n" 84 | f"wandb_project: {wandb_project}\n" 85 | f"wandb_run_name: {wandb_run_name}\n" 86 | f"wandb_watch: {wandb_watch}\n" 87 | f"wandb_log_model: {wandb_log_model}\n" 88 | f"resume_from_checkpoint: {resume_from_checkpoint}\n" 89 | ) 90 | assert ( 91 | base_model 92 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" 93 | gradient_accumulation_steps = batch_size // micro_batch_size 94 | # print(f"gradient_accumulation_steps: {gradient_accumulation_steps}") 95 | 96 | device_map = "auto" 97 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 98 | ddp = world_size != 1 99 | if ddp: 100 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 101 | gradient_accumulation_steps = gradient_accumulation_steps // world_size 102 | 103 | # Check if parameter passed or if set within environ 104 | use_wandb = len(wandb_project) > 0 or ( 105 | "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 106 | ) 107 | # Only overwrite environ if wandb param passed 108 | if len(wandb_project) > 0: 109 | os.environ["WANDB_PROJECT"] = wandb_project 110 | if len(wandb_watch) > 0: 111 | os.environ["WANDB_WATCH"] = wandb_watch 112 | if len(wandb_log_model) > 0: 113 | os.environ["WANDB_LOG_MODEL"] = wandb_log_model 114 | 115 | model = LlamaForCausalLM.from_pretrained( 116 | base_model, 117 | load_in_8bit=True, 118 | torch_dtype=torch.float16, 119 | device_map=device_map, 120 | ) 121 | 122 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 123 | 124 | tokenizer.pad_token_id = ( 125 | 0 # unk. we want this to be different from the eos token 126 | ) 127 | tokenizer.padding_side = "left" # Allow batched inference 128 | 129 | def tokenize(prompt, add_eos_token=True): 130 | # there's probably a way to do this with the tokenizer settings 131 | # but again, gotta move fast 132 | result = tokenizer( 133 | prompt, 134 | truncation=True, 135 | max_length=cutoff_len, 136 | padding=False, 137 | return_tensors=None, 138 | ) 139 | if ( 140 | result["input_ids"][-1] != tokenizer.eos_token_id 141 | and len(result["input_ids"]) < cutoff_len 142 | and add_eos_token 143 | ): 144 | result["input_ids"].append(tokenizer.eos_token_id) 145 | result["attention_mask"].append(1) 146 | 147 | result["labels"] = result["input_ids"].copy() 148 | 149 | return result 150 | 151 | def generate_and_tokenize_prompt(data_point): 152 | full_prompt = generate_prompt(data_point) 153 | tokenized_full_prompt = tokenize(full_prompt) 154 | if not train_on_inputs: 155 | user_prompt = generate_prompt({**data_point, "output": ""}) 156 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) 157 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 158 | 159 | tokenized_full_prompt["labels"] = [ 160 | -100 161 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 162 | user_prompt_len: 163 | ] # could be sped up, probably 164 | return tokenized_full_prompt 165 | 166 | model = prepare_model_for_int8_training(model) 167 | 168 | config = LoraConfig( 169 | r=lora_r, 170 | lora_alpha=lora_alpha, 171 | target_modules=lora_target_modules, 172 | lora_dropout=lora_dropout, 173 | bias="none", 174 | task_type="CAUSAL_LM", 175 | ) 176 | model = get_peft_model(model, config) 177 | 178 | 179 | 180 | if train_data_path.endswith(".json"): # todo: support jsonl 181 | train_data = load_dataset("json", data_files=train_data_path) 182 | else: 183 | train_data = load_dataset(train_data_path) 184 | 185 | if val_data_path.endswith(".json"): # todo: support jsonl 186 | val_data = load_dataset("json", data_files=val_data_path) 187 | else: 188 | val_data = load_dataset(val_data_path) 189 | 190 | if train_data_path2.endswith(".json"): # todo: support jsonl 191 | train_data2 = load_dataset("json", data_files=train_data_path2) 192 | else: 193 | train_data2 = load_dataset(train_data_path2) 194 | 195 | if val_data_path2.endswith(".json"): # todo: support jsonl 196 | val_data2 = load_dataset("json", data_files=val_data_path2) 197 | else: 198 | val_data2 = load_dataset(val_data_path2) 199 | 200 | 201 | 202 | # train_data = train_data.shuffle(seed=42)[:sample] if sample > -1 else train_data 203 | # print(len(train_data)) 204 | if resume_from_checkpoint: 205 | # Check the available weights and load them 206 | checkpoint_name = os.path.join( 207 | resume_from_checkpoint, "pytorch_model.bin" 208 | ) # Full checkpoint 209 | if not os.path.exists(checkpoint_name): 210 | checkpoint_name = os.path.join( 211 | resume_from_checkpoint, "adapter_model.bin" 212 | ) # only LoRA model - LoRA config above has to fit 213 | resume_from_checkpoint = ( 214 | False # So the trainer won't try loading its state 215 | ) 216 | # The two files above have a different name depending on how they were saved, but are actually the same. 217 | if os.path.exists(checkpoint_name): 218 | print(f"Restarting from {checkpoint_name}") 219 | adapters_weights = torch.load(checkpoint_name) 220 | model = set_peft_model_state_dict(model, adapters_weights) 221 | else: 222 | print(f"Checkpoint {checkpoint_name} not found") 223 | 224 | model.print_trainable_parameters() # Be more transparent about the % of trainable params. 225 | 226 | train_data["train"] = train_data["train"].shuffle(seed=seed).select(range(sample)) if sample > -1 else train_data["train"].shuffle(seed=seed) 227 | train_data["train"] = train_data["train"].shuffle(seed=seed) 228 | train_data2["train"] = train_data2["train"].shuffle(seed=seed).select(range(sample)) if sample > -1 else train_data2["train"].shuffle(seed=seed) 229 | train_data2["train"] = train_data2["train"].shuffle(seed=seed) 230 | train_data["train"] = concatenate_datasets([train_data["train"], train_data2["train"]]) 231 | # print(train_data) 232 | train_data = (train_data["train"].map(generate_and_tokenize_prompt)) 233 | val_data = (val_data["train"].map(generate_and_tokenize_prompt)) 234 | if not ddp and torch.cuda.device_count() > 1: 235 | model.is_parallelizable = True 236 | model.model_parallel = True 237 | 238 | def compute_metrics(eval_preds): 239 | pre, labels = eval_preds 240 | auc = roc_auc_score(pre[1], pre[0]) 241 | return {'auc': auc} 242 | 243 | def preprocess_logits_for_metrics(logits, labels): 244 | """ 245 | Original Trainer may have a memory leak. 246 | This is a workaround to avoid storing too many tensors that are not needed. 247 | """ 248 | labels_index = torch.argwhere(torch.bitwise_or(labels == 8241, labels == 3782)) 249 | gold = torch.where(labels[labels_index[:, 0], labels_index[:, 1]] == 3782, 0, 1) 250 | labels_index[: , 1] = labels_index[: , 1] - 1 251 | logits = logits.softmax(dim=-1) 252 | logits = torch.softmax(logits[labels_index[:, 0], labels_index[:, 1]][:,[3782, 8241]], dim = -1) 253 | return logits[:, 1][2::3], gold[2::3] 254 | 255 | os.environ["WANDB_DISABLED"] = "true" 256 | 257 | if sample > -1: 258 | if sample <= 128 : 259 | eval_step = 10 260 | else: 261 | eval_step = sample / 128 * 5 262 | 263 | trainer = transformers.Trainer( 264 | model=model, 265 | train_dataset=train_data, 266 | eval_dataset=val_data, 267 | args=transformers.TrainingArguments( 268 | per_device_train_batch_size=micro_batch_size, 269 | gradient_accumulation_steps=gradient_accumulation_steps, 270 | warmup_steps=20, 271 | num_train_epochs=num_epochs, 272 | learning_rate=learning_rate, 273 | fp16=True, 274 | logging_steps=8, 275 | optim="adamw_torch", 276 | evaluation_strategy="steps", 277 | save_strategy="steps", 278 | eval_steps=eval_step, 279 | save_steps=eval_step, 280 | output_dir=output_dir, 281 | save_total_limit=1, 282 | load_best_model_at_end=True, 283 | metric_for_best_model="eval_auc", 284 | ddp_find_unused_parameters=False if ddp else None, 285 | group_by_length=group_by_length, 286 | report_to=None, 287 | # report_to="wandb" if use_wandb else None, 288 | # run_name=wandb_run_name if use_wandb else None, 289 | # eval_accumulation_steps=10, 290 | ), 291 | data_collator=transformers.DataCollatorForSeq2Seq( 292 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 293 | ), 294 | compute_metrics=compute_metrics, 295 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 296 | callbacks = [EarlyStoppingCallback(early_stopping_patience=10)] 297 | ) 298 | model.config.use_cache = False 299 | 300 | old_state_dict = model.state_dict 301 | model.state_dict = ( 302 | lambda self, *_, **__: get_peft_model_state_dict( 303 | self, old_state_dict() 304 | ) 305 | ).__get__(model, type(model)) 306 | 307 | if torch.__version__ >= "2" and sys.platform != "win32": 308 | model = torch.compile(model) 309 | 310 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 311 | 312 | model.save_pretrained(output_dir) 313 | 314 | print( 315 | "\n If there's a warning about missing keys above, please disregard :)" 316 | ) 317 | 318 | 319 | def generate_prompt(data_point): 320 | # sorry about the formatting disaster gotta move fast 321 | if data_point["input"]: 322 | return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501 323 | 324 | ### Instruction: 325 | {data_point["instruction"]} 326 | 327 | ### Input: 328 | {data_point["input"]} 329 | 330 | ### Response: 331 | {data_point["output"]}""" 332 | else: 333 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. 334 | 335 | ### Instruction: 336 | {data_point["instruction"]} 337 | 338 | ### Response: 339 | {data_point["output"]}""" 340 | 341 | 342 | if __name__ == "__main__": 343 | fire.Fire(train) 344 | -------------------------------------------------------------------------------- /finetune_rec.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['LD_LIBRARY_PATH'] = '/data/baokq/miniconda3/envs/alpaca_lora/lib/' 3 | import sys 4 | from typing import List 5 | 6 | import numpy as np 7 | import fire 8 | import torch 9 | import transformers 10 | from datasets import load_dataset 11 | from transformers import EarlyStoppingCallback 12 | 13 | """ 14 | Unused imports: 15 | import torch.nn as nn 16 | import bitsandbytes as bnb 17 | """ 18 | 19 | from peft import ( # noqa: E402 20 | LoraConfig, 21 | get_peft_model, 22 | get_peft_model_state_dict, 23 | prepare_model_for_int8_training, 24 | set_peft_model_state_dict, 25 | ) 26 | from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402 27 | from sklearn.metrics import roc_auc_score 28 | 29 | def train( 30 | # model/data params 31 | base_model: str = "", # the only required argument 32 | train_data_path: str = "", 33 | val_data_path: str = "", 34 | output_dir: str = "./lora-alpaca", 35 | sample: int = -1, 36 | seed: int = 0, 37 | # training hyperparams 38 | batch_size: int = 128, 39 | micro_batch_size: int = 4, 40 | num_epochs: int = 3, 41 | learning_rate: float = 3e-4, 42 | cutoff_len: int = 256, 43 | # lora hyperparams 44 | lora_r: int = 8, 45 | lora_alpha: int = 16, 46 | lora_dropout: float = 0.05, 47 | lora_target_modules: List[str] = [ 48 | "q_proj", 49 | "v_proj", 50 | ], 51 | # llm hyperparams 52 | train_on_inputs: bool = True, # if False, masks out inputs in loss 53 | group_by_length: bool = False, # faster, but produces an odd training loss curve 54 | # wandb params 55 | wandb_project: str = "", 56 | wandb_run_name: str = "", 57 | wandb_watch: str = "", # options: false | gradients | all 58 | wandb_log_model: str = "", # options: false | true 59 | resume_from_checkpoint: str = None, # either training checkpoint or final adapter 60 | 61 | ): 62 | print( 63 | f"Training Alpaca-LoRA model with params:\n" 64 | f"base_model: {base_model}\n" 65 | f"train_data_path: {train_data_path}\n" 66 | f"val_data_path: {val_data_path}\n" 67 | f"sample: {sample}\n" 68 | f"seed: {seed}\n" 69 | f"output_dir: {output_dir}\n" 70 | f"batch_size: {batch_size}\n" 71 | f"micro_batch_size: {micro_batch_size}\n" 72 | f"num_epochs: {num_epochs}\n" 73 | f"learning_rate: {learning_rate}\n" 74 | f"cutoff_len: {cutoff_len}\n" 75 | f"lora_r: {lora_r}\n" 76 | f"lora_alpha: {lora_alpha}\n" 77 | f"lora_dropout: {lora_dropout}\n" 78 | f"lora_target_modules: {lora_target_modules}\n" 79 | f"train_on_inputs: {train_on_inputs}\n" 80 | f"group_by_length: {group_by_length}\n" 81 | f"wandb_project: {wandb_project}\n" 82 | f"wandb_run_name: {wandb_run_name}\n" 83 | f"wandb_watch: {wandb_watch}\n" 84 | f"wandb_log_model: {wandb_log_model}\n" 85 | f"resume_from_checkpoint: {resume_from_checkpoint}\n" 86 | ) 87 | assert ( 88 | base_model 89 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" 90 | gradient_accumulation_steps = batch_size // micro_batch_size 91 | # print(f"gradient_accumulation_steps: {gradient_accumulation_steps}") 92 | 93 | device_map = "auto" 94 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 95 | ddp = world_size != 1 96 | if ddp: 97 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 98 | gradient_accumulation_steps = gradient_accumulation_steps // world_size 99 | 100 | # Check if parameter passed or if set within environ 101 | use_wandb = len(wandb_project) > 0 or ( 102 | "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 103 | ) 104 | # Only overwrite environ if wandb param passed 105 | if len(wandb_project) > 0: 106 | os.environ["WANDB_PROJECT"] = wandb_project 107 | if len(wandb_watch) > 0: 108 | os.environ["WANDB_WATCH"] = wandb_watch 109 | if len(wandb_log_model) > 0: 110 | os.environ["WANDB_LOG_MODEL"] = wandb_log_model 111 | 112 | model = LlamaForCausalLM.from_pretrained( 113 | base_model, 114 | load_in_8bit=True, 115 | torch_dtype=torch.float16, 116 | device_map=device_map, 117 | ) 118 | 119 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 120 | 121 | tokenizer.pad_token_id = ( 122 | 0 # unk. we want this to be different from the eos token 123 | ) 124 | tokenizer.padding_side = "left" # Allow batched inference 125 | 126 | def tokenize(prompt, add_eos_token=True): 127 | # there's probably a way to do this with the tokenizer settings 128 | # but again, gotta move fast 129 | result = tokenizer( 130 | prompt, 131 | truncation=True, 132 | max_length=cutoff_len, 133 | padding=False, 134 | return_tensors=None, 135 | ) 136 | if ( 137 | result["input_ids"][-1] != tokenizer.eos_token_id 138 | and len(result["input_ids"]) < cutoff_len 139 | and add_eos_token 140 | ): 141 | result["input_ids"].append(tokenizer.eos_token_id) 142 | result["attention_mask"].append(1) 143 | 144 | result["labels"] = result["input_ids"].copy() 145 | 146 | return result 147 | 148 | def generate_and_tokenize_prompt(data_point): 149 | full_prompt = generate_prompt(data_point) 150 | tokenized_full_prompt = tokenize(full_prompt) 151 | if not train_on_inputs: 152 | user_prompt = generate_prompt({**data_point, "output": ""}) 153 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) 154 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 155 | 156 | tokenized_full_prompt["labels"] = [ 157 | -100 158 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 159 | user_prompt_len: 160 | ] # could be sped up, probably 161 | return tokenized_full_prompt 162 | 163 | model = prepare_model_for_int8_training(model) 164 | 165 | config = LoraConfig( 166 | r=lora_r, 167 | lora_alpha=lora_alpha, 168 | target_modules=lora_target_modules, 169 | lora_dropout=lora_dropout, 170 | bias="none", 171 | task_type="CAUSAL_LM", 172 | ) 173 | model = get_peft_model(model, config) 174 | 175 | 176 | 177 | if train_data_path.endswith(".json"): # todo: support jsonl 178 | train_data = load_dataset("json", data_files=train_data_path) 179 | else: 180 | train_data = load_dataset(train_data_path) 181 | 182 | if val_data_path.endswith(".json"): # todo: support jsonl 183 | val_data = load_dataset("json", data_files=val_data_path) 184 | else: 185 | val_data = load_dataset(val_data_path) 186 | 187 | 188 | # train_data = train_data.shuffle(seed=42)[:sample] if sample > -1 else train_data 189 | # print(len(train_data)) 190 | if resume_from_checkpoint: 191 | # Check the available weights and load them 192 | checkpoint_name = os.path.join( 193 | resume_from_checkpoint, "pytorch_model.bin" 194 | ) # Full checkpoint 195 | if not os.path.exists(checkpoint_name): 196 | checkpoint_name = os.path.join( 197 | resume_from_checkpoint, "adapter_model.bin" 198 | ) # only LoRA model - LoRA config above has to fit 199 | resume_from_checkpoint = ( 200 | False # So the trainer won't try loading its state 201 | ) 202 | # The two files above have a different name depending on how they were saved, but are actually the same. 203 | if os.path.exists(checkpoint_name): 204 | print(f"Restarting from {checkpoint_name}") 205 | adapters_weights = torch.load(checkpoint_name) 206 | model = set_peft_model_state_dict(model, adapters_weights) 207 | else: 208 | print(f"Checkpoint {checkpoint_name} not found") 209 | 210 | model.print_trainable_parameters() # Be more transparent about the % of trainable params. 211 | 212 | train_data["train"] = train_data["train"].shuffle(seed=seed).select(range(sample)) if sample > -1 else train_data["train"].shuffle(seed=seed) 213 | train_data["train"] = train_data["train"].shuffle(seed=seed) 214 | train_data = (train_data["train"].map(generate_and_tokenize_prompt)) 215 | val_data = (val_data["train"].map(generate_and_tokenize_prompt)) 216 | if not ddp and torch.cuda.device_count() > 1: 217 | model.is_parallelizable = True 218 | model.model_parallel = True 219 | 220 | def compute_metrics(eval_preds): 221 | pre, labels = eval_preds 222 | auc = roc_auc_score(pre[1], pre[0]) 223 | return {'auc': auc} 224 | 225 | def preprocess_logits_for_metrics(logits, labels): 226 | """ 227 | Original Trainer may have a memory leak. 228 | This is a workaround to avoid storing too many tensors that are not needed. 229 | """ 230 | labels_index = torch.argwhere(torch.bitwise_or(labels == 8241, labels == 3782)) 231 | gold = torch.where(labels[labels_index[:, 0], labels_index[:, 1]] == 3782, 0, 1) 232 | labels_index[: , 1] = labels_index[: , 1] - 1 233 | logits = logits.softmax(dim=-1) 234 | logits = torch.softmax(logits[labels_index[:, 0], labels_index[:, 1]][:,[3782, 8241]], dim = -1) 235 | return logits[:, 1][2::3], gold[2::3] 236 | 237 | os.environ["WANDB_DISABLED"] = "true" 238 | 239 | if sample > -1: 240 | if sample <= 128 : 241 | eval_step = 10 242 | else: 243 | eval_step = sample / 128 * 5 244 | 245 | trainer = transformers.Trainer( 246 | model=model, 247 | train_dataset=train_data, 248 | eval_dataset=val_data, 249 | args=transformers.TrainingArguments( 250 | per_device_train_batch_size=micro_batch_size, 251 | gradient_accumulation_steps=gradient_accumulation_steps, 252 | warmup_steps=20, 253 | num_train_epochs=num_epochs, 254 | learning_rate=learning_rate, 255 | fp16=True, 256 | logging_steps=8, 257 | optim="adamw_torch", 258 | evaluation_strategy="steps", 259 | save_strategy="steps", 260 | eval_steps=eval_step, 261 | save_steps=eval_step, 262 | output_dir=output_dir, 263 | save_total_limit=1, 264 | load_best_model_at_end=True, 265 | metric_for_best_model="eval_auc", 266 | ddp_find_unused_parameters=False if ddp else None, 267 | group_by_length=group_by_length, 268 | report_to=None, 269 | # report_to="wandb" if use_wandb else None, 270 | # run_name=wandb_run_name if use_wandb else None, 271 | # eval_accumulation_steps=10, 272 | ), 273 | data_collator=transformers.DataCollatorForSeq2Seq( 274 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 275 | ), 276 | compute_metrics=compute_metrics, 277 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 278 | callbacks = [EarlyStoppingCallback(early_stopping_patience=10)] 279 | ) 280 | model.config.use_cache = False 281 | 282 | old_state_dict = model.state_dict 283 | model.state_dict = ( 284 | lambda self, *_, **__: get_peft_model_state_dict( 285 | self, old_state_dict() 286 | ) 287 | ).__get__(model, type(model)) 288 | 289 | if torch.__version__ >= "2" and sys.platform != "win32": 290 | model = torch.compile(model) 291 | 292 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 293 | 294 | model.save_pretrained(output_dir) 295 | 296 | print( 297 | "\n If there's a warning about missing keys above, please disregard :)" 298 | ) 299 | 300 | 301 | def generate_prompt(data_point): 302 | # sorry about the formatting disaster gotta move fast 303 | if data_point["input"]: 304 | return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501 305 | 306 | ### Instruction: 307 | {data_point["instruction"]} 308 | 309 | ### Input: 310 | {data_point["input"]} 311 | 312 | ### Response: 313 | {data_point["output"]}""" 314 | else: 315 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501 316 | 317 | ### Instruction: 318 | {data_point["instruction"]} 319 | 320 | ### Response: 321 | {data_point["output"]}""" 322 | 323 | 324 | if __name__ == "__main__": 325 | fire.Fire(train) 326 | -------------------------------------------------------------------------------- /preprocess_book.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | rating = pd.read_csv('BX-Book-Ratings.csv', sep=';', encoding="latin-1") 3 | users = pd.read_csv('BX-Users.csv', sep=';', encoding="latin-1") 4 | books = pd.read_csv('BX-Books.csv', sep=';', encoding="latin-1", error_bad_lines=False) 5 | rating = pd.merge(rating, books, on='ISBN', how='inner') 6 | books.to_csv('book_item_mapping.csv', index=True) 7 | 8 | from tqdm import tqdm 9 | user_dict = {} 10 | item_id = {} 11 | for index, row in tqdm(books.iterrows()): 12 | item_id[row['ISBN']] = index 13 | for index, row in tqdm(rating.iterrows()): 14 | userid = row['User-ID'] 15 | if not user_dict.__contains__(userid): 16 | user_dict[userid] = { 17 | 'ISBN': [], 18 | 'Book-Rating': [], 19 | 'Book-Title': [], 20 | 'Book-Author': [], 21 | 'Year-Of-Publication': [], 22 | } 23 | user_dict[userid]['ISBN'].append(item_id[row['ISBN']]) 24 | user_dict[userid]['Book-Rating'].append(float(row['Book-Rating'])) 25 | user_dict[userid]['Book-Title'].append(row['Book-Title']) 26 | user_dict[userid]['Book-Author'].append(row['Book-Author']) 27 | user_dict[userid]['Year-Of-Publication'].append(row['Year-Of-Publication']) 28 | 29 | new_user_dict = {} 30 | for key in user_dict.keys(): 31 | mx = max(mx, len(user_dict[key]['ISBN'])) 32 | if len(user_dict[key]['ISBN']) <= 3: 33 | pass 34 | else: 35 | new_user_dict[key] = user_dict[key] 36 | 37 | import random 38 | import json 39 | user_list = list(new_user_dict.keys()) 40 | random.shuffle(user_list) 41 | train_user = user_list[:int(len(user_list) * 0.8)] 42 | valid_usser = user_list[int(len(user_list) * 0.8):int(len(user_list) * 0.9)] 43 | test_user = user_list[int(len(user_list) * 0.9):] 44 | 45 | def generate_csv(user_list, output_csv, output_json): 46 | nrows = [] 47 | for user in user_list: 48 | item_id = user_dict[user]['ISBN'] 49 | rating = [int(_ > 5) for _ in user_dict[user]['Book-Rating']] 50 | random.seed(42) 51 | random.shuffle(item_id) 52 | random.seed(42) 53 | random.shuffle(rating) 54 | nrows.append([user, item_id[:-1][:10], rating[:-1][:10], item_id[-1], rating[-1]]) 55 | with open(output_csv, 'w') as f: 56 | import csv 57 | writer = csv.writer(f) 58 | writer.writerow(['user', 'history_item_id','history_rating','item_id','rating']) 59 | writer.writerows(nrows) 60 | Prompt_json = [] 61 | for user in user_list: 62 | item_id = user_dict[user]['ISBN'] 63 | rating = [int(_ > 5) for _ in user_dict[user]['Book-Rating']] 64 | book_title = user_dict[user]['Book-Title'] 65 | book_author = user_dict[user]['Book-Author'] 66 | random.seed(42) 67 | random.shuffle(item_id) 68 | random.seed(42) 69 | random.shuffle(rating) 70 | random.seed(42) 71 | random.shuffle(book_title) 72 | random.seed(42) 73 | random.shuffle(book_author) 74 | preference = [] 75 | unpreference = [] 76 | for i in range(min(len(item_id) - 1, 10)): 77 | if rating[i] == 1: 78 | preference.append("\"" + book_title[i] + "\"" + " written by " + book_author[i]) 79 | else: 80 | unpreference.append("\"" + book_title[i] + "\"" + " written by " + book_author[i]) 81 | preference_str = "" 82 | unpreference_str = "" 83 | for i in range(len(preference)): 84 | if i == 0: 85 | preference_str += preference[i] 86 | else: 87 | preference_str += ", " + preference[i] 88 | for i in range(len(unpreference)): 89 | if i == 0: 90 | unpreference_str += unpreference[i] 91 | else: 92 | unpreference_str += ", " + unpreference[i] 93 | target_preference_str = "Yes." if rating[-1] == 1 else "No." 94 | target_book_str = "\"" + book_title[-1] + "\"" + "written by" + book_author[-1] 95 | Prompt_json.append({ 96 | "instruction": "Given the user's preference and unpreference, identify whether the user will like the target book by answering \"Yes.\" or \"No.\".", 97 | "input": f"User Preference: {preference_str}\nUser Unpreference: {unpreference_str}\nWhether the user will like the target book {target_book_str}?", 98 | "output": target_preference_str, 99 | }) 100 | with open(output_json, 'w') as f: 101 | json.dump(Prompt_json, f, indent=4) 102 | 103 | generate_csv(train_user, 'train_book.csv', 'train_book.json') 104 | generate_csv(valid_usser, 'valid_book.csv', 'valid_book.json') 105 | generate_csv(test_user, 'test_book.csv', 'test_book.json') -------------------------------------------------------------------------------- /preprocess_movie.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | f = open('u.data', 'r') 4 | data = f.readlines() 5 | f = open('u.item', 'r', encoding='ISO-8859-1') 6 | movies = f.readlines() 7 | f = open('u.user', 'r') 8 | users = f.readlines() 9 | 10 | movie_names = [_.split('|')[1] for _ in movies] # movie_names[0] = 'Toy Story (1995)' 11 | user_ids = [_.split('|')[0] for _ in users] # user_ids[0] = '1' 12 | movie_ids = [_.split('|')[0] for _ in movies] # movie_ids[0] = '1' 13 | interaction_dicts = dict() 14 | for line in data: 15 | user_id, movie_id, rating, timestamp = line.split('\t') 16 | if user_id not in interaction_dicts: 17 | interaction_dicts[user_id] = { 18 | 'movie_id': [], 19 | 'rating': [], 20 | 'timestamp': [], 21 | } 22 | interaction_dicts[user_id]['movie_id'].append(movie_id) 23 | interaction_dicts[user_id]['rating'].append(int(int(rating) > 3)) 24 | interaction_dicts[user_id]['timestamp'].append(timestamp) 25 | 26 | with open('item_mapping.csv', 'w') as f: 27 | import csv 28 | writer = csv.writer(f) 29 | writer.writerow(['movie_id', 'movie_name']) 30 | for i, name in enumerate(movie_names): 31 | writer.writerow([i + 1, name]) 32 | 33 | sequential_interaction_list = [] 34 | seq_len = 10 35 | for user_id in interaction_dicts: 36 | temp = zip(interaction_dicts[user_id]['movie_id'], interaction_dicts[user_id]['rating'], interaction_dicts[user_id]['timestamp']) 37 | temp = sorted(temp, key=lambda x: x[2]) 38 | result = zip(*temp) 39 | interaction_dicts[user_id]['movie_id'], interaction_dicts[user_id]['rating'], interaction_dicts[user_id]['timestamp'] = [list(_) for _ in result] 40 | for i in range(10, len(interaction_dicts[user_id]['movie_id'])): 41 | sequential_interaction_list.append( 42 | [user_id, interaction_dicts[user_id]['movie_id'][i-seq_len:i], interaction_dicts[user_id]['rating'][i-seq_len:i], interaction_dicts[user_id]['movie_id'][i], interaction_dicts[user_id]['rating'][i], interaction_dicts[user_id]['timestamp'][i].strip('\n')] 43 | ) 44 | 45 | sequential_interaction_list = sequential_interaction_list[-10000:] # 10000 records 46 | 47 | 48 | import csv 49 | # save the csv file for baselines 50 | with open('./data/train.csv', 'w') as f: 51 | writer = csv.writer(f) 52 | writer.writerow(['user_id', 'history_movie_id', 'history_rating', 'movie_id', 'rating', 'timestamp']) 53 | writer.writerows(sequential_interaction_list[:int(len(sequential_interaction_list)*0.8)]) 54 | with open('./data/valid.csv', 'w') as f: 55 | writer = csv.writer(f) 56 | writer.writerow(['user_id', 'history_movie_id', 'history_rating', 'movie_id', 'rating', 'timestamp']) 57 | writer.writerows(sequential_interaction_list[int(len(sequential_interaction_list)*0.8):int(len(sequential_interaction_list)*0.9)]) 58 | with open('./data/test.csv', 'w') as f: 59 | writer = csv.writer(f) 60 | writer.writerow(['user_id', 'history_movie_id', 'history_rating', 'movie_id', 'rating', 'timestamp']) 61 | writer.writerows(sequential_interaction_list[int(len(sequential_interaction_list)*0.9):]) 62 | 63 | def csv_to_json(input_path, output_path): 64 | data = pd.read_csv(input_path) 65 | json_list = [] 66 | for index, row in data.iterrows(): 67 | row['history_movie_id'] = eval(row['history_movie_id']) 68 | row['history_rating'] = eval(row['history_rating']) 69 | L = len(row['history_movie_id']) 70 | preference = [] 71 | unpreference = [] 72 | for i in range(L): 73 | if int(row['history_rating'][i]) == 1: 74 | preference.append(movie_names[int(row['history_movie_id'][i]) - 1]) 75 | else: 76 | unpreference.append(movie_names[int(row['history_movie_id'][i]) - 1]) 77 | target_movie = movie_names[int(row['movie_id']) - 1] 78 | preference_str = "" 79 | unpreference_str = "" 80 | for i in range(len(preference)): 81 | if i == 0: 82 | preference_str += "\"" + preference[i] + "\"" 83 | else: 84 | preference_str += ", \"" + preference[i] + "\"" 85 | for i in range(len(unpreference)): 86 | if i == 0: 87 | unpreference_str += "\"" + unpreference[i] + "\"" 88 | else: 89 | unpreference_str += ", \"" + unpreference[i] + "\"" 90 | target_preference = int(row['rating']) 91 | target_movie_str = "\"" + target_movie + "\"" 92 | target_preference_str = "Yes." if target_preference == 1 else "No." 93 | json_list.append({ 94 | "instruction": "Given the user's preference and unpreference, identify whether the user will like the target movie by answering \"Yes.\" or \"No.\".", 95 | "input": f"User Preference: {preference_str}\nUser Unpreference: {unpreference_str}\nWhether the user will like the target movie {target_movie_str}?", 96 | "output": target_preference_str, 97 | }) 98 | 99 | with open(output_path, 'w') as f: 100 | json.dump(json_list, f, indent=4) 101 | 102 | # generate the json file for the TALLRec 103 | csv_to_json('./data/train.csv', './data/train.json') 104 | csv_to_json('./data/valid.csv', './data/valid.json') 105 | csv_to_json('./data/test.csv', './data/test.json') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | appdirs 3 | loralib 4 | bitsandbytes==0.37.2 5 | black 6 | black[jupyter] 7 | datasets 8 | fire 9 | peft==0.3.0 10 | transformers==4.28.0 11 | sentencepiece 12 | gradio 13 | -------------------------------------------------------------------------------- /shell/evaluate.sh: -------------------------------------------------------------------------------- 1 | CUDA_ID=$1 2 | output_dir=$2 3 | model_path=$(ls -d $output_dir*) 4 | base_model=XXX 5 | test_data=XXX 6 | for path in $model_path 7 | do 8 | echo $path 9 | CUDA_VISIBLE_DEVICES=$CUDA_ID python evaluate.py \ 10 | --base_model $base_model \ 11 | --lora_weights $path \ 12 | --test_data_path $test_data \ 13 | --result_json_data $2.json 14 | done 15 | -------------------------------------------------------------------------------- /shell/instruct_7B.sh: -------------------------------------------------------------------------------- 1 | echo $1, $2 2 | seed=$2 3 | output_dir=XXX 4 | base_model=XXX 5 | train_data=XXX 6 | val_data=XXX 7 | instruction_model=XXX 8 | for lr in 1e-4 9 | do 10 | for dropout in 0.05 11 | do 12 | for sample in 64 13 | do 14 | mkdir -p $output_dir 15 | echo "lr: $lr, dropout: $dropout , seed: $seed, sample: $sample" 16 | CUDA_VISIBLE_DEVICES=$1 python -u finetune_rec.py \ 17 | --base_model $base_model \ 18 | --train_data_path $train_data \ 19 | --val_data_path $val_data \ 20 | --output_dir ${output_dir}_${seed}_${sample} \ 21 | --batch_size 128 \ 22 | --micro_batch_size 32 \ 23 | --num_epochs 200 \ 24 | --learning_rate $lr \ 25 | --cutoff_len 512 \ 26 | --lora_r 8 \ 27 | --lora_alpha 16\ 28 | --lora_dropout $dropout \ 29 | --lora_target_modules '[q_proj,v_proj]' \ 30 | --train_on_inputs \ 31 | --group_by_length \ 32 | --resume_from_checkpoint $instruction_model \ 33 | --sample $sample \ 34 | --seed $2 35 | done 36 | done 37 | done 38 | 39 | -------------------------------------------------------------------------------- /shell/instruct_multi_7B.sh: -------------------------------------------------------------------------------- 1 | echo $1, $2 2 | seed=$2 3 | output_dir=XXX 4 | base_model=XXX 5 | train_data=XXX 6 | train_data2=XXX 7 | val_data=XXX 8 | val_data2=XXX 9 | instruction_model=XXX 10 | for lr in 1e-4 11 | do 12 | for dropout in 0.05 13 | do 14 | for sample in 1 2 4 8 16 32 64 128 256 512 15 | do 16 | mkdir -p $output_dir 17 | echo "lr: $lr, dropout: $dropout , seed: $seed, sample: $sample" 18 | CUDA_VISIBLE_DEVICES=$1 python -u finetune_rec.py \ 19 | --base_model $base_model \ 20 | --train_data_path $train_data \ 21 | --train_data_path2 $train_data2 \ 22 | --val_data_path $val_data \ 23 | --val_data_path2 $val_data_path2 \ 24 | --output_dir ${output_dir}_${seed}_${sample}\ 25 | --batch_size 128 \ 26 | --micro_batch_size 64 \ 27 | --num_epochs 200 \ 28 | --learning_rate $lr \ 29 | --cutoff_len 512 \ 30 | --lora_r 8 \ 31 | --lora_alpha 16\ 32 | --lora_dropout $dropout \ 33 | --lora_target_modules '[q_proj,v_proj]' \ 34 | --train_on_inputs \ 35 | --group_by_length \ 36 | --resume_from_checkpoint $instruction_model \ 37 | --sample $sample \ 38 | --seed $2 39 | done 40 | done 41 | done 42 | 43 | --------------------------------------------------------------------------------