├── .ipynb_checkpoints ├── LICENSE-checkpoint ├── README-checkpoint.md ├── evaluation-checkpoint.py ├── generate-checkpoint.py ├── pyproject-checkpoint.toml └── rec-checkpoint.py ├── DATA_LICENSE ├── LICENSE ├── README.md ├── checkpoint ├── movies │ ├── adapter_config.json │ └── adapter_model.bin └── toys │ ├── adapter_config.json │ └── adapter_model.bin ├── data ├── .ipynb_checkpoints │ └── movies-checkpoint.json ├── movies.json ├── testset │ ├── .ipynb_checkpoints │ │ └── movies_test-checkpoint.json │ ├── movies_test.json │ └── toys_test.json └── toys.json ├── evaluation.py ├── export_hf_checkpoint.py ├── export_state_dict_checkpoint.py ├── generate.py ├── pyproject.toml ├── rec.py └── requirements.txt /.ipynb_checkpoints/LICENSE-checkpoint: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /.ipynb_checkpoints/README-checkpoint.md: -------------------------------------------------------------------------------- 1 | # GenRec 2 | Large Language Model for Generative Recommendation 3 | 4 | ## Install dependencies 5 | 6 | 7 | pip install -r requirements.txt 8 | 9 | 10 | ### Training (`rec.py`) 11 | 12 | 13 | python rec.py \ 14 | --base_model 'decapoda-research/llama-7b-hf' \ 15 | --data_path './moives' \ 16 | --output_dir './checkpoint' 17 | 18 | 19 | 20 | ### Inference (`generate.py`) 21 | 22 | 23 | 24 | python generate.py \ 25 | --load_8bit \ 26 | --base_model 'decapoda-research/llama-7b-hf' \ 27 | --lora_weights './checkpoint/movies' 28 | 29 | 30 | This project is implemented based on alpaca-lora (https://github.com/tloen/alpaca-lora) -------------------------------------------------------------------------------- /.ipynb_checkpoints/evaluation-checkpoint.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import fire 4 | import gradio as gr 5 | import torch 6 | import transformers 7 | from peft import PeftModel 8 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer 9 | import json 10 | import random 11 | from torch.utils.data import DataLoader, Sampler 12 | from tqdm import tqdm 13 | import math 14 | 15 | from datasets import load_dataset 16 | 17 | #import os 18 | #os.environ["CUDA_VISIBLE_DEVICES"] = "7" 19 | 20 | if torch.cuda.is_available(): 21 | device = "cuda" 22 | else: 23 | device = "cpu" 24 | 25 | try: 26 | if torch.backends.mps.is_available(): 27 | device = "mps" 28 | except: # noqa: E722 29 | pass 30 | 31 | 32 | load_8bit: bool = False, 33 | base_model = '/common/users/jj635/llama/llama-7b/' 34 | lora_weights = './checkpoint/movies' 35 | 36 | 37 | """ 38 | model = LlamaForCausalLM.from_pretrained( 39 | base_model, 40 | load_in_8bit=load_8bit, 41 | torch_dtype=torch.float16, 42 | device_map='auto', 43 | ) 44 | 45 | model = PeftModel.from_pretrained( 46 | model, 47 | lora_weights, 48 | torch_dtype=torch.float16, 49 | device_map='auto', 50 | ) 51 | """ 52 | 53 | 54 | model = LlamaForCausalLM.from_pretrained( 55 | base_model, 56 | load_in_8bit=load_8bit, 57 | torch_dtype=torch.float16, 58 | device_map={'':0},#could be 1 59 | ) 60 | model = PeftModel.from_pretrained( 61 | model, 62 | lora_weights, 63 | torch_dtype=torch.float16, 64 | device_map={'':0},#could be 1 65 | ) 66 | 67 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 68 | 69 | 70 | def generate_prompt(data_point): 71 | # sorry about the formatting disaster gotta move fast 72 | if data_point["input"]: 73 | return f""" # noqa: E501 74 | {data_point["instruction"]} 75 | 76 | ### input: 77 | {data_point["input"]} 78 | 79 | ### Response: 80 | {data_point["output"]}""" 81 | else: 82 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501 83 | 84 | ### Instruction: 85 | {data_point["instruction"]} 86 | 87 | ### Response: 88 | {data_point["output"]}""" 89 | 90 | 91 | def generate_and_tokenize_prompt(data_point): 92 | full_prompt = generate_prompt(data_point) 93 | tokenized_full_prompt = tokenize(full_prompt) 94 | user_prompt = generate_prompt({**data_point, "output": ""}) 95 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) 96 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 97 | 98 | tokenized_full_prompt["labels"] = [ 99 | -100 100 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 101 | user_prompt_len: 102 | ] # could be sped up, probably 103 | return tokenized_full_prompt 104 | 105 | 106 | def tokenize(prompt, add_eos_token=True): 107 | # there's probably a way to do this with the tokenizer settings 108 | # but again, gotta move fast 109 | cutoff_len = 256 110 | result = tokenizer( 111 | prompt, 112 | truncation=True, 113 | max_length=cutoff_len, 114 | padding=False, 115 | return_tensors=None, 116 | ) 117 | if ( 118 | result["input_ids"][-1] != tokenizer.eos_token_id 119 | and len(result["input_ids"]) < cutoff_len 120 | and add_eos_token 121 | ): 122 | result["input_ids"].append(tokenizer.eos_token_id) 123 | result["attention_mask"].append(1) 124 | 125 | result["labels"] = result["input_ids"].copy() 126 | 127 | return result 128 | 129 | #with open("movie.json",'r', encoding='UTF-8') as f: 130 | # data = json.load(f) 131 | 132 | generation_config = GenerationConfig( 133 | temperature=0.1, 134 | top_p=0.75, 135 | top_k=10, 136 | num_beams=10, 137 | num_return_sequences=10, 138 | ) 139 | data = load_dataset('./data/testset/',data_files="toys_test.json") 140 | print(data) 141 | 142 | hit5 = 0 143 | hit10 = 0 144 | ndcg5 = 0 145 | ndcg10 = 0 146 | total = 0 147 | res = [] 148 | 149 | import pdb 150 | for i, cur in tqdm(enumerate(data['train'])): 151 | label = cur['output'] 152 | inputs = generate_prompt({**cur, "output": ""}) 153 | inputs = tokenizer(inputs, return_tensors="pt") 154 | input_ids = inputs['input_ids'].to('cuda:0') 155 | #pdb.set_trace() 156 | res = [] 157 | with torch.no_grad(): 158 | generation_output = model.generate( 159 | input_ids=input_ids, 160 | generation_config=generation_config, 161 | return_dict_in_generate=True, 162 | output_scores=False,#used to be True 163 | max_new_tokens=64,#used to be 128 164 | ) 165 | 166 | with torch.no_grad(): 167 | for i in range(10): 168 | temp = generation_output.sequences[i] 169 | cur = tokenizer.decode(temp,skip_special_tokens=True).split("### Response:")[1].strip() 170 | cur = cur.split("⁇")[0].strip() 171 | res.append(cur) 172 | #print(label) 173 | #print(res) 174 | 175 | if label in res[:5]: 176 | hit5 += 1 177 | pos = res[:5].index(label) 178 | ndcg5 += 1.0 / (math.log(pos + 2) / math.log(2)) / 1.0 179 | #print(res) 180 | #print(label) 181 | 182 | if label in res: 183 | hit10 += 1 184 | pos = res.index(label) 185 | ndcg10 += 1.0 / (math.log(pos + 2) / math.log(2)) / 1.0 186 | #print(res) 187 | #print(label) 188 | 189 | total += 1 190 | 191 | if total % 100 == 0: 192 | print('The Hit@5 is:',hit5/total) 193 | print('The Hit@10 is:',hit10/total) 194 | print('The NDCG@5 is:',ndcg5/total) 195 | print('The NDCG@10 is:',ndcg10/total) 196 | 197 | 198 | print('The Hit@5 is:',hit5/total) 199 | print('The Hit@10 is:',hit10/total) 200 | print('The NDCG@5 is:',ndcg5/total) 201 | print('The NDCG@10 is:',ndcg10/total) -------------------------------------------------------------------------------- /.ipynb_checkpoints/generate-checkpoint.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import fire 4 | import gradio as gr 5 | import torch 6 | import transformers 7 | from peft import PeftModel 8 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer 9 | 10 | if torch.cuda.is_available(): 11 | device = "cuda" 12 | else: 13 | device = "cpu" 14 | 15 | try: 16 | if torch.backends.mps.is_available(): 17 | device = "mps" 18 | except: # noqa: E722 19 | pass 20 | 21 | 22 | def main( 23 | load_8bit: bool = False, 24 | base_model: str = "", 25 | lora_weights: str = "tloen/alpaca-lora-7b", 26 | share_gradio: bool = True, 27 | ): 28 | assert ( 29 | base_model 30 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" 31 | 32 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 33 | if device == "cuda": 34 | model = LlamaForCausalLM.from_pretrained( 35 | base_model, 36 | load_in_8bit=load_8bit, 37 | torch_dtype=torch.float16, 38 | device_map="auto", 39 | ) 40 | model = PeftModel.from_pretrained( 41 | model, 42 | lora_weights, 43 | torch_dtype=torch.float16, 44 | ) 45 | elif device == "mps": 46 | model = LlamaForCausalLM.from_pretrained( 47 | base_model, 48 | device_map={"": device}, 49 | torch_dtype=torch.float16, 50 | ) 51 | model = PeftModel.from_pretrained( 52 | model, 53 | lora_weights, 54 | device_map={"": device}, 55 | torch_dtype=torch.float16, 56 | ) 57 | else: 58 | model = LlamaForCausalLM.from_pretrained( 59 | base_model, device_map={"": device}, low_cpu_mem_usage=True 60 | ) 61 | model = PeftModel.from_pretrained( 62 | model, 63 | lora_weights, 64 | device_map={"": device}, 65 | ) 66 | 67 | # unwind broken decapoda-research config 68 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk 69 | model.config.bos_token_id = 1 70 | model.config.eos_token_id = 2 71 | 72 | if not load_8bit: 73 | model.half() # seems to fix bugs for some users. 74 | 75 | model.eval() 76 | if torch.__version__ >= "2" and sys.platform != "win32": 77 | model = torch.compile(model) 78 | 79 | def evaluate( 80 | instruction, 81 | input=None, 82 | temperature=0.1, 83 | top_p=0.75, 84 | top_k=40, 85 | num_beams=4, 86 | max_new_tokens=128, 87 | **kwargs, 88 | ): 89 | prompt = generate_prompt(instruction, input) 90 | inputs = tokenizer(prompt, return_tensors="pt") 91 | input_ids = inputs["input_ids"].to(device) 92 | generation_config = GenerationConfig( 93 | temperature=temperature, 94 | top_p=top_p, 95 | top_k=top_k, 96 | num_beams=num_beams, 97 | num_return_sequences=num_beams, 98 | **kwargs, 99 | ) 100 | with torch.no_grad(): 101 | generation_output = model.generate( 102 | input_ids=input_ids, 103 | generation_config=generation_config, 104 | return_dict_in_generate=True, 105 | output_scores=True, 106 | max_new_tokens=max_new_tokens, 107 | ) 108 | 109 | s = [] 110 | 111 | for i in range(num_beams): 112 | temp = generation_output.sequences[i] 113 | s.append(tokenizer.decode(temp,skip_special_tokens=True)) 114 | 115 | output = '' 116 | 117 | for cur in s: 118 | output += cur.split("### Response:")[1].strip() + '\n' 119 | 120 | 121 | 122 | return output 123 | 124 | gr.Interface( 125 | fn=evaluate, 126 | inputs=[ 127 | gr.components.Textbox( 128 | lines=2, 129 | label="Instruction", 130 | placeholder="Tell me about alpacas.", 131 | ), 132 | gr.components.Textbox(lines=2, label="Input", placeholder="none"), 133 | gr.components.Slider( 134 | minimum=0, maximum=1, value=0.1, label="Temperature" 135 | ), 136 | gr.components.Slider( 137 | minimum=0, maximum=1, value=0.75, label="Top p" 138 | ), 139 | gr.components.Slider( 140 | minimum=0, maximum=100, step=1, value=40, label="Top k" 141 | ), 142 | gr.components.Slider( 143 | minimum=1, maximum=4, step=1, value=4, label="Beams" 144 | ), 145 | gr.components.Slider( 146 | minimum=1, maximum=2000, step=1, value=128, label="Max tokens" 147 | ), 148 | ], 149 | outputs=[ 150 | gr.inputs.Textbox( 151 | lines=10, 152 | label="Output", 153 | ) 154 | ], 155 | title="Wiselab-LoRA", 156 | description="Wiselab-LoRA is a 7B-parameter LLaMA-LoRA based model.", # noqa: E501 157 | ).launch(share=share_gradio) 158 | # Old testing code follows. 159 | 160 | """ 161 | # testing code for readme 162 | for instruction in [ 163 | "Tell me about alpacas.", 164 | "Tell me about the president of Mexico in 2019.", 165 | "Tell me about the king of France in 2019.", 166 | "List all Canadian provinces in alphabetical order.", 167 | "Write a Python program that prints the first 10 Fibonacci numbers.", 168 | "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", # noqa: E501 169 | "Tell me five words that rhyme with 'shock'.", 170 | "Translate the sentence 'I have no mouth but I must scream' into Spanish.", 171 | "Count up from 1 to 500.", 172 | ]: 173 | print("Instruction:", instruction) 174 | print("Response:", evaluate(instruction)) 175 | print() 176 | """ 177 | 178 | 179 | def generate_prompt(instruction, input=None): 180 | if input: 181 | return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501 182 | 183 | ### Instruction: 184 | {instruction} 185 | 186 | ### Input: 187 | {input} 188 | 189 | ### Response: 190 | """ 191 | else: 192 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501 193 | 194 | ### Instruction: 195 | {instruction} 196 | 197 | ### Response: 198 | """ 199 | 200 | 201 | if __name__ == "__main__": 202 | fire.Fire(main) 203 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/pyproject-checkpoint.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 79 3 | 4 | [tool.isort] 5 | include_trailing_comma = true 6 | line_length = 79 7 | multi_line_output = 3 8 | profile = "black" 9 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/rec-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import List 4 | 5 | import fire 6 | import torch 7 | import transformers 8 | from datasets import load_dataset 9 | 10 | """ 11 | Unused imports: 12 | import torch.nn as nn 13 | import bitsandbytes as bnb 14 | """ 15 | 16 | from peft import ( # noqa: E402 17 | LoraConfig, 18 | get_peft_model, 19 | get_peft_model_state_dict, 20 | prepare_model_for_int8_training, 21 | set_peft_model_state_dict, 22 | ) 23 | from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402 24 | 25 | 26 | def train( 27 | # model/data params 28 | base_model: str = "", # the only required argument 29 | data_path: str = "yahma/alpaca-cleaned", 30 | output_dir: str = "/common/users/jj635/llama/mycheckpoint/", 31 | # training hyperparams 32 | batch_size: int = 128,#used to be 128 33 | micro_batch_size: int = 4, 34 | num_epochs: int = 3, 35 | learning_rate: float = 3e-4, 36 | cutoff_len: int = 256, 37 | val_set_size: int = 0, 38 | # lora hyperparams 39 | lora_r: int = 8, 40 | lora_alpha: int = 16, 41 | lora_dropout: float = 0.05, 42 | lora_target_modules: List[str] = [ 43 | "q_proj", 44 | "v_proj", 45 | ], 46 | # llm hyperparams 47 | train_on_inputs: bool = True, # if False, masks out inputs in loss 48 | group_by_length: bool = False, # faster, but produces an odd training loss curve 49 | # wandb params 50 | wandb_project: str = "", 51 | wandb_run_name: str = "", 52 | wandb_watch: str = "", # options: false | gradients | all 53 | wandb_log_model: str = "", # options: false | true 54 | resume_from_checkpoint: str = None, # either training checkpoint or final adapter 55 | ): 56 | print( 57 | f"Training Alpaca-LoRA model with params:\n" 58 | f"base_model: {base_model}\n" 59 | f"data_path: {data_path}\n" 60 | f"output_dir: {output_dir}\n" 61 | f"batch_size: {batch_size}\n" 62 | f"micro_batch_size: {micro_batch_size}\n" 63 | f"num_epochs: {num_epochs}\n" 64 | f"learning_rate: {learning_rate}\n" 65 | f"cutoff_len: {cutoff_len}\n" 66 | f"val_set_size: {val_set_size}\n" 67 | f"lora_r: {lora_r}\n" 68 | f"lora_alpha: {lora_alpha}\n" 69 | f"lora_dropout: {lora_dropout}\n" 70 | f"lora_target_modules: {lora_target_modules}\n" 71 | f"train_on_inputs: {train_on_inputs}\n" 72 | f"group_by_length: {group_by_length}\n" 73 | f"wandb_project: {wandb_project}\n" 74 | f"wandb_run_name: {wandb_run_name}\n" 75 | f"wandb_watch: {wandb_watch}\n" 76 | f"wandb_log_model: {wandb_log_model}\n" 77 | f"resume_from_checkpoint: {resume_from_checkpoint}\n" 78 | ) 79 | assert ( 80 | base_model 81 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" 82 | gradient_accumulation_steps = batch_size // micro_batch_size 83 | 84 | device_map = "auto" 85 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 86 | ddp = world_size != 1 87 | if ddp: 88 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 89 | #device_map = {"":0} 90 | gradient_accumulation_steps = gradient_accumulation_steps // world_size 91 | 92 | # Check if parameter passed or if set within environ 93 | use_wandb = len(wandb_project) > 0 or ( 94 | "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 95 | ) 96 | # Only overwrite environ if wandb param passed 97 | if len(wandb_project) > 0: 98 | os.environ["WANDB_PROJECT"] = wandb_project 99 | if len(wandb_watch) > 0: 100 | os.environ["WANDB_WATCH"] = wandb_watch 101 | if len(wandb_log_model) > 0: 102 | os.environ["WANDB_LOG_MODEL"] = wandb_log_model 103 | 104 | model = LlamaForCausalLM.from_pretrained( 105 | base_model, 106 | load_in_8bit=True, 107 | torch_dtype=torch.float16, 108 | device_map=device_map, 109 | ) 110 | 111 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 112 | 113 | tokenizer.pad_token_id = ( 114 | 0 # unk. we want this to be different from the eos token 115 | ) 116 | tokenizer.padding_side = "left" # Allow batched inference 117 | 118 | def tokenize(prompt, add_eos_token=True): 119 | # there's probably a way to do this with the tokenizer settings 120 | # but again, gotta move fast 121 | result = tokenizer( 122 | prompt, 123 | truncation=True, 124 | max_length=cutoff_len, 125 | padding=False, 126 | return_tensors=None, 127 | ) 128 | if ( 129 | result["input_ids"][-1] != tokenizer.eos_token_id 130 | and len(result["input_ids"]) < cutoff_len 131 | and add_eos_token 132 | ): 133 | result["input_ids"].append(tokenizer.eos_token_id) 134 | result["attention_mask"].append(1) 135 | 136 | result["labels"] = result["input_ids"].copy() 137 | 138 | return result 139 | 140 | def generate_and_tokenize_prompt(data_point): 141 | full_prompt = generate_prompt(data_point) 142 | tokenized_full_prompt = tokenize(full_prompt) 143 | if not train_on_inputs: 144 | user_prompt = generate_prompt({**data_point, "output": ""}) 145 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) 146 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 147 | 148 | tokenized_full_prompt["labels"] = [ 149 | -100 150 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 151 | user_prompt_len: 152 | ] # could be sped up, probably 153 | return tokenized_full_prompt 154 | 155 | model = prepare_model_for_int8_training(model) 156 | 157 | config = LoraConfig( 158 | r=lora_r, 159 | lora_alpha=lora_alpha, 160 | target_modules=lora_target_modules, 161 | lora_dropout=lora_dropout, 162 | bias="none", 163 | task_type="CAUSAL_LM", 164 | ) 165 | model = get_peft_model(model, config) 166 | 167 | if data_path.endswith(".json") or data_path.endswith(".jsonl"): 168 | data = load_dataset("json", data_files=data_path) 169 | else: 170 | data = load_dataset(data_path) 171 | 172 | if resume_from_checkpoint: 173 | # Check the available weights and load them 174 | checkpoint_name = os.path.join( 175 | resume_from_checkpoint, "pytorch_model.bin" 176 | ) # Full checkpoint 177 | if not os.path.exists(checkpoint_name): 178 | checkpoint_name = os.path.join( 179 | resume_from_checkpoint, "adapter_model.bin" 180 | ) # only LoRA model - LoRA config above has to fit 181 | resume_from_checkpoint = ( 182 | False # So the trainer won't try loading its state 183 | ) 184 | # The two files above have a different name depending on how they were saved, but are actually the same. 185 | if os.path.exists(checkpoint_name): 186 | print(f"Restarting from {checkpoint_name}") 187 | adapters_weights = torch.load(checkpoint_name) 188 | model = set_peft_model_state_dict(model, adapters_weights) 189 | else: 190 | print(f"Checkpoint {checkpoint_name} not found") 191 | 192 | model.print_trainable_parameters() # Be more transparent about the % of trainable params. 193 | 194 | if val_set_size > 0: 195 | train_val = data["train"].train_test_split( 196 | test_size=val_set_size, shuffle=True, seed=42 197 | ) 198 | train_data = ( 199 | train_val["train"].shuffle().map(generate_and_tokenize_prompt) 200 | ) 201 | val_data = ( 202 | train_val["test"].shuffle().map(generate_and_tokenize_prompt) 203 | ) 204 | else: 205 | train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) 206 | val_data = None 207 | 208 | if not ddp and torch.cuda.device_count() > 1: 209 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available 210 | model.is_parallelizable = True 211 | model.model_parallel = True 212 | 213 | #changed by myself 214 | #model.is_parallelizable = False 215 | #model.model_parallel = False 216 | 217 | trainer = transformers.Trainer( 218 | model=model, 219 | train_dataset=train_data, 220 | eval_dataset=val_data, 221 | args=transformers.TrainingArguments( 222 | per_device_train_batch_size=micro_batch_size, 223 | gradient_accumulation_steps=gradient_accumulation_steps, 224 | warmup_steps=100, 225 | num_train_epochs=num_epochs, 226 | learning_rate=learning_rate, 227 | fp16=True, 228 | logging_steps=10, 229 | optim="adamw_torch", 230 | evaluation_strategy="steps" if val_set_size > 0 else "no", 231 | save_strategy="steps", 232 | eval_steps=200 if val_set_size > 0 else None, 233 | save_steps=200, 234 | output_dir=output_dir, 235 | save_total_limit=3, 236 | load_best_model_at_end=True if val_set_size > 0 else False, 237 | ddp_find_unused_parameters=False if ddp else None, 238 | group_by_length=group_by_length, 239 | report_to="wandb" if use_wandb else None, 240 | run_name=wandb_run_name if use_wandb else None, 241 | ), 242 | data_collator=transformers.DataCollatorForSeq2Seq( 243 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 244 | ), 245 | ) 246 | model.config.use_cache = False 247 | 248 | old_state_dict = model.state_dict 249 | model.state_dict = ( 250 | lambda self, *_, **__: get_peft_model_state_dict( 251 | self, old_state_dict() 252 | ) 253 | ).__get__(model, type(model)) 254 | 255 | if torch.__version__ >= "2" and sys.platform != "win32": 256 | model = torch.compile(model) 257 | 258 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 259 | 260 | model.save_pretrained(output_dir) 261 | 262 | print( 263 | "\n If there's a warning about missing keys above, please disregard :)" 264 | ) 265 | 266 | 267 | def generate_prompt(data_point): 268 | # sorry about the formatting disaster gotta move fast 269 | if data_point["input"]: 270 | return f""" # noqa: E501 271 | {data_point["instruction"]} 272 | 273 | ### input: 274 | {data_point["input"]} 275 | 276 | ### Response: 277 | {data_point["output"]}""" 278 | else: 279 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501 280 | 281 | ### Instruction: 282 | {data_point["instruction"]} 283 | 284 | ### Response: 285 | {data_point["output"]}""" 286 | 287 | 288 | if __name__ == "__main__": 289 | fire.Fire(train) 290 | -------------------------------------------------------------------------------- /DATA_LICENSE: -------------------------------------------------------------------------------- 1 | Attribution License (ODC-By) 2 | PREAMBLE 3 | The Open Data Commons Attribution License is a license agreement intended to allow users to freely share, modify, and use this Database subject only to the attribution requirements set out in Section 4. 4 | 5 | Databases can contain a wide variety of types of content (images, audiovisual material, and sounds all in the same database, for example), and so this license only governs the rights over the Database, and not the contents of the Database individually. Licensors may therefore wish to use this license together with another license for the contents. 6 | 7 | Sometimes the contents of a database, or the database itself, can be covered by other rights not addressed here (such as private contracts, trademark over the name, or privacy rights / data protection rights over information in the contents), and so you are advised that you may have to consult other documents or clear other rights before doing activities not covered by this License. 8 | 9 | The Licensor (as defined below) 10 | 11 | and 12 | 13 | You (as defined below) 14 | 15 | agree as follows: 16 | 17 | 1.0 DEFINITIONS OF CAPITALISED WORDS 18 | “Collective Database” – Means this Database in unmodified form as part of a collection of independent databases in themselves that together are assembled into a collective whole. A work that constitutes a Collective Database will not be considered a Derivative Database. 19 | 20 | “Convey” – As a verb, means Using the Database, a Derivative Database, or the Database as part of a Collective Database in any way that enables a Person to make or receive copies of the Database or a Derivative Database. Conveying does not include interaction with a user through a computer network, or creating and Using a Produced Work, where no transfer of a copy of the Database or a Derivative Database occurs. 21 | 22 | “Contents” – The contents of this Database, which includes the information, independent works, or other material collected into the Database. For example, the contents of the Database could be factual data or works such as images, audiovisual material, text, or sounds. 23 | 24 | “Database” – A collection of material (the Contents) arranged in a systematic or methodical way and individually accessible by electronic or other means offered under the terms of this License. 25 | 26 | “Database Directive” – Means Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended or succeeded. 27 | 28 | “Database Right” – Means rights resulting from the Chapter III (“sui generis”) rights in the Database Directive (as amended and as transposed by member states), which includes the Extraction and Re-utilisation of the whole or a Substantial part of the Contents, as well as any similar rights available in the relevant jurisdiction under Section 10.4. 29 | 30 | “Derivative Database” – Means a database based upon the Database, and includes any translation, adaptation, arrangement, modification, or any other alteration of the Database or of a Substantial part of the Contents. This includes, but is not limited to, Extracting or Re-utilising the whole or a Substantial part of the Contents in a new Database. 31 | 32 | “Extraction” – Means the permanent or temporary transfer of all or a Substantial part of the Contents to another medium by any means or in any form. 33 | 34 | “License” – Means this license agreement and is both a license of rights such as copyright and Database Rights and an agreement in contract. 35 | 36 | “Licensor” – Means the Person that offers the Database under the terms of this License. 37 | 38 | “Person” – Means a natural or legal person or a body of persons corporate or incorporate. 39 | 40 | “Produced Work” – a work (such as an image, audiovisual material, text, or sounds) resulting from using the whole or a Substantial part of the Contents (via a search or other query) from this Database, a Derivative Database, or this Database as part of a Collective Database. 41 | 42 | “Publicly” – means to Persons other than You or under Your control by either more than 50% ownership or by the power to direct their activities (such as contracting with an independent consultant). 43 | 44 | “Re-utilisation” – means any form of making available to the public all or a Substantial part of the Contents by the distribution of copies, by renting, by online or other forms of transmission. 45 | 46 | “Substantial” – Means substantial in terms of quantity or quality or a combination of both. The repeated and systematic Extraction or Re-utilisation of insubstantial parts of the Contents may amount to the Extraction or Re-utilisation of a Substantial part of the Contents. 47 | 48 | “Use” – As a verb, means doing any act that is restricted by copyright or Database Rights whether in the original medium or any other; and includes without limitation distributing, copying, publicly performing, publicly displaying, and preparing derivative works of the Database, as well as modifying the Database as may be technically necessary to use it in a different mode or format. 49 | 50 | “You” – Means a Person exercising rights under this License who has not previously violated the terms of this License with respect to the Database, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation. 51 | 52 | Words in the singular include the plural and vice versa. 53 | 54 | 2.0 WHAT THIS LICENSE COVERS 55 | 2.1. Legal effect of this document. This License is: 56 | 57 | a. A license of applicable copyright and neighbouring rights; 58 | 59 | b. A license of the Database Right; and 60 | 61 | c. An agreement in contract between You and the Licensor. 62 | 63 | 2.2 Legal rights covered. This License covers the legal rights in the Database, including: 64 | 65 | a. Copyright. Any copyright or neighbouring rights in the Database. The copyright licensed includes any individual elements of the Database, but does not cover the copyright over the Contents independent of this Database. See Section 2.4 for details. Copyright law varies between jurisdictions, but is likely to cover: the Database model or schema, which is the structure, arrangement, and organisation of the Database, and can also include the Database tables and table indexes; the data entry and output sheets; and the Field names of Contents stored in the Database; 66 | 67 | b. Database Rights. Database Rights only extend to the Extraction and Re-utilisation of the whole or a Substantial part of the Contents. Database Rights can apply even when there is no copyright over the Database. Database Rights can also apply when the Contents are removed from the Database and are selected and arranged in a way that would not infringe any applicable copyright; and 68 | 69 | c. Contract. This is an agreement between You and the Licensor for access to the Database. In return you agree to certain conditions of use on this access as outlined in this License. 70 | 71 | 2.3 Rights not covered. 72 | 73 | a. This License does not apply to computer programs used in the making or operation of the Database; 74 | 75 | b. This License does not cover any patents over the Contents or the Database; and 76 | 77 | c. This License does not cover any trademarks associated with the Database. 78 | 79 | 2.4 Relationship to Contents in the Database. The individual items of the Contents contained in this Database may be covered by other rights, including copyright, patent, data protection, privacy, or personality rights, and this License does not cover any rights (other than Database Rights or in contract) in individual Contents contained in the Database. 80 | 81 | For example, if used on a Database of images (the Contents), this License would not apply to copyright over individual images, which could have their own separate licenses, or one single license covering all of the rights over the images. 82 | 83 | 3.0 RIGHTS GRANTED 84 | 3.1 Subject to the terms and conditions of this License, the Licensor grants to You a worldwide, royalty-free, non-exclusive, terminable (but only under Section 9) license to Use the Database for the duration of any applicable copyright and Database Rights. These rights explicitly include commercial use, and do not exclude any field of endeavour. To the extent possible in the relevant jurisdiction, these rights may be exercised in all media and formats whether now known or created in the future. 85 | 86 | The rights granted cover, for example: 87 | 88 | a. Extraction and Re-utilisation of the whole or a Substantial part of the Contents; 89 | 90 | b. Creation of Derivative Databases; 91 | 92 | c. Creation of Collective Databases; 93 | 94 | d. Creation of temporary or permanent reproductions by any means and in any form, in whole or in part, including of any Derivative Databases or as a part of Collective Databases; and 95 | 96 | e. Distribution, communication, display, lending, making available, or performance to the public by any means and in any form, in whole or in part, including of any Derivative Database or as a part of Collective Databases. 97 | 98 | 3.2 Compulsory license schemes. For the avoidance of doubt: 99 | 100 | a. Non-waivable compulsory license schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme cannot be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; 101 | 102 | b. Waivable compulsory license schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme can be waived, the Licensor waives the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; and, 103 | 104 | c. Voluntary license schemes. The Licensor waives the right to collect royalties, whether individually or, in the event that the Licensor is a member of a collecting society that administers voluntary licensing schemes, via that society, from any exercise by You of the rights granted under this License. 105 | 106 | 3.3 The right to release the Database under different terms, or to stop distributing or making available the Database, is reserved. Note that this Database may be multiple-licensed, and so You may have the choice of using alternative licenses for this Database. Subject to Section 10.4, all other rights not expressly granted by Licensor are reserved. 107 | 108 | 4.0 CONDITIONS OF USE 109 | 4.1 The rights granted in Section 3 above are expressly made subject to Your complying with the following conditions of use. These are important conditions of this License, and if You fail to follow them, You will be in material breach of its terms. 110 | 111 | 4.2 Notices. If You Publicly Convey this Database, any Derivative Database, or the Database as part of a Collective Database, then You must: 112 | 113 | a. Do so only under the terms of this License; 114 | 115 | b. Include a copy of this License or its Uniform Resource Identifier (URI) with the Database or Derivative Database, including both in the Database or Derivative Database and in any relevant documentation; 116 | 117 | c. Keep intact any copyright or Database Right notices and notices that refer to this License; and 118 | 119 | d. If it is not possible to put the required notices in a particular file due to its structure, then You must include the notices in a location (such as a relevant directory) where users would be likely to look for it. 120 | 121 | 4.3 Notice for using output (Contents). Creating and Using a Produced Work does not require the notice in Section 4.2. However, if you Publicly Use a Produced Work, You must include a notice associated with the Produced Work reasonably calculated to make any Person that uses, views, accesses, interacts with, or is otherwise exposed to the Produced Work aware that Content was obtained from the Database, Derivative Database, or the Database as part of a Collective Database, and that it is available under this License. 122 | 123 | a. Example notice. The following text will satisfy notice under Section 4.3: 124 | 125 | Contains information from DATABASE NAME which is made available 126 | under the ODC Attribution License. 127 | DATABASE NAME should be replaced with the name of the Database and a hyperlink to the location of the Database. “ODC Attribution License” should contain a hyperlink to the URI of the text of this License. If hyperlinks are not possible, You should include the plain text of the required URI’s with the above notice. 128 | 129 | 4.4 Licensing of others. You may not sublicense the Database. Each time You communicate the Database, the whole or Substantial part of the Contents, or any Derivative Database to anyone else in any way, the Licensor offers to the recipient a license to the Database on the same terms and conditions as this License. You are not responsible for enforcing compliance by third parties with this License, but You may enforce any rights that You have over a Derivative Database. You are solely responsible for any modifications of a Derivative Database made by You or another Person at Your direction. You may not impose any further restrictions on the exercise of the rights granted or affirmed under this License. 130 | 131 | 5.0 MORAL RIGHTS 132 | 5.1 Moral rights. This section covers moral rights, including any rights to be identified as the author of the Database or to object to treatment that would otherwise prejudice the author’s honour and reputation, or any other derogatory treatment: 133 | 134 | a. For jurisdictions allowing waiver of moral rights, Licensor waives all moral rights that Licensor may have in the Database to the fullest extent possible by the law of the relevant jurisdiction under Section 10.4; 135 | 136 | b. If waiver of moral rights under Section 5.1 a in the relevant jurisdiction is not possible, Licensor agrees not to assert any moral rights over the Database and waives all claims in moral rights to the fullest extent possible by the law of the relevant jurisdiction under Section 10.4; and 137 | 138 | c. For jurisdictions not allowing waiver or an agreement not to assert moral rights under Section 5.1 a and b, the author may retain their moral rights over certain aspects of the Database. 139 | 140 | Please note that some jurisdictions do not allow for the waiver of moral rights, and so moral rights may still subsist over the Database in some jurisdictions. 141 | 142 | 6.0 FAIR DEALING, DATABASE EXCEPTIONS, AND OTHER RIGHTS NOT AFFECTED 143 | 6.1 This License does not affect any rights that You or anyone else may independently have under any applicable law to make any use of this Database, including without limitation: 144 | 145 | a. Exceptions to the Database Right including: Extraction of Contents from non-electronic Databases for private purposes, Extraction for purposes of illustration for teaching or scientific research, and Extraction or Re-utilisation for public security or an administrative or judicial procedure. 146 | 147 | b. Fair dealing, fair use, or any other legally recognised limitation or exception to infringement of copyright or other applicable laws. 148 | 149 | 6.2 This License does not affect any rights of lawful users to Extract and Re-utilise insubstantial parts of the Contents, evaluated quantitatively or qualitatively, for any purposes whatsoever, including creating a Derivative Database (subject to other rights over the Contents, see Section 2.4). The repeated and systematic Extraction or Re-utilisation of insubstantial parts of the Contents may however amount to the Extraction or Re-utilisation of a Substantial part of the Contents. 150 | 151 | 7.0 WARRANTIES AND DISCLAIMER 152 | 7.1 The Database is licensed by the Licensor “as is” and without any warranty of any kind, either express, implied, or arising by statute, custom, course of dealing, or trade usage. Licensor specifically disclaims any and all implied warranties or conditions of title, non-infringement, accuracy or completeness, the presence or absence of errors, fitness for a particular purpose, merchantability, or otherwise. Some jurisdictions do not allow the exclusion of implied warranties, so this exclusion may not apply to You. 153 | 154 | 8.0 LIMITATION OF LIABILITY 155 | 8.1 Subject to any liability that may not be excluded or limited by law, the Licensor is not liable for, and expressly excludes, all liability for loss or damage however and whenever caused to anyone by any use under this License, whether by You or by anyone else, and whether caused by any fault on the part of the Licensor or not. This exclusion of liability includes, but is not limited to, any special, incidental, consequential, punitive, or exemplary damages such as loss of revenue, data, anticipated profits, and lost business. This exclusion applies even if the Licensor has been advised of the possibility of such damages. 156 | 157 | 8.2 If liability may not be excluded by law, it is limited to actual and direct financial loss to the extent it is caused by proved negligence on the part of the Licensor. 158 | 159 | 9.0 TERMINATION OF YOUR RIGHTS UNDER THIS LICENSE 160 | 9.1 Any breach by You of the terms and conditions of this License automatically terminates this License with immediate effect and without notice to You. For the avoidance of doubt, Persons who have received the Database, the whole or a Substantial part of the Contents, Derivative Databases, or the Database as part of a Collective Database from You under this License will not have their licenses terminated provided their use is in full compliance with this License or a license granted under Section 4.8 of this License. Sections 1, 2, 7, 8, 9 and 10 will survive any termination of this License. 161 | 162 | 9.2 If You are not in breach of the terms of this License, the Licensor will not terminate Your rights under it. 163 | 164 | 9.3 Unless terminated under Section 9.1, this License is granted to You for the duration of applicable rights in the Database. 165 | 166 | 9.4 Reinstatement of rights. If you cease any breach of the terms and conditions of this License, then your full rights under this License will be reinstated: 167 | 168 | a. Provisionally and subject to permanent termination until the 60th day after cessation of breach; 169 | 170 | b. Permanently on the 60th day after cessation of breach unless otherwise reasonably notified by the Licensor; or 171 | 172 | c. Permanently if reasonably notified by the Licensor of the violation, this is the first time You have received notice of violation of this License from the Licensor, and You cure the violation prior to 30 days after your receipt of the notice. 173 | 174 | 9.5 Notwithstanding the above, Licensor reserves the right to release the Database under different license terms or to stop distributing or making available the Database. Releasing the Database under different license terms or stopping the distribution of the Database will not withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above. 175 | 176 | 10.0 GENERAL 177 | 10.1 If any provision of this License is held to be invalid or unenforceable, that must not affect the validity or enforceability of the remainder of the terms and conditions of this License and each remaining provision of this License shall be valid and enforced to the fullest extent permitted by law. 178 | 179 | 10.2 This License is the entire agreement between the parties with respect to the rights granted here over the Database. It replaces any earlier understandings, agreements or representations with respect to the Database. 180 | 181 | 10.3 If You are in breach of the terms of this License, You will not be entitled to rely on the terms of this License or to complain of any breach by the Licensor. 182 | 183 | 10.4 Choice of law. This License takes effect in and will be governed by the laws of the relevant jurisdiction in which the License terms are sought to be enforced. If the standard suite of rights granted under applicable copyright law and Database Rights in the relevant jurisdiction includes additional rights not granted under this License, these additional rights are granted in this License in order to meet the terms of this License. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GenRec: Large Language Model for Generative Recommendation 2 | 3 | ## Install dependencies 4 | 5 | 6 | pip install -r requirements.txt 7 | 8 | 9 | ### Training (`rec.py`) 10 | 11 | 12 | python rec.py \ 13 | --base_model 'decapoda-research/llama-7b-hf' \ 14 | --data_path './moives' \ 15 | --output_dir './checkpoint' 16 | 17 | 18 | 19 | ### Inference (`generate.py`) 20 | 21 | 22 | 23 | python generate.py \ 24 | --load_8bit \ 25 | --base_model 'decapoda-research/llama-7b-hf' \ 26 | --lora_weights './checkpoint/movies' 27 | 28 | 29 | This project is implemented based on alpaca-lora (https://github.com/tloen/alpaca-lora) 30 | 31 | ## Citation 32 | 33 | ``` 34 | @article{ji2024genrec, 35 | title={GenRec: Large Language Model for Generative Recommendation}, 36 | author={Jianchao Ji, Zelong Li, Shuyuan Xu, Wenyue Hua, Yingqiang Ge, Juntao Tan and Yongfeng Zhang}, 37 | journal={ECIR}, 38 | year={2024} 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /checkpoint/movies/adapter_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name_or_path": "/common/users/jj635/llama/llama-7b/", 3 | "bias": "none", 4 | "enable_lora": null, 5 | "fan_in_fan_out": false, 6 | "inference_mode": true, 7 | "lora_alpha": 16, 8 | "lora_dropout": 0.05, 9 | "merge_weights": false, 10 | "modules_to_save": null, 11 | "peft_type": "LORA", 12 | "r": 8, 13 | "target_modules": [ 14 | "q_proj", 15 | "v_proj" 16 | ], 17 | "task_type": "CAUSAL_LM" 18 | } -------------------------------------------------------------------------------- /checkpoint/movies/adapter_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rutgerswiselab/GenRec/e6fb9bb504078fdadf5825bcb2262387aab18e3b/checkpoint/movies/adapter_model.bin -------------------------------------------------------------------------------- /checkpoint/toys/adapter_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name_or_path": "/common/users/jj635/llama/llama-7b/", 3 | "bias": "none", 4 | "enable_lora": null, 5 | "fan_in_fan_out": false, 6 | "inference_mode": true, 7 | "lora_alpha": 16, 8 | "lora_dropout": 0.05, 9 | "merge_weights": false, 10 | "modules_to_save": null, 11 | "peft_type": "LORA", 12 | "r": 8, 13 | "target_modules": [ 14 | "q_proj", 15 | "v_proj" 16 | ], 17 | "task_type": "CAUSAL_LM" 18 | } -------------------------------------------------------------------------------- /checkpoint/toys/adapter_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rutgerswiselab/GenRec/e6fb9bb504078fdadf5825bcb2262387aab18e3b/checkpoint/toys/adapter_model.bin -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import fire 4 | import gradio as gr 5 | import torch 6 | import transformers 7 | from peft import PeftModel 8 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer 9 | import json 10 | import random 11 | from torch.utils.data import DataLoader, Sampler 12 | from tqdm import tqdm 13 | import math 14 | 15 | from datasets import load_dataset 16 | 17 | #import os 18 | #os.environ["CUDA_VISIBLE_DEVICES"] = "7" 19 | 20 | if torch.cuda.is_available(): 21 | device = "cuda" 22 | else: 23 | device = "cpu" 24 | 25 | try: 26 | if torch.backends.mps.is_available(): 27 | device = "mps" 28 | except: # noqa: E722 29 | pass 30 | 31 | 32 | load_8bit: bool = False, 33 | base_model = '/common/users/jj635/llama/llama-7b/' 34 | lora_weights = './checkpoint/movies' 35 | 36 | 37 | """ 38 | model = LlamaForCausalLM.from_pretrained( 39 | base_model, 40 | load_in_8bit=load_8bit, 41 | torch_dtype=torch.float16, 42 | device_map='auto', 43 | ) 44 | 45 | model = PeftModel.from_pretrained( 46 | model, 47 | lora_weights, 48 | torch_dtype=torch.float16, 49 | device_map='auto', 50 | ) 51 | """ 52 | 53 | 54 | model = LlamaForCausalLM.from_pretrained( 55 | base_model, 56 | load_in_8bit=load_8bit, 57 | torch_dtype=torch.float16, 58 | device_map={'':0},#could be 1 59 | ) 60 | model = PeftModel.from_pretrained( 61 | model, 62 | lora_weights, 63 | torch_dtype=torch.float16, 64 | device_map={'':0},#could be 1 65 | ) 66 | 67 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 68 | 69 | 70 | def generate_prompt(data_point): 71 | # sorry about the formatting disaster gotta move fast 72 | if data_point["input"]: 73 | return f""" # noqa: E501 74 | {data_point["instruction"]} 75 | 76 | ### input: 77 | {data_point["input"]} 78 | 79 | ### Response: 80 | {data_point["output"]}""" 81 | else: 82 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501 83 | 84 | ### Instruction: 85 | {data_point["instruction"]} 86 | 87 | ### Response: 88 | {data_point["output"]}""" 89 | 90 | 91 | def generate_and_tokenize_prompt(data_point): 92 | full_prompt = generate_prompt(data_point) 93 | tokenized_full_prompt = tokenize(full_prompt) 94 | user_prompt = generate_prompt({**data_point, "output": ""}) 95 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) 96 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 97 | 98 | tokenized_full_prompt["labels"] = [ 99 | -100 100 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 101 | user_prompt_len: 102 | ] # could be sped up, probably 103 | return tokenized_full_prompt 104 | 105 | 106 | def tokenize(prompt, add_eos_token=True): 107 | # there's probably a way to do this with the tokenizer settings 108 | # but again, gotta move fast 109 | cutoff_len = 256 110 | result = tokenizer( 111 | prompt, 112 | truncation=True, 113 | max_length=cutoff_len, 114 | padding=False, 115 | return_tensors=None, 116 | ) 117 | if ( 118 | result["input_ids"][-1] != tokenizer.eos_token_id 119 | and len(result["input_ids"]) < cutoff_len 120 | and add_eos_token 121 | ): 122 | result["input_ids"].append(tokenizer.eos_token_id) 123 | result["attention_mask"].append(1) 124 | 125 | result["labels"] = result["input_ids"].copy() 126 | 127 | return result 128 | 129 | #with open("movie.json",'r', encoding='UTF-8') as f: 130 | # data = json.load(f) 131 | 132 | generation_config = GenerationConfig( 133 | temperature=0.1, 134 | top_p=0.75, 135 | top_k=10, 136 | num_beams=10, 137 | num_return_sequences=10, 138 | ) 139 | data = load_dataset('./data/testset/',data_files="toys_test.json") 140 | print(data) 141 | 142 | hit5 = 0 143 | hit10 = 0 144 | ndcg5 = 0 145 | ndcg10 = 0 146 | total = 0 147 | res = [] 148 | 149 | import pdb 150 | for i, cur in tqdm(enumerate(data['train'])): 151 | label = cur['output'] 152 | inputs = generate_prompt({**cur, "output": ""}) 153 | inputs = tokenizer(inputs, return_tensors="pt") 154 | input_ids = inputs['input_ids'].to('cuda:0') 155 | #pdb.set_trace() 156 | res = [] 157 | with torch.no_grad(): 158 | generation_output = model.generate( 159 | input_ids=input_ids, 160 | generation_config=generation_config, 161 | return_dict_in_generate=True, 162 | output_scores=False,#used to be True 163 | max_new_tokens=64,#used to be 128 164 | ) 165 | 166 | with torch.no_grad(): 167 | for i in range(10): 168 | temp = generation_output.sequences[i] 169 | cur = tokenizer.decode(temp,skip_special_tokens=True).split("### Response:")[1].strip() 170 | cur = cur.split("⁇")[0].strip() 171 | res.append(cur) 172 | #print(label) 173 | #print(res) 174 | 175 | if label in res[:5]: 176 | hit5 += 1 177 | pos = res[:5].index(label) 178 | ndcg5 += 1.0 / (math.log(pos + 2) / math.log(2)) / 1.0 179 | #print(res) 180 | #print(label) 181 | 182 | if label in res: 183 | hit10 += 1 184 | pos = res.index(label) 185 | ndcg10 += 1.0 / (math.log(pos + 2) / math.log(2)) / 1.0 186 | #print(res) 187 | #print(label) 188 | 189 | total += 1 190 | 191 | if total % 100 == 0: 192 | print('The Hit@5 is:',hit5/total) 193 | print('The Hit@10 is:',hit10/total) 194 | print('The NDCG@5 is:',ndcg5/total) 195 | print('The NDCG@10 is:',ndcg10/total) 196 | 197 | 198 | print('The Hit@5 is:',hit5/total) 199 | print('The Hit@10 is:',hit10/total) 200 | print('The NDCG@5 is:',ndcg5/total) 201 | print('The NDCG@10 is:',ndcg10/total) -------------------------------------------------------------------------------- /export_hf_checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import transformers 5 | from peft import PeftModel 6 | from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402 7 | 8 | BASE_MODEL = os.environ.get("BASE_MODEL", None) 9 | assert ( 10 | BASE_MODEL 11 | ), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501 12 | 13 | tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL) 14 | 15 | base_model = LlamaForCausalLM.from_pretrained( 16 | BASE_MODEL, 17 | load_in_8bit=False, 18 | torch_dtype=torch.float16, 19 | device_map={"": "cpu"}, 20 | ) 21 | 22 | first_weight = base_model.model.layers[0].self_attn.q_proj.weight 23 | first_weight_old = first_weight.clone() 24 | 25 | lora_model = PeftModel.from_pretrained( 26 | base_model, 27 | "tloen/alpaca-lora-7b", 28 | device_map={"": "cpu"}, 29 | torch_dtype=torch.float16, 30 | ) 31 | 32 | lora_weight = lora_model.base_model.model.model.layers[ 33 | 0 34 | ].self_attn.q_proj.weight 35 | 36 | assert torch.allclose(first_weight_old, first_weight) 37 | 38 | # merge weights 39 | for layer in lora_model.base_model.model.model.layers: 40 | layer.self_attn.q_proj.merge_weights = True 41 | layer.self_attn.v_proj.merge_weights = True 42 | 43 | lora_model.train(False) 44 | 45 | # did we do anything? 46 | assert not torch.allclose(first_weight_old, first_weight) 47 | 48 | lora_model_sd = lora_model.state_dict() 49 | deloreanized_sd = { 50 | k.replace("base_model.model.", ""): v 51 | for k, v in lora_model_sd.items() 52 | if "lora" not in k 53 | } 54 | 55 | LlamaForCausalLM.save_pretrained( 56 | base_model, "./hf_ckpt", state_dict=deloreanized_sd, max_shard_size="400MB" 57 | ) 58 | -------------------------------------------------------------------------------- /export_state_dict_checkpoint.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | import transformers 6 | from peft import PeftModel 7 | from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: E402 8 | 9 | BASE_MODEL = os.environ.get("BASE_MODEL", None) 10 | assert ( 11 | BASE_MODEL 12 | ), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501 13 | 14 | tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL) 15 | 16 | base_model = LlamaForCausalLM.from_pretrained( 17 | BASE_MODEL, 18 | load_in_8bit=False, 19 | torch_dtype=torch.float16, 20 | device_map={"": "cpu"}, 21 | ) 22 | 23 | lora_model = PeftModel.from_pretrained( 24 | base_model, 25 | "tloen/alpaca-lora-7b", 26 | device_map={"": "cpu"}, 27 | torch_dtype=torch.float16, 28 | ) 29 | 30 | # merge weights 31 | for layer in lora_model.base_model.model.model.layers: 32 | layer.self_attn.q_proj.merge_weights = True 33 | layer.self_attn.v_proj.merge_weights = True 34 | 35 | lora_model.train(False) 36 | 37 | lora_model_sd = lora_model.state_dict() 38 | 39 | params = { 40 | "dim": 4096, 41 | "multiple_of": 256, 42 | "n_heads": 32, 43 | "n_layers": 32, 44 | "norm_eps": 1e-06, 45 | "vocab_size": -1, 46 | } 47 | n_layers = params["n_layers"] 48 | n_heads = params["n_heads"] 49 | dim = params["dim"] 50 | dims_per_head = dim // n_heads 51 | base = 10000.0 52 | inv_freq = 1.0 / ( 53 | base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head) 54 | ) 55 | 56 | 57 | def permute(w): 58 | return ( 59 | w.view(n_heads, dim // n_heads // 2, 2, dim) 60 | .transpose(1, 2) 61 | .reshape(dim, dim) 62 | ) 63 | 64 | 65 | def unpermute(w): 66 | return ( 67 | w.view(n_heads, 2, dim // n_heads // 2, dim) 68 | .transpose(1, 2) 69 | .reshape(dim, dim) 70 | ) 71 | 72 | 73 | def translate_state_dict_key(k): # noqa: C901 74 | k = k.replace("base_model.model.", "") 75 | if k == "model.embed_tokens.weight": 76 | return "tok_embeddings.weight" 77 | elif k == "model.norm.weight": 78 | return "norm.weight" 79 | elif k == "lm_head.weight": 80 | return "output.weight" 81 | elif k.startswith("model.layers."): 82 | layer = k.split(".")[2] 83 | if k.endswith(".self_attn.q_proj.weight"): 84 | return f"layers.{layer}.attention.wq.weight" 85 | elif k.endswith(".self_attn.k_proj.weight"): 86 | return f"layers.{layer}.attention.wk.weight" 87 | elif k.endswith(".self_attn.v_proj.weight"): 88 | return f"layers.{layer}.attention.wv.weight" 89 | elif k.endswith(".self_attn.o_proj.weight"): 90 | return f"layers.{layer}.attention.wo.weight" 91 | elif k.endswith(".mlp.gate_proj.weight"): 92 | return f"layers.{layer}.feed_forward.w1.weight" 93 | elif k.endswith(".mlp.down_proj.weight"): 94 | return f"layers.{layer}.feed_forward.w2.weight" 95 | elif k.endswith(".mlp.up_proj.weight"): 96 | return f"layers.{layer}.feed_forward.w3.weight" 97 | elif k.endswith(".input_layernorm.weight"): 98 | return f"layers.{layer}.attention_norm.weight" 99 | elif k.endswith(".post_attention_layernorm.weight"): 100 | return f"layers.{layer}.ffn_norm.weight" 101 | elif k.endswith("rotary_emb.inv_freq") or "lora" in k: 102 | return None 103 | else: 104 | print(layer, k) 105 | raise NotImplementedError 106 | else: 107 | print(k) 108 | raise NotImplementedError 109 | 110 | 111 | new_state_dict = {} 112 | for k, v in lora_model_sd.items(): 113 | new_k = translate_state_dict_key(k) 114 | if new_k is not None: 115 | if "wq" in new_k or "wk" in new_k: 116 | new_state_dict[new_k] = unpermute(v) 117 | else: 118 | new_state_dict[new_k] = v 119 | 120 | os.makedirs("./ckpt", exist_ok=True) 121 | 122 | torch.save(new_state_dict, "./ckpt/consolidated.00.pth") 123 | 124 | with open("./ckpt/params.json", "w") as f: 125 | json.dump(params, f) 126 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import fire 4 | import gradio as gr 5 | import torch 6 | import transformers 7 | from peft import PeftModel 8 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer 9 | 10 | if torch.cuda.is_available(): 11 | device = "cuda" 12 | else: 13 | device = "cpu" 14 | 15 | try: 16 | if torch.backends.mps.is_available(): 17 | device = "mps" 18 | except: # noqa: E722 19 | pass 20 | 21 | 22 | def main( 23 | load_8bit: bool = False, 24 | base_model: str = "", 25 | lora_weights: str = "tloen/alpaca-lora-7b", 26 | share_gradio: bool = True, 27 | ): 28 | assert ( 29 | base_model 30 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" 31 | 32 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 33 | if device == "cuda": 34 | model = LlamaForCausalLM.from_pretrained( 35 | base_model, 36 | load_in_8bit=load_8bit, 37 | torch_dtype=torch.float16, 38 | device_map="auto", 39 | ) 40 | model = PeftModel.from_pretrained( 41 | model, 42 | lora_weights, 43 | torch_dtype=torch.float16, 44 | ) 45 | elif device == "mps": 46 | model = LlamaForCausalLM.from_pretrained( 47 | base_model, 48 | device_map={"": device}, 49 | torch_dtype=torch.float16, 50 | ) 51 | model = PeftModel.from_pretrained( 52 | model, 53 | lora_weights, 54 | device_map={"": device}, 55 | torch_dtype=torch.float16, 56 | ) 57 | else: 58 | model = LlamaForCausalLM.from_pretrained( 59 | base_model, device_map={"": device}, low_cpu_mem_usage=True 60 | ) 61 | model = PeftModel.from_pretrained( 62 | model, 63 | lora_weights, 64 | device_map={"": device}, 65 | ) 66 | 67 | # unwind broken decapoda-research config 68 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk 69 | model.config.bos_token_id = 1 70 | model.config.eos_token_id = 2 71 | 72 | if not load_8bit: 73 | model.half() # seems to fix bugs for some users. 74 | 75 | model.eval() 76 | if torch.__version__ >= "2" and sys.platform != "win32": 77 | model = torch.compile(model) 78 | 79 | def evaluate( 80 | instruction, 81 | input=None, 82 | temperature=0.1, 83 | top_p=0.75, 84 | top_k=40, 85 | num_beams=4, 86 | max_new_tokens=128, 87 | **kwargs, 88 | ): 89 | prompt = generate_prompt(instruction, input) 90 | inputs = tokenizer(prompt, return_tensors="pt") 91 | input_ids = inputs["input_ids"].to(device) 92 | generation_config = GenerationConfig( 93 | temperature=temperature, 94 | top_p=top_p, 95 | top_k=top_k, 96 | num_beams=num_beams, 97 | num_return_sequences=num_beams, 98 | **kwargs, 99 | ) 100 | with torch.no_grad(): 101 | generation_output = model.generate( 102 | input_ids=input_ids, 103 | generation_config=generation_config, 104 | return_dict_in_generate=True, 105 | output_scores=True, 106 | max_new_tokens=max_new_tokens, 107 | ) 108 | 109 | s = [] 110 | 111 | for i in range(num_beams): 112 | temp = generation_output.sequences[i] 113 | s.append(tokenizer.decode(temp,skip_special_tokens=True)) 114 | 115 | output = '' 116 | 117 | for cur in s: 118 | output += cur.split("### Response:")[1].strip() + '\n' 119 | 120 | 121 | 122 | return output 123 | 124 | gr.Interface( 125 | fn=evaluate, 126 | inputs=[ 127 | gr.components.Textbox( 128 | lines=2, 129 | label="Instruction", 130 | placeholder="Tell me about alpacas.", 131 | ), 132 | gr.components.Textbox(lines=2, label="Input", placeholder="none"), 133 | gr.components.Slider( 134 | minimum=0, maximum=1, value=0.1, label="Temperature" 135 | ), 136 | gr.components.Slider( 137 | minimum=0, maximum=1, value=0.75, label="Top p" 138 | ), 139 | gr.components.Slider( 140 | minimum=0, maximum=100, step=1, value=40, label="Top k" 141 | ), 142 | gr.components.Slider( 143 | minimum=1, maximum=4, step=1, value=4, label="Beams" 144 | ), 145 | gr.components.Slider( 146 | minimum=1, maximum=2000, step=1, value=128, label="Max tokens" 147 | ), 148 | ], 149 | outputs=[ 150 | gr.inputs.Textbox( 151 | lines=10, 152 | label="Output", 153 | ) 154 | ], 155 | title="Wiselab-LoRA", 156 | description="Wiselab-LoRA is a 7B-parameter LLaMA-LoRA based model.", # noqa: E501 157 | ).launch(share=share_gradio) 158 | # Old testing code follows. 159 | 160 | """ 161 | # testing code for readme 162 | for instruction in [ 163 | "Tell me about alpacas.", 164 | "Tell me about the president of Mexico in 2019.", 165 | "Tell me about the king of France in 2019.", 166 | "List all Canadian provinces in alphabetical order.", 167 | "Write a Python program that prints the first 10 Fibonacci numbers.", 168 | "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", # noqa: E501 169 | "Tell me five words that rhyme with 'shock'.", 170 | "Translate the sentence 'I have no mouth but I must scream' into Spanish.", 171 | "Count up from 1 to 500.", 172 | ]: 173 | print("Instruction:", instruction) 174 | print("Response:", evaluate(instruction)) 175 | print() 176 | """ 177 | 178 | 179 | def generate_prompt(instruction, input=None): 180 | if input: 181 | return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501 182 | 183 | ### Instruction: 184 | {instruction} 185 | 186 | ### Input: 187 | {input} 188 | 189 | ### Response: 190 | """ 191 | else: 192 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501 193 | 194 | ### Instruction: 195 | {instruction} 196 | 197 | ### Response: 198 | """ 199 | 200 | 201 | if __name__ == "__main__": 202 | fire.Fire(main) 203 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 79 3 | 4 | [tool.isort] 5 | include_trailing_comma = true 6 | line_length = 79 7 | multi_line_output = 3 8 | profile = "black" 9 | -------------------------------------------------------------------------------- /rec.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import List 4 | 5 | import fire 6 | import torch 7 | import transformers 8 | from datasets import load_dataset 9 | 10 | """ 11 | Unused imports: 12 | import torch.nn as nn 13 | import bitsandbytes as bnb 14 | """ 15 | 16 | from peft import ( # noqa: E402 17 | LoraConfig, 18 | get_peft_model, 19 | get_peft_model_state_dict, 20 | prepare_model_for_int8_training, 21 | set_peft_model_state_dict, 22 | ) 23 | from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402 24 | 25 | 26 | def train( 27 | # model/data params 28 | base_model: str = "", # the only required argument 29 | data_path: str = "yahma/alpaca-cleaned", 30 | output_dir: str = "/common/users/jj635/llama/mycheckpoint/", 31 | # training hyperparams 32 | batch_size: int = 128,#used to be 128 33 | micro_batch_size: int = 4, 34 | num_epochs: int = 3, 35 | learning_rate: float = 3e-4, 36 | cutoff_len: int = 256, 37 | val_set_size: int = 0, 38 | # lora hyperparams 39 | lora_r: int = 8, 40 | lora_alpha: int = 16, 41 | lora_dropout: float = 0.05, 42 | lora_target_modules: List[str] = [ 43 | "q_proj", 44 | "v_proj", 45 | ], 46 | # llm hyperparams 47 | train_on_inputs: bool = True, # if False, masks out inputs in loss 48 | group_by_length: bool = False, # faster, but produces an odd training loss curve 49 | # wandb params 50 | wandb_project: str = "", 51 | wandb_run_name: str = "", 52 | wandb_watch: str = "", # options: false | gradients | all 53 | wandb_log_model: str = "", # options: false | true 54 | resume_from_checkpoint: str = None, # either training checkpoint or final adapter 55 | ): 56 | print( 57 | f"Training Alpaca-LoRA model with params:\n" 58 | f"base_model: {base_model}\n" 59 | f"data_path: {data_path}\n" 60 | f"output_dir: {output_dir}\n" 61 | f"batch_size: {batch_size}\n" 62 | f"micro_batch_size: {micro_batch_size}\n" 63 | f"num_epochs: {num_epochs}\n" 64 | f"learning_rate: {learning_rate}\n" 65 | f"cutoff_len: {cutoff_len}\n" 66 | f"val_set_size: {val_set_size}\n" 67 | f"lora_r: {lora_r}\n" 68 | f"lora_alpha: {lora_alpha}\n" 69 | f"lora_dropout: {lora_dropout}\n" 70 | f"lora_target_modules: {lora_target_modules}\n" 71 | f"train_on_inputs: {train_on_inputs}\n" 72 | f"group_by_length: {group_by_length}\n" 73 | f"wandb_project: {wandb_project}\n" 74 | f"wandb_run_name: {wandb_run_name}\n" 75 | f"wandb_watch: {wandb_watch}\n" 76 | f"wandb_log_model: {wandb_log_model}\n" 77 | f"resume_from_checkpoint: {resume_from_checkpoint}\n" 78 | ) 79 | assert ( 80 | base_model 81 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" 82 | gradient_accumulation_steps = batch_size // micro_batch_size 83 | 84 | device_map = "auto" 85 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 86 | ddp = world_size != 1 87 | if ddp: 88 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 89 | #device_map = {"":0} 90 | gradient_accumulation_steps = gradient_accumulation_steps // world_size 91 | 92 | # Check if parameter passed or if set within environ 93 | use_wandb = len(wandb_project) > 0 or ( 94 | "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 95 | ) 96 | # Only overwrite environ if wandb param passed 97 | if len(wandb_project) > 0: 98 | os.environ["WANDB_PROJECT"] = wandb_project 99 | if len(wandb_watch) > 0: 100 | os.environ["WANDB_WATCH"] = wandb_watch 101 | if len(wandb_log_model) > 0: 102 | os.environ["WANDB_LOG_MODEL"] = wandb_log_model 103 | 104 | model = LlamaForCausalLM.from_pretrained( 105 | base_model, 106 | load_in_8bit=True, 107 | torch_dtype=torch.float16, 108 | device_map=device_map, 109 | ) 110 | 111 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 112 | 113 | tokenizer.pad_token_id = ( 114 | 0 # unk. we want this to be different from the eos token 115 | ) 116 | tokenizer.padding_side = "left" # Allow batched inference 117 | 118 | def tokenize(prompt, add_eos_token=True): 119 | # there's probably a way to do this with the tokenizer settings 120 | # but again, gotta move fast 121 | result = tokenizer( 122 | prompt, 123 | truncation=True, 124 | max_length=cutoff_len, 125 | padding=False, 126 | return_tensors=None, 127 | ) 128 | if ( 129 | result["input_ids"][-1] != tokenizer.eos_token_id 130 | and len(result["input_ids"]) < cutoff_len 131 | and add_eos_token 132 | ): 133 | result["input_ids"].append(tokenizer.eos_token_id) 134 | result["attention_mask"].append(1) 135 | 136 | result["labels"] = result["input_ids"].copy() 137 | 138 | return result 139 | 140 | def generate_and_tokenize_prompt(data_point): 141 | full_prompt = generate_prompt(data_point) 142 | tokenized_full_prompt = tokenize(full_prompt) 143 | if not train_on_inputs: 144 | user_prompt = generate_prompt({**data_point, "output": ""}) 145 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) 146 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 147 | 148 | tokenized_full_prompt["labels"] = [ 149 | -100 150 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 151 | user_prompt_len: 152 | ] # could be sped up, probably 153 | return tokenized_full_prompt 154 | 155 | model = prepare_model_for_int8_training(model) 156 | 157 | config = LoraConfig( 158 | r=lora_r, 159 | lora_alpha=lora_alpha, 160 | target_modules=lora_target_modules, 161 | lora_dropout=lora_dropout, 162 | bias="none", 163 | task_type="CAUSAL_LM", 164 | ) 165 | model = get_peft_model(model, config) 166 | 167 | if data_path.endswith(".json") or data_path.endswith(".jsonl"): 168 | data = load_dataset("json", data_files=data_path) 169 | else: 170 | data = load_dataset(data_path) 171 | 172 | if resume_from_checkpoint: 173 | # Check the available weights and load them 174 | checkpoint_name = os.path.join( 175 | resume_from_checkpoint, "pytorch_model.bin" 176 | ) # Full checkpoint 177 | if not os.path.exists(checkpoint_name): 178 | checkpoint_name = os.path.join( 179 | resume_from_checkpoint, "adapter_model.bin" 180 | ) # only LoRA model - LoRA config above has to fit 181 | resume_from_checkpoint = ( 182 | False # So the trainer won't try loading its state 183 | ) 184 | # The two files above have a different name depending on how they were saved, but are actually the same. 185 | if os.path.exists(checkpoint_name): 186 | print(f"Restarting from {checkpoint_name}") 187 | adapters_weights = torch.load(checkpoint_name) 188 | model = set_peft_model_state_dict(model, adapters_weights) 189 | else: 190 | print(f"Checkpoint {checkpoint_name} not found") 191 | 192 | model.print_trainable_parameters() # Be more transparent about the % of trainable params. 193 | 194 | if val_set_size > 0: 195 | train_val = data["train"].train_test_split( 196 | test_size=val_set_size, shuffle=True, seed=42 197 | ) 198 | train_data = ( 199 | train_val["train"].shuffle().map(generate_and_tokenize_prompt) 200 | ) 201 | val_data = ( 202 | train_val["test"].shuffle().map(generate_and_tokenize_prompt) 203 | ) 204 | else: 205 | train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) 206 | val_data = None 207 | 208 | if not ddp and torch.cuda.device_count() > 1: 209 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available 210 | model.is_parallelizable = True 211 | model.model_parallel = True 212 | 213 | #changed by myself 214 | #model.is_parallelizable = False 215 | #model.model_parallel = False 216 | 217 | trainer = transformers.Trainer( 218 | model=model, 219 | train_dataset=train_data, 220 | eval_dataset=val_data, 221 | args=transformers.TrainingArguments( 222 | per_device_train_batch_size=micro_batch_size, 223 | gradient_accumulation_steps=gradient_accumulation_steps, 224 | warmup_steps=100, 225 | num_train_epochs=num_epochs, 226 | learning_rate=learning_rate, 227 | fp16=True, 228 | logging_steps=10, 229 | optim="adamw_torch", 230 | evaluation_strategy="steps" if val_set_size > 0 else "no", 231 | save_strategy="steps", 232 | eval_steps=200 if val_set_size > 0 else None, 233 | save_steps=200, 234 | output_dir=output_dir, 235 | save_total_limit=3, 236 | load_best_model_at_end=True if val_set_size > 0 else False, 237 | ddp_find_unused_parameters=False if ddp else None, 238 | group_by_length=group_by_length, 239 | report_to="wandb" if use_wandb else None, 240 | run_name=wandb_run_name if use_wandb else None, 241 | ), 242 | data_collator=transformers.DataCollatorForSeq2Seq( 243 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 244 | ), 245 | ) 246 | model.config.use_cache = False 247 | 248 | old_state_dict = model.state_dict 249 | model.state_dict = ( 250 | lambda self, *_, **__: get_peft_model_state_dict( 251 | self, old_state_dict() 252 | ) 253 | ).__get__(model, type(model)) 254 | 255 | if torch.__version__ >= "2" and sys.platform != "win32": 256 | model = torch.compile(model) 257 | 258 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 259 | 260 | model.save_pretrained(output_dir) 261 | 262 | print( 263 | "\n If there's a warning about missing keys above, please disregard :)" 264 | ) 265 | 266 | 267 | def generate_prompt(data_point): 268 | # sorry about the formatting disaster gotta move fast 269 | if data_point["input"]: 270 | return f""" # noqa: E501 271 | {data_point["instruction"]} 272 | 273 | ### input: 274 | {data_point["input"]} 275 | 276 | ### Response: 277 | {data_point["output"]}""" 278 | else: 279 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501 280 | 281 | ### Instruction: 282 | {data_point["instruction"]} 283 | 284 | ### Response: 285 | {data_point["output"]}""" 286 | 287 | 288 | if __name__ == "__main__": 289 | fire.Fire(train) 290 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | appdirs 3 | bitsandbytes 4 | black 5 | black[jupyter] 6 | datasets 7 | fire 8 | git+https://github.com/huggingface/peft.git 9 | git+https://github.com/huggingface/transformers.git 10 | gradio 11 | --------------------------------------------------------------------------------