├── .dockerignore ├── .gitignore ├── LICENSE.txt ├── README.md ├── cog.yaml ├── config.py ├── ds_config ├── ds_z3_bf16_config.json └── ds_z3_fp16_config.json ├── examples └── alpaca │ ├── README.md │ ├── process_data.py │ └── replicate_alpaca_data.json ├── llama_weights ├── llama-13b │ ├── config.json │ └── generation_config.json ├── llama-7b │ ├── config.json │ └── generation_config.json └── tokenizer │ ├── special_tokens_map.json │ └── tokenizer_config.json ├── predict.py ├── scripts ├── cog_push_all.sh ├── train_multi_gpu.sh └── train_single_gpu.sh ├── select_model.py ├── subclass.py ├── templates └── config_template.py ├── test_deserialization.py ├── train.py └── training ├── __init__.py └── trainer.py /.dockerignore: -------------------------------------------------------------------------------- 1 | flan-t5** 2 | checkpoints/** 3 | examples/** 4 | weights_13/** 5 | tmp/** 6 | **.jsonl 7 | **.tensors -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/** 2 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022, Replicate, Inc. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLaMA Cog template 🦙 2 | 3 | LLaMA is a [new open-source language model from Meta Research](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) that performs as well as closed-source models. 4 | 5 | Similar to Stable Diffusion, this has created a wealth of experiments and innovation. [As Simon Willison articulated](https://simonwillison.net/2023/Mar/11/llama/), it's easy to run on your own hardware, large enough to be useful, and open-source enough to be tinkered with. 6 | 7 | This is a guide to running LLaMA using in the cloud using Replicate. You'll use the [Cog](https://github.com/replicate/cog) command-line tool to package the model and push it to Replicate as a web interface and API. 8 | 9 | This model can be used to run the `7B` version of LLaMA and it also works with fine-tuned models. 10 | 11 | **Note: LLaMA is for research purposes only. It is not intended for commercial use.** 12 | 13 | ## Prerequisites 14 | 15 | - **LLaMA weights**. The weights for LLaMA have not yet been released publicly. To apply for access, fill out [this Meta Research form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform). 16 | - **GPU machine**. You'll need a Linux machine with an NVIDIA GPU attached and the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker) installed. If you don't already have access to a machine with a GPU, check out our [guide to getting a GPU machine](https://replicate.com/docs/guides/get-a-gpu-machine). 17 | - **Docker**. You'll be using the [Cog](https://github.com/replicate/cog) command-line tool to build and push a model. Cog uses Docker to create containers for models. 18 | 19 | ## Step 0: Install Cog 20 | 21 | First, [install Cog](https://github.com/replicate/cog#install): 22 | 23 | ``` 24 | sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m)" 25 | sudo chmod +x /usr/local/bin/cog 26 | ``` 27 | 28 | ## Step 1: Set up weights 29 | 30 | Replicate currently supports the `7B` model size. 31 | 32 | Put your downloaded weights in a folder called `unconverted-weights`. The folder hierarchy should look something like this: 33 | 34 | ``` 35 | unconverted-weights 36 | ├── 7B 37 | │ ├── checklist.chk 38 | │ ├── consolidated.00.pth 39 | │ └── params.json 40 | ├── tokenizer.model 41 | └── tokenizer_checklist.chk 42 | ``` 43 | 44 | Convert the weights from a PyTorch checkpoint to a transformers-compatible format using the this command: 45 | 46 | ``` 47 | cog run python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir unconverted-weights --model_size 7B --output_dir weights 48 | ``` 49 | 50 | You final directory structure should look like this: 51 | 52 | ``` 53 | weights 54 | ├── config.json 55 | ├── generation_config.json 56 | ├── pytorch_model-00001-of-00002.bin 57 | ├── pytorch_model-00002-of-00002.bin 58 | ├── pytorch_model.bin.index.json 59 | ├── special_tokens_map.json 60 | ├── tokenizer.model 61 | └── tokenizer_config.json 62 | ``` 63 | 64 | Once you've done this, you should uncomment `unconverted-weights` in your `.dockerignore` file. This ensures that `unconverted-weights` aren't built into the resulting cog image. 65 | 66 | ## Step 2: Run the model 67 | 68 | You can run the model locally to test it: 69 | 70 | ``` 71 | cog predict -i prompt="Simply put, the theory of relativity states that" 72 | ``` 73 | 74 | LLaMA is not fine-tuned to answer questions. You should construct your prompt so that the expected answer is the natural continuation of your prompt. 75 | 76 | Here are a few examples from the [LLaMA FAQ](https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/FAQ.md#2-generations-are-bad): 77 | 78 | - Do not prompt with "What is the meaning of life? Be concise and do not repeat yourself." but with "I believe the meaning of life is" 79 | - Do not prompt with "Explain the theory of relativity." but with "Simply put, the theory of relativity states that" 80 | - Do not prompt with "Ten easy steps to build a website..." but with "Building a website can be done in 10 simple steps:\n" 81 | 82 | ## Step 3: Create a model on Replicate 83 | 84 | Go to [replicate.com/create](https://replicate.com/create) to create a Replicate model. 85 | 86 | Make sure to specify "private" to keep the model private. 87 | 88 | ## Step 4: Configure the model to run on A100 GPUs 89 | 90 | Replicate supports running models on a variety of GPUs. The default GPU type is a T4, but for best performance you'll want to configure your model to run on an A100. 91 | 92 | Click on the "Settings" tab on your model page, scroll down to "GPU hardware", and select "A100". Then click "Save". 93 | 94 | ## Step 5: Push the model to Replicate 95 | 96 | Log in to Replicate: 97 | 98 | ``` 99 | cog login 100 | ``` 101 | 102 | Push the contents of your current directory to Replicate, using the model name you specified in step 3: 103 | 104 | ``` 105 | cog push r8.im/username/modelname 106 | ``` 107 | 108 | [Learn more about pushing models to Replicate.](https://replicate.com/docs/guides/push-a-model) 109 | 110 | 111 | ## Step 6: Run the model on Replicate 112 | 113 | Now that you've pushed the model to Replicate, you can run it from the website or with an API. 114 | 115 | To use your model in the browser, go to your model page. 116 | 117 | To use your model with an API, click on the "API" tab on your model page. You'll see commands to run the model with cURL, Python, etc. 118 | 119 | To learn more about how to use Replicate, [check out our documentation](https://replicate.com/docs). -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: true 7 | cuda: "11.7" 8 | 9 | # python version in the form '3.8' or '3.8.12' 10 | python_version: "3.8" 11 | 12 | # a list of packages in the format == 13 | python_packages: 14 | - "numpy==1.24.2" 15 | - "torch==2.0.0" 16 | - "accelerate==0.18.0" 17 | - "peft==0.2.0" 18 | - "sentencepiece==0.1.97" 19 | - "tensorizer==1.0.1" 20 | - "jinja2==3.1.2" 21 | - "deepspeed==0.8.3" 22 | 23 | 24 | run: 25 | - "pip install git+https://github.com/huggingface/transformers.git@786092a35e18154cacad62c30fe92bac2c27a1e1" 26 | - "mkdir /gc && cd /gc && curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-426.0.0-linux-x86_64.tar.gz && tar -xf google-cloud-cli-426.0.0-linux-x86_64.tar.gz && ./google-cloud-sdk/install.sh -q" 27 | - "pip install google-cloud-storage" 28 | 29 | 30 | # predict.py defines how predictions are run on your model 31 | predict: "predict.py:Predictor" 32 | train: "train.py:train" -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import logging 3 | import re 4 | import time 5 | from transformers import LlamaTokenizer, AutoConfig, LlamaForCausalLM 6 | import torch 7 | import subprocess 8 | from subprocess import DEVNULL, STDOUT 9 | from tensorizer import TensorDeserializer 10 | from tensorizer.utils import no_init_or_tensor 11 | 12 | from subclass import YieldingLlama 13 | 14 | DEFAULT_MODEL_NAME = "llama_weights/llama-7b" # path from which we pull weights when there's no COG_WEIGHTS environment variable 15 | TOKENIZER_NAME = "llama_weights/tokenizer" 16 | CONFIG_LOCATION = "llama_weights/llama-7b" 17 | 18 | DEFAULT_PAD_TOKEN = "[PAD]" 19 | DEFAULT_EOS_TOKEN = "" 20 | DEFAULT_BOS_TOKEN = "" 21 | DEFAULT_UNK_TOKEN = "" 22 | 23 | 24 | def load_tokenizer(): 25 | """Same tokenizer, agnostic from tensorized weights/etc""" 26 | tok = LlamaTokenizer.from_pretrained(TOKENIZER_NAME, cache_dir="pretrained_weights") 27 | tok.add_special_tokens( 28 | { 29 | "eos_token": DEFAULT_EOS_TOKEN, 30 | "bos_token": DEFAULT_BOS_TOKEN, 31 | "unk_token": DEFAULT_UNK_TOKEN, 32 | "pad_token": DEFAULT_PAD_TOKEN, 33 | } 34 | ) 35 | return tok 36 | 37 | def pull_gcp_file(weights, local_filename): 38 | """Pulls weights from GCP to local storage""" 39 | pattern = r'https://pbxt\.replicate\.delivery/([^/]+/[^/]+)' 40 | match = re.search(pattern, weights) 41 | if match: 42 | weights = f"gs://replicate-files/{match.group(1)}" 43 | 44 | command = ( 45 | f"/gc/google-cloud-sdk/bin/gcloud storage cp {weights} {local_filename}".split() 46 | ) 47 | res = subprocess.run(command) 48 | if res.returncode != 0: 49 | raise Exception( 50 | f"gcloud storage cp command failed with return code {res.returncode}: {res.stderr.decode('utf-8')}" 51 | ) 52 | return 53 | 54 | 55 | 56 | def load_tensorizer( 57 | weights, plaid_mode: bool = True, cls: LlamaForCausalLM = YieldingLlama 58 | ): 59 | st = time.time() 60 | weights = str(weights) 61 | local_weights = "/src/llama_tensors" 62 | print("Deserializing weights...") 63 | if 'http' in weights or 'gs' in weights: 64 | pull_gcp_file(weights, local_weights) 65 | else: 66 | local_weights = weights 67 | 68 | config = AutoConfig.from_pretrained(CONFIG_LOCATION) 69 | 70 | logging.disable(logging.WARN) 71 | model = no_init_or_tensor( 72 | lambda: cls.from_pretrained( 73 | None, config=config, state_dict=OrderedDict(), torch_dtype=torch.float16 74 | ) 75 | ) 76 | logging.disable(logging.NOTSET) 77 | 78 | des = TensorDeserializer(local_weights, plaid_mode=plaid_mode) 79 | des.load_into_module(model) 80 | print(f"weights loaded in {time.time() - st}") 81 | return model 82 | -------------------------------------------------------------------------------- /ds_config/ds_z3_bf16_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupLR", 16 | "params": { 17 | "warmup_min_lr": "auto", 18 | "warmup_max_lr": "auto", 19 | "warmup_num_steps": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "offload_param": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "overlap_comm": true, 33 | "contiguous_gradients": true, 34 | "sub_group_size": 1e9, 35 | "reduce_bucket_size": "auto", 36 | "stage3_prefetch_bucket_size": "auto", 37 | "stage3_param_persistence_threshold": "auto", 38 | "stage3_max_live_parameters": 1e9, 39 | "stage3_max_reuse_distance": 1e9, 40 | "stage3_gather_16bit_weights_on_model_save": true 41 | }, 42 | "gradient_accumulation_steps": "auto", 43 | "gradient_clipping": "auto", 44 | "steps_per_print": 2000, 45 | "train_batch_size": "auto", 46 | "train_micro_batch_size_per_gpu": "auto", 47 | "wall_clock_breakdown": false 48 | } -------------------------------------------------------------------------------- /ds_config/ds_z3_fp16_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "initial_scale_power": 16, 6 | "loss_scale_window": 1000, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "optimizer": { 11 | "type": "AdamW", 12 | "params": { 13 | "lr": "auto", 14 | "betas": "auto", 15 | "eps": "auto", 16 | "weight_decay": "auto" 17 | } 18 | }, 19 | "scheduler": { 20 | "type": "WarmupLR", 21 | "params": { 22 | "warmup_min_lr": "auto", 23 | "warmup_max_lr": "auto", 24 | "warmup_num_steps": "auto" 25 | } 26 | }, 27 | "zero_optimization": { 28 | "stage": 1, 29 | "overlap_comm": true, 30 | "contiguous_gradients": true, 31 | "sub_group_size": 1e9, 32 | "reduce_bucket_size": "auto" 33 | }, 34 | "gradient_accumulation_steps": "auto", 35 | "gradient_clipping": "auto", 36 | "steps_per_print": 2000, 37 | "train_batch_size": "auto", 38 | "train_micro_batch_size_per_gpu": "auto", 39 | "wall_clock_breakdown": false 40 | } -------------------------------------------------------------------------------- /examples/alpaca/README.md: -------------------------------------------------------------------------------- 1 | Example code for parsing the dataset needed to train [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca). 2 | 3 | This contains both a function, `process_data.py`, which shows how to transform the [given alpaca data](https://github.com/gururise/AlpacaDataCleaned) into the format expected by `cog train`. It also contains an example parsed dataset as a reference for that `{'prompt': ..., 'completion':...}` format. -------------------------------------------------------------------------------- /examples/alpaca/process_data.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer 2 | import json 3 | 4 | PROMPT_DICT = { 5 | "prompt_input": ( 6 | "Below is an instruction that describes a task, paired with an input that provides further context. " 7 | "Write a response that appropriately completes the request.\n\n" 8 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 9 | ), 10 | "prompt_no_input": ( 11 | "Below is an instruction that describes a task. " 12 | "Write a response that appropriately completes the request.\n\n" 13 | "### Instruction:\n{instruction}\n\n### Response:" 14 | ), 15 | } 16 | 17 | class Preprocessor: 18 | """Simple class to parse alpaca data into format expected by trainer. Run this offline to build your dataset.""" 19 | 20 | def __init__(self, tokenizer): 21 | self.prompt_dict = PROMPT_DICT 22 | self.tokenizer = tokenizer 23 | 24 | def batch_tokenize(self, texts): 25 | """Tokenizes text. Presently doesn't pad inputs, just returns input ids.""" 26 | tokenized = [ 27 | self.tokenizer( 28 | prompt, 29 | return_tensors="pt", 30 | padding="longest", 31 | ).input_ids 32 | for prompt in texts 33 | ] 34 | return tokenized 35 | 36 | def make_prompt(self, input_row): 37 | if len(input_row["input"]) > 1: 38 | return self.prompt_dict["prompt_input"].format_map(input_row) 39 | return self.prompt_dict["prompt_no_input"].format_map(input_row) 40 | 41 | def make_short_prompt(self, input_row): 42 | if len(input_row["input"]) > 1: 43 | return f'''{input_row['instruction']}\n{input_row['input']}''' 44 | return input_row['instruction'] 45 | 46 | def construct_dataset(self, input_data): 47 | prompts = [self.make_short_prompt(val) for val in input_data] 48 | return [{'prompt':val[0], 'completion':val[1]} for val in zip(prompts, [val["output"] for val in input_data])] 49 | 50 | if __name__ == '__main__': 51 | proc = Preprocessor(T5Tokenizer.from_pretrained('google/flan-t5-xl')) 52 | with open('alpaca_data.json', 'r') as f: 53 | data = json.load(f) 54 | 55 | data_out = proc.construct_dataset(data) 56 | 57 | with open('short_alpaca_data.json', 'w') as f: 58 | json.dump(data_out, f, indent=2) 59 | -------------------------------------------------------------------------------- /llama_weights/llama-13b/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "bos_token_id": 1, 6 | "eos_token_id": 2, 7 | "hidden_act": "silu", 8 | "hidden_size": 5120, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 13824, 11 | "model_type": "llama", 12 | "num_attention_heads": 40, 13 | "num_hidden_layers": 40, 14 | "pad_token_id": 0, 15 | "rms_norm_eps": 1e-06, 16 | "tie_word_embeddings": false, 17 | "torch_dtype": "float16", 18 | "transformers_version": "4.28.0.dev0", 19 | "use_cache": true, 20 | "vocab_size": 32000 21 | } 22 | -------------------------------------------------------------------------------- /llama_weights/llama-13b/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "bos_token_id": 1, 4 | "eos_token_id": 2, 5 | "pad_token_id": 0, 6 | "transformers_version": "4.28.0.dev0" 7 | } 8 | -------------------------------------------------------------------------------- /llama_weights/llama-7b/config.json: -------------------------------------------------------------------------------- 1 | {"architectures": ["LLaMAForCausalLM"], "bos_token_id": 1, "eos_token_id": 2, "hidden_act": "silu", "hidden_size": 4096, "intermediate_size": 11008, "initializer_range": 0.02, "max_sequence_length": 2048, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "pad_token_id": 0, "rms_norm_eps": 1e-06, "torch_dtype": "float16", "transformers_version": "4.27.0.dev0", "use_cache": true, "vocab_size": 32000} -------------------------------------------------------------------------------- /llama_weights/llama-7b/generation_config.json: -------------------------------------------------------------------------------- 1 | {"_from_model_config": true, "bos_token_id": 1, "eos_token_id": 2, "pad_token_id": 0, "transformers_version": "4.27.0.dev0"} -------------------------------------------------------------------------------- /llama_weights/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /llama_weights/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "", "eos_token": "", "model_max_length": 1000000000000000019884624838656, "tokenizer_class": "LlamaTokenizer", "unk_token": ""} -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import time 3 | from typing import Optional 4 | import zipfile 5 | 6 | import torch 7 | from cog import BasePredictor, ConcatenateIterator, Input, Path 8 | 9 | from config import DEFAULT_MODEL_NAME, load_tokenizer, load_tensorizer, pull_gcp_file 10 | from subclass import YieldingLlama 11 | from peft import PeftModel 12 | import os 13 | 14 | 15 | class Predictor(BasePredictor): 16 | def setup(self, weights: Optional[Path] = None): 17 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 18 | if weights is not None and weights.name == "weights": 19 | # bugfix 20 | weights = None 21 | 22 | if weights is None: 23 | self.model = load_tensorizer(weights=DEFAULT_MODEL_NAME, plaid_mode=True, cls=YieldingLlama) 24 | else: 25 | weights = str(weights) 26 | if '.zip' in weights: 27 | self.model = self.load_peft(weights) 28 | elif "tensors" in weights: 29 | self.model = load_tensorizer(weights, plaid_mode=True, cls=YieldingLlama) 30 | else: 31 | self.model = self.load_huggingface_model(weights=weights) 32 | 33 | self.tokenizer = load_tokenizer() 34 | 35 | def load_peft(self, weights): 36 | st = time.time() 37 | if 'tensors' in DEFAULT_MODEL_NAME: 38 | model = load_tensorizer(DEFAULT_MODEL_NAME, plaid_mode=False, cls=YieldingLlama) 39 | else: 40 | model = self.load_huggingface_model(DEFAULT_MODEL_NAME) 41 | if 'https' in weights: # weights are in the cloud 42 | local_weights = 'local_weights.zip' 43 | pull_gcp_file(weights, local_weights) 44 | weights = local_weights 45 | out = '/src/peft_dir' 46 | if os.path.exists(out): 47 | shutil.rmtree(out) 48 | with zipfile.ZipFile(weights, 'r') as zip_ref: 49 | zip_ref.extractall(out) 50 | model = PeftModel.from_pretrained(model, out) 51 | print(f"peft model loaded in {time.time() - st}") 52 | return model.to('cuda') 53 | 54 | def load_huggingface_model(self, weights=None): 55 | st = time.time() 56 | print(f"loading weights from {weights} w/o tensorizer") 57 | model = YieldingLlama.from_pretrained( 58 | weights, cache_dir="pretrained_weights", torch_dtype=torch.float16 59 | ) 60 | model.to(self.device) 61 | print(f"weights loaded in {time.time() - st}") 62 | return model 63 | 64 | def predict( 65 | self, 66 | prompt: str = Input(description=f"Prompt to send to Llama."), 67 | max_length: int = Input( 68 | description="Maximum number of tokens to generate. A word is generally 2-3 tokens", 69 | ge=1, 70 | default=500, 71 | ), 72 | temperature: float = Input( 73 | description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic, 0.75 is a good starting value.", 74 | ge=0.01, 75 | le=5, 76 | default=0.75, 77 | ), 78 | top_p: float = Input( 79 | description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", 80 | ge=0.01, 81 | le=1.0, 82 | default=1.0, 83 | ), 84 | repetition_penalty: float = Input( 85 | description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it.", 86 | ge=0.01, 87 | le=5, 88 | default=1, 89 | ), 90 | debug: bool = Input( 91 | description="provide debugging output in logs", default=False 92 | ), 93 | ) -> ConcatenateIterator[str]: 94 | input = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) 95 | 96 | with torch.inference_mode() and torch.autocast("cuda"): 97 | first_token_yielded = False 98 | prev_ids = [] 99 | for output in self.model.generate( 100 | input_ids=input, 101 | max_length=max_length, 102 | do_sample=True, 103 | temperature=temperature, 104 | top_p=top_p, 105 | repetition_penalty=repetition_penalty, 106 | ): 107 | cur_id = output.item() 108 | 109 | # in order to properly handle spaces, we need to do our own tokenizing. Fun! 110 | # we're building up a buffer of sub-word / punctuation tokens until we hit a space, and then yielding whole words + punctuation. 111 | cur_token = self.tokenizer.convert_ids_to_tokens(cur_id) 112 | 113 | # skip initial newline, which this almost always yields. hack - newline id = 13. 114 | if not first_token_yielded and not prev_ids and cur_id == 13: 115 | continue 116 | 117 | # underscore means a space, means we yield previous tokens 118 | if cur_token.startswith("▁"): # this is not a standard underscore. 119 | # first token 120 | if not prev_ids: 121 | prev_ids = [cur_id] 122 | continue 123 | 124 | # there are tokens to yield 125 | else: 126 | token = self.tokenizer.decode(prev_ids) 127 | prev_ids = [cur_id] 128 | 129 | if not first_token_yielded: 130 | # no leading space for first token 131 | token = token.strip() 132 | first_token_yielded = True 133 | yield token 134 | else: 135 | prev_ids.append(cur_id) 136 | continue 137 | 138 | # remove any special tokens such as 139 | token = self.tokenizer.decode(prev_ids, skip_special_tokens=True) 140 | if not first_token_yielded: 141 | # no leading space for first token 142 | token = token.strip() 143 | first_token_yielded = True 144 | yield token 145 | 146 | if debug: 147 | print(f"cur memory: {torch.cuda.memory_allocated()}") 148 | print(f"max allocated: {torch.cuda.max_memory_allocated()}") 149 | print(f"peak memory: {torch.cuda.max_memory_reserved()}") 150 | 151 | 152 | class EightBitPredictor(Predictor): 153 | """subclass s.t. we can configure whether a model is loaded in 8bit mode from cog.yaml""" 154 | 155 | def setup(self, weights: Optional[Path] = None): 156 | if weights is not None and weights.name == "weights": 157 | # bugfix 158 | weights = None 159 | # TODO: fine-tuned 8bit weights. 160 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 161 | self.model = YieldingLlama.from_pretrained( 162 | DEFAULT_MODEL_NAME, load_in_8bit=True, device_map="auto" 163 | ) 164 | self.tokenizer = load_tokenizer() 165 | -------------------------------------------------------------------------------- /scripts/cog_push_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_names=("llama-7b") 4 | 5 | for model_name in "${model_names[@]}"; do 6 | echo "Pushing model: $model_name" 7 | cog run python select_model.py --model_name $model_name 8 | cog login --token-stdin <<< "$COG_TOKEN" 9 | cog push r8.im/replicate/$model_name 10 | done 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /scripts/train_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py \ 4 | --train_data 70k_samples_prompt.jsonl \ 5 | --num_train_epochs 1 \ 6 | --learning_rate 2e-5 \ 7 | --train_batch_size 2 \ 8 | --gradient_accumulation_steps 4 \ 9 | --logging_steps 2 \ 10 | --warmup_ratio 0.03 \ 11 | --weights /src/weights_13 -------------------------------------------------------------------------------- /scripts/train_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py \ 4 | --model_name_or_path google/flan-t5-base \ 5 | --data_path ./replicate_alpaca_data.json \ 6 | --num_train_epochs 3 \ 7 | --learning_rate 3e-4 \ 8 | --train_batch_size 8 \ 9 | --warmup_ratio 0.03 \ 10 | --max_steps 10 # number of steps before returning, mostly useful for testing performance 11 | -------------------------------------------------------------------------------- /select_model.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import stat 4 | from jinja2 import Template 5 | 6 | # llama configs, can modify as needed. 7 | CONFIGS = { 8 | "llama-7b": { 9 | "cog_yaml_parameters": {"predictor":"predict.py:Predictor"}, 10 | "config_py_parameters": {"model_name": "SET_ME", "config_location": "llama_weights/llama-7b"} 11 | }, 12 | "llama-13b": { 13 | "cog_yaml_parameters": {"predictor":"predict.py:Predictor"}, 14 | "config_py_parameters": {"model_name": "SET_ME", "config_location": "llama_weights/llama-13b"} 15 | }, 16 | } 17 | 18 | def _reset_file(file_path): 19 | if os.path.exists(file_path): 20 | os.remove(file_path) 21 | 22 | 23 | def write_one_config(template_fpath: str, fname_out: str, config: dict): 24 | with open(template_fpath, "r") as f: 25 | template_content = f.read() 26 | base_template = Template(template_content) 27 | 28 | _reset_file(fname_out) 29 | 30 | with open(fname_out, "w") as f: 31 | f.write(base_template.render(config)) 32 | 33 | # Give all users write access to resulting generated file. 34 | current_permissions = os.stat(fname_out).st_mode 35 | new_permissions = current_permissions | stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH 36 | os.chmod(fname_out, new_permissions) 37 | 38 | 39 | def write_configs(model_name): 40 | master_config = CONFIGS[model_name] 41 | #write_one_config("templates/cog_template.yaml", "cog.yaml", master_config['cog_yaml_parameters']) 42 | write_one_config("templates/config_template.py", "cronfig.py", master_config['config_py_parameters']) 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = ArgumentParser() 47 | parser.add_argument("--model_name", default="llama-7b", help="name of the flan-t5 model you want to configure cog for") 48 | args = parser.parse_args() 49 | 50 | write_configs(args.model_name) -------------------------------------------------------------------------------- /subclass.py: -------------------------------------------------------------------------------- 1 | """sampling code pulled from Transformers & slightly modified to stream tokens""" 2 | import warnings 3 | from typing import List, Optional, Union 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from torch import nn 8 | 9 | from transformers.generation.logits_process import LogitsProcessorList 10 | from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria 11 | from transformers.generation.utils import SampleOutput, SampleDecoderOnlyOutput, SampleEncoderDecoderOutput 12 | 13 | from transformers import LlamaForCausalLM 14 | 15 | class YieldingLlama(LlamaForCausalLM): 16 | """Overriding sample to yield tokens""" 17 | def sample( 18 | self, 19 | input_ids: torch.LongTensor, 20 | logits_processor: Optional[LogitsProcessorList] = None, 21 | stopping_criteria: Optional[StoppingCriteriaList] = None, 22 | logits_warper: Optional[LogitsProcessorList] = None, 23 | max_length: Optional[int] = None, 24 | pad_token_id: Optional[int] = None, 25 | eos_token_id: Optional[Union[int, List[int]]] = None, 26 | output_attentions: Optional[bool] = None, 27 | output_hidden_states: Optional[bool] = None, 28 | output_scores: Optional[bool] = None, 29 | return_dict_in_generate: Optional[bool] = None, 30 | synced_gpus: Optional[bool] = False, 31 | **model_kwargs, 32 | ) -> Union[SampleOutput, torch.LongTensor]: 33 | r""" 34 | Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and 35 | can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. 36 | 37 | 38 | 39 | In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. 40 | For an overview of generation strategies and code examples, check the [following 41 | guide](./generation_strategies). 42 | 43 | 44 | 45 | Parameters: 46 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 47 | The sequence used as a prompt for the generation. 48 | logits_processor (`LogitsProcessorList`, *optional*): 49 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] 50 | used to modify the prediction scores of the language modeling head applied at each generation step. 51 | stopping_criteria (`StoppingCriteriaList`, *optional*): 52 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] 53 | used to tell if the generation loop should stop. 54 | logits_warper (`LogitsProcessorList`, *optional*): 55 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used 56 | to warp the prediction score distribution of the language modeling head applied before multinomial 57 | sampling at each generation step. 58 | max_length (`int`, *optional*, defaults to 20): 59 | **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated 60 | tokens. The maximum length of the sequence to be generated. 61 | pad_token_id (`int`, *optional*): 62 | The id of the *padding* token. 63 | eos_token_id (`int`, *optional*): 64 | The id of the *end-of-sequence* token. 65 | output_attentions (`bool`, *optional*, defaults to `False`): 66 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 67 | returned tensors for more details. 68 | output_hidden_states (`bool`, *optional*, defaults to `False`): 69 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 70 | for more details. 71 | output_scores (`bool`, *optional*, defaults to `False`): 72 | Whether or not to return the prediction scores. See `scores` under returned tensors for more details. 73 | return_dict_in_generate (`bool`, *optional*, defaults to `False`): 74 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 75 | synced_gpus (`bool`, *optional*, defaults to `False`): 76 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3) 77 | model_kwargs: 78 | Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is 79 | an encoder-decoder model the kwargs should include `encoder_outputs`. 80 | 81 | Return: 82 | [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: 83 | A `torch.LongTensor` containing the generated tokens (default behaviour) or a 84 | [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and 85 | `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if 86 | `model.config.is_encoder_decoder=True`. 87 | 88 | Examples: 89 | 90 | ```python 91 | >>> from transformers import ( 92 | ... AutoTokenizer, 93 | ... AutoModelForCausalLM, 94 | ... LogitsProcessorList, 95 | ... MinLengthLogitsProcessor, 96 | ... TopKLogitsWarper, 97 | ... TemperatureLogitsWarper, 98 | ... StoppingCriteriaList, 99 | ... MaxLengthCriteria, 100 | ... ) 101 | >>> import torch 102 | 103 | >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") 104 | >>> model = AutoModelForCausalLM.from_pretrained("gpt2") 105 | 106 | >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token 107 | >>> model.config.pad_token_id = model.config.eos_token_id 108 | >>> model.generation_config.pad_token_id = model.config.eos_token_id 109 | 110 | >>> input_prompt = "Today is a beautiful day, and" 111 | >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids 112 | 113 | >>> # instantiate logits processors 114 | >>> logits_processor = LogitsProcessorList( 115 | ... [ 116 | ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), 117 | ... ] 118 | ... ) 119 | >>> # instantiate logits processors 120 | >>> logits_warper = LogitsProcessorList( 121 | ... [ 122 | ... TopKLogitsWarper(50), 123 | ... TemperatureLogitsWarper(0.7), 124 | ... ] 125 | ... ) 126 | 127 | >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) 128 | 129 | >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT 130 | >>> outputs = model.sample( 131 | ... input_ids, 132 | ... logits_processor=logits_processor, 133 | ... logits_warper=logits_warper, 134 | ... stopping_criteria=stopping_criteria, 135 | ... ) 136 | 137 | >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) 138 | ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] 139 | ```""" 140 | # init values 141 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 142 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 143 | if max_length is not None: 144 | warnings.warn( 145 | "`max_length` is deprecated in this function, use" 146 | " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", 147 | UserWarning, 148 | ) 149 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) 150 | logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() 151 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id 152 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id 153 | if isinstance(eos_token_id, int): 154 | eos_token_id = [eos_token_id] 155 | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores 156 | output_attentions = ( 157 | output_attentions if output_attentions is not None else self.generation_config.output_attentions 158 | ) 159 | output_hidden_states = ( 160 | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states 161 | ) 162 | return_dict_in_generate = ( 163 | return_dict_in_generate 164 | if return_dict_in_generate is not None 165 | else self.generation_config.return_dict_in_generate 166 | ) 167 | 168 | # init attention / hidden states / scores tuples 169 | scores = () if (return_dict_in_generate and output_scores) else None 170 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 171 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 172 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 173 | 174 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 175 | if return_dict_in_generate and self.config.is_encoder_decoder: 176 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None 177 | encoder_hidden_states = ( 178 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None 179 | ) 180 | 181 | # keep track of which sequences are already finished 182 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) 183 | 184 | this_peer_finished = False # used by synced_gpus only 185 | # auto-regressive generation 186 | while True: 187 | if synced_gpus: 188 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 189 | # The following logic allows an early break if all peers finished generating their sequence 190 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 191 | # send 0.0 if we finished, 1.0 otherwise 192 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 193 | # did all peers finish? the reduced sum will be 0.0 then 194 | if this_peer_finished_flag.item() == 0.0: 195 | break 196 | 197 | # prepare model inputs 198 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 199 | 200 | # forward pass to get next token 201 | with torch.inference_mode(): 202 | outputs = self( 203 | **model_inputs, 204 | return_dict=True, 205 | output_attentions=output_attentions, 206 | output_hidden_states=output_hidden_states, 207 | ) 208 | 209 | if synced_gpus and this_peer_finished: 210 | continue # don't waste resources running the code we don't need 211 | 212 | next_token_logits = outputs.logits[:, -1, :] 213 | 214 | # pre-process distribution 215 | next_token_scores = logits_processor(input_ids, next_token_logits) 216 | next_token_scores = logits_warper(input_ids, next_token_scores) 217 | 218 | # Store scores, attentions and hidden_states when required 219 | if return_dict_in_generate: 220 | if output_scores: 221 | scores += (next_token_scores,) 222 | if output_attentions: 223 | decoder_attentions += ( 224 | (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) 225 | ) 226 | if self.config.is_encoder_decoder: 227 | cross_attentions += (outputs.cross_attentions,) 228 | 229 | if output_hidden_states: 230 | decoder_hidden_states += ( 231 | (outputs.decoder_hidden_states,) 232 | if self.config.is_encoder_decoder 233 | else (outputs.hidden_states,) 234 | ) 235 | 236 | # sample 237 | probs = nn.functional.softmax(next_token_scores, dim=-1) 238 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 239 | 240 | # finished sentences should have their next token be a padding token 241 | if eos_token_id is not None: 242 | if pad_token_id is None: 243 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 244 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 245 | 246 | # update generated ids, model inputs, and length for next step 247 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 248 | model_kwargs = self._update_model_kwargs_for_generation( 249 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 250 | ) 251 | 252 | # if eos_token was found in one sentence, set sentence to finished 253 | if eos_token_id is not None: 254 | unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) 255 | 256 | # stop when each sentence is finished, or if we exceed the maximum length 257 | if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): 258 | if not synced_gpus: 259 | break 260 | else: 261 | this_peer_finished = True 262 | else: 263 | yield next_tokens 264 | 265 | if return_dict_in_generate: 266 | if self.config.is_encoder_decoder: 267 | yield SampleEncoderDecoderOutput( 268 | sequences=input_ids, 269 | scores=scores, 270 | encoder_attentions=encoder_attentions, 271 | encoder_hidden_states=encoder_hidden_states, 272 | decoder_attentions=decoder_attentions, 273 | cross_attentions=cross_attentions, 274 | decoder_hidden_states=decoder_hidden_states, 275 | ) 276 | else: 277 | yield SampleDecoderOnlyOutput( 278 | sequences=input_ids, 279 | scores=scores, 280 | attentions=decoder_attentions, 281 | hidden_states=decoder_hidden_states, 282 | ) 283 | else: 284 | yield next_tokens -------------------------------------------------------------------------------- /templates/config_template.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import logging 3 | import re 4 | import time 5 | from transformers import LlamaTokenizer, AutoConfig, LlamaForCausalLM 6 | import torch 7 | import subprocess 8 | from subprocess import DEVNULL, STDOUT 9 | from tensorizer import TensorDeserializer 10 | from tensorizer.utils import no_init_or_tensor 11 | 12 | from subclass import YieldingLlama 13 | 14 | DEFAULT_MODEL_NAME = "{{model_name}}" # path from which we pull weights when there's no COG_WEIGHTS environment variable 15 | TOKENIZER_NAME = "llama_weights/tokenizer" 16 | CONFIG_LOCATION = "{{config_location}}" 17 | 18 | DEFAULT_PAD_TOKEN = "[PAD]" 19 | DEFAULT_EOS_TOKEN = "" 20 | DEFAULT_BOS_TOKEN = "" 21 | DEFAULT_UNK_TOKEN = "" 22 | 23 | 24 | def load_tokenizer(): 25 | """Same tokenizer, agnostic from tensorized weights/etc""" 26 | tok = LlamaTokenizer.from_pretrained(TOKENIZER_NAME, cache_dir="pretrained_weights") 27 | tok.add_special_tokens( 28 | { 29 | "eos_token": DEFAULT_EOS_TOKEN, 30 | "bos_token": DEFAULT_BOS_TOKEN, 31 | "unk_token": DEFAULT_UNK_TOKEN, 32 | "pad_token": DEFAULT_PAD_TOKEN, 33 | } 34 | ) 35 | return tok 36 | 37 | 38 | def pull_gcp_file(weights, local_filename): 39 | """Pulls weights from GCP to local storage""" 40 | pattern = r'https://pbxt\.replicate\.delivery/([^/]+/[^/]+)' 41 | match = re.search(pattern, weights) 42 | if match: 43 | weights = f"gs://replicate-files/{match.group(1)}" 44 | 45 | command = ( 46 | f"/gc/google-cloud-sdk/bin/gcloud storage cp {weights} {local_filename}".split() 47 | ) 48 | res = subprocess.run(command) 49 | if res.returncode != 0: 50 | raise Exception( 51 | f"gcloud storage cp command failed with return code {res.returncode}: {res.stderr.decode('utf-8')}" 52 | ) 53 | return 54 | 55 | 56 | def load_tensorizer( 57 | weights, plaid_mode: bool = True, cls: LlamaForCausalLM = YieldingLlama 58 | ): 59 | st = time.time() 60 | weights = str(weights) 61 | local_weights = "/src/llama_tensors" 62 | print("Deserializing weights...") 63 | if 'http' in weights: 64 | pull_gcp_file(weights, local_weights) 65 | else: 66 | local_weights = weights 67 | 68 | config = AutoConfig.from_pretrained(CONFIG_LOCATION) 69 | 70 | logging.disable(logging.WARN) 71 | model = no_init_or_tensor( 72 | lambda: cls.from_pretrained( 73 | None, config=config, state_dict=OrderedDict(), torch_dtype=torch.float16 74 | ) 75 | ) 76 | logging.disable(logging.NOTSET) 77 | 78 | des = TensorDeserializer(local_weights, plaid_mode=plaid_mode) 79 | des.load_into_module(model) 80 | print(f"weights loaded in {time.time() - st}") 81 | return model 82 | -------------------------------------------------------------------------------- /test_deserialization.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor 2 | import concurrent.futures 3 | import os 4 | import time 5 | from collections import OrderedDict 6 | from typing import Optional, List 7 | import io 8 | import subprocess 9 | 10 | 11 | import torch 12 | from tensorizer import TensorDeserializer 13 | from tensorizer.utils import no_init_or_tensor 14 | from transformers import AutoConfig 15 | from google.cloud import storage 16 | 17 | 18 | from config import CONFIG_LOCATION, load_tokenizer 19 | from subclass import YieldingLlama 20 | import logging 21 | 22 | MODEL_OUT = "/src/tuned_weights.tensors" 23 | CHECKPOINT_DIR = "checkpoints" 24 | SAVE_STRATEGY = "epoch" 25 | DIST_OUT_DIR = "tmp/model" 26 | 27 | def test_cli_deserialization(): 28 | 29 | path = "path_to_weights" 30 | 31 | st = time.time() 32 | print("downloading weights") 33 | # don't love the whole /gc/google-cloud-sdk/bin/gcloud path but don't think there's an easy way to update PATH from cog, so might as well do this. 34 | weights = "/src/llama_tensors" 35 | os.system(f"/gc/google-cloud-sdk/bin/gcloud storage cp {path} {weights}") 36 | print(f"weignts downloaded in {time.time() - st}") 37 | print(f"deserializing weights from {weights}") 38 | config = AutoConfig.from_pretrained(CONFIG_LOCATION) 39 | 40 | logging.disable(logging.WARN) 41 | model = no_init_or_tensor( 42 | lambda: YieldingLlama.from_pretrained( 43 | None, config=config, state_dict=OrderedDict(), torch_dtype=torch.float16 44 | ) 45 | ) 46 | logging.disable(logging.NOTSET) 47 | des = TensorDeserializer(weights, plaid_mode=False) 48 | des.load_into_module(model) 49 | print(f"zero to weights in {time.time() - st}") 50 | 51 | 52 | def test_in_memory_cli_deserialization(): 53 | """This is quite slow, turns out that gcloud storage streaming into memory (-) runs in series.""" 54 | path = "path/to/weights" 55 | st = time.time() 56 | print("downloading weights") 57 | # don't love the whole /gc/google-cloud-sdk/bin/gcloud path but don't think there's an easy way to update PATH from cog, so might as well do this. 58 | command = f"/gc/google-cloud-sdk/bin/gcloud storage cp {path} -".split() 59 | result = subprocess.run(command, stdout=subprocess.PIPE, text=False) 60 | if result.returncode != 0: 61 | raise Exception(f"gcloud storage cp command failed with return code {result.returncode}: {result.stderr.decode('utf-8')}") 62 | 63 | in_memory_file = io.BytesIO(result.stdout) 64 | in_memory_file.seek(0) 65 | 66 | print(f"weignts downloaded in {time.time() - st}") 67 | config = AutoConfig.from_pretrained(CONFIG_LOCATION) 68 | 69 | logging.disable(logging.WARN) 70 | model = no_init_or_tensor( 71 | lambda: YieldingLlama.from_pretrained( 72 | None, config=config, state_dict=OrderedDict(), torch_dtype=torch.float16 73 | ) 74 | ) 75 | logging.disable(logging.NOTSET) 76 | des = TensorDeserializer(in_memory_file, plaid_mode=False) 77 | des.load_into_module(model) 78 | print(f"zero to weights in {time.time() - st}") 79 | 80 | 81 | def download_chunk(dl_cfg): 82 | """Submittable function to python process pool for downloading byte chunk""" 83 | storage_client = storage.Client() 84 | bucket = storage_client.bucket(dl_cfg['bucket']) 85 | blob = bucket.get_blob(dl_cfg['blob']) 86 | in_memory_file = io.BytesIO() 87 | blob.download_to_file(in_memory_file, start=dl_cfg['start'], end=dl_cfg['end']) 88 | return in_memory_file 89 | 90 | def download_blob_to_stream(bucket_name: str, source_blob_name: str, n: int = 4): 91 | """Downloads a blob to a stream or other file-like object.""" 92 | 93 | storage_client = storage.Client() 94 | 95 | bucket = storage_client.bucket(bucket_name) 96 | 97 | # Need to call get_blob to get metadata 98 | blob = bucket.get_blob(source_blob_name) 99 | 100 | def _partition_file(size: int, bucket: str, blob: str, n: int) -> List[dict]: 101 | partitions = [] 102 | split = int(size/n) 103 | start = 0 104 | end = split 105 | 106 | for i in range(n): 107 | if i == n - 1: # If it's the last partition 108 | end = size - 1 # Set the endpoint to the last byte of the file 109 | partitions.append({"start": start, "end": end, "bucket": bucket, "blob": blob}) 110 | start = end + 1 111 | end += split 112 | 113 | return partitions 114 | 115 | partitions = _partition_file(blob.size, bucket_name, source_blob_name, n) 116 | 117 | print('submitting tasks') 118 | res = [] 119 | with ProcessPoolExecutor(n) as ex: 120 | res = list(ex.map(download_chunk, partitions)) 121 | # results = [ex.submit(download_chunk, partition) for partition in partitions] 122 | 123 | # for future in concurrent.futures.as_completed(results): 124 | # res.append(future.result()) 125 | print('all downloads finished') 126 | 127 | concatenated_bytes = b''.join(result.getvalue() for result in res) 128 | 129 | # Create a new in memory file w/all bytes concatenated 130 | in_memory_file = io.BytesIO(concatenated_bytes) 131 | in_memory_file.seek(0) 132 | return in_memory_file 133 | 134 | 135 | def test_python_deserialization(): 136 | st = time.time() 137 | print("downloading weights") 138 | bucket_name = "CHANGEME" 139 | source_name = "CHANGEME" 140 | 141 | obj = download_blob_to_stream(bucket_name=bucket_name, source_blob_name=source_name, n=24) 142 | 143 | print(f"weignts downloaded in {time.time() - st}") 144 | 145 | print(f"deserializing weights from memory") 146 | config = AutoConfig.from_pretrained(CONFIG_LOCATION) 147 | 148 | logging.disable(logging.WARN) # turns off long message about not training the model 149 | model = no_init_or_tensor( 150 | lambda: YieldingLlama.from_pretrained( 151 | None, config=config, state_dict=OrderedDict() 152 | ) 153 | ) 154 | logging.disable(logging.NOTSET) 155 | des = TensorDeserializer(obj, plaid_mode=False) 156 | des.load_into_module(model) 157 | print(f"zero to weights in {time.time() - st}") 158 | 159 | 160 | if __name__ == '__main__': 161 | #test_python_deserialization() 162 | test_cli_deserialization() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from subprocess import call 5 | import logging 6 | from typing import Optional 7 | from zipfile import ZipFile 8 | 9 | import torch 10 | from cog import BaseModel, Input, Path 11 | from tensorizer import TensorSerializer 12 | from transformers import LlamaForCausalLM 13 | 14 | from config import DEFAULT_MODEL_NAME, pull_gcp_file 15 | 16 | MODEL_OUT = "/src/tuned_weights.tensors" 17 | CHECKPOINT_DIR = "checkpoints" 18 | SAVE_STRATEGY = "epoch" 19 | DIST_OUT_DIR = "tmp/model" 20 | 21 | 22 | class TrainingOutput(BaseModel): 23 | weights: Path 24 | 25 | 26 | def train( 27 | train_data: Path = Input( 28 | description="path to data file to use for fine-tuning your model" 29 | ), 30 | eval_data: Path = Input( 31 | description="path to optional evaluation data file to use for model eval", 32 | default=None, 33 | ), 34 | weights: Path = Input( 35 | description="location of weights that are going to be fine-tuned", default=None 36 | ), 37 | train_batch_size: int = Input(description="batch size per GPU", default=1, ge=1), 38 | gradient_accumulation_steps: int = Input( 39 | description="number of training steps to update gradient for before performing a backward pass", 40 | default=8, 41 | ), 42 | learning_rate: float = Input( 43 | description="learning rate, for learning!", default=2e-5, ge=0 44 | ), 45 | warmup_ratio: float = Input( 46 | description="pct of steps for a linear learning rate warmup", 47 | ge=0, 48 | le=0.5, 49 | default=0.03, 50 | ), 51 | num_train_epochs: int = Input( 52 | description="number of training epochs", ge=1, default=1 53 | ), 54 | max_steps: int = Input( 55 | description="number of steps to run training for, supersedes num_train_epochs", 56 | default=-1, 57 | ), 58 | logging_steps: int = Input( 59 | description="number of steps between logging epoch & loss", default=1 60 | ), 61 | lora_rank: int = Input( 62 | description="Rank of the lora matrices", default=8, ge=1), 63 | lora_alpha: int = Input(description="Alpha parameter for scaling lora weights; weights are scaled by alpha/rank", default=16, ge=1), 64 | lora_dropout: float = Input(description="Dropout for lora training", default=0.1, ge=0.0, le=1.0), 65 | lora_target_modules: str = Input(description="Comma-separated list of lora modules to target, i.e. 'q_proj,v_proj'. Leave blank for default.", default="q_proj,v_proj") 66 | ) -> TrainingOutput: 67 | input_weights = weights if weights is not None else DEFAULT_MODEL_NAME 68 | 69 | if 'http' in input_weights or 'gs' in input_weights: 70 | # doing this once instead of 4x 71 | local_weights = '/src/llama.tensors' 72 | pull_gcp_file(input_weights, local_weights) 73 | input_weights = local_weights 74 | 75 | root_path = os.getcwd() 76 | deepspeed_config = os.path.join(root_path, "ds_config/ds_z3_fp16_config.json") 77 | 78 | output_dir = DIST_OUT_DIR 79 | if os.path.exists(output_dir): 80 | shutil.rmtree(output_dir) 81 | os.makedirs(output_dir) 82 | 83 | num_gpus = torch.cuda.device_count() 84 | num_gpus_flag = f"--num_gpus={num_gpus}" 85 | 86 | print(f"Local Output Dir: {output_dir}") 87 | print(f"Number of GPUs: {num_gpus}") 88 | 89 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 90 | os.environ["HF_DATASETS_CACHE"] = "/src/.hf-cache" 91 | 92 | def _arg_if_present(var, var_name): 93 | """Need to wrap any arguments whose default value in train() is `None`""" 94 | if var: 95 | return f" --{var_name} {var}" 96 | return " " 97 | 98 | res = call( 99 | "deepspeed " 100 | + num_gpus_flag 101 | + " --master_port=9292" 102 | + " --module training.trainer" 103 | + f" --deepspeed {deepspeed_config}" 104 | + f" --train_data={str(train_data)}" 105 | + f" --weights={input_weights}" 106 | + f" --num_train_epochs={num_train_epochs}" 107 | + f" --max_steps={max_steps}" 108 | + _arg_if_present(eval_data, "eval_data") 109 | + f" --learning_rate {learning_rate}" 110 | + f" --train_batch_size {train_batch_size}" 111 | + f" --gradient_accumulation_steps {gradient_accumulation_steps}" 112 | + f" --logging_steps {logging_steps}" 113 | + f" --warmup_ratio {warmup_ratio}" 114 | + f" --lora_rank {lora_rank}" 115 | + f" --lora_alpha {lora_alpha}" 116 | + f" --lora_dropout {lora_dropout}" 117 | + _arg_if_present(lora_target_modules, "lora_target_modules") 118 | + " --local_output_dir " 119 | + output_dir, 120 | shell=True, 121 | ) 122 | if res != 0: 123 | raise Exception(f"Training failed! Process returned error code {res}. Check the logs for details.") 124 | 125 | out_path = "training_output.zip" 126 | 127 | directory = Path(output_dir) 128 | with ZipFile(out_path, "w") as zip: 129 | for file_path in directory.rglob("*"): 130 | print(file_path) 131 | zip.write(file_path, arcname=file_path.relative_to(directory)) 132 | 133 | return TrainingOutput(weights=Path(out_path)) 134 | 135 | 136 | if __name__ == "__main__": 137 | parser = argparse.ArgumentParser( 138 | description="Fine-tune a language model on a text dataset" 139 | ) 140 | parser.add_argument( 141 | "--train_data", type=Path, required=True, help="Path to the json dataset" 142 | ) 143 | parser.add_argument( 144 | "--eval_data", 145 | type=Path, 146 | required=False, 147 | help="Path to the json dataset", 148 | default=None, 149 | ) 150 | parser.add_argument( 151 | "--weights", 152 | type=str, 153 | default=None, 154 | help="The model class to fine-tune on HF or as a local path (e.g. 'google/flan-t5-xxl'", 155 | ) 156 | parser.add_argument( 157 | "--num_train_epochs", type=int, required=True, help="Number of training epochs" 158 | ) 159 | parser.add_argument( 160 | "--learning_rate", 161 | type=float, 162 | default=2e-5, 163 | help="Learning rate for the optimizer", 164 | ) 165 | parser.add_argument( 166 | "--train_batch_size", type=int, default=4, help="Batch size for training" 167 | ) 168 | parser.add_argument( 169 | "--warmup_ratio", 170 | type=float, 171 | default=0.03, 172 | help="Number of warmup steps for the learning rate scheduler", 173 | ) 174 | parser.add_argument( 175 | "--max_steps", 176 | type=int, 177 | default=0, 178 | help="Number of training steps to run, overrides num_train_epochs, useful for testing", 179 | ) 180 | parser.add_argument( 181 | "--gradient_accumulation_steps", 182 | type=int, 183 | default=8, 184 | help="Number of training steps to run, overrides num_train_epochs, useful for testing", 185 | ) 186 | parser.add_argument("--logging_steps", type=int, default=1) 187 | some_args = parser.parse_args() 188 | train(**vars(some_args)) 189 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-llama/f85037aed5269b6d046efee1f8552695c11912d2/training/__init__.py -------------------------------------------------------------------------------- /training/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import json 4 | import os 5 | import time 6 | import logging 7 | from collections import OrderedDict 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | from cog import Input, Path 12 | from peft import (LoraConfig, get_peft_model) 13 | from torch.utils.data import Dataset 14 | from transformers import LlamaForCausalLM, Trainer, TrainingArguments, AutoConfig 15 | from tensorizer import TensorDeserializer 16 | from tensorizer.utils import no_init_or_tensor 17 | import sys 18 | sys.path.append('/src/') 19 | 20 | from config import DEFAULT_MODEL_NAME, load_tokenizer, load_tensorizer 21 | 22 | MODEL_OUT = "/src/tuned_weights.tensors" 23 | CHECKPOINT_DIR = "checkpoints" 24 | SAVE_STRATEGY = "epoch" 25 | DIST_OUT_DIR = "tmp/model" 26 | IGNORE_INDEX = -100 27 | 28 | 29 | class DatasetBuilder: 30 | """Dataset agnostic class to take in input_ids and labels and spit out tokens""" 31 | 32 | def __init__(self, tokenizer): 33 | self.tokenizer = tokenizer 34 | 35 | def batch_tokenize(self, texts): 36 | """Tokenizes text. Presently doesn't pad inputs, just returns input ids.""" 37 | tokenized = [ 38 | self.tokenizer( 39 | prompt, return_tensors="pt", padding="longest", truncation=True 40 | ).input_ids 41 | for prompt in texts 42 | ] 43 | return tokenized 44 | 45 | def construct_dataset(self, input_data): 46 | prompts = [val["prompt"] for val in input_data] 47 | tokenized_input_ids = self.batch_tokenize(prompts) 48 | labels = [val["completion"] for val in input_data] 49 | tokenized_labels = self.batch_tokenize(labels) 50 | return TuneDataset(tokenized_input_ids, tokenized_labels) 51 | 52 | 53 | class CausalDatasetBuilder(DatasetBuilder): 54 | """Builds generative dataset for Causal LM.""" 55 | 56 | def __init__(self, tokenizer, train_on_prompt=True): 57 | super().__init__(tokenizer) 58 | self.train_on_prompt = train_on_prompt 59 | 60 | def construct_dataset(self, input_data): 61 | labels = [ 62 | val["prompt"] + "\n" + val["completion"] + self.tokenizer.eos_token 63 | for val in input_data 64 | ] 65 | input_ids = [val.squeeze() for val in self.batch_tokenize(labels)] 66 | labels = copy.deepcopy(input_ids) 67 | if self.train_on_prompt: 68 | return TuneDataset(input_ids, labels) 69 | # masking prompt 70 | prompts = [val["prompt"] for val in input_data] 71 | tokenized_prompts = self.batch_tokenize(prompts) 72 | prompt_lens = [val.shape[1] for val in tokenized_prompts] 73 | 74 | for label, source_len in zip(labels, prompt_lens): 75 | label[:source_len] = IGNORE_INDEX 76 | return TuneDataset(input_ids, labels) 77 | 78 | 79 | class TuneDataset(Dataset): 80 | """Dead simple torch dataset wrapper. Attention masks are created in collator""" 81 | 82 | def __init__(self, input_ids, labels): 83 | self.input_ids = input_ids 84 | self.labels = labels 85 | 86 | def __len__(self): 87 | return len(self.input_ids) 88 | 89 | def __getitem__(self, i): 90 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 91 | 92 | 93 | class SequenceDataCollator: 94 | """Collate examples for dynamic batch construction in supervised fine-tuning.""" 95 | 96 | def __init__(self, tokenizer, multiple_of=None): 97 | self.tokenizer = tokenizer 98 | self.multiple_of = multiple_of 99 | self.cache_count = 0 100 | 101 | def pad_to_multiple(self, tensor, value): 102 | # taking advantage of tensor cores, perhaps 103 | multiple = self.multiple_of 104 | target_length = (tensor.size(0) + multiple - 1) // multiple * multiple 105 | return torch.nn.functional.pad( 106 | tensor, (0, target_length - tensor.size(0)), value=value 107 | ) 108 | 109 | def __call__(self, instances): 110 | input_ids, labels = tuple( 111 | [instance[key] for instance in instances] for key in ("input_ids", "labels") 112 | ) 113 | if self.multiple_of: 114 | input_ids = [ 115 | self.pad_to_multiple(val, self.tokenizer.pad_token_id) 116 | for val in input_ids 117 | ] 118 | labels = [self.pad_to_multiple(val, -100) for val in labels] 119 | 120 | input_ids = torch.nn.utils.rnn.pad_sequence( 121 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 122 | ) 123 | labels = torch.nn.utils.rnn.pad_sequence( 124 | labels, batch_first=True, padding_value=-100 125 | ) # -100 tells torch to ignore these tokens in loss computation. 126 | 127 | if self.cache_count < 1: 128 | torch.cuda.empty_cache() 129 | self.cache_count += 1 130 | 131 | return dict( 132 | input_ids=input_ids, 133 | labels=labels, 134 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 135 | ) 136 | 137 | 138 | def load_data(path): 139 | if path.suffix == ".json": 140 | return load_json(path) 141 | elif path.suffix == ".jsonl": 142 | return load_jsonl(path) 143 | else: 144 | raise Exception( 145 | f"file type {path} not supported. Currently supported types are json, jsonl" 146 | ) 147 | 148 | 149 | def load_jsonl(path): 150 | data = [] 151 | with open(path, "r") as f: 152 | for line in f: 153 | json_object = json.loads(line) 154 | data.append(json_object) 155 | return data 156 | 157 | 158 | def load_json(path): 159 | """Loads a single json blob""" 160 | with open(path, "r") as f: 161 | data = json.load(f) 162 | return data 163 | 164 | 165 | def load_model(model_name_or_path): 166 | print(f"Rank : {os.environ['RANK']}, device: {torch.cuda.current_device()}") 167 | torch.cuda.set_device(int(os.environ['RANK'])) 168 | if model_name_or_path is None: 169 | model_name_or_path = DEFAULT_MODEL_NAME 170 | model = load_tensorizer(model_name_or_path, plaid_mode=False, cls=LlamaForCausalLM) 171 | return model 172 | 173 | 174 | def load_peft_model( 175 | model_name_or_path, lora_rank: int, lora_alpha: int, lora_dropout: float, lora_target_modules: Optional[Union[List[str], str]] 176 | ): 177 | if lora_target_modules: 178 | lora_target_modules = lora_target_modules.split(",") 179 | print("Using LoRA...") 180 | model = load_model(model_name_or_path) 181 | 182 | config = LoraConfig( 183 | r=lora_rank, 184 | lora_alpha=lora_alpha, 185 | lora_dropout=lora_dropout, 186 | bias="none", 187 | target_modules=lora_target_modules, 188 | task_type="CAUSAL_LM", 189 | ) 190 | print(f"LoRA config: {config}") 191 | model = get_peft_model(model, config) 192 | return model 193 | 194 | 195 | def train( 196 | train_data: Path = Input( 197 | description="path to data file to use for fine-tuning your model" 198 | ), 199 | eval_data: Path = Input( 200 | description="path to optional evaluation data file to use for model eval", 201 | default=None, 202 | ), 203 | weights: Path = Input( 204 | description="location of weights that are going to be fine-tuned", default=None 205 | ), 206 | train_batch_size: int = Input(description="batch size per GPU", default=4, ge=1), 207 | gradient_accumulation_steps: int = Input( 208 | description="number of training steps to update gradient for before performing a backward pass", 209 | default=8, 210 | ), 211 | lr_scheduler_type: str = Input( 212 | description="learning rate scheduler", 213 | default="cosine", 214 | choices=[ 215 | "linear", 216 | "cosine", 217 | "cosine_with_restarts", 218 | "polynomial", 219 | "inverse_sqrt", 220 | "constant", 221 | "constant_with_warmup", 222 | ], 223 | ), 224 | learning_rate: float = Input( 225 | description="learning rate, for learning!", default=2e-4, ge=0 226 | ), 227 | warmup_ratio: float = Input( 228 | description="pct of steps for a linear learning rate warmup", 229 | ge=0, 230 | le=0.5, 231 | default=0.03, 232 | ), 233 | num_train_epochs: int = Input( 234 | description="number of training epochs", ge=1, default=1 235 | ), 236 | max_steps: int = Input( 237 | description="number of steps to run training for, supersedes num_train_epochs", 238 | default=-1, 239 | ge=0, 240 | ), 241 | logging_steps: int = Input( 242 | description="number of steps between logging epoch & loss", default=1 243 | ), 244 | lora_rank: int = 8, 245 | lora_alpha: int = 16, 246 | lora_dropout: float = 0.1, 247 | lora_target_modules: Optional[Union[List[str], str]] = None, 248 | local_output_dir: str = None, 249 | local_rank: int = -1, 250 | deepspeed: str = None 251 | ) -> None: 252 | print("Loading model...") 253 | model = load_peft_model(weights, lora_rank, lora_alpha, lora_dropout, lora_target_modules) 254 | tokenizer = load_tokenizer() 255 | 256 | print(f"Loading dataset {train_data}...") 257 | print(train_data) 258 | train_data = load_data(train_data) 259 | p = CausalDatasetBuilder(tokenizer) 260 | train_dataset = p.construct_dataset(train_data) 261 | eval_dataset = None 262 | if eval_data: 263 | eval_data = load_json(eval_data) 264 | eval_dataset = p.construct_dataset(eval_data) 265 | 266 | torch.cuda.empty_cache() 267 | torch.set_float32_matmul_precision("high") 268 | 269 | 270 | print("Training...") 271 | trainer = Trainer( 272 | model=model, 273 | train_dataset=train_dataset, 274 | eval_dataset=eval_dataset, 275 | args=TrainingArguments( 276 | output_dir=CHECKPOINT_DIR, 277 | per_device_train_batch_size=train_batch_size, 278 | gradient_accumulation_steps=gradient_accumulation_steps, 279 | save_strategy="no", 280 | logging_steps=logging_steps, 281 | lr_scheduler_type=lr_scheduler_type, 282 | warmup_ratio=warmup_ratio, 283 | num_train_epochs=num_train_epochs, 284 | learning_rate=learning_rate, 285 | max_steps=max_steps, 286 | tf32=True, 287 | fp16=True, 288 | half_precision_backend="cuda_amp", 289 | deepspeed=deepspeed, 290 | local_rank=local_rank 291 | ), 292 | data_collator=SequenceDataCollator(tokenizer, 8), # depends on bf16 value 293 | ) 294 | trainer.train() 295 | print("model saving!") 296 | model.save_pretrained(local_output_dir) 297 | return 298 | 299 | 300 | if __name__ == "__main__": 301 | parser = argparse.ArgumentParser( 302 | description="Fine-tune a language model on a text dataset" 303 | ) 304 | parser.add_argument( 305 | "--train_data", type=Path, required=True, help="Path to the json dataset" 306 | ) 307 | parser.add_argument( 308 | "--eval_data", 309 | type=Path, 310 | required=False, 311 | help="Path to the json dataset", 312 | default=None, 313 | ) 314 | parser.add_argument( 315 | "--weights", 316 | type=str, 317 | default=None, 318 | help="The model class to fine-tune on HF or as a local path (e.g. 'google/flan-t5-xxl'", 319 | ) 320 | parser.add_argument( 321 | "--num_train_epochs", type=int, required=True, help="Number of training epochs" 322 | ) 323 | parser.add_argument( 324 | "--deepspeed", type=str, default=None, help="Path to deepspeed config file." 325 | ) 326 | parser.add_argument( 327 | "--learning_rate", 328 | type=float, 329 | default=2e-5, 330 | help="Learning rate for the optimizer", 331 | ) 332 | parser.add_argument( 333 | "--train_batch_size", type=int, default=4, help="Batch size for training" 334 | ) 335 | parser.add_argument( 336 | "--warmup_ratio", 337 | type=float, 338 | default=0.03, 339 | help="Number of warmup steps for the learning rate scheduler", 340 | ) 341 | parser.add_argument( 342 | "--max_steps", 343 | type=int, 344 | default=0, 345 | help="Number of training steps to run, overrides num_train_epochs, useful for testing", 346 | ) 347 | parser.add_argument( 348 | "--gradient_accumulation_steps", 349 | type=int, 350 | default=8, 351 | help="Number of training steps to run, overrides num_train_epochs, useful for testing", 352 | ) 353 | parser.add_argument("--logging_steps", type=int, default=1) 354 | parser.add_argument( 355 | "--lr_scheduler_type", 356 | type=str, 357 | default="cosine", 358 | ) 359 | parser.add_argument( 360 | "--local_output_dir", 361 | type=str, 362 | help="Write directly to this local path", 363 | required=True, 364 | ) 365 | parser.add_argument( 366 | "--local_rank", 367 | type=int, 368 | default=-1, 369 | help="Provided by deepspeed to identify which instance this process is when performing multi-GPU training.", 370 | ) 371 | parser.add_argument( 372 | "--lora_rank", 373 | type=int, 374 | default=8, 375 | help="Number of training steps to run, overrides num_train_epochs, useful for testing", 376 | ) 377 | parser.add_argument( 378 | "--lora_alpha", 379 | type=int, 380 | default=16, 381 | help="Number of training steps to run, overrides num_train_epochs, useful for testing", 382 | ) 383 | parser.add_argument( 384 | "--lora_dropout", 385 | type=float, 386 | default=0.4, 387 | help="Number of training steps to run, overrides num_train_epochs, useful for testing", 388 | ) 389 | parser.add_argument( 390 | "--lora_target_modules", 391 | type=str, 392 | default=None, 393 | help="Comma-separated list of lora modules to target, i.e. 'q_proj,v_proj'. Leave blank for default" 394 | ) 395 | some_args = parser.parse_args() 396 | train(**vars(some_args)) 397 | --------------------------------------------------------------------------------