├── LICENSE ├── README.md ├── instruction_finetuning ├── dataset_generation │ └── Ad_Copy_Dataset.ipynb └── training │ ├── __init__.py │ ├── requirements.txt │ ├── run_peft.sh │ ├── train.py │ └── utils.py ├── multimodal_instruction_finetuning └── IDEFICS_Finetuning_demo.ipynb └── personal_copilot ├── dataset_generation ├── README.md ├── clone_hf_repos.py ├── prepare_dataset.py └── requirements.txt └── training ├── fim.py ├── llama_flash_attn_monkey_patch.py ├── requirements.txt ├── run_peft.sh └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # peft-pytorch-conference 2 | Code for the examples presented in the talk "Training a Llama in your backyard: fine-tuning very large models on consumer hardware" given at PyTorch Conference 2023 3 | -------------------------------------------------------------------------------- /instruction_finetuning/dataset_generation/Ad_Copy_Dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "id": "7a78203d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from transformers import AutoTokenizer\n", 11 | "\n", 12 | "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-2-7b-chat-hf\", use_auth_token=True)\n" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 6, 18 | "id": "84a5c544", 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\\'t know the answer to a question, please don\\'t share false information.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}\n" 26 | ] 27 | } 28 | ], 29 | "source": [ 30 | "print(tokenizer.default_chat_template)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 19, 36 | "id": "f785dd97", 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "DatasetDict({\n", 44 | " train: Dataset({\n", 45 | " features: ['content'],\n", 46 | " num_rows: 1000\n", 47 | " })\n", 48 | " test: Dataset({\n", 49 | " features: ['content'],\n", 50 | " num_rows: 141\n", 51 | " })\n", 52 | "})\n", 53 | "{'content': '[INST] <>\\nCreate a text ad given the following product and description.\\n<>\\n\\nProduct: Fitness Magazine\\nDescription: Fitness magazine for staying active and achieving your fitness goals. [/INST] Ad: Stay active with a Fitness Magazine! 💪📖 Experience fitness tips and motivating stories. Perfect for fitness enthusiasts and reaching your health and wellness goals. Limited stock - achieve fitness with a touch of motivation! 🌟🌟🏋️\\u200d♀️ '}\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "from datasets import load_dataset\n", 59 | "\n", 60 | "system_prompt = \"\"\"Create a text ad given the following product and description.\"\"\"\n", 61 | "\n", 62 | "def preprocess(samples):\n", 63 | " batch = []\n", 64 | " for product, desc, ad_copy in zip(samples[\"product\"],samples[\"description\"],samples[\"ad\"]):\n", 65 | " conversation = [\n", 66 | " {\"role\": \"system\", \"content\": system_prompt},\n", 67 | " {\"role\": \"user\", \"content\": f\"\"\"Product: {product}\\nDescription: {desc}\\n\"\"\"},\n", 68 | " {\"role\": \"assistant\", \"content\": f\"\"\"Ad: {ad_copy}\\n\"\"\"},\n", 69 | " ]\n", 70 | " batch.append(tokenizer.apply_chat_template(conversation, tokenize=False))\n", 71 | " return {\"content\": batch}\n", 72 | " \n", 73 | " \n", 74 | "\n", 75 | "\n", 76 | "dataset = load_dataset(\"jaykin01/advertisement-copy\")\n", 77 | "dataset\n", 78 | "dataset = dataset.map(\n", 79 | " preprocess,\n", 80 | " batched=True,\n", 81 | " remove_columns=dataset[\"train\"].column_names\n", 82 | ")\n", 83 | "\n", 84 | "dataset[\"train\"] = dataset[\"train\"].shuffle(100)\n", 85 | "dataset_subsets = dataset[\"train\"].train_test_split(141)\n", 86 | "dataset[\"train\"] = dataset_subsets[\"train\"]\n", 87 | "dataset[\"test\"] = dataset_subsets[\"test\"]\n", 88 | "print(dataset)\n", 89 | "print(dataset[\"train\"][0])" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 20, 95 | "id": "8cbcdfa1", 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "application/vnd.jupyter.widget-view+json": { 101 | "model_id": "a7dfe18f4a4d418da38237e6e98d1d91", 102 | "version_major": 2, 103 | "version_minor": 0 104 | }, 105 | "text/plain": [ 106 | "Pushing dataset shards to the dataset hub: 0%| | 0/1 [00:00= self.max_buffer_size: 85 | break 86 | try: 87 | buffer.append(next(iterator)[self.content_field]) 88 | buffer_len += len(buffer[-1]) 89 | except StopIteration: 90 | if self.infinite: 91 | iterator = iter(self.dataset) 92 | else: 93 | more_examples = False 94 | break 95 | tokenized_inputs = self.tokenizer(buffer, truncation=False, add_special_tokens=False)["input_ids"] 96 | all_token_ids = [] 97 | for tokenized_input in tokenized_inputs: 98 | if self.add_eos_token: 99 | tokenized_input = tokenized_input + [self.concat_token_id] 100 | all_token_ids.extend(tokenized_input) 101 | examples = [] 102 | for i in range(0, len(all_token_ids), self.seq_length): 103 | input_ids = all_token_ids[i : i + self.seq_length] 104 | if len(input_ids) == self.seq_length: 105 | examples.append(input_ids) 106 | if self.shuffle: 107 | random.shuffle(examples) 108 | for example in examples: 109 | self.current_size += 1 110 | yield { 111 | "input_ids": torch.LongTensor(example), 112 | "labels": torch.LongTensor(example), 113 | } 114 | 115 | 116 | def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400): 117 | """ 118 | Estimate the average number of characters per token in the dataset. 119 | """ 120 | total_characters, total_tokens = 0, 0 121 | for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): 122 | total_characters += len(example[data_column]) 123 | total_tokens += len(tokenizer(example[data_column]).tokens()) 124 | 125 | return total_characters / total_tokens 126 | 127 | 128 | def create_datasets(tokenizer, args): 129 | dataset = load_dataset(args.dataset_name, use_auth_token=True, num_proc=args.num_workers) 130 | train_data = dataset["train"] 131 | valid_data = dataset["test"] 132 | print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") 133 | chars_per_token = chars_token_ratio(train_data, tokenizer, args.dataset_text_field) 134 | print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") 135 | train_dataset = ConstantLengthDataset( 136 | tokenizer, 137 | train_data, 138 | infinite=True, 139 | seq_length=args.max_seq_length, 140 | chars_per_token=chars_per_token, 141 | content_field=args.dataset_text_field, 142 | shuffle=True, 143 | add_eos_token=False, 144 | ) 145 | valid_dataset = ConstantLengthDataset( 146 | tokenizer, 147 | valid_data, 148 | infinite=False, 149 | seq_length=args.max_seq_length, 150 | chars_per_token=chars_per_token, 151 | content_field=args.dataset_text_field, 152 | shuffle=False, 153 | add_eos_token=False, 154 | ) 155 | 156 | return train_dataset, valid_dataset 157 | 158 | 159 | def create_and_prepare_model(args): 160 | device_map = None 161 | bnb_config = None 162 | load_in_8bit = args.use_8bit_qunatization 163 | 164 | if args.use_4bit_qunatization: 165 | compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype) 166 | 167 | bnb_config = BitsAndBytesConfig( 168 | load_in_4bit=args.use_4bit_qunatization, 169 | bnb_4bit_quant_type=args.bnb_4bit_quant_type, 170 | bnb_4bit_compute_dtype=compute_dtype, 171 | bnb_4bit_use_double_quant=args.use_nested_quant, 172 | ) 173 | 174 | if compute_dtype == torch.float16 and args.use_4bit_qunatization: 175 | major, _ = torch.cuda.get_device_capability() 176 | if major >= 8: 177 | print("=" * 80) 178 | print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16") 179 | print("=" * 80) 180 | 181 | if args.use_4bit_qunatization or args.use_8bit_qunatization: 182 | device_map = "auto" # {"": 0} 183 | 184 | model = AutoModelForCausalLM.from_pretrained( 185 | args.model_name, 186 | torch_dtype=compute_dtype, 187 | load_in_8bit=load_in_8bit, 188 | quantization_config=bnb_config, 189 | device_map=device_map, 190 | use_cache=not args.use_gradient_checkpointing, 191 | trust_remote_code=True, 192 | use_flash_attention_2=args.use_flash_attn, 193 | ) 194 | 195 | peft_config = None 196 | if args.use_peft_lora: 197 | peft_config = LoraConfig( 198 | lora_alpha=args.lora_alpha, 199 | lora_dropout=args.lora_dropout, 200 | r=args.lora_r, 201 | bias="none", 202 | task_type="CAUSAL_LM", 203 | target_modules=args.lora_target_modules.split(","), 204 | ) 205 | 206 | if (args.use_4bit_qunatization or args.use_8bit_qunatization) and args.use_peft_lora: 207 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.use_gradient_checkpointing) 208 | 209 | if args.use_gradient_checkpointing: 210 | model.gradient_checkpointing_enable() 211 | 212 | model = get_peft_model(model, peft_config) 213 | model.print_trainable_parameters() 214 | 215 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) 216 | tokenizer.pad_token = tokenizer.eos_token 217 | 218 | return model, peft_config, tokenizer 219 | 220 | 221 | def peft_module_casting_to_bf16(model, args): 222 | for name, module in model.named_modules(): 223 | if isinstance(module, LoraLayer): 224 | if args.bf16: 225 | module = module.to(torch.bfloat16) 226 | if "norm" in name: 227 | module = module.to(torch.float32) 228 | if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): 229 | if hasattr(module, "weight"): 230 | if args.bf16 and module.weight.dtype == torch.float32: 231 | module = module.to(torch.bfloat16) 232 | -------------------------------------------------------------------------------- /multimodal_instruction_finetuning/IDEFICS_Finetuning_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "UNCNPVi8iAgw" 7 | }, 8 | "source": [ 9 | "# IDEFICS: A Flamingo-based model, trained at scale for the community\n", 10 | "# Finetuning Demo Notebook:\n", 11 | "\n", 12 | "
\n", 13 | "
\n", 14 | "
\n", 15 | " \"Idefics\n", 16 | "
\n", 17 | "\n", 18 | "Credit: [Flamingo blog](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model)\n", 19 | "\n", 20 | "This google colab notebook shows how to run predictions with the 4-bit quantized 🤗 [Idefics-9B model](https://huggingface.co/HuggingFaceM4/idefics-9b) and finetune it on a specific dataset.\n", 21 | "\n", 22 | "[IDEFICS](https://huggingface.co/HuggingFaceM4/idefics-80b) is a multi-modal model based on the [Flamingo](https://arxiv.org/abs/2204.14198) architecture. It can take images and texts as input and return text outputs but it does not support image generation. \\\\\n", 23 | "IDEFICS is built on top of two unimodal open-access pre-trained models to connect the two modalities. Newly initialized parameters in the form of Transformer blocks bridge the gap between the vision encoder and the language model. The model is trained on a mixture of image/text pairs and unstrucutred multimodal web documents. \\\\\n", 24 | "The [finetuned versions](https://huggingface.co/HuggingFaceM4/idefics-80b-instruct) of IDEFICS behave like LLM chatbots while also understanding visual input. \\\\\n", 25 | "You can play with the [demo here](https://huggingface.co/spaces/HuggingFaceM4/idefics_playground)\n", 26 | "\n", 27 | "The code for this notebook was contributed to by *Léo Tronchon, Younes Belkada, and Stas Bekman*, the IDEFICS model has been contributed to by: *Lucile Saulnier, Léo Tronchon, Hugo Laurençon, Stas Bekman, Amanpreet Singh, Siddharth Karamcheti, and Victor Sanh*" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "source": [ 33 | "# Install and import necessary libraries" 34 | ], 35 | "metadata": { 36 | "id": "7m9zw1wcCC8e" 37 | } 38 | }, 39 | { 40 | "cell_type": "code", 41 | "source": [ 42 | "!pip install -q datasets\n", 43 | "!pip install -q git+https://github.com/huggingface/transformers.git@add-model-idefics\n", 44 | "!pip install -q bitsandbytes sentencepiece accelerate loralib\n", 45 | "!pip install -q -U git+https://github.com/huggingface/peft.git" 46 | ], 47 | "metadata": { 48 | "colab": { 49 | "base_uri": "https://localhost:8080/" 50 | }, 51 | "id": "prXRsUiXCII9", 52 | "outputId": "3b9da6dd-365b-484d-9d37-a723eee947de" 53 | }, 54 | "execution_count": null, 55 | "outputs": [ 56 | { 57 | "output_type": "stream", 58 | "name": "stdout", 59 | "text": [ 60 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.3/519.3 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 61 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 62 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m9.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 63 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 64 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 65 | "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", 66 | " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", 67 | " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", 68 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m17.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 69 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m31.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 70 | "\u001b[?25h Building wheel for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", 71 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.6/92.6 MB\u001b[0m \u001b[31m11.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 72 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m67.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 73 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m244.2/244.2 kB\u001b[0m \u001b[31m25.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 74 | "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", 75 | " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", 76 | " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", 77 | " Building wheel for peft (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n" 78 | ] 79 | } 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": { 86 | "id": "MxoHmx-HfAgf" 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "import torch\n", 91 | "from datasets import load_dataset\n", 92 | "from peft import LoraConfig, get_peft_model\n", 93 | "from PIL import Image\n", 94 | "from transformers import IdeficsForVisionText2Text, AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig\n", 95 | "import torchvision.transforms as transforms" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": { 101 | "id": "DP_ilre6jI6l" 102 | }, 103 | "source": [ 104 | "# Load quantized model\n", 105 | "First get the quantized version of the model. This will allow us to use the 9B version of Idefics with a single 16GB gpu\n", 106 | "\n" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "colab": { 114 | "base_uri": "https://localhost:8080/", 115 | "height": 84, 116 | "referenced_widgets": [ 117 | "cf454254fbc74724a6909e60d82f86a3", 118 | "561b1b43dbc1484784ea2abed7278c08", 119 | "996e2ae7de594ccc968ce83382786365", 120 | "7e72c1fdf039470f8b14859034c7942f", 121 | "f34958207dca46fd9aa044912ec9fddb", 122 | "0fa55920c3a54b30aca74aa7247fe2ea", 123 | "119ec52a3ce54b0d9565a0d44e731850", 124 | "27e2b5c562174873bb966f1408727058", 125 | "008e6d4c958149819fd7e64e30f79e39", 126 | "9302d5fbae224b999a0c3fcb3f34beb3", 127 | "8c82d2f9f97047478d8399b2aee3389f" 128 | ] 129 | }, 130 | "id": "IRiT0q0Ck-3Y", 131 | "outputId": "52bc69ec-32ec-45d7-b1a2-1a7af0539506" 132 | }, 133 | "outputs": [ 134 | { 135 | "output_type": "stream", 136 | "name": "stderr", 137 | "text": [ 138 | "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/processing_auto.py:203: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.\n", 139 | " warnings.warn(\n" 140 | ] 141 | }, 142 | { 143 | "output_type": "display_data", 144 | "data": { 145 | "text/plain": [ 146 | "Loading checkpoint shards: 0%| | 0/19 [00:00\", \"\"]\n", 329 | " if len(bad_words) > 0:\n", 330 | " bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids\n", 331 | "\n", 332 | " eos_token = \"\"\n", 333 | " eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)\n", 334 | "\n", 335 | " inputs = processor(prompts, return_tensors=\"pt\").to(device)\n", 336 | " generated_ids = model.generate(**inputs, eos_token_id=[eos_token_id], bad_words_ids=bad_words_ids, max_new_tokens=max_new_tokens, early_stopping=True)\n", 337 | " generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]\n", 338 | " print(generated_text)" 339 | ], 340 | "metadata": { 341 | "id": "J5MSZ3xdPF4f" 342 | }, 343 | "execution_count": null, 344 | "outputs": [] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": { 349 | "id": "RYA2HKGC0n9d" 350 | }, 351 | "source": [ 352 | "\n", 353 | "Let's run prediction with the quantized model for the image below which pictures two kittens. \\\\\n", 354 | "" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "metadata": { 361 | "id": "6I_iDtQN03jE", 362 | "colab": { 363 | "base_uri": "https://localhost:8080/" 364 | }, 365 | "outputId": "a4a77c65-186a-45e0-f819-3ea3d9d319c0" 366 | }, 367 | "outputs": [ 368 | { 369 | "output_type": "stream", 370 | "name": "stdout", 371 | "text": [ 372 | "Question: What's on the picture? Answer: Two kittens.\n" 373 | ] 374 | } 375 | ], 376 | "source": [ 377 | "url = \"https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg\"\n", 378 | "prompts = [\n", 379 | " # \"Instruction: provide an answer to the question. Use the image to answer.\\n\",\n", 380 | " url,\n", 381 | " \"Question: What's on the picture? Answer:\",\n", 382 | "]\n", 383 | "check_inference(model, processor, prompts, max_new_tokens=5)\n" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "source": [ 389 | "Now let's see how the model fares on pokemon knowledge before we try to finetune it further. \\\\\n", 390 | "\n" 391 | ], 392 | "metadata": { 393 | "id": "DLiwPnGBxiJf" 394 | } 395 | }, 396 | { 397 | "cell_type": "code", 398 | "source": [ 399 | "# check generation before finetuning\n", 400 | "\n", 401 | "url = \"https://images.pokemontcg.io/pop6/2_hires.png\"\n", 402 | "prompts = [\n", 403 | " url,\n", 404 | " \"Question: What's on the picture? Answer:\",\n", 405 | "]\n", 406 | "check_inference(model, processor, prompts, max_new_tokens=100)\n", 407 | "# It looks like the model is already aware of pokemon - but it could be more specific, and less repetitive" 408 | ], 409 | "metadata": { 410 | "colab": { 411 | "base_uri": "https://localhost:8080/" 412 | }, 413 | "id": "lDVDUE1ew7tZ", 414 | "outputId": "37ba5c61-c607-4282-e57b-25cada593391" 415 | }, 416 | "execution_count": null, 417 | "outputs": [ 418 | { 419 | "output_type": "stream", 420 | "name": "stdout", 421 | "text": [ 422 | "Question: What's on the picture? Answer: Lucario\n", 423 | "\n", 424 | "Lucario is a Pokémon that is a combination of a bear and a lion. It is a Pokémon that is a combination of a bear and a lion. It is a Pokémon that is a combination of a bear and a lion. It is a Pokémon that is a combination of a bear and a lion. It is a Pokémon that is a combination of a bear and a lion. It is a Pok\n" 425 | ] 426 | } 427 | ] 428 | }, 429 | { 430 | "cell_type": "markdown", 431 | "source": [ 432 | "# Finetuning dataset\n", 433 | "Prepare the dataset that will be used for finetuning\n" 434 | ], 435 | "metadata": { 436 | "id": "ydBhQT6SQiWy" 437 | } 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "metadata": { 443 | "colab": { 444 | "base_uri": "https://localhost:8080/", 445 | "height": 177, 446 | "referenced_widgets": [ 447 | "eac0761e22a84275aaee5d7ec7929da6", 448 | "ba24eb82f1194ecab3514466eca8a2b8", 449 | "52997c23e16a4f8aa220909e99b5452e", 450 | "6b7767dc6c5b45a89f7becfe5fcf81d7", 451 | "050b365a82b0412b83918f9f9603bf2f", 452 | "39c0d7023e574db9a55eb7e82913d4ed", 453 | "9bdbd4871dcd49a5bbfaa86b813e9a36", 454 | "7bfee1d1c4134316af5b82cd354457ba", 455 | "594fd06a2b07443a9ce27200468d5fe3", 456 | "55de5af50af247cd93da17057661fd6c", 457 | "450f2b15f9df4f72b23c4f916bc18f3b", 458 | "22df8e4fce3b470b94fdce6e7b77a9ea", 459 | "cd89e195d2bb4537889ec8cc9e7a815e", 460 | "196951cc2fdd43d4a153de2666067cd0", 461 | "aaf9e7678c174fe8820c5c0bebb6bb1e", 462 | "699f568cedd846f590efa2500dd8b3a9", 463 | "a0f3836eb674483295fbb147065b74fc", 464 | "da922717666a496da59a4cf8840e6554", 465 | "cc9ddc6c56324dd59cfb8bc9649fea28", 466 | "830ec2345d9a4be88b486ad24bfc3b10", 467 | "05b63fe3c99c417fb6bcdb450081bff8", 468 | "e2872821e4e84271b32b8c8c8c093bfc", 469 | "782d656769144ef9b48a3a37de81abb5", 470 | "eb2f4bcb78534f4d9f9e2ccb52e738b7", 471 | "a2bcf8164d904dcbada2196189b332be", 472 | "99b5b2cd3f104c72b5ee880fe1d0e9b9", 473 | "3197f87aadd5422cbb9804b0843ffc48", 474 | "5cdf7a7b08cc46f5a4b2da143ba39bb6", 475 | "8f335a7d85574c11b183fb700aeac5c3", 476 | "6b96186a1ccb4e24b491b5849ac90c50", 477 | "3a845c0efe954da1a47e77740f8623ff", 478 | "4c8f47c325a54f52abab545362f36c43", 479 | "7c1dc629e6dc4048b1b88a224c9a352d", 480 | "da84172eaff34e61ac902681dbd364ca", 481 | "2796bada5f6748b6af59f6b14b0957af", 482 | "400b852ef365473cad76663421954c86", 483 | "fd58bb90108a4486967a217eb3bc4389", 484 | "b96a2d9afc324a4eb52f7a04caab630a", 485 | "7c20b8d8e3b14504bba903e68d043e79", 486 | "c8bc395e18e14492ae40ec6ff21a18d1", 487 | "6c85b036e1be434faa2d515bed62e228", 488 | "da15ec7761a847678dc696b214c67ada", 489 | "03d2d213eb2a4c819bbcf8457e11904b", 490 | "f651fffdc274473a85ed701097afaa1f", 491 | "3fbc282a30cc49b99f335216df028cd6", 492 | "651249802d0249479eb1700e600f9a5a", 493 | "31e2d7d5057a4dfa96a65888697e9923", 494 | "cbaf9ba59da24341a933c3c7473a3b7d", 495 | "ee26c8314e6742a88cd59429f3d5b745", 496 | "dd9e81eb4e3d45cca5c6e2b1e6cf335d", 497 | "a7c9efe8c49a43d0ba6929bada9f78c2", 498 | "15d3af1073fe4447847d0e6f3543f953", 499 | "e4daf9a3e9e14e93ab55b91da59ecc9b", 500 | "3557bb8fc4064fdf99ca2a1ec5469cff", 501 | "b2ccec96efa1415fa4623ec8fa0f2c21" 502 | ] 503 | }, 504 | "id": "5iZAz655m8Q9", 505 | "outputId": "6524cedf-f0f1-43fa-d5dc-2b4f2d8f6eb1" 506 | }, 507 | "outputs": [ 508 | { 509 | "output_type": "display_data", 510 | "data": { 511 | "text/plain": [ 512 | "Downloading readme: 0%| | 0.00/2.77k [00:00\",\n", 612 | " ],\n", 613 | " )\n", 614 | "\n", 615 | " inputs = processor(prompts, transform=image_transform, return_tensors=\"pt\").to(device)\n", 616 | "\n", 617 | " inputs[\"labels\"] = inputs[\"input_ids\"]\n", 618 | "\n", 619 | " return inputs\n", 620 | "\n", 621 | "\n", 622 | "# load and prepare dataset\n", 623 | "ds = load_dataset(\"TheFusion21/PokemonCards\")\n", 624 | "ds = ds[\"train\"].train_test_split(test_size=0.002)\n", 625 | "train_ds = ds[\"train\"]\n", 626 | "eval_ds = ds[\"test\"]\n", 627 | "train_ds.set_transform(ds_transforms)\n", 628 | "eval_ds.set_transform(ds_transforms)" 629 | ] 630 | }, 631 | { 632 | "cell_type": "markdown", 633 | "source": [ 634 | "# LoRA\n", 635 | "After specifying the low-rank adapters (LoRA) config, we load the PeftModel using the get_peft_model utility function" 636 | ], 637 | "metadata": { 638 | "id": "Kui4EkCmOQzd" 639 | } 640 | }, 641 | { 642 | "cell_type": "code", 643 | "execution_count": null, 644 | "metadata": { 645 | "id": "jKa5oTorp_A-" 646 | }, 647 | "outputs": [], 648 | "source": [ 649 | "model_name = checkpoint.split(\"/\")[1]\n", 650 | "config = LoraConfig(\n", 651 | " r=16,\n", 652 | " lora_alpha=32,\n", 653 | " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\"],\n", 654 | " lora_dropout=0.05,\n", 655 | " bias=\"none\",\n", 656 | ")\n", 657 | "model = get_peft_model(model, config)" 658 | ] 659 | }, 660 | { 661 | "cell_type": "code", 662 | "execution_count": null, 663 | "metadata": { 664 | "colab": { 665 | "base_uri": "https://localhost:8080/" 666 | }, 667 | "id": "ShuZJ5K2pYoL", 668 | "outputId": "6c22299b-5584-4994-c906-e9d031b40ad1" 669 | }, 670 | "outputs": [ 671 | { 672 | "output_type": "stream", 673 | "name": "stdout", 674 | "text": [ 675 | "trainable params: 19,750,912 || all params: 8,949,430,544 || trainable%: 0.2206946230030432\n" 676 | ] 677 | } 678 | ], 679 | "source": [ 680 | "model.print_trainable_parameters()" 681 | ] 682 | }, 683 | { 684 | "cell_type": "markdown", 685 | "source": [ 686 | "# Training\n", 687 | "Finally, using the Hugging Face Trainer, we can finetune the model! \\\\\n", 688 | "For the sake of the demo, we have set the max_steps at 40. That's about 0.05 epoch on this dataset, so feel free to tune further!" 689 | ], 690 | "metadata": { 691 | "id": "0Ok1sOZKQ29s" 692 | } 693 | }, 694 | { 695 | "cell_type": "code", 696 | "execution_count": null, 697 | "metadata": { 698 | "colab": { 699 | "base_uri": "https://localhost:8080/", 700 | "height": 155 701 | }, 702 | "id": "9cD3OuygpR5l", 703 | "outputId": "a8238139-59c3-49cb-c654-4aacb010dd7a" 704 | }, 705 | "outputs": [ 706 | { 707 | "output_type": "display_data", 708 | "data": { 709 | "text/plain": [ 710 | "" 711 | ], 712 | "text/html": [ 713 | "\n", 714 | "
\n", 715 | " \n", 716 | " \n", 717 | " [40/40 06:32, Epoch 0/1]\n", 718 | "
\n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | "
StepTraining LossValidation Loss
201.4500000.880157
400.7020000.675355

" 740 | ] 741 | }, 742 | "metadata": {} 743 | }, 744 | { 745 | "output_type": "execute_result", 746 | "data": { 747 | "text/plain": [ 748 | "TrainOutput(global_step=40, training_loss=1.0759869813919067, metrics={'train_runtime': 403.1999, 'train_samples_per_second': 1.587, 'train_steps_per_second': 0.099, 'total_flos': 1445219210656320.0, 'train_loss': 1.0759869813919067, 'epoch': 0.05})" 749 | ] 750 | }, 751 | "metadata": {}, 752 | "execution_count": 23 753 | } 754 | ], 755 | "source": [ 756 | "training_args = TrainingArguments(\n", 757 | " output_dir=f\"{model_name}-pokemon\",\n", 758 | " learning_rate=2e-4,\n", 759 | " fp16=True,\n", 760 | " per_device_train_batch_size=2,\n", 761 | " per_device_eval_batch_size=2,\n", 762 | " gradient_accumulation_steps=8,\n", 763 | " dataloader_pin_memory=False,\n", 764 | " save_total_limit=3,\n", 765 | " evaluation_strategy=\"steps\",\n", 766 | " save_strategy=\"steps\",\n", 767 | " save_steps=40,\n", 768 | " eval_steps=20,\n", 769 | " logging_steps=20,\n", 770 | " max_steps=40,\n", 771 | " remove_unused_columns=False,\n", 772 | " push_to_hub=False,\n", 773 | " label_names=[\"labels\"],\n", 774 | " load_best_model_at_end=True,\n", 775 | " report_to=None,\n", 776 | " optim=\"paged_adamw_8bit\",\n", 777 | ")\n", 778 | "\n", 779 | "trainer = Trainer(\n", 780 | " model=model,\n", 781 | " args=training_args,\n", 782 | " train_dataset=train_ds,\n", 783 | " eval_dataset=eval_ds,\n", 784 | ")\n", 785 | "\n", 786 | "trainer.train()" 787 | ] 788 | }, 789 | { 790 | "cell_type": "code", 791 | "source": [ 792 | "# check generation again after finetuning\n", 793 | "check_inference(model, processor, prompts, max_new_tokens=100)" 794 | ], 795 | "metadata": { 796 | "colab": { 797 | "base_uri": "https://localhost:8080/" 798 | }, 799 | "id": "v6NZ47vYTr-z", 800 | "outputId": "8807a1dc-e37e-4c36-da02-507029a546ab" 801 | }, 802 | "execution_count": null, 803 | "outputs": [ 804 | { 805 | "output_type": "stream", 806 | "name": "stdout", 807 | "text": [ 808 | "Question: What's on the picture? Answer: This is Lucario. A Stage 2 Pokemon Card of type Fighting with the title Lucario and 90 HP of rarity Rare evolved from Pikachu from the set Neo Destiny and the flavor text: It can use its tail as a whip\n" 809 | ] 810 | } 811 | ] 812 | }, 813 | { 814 | "cell_type": "markdown", 815 | "source": [ 816 | "# Push your new model to the hub!\n" 817 | ], 818 | "metadata": { 819 | "id": "zgqonle8AdPs" 820 | } 821 | }, 822 | { 823 | "cell_type": "code", 824 | "source": [ 825 | "# Insert your \"write\" token. You should find it in the settings of your HF profile\n", 826 | "!huggingface-cli login" 827 | ], 828 | "metadata": { 829 | "colab": { 830 | "base_uri": "https://localhost:8080/" 831 | }, 832 | "id": "KrnB4kFxAjIA", 833 | "outputId": "8370ee48-9b3d-446b-b69a-c3cec93f61fd" 834 | }, 835 | "execution_count": null, 836 | "outputs": [ 837 | { 838 | "output_type": "stream", 839 | "name": "stdout", 840 | "text": [ 841 | "\n", 842 | " _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|\n", 843 | " _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n", 844 | " _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|\n", 845 | " _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n", 846 | " _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|\n", 847 | " \n", 848 | " A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.\n", 849 | " Setting a new token will erase the existing one.\n", 850 | " To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .\n", 851 | "Token: \n", 852 | "Add token as git credential? (Y/n) Y\n", 853 | "Token is valid (permission: write).\n", 854 | "\u001b[1m\u001b[31mCannot authenticate through git-credential as no helper is defined on your machine.\n", 855 | "You might have to re-authenticate when pushing to the Hugging Face Hub.\n", 856 | "Run the following command in your terminal in case you want to set the 'store' credential helper as default.\n", 857 | "\n", 858 | "git config --global credential.helper store\n", 859 | "\n", 860 | "Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.\u001b[0m\n", 861 | "Token has not been saved to git credential helper.\n", 862 | "Your token has been saved to /root/.cache/huggingface/token\n", 863 | "Login successful\n" 864 | ] 865 | } 866 | ] 867 | }, 868 | { 869 | "cell_type": "code", 870 | "source": [ 871 | "model.push_to_hub(f\"{model_name}-pokemon\", private=False)" 872 | ], 873 | "metadata": { 874 | "colab": { 875 | "base_uri": "https://localhost:8080/", 876 | "height": 66, 877 | "referenced_widgets": [ 878 | "73bdfdf8980d4c358c90d574eb91bef5", 879 | "3f49f9009fa14fe3b87bb123491a4b0f", 880 | "548fc33764964fe9a0498194df85b768", 881 | "e60ce018bf3a4b15941062300143e2a3", 882 | "d0a78497d9694dc6b7e903392daf6a26", 883 | "5b585e82891a40b0826679a79583ee7c", 884 | "db9f5a1c1a0a49b3b58a30f0a74c3329", 885 | "641cf05e799e4ae89ec84fdf8c225b93", 886 | "cf937013fade482f90bd599eced8bfb4", 887 | "ee49f8d2b11f43e2bb30d27407744ed3", 888 | "ef2f2655d7b9432f983ae508f6dd4e0b" 889 | ] 890 | }, 891 | "id": "_jFKg3iP172d", 892 | "outputId": "2b58ecb2-fe97-4a6c-bd2c-7fdaaf03e99a" 893 | }, 894 | "execution_count": null, 895 | "outputs": [ 896 | { 897 | "output_type": "display_data", 898 | "data": { 899 | "text/plain": [ 900 | "adapter_model.bin: 0%| | 0.00/79.2M [00:00 bool: 71 | """Filters a code cell w.r.t shell commands, etc.""" 72 | only_shell = cell["source"].startswith("!") 73 | only_magic = "%%capture" in cell["source"] 74 | if only_shell or only_magic: 75 | return False 76 | else: 77 | return True 78 | 79 | 80 | def process_file(directory_name: str, file_path: str) -> Dict[str, str]: 81 | """Processes a single file.""" 82 | try: 83 | with open(file_path, "r", encoding="utf-8") as file: 84 | content = file.read() 85 | if file_path.endswith("ipynb"): 86 | # Code courtesy: Chansung Park and Sayak Paul. 87 | code_cell_str = "" 88 | notebook = reads(content, NO_CONVERT) 89 | 90 | code_cells = [c for c in notebook["cells"] if c["cell_type"] == "code" if filter_code_cell(c)] 91 | 92 | for cell in code_cells: 93 | code_cell_str += cell["source"] 94 | content = code_cell_str 95 | except Exception: 96 | content = "" 97 | 98 | return { 99 | "repo_id": directory_name, 100 | "file_path": file_path, 101 | "content": content, 102 | } 103 | 104 | 105 | def read_repository_files(directory) -> pd.DataFrame: 106 | """Reads the files from the locally cloned repositories.""" 107 | file_paths = [] 108 | df = pd.DataFrame(columns=["repo_id", "file_path", "content"]) 109 | chunk_flag = 0 110 | 111 | # Recursively find all files within the directory 112 | for root, _, files in os.walk(directory): 113 | for file in files: 114 | file_path = os.path.join(root, file) 115 | if not file_path.endswith(ANTI_FOMATS) and all( 116 | k not in file_path for k in [".git", "__pycache__", "xcodeproj"] 117 | ): 118 | file_paths.append((os.path.dirname(root), file_path)) 119 | 120 | # Process files sequentially. 121 | print(f"Total file paths: {len(file_paths)}.") 122 | print("Reading file contents...") 123 | 124 | for i, (directory_name, file_path) in enumerate(tqdm(file_paths)): 125 | file_content = process_file(directory_name, file_path) 126 | 127 | if file_content["content"] != "": 128 | temp_df = pd.DataFrame.from_dict([file_content]) 129 | df = pd.concat([df, temp_df]) 130 | 131 | if SERIALIZE_IN_CHUNKS and len(df) != 0 and (len(df) % SERIALIZE_IN_CHUNKS == 0): 132 | df_path = f"df_chunk_{chunk_flag}_{len(df)}.{FEATHER_FORMAT}" 133 | print(f"Serializing dataframe to {df_path}...") 134 | df.reset_index().to_parquet(df_path) 135 | del df 136 | df = pd.DataFrame(columns=["repo_id", "file_path", "content"]) 137 | chunk_flag += 1 138 | 139 | return df 140 | 141 | 142 | if __name__ == "__main__": 143 | df = read_repository_files(MIRROR_DIRECTORY) 144 | print("DataFrame created, creating dataset...") 145 | upload_to_hub(file_format=PARQUET_FORMAT, repo_id=DATASET_ID) 146 | print(f"{FEATHER_FORMAT} files uploaded to the Hub.") 147 | if not SERIALIZE_IN_CHUNKS: 148 | dataset = Dataset.from_pandas(df) 149 | dataset.push_to_hub(DATASET_ID, private=True) 150 | -------------------------------------------------------------------------------- /personal_copilot/dataset_generation/requirements.txt: -------------------------------------------------------------------------------- 1 | PyGithub 2 | datasets 3 | nbformat 4 | pandas -------------------------------------------------------------------------------- /personal_copilot/training/fim.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import numpy as np 4 | 5 | 6 | # this is expensive so we cache it 7 | @functools.lru_cache(maxsize=None) 8 | def get_fim_token_ids(tokenizer): 9 | return tokenizer.bos_token_id, tokenizer.suffix_id, tokenizer.prefix_id, tokenizer.middle_id, 0 10 | 11 | 12 | ## Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py 13 | def permute( 14 | sample, 15 | np_rng, 16 | bos_token_id, 17 | suffix_tok_id, 18 | prefix_tok_id, 19 | middle_tok_id, 20 | pad_tok_id, 21 | fim_rate=0.5, 22 | fim_spm_rate=0.5, 23 | truncate_or_pad=False, 24 | ): 25 | """ 26 | Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes: 27 | PSM and SPM (with a probability of fim_spm_rate). 28 | """ 29 | 30 | if np_rng.binomial(1, fim_rate): 31 | boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2)) 32 | boundaries.sort() 33 | 34 | prefix = np.array(sample[: boundaries[0]], dtype=np.int64) 35 | middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64) 36 | suffix = np.array(sample[boundaries[1] :], dtype=np.int64) 37 | 38 | if truncate_or_pad: 39 | new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3 40 | diff = new_length - len(sample) 41 | if diff > 0: 42 | if suffix.shape[0] <= diff: 43 | return sample, np_rng 44 | suffix = suffix[: suffix.shape[0] - diff] 45 | elif diff < 0: 46 | suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)]) 47 | 48 | if np_rng.binomial(1, fim_spm_rate): 49 | # SPM (variant 2 from FIM paper) 50 | new_sample = np.concatenate( 51 | [ 52 | [bos_token_id, prefix_tok_id, suffix_tok_id], 53 | suffix, 54 | [middle_tok_id], 55 | prefix, 56 | middle, 57 | ] 58 | ) 59 | else: 60 | # PSM 61 | new_sample = np.concatenate( 62 | [ 63 | [bos_token_id, prefix_tok_id], 64 | prefix, 65 | [suffix_tok_id], 66 | suffix, 67 | [middle_tok_id], 68 | middle, 69 | ] 70 | ) 71 | else: 72 | # don't do FIM preproc 73 | new_sample = sample 74 | 75 | return list(new_sample), np_rng 76 | -------------------------------------------------------------------------------- /personal_copilot/training/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py 2 | 3 | from typing import List, Optional, Tuple, Union 4 | import logging 5 | 6 | import torch 7 | from torch import nn 8 | 9 | import transformers 10 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 11 | 12 | from einops import rearrange 13 | from flash_attn import flash_attn_func 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.LongTensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | **dummy_kwargs, 25 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 26 | bsz, q_len, _ = hidden_states.size() 27 | 28 | if self.config.pretraining_tp > 1: 29 | raise ValueError("pretraining_tp > 1 is not supported for flash attention") 30 | else: 31 | query_states = self.q_proj(hidden_states) 32 | key_states = self.k_proj(hidden_states) 33 | value_states = self.v_proj(hidden_states) 34 | 35 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 36 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 37 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 38 | 39 | kv_seq_len = key_states.shape[-2] 40 | 41 | if past_key_value is not None: 42 | kv_seq_len += past_key_value[0].shape[-2] 43 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 44 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 45 | 46 | if past_key_value is not None: 47 | # reuse k, v, self_attention 48 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 49 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 50 | 51 | past_key_value = (key_states, value_states) if use_cache else None 52 | 53 | query_states, key_states, value_states = [ 54 | rearrange(x, "b h s d -> b s h d") for x in [query_states, key_states, value_states] 55 | ] 56 | 57 | query_states, key_states, value_states = [x.to(torch.bfloat16) for x in [query_states, key_states, value_states]] 58 | # print(f"{query.shape=} {key.shape=} {value.shape=}") 59 | # below output will have shape (batch_size, seqlen, nheads, headdim) 60 | attn_output = flash_attn_func(query_states, key_states, value_states, causal=True) 61 | 62 | if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): 63 | raise ValueError( 64 | f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" 65 | f" {attn_output.size()}" 66 | ) 67 | 68 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 69 | attn_output = self.o_proj(attn_output) 70 | if output_attentions: 71 | raise NotImplementedError("`output_attentions` is not supported when `use_flash_attn` is True") 72 | attn_weights = None 73 | 74 | return attn_output, attn_weights, past_key_value 75 | 76 | 77 | # def forward( 78 | # self, 79 | # hidden_states: torch.Tensor, 80 | # attention_mask: Optional[torch.Tensor] = None, 81 | # position_ids: Optional[torch.Tensor] = None, 82 | # past_key_value: Optional[Tuple[torch.Tensor]] = None, 83 | # output_attentions: bool = False, 84 | # use_cache: bool = False, 85 | # ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 86 | # """Input shape: Batch x Time x Channel 87 | 88 | # attention_mask: [bsz, q_len] 89 | # """ 90 | # bsz, q_len, _ = hidden_states.size() 91 | 92 | # query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 93 | # key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 94 | # value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 95 | # # [bsz, q_len, nh, hd] 96 | # # [bsz, nh, q_len, hd] 97 | 98 | # kv_seq_len = key_states.shape[-2] 99 | # assert past_key_value is None, "past_key_value is not supported" 100 | 101 | # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 102 | # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 103 | # # [bsz, nh, t, hd] 104 | # assert not output_attentions, "output_attentions is not supported" 105 | # assert not use_cache, "use_cache is not supported" 106 | 107 | # # Flash attention codes from 108 | # # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 109 | 110 | # # transform the data into the format required by flash attention 111 | # qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] 112 | # qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 113 | # # We have disabled _prepare_decoder_attention_mask in LlamaModel 114 | # # the attention_mask should be the same as the key_padding_mask 115 | # key_padding_mask = attention_mask 116 | 117 | # if key_padding_mask is None: 118 | # qkv = rearrange(qkv, "b s ... -> (b s) ...") 119 | # max_s = q_len 120 | # cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) 121 | # output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 122 | # output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 123 | # else: 124 | # nheads = qkv.shape[-2] 125 | # x = rearrange(qkv, "b s three h d -> b s (three h d)") 126 | # x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 127 | # x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) 128 | # output_unpad = flash_attn_unpadded_qkvpacked_func( 129 | # x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 130 | # ) 131 | # output = rearrange( 132 | # pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len), 133 | # "b s (h d) -> b s h d", 134 | # h=nheads, 135 | # ) 136 | # return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None 137 | 138 | 139 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 140 | # requires the attention mask to be the same as the key_padding_mask 141 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 142 | # [bsz, seq_len] 143 | return attention_mask 144 | 145 | 146 | def replace_llama_attn_with_flash_attn(): 147 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 148 | if cuda_major < 8: 149 | logging.warning( 150 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 151 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 152 | ) 153 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 154 | _prepare_decoder_attention_mask 155 | ) 156 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 157 | -------------------------------------------------------------------------------- /personal_copilot/training/requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/huggingface/transformers 2 | git+https://github.com/huggingface/accelerate 3 | git+https://github.com/huggingface/peft 4 | trl 5 | huggingface-hub 6 | bitsandbytes 7 | evaluate 8 | datasets 9 | einops 10 | wandb 11 | tiktoken 12 | deepspeed 13 | tqdm 14 | safetensors -------------------------------------------------------------------------------- /personal_copilot/training/run_peft.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --model_path "codellama/CodeLlama-13b-Instruct-hf" \ 3 | --dataset_name "smangrul/hf-stack-v3" \ 4 | --subset "data" \ 5 | --data_column "content" \ 6 | --split "train" \ 7 | --seq_length 2048 \ 8 | --max_steps 2000 \ 9 | --batch_size 8 \ 10 | --gradient_accumulation_steps 2 \ 11 | --learning_rate 3e-4 \ 12 | --lr_scheduler_type "cosine" \ 13 | --weight_decay 0.01 \ 14 | --num_warmup_steps 30 \ 15 | --eval_freq 100 \ 16 | --save_freq 100 \ 17 | --log_freq 5 \ 18 | --push_to_hub \ 19 | --num_workers 4 \ 20 | --bf16 \ 21 | --no_fp16 \ 22 | --output_dir "codellama-13b-personal-copilot" \ 23 | --fim_rate 0.5 \ 24 | --fim_spm_rate 0.0 \ 25 | --use_peft_lora \ 26 | --lora_r 8 \ 27 | --lora_alpha 32 \ 28 | --lora_dropout 0.1 \ 29 | --lora_target_modules "q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj" \ 30 | --use_flash_attn \ 31 | --use_4bit_qunatization \ 32 | --use_nested_quant \ 33 | --bnb_4bit_compute_dtype "bfloat16" \ 34 | --seed 24 -------------------------------------------------------------------------------- /personal_copilot/training/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fine-Tune StarCoder on code/text dataset 3 | """ 4 | 5 | import argparse 6 | import os 7 | import random 8 | import subprocess 9 | import warnings 10 | 11 | import numpy as np 12 | import torch 13 | from datasets import load_dataset 14 | from torch.utils.data import IterableDataset 15 | from torch.utils.data.dataloader import DataLoader 16 | from tqdm import tqdm 17 | from transformers import ( 18 | AutoModelForCausalLM, 19 | AutoTokenizer, 20 | Trainer, 21 | TrainingArguments, 22 | logging, 23 | set_seed, 24 | BitsAndBytesConfig, 25 | ) 26 | 27 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 28 | from peft.tuners.lora import LoraLayer 29 | 30 | import fim 31 | 32 | 33 | def get_args(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--model_path", type=str, default="codellama/CodeLlama-13b-Instruct-hf") 36 | parser.add_argument("--dataset_name", type=str, default="smangrul/hf-stack-v2") 37 | parser.add_argument("--subset", type=str, default="data") 38 | parser.add_argument("--split", type=str, default="train") 39 | parser.add_argument("--size_valid_set", type=int, default=4000) 40 | parser.add_argument("--test_size", type=float, default=0.005) 41 | parser.add_argument("--streaming", action="store_true") 42 | parser.add_argument("--shuffle_buffer", type=int, default=5000) 43 | parser.add_argument("--data_column", type=str, default="content") 44 | 45 | parser.add_argument("--seq_length", type=int, default=8192) 46 | parser.add_argument("--max_steps", type=int, default=10000) 47 | parser.add_argument("--batch_size", type=int, default=2) 48 | parser.add_argument("--gradient_accumulation_steps", type=int, default=8) 49 | parser.add_argument("--eos_token_id", type=int, default=49152) 50 | 51 | parser.add_argument("--learning_rate", type=float, default=5e-5) 52 | parser.add_argument("--lr_scheduler_type", type=str, default="cosine") 53 | parser.add_argument("--num_warmup_steps", type=int, default=100) 54 | parser.add_argument("--weight_decay", type=float, default=0.05) 55 | 56 | parser.add_argument("--local_rank", type=int, default=0) 57 | parser.add_argument("--no_fp16", action="store_false") 58 | parser.add_argument("--bf16", action="store_true") 59 | parser.add_argument("--no_gradient_checkpointing", action="store_false") 60 | parser.add_argument("--seed", type=int, default=0) 61 | parser.add_argument("--num_workers", type=int, default=None) 62 | parser.add_argument("--output_dir", type=str, default="./checkpoints") 63 | parser.add_argument("--log_freq", default=1, type=int) 64 | parser.add_argument("--eval_freq", default=1000, type=int) 65 | parser.add_argument("--save_freq", default=1000, type=int) 66 | 67 | parser.add_argument("--fim_rate", type=float, default=0) 68 | parser.add_argument("--fim_spm_rate", type=float, default=0) 69 | 70 | parser.add_argument("--use_peft_lora", action="store_true") 71 | parser.add_argument("--lora_r", type=int, default=0) 72 | parser.add_argument("--lora_alpha", type=int, default=0) 73 | parser.add_argument("--lora_dropout", type=float, default=0) 74 | parser.add_argument("--lora_target_modules", type=str, default=None) 75 | 76 | parser.add_argument("--use_flash_attn", action="store_true") 77 | 78 | parser.add_argument("--use_4bit_qunatization", action="store_true") 79 | parser.add_argument("--use_nested_quant", action="store_true") 80 | parser.add_argument("--bnb_4bit_quant_type", type=str, default="nf4") 81 | parser.add_argument("--bnb_4bit_compute_dtype", type=str, default="float16") 82 | 83 | parser.add_argument("--use_8bit_qunatization", action="store_true") 84 | 85 | parser.add_argument("--push_to_hub", action="store_true") 86 | 87 | return parser.parse_args() 88 | 89 | 90 | def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400): 91 | """ 92 | Estimate the average number of characters per token in the dataset. 93 | """ 94 | total_characters, total_tokens = 0, 0 95 | for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): 96 | total_characters += len(example[data_column]) 97 | total_tokens += len(tokenizer(example[data_column]).tokens()) 98 | 99 | return total_characters / total_tokens 100 | 101 | 102 | class ConstantLengthDataset(IterableDataset): 103 | """ 104 | Iterable dataset that returns constant length chunks of tokens from stream of text files. 105 | Args: 106 | tokenizer (Tokenizer): The processor used for proccessing the data. 107 | dataset (dataset.Dataset): Dataset with text files. 108 | infinite (bool): If True the iterator is reset after dataset reaches end else stops. 109 | seq_length (int): Length of token sequences to return. 110 | num_of_sequences (int): Number of token sequences to keep in buffer. 111 | chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. 112 | fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM. 113 | fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM. 114 | seed (int): Seed for random number generator. 115 | """ 116 | 117 | def __init__( 118 | self, 119 | tokenizer, 120 | dataset, 121 | infinite=False, 122 | seq_length=1024, 123 | num_of_sequences=1024, 124 | chars_per_token=3.6, 125 | content_field="content", 126 | fim_rate=0.5, 127 | fim_spm_rate=0.5, 128 | seed=0, 129 | ): 130 | self.tokenizer = tokenizer 131 | self.concat_token_id = tokenizer.eos_token_id 132 | self.dataset = dataset 133 | self.seq_length = seq_length 134 | self.infinite = infinite 135 | self.current_size = 0 136 | self.max_buffer_size = seq_length * chars_per_token * num_of_sequences 137 | self.content_field = content_field 138 | self.fim_rate = fim_rate 139 | self.fim_spm_rate = fim_spm_rate 140 | self.seed = seed 141 | 142 | ( 143 | self.bos_token_id, 144 | self.suffix_tok_id, 145 | self.prefix_tok_id, 146 | self.middle_tok_id, 147 | self.pad_tok_id, 148 | ) = fim.get_fim_token_ids(self.tokenizer) 149 | if not self.suffix_tok_id and self.fim_rate > 0: 150 | print("FIM is not supported by tokenizer, disabling FIM") 151 | self.fim_rate = 0 152 | 153 | def __iter__(self): 154 | iterator = iter(self.dataset) 155 | more_examples = True 156 | while more_examples: 157 | buffer, buffer_len = [], 0 158 | while True: 159 | if buffer_len >= self.max_buffer_size: 160 | break 161 | try: 162 | buffer.append(next(iterator)[self.content_field]) 163 | buffer_len += len(buffer[-1]) 164 | except StopIteration: 165 | if self.infinite: 166 | iterator = iter(self.dataset) 167 | else: 168 | more_examples = False 169 | break 170 | tokenized_inputs = self.tokenizer(buffer, truncation=False, add_special_tokens=False)["input_ids"] 171 | all_token_ids = [] 172 | 173 | np_rng = np.random.RandomState(seed=self.seed) 174 | for tokenized_input in tokenized_inputs: 175 | # optionally do FIM permutations 176 | if self.fim_rate > 0: 177 | tokenized_input, np_rng = fim.permute( 178 | tokenized_input, 179 | np_rng, 180 | self.bos_token_id, 181 | self.suffix_tok_id, 182 | self.prefix_tok_id, 183 | self.middle_tok_id, 184 | self.pad_tok_id, 185 | fim_rate=self.fim_rate, 186 | fim_spm_rate=self.fim_spm_rate, 187 | truncate_or_pad=False, 188 | ) 189 | 190 | all_token_ids.extend(tokenized_input + [self.concat_token_id]) 191 | examples = [] 192 | for i in range(0, len(all_token_ids), self.seq_length): 193 | input_ids = all_token_ids[i : i + self.seq_length] 194 | if len(input_ids) == self.seq_length: 195 | examples.append(input_ids) 196 | random.shuffle(examples) 197 | for example in examples: 198 | self.current_size += 1 199 | yield { 200 | "input_ids": torch.LongTensor(example), 201 | "labels": torch.LongTensor(example), 202 | } 203 | 204 | 205 | def create_datasets(tokenizer, args): 206 | dataset = load_dataset( 207 | args.dataset_name, 208 | split=args.split, 209 | use_auth_token=True, 210 | num_proc=args.num_workers if not args.streaming else None, 211 | streaming=args.streaming, 212 | ) 213 | if args.streaming: 214 | print("Loading the dataset in streaming mode") 215 | valid_data = dataset.take(args.size_valid_set) 216 | train_data = dataset.skip(args.size_valid_set) 217 | train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) 218 | else: 219 | dataset = dataset.train_test_split(test_size=args.test_size, seed=args.seed, shuffle=True) 220 | train_data = dataset["train"] 221 | valid_data = dataset["test"] 222 | print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") 223 | chars_per_token = chars_token_ratio(train_data, tokenizer, args.data_column) 224 | print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") 225 | train_dataset = ConstantLengthDataset( 226 | tokenizer, 227 | train_data, 228 | infinite=True, 229 | seq_length=args.seq_length, 230 | chars_per_token=chars_per_token, 231 | content_field=args.data_column, 232 | fim_rate=args.fim_rate, 233 | fim_spm_rate=args.fim_spm_rate, 234 | seed=args.seed, 235 | ) 236 | valid_dataset = ConstantLengthDataset( 237 | tokenizer, 238 | valid_data, 239 | infinite=False, 240 | seq_length=args.seq_length, 241 | chars_per_token=chars_per_token, 242 | content_field=args.data_column, 243 | fim_rate=args.fim_rate, 244 | fim_spm_rate=args.fim_spm_rate, 245 | seed=args.seed, 246 | ) 247 | 248 | return train_dataset, valid_dataset 249 | 250 | 251 | def create_and_prepare_model(args): 252 | device_map = None 253 | bnb_config = None 254 | 255 | load_in_8bit = args.use_8bit_qunatization 256 | 257 | if args.use_4bit_qunatization: 258 | compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype) 259 | 260 | bnb_config = BitsAndBytesConfig( 261 | load_in_4bit=args.use_4bit_qunatization, 262 | bnb_4bit_quant_type=args.bnb_4bit_quant_type, 263 | bnb_4bit_compute_dtype=compute_dtype, 264 | bnb_4bit_use_double_quant=args.use_nested_quant, 265 | ) 266 | 267 | if compute_dtype == torch.float16 and args.use_4bit_qunatization: 268 | major, _ = torch.cuda.get_device_capability() 269 | if major >= 8: 270 | print("=" * 80) 271 | print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16") 272 | print("=" * 80) 273 | 274 | if args.use_4bit_qunatization or args.use_8bit_qunatization: 275 | device_map = {"": 0} 276 | 277 | model = AutoModelForCausalLM.from_pretrained( 278 | args.model_path, 279 | torch_dtype=compute_dtype, 280 | load_in_8bit=load_in_8bit, 281 | quantization_config=bnb_config, 282 | device_map=device_map, 283 | use_cache=not args.no_gradient_checkpointing, 284 | trust_remote_code=True, 285 | # use_flash_attention_2=args.use_flash_attn, 286 | ) 287 | 288 | if (args.use_4bit_qunatization or args.use_8bit_qunatization) and args.use_peft_lora: 289 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.no_gradient_checkpointing) 290 | 291 | if args.use_peft_lora: 292 | peft_config = LoraConfig( 293 | lora_alpha=args.lora_alpha, 294 | lora_dropout=args.lora_dropout, 295 | r=args.lora_r, 296 | bias="none", 297 | task_type="CAUSAL_LM", 298 | target_modules=args.lora_target_modules.split(","), 299 | ) 300 | 301 | if args.no_gradient_checkpointing: 302 | model.gradient_checkpointing_enable() 303 | 304 | model = get_peft_model(model, peft_config) 305 | model.print_trainable_parameters() 306 | return model 307 | 308 | 309 | def run_training(args, train_data, val_data): 310 | train_data.start_iteration = 0 311 | 312 | print(f"Starting main loop") 313 | training_args = TrainingArguments( 314 | output_dir=args.output_dir, 315 | dataloader_drop_last=True, 316 | evaluation_strategy="steps", 317 | save_strategy="steps", 318 | max_steps=args.max_steps, 319 | eval_steps=args.eval_freq, 320 | save_steps=args.save_freq, 321 | logging_steps=args.log_freq, 322 | per_device_train_batch_size=args.batch_size, 323 | per_device_eval_batch_size=args.batch_size, 324 | learning_rate=args.learning_rate, 325 | lr_scheduler_type=args.lr_scheduler_type, 326 | warmup_steps=args.num_warmup_steps, 327 | gradient_accumulation_steps=args.gradient_accumulation_steps, 328 | gradient_checkpointing=args.no_gradient_checkpointing, 329 | fp16=args.no_fp16, 330 | bf16=args.bf16, 331 | weight_decay=args.weight_decay, 332 | run_name=f"codellama-copilot", 333 | push_to_hub=args.push_to_hub, 334 | ) 335 | 336 | print("Loading the model") 337 | model = create_and_prepare_model(args) 338 | print(model) 339 | if args.use_peft_lora: 340 | model.print_trainable_parameters() 341 | 342 | trainer = Trainer(model=model, args=training_args, train_dataset=train_data, eval_dataset=val_data) 343 | 344 | # post process for faster training when using PEFT + INT4 Quantization 345 | if args.use_peft_lora: 346 | for name, module in trainer.model.named_modules(): 347 | if isinstance(module, LoraLayer): 348 | if args.bf16: 349 | module = module.to(torch.bfloat16) 350 | if "norm" in name: 351 | module = module.to(torch.float32) 352 | if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): 353 | if hasattr(module, "weight"): 354 | if args.bf16 and module.weight.dtype == torch.float32: 355 | module = module.to(torch.bfloat16) 356 | 357 | print("Training...") 358 | trainer.train() 359 | if args.use_peft_lora: 360 | print("Saving last checkpoint of the model") 361 | model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/")) 362 | 363 | if args.push_to_hub: 364 | trainer.push_to_hub() 365 | else: 366 | trainer.save_model(args.output_dir) 367 | trainer.accelerator.print(f"Model saved to {args.output_dir}") 368 | 369 | if args.use_peft_lora: 370 | trainer.model.push_to_hub(args.output_dir) 371 | 372 | 373 | def main(args): 374 | if args.use_flash_attn: 375 | warnings.warn( 376 | "Flash V2 support implemented here ignores padding/attention_mask/custom_mask. \n" 377 | + "It is meant for continued pre-training with packing inputs to consume the entire sequence lengths." 378 | ) 379 | from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 380 | 381 | replace_llama_attn_with_flash_attn() 382 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_auth_token=True, trust_remote_code=True) 383 | train_dataset, eval_dataset = create_datasets(tokenizer, args) 384 | run_training(args, train_dataset, eval_dataset) 385 | 386 | 387 | if __name__ == "__main__": 388 | args = get_args() 389 | set_seed(args.seed) 390 | os.makedirs(args.output_dir, exist_ok=True) 391 | main(args) 392 | --------------------------------------------------------------------------------