├── .dockerignore ├── .gitignore ├── LICENSE ├── README.md ├── cog.yaml ├── config.py ├── ds_config ├── ds_flan_t5_z3_config_bf16_no_offload.json └── ds_z3_bf16_config.json ├── examples └── alpaca │ ├── README.md │ ├── process_data.py │ └── replicate_alpaca_data.json ├── predict.py ├── scripts ├── cog_push_all.sh ├── train_multi_gpu.sh └── train_single_gpu.sh ├── select_model.py ├── subclass.py ├── templates ├── cog_template.yaml └── config_template.py ├── train.py └── training ├── __init__.py └── trainer.py /.dockerignore: -------------------------------------------------------------------------------- 1 | flan-t5** 2 | checkpoints/** 3 | examples/** 4 | # modify pretrained_weights if you want to build weights into the image 5 | pretrained_weights/** 6 | tmp/** -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/** 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository is an implementation of a fine-tunable [FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) as a Cog model. [Cog packages machine learning models as standard containers.](https://github.com/replicate/cog) 2 | 3 | The model can be fine tuned using `cog train`, and can serve predictions using `cog predict`. 4 | 5 | ### Fine-tuning 6 | 7 | All that `cog train` requires is an input dataset consisting of a JSON list where each example has a 'prompt' and 'completion' field. The model will be fine-tuned to produce 'completion' given 'prompt'. Here's an example command to train the model from the root directory: 8 | 9 | ``` 10 | cog train -i train_data="https://storage.googleapis.com/dan-scratch-public/fine-tuning/70k_samples_prompt.jsonl" -i gradient_accumulation_steps=8 -i learning_rate=2e-5 -i num_train_epochs=3 -i logging_steps=2 -i train_batch_size=4 11 | ``` 12 | 13 | Of the params above for training, the only required param is the `train_data`, but you can pass other parameters to modify training the model as you see fit. See the 'examples' folder for an example dataset. 14 | 15 | ### Inference 16 | 17 | To generate text given input prompts, simply run the `cog predict` command below: 18 | ``` 19 | cog predict -i prompt="Q: Answer the following yes/no question by reasoning step-by-step. Can a dog drive a car?" 20 | ``` 21 | 22 | Note that the first prediction run will download weights for the selected model from Huggingface to a local directory; subsequent predictions will be faster. 23 | 24 | ### Model selection 25 | 26 | This project can fine-tune or run inference for any of the FLAN family of models (note that larger models require high-performance GPUs). Just run `cog run python select_model.py --model_name ["flan-t5-small" "flan-t5-base" "flan-t5-large" "flan-t5-xl" "flan-t5-xxl" "flan-ul2"]`, and then you can run all other `cog` commands with the appropriate model. 27 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: true 7 | cuda: "11.7" 8 | 9 | # python version in the form '3.8' or '3.8.12' 10 | python_version: "3.8" 11 | 12 | # a list of packages in the format == 13 | python_packages: 14 | - "numpy==1.24.2" 15 | - "torch==2.0" 16 | - "transformers==4.27.4" 17 | - "accelerate==0.18.0" 18 | - "peft==0.2.0" 19 | - "sentencepiece==0.1.97" 20 | - "tensorizer==1.0.1" 21 | - "jinja2==3.1.2" 22 | - "deepspeed" 23 | 24 | # predict.py defines how predictions are run on your model 25 | predict: "predict.py:Predictor" 26 | train: "train.py:train" -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer 2 | 3 | HUGGINGFACE_MODEL_NAME = "google/flan-t5-base" 4 | 5 | 6 | def load_tokenizer(): 7 | """Same tokenizer, agnostic from tensorized weights/etc""" 8 | return T5Tokenizer.from_pretrained( 9 | HUGGINGFACE_MODEL_NAME, cache_dir="pretrained_weights" 10 | ) 11 | -------------------------------------------------------------------------------- /ds_config/ds_flan_t5_z3_config_bf16_no_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupLR", 16 | "params": { 17 | "warmup_min_lr": "auto", 18 | "warmup_max_lr": "auto", 19 | "warmup_num_steps": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | "gradient_accumulation_steps": "auto", 35 | "gradient_clipping": "auto", 36 | "steps_per_print": 2000, 37 | "train_batch_size": "auto", 38 | "train_micro_batch_size_per_gpu": "auto", 39 | "wall_clock_breakdown": false 40 | } -------------------------------------------------------------------------------- /ds_config/ds_z3_bf16_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupLR", 16 | "params": { 17 | "warmup_min_lr": "auto", 18 | "warmup_max_lr": "auto", 19 | "warmup_num_steps": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "offload_param": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "overlap_comm": true, 33 | "contiguous_gradients": true, 34 | "sub_group_size": 1e9, 35 | "reduce_bucket_size": "auto", 36 | "stage3_prefetch_bucket_size": "auto", 37 | "stage3_param_persistence_threshold": "auto", 38 | "stage3_max_live_parameters": 1e9, 39 | "stage3_max_reuse_distance": 1e9, 40 | "stage3_gather_16bit_weights_on_model_save": true 41 | }, 42 | "gradient_accumulation_steps": "auto", 43 | "gradient_clipping": "auto", 44 | "steps_per_print": 2000, 45 | "train_batch_size": "auto", 46 | "train_micro_batch_size_per_gpu": "auto", 47 | "wall_clock_breakdown": false 48 | } -------------------------------------------------------------------------------- /examples/alpaca/README.md: -------------------------------------------------------------------------------- 1 | Example code for parsing the dataset needed to train [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca). 2 | 3 | This contains both a function, `process_data.py`, which shows how to transform the [given alpaca data](https://github.com/gururise/AlpacaDataCleaned) into the format expected by `cog train`. It also contains an example parsed dataset as a reference for that `{'prompt': ..., 'completion':...}` format. -------------------------------------------------------------------------------- /examples/alpaca/process_data.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer 2 | import json 3 | 4 | PROMPT_DICT = { 5 | "prompt_input": ( 6 | "Below is an instruction that describes a task, paired with an input that provides further context. " 7 | "Write a response that appropriately completes the request.\n\n" 8 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 9 | ), 10 | "prompt_no_input": ( 11 | "Below is an instruction that describes a task. " 12 | "Write a response that appropriately completes the request.\n\n" 13 | "### Instruction:\n{instruction}\n\n### Response:" 14 | ), 15 | } 16 | 17 | class Preprocessor: 18 | """Simple class to parse alpaca data into format expected by trainer. Run this offline to build your dataset.""" 19 | 20 | def __init__(self, tokenizer): 21 | self.prompt_dict = PROMPT_DICT 22 | self.tokenizer = tokenizer 23 | 24 | def batch_tokenize(self, texts): 25 | """Tokenizes text. Presently doesn't pad inputs, just returns input ids.""" 26 | tokenized = [ 27 | self.tokenizer( 28 | prompt, 29 | return_tensors="pt", 30 | padding="longest", 31 | ).input_ids 32 | for prompt in texts 33 | ] 34 | return tokenized 35 | 36 | def make_prompt(self, input_row): 37 | if len(input_row["input"]) > 1: 38 | return self.prompt_dict["prompt_input"].format_map(input_row) 39 | return self.prompt_dict["prompt_no_input"].format_map(input_row) 40 | 41 | def make_short_prompt(self, input_row): 42 | if len(input_row["input"]) > 1: 43 | return f'''{input_row['instruction']}\n{input_row['input']}''' 44 | return input_row['instruction'] 45 | 46 | def construct_dataset(self, input_data): 47 | prompts = [self.make_short_prompt(val) for val in input_data] 48 | return [{'prompt':val[0], 'completion':val[1]} for val in zip(prompts, [val["output"] for val in input_data])] 49 | 50 | if __name__ == '__main__': 51 | proc = Preprocessor(T5Tokenizer.from_pretrained('google/flan-t5-xl')) 52 | with open('alpaca_data.json', 'r') as f: 53 | data = json.load(f) 54 | 55 | data_out = proc.construct_dataset(data) 56 | 57 | with open('short_alpaca_data.json', 'w') as f: 58 | json.dump(data_out, f, indent=2) 59 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | import subprocess 4 | import time 5 | from collections import OrderedDict 6 | from typing import Optional 7 | 8 | import torch 9 | from cog import BasePredictor, ConcatenateIterator, Input, Path 10 | from tensorizer import TensorDeserializer 11 | from tensorizer.utils import no_init_or_tensor 12 | from transformers import AutoConfig, T5ForConditionalGeneration 13 | 14 | from config import HUGGINGFACE_MODEL_NAME, load_tokenizer 15 | from subclass import YieldingT5 16 | 17 | 18 | class Predictor(BasePredictor): 19 | def setup(self, weights: Optional[Path] = None): 20 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 21 | if weights is not None and weights.name == "weights": 22 | # bugfix 23 | weights = None 24 | if weights is None: 25 | self.model = self.load_huggingface_model(weights=HUGGINGFACE_MODEL_NAME) 26 | elif hasattr(weights, "filename") and "tensors" in weights.filename: 27 | self.model = self.load_tensorizer(weights) 28 | elif hasattr(weights, "suffix") and "tensors" in weights.suffix: 29 | self.model = self.load_tensorizer(weights) 30 | else: 31 | self.model = self.load_huggingface_model(weights=weights) 32 | 33 | self.tokenizer = load_tokenizer() 34 | 35 | def load_huggingface_model(self, weights=None): 36 | st = time.time() 37 | print(f"loading weights from {weights} w/o tensorizer") 38 | model = YieldingT5.from_pretrained( 39 | weights, cache_dir="pretrained_weights", torch_dtype=torch.float16 40 | ) 41 | model.to(self.device) 42 | print(f"weights loaded in {time.time() - st}") 43 | return model 44 | 45 | def load_tensorizer(self, weights): 46 | st = time.time() 47 | weights = str(weights) 48 | pattern = r"https://pbxt\.replicate\.delivery/([^/]+/[^/]+)" 49 | match = re.search(pattern, weights) 50 | if match: 51 | weights = f"gs://replicate-files/{match.group(1)}" 52 | 53 | print(f"deserializing weights") 54 | local_weights = "/src/flan_tensors" 55 | command = f"/gc/google-cloud-sdk/bin/gcloud storage cp {weights} {local_weights}".split() 56 | res = subprocess.run(command) 57 | if res.returncode != 0: 58 | raise Exception( 59 | f"gcloud storage cp command failed with return code {res.returncode}: {res.stderr.decode('utf-8')}" 60 | ) 61 | config = AutoConfig.from_pretrained(HUGGINGFACE_MODEL_NAME) 62 | 63 | logging.disable(logging.WARN) 64 | model = no_init_or_tensor( 65 | lambda: YieldingT5.from_pretrained( 66 | None, config=config, state_dict=OrderedDict() 67 | ) 68 | ) 69 | logging.disable(logging.NOTSET) 70 | 71 | des = TensorDeserializer(local_weights, plaid_mode=True) 72 | des.load_into_module(model) 73 | print(f"weights loaded in {time.time() - st}") 74 | return model 75 | 76 | def predict( 77 | self, 78 | prompt: str = Input(description=f"Prompt to send to FLAN-T5."), 79 | max_length: int = Input( 80 | description="Maximum number of tokens to generate. A word is generally 2-3 tokens", 81 | ge=1, 82 | default=50, 83 | ), 84 | temperature: float = Input( 85 | description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic, 0.75 is a good starting value.", 86 | ge=0.01, 87 | le=5, 88 | default=0.75, 89 | ), 90 | top_p: float = Input( 91 | description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", 92 | ge=0.01, 93 | le=1.0, 94 | default=1.0, 95 | ), 96 | repetition_penalty: float = Input( 97 | description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it.", 98 | ge=0.01, 99 | le=5, 100 | default=1, 101 | ), 102 | debug: bool = Input( 103 | description="provide debugging output in logs", default=False 104 | ), 105 | ) -> ConcatenateIterator[str]: 106 | input = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) 107 | 108 | with torch.inference_mode(): 109 | first_token_yielded = False 110 | prev_ids = [] 111 | for output in self.model.generate( 112 | input, 113 | max_length=max_length, 114 | do_sample=True, 115 | temperature=temperature, 116 | top_p=top_p, 117 | repetition_penalty=repetition_penalty, 118 | ): 119 | cur_id = output.item() 120 | 121 | # in order to properly handle spaces, we need to do our own tokenizing. Fun! 122 | # we're building up a buffer of sub-word / punctuation tokens until we hit a space, and then yielding whole words + punctuation. 123 | cur_token = self.tokenizer.convert_ids_to_tokens(cur_id) 124 | 125 | # skip initial newline, which this almost always yields. hack - newline id = 13. 126 | if not first_token_yielded and not prev_ids and cur_id == 13: 127 | continue 128 | 129 | # underscore means a space, means we yield previous tokens 130 | if cur_token.startswith("▁"): # this is not a standard underscore. 131 | # first token 132 | if not prev_ids: 133 | prev_ids = [cur_id] 134 | continue 135 | 136 | # there are tokens to yield 137 | else: 138 | token = " " + self.tokenizer.decode(prev_ids) 139 | prev_ids = [cur_id] 140 | 141 | if not first_token_yielded: 142 | # no leading space for first token 143 | token = token.strip() 144 | first_token_yielded = True 145 | yield token 146 | else: 147 | prev_ids.append(cur_id) 148 | continue 149 | 150 | # remove any special tokens such as 151 | token = " " + self.tokenizer.decode(prev_ids, skip_special_tokens=True) 152 | if not first_token_yielded: 153 | # no leading space for first token 154 | token = token.strip() 155 | first_token_yielded = True 156 | yield token 157 | 158 | if debug: 159 | print(f"cur memory: {torch.cuda.memory_allocated()}") 160 | print(f"max allocated: {torch.cuda.max_memory_allocated()}") 161 | print(f"peak memory: {torch.cuda.max_memory_reserved()}") 162 | 163 | 164 | class EightBitPredictor(Predictor): 165 | """subclass s.t. we can configure whether a model is loaded in 8bit mode from cog.yaml""" 166 | 167 | def setup(self, weights: Optional[Path] = None): 168 | if weights is not None and weights.name == "weights": 169 | # bugfix 170 | weights = None 171 | # TODO: fine-tuned 8bit weights. 172 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 173 | self.model = T5ForConditionalGeneration.from_pretrained( 174 | HUGGINGFACE_MODEL_NAME, load_in_8bit=True, device_map="auto" 175 | ) 176 | self.tokenizer = load_tokenizer() 177 | -------------------------------------------------------------------------------- /scripts/cog_push_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_names=("flan-t5-small" "flan-t5-base" "flan-t5-large" "flan-t5-xl" "flan-t5-xxl" "flan-ul2") 4 | 5 | for model_name in "${model_names[@]}"; do 6 | echo "Pushing model: $model_name" 7 | cog run python select_model.py --model_name $model_name 8 | cog login --token-stdin <<< "$COG_TOKEN" 9 | cog push r8.im/replicate/$model_name 10 | done 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /scripts/train_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | torchrun --nproc_per_node=4 --master_port=9292 train.py \ 4 | --train_data ./short_alpaca_data.json \ 5 | --num_train_epochs 3 \ 6 | --learning_rate 5e-4 \ 7 | --train_batch_size 1 \ 8 | \ 9 | --gradient_accumulation_steps 64 \ 10 | --logging_steps 2 \ 11 | --warmup_ratio 0.03 12 | 13 | # xl - batch size = 6 14 | # xxl 15 | # deepspeed --num_gpus 4 --master_port=9292 train.py \ 16 | # --train_data ./short_alpaca_data.json \ 17 | # --num_train_epochs 3 \ 18 | # --learning_rate 5e-4 \ 19 | # --train_batch_size 8 \ 20 | # --gradient_accumulation_steps 8 \ 21 | # --logging_steps 2 \ 22 | # --warmup_ratio 0.03 -------------------------------------------------------------------------------- /scripts/train_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py \ 4 | --model_name_or_path google/flan-t5-base \ 5 | --data_path ./replicate_alpaca_data.json \ 6 | --num_train_epochs 3 \ 7 | --learning_rate 3e-4 \ 8 | --train_batch_size 8 \ 9 | --warmup_ratio 0.03 \ 10 | --max_steps 10 # number of steps before returning, mostly useful for testing performance 11 | -------------------------------------------------------------------------------- /select_model.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import stat 4 | from jinja2 import Template 5 | 6 | # Flan configs, can modify as 7 | CONFIGS = { 8 | "flan-t5-small": { 9 | "cog_yaml_parameters": {"predictor":"predict.py:Predictor"}, 10 | "config_py_parameters": {"model_name": "google/flan-t5-small"} 11 | }, 12 | "flan-t5-base": { 13 | "cog_yaml_parameters": {"predictor":"predict.py:Predictor"}, 14 | "config_py_parameters": {"model_name": "google/flan-t5-base"} 15 | }, 16 | "flan-t5-large": { 17 | "cog_yaml_parameters": {"predictor":"predict.py:Predictor"}, 18 | "config_py_parameters": {"model_name": "google/flan-t5-large"} 19 | }, 20 | "flan-t5-xl": { 21 | "cog_yaml_parameters": {"predictor":"predict.py:Predictor"}, 22 | "config_py_parameters": {"model_name": "google/flan-t5-xl"} 23 | }, 24 | "flan-t5-xxl": { 25 | "cog_yaml_parameters": {"predictor":"predict.py:Predictor"}, 26 | "config_py_parameters": {"model_name": "google/flan-t5-xxl"} 27 | }, 28 | "flan-ul2": { 29 | "cog_yaml_parameters": {"predictor":"predict.py:EightBitPredictor", "extra_deps":'''- "bitsandbytes==0.37.2"'''}, 30 | "config_py_parameters": {"model_name": "google/flan-ul2"} 31 | } 32 | } 33 | 34 | def _reset_file(file_path): 35 | if os.path.exists(file_path): 36 | os.remove(file_path) 37 | 38 | 39 | def write_one_config(template_fpath: str, fname_out: str, config: dict): 40 | with open(template_fpath, "r") as f: 41 | template_content = f.read() 42 | base_template = Template(template_content) 43 | 44 | _reset_file(fname_out) 45 | 46 | with open(fname_out, "w") as f: 47 | f.write(base_template.render(config)) 48 | 49 | # Give all users write access to resulting generated file. 50 | current_permissions = os.stat(fname_out).st_mode 51 | new_permissions = current_permissions | stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH 52 | os.chmod(fname_out, new_permissions) 53 | 54 | 55 | def write_configs(model_name): 56 | master_config = CONFIGS[model_name] 57 | write_one_config("templates/cog_template.yaml", "cog.yaml", master_config['cog_yaml_parameters']) 58 | write_one_config("templates/config_template.py", "config.py", master_config['config_py_parameters']) 59 | 60 | 61 | if __name__ == '__main__': 62 | parser = ArgumentParser() 63 | parser.add_argument("--model_name", default="flan-t5-base", help="name of the flan-t5 model you want to configure cog for") 64 | args = parser.parse_args() 65 | 66 | write_configs(args.model_name) -------------------------------------------------------------------------------- /subclass.py: -------------------------------------------------------------------------------- 1 | """sampling code pulled from Transformers & slightly modified to stream tokens""" 2 | import warnings 3 | from typing import List, Optional, Union 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from torch import nn 8 | 9 | from transformers.generation.logits_process import LogitsProcessorList 10 | from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria 11 | from transformers.generation.utils import SampleOutput, SampleDecoderOnlyOutput, SampleEncoderDecoderOutput 12 | 13 | from transformers import T5ForConditionalGeneration 14 | 15 | class YieldingT5(T5ForConditionalGeneration): 16 | """Overriding sample to yield tokens""" 17 | def sample( 18 | self, 19 | input_ids: torch.LongTensor, 20 | logits_processor: Optional[LogitsProcessorList] = None, 21 | stopping_criteria: Optional[StoppingCriteriaList] = None, 22 | logits_warper: Optional[LogitsProcessorList] = None, 23 | max_length: Optional[int] = None, 24 | pad_token_id: Optional[int] = None, 25 | eos_token_id: Optional[Union[int, List[int]]] = None, 26 | output_attentions: Optional[bool] = None, 27 | output_hidden_states: Optional[bool] = None, 28 | output_scores: Optional[bool] = None, 29 | return_dict_in_generate: Optional[bool] = None, 30 | synced_gpus: Optional[bool] = False, 31 | **model_kwargs, 32 | ) -> Union[SampleOutput, torch.LongTensor]: 33 | r""" 34 | Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and 35 | can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. 36 | 37 | 38 | 39 | In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. 40 | For an overview of generation strategies and code examples, check the [following 41 | guide](./generation_strategies). 42 | 43 | 44 | 45 | Parameters: 46 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 47 | The sequence used as a prompt for the generation. 48 | logits_processor (`LogitsProcessorList`, *optional*): 49 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] 50 | used to modify the prediction scores of the language modeling head applied at each generation step. 51 | stopping_criteria (`StoppingCriteriaList`, *optional*): 52 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] 53 | used to tell if the generation loop should stop. 54 | logits_warper (`LogitsProcessorList`, *optional*): 55 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used 56 | to warp the prediction score distribution of the language modeling head applied before multinomial 57 | sampling at each generation step. 58 | max_length (`int`, *optional*, defaults to 20): 59 | **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated 60 | tokens. The maximum length of the sequence to be generated. 61 | pad_token_id (`int`, *optional*): 62 | The id of the *padding* token. 63 | eos_token_id (`int`, *optional*): 64 | The id of the *end-of-sequence* token. 65 | output_attentions (`bool`, *optional*, defaults to `False`): 66 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 67 | returned tensors for more details. 68 | output_hidden_states (`bool`, *optional*, defaults to `False`): 69 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 70 | for more details. 71 | output_scores (`bool`, *optional*, defaults to `False`): 72 | Whether or not to return the prediction scores. See `scores` under returned tensors for more details. 73 | return_dict_in_generate (`bool`, *optional*, defaults to `False`): 74 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 75 | synced_gpus (`bool`, *optional*, defaults to `False`): 76 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3) 77 | model_kwargs: 78 | Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is 79 | an encoder-decoder model the kwargs should include `encoder_outputs`. 80 | 81 | Return: 82 | [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: 83 | A `torch.LongTensor` containing the generated tokens (default behaviour) or a 84 | [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and 85 | `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if 86 | `model.config.is_encoder_decoder=True`. 87 | 88 | Examples: 89 | 90 | ```python 91 | >>> from transformers import ( 92 | ... AutoTokenizer, 93 | ... AutoModelForCausalLM, 94 | ... LogitsProcessorList, 95 | ... MinLengthLogitsProcessor, 96 | ... TopKLogitsWarper, 97 | ... TemperatureLogitsWarper, 98 | ... StoppingCriteriaList, 99 | ... MaxLengthCriteria, 100 | ... ) 101 | >>> import torch 102 | 103 | >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") 104 | >>> model = AutoModelForCausalLM.from_pretrained("gpt2") 105 | 106 | >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token 107 | >>> model.config.pad_token_id = model.config.eos_token_id 108 | >>> model.generation_config.pad_token_id = model.config.eos_token_id 109 | 110 | >>> input_prompt = "Today is a beautiful day, and" 111 | >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids 112 | 113 | >>> # instantiate logits processors 114 | >>> logits_processor = LogitsProcessorList( 115 | ... [ 116 | ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), 117 | ... ] 118 | ... ) 119 | >>> # instantiate logits processors 120 | >>> logits_warper = LogitsProcessorList( 121 | ... [ 122 | ... TopKLogitsWarper(50), 123 | ... TemperatureLogitsWarper(0.7), 124 | ... ] 125 | ... ) 126 | 127 | >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) 128 | 129 | >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT 130 | >>> outputs = model.sample( 131 | ... input_ids, 132 | ... logits_processor=logits_processor, 133 | ... logits_warper=logits_warper, 134 | ... stopping_criteria=stopping_criteria, 135 | ... ) 136 | 137 | >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) 138 | ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] 139 | ```""" 140 | # init values 141 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 142 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 143 | if max_length is not None: 144 | warnings.warn( 145 | "`max_length` is deprecated in this function, use" 146 | " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", 147 | UserWarning, 148 | ) 149 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) 150 | logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() 151 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id 152 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id 153 | if isinstance(eos_token_id, int): 154 | eos_token_id = [eos_token_id] 155 | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores 156 | output_attentions = ( 157 | output_attentions if output_attentions is not None else self.generation_config.output_attentions 158 | ) 159 | output_hidden_states = ( 160 | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states 161 | ) 162 | return_dict_in_generate = ( 163 | return_dict_in_generate 164 | if return_dict_in_generate is not None 165 | else self.generation_config.return_dict_in_generate 166 | ) 167 | 168 | # init attention / hidden states / scores tuples 169 | scores = () if (return_dict_in_generate and output_scores) else None 170 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 171 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 172 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 173 | 174 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 175 | if return_dict_in_generate and self.config.is_encoder_decoder: 176 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None 177 | encoder_hidden_states = ( 178 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None 179 | ) 180 | 181 | # keep track of which sequences are already finished 182 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) 183 | 184 | this_peer_finished = False # used by synced_gpus only 185 | # auto-regressive generation 186 | while True: 187 | if synced_gpus: 188 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 189 | # The following logic allows an early break if all peers finished generating their sequence 190 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 191 | # send 0.0 if we finished, 1.0 otherwise 192 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 193 | # did all peers finish? the reduced sum will be 0.0 then 194 | if this_peer_finished_flag.item() == 0.0: 195 | break 196 | 197 | # prepare model inputs 198 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 199 | 200 | # forward pass to get next token 201 | outputs = self( 202 | **model_inputs, 203 | return_dict=True, 204 | output_attentions=output_attentions, 205 | output_hidden_states=output_hidden_states, 206 | ) 207 | 208 | if synced_gpus and this_peer_finished: 209 | continue # don't waste resources running the code we don't need 210 | 211 | next_token_logits = outputs.logits[:, -1, :] 212 | 213 | # pre-process distribution 214 | next_token_scores = logits_processor(input_ids, next_token_logits) 215 | next_token_scores = logits_warper(input_ids, next_token_scores) 216 | 217 | # Store scores, attentions and hidden_states when required 218 | if return_dict_in_generate: 219 | if output_scores: 220 | scores += (next_token_scores,) 221 | if output_attentions: 222 | decoder_attentions += ( 223 | (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) 224 | ) 225 | if self.config.is_encoder_decoder: 226 | cross_attentions += (outputs.cross_attentions,) 227 | 228 | if output_hidden_states: 229 | decoder_hidden_states += ( 230 | (outputs.decoder_hidden_states,) 231 | if self.config.is_encoder_decoder 232 | else (outputs.hidden_states,) 233 | ) 234 | 235 | # sample 236 | probs = nn.functional.softmax(next_token_scores, dim=-1) 237 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 238 | 239 | # finished sentences should have their next token be a padding token 240 | if eos_token_id is not None: 241 | if pad_token_id is None: 242 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 243 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 244 | 245 | # update generated ids, model inputs, and length for next step 246 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 247 | model_kwargs = self._update_model_kwargs_for_generation( 248 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 249 | ) 250 | 251 | # if eos_token was found in one sentence, set sentence to finished 252 | if eos_token_id is not None: 253 | unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) 254 | 255 | # stop when each sentence is finished, or if we exceed the maximum length 256 | if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): 257 | if not synced_gpus: 258 | break 259 | else: 260 | this_peer_finished = True 261 | else: 262 | yield next_tokens 263 | 264 | if return_dict_in_generate: 265 | if self.config.is_encoder_decoder: 266 | yield SampleEncoderDecoderOutput( 267 | sequences=input_ids, 268 | scores=scores, 269 | encoder_attentions=encoder_attentions, 270 | encoder_hidden_states=encoder_hidden_states, 271 | decoder_attentions=decoder_attentions, 272 | cross_attentions=cross_attentions, 273 | decoder_hidden_states=decoder_hidden_states, 274 | ) 275 | else: 276 | yield SampleDecoderOnlyOutput( 277 | sequences=input_ids, 278 | scores=scores, 279 | attentions=decoder_attentions, 280 | hidden_states=decoder_hidden_states, 281 | ) 282 | else: 283 | yield next_tokens -------------------------------------------------------------------------------- /templates/cog_template.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: true 7 | cuda: "11.7" 8 | 9 | # python version in the form '3.8' or '3.8.12' 10 | python_version: "3.8" 11 | 12 | # a list of packages in the format == 13 | python_packages: 14 | - "numpy==1.24.2" 15 | - "torch==1.13.1" 16 | - "transformers==4.27.4" 17 | - "accelerate==0.18.0" 18 | - "peft==0.2.0" 19 | - "sentencepiece==0.1.97" 20 | - "tensorizer==1.0.1" 21 | - "jinja2==3.1.2" 22 | - "deepspeed==0.8.3" 23 | {{extra_deps}} 24 | 25 | run: 26 | - "mkdir /gc && cd /gc && curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-426.0.0-linux-x86_64.tar.gz && tar -xf google-cloud-cli-426.0.0-linux-x86_64.tar.gz && ./google-cloud-sdk/install.sh -q" 27 | 28 | # predict.py defines how predictions are run on your model 29 | predict: "{{predictor}}" 30 | train: "train.py:train" 31 | -------------------------------------------------------------------------------- /templates/config_template.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer 2 | 3 | HUGGINGFACE_MODEL_NAME = "{{model_name}}" 4 | 5 | 6 | def load_tokenizer(): 7 | """Same tokenizer, agnostic from tensorized weights/etc""" 8 | return T5Tokenizer.from_pretrained( 9 | HUGGINGFACE_MODEL_NAME, cache_dir="pretrained_weights" 10 | ) 11 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from subprocess import call 4 | from typing import Optional 5 | 6 | import torch 7 | from cog import BaseModel, Input, Path 8 | from tensorizer import TensorSerializer 9 | from transformers import T5ForConditionalGeneration 10 | 11 | from config import HUGGINGFACE_MODEL_NAME 12 | 13 | MODEL_OUT = "/src/tuned_weights.tensors" 14 | CHECKPOINT_DIR = "checkpoints" 15 | SAVE_STRATEGY = "epoch" 16 | DIST_OUT_DIR = "tmp/model" 17 | 18 | 19 | class TrainingOutput(BaseModel): 20 | weights: Path 21 | 22 | 23 | def train( 24 | train_data: Path = Input( 25 | description="path to data file to use for fine-tuning your model" 26 | ), 27 | eval_data: Path = Input( 28 | description="path to optional evaluation data file to use for model eval", 29 | default=None, 30 | ), 31 | weights: Path = Input( 32 | description="location of weights that are going to be fine-tuned", default=None 33 | ), 34 | train_batch_size: int = Input(description="batch size per GPU", default=4, ge=1), 35 | gradient_accumulation_steps: int = Input( 36 | description="number of training steps to update gradient for before performing a backward pass", 37 | default=1, 38 | ), 39 | learning_rate: float = Input( 40 | description="learning rate, for learning!", default=2e-5, ge=0 41 | ), 42 | warmup_ratio: float = Input( 43 | description="pct of steps for a linear learning rate warmup", 44 | ge=0, 45 | le=0.5, 46 | default=0.03, 47 | ), 48 | num_train_epochs: int = Input( 49 | description="number of training epochs", ge=1, default=1 50 | ), 51 | max_steps: int = Input( 52 | description="number of steps to run training for, supersedes num_train_epochs", 53 | default=-1 54 | ), 55 | logging_steps: int = Input( 56 | description="number of steps between logging epoch & loss", default=2 57 | ), 58 | gradient_checkpointing: bool = Input( 59 | description="whether to use gradient checkpointing to save memory at the cost of speed", 60 | default=True 61 | ), 62 | ) -> TrainingOutput: 63 | input_model = weights if weights is not None else HUGGINGFACE_MODEL_NAME 64 | 65 | root_path = os.getcwd() 66 | deepspeed_config = os.path.join(root_path, "ds_config/ds_flan_t5_z3_config_bf16_no_offload.json") 67 | 68 | output_dir = DIST_OUT_DIR 69 | os.makedirs(output_dir, exist_ok=True) 70 | 71 | num_gpus = torch.cuda.device_count() 72 | num_gpus_flag = f"--num_gpus={num_gpus}" 73 | 74 | print(f"Local Output Dir: {output_dir}") 75 | print(f"Number of GPUs: {num_gpus}") 76 | 77 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 78 | os.environ["HF_DATASETS_CACHE"] = "/src/.hf-cache" 79 | 80 | # TODO: use deepspeed's python api instead of subprocessing 81 | def _arg_if_present(var, var_name): 82 | """Need to wrap any arguments whose default value in train() is `None`""" 83 | if var: 84 | return f"--{var_name} {var}" 85 | return " " 86 | 87 | call( 88 | "deepspeed " 89 | + num_gpus_flag 90 | + " --module training.trainer --deepspeed " 91 | + deepspeed_config 92 | + f" --train_data={str(train_data)}" 93 | + f" --weights={input_model}" 94 | + f" --num_train_epochs={num_train_epochs}" 95 | + f" --max_steps={max_steps}" 96 | + _arg_if_present(eval_data, "eval_data") 97 | + f" --learning_rate {learning_rate}" 98 | + f" --train_batch_size {train_batch_size}" 99 | + f" --gradient_accumulation_steps {gradient_accumulation_steps}" 100 | + f" --logging_steps {logging_steps}" 101 | + f" --warmup_ratio {warmup_ratio}" 102 | + f" --gradient_checkpointing {gradient_checkpointing}" 103 | + " --local_output_dir " 104 | + output_dir, 105 | shell=True, 106 | ) 107 | 108 | if os.path.exists(MODEL_OUT): 109 | os.remove(MODEL_OUT) 110 | 111 | model = T5ForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.float16) 112 | serializer = TensorSerializer(MODEL_OUT) 113 | serializer.write_module(model) 114 | serializer.close() 115 | 116 | return TrainingOutput(weights=Path(MODEL_OUT)) 117 | 118 | 119 | if __name__ == "__main__": 120 | parser = argparse.ArgumentParser( 121 | description="Fine-tune a language model on a text dataset" 122 | ) 123 | parser.add_argument( 124 | "--train_data", type=Path, required=True, help="Path to the json dataset" 125 | ) 126 | parser.add_argument( 127 | "--eval_data", 128 | type=Path, 129 | required=False, 130 | help="Path to the json dataset", 131 | default=None, 132 | ) 133 | parser.add_argument( 134 | "--weights", 135 | type=str, 136 | default=None, 137 | help="The model class to fine-tune on HF or as a local path (e.g. 'google/flan-t5-xxl'", 138 | ) 139 | parser.add_argument( 140 | "--num_train_epochs", type=int, required=True, help="Number of training epochs" 141 | ) 142 | parser.add_argument( 143 | "--learning_rate", 144 | type=float, 145 | default=2e-5, 146 | help="Learning rate for the optimizer", 147 | ) 148 | parser.add_argument( 149 | "--train_batch_size", type=int, default=4, help="Batch size for training" 150 | ) 151 | parser.add_argument( 152 | "--warmup_ratio", 153 | type=float, 154 | default=0.03, 155 | help="Number of warmup steps for the learning rate scheduler", 156 | ) 157 | parser.add_argument( 158 | "--max_steps", 159 | type=int, 160 | default=0, 161 | help="Number of training steps to run, overrides num_train_epochs, useful for testing", 162 | ) 163 | parser.add_argument( 164 | "--gradient_accumulation_steps", 165 | type=int, 166 | default=1, 167 | help="Number of training steps to run, overrides num_train_epochs, useful for testing", 168 | ) 169 | parser.add_argument("--logging_steps", type=int, default=100) 170 | parser.add_argument( 171 | "--lr_scheduler_type", 172 | type=str, 173 | default="cosine", 174 | ) 175 | parser.add_argument( 176 | "--gradient_checkpointing", 177 | type=bool, 178 | default=True, 179 | help="Path to deepspeed config file." 180 | ) 181 | some_args = parser.parse_args() 182 | train(**vars(some_args)) 183 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-flan-models/f35f162f03b2559a755c5ffbe47a5f17009ab867/training/__init__.py -------------------------------------------------------------------------------- /training/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import torch 5 | from cog import Input, Path 6 | from peft import (LoraConfig, TaskType, get_peft_model, 7 | prepare_model_for_int8_training) 8 | from torch.utils.data import Dataset 9 | from transformers import T5ForConditionalGeneration, Trainer, TrainingArguments 10 | 11 | from config import HUGGINGFACE_MODEL_NAME, load_tokenizer 12 | 13 | MODEL_OUT = "/src/tuned_weights.tensors" 14 | CHECKPOINT_DIR = "checkpoints" 15 | SAVE_STRATEGY = "epoch" 16 | DIST_OUT_DIR = "tmp/model" 17 | 18 | 19 | class DatasetBuilder: 20 | """Dataset agnostic class to take in input_ids and labels and spit out tokens""" 21 | 22 | def __init__(self, tokenizer): 23 | self.tokenizer = tokenizer 24 | 25 | def batch_tokenize(self, texts): 26 | """Tokenizes text. Presently doesn't pad inputs, just returns input ids.""" 27 | tokenized = [ 28 | self.tokenizer( 29 | prompt, return_tensors="pt", padding="longest", truncation=True 30 | ).input_ids 31 | for prompt in texts 32 | ] 33 | return tokenized 34 | 35 | def construct_dataset(self, input_data): 36 | prompts = [val["prompt"] for val in input_data] 37 | tokenized_input_ids = self.batch_tokenize(prompts) 38 | labels = [val["completion"] for val in input_data] 39 | tokenized_labels = self.batch_tokenize(labels) 40 | return TuneDataset(tokenized_input_ids, tokenized_labels) 41 | 42 | 43 | class TuneDataset(Dataset): 44 | """Dead simple torch dataset wrapper. Attention masks are created in collator""" 45 | 46 | def __init__(self, input_ids, labels): 47 | self.input_ids = input_ids 48 | self.labels = labels 49 | 50 | def __len__(self): 51 | return len(self.input_ids) 52 | 53 | def __getitem__(self, i): 54 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 55 | 56 | 57 | class CustomDataCollatorSeq2Seq: 58 | """Collate examples for dynamic batch construction in supervised fine-tuning.""" 59 | 60 | def __init__(self, tokenizer, multiple_of=None): 61 | self.tokenizer = tokenizer 62 | self.multiple_of = multiple_of 63 | 64 | def pad_to_multiple(self, tensor, value): 65 | # taking advantage of tensor cores, perhaps 66 | multiple = self.multiple_of 67 | target_length = (tensor.size(0) + multiple - 1) // multiple * multiple 68 | return torch.nn.functional.pad( 69 | tensor, (0, target_length - tensor.size(0)), value=value 70 | ) 71 | 72 | def __call__(self, instances): 73 | input_ids, labels = tuple( 74 | [instance[key][0] for instance in instances] 75 | for key in ("input_ids", "labels") 76 | ) 77 | if self.multiple_of: 78 | input_ids = [ 79 | self.pad_to_multiple(val, self.tokenizer.pad_token_id) 80 | for val in input_ids 81 | ] 82 | labels = [self.pad_to_multiple(val, -100) for val in labels] 83 | 84 | input_ids = torch.nn.utils.rnn.pad_sequence( 85 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 86 | ) 87 | labels = torch.nn.utils.rnn.pad_sequence( 88 | labels, batch_first=True, padding_value=-100 89 | ) # -100 tells torch to ignore these tokens in loss computation. 90 | 91 | return dict( 92 | input_ids=input_ids, 93 | labels=labels, 94 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 95 | ) 96 | 97 | 98 | def load_data(path): 99 | if path.suffix == ".json": 100 | return load_json(path) 101 | elif path.suffix == ".jsonl": 102 | return load_jsonl(path) 103 | else: 104 | raise Exception( 105 | f"file type {path} not supported. Currently supported types are json, jsonl" 106 | ) 107 | 108 | 109 | def load_jsonl(path): 110 | data = [] 111 | with open(path, "r") as f: 112 | for line in f: 113 | json_object = json.loads(line) 114 | data.append(json_object) 115 | return data 116 | 117 | 118 | def load_json(path): 119 | """Loads a single json blob""" 120 | with open(path, "r") as f: 121 | data = json.load(f) 122 | return data 123 | 124 | 125 | def load_model(model_name_or_path, gradient_checkpointing): 126 | if model_name_or_path is None: 127 | model_name_or_path = HUGGINGFACE_MODEL_NAME 128 | model = T5ForConditionalGeneration.from_pretrained( 129 | model_name_or_path, 130 | cache_dir="pretrained_weights", 131 | use_cache=False if gradient_checkpointing else True 132 | ) 133 | 134 | return model 135 | 136 | 137 | def load_peft_model( 138 | model_name_or_path, lora_rank: int, lora_alpha: int, lora_dropout: float 139 | ): 140 | if model_name_or_path is None: 141 | model_name_or_path = HUGGINGFACE_MODEL_NAME 142 | model = T5ForConditionalGeneration.from_pretrained( 143 | model_name_or_path, 144 | cache_dir="pretrained_weights", 145 | torch_dtype=torch.float16, 146 | load_in_8bit=True, 147 | device_map="auto", 148 | ) 149 | model = prepare_model_for_int8_training(model) 150 | config = LoraConfig( 151 | r=lora_rank, 152 | lora_alpha=lora_alpha, 153 | lora_dropout=lora_dropout, 154 | bias="none", 155 | task_type=TaskType.SEQ_2_SEQ_LM, 156 | ) 157 | model = get_peft_model(model, config) 158 | return model 159 | 160 | 161 | def train( 162 | train_data: Path = Input( 163 | description="path to data file to use for fine-tuning your model" 164 | ), 165 | eval_data: Path = Input( 166 | description="path to optional evaluation data file to use for model eval", 167 | default=None, 168 | ), 169 | weights: Path = Input( 170 | description="location of weights that are going to be fine-tuned", default=None 171 | ), 172 | train_batch_size: int = Input(description="batch size per GPU", default=8, ge=1), 173 | gradient_accumulation_steps: int = Input( 174 | description="number of training steps to update gradient for before performing a backward pass", 175 | default=1, 176 | ), 177 | lr_scheduler_type: str = Input( 178 | description="learning rate scheduler", 179 | default="cosine", 180 | choices=[ 181 | "linear", 182 | "cosine", 183 | "cosine_with_restarts", 184 | "polynomial", 185 | "inverse_sqrt", 186 | "constant", 187 | "constant_with_warmup", 188 | ], 189 | ), 190 | learning_rate: float = Input( 191 | description="learning rate, for learning!", default=2e-4, ge=0 192 | ), 193 | warmup_ratio: float = Input( 194 | description="pct of steps for a linear learning rate warmup", 195 | ge=0, 196 | le=0.5, 197 | default=0.03, 198 | ), 199 | num_train_epochs: int = Input( 200 | description="number of training epochs", ge=1, default=1 201 | ), 202 | max_steps: int = Input( 203 | description="number of steps to run training for, supersedes num_train_epochs", 204 | default=-1, 205 | ge=0, 206 | ), 207 | logging_steps: int = Input( 208 | description="number of steps between logging epoch & loss", default=1 209 | ), 210 | local_output_dir: str = None, 211 | deepspeed: str = None, 212 | local_rank: int = -1, 213 | gradient_checkpointing: bool = True, 214 | ) -> None: 215 | print("Loading model...") 216 | 217 | # if peft: 218 | # print("training lora!") 219 | # model = load_peft_model(weights, lora_rank, lora_alpha, lora_dropout) 220 | model = load_model(weights, gradient_checkpointing) 221 | tokenizer = load_tokenizer() 222 | 223 | print(f"Loading dataset {train_data}...") 224 | print(train_data) 225 | train_data = load_data(train_data) 226 | p = DatasetBuilder(tokenizer) 227 | train_dataset = p.construct_dataset(train_data) 228 | eval_dataset = None 229 | if eval_data: 230 | eval_data = load_json(eval_data) 231 | eval_dataset = p.construct_dataset(eval_data) 232 | 233 | print("Training...") 234 | trainer = Trainer( 235 | model=model, 236 | train_dataset=train_dataset, 237 | eval_dataset=eval_dataset, 238 | args=TrainingArguments( 239 | output_dir=CHECKPOINT_DIR, 240 | per_device_train_batch_size=train_batch_size, 241 | gradient_accumulation_steps=gradient_accumulation_steps, 242 | save_strategy="no", 243 | logging_steps=logging_steps, 244 | lr_scheduler_type=lr_scheduler_type, 245 | warmup_ratio=warmup_ratio, 246 | num_train_epochs=num_train_epochs, 247 | learning_rate=learning_rate, 248 | deepspeed=deepspeed, 249 | max_steps=max_steps, 250 | fp16=False, 251 | bf16=True, 252 | half_precision_backend="cuda_amp", 253 | local_rank=local_rank, 254 | gradient_checkpointing=gradient_checkpointing, 255 | ), 256 | data_collator=CustomDataCollatorSeq2Seq(tokenizer, 8), # depends on bf16 value 257 | ) 258 | trainer.train() 259 | trainer.save_model(output_dir=local_output_dir) 260 | return 261 | 262 | 263 | if __name__ == "__main__": 264 | parser = argparse.ArgumentParser( 265 | description="Fine-tune a language model on a text dataset" 266 | ) 267 | parser.add_argument( 268 | "--train_data", type=Path, required=True, help="Path to the json dataset" 269 | ) 270 | parser.add_argument( 271 | "--eval_data", 272 | type=Path, 273 | required=False, 274 | help="Path to the json dataset", 275 | default=None, 276 | ) 277 | parser.add_argument( 278 | "--weights", 279 | type=str, 280 | default=None, 281 | help="The model class to fine-tune on HF or as a local path (e.g. 'google/flan-t5-xxl'", 282 | ) 283 | parser.add_argument( 284 | "--num_train_epochs", type=int, required=True, help="Number of training epochs" 285 | ) 286 | parser.add_argument( 287 | "--learning_rate", 288 | type=float, 289 | default=2e-5, 290 | help="Learning rate for the optimizer", 291 | ) 292 | parser.add_argument( 293 | "--train_batch_size", type=int, default=4, help="Batch size for training" 294 | ) 295 | parser.add_argument( 296 | "--warmup_ratio", 297 | type=float, 298 | default=0.03, 299 | help="Number of warmup steps for the learning rate scheduler", 300 | ) 301 | parser.add_argument( 302 | "--max_steps", 303 | type=int, 304 | default=0, 305 | help="Number of training steps to run, overrides num_train_epochs, useful for testing", 306 | ) 307 | parser.add_argument( 308 | "--gradient_accumulation_steps", 309 | type=int, 310 | default=8, 311 | help="Number of training steps to run, overrides num_train_epochs, useful for testing", 312 | ) 313 | parser.add_argument("--logging_steps", type=int, default=1) 314 | parser.add_argument( 315 | "--lr_scheduler_type", 316 | type=str, 317 | default="cosine", 318 | ) 319 | parser.add_argument( 320 | "--deepspeed", type=str, default=None, help="Path to deepspeed config file." 321 | ) 322 | parser.add_argument( 323 | "--local_output_dir", 324 | type=str, 325 | help="Write directly to this local path", 326 | required=True, 327 | ) 328 | parser.add_argument( 329 | "--local_rank", 330 | type=int, 331 | default=-1, 332 | help="Provided by deepspeed to identify which instance this process is when performing multi-GPU training.", 333 | ) 334 | parser.add_argument( 335 | "--gradient_checkpointing", 336 | type=bool, 337 | default=True, 338 | help="Path to deepspeed config file." 339 | ) 340 | some_args = parser.parse_args() 341 | train(**vars(some_args)) 342 | 343 | # parser.add_argument( 344 | # "--local_rank", 345 | # type=int, 346 | # default=0 347 | # ) 348 | # parser.add_argument( 349 | # "--peft", 350 | # action="store_true" 351 | # ) 352 | # parser.add_argument( 353 | # "--lora_rank", 354 | # type=int, 355 | # default=16, 356 | # help="Number of training steps to run, overrides num_train_epochs, useful for testing", 357 | # ) 358 | # parser.add_argument( 359 | # "--lora_alpha", 360 | # type=int, 361 | # default=16, 362 | # help="Number of training steps to run, overrides num_train_epochs, useful for testing", 363 | # ) 364 | # parser.add_argument( 365 | # "--lora_dropout", 366 | # type=float, 367 | # default=0.4, 368 | # help="Number of training steps to run, overrides num_train_epochs, useful for testing", 369 | # ) 370 | --------------------------------------------------------------------------------