├── LICENSE ├── README.md ├── requirements.txt └── scripts ├── benchmark ├── README.md ├── benchmark.py ├── data_utils.py ├── model_utils.py ├── performance_evaluator.py ├── requirements.txt ├── scripts │ └── benchmark_7B │ │ ├── gemini.sh │ │ └── gemini_auto.sh └── test_ci.sh ├── convert_hf_to_gguf.py ├── finetune ├── chat_finetune │ └── finetune.py ├── instruct_finetune │ ├── dpo_finetune.sh │ └── sft_finetune.sh └── reason_finetune │ └── train_7b.sh ├── inference └── inference.py └── train ├── attn.py ├── pretrain.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 moxin-org 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Moxin LLM

2 |

Moxin is a family of fully open-source and reproducible LLMs

3 |

Technical Report    |    Base Model    |    Chat Model    |    Instruct Model    |    Reasoning Model

4 | 5 | 6 | ## Introduction 7 | 8 | Generative AI (GAI) offers unprecedented opportunities for research and innovation, but its commercialization has raised concerns about transparency, reproducibility, and safety. Many open GAI models lack the necessary components for full understanding and reproducibility, and some use restrictive licenses whilst claiming to be “open-source”. To address these concerns, we follow the [Model Openness Framework (MOF)](https://arxiv.org/pdf/2403.13784), a ranked classification system that rates machine learning models based on their completeness and openness, following principles of open science, open source, open data, and open access. 9 | 10 | By promoting transparency and reproducibility, the MOF combats “openwashing” practices and establishes completeness and openness as primary criteria alongside the core tenets of responsible AI. Wide adoption of the MOF will foster a more open AI ecosystem, benefiting research, innovation, and adoption of state-of-the-art models. 11 | 12 | We follow MOF to release the datasets during training, the training scripts, and the trained models. 13 | 14 | 15 | 16 | ## Model 17 | You can download our [Moxin-7B-Base](https://huggingface.co/moxin-org/moxin-llm-7b), [Moxin-7B-Chat](https://huggingface.co/moxin-org/moxin-chat-7b), [Moxin-7B-Instruct](https://huggingface.co/moxin-org/moxin-instruct-7b) and [Moxin-7B-Reasoning](https://huggingface.co/moxin-org/moxin-reasoning-7b) models. 18 | 19 | 20 | 21 | 22 | 23 | ## Evaluation 24 | 25 | ### Base Model Evaluation 26 | 27 | We test the performance of our base model with [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). The evaluation results on common datasets are shown below. We test on AI2 Reasoning Challenge (25-shot), HellaSwag (10-shot), MMLU (5-shot), and Winogrande (5-shot). We release the Moxin-7B-Enhanced as our base model. We further finetune our base model on Tulu v2 to obtain our chat model. 28 | 29 | | Models | ARC-C | Hellaswag | MMLU | WinoGrade | Ave | 30 | |:----------------------:|:-----:|:---------:|:-----:|:---------:|:-----:| 31 | | Mistral-7B | 57.59 | 83.25 | 62.42 | 78.77 | 70.51 | 32 | | LLaMA 3.1-8B | 54.61 | 81.95 | 65.16 | 77.35 | 69.77 | 33 | | LLaMA 3-8B | 55.46 | 82.09 | 65.29 | 77.82 | 70.17 | 34 | | LLaMA 2-7B | 49.74 | 78.94 | 45.89 | 74.27 | 62.21 | 35 | | Qwen 2-7B | 57.68 | 80.76 | 70.42 | 77.43 | 71.57 | 36 | | Gemma-7b | 56.48 | 82.31 | 63.02 | 78.3 | 70.03 | 37 | | Internlm2.5-7b | 54.78 | 79.7 | 68.17 | 80.9 | 70.89 | 38 | | Baichuan2-7B | 47.87 | 73.89 | 54.13 | 70.8 | 61.67 | 39 | | Yi-1.5-9B | 58.36 | 80.36 | 69.54 | 77.53 | 71.48 | 40 | | Moxin-7B-Original | 53.75 | 75.46 | 59.43 | 70.32 | 64.74 | 41 | | Moxin-7B-Enhanced (Moxin-7B-Base)| 59.47 | 83.08 | 60.97 | 78.69 | 70.55 | 42 | 43 | 44 | We also test the zero shot performance on AI2 Reasoning Challenge (0-shot), AI2 Reasoning Easy (0-shot), HellaSwag (0-shot), PIQA (0-shot) and Winogrande (0-shot). The results are shown below. 45 | 46 | | Models | HellaSwag | WinoGrade | PIQA | ARC-E | ARC-C | Ave | 47 | |:-----------------: |:---------: |:---------: |:-----: |:-----: |:-----: |:-----: | 48 | | Mistral-7B | 80.39 | 73.4 | 82.15 | 78.28 | 52.22 | 73.29 | 49 | | LLaMA 2-7B | 75.99 | 69.06 | 79.11 | 74.54 | 46.42 | 69.02 | 50 | | LLaMA 2-13B | 79.37 | 72.22 | 80.52 | 77.4 | 49.06 | 71.71 | 51 | | LLaMA 3.1-8B | 78.92 | 74.19 | 81.12 | 81.06 | 53.67 | 73.79 | 52 | | Gemma-7b | 80.45 | 73.72 | 80.9 | 79.97 | 54.1 | 73.83 | 53 | | Qwen v2-7B | 78.9 | 72.38 | 79.98 | 74.71 | 50.09 | 71.21 | 54 | | Internlm2.5-7b | 79.14 | 77.9 | 80.52 | 76.16 | 51.37 | 73.02 | 55 | | Baichuan2-7B | 72.25 | 67.17 | 77.26 | 72.98 | 42.15 | 66.36 | 56 | | Yi-1.5-9B | 77.86 | 73.01 | 80.74 | 79.04 | 55.03 | 73.14 | 57 | | Deepseek-7b | 76.13 | 69.77 | 79.76 | 71.04 | 44.8 | 68.3 | 58 | | Moxin-7B-Original | 72.06 | 66.31 | 78.07 | 71.47 | 48.15 | 67.21 | 59 | | Moxin-7B-Enhanced (Moxin-7B-Base) | 80.03 | 75.17 | 82.24 | 81.12 | 58.64 | 75.44 | 60 | 61 | 62 | 63 | ### Instruct Model Evaluation 64 | 65 | Our instruct model is trained with [Tulu 3](https://allenai.org/blog/tulu-3-technical). The evaluations are demonstrated below. We evaluate with [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) and [OLMES](https://github.com/allenai/olmes). 66 | 67 | We test on AI2 Reasoning Challenge (25-shot), HellaSwag (10-shot), MMLU (5-shot), and Winogrande (5-shot). 68 | |Model |ARC-C| Hellaswag| MMLU |WinoGrade| Ave| 69 | |:-----------------: |:---------: |:---------: |:-----: |:-----: |:-----: | 70 | |Mistral 8B Instruct| 62.63 |80.61 |64.16| 79.08| 71.62| 71 | |Llama3.1 8B Instruct| 60.32 |80 |68.18 |77.27| 71.44| 72 | |Qwen2.5 7B Instruct| 66.72 |81.54| 71.3 |74.59| 73.54| 73 | |Moxin-7B-SFT| 60.11 |83.43| 60.56| 77.56| 70.42| 74 | |Moxin-7B-DPO (Moxin-7B-Instruct) | 64.76 |87.19| 58.36| 76.32| 71.66| 75 | 76 | 77 | We also test the zero shot performance on AI2 Reasoning Challenge (0-shot), AI2 Reasoning Easy (0-shot), HellaSwag (0-shot), PIQA (0-shot) and Winogrande (0-shot). The results are shown below. 78 | |Models | HellaSwag | WinoGrade | PIQA | ARC-E | ARC-C | Ave | 79 | |:-----------------: |:---------: |:---------: |:-----: |:-----: |:-----: |:-----: | 80 | |Mistral 8B Instruct | 79.08 | 73.56 | 82.26 | 79.88 | 56.57 | 74.27 | 81 | | Llama3.1 8B Instruct | 79.21| 74.19 |80.79 |79.71 |55.03 |73.79| 82 | |Qwen2.5 7B Instruct | 80.5 | 71.03 | 80.47 | 81.31 | 55.12 | 73.69 | 83 | |Moxin-7B-SFT |81.44 |73.09 |81.07 |79.8 |54.67| 74.01| 84 | |Moxin-7B-DPO (Moxin-7B-Instruct) | 85.7 | 73.24 | 81.56 |81.1 |58.02| 75.92| 85 | 86 | 87 | 88 | 89 | The evaluation results with OLMES are shown below. 90 | |Models/Datasets |GSM8K |MATH |Humaneval |Humaneval plus |MMLU |PopQA |BBH |TruthfulQA| Ave| 91 | |:-----------------: |:---------: |:---------: |:-----: |:-----: |:-----: |:-----: |:-----: |:-----: |:-----: | 92 | |Qwen2.5 7B Instruct |83.8 |14.8 |93.1 |89.7 |76.6 |18.1 |21.7 |63.1| 57.61| 93 | |Gemma2 9B Instruct| 79.7 |29.8 |71.7 |67 |74.6 |28.3 |2.5 |61.4 |51.88| 94 | |Moxin-7B-DPO (Moxin-7B-Instruct) |81.19| 36.42| 82.86| 77.18 |60.85 |23.85 |57.44| 55.27 |59.38| 95 | 96 | 97 | ### Reasoning Model Evaluation 98 | 99 | Our reasoning model is trained with [DeepScaleR](https://github.com/agentica-project/rllm). The evaluation on math datasets are demonstrated below. 100 | 101 | |Models/Datasets |MATH 500 |AMC |Minerva Math |OlympiadBench |Ave| 102 | |:-----------------: |:---------: |:---------: |:-----: |:-----: |:-----: | 103 | |Qwen2.5-Math-7B-Base |52.4 |52.5 |12.9 |16.4| 33.55| 104 | |Qwen2.5-Math-7B-Base + 8K MATH SFT |54.6 |22.5| 32.7| 19.6| 32.35| 105 | |Llama-3.1-70B-Instruct| 64.6 |30.1 |35.3| 31.9| 40.48| 106 | |Moxin-7B-RL-DeepScaleR| 68 |57.5 |16.9| 30.4 |43.2| 107 | 108 | 109 | ## Inference 110 | 111 | You can use the following code to run inference with the model. The model is saved under './model/' directory. Change the model directory accordingly or use the Huggingface link. 112 | 113 | ``` 114 | import torch 115 | from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline 116 | 117 | torch.backends.cuda.enable_mem_efficient_sdp(False) 118 | torch.backends.cuda.enable_flash_sdp(False) 119 | 120 | model_name = 'moxin-org/moxin-7b' 121 | tokenizer = AutoTokenizer.from_pretrained(model_name) 122 | model = AutoModelForCausalLM.from_pretrained( 123 | model_name, 124 | torch_dtype=torch.bfloat16, 125 | device_map="auto", 126 | trust_remote_code=True, 127 | ) 128 | 129 | pipe = pipeline( 130 | "text-generation", 131 | model=model, 132 | tokenizer = tokenizer, 133 | torch_dtype=torch.bfloat16, 134 | device_map="auto" 135 | ) 136 | 137 | prompt = "Can you explain the concept of regularization in machine learning?" 138 | 139 | sequences = pipe( 140 | prompt, 141 | do_sample=True, 142 | max_new_tokens=100, 143 | temperature=0.7, 144 | top_k=50, 145 | top_p=0.95, 146 | num_return_sequences=1, 147 | ) 148 | print(sequences[0]['generated_text']) 149 | ``` 150 | 151 | ### Convert to GGUF 152 | 153 | 154 | Build a typical deep learning environment with pytorch. Then use the script covert_hf_to_gguf.py to convert the hf model to GGUF. 155 | ``` 156 | python covert_hf_to_gguf.py path_to_model_directory/ 157 | ``` 158 | Then, you can experiment with this gguf model following [llama.cpp](https://github.com/ggerganov/llama.cpp). 159 | 160 | 161 | 162 | ## Reinforcement Learning with GRPO 163 | 164 | To enhance the CoT capabilities of our model, we adopt RL techniques similar to DeepSeek R1. We first use high quality reasoning data to SFT our instruct model. The reasoning data mainly includes Openthoughts and OpenR1-Math-220k. Next, we adopt the RL techniques in DeepSeek R1, i.e., GRPO to finetune our model with RL. We adopt the [DeepScaleR](https://github.com/agentica-project/rllm) as our RL training framework. 165 | 166 | We first use high quality reasoning data to SFT our instruct (DPO) model. 167 | + Dataset: [OpenThoughts](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) and [OpenR1-Math-220k](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k) 168 | + Framework: [open-instruct](https://github.com/allenai/open-instruct) 169 | + Configuration: [Llama-3.1-Tulu-3-8B-SFT](https://github.com/allenai/open-instruct/blob/main/docs/tulu3.md) 170 | 171 | Refer to 'scripts/finetune/instruct_finetune/sft_finetune.sh' for more details. 172 | 173 | Next, we adopt GRPO to finetune our model with RL. 174 | + Framework, configuration and Dataset: [DeepScaleR](https://github.com/agentica-project/rllm) 175 | 176 | Refer to 'scripts/finetune/reason_finetune/train_7b.sh' for more details. 177 | 178 | ## Post-Training with Tülu 3 179 | 180 | The open-source Tülu 3 dataset and framework are adopted for the model post-training. For our post-training, with our base model, we follow Tülu 3 to perform supervised finetuning (SFT) and then Direct Preference Optimization (DPO). 181 | 182 | Specifically, we use the Tülu 3 SFT Mixture dataset from Tülu 3 to train our base model with the SFT training method for two epochs and obtain our SFT model, following the default training configuration of the Tülu 3 8B SFT model. 183 | + Dataset: [Tülu 3 SFT Mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) 184 | + Framework: [open-instruct](https://github.com/allenai/open-instruct) 185 | + Configuration: [Llama-3.1-Tulu-3-8B-SFT](https://github.com/allenai/open-instruct/blob/main/docs/tulu3.md) 186 | 187 | Refer to 'scripts/finetune/instruct_finetune/sft_finetune.sh' for more details. 188 | 189 | Next, we continue to train our SFT model on the Tülu 3 8B Preference Mixture dataset from Tülu 3 with the DPO training method to obtain our DPO model, following the same training configuration of the Tülu 3 8B DPO model. 190 | + Dataset: [Tülu 3 8B Preference Mixture](https://huggingface.co/datasets/allenai/llama-3.1-tulu-3-8b-preference-mixture) 191 | + Framework: [open-instruct](https://github.com/allenai/open-instruct) 192 | + Configuration: [Llama-3.1-Tulu-3-8B-DPO](https://github.com/allenai/open-instruct/blob/main/docs/tulu3.md) 193 | 194 | Refer to 'scripts/finetune/instruct_finetune/dpo_finetune.sh' for more details. 195 | 196 | 197 | ## Pre-Training Environment 198 | 199 | #### 1. Dataset config 200 | To prepare the dataset, it needs to install the following package, 201 | 202 | 203 | ``` 204 | pip install datasets 205 | ``` 206 | 207 | #### 2. Cuda install 208 | 209 | We use cuda 11.7. Other cuda versions may also work. 210 | ``` 211 | get https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run 212 | sudo sh cuda_11.7.0_515.43.04_linux.run 213 | ``` 214 | 215 | #### 3. Install pytorch 216 | 217 | We use pytorch 2.0.0. 218 | ``` 219 | conda create --name llm_train python==3.10 220 | conda activate llm_train 221 | pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1 222 | ``` 223 | 224 | #### 4. Install other packages 225 | 226 | To install other packages, follow the requirements.txt 227 | ``` 228 | pip install -r requirements.txt 229 | ``` 230 | 231 | #### 5. Install flash attention 232 | 233 | We use flash-attention 2.2.1. 234 | ``` 235 | git clone https://github.com/Dao-AILab/flash-attention.git 236 | cd flash-attention/ 237 | git checkout a1576ad ## flash-attention 2.2.1 238 | python setup.py install 239 | cd ./csrc 240 | cd fused_dense_lib && pip install -v . 241 | cd ../xentropy && pip install -v . 242 | cd ../rotary && pip install -v . 243 | cd ../layer_norm && pip install -v . 244 | ``` 245 | 246 | 247 | ## Pretrain Datasets 248 | 249 | 250 | To use the [SlimPajama dataset](https://huggingface.co/datasets/cerebras/SlimPajama-627B) for pretraining, you can download the dataset using Hugging Face datasets: 251 | ``` 252 | import datasets 253 | ds = datasets.load_dataset("cerebras/SlimPajama-627B") 254 | ``` 255 | SlimPajama is the largest extensively deduplicated, multi-corpora, open-source dataset for training large language models. SlimPajama was created by cleaning and deduplicating the 1.2T token RedPajama dataset from Together. By filtering out low quality data and duplicates, it removes 49.6% of bytes, slimming down the RedPajama dataset from 1210B to 627B tokens. SlimPajama offers the highest quality and most compute efficient data to train on for runs up to 627B tokens. When upsampled, SlimPajama is expected to perform equal to or better than RedPajama-1T when training at trillion token scale. 256 | 257 | 258 | To use the [stack-dedup dataset](https://huggingface.co/datasets/bigcode/the-stack-dedup) for pretraining, you can download the dataset using Hugging Face datasets: 259 | ``` 260 | from datasets import load_dataset 261 | 262 | # full dataset (3TB of data) 263 | ds = load_dataset("bigcode/the-stack-dedup", split="train") 264 | 265 | # specific language (e.g. Dockerfiles) 266 | ds = load_dataset("bigcode/the-stack-dedup", data_dir="data/dockerfile", split="train") 267 | 268 | # dataset streaming (will only download the data as needed) 269 | ds = load_dataset("bigcode/the-stack-dedup", streaming=True, split="train") 270 | for sample in iter(ds): print(sample["content"]) 271 | ``` 272 | The Stack contains over 6TB of permissively-licensed source code files covering 358 programming languages. The dataset was created as part of the BigCode Project, an open scientific collaboration working on the responsible development of Large Language Models for Code (Code LLMs). The Stack serves as a pre-training dataset for Code LLMs, i.e., code-generating AI systems which enable the synthesis of programs from natural language descriptions as well as other from code snippets. This is the near-deduplicated version with 3TB data. 273 | 274 | You can find more details about the DCLM-baseline dataset on the [homepage](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0). 275 | 276 | ## Pre-Training 277 | 278 | We follow the [ColossalAI](https://github.com/hpcaitech/ColossalAI) framework to train the LLM model. Colossal-AI provides a collection of parallel components for the training. It aims to support to write the distributed deep learning models just like how you write your model on your laptop. It provides user-friendly tools to kickstart distributed training and inference in a few lines. 279 | 280 | We provide a few examples to show how to run benchmark or pretraining based on Colossal-AI. 281 | 282 | ### 1. Training LLM 283 | 284 | You can find the shell scripts in 'scripts/train_7B' directory. The main command should be in the format of: 285 | ``` 286 | colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ 287 | benchmark.py --OTHER_CONFIGURATIONS 288 | ``` 289 | 290 | #### a. Running on a sinlge node 291 | we provide an example to run the training on a single node as below, 292 | ``` 293 | colossalai run --nproc_per_node 1 pretrain.py \ 294 | --config 7b \ 295 | --dataset togethercomputer/RedPajama-Data-1T-Sample \ 296 | --batch_size 1 \ 297 | --num_epochs 5 \ 298 | --save_interval 5000 \ 299 | --max_length 2048 \ 300 | --save_dir output-checkpoints \ 301 | --plugin zero2_cpu \ 302 | --lr 2e-5 \ 303 | --expanded_model hpcai-tech/Colossal-LLaMA-2-7b-base 304 | ``` 305 | In the example, it uses the sample dataset 'togethercomputer/RedPajama-Data-1T-Sample' for training. It trains the 7B model 'hpcai-tech/Colossal-LLaMA-2-7b-base'. You can refer the main file 'run.sh' and 'pretrain.py' for more details. To start the training, run the following, 306 | ```bash 307 | bash run.sh 308 | ``` 309 | 310 | #### b. Running on a sinlge node 311 | 312 | we provide an example to run the training on multiple nodes as below, 313 | ``` 314 | srun colossalai run --num_nodes 8 --nproc_per_node 8 pretrain.py \ 315 | --config 7b \ 316 | --dataset cerebras/SlimPajama-627B \ 317 | --batch_size 1 \ 318 | --num_epochs 10 \ 319 | --save_interval 50000 \ 320 | --max_length 2048 \ 321 | --save_dir output-checkpoints \ 322 | --flash_attention \ 323 | --plugin zero2_cpu \ 324 | --lr 1e-5 \ 325 | --expanded_model hpcai-tech/Colossal-LLaMA-2-7b-base 326 | ``` 327 | It uses 8 nodes. Put your host file (`hosts.txt`) in this directory with your real host ip or host name. 328 | Here is a sample `hosts.txt`: 329 | ```text 330 | hostname1 331 | hostname2 332 | hostname3 333 | ... 334 | hostname8 335 | ``` 336 | You can refer to the main file 'run-multi-server.sh' and 'pretrain.py' for more details. To start the training, run the following, 337 | 338 | ```bash 339 | bash run-multi-server.sh 340 | ``` 341 | 342 | ### 2. Benchmark 343 | 344 | 345 | You can find the shell scripts in 'scripts/benchmark_7B' directory. The benchmark mainly test the throughput of the LLM, without actual model training. The main command should be in the format of: 346 | ``` 347 | colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ 348 | benchmark.py --OTHER_CONFIGURATIONS 349 | ``` 350 | 351 | Here we will show an example of how to run training llama pretraining with 'gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True'. 352 | 353 | #### a. Running environment 354 | 355 | This experiment was performed on 4 computing nodes with 32 L40S GPUs in total for LLaMA-2 7B. The nodes are connected with RDMA and GPUs within one node are fully connected with NVLink. 356 | 357 | #### b. Running command 358 | 359 | ```bash 360 | cd scripts/benchmark_7B 361 | ``` 362 | 363 | First, put your host file (`hosts.txt`) in this directory with your real host ip or host name. 364 | 365 | Here is a sample `hosts.txt`: 366 | ```text 367 | hostname1 368 | hostname2 369 | hostname3 370 | hostname4 371 | ``` 372 | 373 | Then add environment variables to script if needed. 374 | 375 | Finally, run the following command to start training: 376 | 377 | ```bash 378 | bash gemini.sh 379 | ``` 380 | 381 | 382 | ## Citation 383 | 384 | ``` 385 | @article{zhao2024fully, 386 | title={Fully Open Source Moxin-7B Technical Report}, 387 | author={Zhao, Pu and Shen, Xuan and Kong, Zhenglun and Shen, Yixin and Chang, Sung-En and Rupprecht, Timothy and Lu, Lei and Nan, Enfu and Yang, Changdi and He, Yumei and others}, 388 | journal={arXiv preprint arXiv:2412.06845}, 389 | year={2024} 390 | } 391 | ``` 392 | 393 | 394 | 395 | 396 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiohttp==3.9.3 3 | aiosignal==1.3.1 4 | annotated-types==0.6.0 5 | async-timeout==4.0.3 6 | attrs==23.2.0 7 | bcrypt==4.1.2 8 | beautifulsoup4==4.12.3 9 | cachetools==5.3.3 10 | certifi==2024.2.2 11 | cffi==1.16.0 12 | cfgv==3.4.0 13 | charset-normalizer==3.3.2 14 | click==8.1.7 15 | cmake==3.29.0.1 16 | colossalai==0.3.6 17 | contexttimer==0.3.3 18 | cryptography==42.0.5 19 | datasets==2.18.0 20 | decorator==5.1.1 21 | Deprecated==1.2.14 22 | dill==0.3.8 23 | distlib==0.3.8 24 | einops==0.7.0 25 | fabric==3.2.2 26 | filelock==3.13.3 27 | flash-attn==2.2.1 28 | frozenlist==1.4.1 29 | fsspec==2024.2.0 30 | google==3.0.0 31 | google-auth==2.29.0 32 | google-auth-oauthlib==1.0.0 33 | grpcio==1.62.1 34 | huggingface-hub==0.22.2 35 | identify==2.5.35 36 | idna==3.6 37 | invoke==2.2.0 38 | Jinja2==3.1.3 39 | jsonschema==4.21.1 40 | jsonschema-specifications==2023.12.1 41 | lit==18.1.2 42 | Markdown==3.6 43 | markdown-it-py==3.0.0 44 | MarkupSafe==2.1.5 45 | mdurl==0.1.2 46 | mpmath==1.3.0 47 | msgpack==1.0.8 48 | multidict==6.0.5 49 | multiprocess==0.70.16 50 | networkx== 3.3 51 | ninja==1.11.1.1 52 | nodeenv==1.8.0 53 | numpy==1.26.4 54 | nvidia-cublas-cu11==11.10.3.66 55 | nvidia-cublas-cu12==12.1.3.1 56 | nvidia-cuda-cupti-cu11==11.7.101 57 | nvidia-cuda-cupti-cu12==12.1.105 58 | nvidia-cuda-nvrtc-cu11==11.7.99 59 | nvidia-cuda-nvrtc-cu12==12.1.105 60 | nvidia-cuda-runtime-cu11==11.7.99 61 | nvidia-cuda-runtime-cu12==12.1.105 62 | nvidia-cudnn-cu11==8.5.0.96 63 | nvidia-cudnn-cu12==8.9.2.26 64 | nvidia-cufft-cu11==10.9.0.58 65 | nvidia-cufft-cu12==11.0.2.54 66 | nvidia-curand-cu11==10.2.10.91 67 | nvidia-curand-cu12==10.3.2.106 68 | nvidia-cusolver-cu11==11.4.0.1 69 | nvidia-cusolver-cu12==11.4.5.107 70 | nvidia-cusparse-cu11==11.7.4.91 71 | nvidia-cusparse-cu12==12.1.0.106 72 | nvidia-nccl-cu11==2.14.3 73 | nvidia-nccl-cu12==2.19.3 74 | nvidia-nvjitlink-cu12==12.4.127 75 | nvidia-nvtx-cu11==11.7.91 76 | nvidia-nvtx-cu12==12.1.105 77 | oauthlib==3.2.2 78 | packaging==24.0 79 | pandas==2.2.1 80 | paramiko==3.4.0 81 | pip==23.3.1 82 | platformdirs==4.2.0 83 | pre-commit==3.7.0 84 | protobuf==5.26.1 85 | psutil==5.9.8 86 | pyarrow==15.0.2 87 | pyarrow-hotfix==0.6 88 | pyasn1==0.6.0 89 | pyasn1_modules==0.4.0 90 | pycparser==2.22 91 | pydantic==2.6.4 92 | pydantic_core==2.16.3 93 | Pygments==2.17.2 94 | PyNaCl==1.5.0 95 | python-dateutil==2.9.0.post0 96 | pytz==2024.1 97 | PyYAML==6.0.1 98 | ray==2.10.0 99 | referencing==0.34.0 100 | regex==2023.12.25 101 | requests==2.31.0 102 | requests-oauthlib==2.0.0 103 | rich==13.7.1 104 | rpds-py==0.18.0 105 | rsa==4.9 106 | safetensors==0.4.2 107 | sentencepiece==0.1.99 108 | setuptools==68.2.2 109 | six==1.16.0 110 | soupsieve==2.5 111 | sympy==1.12 112 | tensorboard==2.14.0 113 | tensorboard-data-server==0.7.2 114 | tokenizers==0.13.3 115 | torch==2.0.0 116 | tqdm==4.66.2 117 | transformers==4.34.0 118 | triton==2.0.0 119 | typing_extensions==4.11.0 120 | tzdata==2024.1 121 | urllib3==2.2.1 122 | virtualenv==20.25.1 123 | Werkzeug==3.0.2 124 | wheel==0.41.2 125 | wrapt==1.16.0 126 | xxhash==3.4.1 127 | yarl==1.9.4 128 | -------------------------------------------------------------------------------- /scripts/benchmark/README.md: -------------------------------------------------------------------------------- 1 | # Pretraining LLaMA-1/2/3: best practices for building LLaMA-1/2/3-like base models 2 | ### LLaMA3 3 |

4 | 5 |

6 | 7 | - 70 billion parameter LLaMA3 model training accelerated by 18% 8 | 9 | ### LLaMA2 10 |

11 | 12 |

13 | 14 | - 70 billion parameter LLaMA2 model training accelerated by 195% 15 | [[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training) 16 | 17 | ### LLaMA1 18 |

19 | 20 |

21 | 22 | - 65-billion-parameter large model pretraining accelerated by 38% 23 | [[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining) 24 | 25 | ## Usage 26 | 27 | > ⚠ This example only has benchmarking script. For training/finetuning, please refer to the [applications/Colossal-LLaMA](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA). 28 | 29 | ### 1. Installation 30 | 31 | Please install the latest ColossalAI from source. 32 | 33 | ```bash 34 | BUILD_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI 35 | ``` 36 | 37 | Then install other dependencies. 38 | 39 | ```bash 40 | pip install -r requirements.txt 41 | ``` 42 | 43 | ### 4. Shell Script Examples 44 | 45 | For your convenience, we provide some shell scripts to run benchmark with various configurations. 46 | 47 | You can find them in `scripts/benchmark_7B` and `scripts/benchmark_70B` directory. The main command should be in the format of: 48 | ```bash 49 | colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ 50 | benchmark.py --OTHER_CONFIGURATIONS 51 | ``` 52 | Here we will show an example of how to run training 53 | llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`. 54 | 55 | #### a. Running environment 56 | This experiment was performed on 4 computing nodes with 32 A800/H800 80GB GPUs in total for LLaMA-1 65B or LLaMA-2 70B. The nodes are 57 | connected with RDMA and GPUs within one node are fully connected with NVLink. 58 | 59 | #### b. Running command 60 | 61 | ```bash 62 | cd scripts/benchmark_7B 63 | ``` 64 | 65 | First, put your host file (`hosts.txt`) in this directory with your real host ip or host name. 66 | 67 | Here is a sample `hosts.txt`: 68 | ```text 69 | hostname1 70 | hostname2 71 | hostname3 72 | hostname4 73 | ``` 74 | 75 | Then add environment variables to script if needed. 76 | 77 | Finally, run the following command to start training: 78 | 79 | ```bash 80 | bash gemini.sh 81 | ``` 82 | 83 | If you encounter out-of-memory(OOM) error during training with script `gemini.sh`, changing to script `gemini_auto.sh` might be a solution, since gemini_auto will set a upper limit on GPU memory usage through offloading part of the model parameters and optimizer states back to CPU memory. But there's a trade-off: `gemini_auto.sh` will be a bit slower, since more data are transmitted between CPU and GPU. 84 | 85 | #### c. Results 86 | If you run the above command successfully, you will get the following results: 87 | `max memory usage: 55491.10 MB, throughput: 24.26 samples/s, TFLOPS/GPU: 167.43`. 88 | 89 | 90 | ## Reference 91 | ``` 92 | @article{bian2021colossal, 93 | title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, 94 | author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, 95 | journal={arXiv preprint arXiv:2110.14883}, 96 | year={2021} 97 | } 98 | ``` 99 | 100 | ```bibtex 101 | @software{openlm2023openllama, 102 | author = {Geng, Xinyang and Liu, Hao}, 103 | title = {OpenLLaMA: An Open Reproduction of LLaMA}, 104 | month = May, 105 | year = 2023, 106 | url = {https://github.com/openlm-research/open_llama} 107 | } 108 | ``` 109 | 110 | ```bibtex 111 | @software{together2023redpajama, 112 | author = {Together Computer}, 113 | title = {RedPajama-Data: An Open Source Recipe to Reproduce LLaMA training dataset}, 114 | month = April, 115 | year = 2023, 116 | url = {https://github.com/togethercomputer/RedPajama-Data} 117 | } 118 | ``` 119 | 120 | ```bibtex 121 | @article{touvron2023llama, 122 | title={Llama: Open and efficient foundation language models}, 123 | author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others}, 124 | journal={arXiv preprint arXiv:2302.13971}, 125 | year={2023} 126 | } 127 | ``` 128 | -------------------------------------------------------------------------------- /scripts/benchmark/benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import resource 3 | import time 4 | from contextlib import nullcontext 5 | 6 | import torch 7 | from data_utils import RandomDataset 8 | from model_utils import format_numel_str, get_model_numel 9 | from performance_evaluator import PerformanceEvaluator, get_profile_context 10 | from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision 11 | from tqdm import tqdm 12 | from transformers import AutoConfig, AutoModelForCausalLM 13 | from transformers.models.llama.configuration_llama import LlamaConfig 14 | 15 | import colossalai 16 | from colossalai.accelerator import get_accelerator 17 | from colossalai.booster import Booster 18 | from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin 19 | from colossalai.cluster import DistCoordinator 20 | from colossalai.lazy import LazyInitContext 21 | from colossalai.nn.optimizer import HybridAdam 22 | from colossalai.shardformer import PipelineGradientCheckpointConfig 23 | 24 | # ============================== 25 | # Constants 26 | # ============================== 27 | 28 | MODEL_CONFIGS = { 29 | "7b": LlamaConfig(max_position_embeddings=4096), 30 | "13b": LlamaConfig( 31 | hidden_size=5120, 32 | intermediate_size=13824, 33 | num_hidden_layers=40, 34 | num_attention_heads=40, 35 | max_position_embeddings=4096, 36 | ), 37 | "70b": LlamaConfig( 38 | hidden_size=8192, 39 | intermediate_size=28672, 40 | num_hidden_layers=80, 41 | num_attention_heads=64, 42 | max_position_embeddings=4096, 43 | num_key_value_heads=8, 44 | ), 45 | } 46 | 47 | 48 | def main(): 49 | # ============================== 50 | # Parse Arguments 51 | # ============================== 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration") 54 | parser.add_argument( 55 | "-p", 56 | "--plugin", 57 | choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"], 58 | default="gemini", 59 | help="Choose which plugin to use", 60 | ) 61 | parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size") 62 | parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") 63 | parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") 64 | parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") 65 | parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") 66 | parser.add_argument( 67 | "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto" 68 | ) 69 | parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb") 70 | parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers") 71 | parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini") 72 | parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") 73 | parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") 74 | parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") 75 | parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") 76 | parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") 77 | parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") 78 | parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") 79 | parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) 80 | parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) 81 | parser.add_argument( 82 | "--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation", default=False 83 | ) 84 | parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") 85 | args = parser.parse_args() 86 | 87 | colossalai.launch_from_torch() 88 | coordinator = DistCoordinator() 89 | 90 | def empty_init(): 91 | pass 92 | 93 | # ckpt config for LLaMA3-70B on 64 H100 GPUs 94 | hybrid_kwargs = ( 95 | { 96 | "gradient_checkpoint_config": PipelineGradientCheckpointConfig( 97 | num_ckpt_layers_per_stage=[19, 19, 19, 13], 98 | ), 99 | "num_layers_per_stage": [19, 20, 20, 21], 100 | } 101 | if args.custom_ckpt 102 | else {} 103 | ) 104 | 105 | # ============================== 106 | # Initialize Booster 107 | # ============================== 108 | use_empty_init = True 109 | if args.plugin == "gemini": 110 | plugin = GeminiPlugin( 111 | precision="bf16", 112 | shard_param_frac=args.shard_param_frac, 113 | offload_optim_frac=args.offload_optim_frac, 114 | offload_param_frac=args.offload_param_frac, 115 | tp_size=args.tp, 116 | extra_dp_size=args.extra_dp, 117 | enable_fused_normalization=torch.cuda.is_available(), 118 | enable_flash_attention=args.xformers, 119 | max_prefetch=args.prefetch_num, 120 | enable_async_reduce=not args.disable_async_reduce, 121 | ) 122 | elif args.plugin == "gemini_auto": 123 | plugin = GeminiPlugin( 124 | placement_policy="auto", 125 | precision="bf16", 126 | warmup_non_model_data_ratio=args.warmup_ratio, 127 | tp_size=args.tp, 128 | extra_dp_size=args.extra_dp, 129 | enable_fused_normalization=torch.cuda.is_available(), 130 | max_prefetch=args.prefetch_num, 131 | enable_async_reduce=not args.disable_async_reduce, 132 | enable_flash_attention=args.xformers, 133 | ) 134 | elif args.plugin == "fsdp": 135 | if use_empty_init: 136 | plugin = TorchFSDPPlugin( 137 | mixed_precision=MixedPrecision( 138 | param_dtype=torch.float16, 139 | reduce_dtype=torch.float16, 140 | buffer_dtype=torch.float16, 141 | ), 142 | param_init_fn=empty_init(), 143 | ) 144 | else: 145 | plugin = TorchFSDPPlugin( 146 | mixed_precision=MixedPrecision( 147 | param_dtype=torch.float16, 148 | reduce_dtype=torch.float16, 149 | buffer_dtype=torch.float16, 150 | ) 151 | ) 152 | elif args.plugin == "fsdp_cpu": 153 | if use_empty_init: 154 | plugin = TorchFSDPPlugin( 155 | mixed_precision=MixedPrecision( 156 | param_dtype=torch.float16, 157 | reduce_dtype=torch.float16, 158 | buffer_dtype=torch.float16, 159 | ), 160 | cpu_offload=CPUOffload(offload_params=True), 161 | param_init_fn=empty_init(), 162 | ) 163 | else: 164 | plugin = TorchFSDPPlugin( 165 | mixed_precision=MixedPrecision( 166 | param_dtype=torch.float16, 167 | reduce_dtype=torch.float16, 168 | buffer_dtype=torch.float16, 169 | ), 170 | cpu_offload=CPUOffload(offload_params=True), 171 | ) 172 | elif args.plugin == "3d": 173 | plugin = HybridParallelPlugin( 174 | tp_size=args.tp, 175 | pp_size=args.pp, 176 | zero_stage=args.zero, 177 | enable_fused_normalization=torch.cuda.is_available(), 178 | enable_flash_attention=args.xformers, 179 | microbatch_size=args.mbs, 180 | precision="bf16", 181 | dp_outside=False, 182 | **hybrid_kwargs, 183 | ) 184 | elif args.plugin == "3d_cpu": 185 | plugin = HybridParallelPlugin( 186 | tp_size=args.tp, 187 | pp_size=args.pp, 188 | zero_stage=args.zero, 189 | cpu_offload=True, 190 | enable_fused_normalization=torch.cuda.is_available(), 191 | enable_flash_attention=args.xformers, 192 | microbatch_size=args.mbs, 193 | initial_scale=2**8, 194 | precision="bf16", 195 | ) 196 | else: 197 | raise ValueError(f"Unknown plugin {args.plugin}") 198 | 199 | booster = Booster(plugin=plugin) 200 | 201 | # ============================== 202 | # Initialize Dataset and Dataloader 203 | # ============================== 204 | dp_size = getattr(plugin, "dp_size", coordinator.world_size) 205 | 206 | if args.config in MODEL_CONFIGS: 207 | config = MODEL_CONFIGS[args.config] 208 | else: 209 | config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) 210 | dataset = RandomDataset( 211 | num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size 212 | ) 213 | dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) 214 | 215 | # ============================== 216 | # Initialize Model and Optimizer 217 | # ============================== 218 | init_ctx = ( 219 | LazyInitContext(default_device=get_accelerator().get_current_device()) 220 | if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) 221 | else nullcontext() 222 | ) 223 | 224 | init_kwargs = {} 225 | if config.model_type == "chatglm": 226 | init_kwargs["empty_init"] = False 227 | 228 | with init_ctx: 229 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs) 230 | 231 | if args.grad_checkpoint: 232 | model.gradient_checkpointing_enable() 233 | if config.model_type == "chatglm": 234 | model.transformer.encoder.gradient_checkpointing = True 235 | 236 | model_numel = get_model_numel(model) 237 | coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") 238 | performance_evaluator = PerformanceEvaluator( 239 | model_numel, 240 | model.config.num_hidden_layers, 241 | model.config.hidden_size, 242 | model.config.vocab_size, 243 | args.grad_checkpoint, 244 | args.ignore_steps, 245 | dp_world_size=dp_size, 246 | ) 247 | 248 | optimizer = HybridAdam(model.parameters()) 249 | torch.set_default_dtype(torch.bfloat16) 250 | model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) 251 | torch.set_default_dtype(torch.float) 252 | coordinator.print_on_master( 253 | f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" 254 | ) 255 | coordinator.print_on_master( 256 | f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" 257 | ) 258 | 259 | with get_profile_context( 260 | args.profile, 261 | 1, 262 | len(dataloader) - 1, 263 | save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", 264 | ) as prof: 265 | if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: 266 | data_iter = iter(dataloader) 267 | for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): 268 | performance_evaluator.on_step_start(step) 269 | booster.execute_pipeline( 270 | data_iter, 271 | model, 272 | criterion=lambda outputs, inputs: outputs[0], 273 | optimizer=optimizer, 274 | return_loss=False, 275 | ) 276 | optimizer.step() 277 | optimizer.zero_grad() 278 | performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) 279 | prof.step() 280 | else: 281 | for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): 282 | performance_evaluator.on_step_start(step) 283 | outputs = model(**batch) 284 | loss = outputs[0] 285 | booster.backward(loss, optimizer) 286 | optimizer.step() 287 | optimizer.zero_grad() 288 | performance_evaluator.on_step_end(**batch) 289 | prof.step() 290 | 291 | performance_evaluator.on_fit_end() 292 | coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") 293 | 294 | 295 | if __name__ == "__main__": 296 | main() 297 | -------------------------------------------------------------------------------- /scripts/benchmark/data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from typing import Iterator, Optional 4 | 5 | import numpy as np 6 | import torch 7 | from torch.distributed import ProcessGroup 8 | from torch.distributed.distributed_c10d import _get_default_group 9 | from torch.utils.data import DataLoader, Dataset, DistributedSampler 10 | 11 | from colossalai.accelerator import get_accelerator 12 | 13 | 14 | class StatefulDistributedSampler(DistributedSampler): 15 | def __init__( 16 | self, 17 | dataset: Dataset, 18 | num_replicas: Optional[int] = None, 19 | rank: Optional[int] = None, 20 | shuffle: bool = True, 21 | seed: int = 0, 22 | drop_last: bool = False, 23 | ) -> None: 24 | super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) 25 | self.start_index: int = 0 26 | 27 | def __iter__(self) -> Iterator: 28 | iterator = super().__iter__() 29 | indices = list(iterator) 30 | indices = indices[self.start_index :] 31 | return iter(indices) 32 | 33 | def __len__(self) -> int: 34 | return self.num_samples - self.start_index 35 | 36 | def set_start_index(self, start_index: int) -> None: 37 | self.start_index = start_index 38 | 39 | 40 | def prepare_dataloader( 41 | dataset, 42 | batch_size, 43 | shuffle=False, 44 | seed=1024, 45 | drop_last=False, 46 | pin_memory=False, 47 | num_workers=0, 48 | process_group: Optional[ProcessGroup] = None, 49 | **kwargs, 50 | ): 51 | r""" 52 | Prepare a dataloader for distributed training. The dataloader will be wrapped by 53 | `torch.utils.data.DataLoader` and `StatefulDistributedSampler`. 54 | 55 | 56 | Args: 57 | dataset (`torch.utils.data.Dataset`): The dataset to be loaded. 58 | shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. 59 | seed (int, optional): Random worker seed for sampling, defaults to 1024. 60 | add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. 61 | drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size 62 | is not divisible by the batch size. If False and the size of dataset is not divisible by 63 | the batch size, then the last batch will be smaller, defaults to False. 64 | pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. 65 | num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. 66 | kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in 67 | `DataLoader `_. 68 | 69 | Returns: 70 | :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. 71 | """ 72 | _kwargs = kwargs.copy() 73 | process_group = process_group or _get_default_group() 74 | sampler = StatefulDistributedSampler( 75 | dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle 76 | ) 77 | 78 | # Deterministic dataloader 79 | def seed_worker(worker_id): 80 | worker_seed = seed 81 | np.random.seed(worker_seed) 82 | torch.manual_seed(worker_seed) 83 | random.seed(worker_seed) 84 | 85 | return DataLoader( 86 | dataset, 87 | batch_size=batch_size, 88 | sampler=sampler, 89 | worker_init_fn=seed_worker, 90 | drop_last=drop_last, 91 | pin_memory=pin_memory, 92 | num_workers=num_workers, 93 | **_kwargs, 94 | ) 95 | 96 | 97 | def load_json(file_path: str): 98 | with open(file_path, "r") as f: 99 | return json.load(f) 100 | 101 | 102 | def save_json(data, file_path: str): 103 | with open(file_path, "w") as f: 104 | json.dump(data, f, indent=4) 105 | 106 | 107 | class RandomDataset(Dataset): 108 | def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): 109 | self.num_samples = num_samples 110 | self.max_length = max_length 111 | self.input_ids = torch.randint( 112 | 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() 113 | ) 114 | self.attention_mask = torch.ones_like(self.input_ids) 115 | 116 | def __len__(self): 117 | return self.num_samples 118 | 119 | def __getitem__(self, idx): 120 | return { 121 | "input_ids": self.input_ids[idx], 122 | "attention_mask": self.attention_mask[idx], 123 | "labels": self.input_ids[idx], 124 | } 125 | -------------------------------------------------------------------------------- /scripts/benchmark/model_utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | @contextmanager 8 | def low_precision_init(target_dtype: torch.dtype = torch.float16): 9 | dtype = torch.get_default_dtype() 10 | try: 11 | torch.set_default_dtype(target_dtype) 12 | yield 13 | finally: 14 | torch.set_default_dtype(dtype) 15 | 16 | 17 | def get_model_numel(model: nn.Module) -> int: 18 | return sum(p.numel() for p in model.parameters()) 19 | 20 | 21 | def format_numel_str(numel: int) -> str: 22 | B = 1024**3 23 | M = 1024**2 24 | K = 1024 25 | if numel >= B: 26 | return f"{numel / B:.2f} B" 27 | elif numel >= M: 28 | return f"{numel / M:.2f} M" 29 | elif numel >= K: 30 | return f"{numel / K:.2f} K" 31 | else: 32 | return f"{numel}" 33 | -------------------------------------------------------------------------------- /scripts/benchmark/performance_evaluator.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from torch import Tensor 7 | from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler 8 | 9 | from colossalai.accelerator import get_accelerator 10 | from colossalai.cluster import DistCoordinator 11 | 12 | 13 | def divide(x: float, y: float) -> float: 14 | if y == 0: 15 | return float("inf") 16 | elif y == float("inf"): 17 | return float("nan") 18 | return x / y 19 | 20 | 21 | @torch.no_grad() 22 | def all_reduce_mean(x: float, world_size: int) -> float: 23 | if world_size == 1: 24 | return x 25 | tensor = torch.tensor([x], device=get_accelerator().get_current_device()) 26 | dist.all_reduce(tensor) 27 | tensor = tensor / world_size 28 | return tensor.item() 29 | 30 | 31 | def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): 32 | class DummyProfiler: 33 | def __init__(self): 34 | self.step_number = 0 35 | 36 | def step(self): 37 | self.step_number += 1 38 | 39 | def __enter__(self): 40 | return self 41 | 42 | def __exit__(self, exc_type, exc_value, traceback): 43 | pass 44 | 45 | if enable_flag: 46 | return profile( 47 | activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 48 | schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), 49 | on_trace_ready=tensorboard_trace_handler(save_dir), 50 | record_shapes=True, 51 | profile_memory=True, 52 | with_stack=True, 53 | ) 54 | else: 55 | return DummyProfiler() 56 | 57 | 58 | class Timer: 59 | def __init__(self) -> None: 60 | self.start_time: Optional[float] = None 61 | self.duration: float = 0.0 62 | 63 | def start(self) -> None: 64 | self.start_time = time() 65 | 66 | def end(self) -> None: 67 | assert self.start_time is not None 68 | self.duration += time() - self.start_time 69 | self.start_time = None 70 | 71 | def reset(self) -> None: 72 | self.duration = 0.0 73 | 74 | 75 | class PerformanceEvaluator: 76 | """ 77 | Callback for valuate the performance of the model. 78 | Args: 79 | actor_num_params: The number of parameters of the actor model. 80 | critic_num_params: The number of parameters of the critic model. 81 | initial_model_num_params: The number of parameters of the initial model. 82 | reward_model_num_params: The number of parameters of the reward model. 83 | enable_grad_checkpoint: Whether to enable gradient checkpointing. 84 | ignore_episodes: The number of episodes to ignore when calculating the performance. 85 | """ 86 | 87 | def __init__( 88 | self, 89 | model_numel: int, 90 | num_layers: int, 91 | hidden_size: int, 92 | vocab_size: int, 93 | enable_grad_checkpoint: bool = False, 94 | ignore_steps: int = 0, 95 | dp_world_size: Optional[int] = None, 96 | ) -> None: 97 | self.model_numel = model_numel 98 | self.enable_grad_checkpoint = enable_grad_checkpoint 99 | self.ignore_steps = ignore_steps 100 | self.num_layers = num_layers 101 | self.hidden_size = hidden_size 102 | self.vocab_size = vocab_size 103 | 104 | self.coordinator = DistCoordinator() 105 | self.dp_world_size = dp_world_size or self.coordinator.world_size 106 | self.disable: bool = False 107 | self.timer = Timer() 108 | self.num_samples: int = 0 109 | self.flop_megatron = 0 110 | self.flop: int = 0 111 | 112 | def on_step_start(self, step: int) -> None: 113 | self.disable = self.ignore_steps > 0 and step < self.ignore_steps 114 | if self.disable: 115 | return 116 | get_accelerator().synchronize() 117 | self.timer.start() 118 | 119 | def on_step_end(self, input_ids: Tensor, **kwargs) -> None: 120 | if self.disable: 121 | return 122 | get_accelerator().synchronize() 123 | self.timer.end() 124 | 125 | batch_size, seq_len = input_ids.shape 126 | 127 | self.num_samples += batch_size 128 | checkpoint_activations_factor = 3 + int(self.enable_grad_checkpoint) 129 | self.flop_megatron += ( 130 | 24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2) 131 | ) * ( 132 | 1.0 + (seq_len / (6.0 * self.hidden_size)) + (self.vocab_size / (16.0 * self.num_layers * self.hidden_size)) 133 | ) 134 | self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)) 135 | 136 | def on_fit_end(self) -> None: 137 | avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size) 138 | avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12) 139 | mp_world_size = self.coordinator.world_size // self.dp_world_size 140 | avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size 141 | avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size 142 | self.coordinator.print_on_master( 143 | f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, " 144 | f"avg_throughput: {avg_throughput}" 145 | ) 146 | self.coordinator.print_on_master( 147 | f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f}, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}" 148 | ) 149 | -------------------------------------------------------------------------------- /scripts/benchmark/requirements.txt: -------------------------------------------------------------------------------- 1 | colossalai>=0.3.6 2 | datasets 3 | numpy 4 | tqdm 5 | transformers 6 | flash-attn>=2.0.0 7 | SentencePiece==0.1.99 8 | tensorboard==2.14.0 9 | -------------------------------------------------------------------------------- /scripts/benchmark/scripts/benchmark_7B/gemini.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################ 4 | #Load your environments and modules here 5 | ################ 6 | 7 | HOSTFILE=$(realpath hosts.txt) 8 | 9 | cd ../.. 10 | 11 | export OMP_NUM_THREADS=8 12 | 13 | colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -g -x -b 16 14 | -------------------------------------------------------------------------------- /scripts/benchmark/scripts/benchmark_7B/gemini_auto.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################ 4 | #Load your environments and modules here 5 | ################ 6 | 7 | HOSTFILE=$(realpath hosts.txt) 8 | 9 | cd ../.. 10 | 11 | export OMP_NUM_THREADS=8 12 | 13 | colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -p gemini_auto -g -x -b 16 14 | -------------------------------------------------------------------------------- /scripts/benchmark/test_ci.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moxin-org/Moxin-LLM/2ff78a7875c613d1fa2198f6d398813e9caf402e/scripts/benchmark/test_ci.sh -------------------------------------------------------------------------------- /scripts/finetune/chat_finetune/finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import resource 4 | from contextlib import nullcontext 5 | from functools import partial 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.nn as nn 11 | from attn import replace_with_flash_attention 12 | 13 | # xuan ======================================== 14 | import sys; sys.path.append("..") 15 | # ======================================== 16 | 17 | # datasets 2.18.0 18 | # fsspec 2024.2.0 19 | 20 | from data_utils import load_json, prepare_dataloader, save_json 21 | from datasets import load_dataset, load_from_disk # , save_to_disk 22 | from torch.optim import Optimizer 23 | from torch.optim.lr_scheduler import _LRScheduler 24 | from torch.utils.tensorboard import SummaryWriter 25 | from tqdm import tqdm 26 | from transformers.models.llama.configuration_llama import LlamaConfig 27 | from transformers.models.llama.modeling_llama import LlamaForCausalLM 28 | from transformers.models.llama.tokenization_llama import LlamaTokenizer 29 | 30 | import colossalai 31 | from colossalai.accelerator import get_accelerator 32 | from colossalai.booster import Booster 33 | from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin 34 | from colossalai.cluster import DistCoordinator 35 | from colossalai.lazy import LazyInitContext 36 | from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR 37 | from colossalai.nn.optimizer import HybridAdam 38 | 39 | MODEL_CONFIGS = { 40 | "7b": LlamaConfig(max_position_embeddings=4096), 41 | "13b": LlamaConfig( 42 | hidden_size=5120, 43 | intermediate_size=13824, 44 | num_hidden_layers=40, 45 | num_attention_heads=40, 46 | max_position_embeddings=4096, 47 | ), 48 | "70b": LlamaConfig( 49 | hidden_size=8192, 50 | intermediate_size=28672, 51 | num_hidden_layers=80, 52 | num_attention_heads=64, 53 | max_position_embeddings=4096, 54 | num_key_value_heads=8, 55 | ), 56 | } 57 | 58 | 59 | def get_model_numel(model: nn.Module) -> int: 60 | return sum(p.numel() for p in model.parameters()) 61 | 62 | 63 | def format_numel_str(numel: int) -> str: 64 | B = 1024 ** 3 65 | M = 1024 ** 2 66 | K = 1024 67 | if numel >= B: 68 | return f"{numel / B:.2f} B" 69 | elif numel >= M: 70 | return f"{numel / M:.2f} M" 71 | elif numel >= K: 72 | return f"{numel / K:.2f} K" 73 | else: 74 | return f"{numel}" 75 | 76 | 77 | def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): 78 | texts = [sample["text"] for sample in batch] 79 | data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) 80 | data = {k: v.cuda() for k, v in data.items()} 81 | data["labels"] = data["input_ids"].clone() 82 | return data 83 | 84 | 85 | def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: 86 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 87 | tensor = tensor.data 88 | tensor.div_(dist.get_world_size()) 89 | return tensor 90 | 91 | 92 | def save( 93 | booster: Booster, 94 | model: nn.Module, 95 | optimizer: Optimizer, 96 | lr_scheduler: _LRScheduler, 97 | epoch: int, 98 | step: int, 99 | batch_size: int, 100 | coordinator: DistCoordinator, 101 | save_dir: str, 102 | ): 103 | save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") 104 | os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) 105 | 106 | booster.save_model(model, os.path.join(save_dir, "model"), shard=True) 107 | booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) 108 | booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) 109 | running_states = { 110 | "epoch": epoch, 111 | "step": step, 112 | "sample_start_index": step * batch_size, 113 | } 114 | if coordinator.is_master(): 115 | save_json(running_states, os.path.join(save_dir, "running_states.json")) 116 | 117 | 118 | def load( 119 | booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str 120 | ) -> Tuple[int, int, int]: 121 | booster.load_model(model, os.path.join(load_dir, "model")) 122 | # booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) 123 | # booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) 124 | running_states = load_json(os.path.join(load_dir, "running_states.json")) 125 | return running_states["epoch"], running_states["step"], running_states["sample_start_index"] 126 | 127 | 128 | def _criterion(outputs, inputs): 129 | return outputs.loss 130 | 131 | 132 | def encode_with_prompt_completion_format(example, tokenizer, max_seq_length): 133 | ''' 134 | Here we assume each example has 'prompt' and 'completion' fields. 135 | We concatenate prompt and completion and tokenize them together because otherwise prompt will be padded/trancated 136 | and it doesn't make sense to follow directly with the completion. 137 | ''' 138 | # if prompt doesn't end with space and completion doesn't start with space, add space 139 | if not example['prompt'].endswith((' ', '\n', '\t')) and not example['completion'].startswith((' ', '\n', '\t')): 140 | example_text = example['prompt'] + ' ' + example['completion'] 141 | else: 142 | example_text = example['prompt'] + example['completion'] 143 | example_text = example_text + tokenizer.eos_token 144 | tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True) 145 | input_ids = tokenized_example.input_ids 146 | labels = input_ids.clone() 147 | tokenized_prompt = tokenizer(example['prompt'], return_tensors='pt', max_length=max_seq_length, truncation=True) 148 | # mask the prompt part for avoiding loss 149 | labels[:, :tokenized_prompt.input_ids.shape[1]] = -100 150 | attention_mask = torch.ones_like(input_ids) 151 | return { 152 | 'input_ids': input_ids.flatten(), 153 | 'labels': labels.flatten(), 154 | 'attention_mask': attention_mask.flatten(), 155 | } 156 | 157 | 158 | def encode_with_messages_format(example, tokenizer, max_seq_length): 159 | ''' 160 | Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields. 161 | We concatenate all messages with the roles as delimiters and tokenize them together. 162 | ''' 163 | messages = example['messages'] 164 | if len(messages) == 0: 165 | raise ValueError('messages field is empty.') 166 | 167 | def _concat_messages(messages): 168 | message_text = "" 169 | for message in messages: 170 | if message["role"] == "system": 171 | message_text += "<|system|>\n" + message["content"].strip() + "\n" 172 | elif message["role"] == "user": 173 | message_text += "<|user|>\n" + message["content"].strip() + "\n" 174 | elif message["role"] == "assistant": 175 | message_text += "<|assistant|>\n" + message["content"].strip() + tokenizer.eos_token + "\n" 176 | else: 177 | raise ValueError("Invalid role: {}".format(message["role"])) 178 | return message_text 179 | 180 | example_text = _concat_messages(messages).strip() 181 | tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True) 182 | input_ids = tokenized_example.input_ids 183 | labels = input_ids.clone() 184 | 185 | # mask the non-assistant part for avoiding loss 186 | for message_idx, message in enumerate(messages): 187 | if message["role"] != "assistant": 188 | if message_idx == 0: 189 | message_start_idx = 0 190 | else: 191 | message_start_idx = tokenizer( 192 | _concat_messages(messages[:message_idx]), return_tensors='pt', max_length=max_seq_length, 193 | truncation=True 194 | ).input_ids.shape[1] 195 | if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant": 196 | # here we also ignore the role of the assistant 197 | messages_so_far = _concat_messages(messages[:message_idx + 1]) + "<|assistant|>\n" 198 | else: 199 | messages_so_far = _concat_messages(messages[:message_idx + 1]) 200 | message_end_idx = tokenizer( 201 | messages_so_far, 202 | return_tensors='pt', 203 | max_length=max_seq_length, 204 | truncation=True 205 | ).input_ids.shape[1] 206 | labels[:, message_start_idx:message_end_idx] = -100 207 | 208 | if message_end_idx >= max_seq_length: 209 | break 210 | 211 | attention_mask = torch.ones_like(input_ids) 212 | return { 213 | 'input_ids': input_ids.flatten(), 214 | 'labels': labels.flatten(), 215 | 'attention_mask': attention_mask.flatten(), 216 | } 217 | 218 | 219 | def main(): 220 | # ============================== 221 | # Parse Arguments 222 | # ============================== 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration") 225 | parser.add_argument( 226 | "-p", 227 | "--plugin", 228 | choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], 229 | default="gemini", 230 | help="Choose which plugin to use", 231 | ) 232 | parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") 233 | parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") 234 | parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") 235 | parser.add_argument("--data_path", type=str, default="workspace/datasets/tulu_v2.jsonl", help="dataset path") 236 | parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") 237 | parser.add_argument("-s", "--warmup_steps", type=int, default=2000, help="Warmup steps") 238 | parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") 239 | parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") 240 | parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") 241 | parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") 242 | parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") 243 | parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") 244 | parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") 245 | parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") 246 | parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") 247 | 248 | # xuan ================================================================================ 249 | parser.add_argument("--accumulation_steps", default=16, help="accumulation steps") 250 | parser.add_argument("--expanded_model", default="", help="model path") 251 | parser.add_argument("--tokenizer_path", default="", help="model path") 252 | # ================================================================================ 253 | 254 | args = parser.parse_args() 255 | 256 | # ============================== 257 | # Initialize Distributed Training 258 | # ============================== 259 | colossalai.launch_from_torch({}) 260 | coordinator = DistCoordinator() 261 | 262 | # ============================== 263 | # Initialize Booster 264 | # ============================== 265 | if args.plugin == "gemini": 266 | plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2 ** 16, max_norm=args.grad_clip) 267 | elif args.plugin == "gemini_auto": 268 | plugin = GeminiPlugin( 269 | precision=args.mixed_precision, placement_policy="auto", initial_scale=2 ** 16, max_norm=args.grad_clip 270 | ) 271 | elif args.plugin == "zero2": 272 | plugin = LowLevelZeroPlugin( 273 | stage=2, precision=args.mixed_precision, initial_scale=2 ** 16, max_norm=args.grad_clip 274 | ) 275 | elif args.plugin == "zero2_cpu": 276 | plugin = LowLevelZeroPlugin( 277 | stage=2, precision=args.mixed_precision, initial_scale=2 ** 16, cpu_offload=True, max_norm=args.grad_clip 278 | ) 279 | elif args.plugin == "hybrid_parallel": 280 | plugin = HybridParallelPlugin( 281 | tp_size=4, 282 | pp_size=2, 283 | num_microbatches=None, 284 | microbatch_size=1, 285 | enable_jit_fused=False, 286 | zero_stage=0, 287 | precision=args.mixed_precision, 288 | initial_scale=1, 289 | ) 290 | else: 291 | raise ValueError(f"Unknown plugin {args.plugin}") 292 | 293 | booster = Booster(plugin=plugin) 294 | 295 | use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 296 | is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() 297 | print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) 298 | 299 | # ============================== 300 | # Initialize Tensorboard 301 | # ============================== 302 | if print_flag: 303 | os.makedirs(args.tensorboard_dir, exist_ok=True) 304 | writer = SummaryWriter(args.tensorboard_dir) 305 | 306 | # ============================== 307 | # Initialize Tokenizer, Dataset and Dataloader 308 | # ============================== 309 | from transformers import AutoTokenizer 310 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=False) 311 | tokenizer.pad_token = tokenizer.unk_token 312 | # ================================================================ 313 | 314 | 315 | train_file = args.data_path 316 | 317 | data_files = {} 318 | dataset_args = {} 319 | if train_file is not None: 320 | data_files["train"] = train_file 321 | raw_datasets = load_dataset( 322 | "json", 323 | data_files=data_files, 324 | **dataset_args, 325 | ) 326 | 327 | # Preprocessing the datasets. 328 | if "prompt" in raw_datasets["train"].column_names and "completion" in raw_datasets["train"].column_names: 329 | encode_function = partial( 330 | encode_with_prompt_completion_format, 331 | tokenizer=tokenizer, 332 | max_seq_length=args.max_length, 333 | ) 334 | elif "messages" in raw_datasets["train"].column_names: 335 | encode_function = partial( 336 | encode_with_messages_format, 337 | tokenizer=tokenizer, 338 | max_seq_length=args.max_length, 339 | ) 340 | else: 341 | raise ValueError("You need to have either 'prompt'&'completion' or 'messages' in your column names.") 342 | 343 | # with accelerator.main_process_first(): 344 | # if coordinator.is_master(): 345 | lm_datasets = raw_datasets.map( 346 | encode_function, 347 | batched=False, 348 | num_proc=16, 349 | load_from_cache_file=True, 350 | remove_columns=[name for name in raw_datasets["train"].column_names if 351 | name not in ["input_ids", "labels", "attention_mask"]], 352 | desc="Tokenizing and reformatting instruction data", 353 | ) 354 | lm_datasets.set_format(type="pt") 355 | lm_datasets = lm_datasets.filter(lambda example: (example['labels'] != -100).any()) 356 | 357 | 358 | dist.barrier() 359 | 360 | train_dataset = lm_datasets["train"] 361 | 362 | train_ds = train_dataset 363 | 364 | # ============================== 365 | # Initialize Model, Optimizer and LR Scheduler 366 | # ============================== 367 | config = MODEL_CONFIGS[args.config] 368 | 369 | init_ctx = ( 370 | LazyInitContext(default_device=get_accelerator().get_current_device()) 371 | if isinstance(plugin, GeminiPlugin) 372 | else nullcontext() 373 | ) 374 | 375 | with init_ctx: 376 | 377 | from transformers import AutoModelForCausalLM, AutoConfig 378 | model = AutoModelForCausalLM.from_pretrained(args.expanded_model, torch_dtype=torch.float16)#, device_map='cpu') 379 | 380 | for name, weight in model.named_parameters(): 381 | weight.requires_grad = True 382 | 383 | from transformers import DataCollatorForSeq2Seq 384 | dataloader = prepare_dataloader( 385 | lm_datasets["train"], 386 | batch_size=args.batch_size, 387 | shuffle=True, 388 | drop_last=True, 389 | collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"), 390 | ) 391 | # ======================================================================= 392 | 393 | if args.grad_checkpoint: 394 | model.gradient_checkpointing_enable() 395 | 396 | model_numel = get_model_numel(model) 397 | coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") 398 | 399 | optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) 400 | 401 | lr_scheduler = CosineAnnealingWarmupLR( 402 | optimizer, total_steps=args.num_epochs * len(dataloader), warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr 403 | ) 404 | default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 405 | torch.set_default_dtype(default_dtype) 406 | model, optimizer, _, dataloader, lr_scheduler = booster.boost( 407 | model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler 408 | ) 409 | torch.set_default_dtype(torch.float) 410 | 411 | coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") 412 | coordinator.print_on_master( 413 | f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" 414 | ) 415 | 416 | # load checkpoint if specified 417 | start_epoch = 0 418 | start_step = 0 419 | sampler_start_idx = 0 420 | if args.load is not None: 421 | coordinator.print_on_master("Loading checkpoint") 422 | start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) 423 | coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") 424 | 425 | num_steps_per_epoch = len(dataloader) 426 | 427 | # if resume training, set the sampler start index to the correct value 428 | dataloader.sampler.set_start_index(sampler_start_idx) 429 | for epoch in range(start_epoch, args.num_epochs): 430 | dataloader.sampler.set_epoch(epoch) 431 | dataloader_iter = iter(dataloader) 432 | 433 | for step in range(start_step, num_steps_per_epoch): 434 | if use_pipeline: 435 | outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True) 436 | loss = outputs["loss"] 437 | else: 438 | batch = next(dataloader_iter) 439 | batch = batch.to(get_accelerator().get_current_device()) 440 | outputs = model(**batch) 441 | loss = outputs[0] 442 | booster.backward(loss, optimizer) 443 | 444 | if (step + 1) % args.accumulation_steps == 0: 445 | optimizer.step() # Update parameters 446 | lr_scheduler.step() 447 | optimizer.zero_grad() # Reset gradients 448 | 449 | if step + 1 == num_steps_per_epoch - 1: 450 | optimizer.step() 451 | lr_scheduler.step() 452 | optimizer.zero_grad() 453 | 454 | if not use_pipeline: 455 | all_reduce_mean(loss) 456 | 457 | if print_flag: 458 | writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) 459 | print("Epoch: {}, step: {}, loss: {:.3f}".format( 460 | epoch, 461 | epoch * num_steps_per_epoch + step, 462 | loss.item() 463 | )) 464 | 465 | if args.save_interval > 0 and (step + 1) % args.save_interval == 0: 466 | coordinator.print_on_master(f"Saving checkpoint") 467 | save( 468 | booster, 469 | model, 470 | optimizer, 471 | lr_scheduler, 472 | epoch, 473 | step + 1, 474 | args.batch_size, 475 | coordinator, 476 | args.save_dir, 477 | ) 478 | coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") 479 | # the continue epochs are not resumed, so we need to reset the sampler start index and start step 480 | dataloader.sampler.set_start_index(0) 481 | start_step = 0 482 | 483 | coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") 484 | 485 | 486 | if __name__ == "__main__": 487 | main() 488 | -------------------------------------------------------------------------------- /scripts/finetune/instruct_finetune/dpo_finetune.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | accelerate launch \ 4 | --mixed_precision bf16 \ 5 | --num_machines 1 \ 6 | --num_processes 8 \ 7 | --use_deepspeed \ 8 | --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf open_instruct/dpo_tune.py \ 9 | --model_name_or_path output/sft_7b \ 10 | --use_flash_attn \ 11 | --tokenizer_name output/sft_7b \ 12 | --max_seq_length 2048 \ 13 | --preprocessing_num_workers 16 \ 14 | --per_device_train_batch_size 1 \ 15 | --gradient_accumulation_steps 16 \ 16 | --learning_rate 5e-07 \ 17 | --lr_scheduler_type linear \ 18 | --warmup_ratio 0.1 \ 19 | --weight_decay 0.0 \ 20 | --num_train_epochs 1 \ 21 | --output_dir output/dpo_7b \ 22 | --with_tracking \ 23 | --report_to wandb \ 24 | --logging_steps 1 \ 25 | --model_revision main \ 26 | --gradient_checkpointing \ 27 | --dataset_mixer_list allenai/llama-3.1-tulu-3-8b-preference-mixture 1.0 \ 28 | --use_slow_tokenizer \ 29 | --use_lora False \ 30 | --dpo_loss_type dpo_norm \ 31 | --dpo_beta 5 \ 32 | --checkpointing_steps 1000 \ 33 | --exp_name tulu-3-7b-dpo 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /scripts/finetune/instruct_finetune/sft_finetune.sh: -------------------------------------------------------------------------------- 1 | 2 | # modify the following `MACHINE_RANK`, `MAIN_PROCESS_IP`, 3 | # `NUM_MACHINES`, `NUM_PROCESSES`, `PER_DEVICE_TRAIN_BATCH_SIZE`, 4 | # `GRADIENT_ACCUMULATION_STEPS` according to your setup 5 | MACHINE_RANK=0 6 | MAIN_PROCESS_IP=localhost 7 | NUM_MACHINES=8 8 | NUM_PROCESSES=64 9 | PER_DEVICE_TRAIN_BATCH_SIZE=1 10 | GRADIENT_ACCUMULATION_STEPS=2 11 | accelerate launch \ 12 | --mixed_precision bf16 \ 13 | --num_machines 8 \ 14 | --num_processes 64 \ 15 | --machine_rank $MACHINE_RANK \ 16 | --main_process_ip $MAIN_PROCESS_IP \ 17 | --main_process_port 29400 \ 18 | --use_deepspeed \ 19 | --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \ 20 | --deepspeed_multinode_launcher standard open_instruct/finetune.py \ 21 | --model_name_or_path moxin-org/moxin-llm-7b \ 22 | --tokenizer_name moxin-org/moxin-llm-7b \ 23 | --use_slow_tokenizer \ 24 | --use_flash_attn \ 25 | --max_seq_length 4096 \ 26 | --preprocessing_num_workers 128 \ 27 | --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ 28 | --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \ 29 | --learning_rate 5e-06 \ 30 | --lr_scheduler_type linear \ 31 | --warmup_ratio 0.03 \ 32 | --weight_decay 0.0 \ 33 | --num_train_epochs 2 \ 34 | --output_dir output/sft_7b \ 35 | --with_tracking \ 36 | --report_to wandb \ 37 | --logging_steps 1 \ 38 | --reduce_loss sum \ 39 | --model_revision main \ 40 | --dataset_mixer_list allenai/tulu-3-sft-mixture 1.0 \ 41 | --checkpointing_steps epoch \ 42 | --dataset_mix_dir output/sft_7b \ 43 | --exp_name tulu-3-7b-sft \ 44 | --seed 123 45 | 46 | -------------------------------------------------------------------------------- /scripts/finetune/reason_finetune/train_7b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | # Warning: Export VLLM_ATTENTION_BACKEND on every machine before starting Ray cluster. 5 | # vLLM without XFORMERS will results in CUDA errors. 6 | export VLLM_ATTENTION_BACKEND=XFORMERS 7 | 8 | # Parse command line arguments 9 | while [[ $# -gt 0 ]]; do 10 | case $1 in 11 | --model) 12 | MODEL_PATH="$2" 13 | shift 2 14 | ;; 15 | *) 16 | break 17 | ;; 18 | esac 19 | done 20 | 21 | # Set default model path if not provided 22 | if [ -z "$MODEL_PATH" ]; then 23 | MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" 24 | fi 25 | 26 | # Train over a single node, 8 A100-80GB GPUs. 27 | python3 -m verl.trainer.main_ppo \ 28 | algorithm.adv_estimator=grpo \ 29 | data.train_files=$HOME/workspace/data/train.parquet \ 30 | data.val_files=$HOME/workspace/data/aime.parquet \ 31 | data.train_batch_size=64 \ 32 | data.val_batch_size=256 \ 33 | data.max_prompt_length=1024 \ 34 | data.max_response_length=16384 \ 35 | actor_rollout_ref.model.path=$MODEL_PATH \ 36 | actor_rollout_ref.actor.optim.lr=1e-6 \ 37 | actor_rollout_ref.model.use_remove_padding=True \ 38 | actor_rollout_ref.actor.ppo_mini_batch_size=64 \ 39 | actor_rollout_ref.actor.ppo_micro_batch_size=32 \ 40 | actor_rollout_ref.actor.use_dynamic_bsz=True \ 41 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768 \ 42 | actor_rollout_ref.actor.use_kl_loss=True \ 43 | actor_rollout_ref.actor.kl_loss_coef=0.001 \ 44 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \ 45 | actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ 46 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 47 | actor_rollout_ref.actor.fsdp_config.param_offload=False \ 48 | actor_rollout_ref.actor.fsdp_config.grad_offload=False \ 49 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ 50 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ 51 | actor_rollout_ref.rollout.name=vllm \ 52 | actor_rollout_ref.rollout.temperature=0.6 \ 53 | actor_rollout_ref.rollout.val_temperature=0.6 \ 54 | actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ 55 | actor_rollout_ref.rollout.n=8 \ 56 | actor_rollout_ref.rollout.n_val=8 \ 57 | actor_rollout_ref.ref.fsdp_config.param_offload=True \ 58 | algorithm.kl_ctrl.kl_coef=0.001 \ 59 | trainer.critic_warmup=0 \ 60 | trainer.logger=['console','wandb'] \ 61 | trainer.project_name='deepscaler' \ 62 | trainer.experiment_name='deepscaler-7b-16k' \ 63 | +trainer.val_before_train=True \ 64 | trainer.n_gpus_per_node=8 \ 65 | trainer.nnodes=1 \ 66 | trainer.save_freq=20 \ 67 | trainer.test_freq=20 \ 68 | trainer.default_hdfs_dir=null \ 69 | trainer.total_epochs=30 "${@:1}" 70 | 71 | -------------------------------------------------------------------------------- /scripts/inference/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline 3 | 4 | model_name = './model' 5 | tokenizer = AutoTokenizer.from_pretrained(model_name) 6 | model = AutoModelForCausalLM.from_pretrained( 7 | model_name, 8 | torch_dtype=torch.bfloat16, 9 | device_map="auto", 10 | trust_remote_code=True, 11 | ) 12 | 13 | pipe = pipeline( 14 | "text-generation", 15 | model=model, 16 | tokenizer = tokenizer, 17 | torch_dtype=torch.bfloat16, 18 | device_map="auto" 19 | ) 20 | 21 | 22 | 23 | prompt = "Can you explain the concept of regularization in machine learning?" 24 | 25 | sequences = pipe( 26 | prompt, 27 | do_sample=True, 28 | max_new_tokens=100, 29 | temperature=0.7, 30 | top_k=50, 31 | top_p=0.95, 32 | num_return_sequences=1, 33 | ) 34 | print(sequences[0]['generated_text']) 35 | 36 | -------------------------------------------------------------------------------- /scripts/train/attn.py: -------------------------------------------------------------------------------- 1 | import math 2 | from types import MethodType 3 | from typing import Optional, Tuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from einops import rearrange 9 | from transformers.models.llama.configuration_llama import LlamaConfig 10 | from transformers.models.llama.modeling_llama import ( 11 | LlamaAttention, 12 | LlamaForCausalLM, 13 | LlamaModel, 14 | LlamaRMSNorm, 15 | apply_rotary_pos_emb, 16 | repeat_kv, 17 | ) 18 | 19 | from colossalai.accelerator import get_accelerator 20 | from colossalai.logging import get_dist_logger 21 | 22 | logger = get_dist_logger() 23 | 24 | if get_accelerator().name == "cuda": 25 | from flash_attn.bert_padding import pad_input, unpad_input 26 | from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func 27 | from flash_attn.ops.rms_norm import rms_norm 28 | 29 | def _prepare_decoder_attention_mask( 30 | self: LlamaModel, 31 | attention_mask: torch.BoolTensor, 32 | input_shape: torch.Size, 33 | inputs_embeds: torch.Tensor, 34 | past_key_values_length: int, 35 | ) -> Optional[torch.Tensor]: 36 | """ 37 | Decoder attetion mask 38 | """ 39 | if past_key_values_length > 0 and attention_mask is not None: 40 | attention_mask = torch.cat( 41 | tensors=( 42 | torch.full( 43 | size=(input_shape[0], past_key_values_length), 44 | fill_value=True, 45 | dtype=attention_mask.dtype, 46 | device=attention_mask.device, 47 | ), 48 | attention_mask, 49 | ), 50 | dim=-1, 51 | ) # (bsz, past_key_values_length + q_len) 52 | if attention_mask is not None and torch.all(attention_mask): 53 | return None # Faster 54 | return attention_mask 55 | 56 | def attention_forward( 57 | self: LlamaAttention, 58 | hidden_states: torch.Tensor, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | position_ids: Optional[torch.LongTensor] = None, 61 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 62 | output_attentions: bool = False, 63 | use_cache: bool = False, 64 | **kwargs, 65 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 66 | """ 67 | Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. 68 | """ 69 | if output_attentions: 70 | logger.warning( 71 | "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " 72 | "return `None` instead." 73 | ) 74 | 75 | bsz, q_len, _ = hidden_states.size() 76 | 77 | if self.config.pretraining_tp > 1: 78 | q_slicing, kv_slicing = ( 79 | dim // self.config.pretraining_tp 80 | for dim in ( 81 | self.num_heads * self.head_dim, 82 | self.num_key_value_heads * self.head_dim, 83 | ) 84 | ) # `Tuple[int, int]` 85 | q_slices, k_slices, v_slices = ( 86 | proj.weight.split(slicing, dim=0) 87 | for proj, slicing in ( 88 | (self.q_proj, q_slicing), 89 | (self.k_proj, kv_slicing), 90 | (self.v_proj, kv_slicing), 91 | ) 92 | ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] 93 | q, k, v = ( 94 | torch.cat( 95 | [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], 96 | dim=-1, 97 | ) 98 | for slices in (q_slices, k_slices, v_slices) 99 | ) 100 | # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: 101 | # (bsz, q_len, num_heads * head_dim), 102 | # (bsz, q_len, num_key_value_heads * head_dim), 103 | # (bsz, q_len, num_key_value_heads * head_dim) 104 | else: 105 | q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) 106 | # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: 107 | # (bsz, q_len, num_heads * head_dim), 108 | # (bsz, q_len, num_key_value_heads * head_dim), 109 | # (bsz, q_len, num_key_value_heads * head_dim) 110 | 111 | # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); 112 | # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); 113 | # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) 114 | q, k, v = ( 115 | states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) 116 | for states, num_heads in ( 117 | (q, self.num_heads), 118 | (k, self.num_key_value_heads), 119 | (v, self.num_key_value_heads), 120 | ) 121 | ) 122 | kv_len = k.shape[-2] # initially, `kv_len` == `q_len` 123 | past_kv_len = 0 124 | if past_key_value is not None: 125 | # if `past_key_value` is not None, `kv_len` > `q_len`. 126 | past_kv_len = past_key_value[0].shape[-2] 127 | kv_len += past_kv_len 128 | 129 | # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) 130 | cos, sin = self.rotary_emb(v, seq_len=kv_len) 131 | # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) 132 | q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) 133 | if past_key_value is not None: 134 | # reuse k, v, self_attention 135 | k = torch.cat([past_key_value[0], k], dim=2) 136 | v = torch.cat([past_key_value[1], v], dim=2) 137 | 138 | past_key_value = (k, v) if use_cache else None 139 | 140 | # repeat k/v heads if n_kv_heads < n_heads 141 | k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) 142 | # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) 143 | v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) 144 | # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) 145 | 146 | key_padding_mask = attention_mask 147 | # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) 148 | q, k, v = (states.transpose(1, 2) for states in (q, k, v)) 149 | 150 | if past_kv_len > 0: 151 | q = torch.cat( 152 | tensors=( 153 | torch.full( 154 | size=(bsz, past_kv_len, self.num_heads, self.head_dim), 155 | fill_value=0.0, 156 | dtype=q.dtype, 157 | device=q.device, 158 | ), 159 | q, 160 | ), 161 | dim=1, 162 | ) # (bsz, past_kv_len + q_len, num_heads, head_dim) 163 | 164 | if key_padding_mask is None: 165 | # (bsz, past_kv_len + q_len, num_heads, head_dim) 166 | output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) 167 | output = rearrange( 168 | output, pattern="... h d -> ... (h d)" 169 | ) # (bsz, past_kv_len + q_len, num_heads * head_dim) 170 | else: 171 | q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) 172 | kv, _, cu_kv_lens, max_kv_len = unpad_input( 173 | hidden_states=torch.stack(tensors=(k, v), dim=2), 174 | attention_mask=key_padding_mask, 175 | ) 176 | output_unpad = flash_attn_varlen_kvpacked_func( 177 | q=q, 178 | kv=kv, 179 | cu_seqlens_q=cu_q_lens, 180 | cu_seqlens_k=cu_kv_lens, 181 | max_seqlen_q=max_q_len, 182 | max_seqlen_k=max_kv_len, 183 | dropout_p=0.0, 184 | softmax_scale=None, 185 | causal=True, 186 | ) 187 | output = pad_input( 188 | hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), 189 | indices=indices, 190 | batch=bsz, 191 | seqlen=past_kv_len + q_len, 192 | ) # (bsz, past_kv_len + q_len, num_heads * head_dim) 193 | 194 | if past_kv_len > 0: 195 | # Strip off the zero query outputs. 196 | output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) 197 | output = self.o_proj(output) # (bsz, q_len, hidden_size) 198 | return output, None, past_key_value 199 | 200 | def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: 201 | """ 202 | Formard function for RMS Norm 203 | """ 204 | return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) 205 | 206 | def replace_with_flash_attention(model: LlamaForCausalLM) -> None: 207 | for name, module in model.named_modules(): 208 | if isinstance(module, LlamaAttention): 209 | module.forward = MethodType(attention_forward, module) 210 | if isinstance(module, LlamaModel): 211 | module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) 212 | if isinstance(module, LlamaRMSNorm): 213 | module.forward = MethodType(rms_norm_forward, module) 214 | 215 | elif get_accelerator().name == "npu": 216 | import torch_npu 217 | 218 | class NPULlamaAttention(LlamaAttention): 219 | use_flash: bool = True 220 | 221 | def __init__(self, config: LlamaConfig): 222 | super().__init__(config) 223 | self.setup() 224 | 225 | def setup(self): 226 | self._softmax_scale = 1 / math.sqrt(self.head_dim) 227 | 228 | def forward( 229 | self, 230 | hidden_states: torch.Tensor, 231 | attention_mask: Optional[torch.Tensor] = None, 232 | position_ids: Optional[torch.LongTensor] = None, 233 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 234 | output_attentions: bool = False, 235 | use_cache: bool = False, 236 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 237 | bsz, q_len, _ = hidden_states.size() 238 | 239 | if self.config.pretraining_tp > 1: 240 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp 241 | query_slices = self.q_proj.weight.split( 242 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 243 | ) 244 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 245 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 246 | 247 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] 248 | query_states = torch.cat(query_states, dim=-1) 249 | 250 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] 251 | key_states = torch.cat(key_states, dim=-1) 252 | 253 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] 254 | value_states = torch.cat(value_states, dim=-1) 255 | 256 | else: 257 | query_states = self.q_proj(hidden_states) 258 | key_states = self.k_proj(hidden_states) 259 | value_states = self.v_proj(hidden_states) 260 | 261 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 262 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 263 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 264 | 265 | kv_seq_len = key_states.shape[-2] 266 | if past_key_value is not None: 267 | kv_seq_len += past_key_value[0].shape[-2] 268 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 269 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 270 | 271 | if past_key_value is not None: 272 | # reuse k, v, self_attention 273 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 274 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 275 | 276 | past_key_value = (key_states, value_states) if use_cache else None 277 | 278 | key_states = repeat_kv(key_states, self.num_key_value_groups) 279 | value_states = repeat_kv(value_states, self.num_key_value_groups) 280 | 281 | if not self.use_flash: 282 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 283 | 284 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 285 | raise ValueError( 286 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 287 | f" {attn_weights.size()}" 288 | ) 289 | 290 | if attention_mask is not None: 291 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 292 | raise ValueError( 293 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 294 | ) 295 | attn_weights = attn_weights + attention_mask 296 | 297 | # upcast attention to fp32 298 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 299 | attn_output = torch.matmul(attn_weights, value_states) 300 | else: 301 | attn_output, *_ = torch_npu.npu_fusion_attention( 302 | query_states, 303 | key_states, 304 | value_states, 305 | self.num_heads, 306 | "BNSD", 307 | atten_mask=attention_mask.bool(), 308 | scale=self._softmax_scale, 309 | padding_mask=None, 310 | pre_tockens=65535, 311 | next_tockens=0, 312 | keep_prob=1.0, 313 | inner_precise=0, 314 | ) 315 | 316 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 317 | raise ValueError( 318 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 319 | f" {attn_output.size()}" 320 | ) 321 | 322 | attn_output = attn_output.transpose(1, 2).contiguous() 323 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 324 | 325 | if self.config.pretraining_tp > 1: 326 | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) 327 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) 328 | attn_output = sum( 329 | [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)] 330 | ) 331 | else: 332 | attn_output = self.o_proj(attn_output) 333 | 334 | if not output_attentions: 335 | attn_weights = None 336 | 337 | return attn_output, attn_weights, past_key_value 338 | 339 | class NPURMSNorm(LlamaRMSNorm): 340 | def forward(self, hidden_states): 341 | return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] 342 | 343 | def replace_with_flash_attention(model: LlamaForCausalLM) -> None: 344 | for name, module in model.named_modules(): 345 | if isinstance(module, LlamaAttention): 346 | module.__class__ = NPULlamaAttention 347 | module.setup() 348 | if isinstance(module, LlamaRMSNorm): 349 | module.__class__ = NPURMSNorm 350 | -------------------------------------------------------------------------------- /scripts/train/pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import resource 4 | from contextlib import nullcontext 5 | from functools import partial 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.nn as nn 11 | from attn import replace_with_flash_attention 12 | 13 | 14 | import sys 15 | sys.path.append("..") 16 | 17 | # datasets 2.18.0 18 | # fsspec 2024.2.0 19 | 20 | from data_utils import load_json, prepare_dataloader, save_json 21 | from datasets import load_dataset, load_from_disk 22 | from torch.optim import Optimizer 23 | from torch.optim.lr_scheduler import _LRScheduler 24 | from torch.utils.tensorboard import SummaryWriter 25 | from tqdm import tqdm 26 | from transformers.models.llama.configuration_llama import LlamaConfig 27 | from transformers.models.llama.modeling_llama import LlamaForCausalLM 28 | from transformers.models.llama.tokenization_llama import LlamaTokenizer 29 | 30 | import colossalai 31 | from colossalai.accelerator import get_accelerator 32 | from colossalai.booster import Booster 33 | from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin 34 | from colossalai.cluster import DistCoordinator 35 | from colossalai.lazy import LazyInitContext 36 | from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR 37 | from colossalai.nn.optimizer import HybridAdam 38 | 39 | MODEL_CONFIGS = { 40 | "7b": LlamaConfig(max_position_embeddings=4096), 41 | "13b": LlamaConfig( 42 | hidden_size=5120, 43 | intermediate_size=13824, 44 | num_hidden_layers=40, 45 | num_attention_heads=40, 46 | max_position_embeddings=4096, 47 | ), 48 | "70b": LlamaConfig( 49 | hidden_size=8192, 50 | intermediate_size=28672, 51 | num_hidden_layers=80, 52 | num_attention_heads=64, 53 | max_position_embeddings=4096, 54 | num_key_value_heads=8, 55 | ), 56 | } 57 | 58 | 59 | def get_model_numel(model: nn.Module) -> int: 60 | return sum(p.numel() for p in model.parameters()) 61 | 62 | 63 | def format_numel_str(numel: int) -> str: 64 | B = 1024**3 65 | M = 1024**2 66 | K = 1024 67 | if numel >= B: 68 | return f"{numel / B:.2f} B" 69 | elif numel >= M: 70 | return f"{numel / M:.2f} M" 71 | elif numel >= K: 72 | return f"{numel / K:.2f} K" 73 | else: 74 | return f"{numel}" 75 | 76 | 77 | def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): 78 | texts = [sample["text"] for sample in batch] 79 | data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) 80 | data = {k: v.cuda() for k, v in data.items()} 81 | data["labels"] = data["input_ids"].clone() 82 | return data 83 | 84 | 85 | def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: 86 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 87 | tensor = tensor.data 88 | tensor.div_(dist.get_world_size()) 89 | return tensor 90 | 91 | 92 | def save( 93 | booster: Booster, 94 | model: nn.Module, 95 | optimizer: Optimizer, 96 | lr_scheduler: _LRScheduler, 97 | epoch: int, 98 | step: int, 99 | batch_size: int, 100 | coordinator: DistCoordinator, 101 | save_dir: str, 102 | ): 103 | save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") 104 | os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) 105 | 106 | booster.save_model(model, os.path.join(save_dir, "model"), shard=True) 107 | booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) 108 | booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) 109 | running_states = { 110 | "epoch": epoch, 111 | "step": step, 112 | "sample_start_index": step * batch_size, 113 | } 114 | if coordinator.is_master(): 115 | save_json(running_states, os.path.join(save_dir, "running_states.json")) 116 | 117 | 118 | def load( 119 | booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str 120 | ) -> Tuple[int, int, int]: 121 | booster.load_model(model, os.path.join(load_dir, "model")) 122 | 123 | running_states = load_json(os.path.join(load_dir, "running_states.json")) 124 | return running_states["epoch"], running_states["step"], running_states["sample_start_index"] 125 | 126 | 127 | def _criterion(outputs, inputs): 128 | return outputs.loss 129 | 130 | 131 | def main(): 132 | # ============================== 133 | # Parse Arguments 134 | # ============================== 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration") 137 | parser.add_argument( 138 | "-p", 139 | "--plugin", 140 | choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], 141 | default="gemini", 142 | help="Choose which plugin to use", 143 | ) 144 | parser.add_argument( 145 | "-d", "--dataset", type=str, default="togethercomputer/RedPajama-Data-1T-Sample", help="Data set path" 146 | ) 147 | parser.add_argument("--cache_path", type=str, default="workspace/.cache/huggingface/datasets", help="cache path") 148 | parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") 149 | parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") 150 | parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") 151 | parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") 152 | parser.add_argument("-s", "--warmup_steps", type=int, default=2000, help="Warmup steps") 153 | parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") 154 | parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") 155 | parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") 156 | parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") 157 | parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") 158 | parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") 159 | parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") 160 | parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") 161 | parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") 162 | 163 | parser.add_argument("--accumulation_steps", default=8, help="accumulation steps") 164 | parser.add_argument("--expanded_model", default="", help="model path") 165 | 166 | 167 | args = parser.parse_args() 168 | 169 | # ============================== 170 | # Initialize Distributed Training 171 | # ============================== 172 | colossalai.launch_from_torch({}) 173 | coordinator = DistCoordinator() 174 | 175 | # ============================== 176 | # Initialize Booster 177 | # ============================== 178 | if args.plugin == "gemini": 179 | plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) 180 | elif args.plugin == "gemini_auto": 181 | plugin = GeminiPlugin( 182 | precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip 183 | ) 184 | elif args.plugin == "zero2": 185 | plugin = LowLevelZeroPlugin( 186 | stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip 187 | ) 188 | elif args.plugin == "zero2_cpu": 189 | plugin = LowLevelZeroPlugin( 190 | stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip 191 | ) 192 | elif args.plugin == "hybrid_parallel": 193 | plugin = HybridParallelPlugin( 194 | tp_size=4, 195 | pp_size=2, 196 | num_microbatches=None, 197 | microbatch_size=1, 198 | enable_jit_fused=False, 199 | zero_stage=0, 200 | precision=args.mixed_precision, 201 | initial_scale=1, 202 | ) 203 | else: 204 | raise ValueError(f"Unknown plugin {args.plugin}") 205 | 206 | booster = Booster(plugin=plugin) 207 | 208 | use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 209 | is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() 210 | print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) 211 | 212 | # ============================== 213 | # Initialize Tensorboard 214 | # ============================== 215 | if print_flag: 216 | os.makedirs(args.tensorboard_dir, exist_ok=True) 217 | writer = SummaryWriter(args.tensorboard_dir) 218 | 219 | # ============================== 220 | # Initialize Tokenizer, Dataset and Dataloader 221 | # ============================== 222 | 223 | from transformers import AutoTokenizer 224 | tokenizer = AutoTokenizer.from_pretrained(args.expanded_model, use_fast=False) 225 | tokenizer.pad_token = tokenizer.unk_token 226 | tokenizer.padding_side = 'left' 227 | # ================================================================ 228 | 229 | dataset = load_dataset(args.dataset, cache_dir=args.cache_path) 230 | 231 | train_ds = dataset["train"] 232 | dataloader = prepare_dataloader( 233 | train_ds, 234 | batch_size=args.batch_size, 235 | shuffle=True, 236 | drop_last=True, 237 | collate_fn=partial(tokenize_batch_for_pretrain, tokenizer=tokenizer, max_length=args.max_length), 238 | ) 239 | 240 | # ============================== 241 | # Initialize Model, Optimizer and LR Scheduler 242 | # ============================== 243 | config = MODEL_CONFIGS[args.config] 244 | init_ctx = ( 245 | LazyInitContext(default_device=get_accelerator().get_current_device()) 246 | if isinstance(plugin, GeminiPlugin) 247 | else nullcontext() 248 | ) 249 | 250 | with init_ctx: 251 | 252 | model = LlamaForCausalLM.from_pretrained(args.expanded_model, torch_dtype=torch.float16) 253 | 254 | if args.grad_checkpoint: 255 | model.gradient_checkpointing_enable() 256 | if args.flash_attention: 257 | replace_with_flash_attention(model) 258 | 259 | model_numel = get_model_numel(model) 260 | coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") 261 | 262 | optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) 263 | 264 | lr_scheduler = CosineAnnealingWarmupLR( 265 | optimizer, total_steps=args.num_epochs * len(dataloader), warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr 266 | ) 267 | default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 268 | torch.set_default_dtype(default_dtype) 269 | model, optimizer, _, dataloader, lr_scheduler = booster.boost( 270 | model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler 271 | ) 272 | torch.set_default_dtype(torch.float) 273 | 274 | coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") 275 | coordinator.print_on_master( 276 | f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" 277 | ) 278 | 279 | # load checkpoint if specified 280 | start_epoch = 0 281 | start_step = 0 282 | sampler_start_idx = 0 283 | if args.load is not None: 284 | coordinator.print_on_master("Loading checkpoint") 285 | start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) 286 | coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") 287 | 288 | num_steps_per_epoch = len(dataloader) 289 | 290 | dataloader.sampler.set_start_index(sampler_start_idx) 291 | for epoch in range(start_epoch, args.num_epochs): 292 | dataloader.sampler.set_epoch(epoch) 293 | dataloader_iter = iter(dataloader) 294 | 295 | for step in range(start_step, num_steps_per_epoch): 296 | if use_pipeline: 297 | outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True) 298 | loss = outputs["loss"] 299 | else: 300 | batch = next(dataloader_iter) 301 | outputs = model(**batch) 302 | loss = outputs[0] 303 | booster.backward(loss, optimizer) 304 | 305 | if (step + 1) % args.accumulation_steps == 0: 306 | optimizer.step() # Update parameters 307 | lr_scheduler.step() 308 | optimizer.zero_grad() # Reset gradients 309 | 310 | if step + 1 == num_steps_per_epoch - 1: 311 | optimizer.step() 312 | lr_scheduler.step() 313 | optimizer.zero_grad() 314 | 315 | if not use_pipeline: 316 | all_reduce_mean(loss) 317 | 318 | if print_flag: 319 | writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) 320 | print("Epoch: {}, step: {}, loss: {:.3f}".format( 321 | epoch, 322 | epoch * num_steps_per_epoch + step, 323 | loss.item() 324 | )) 325 | 326 | if args.save_interval > 0 and (step + 1) % args.save_interval == 0: 327 | coordinator.print_on_master(f"Saving checkpoint") 328 | save( 329 | booster, 330 | model, 331 | optimizer, 332 | lr_scheduler, 333 | epoch, 334 | step + 1, 335 | args.batch_size, 336 | coordinator, 337 | args.save_dir, 338 | ) 339 | coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") 340 | dataloader.sampler.set_start_index(0) 341 | start_step = 0 342 | 343 | coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") 344 | 345 | 346 | if __name__ == "__main__": 347 | main() 348 | -------------------------------------------------------------------------------- /scripts/train/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiohttp==3.9.3 3 | aiosignal==1.3.1 4 | annotated-types==0.6.0 5 | async-timeout==4.0.3 6 | attrs==23.2.0 7 | bcrypt==4.1.2 8 | beautifulsoup4==4.12.3 9 | cachetools==5.3.3 10 | certifi==2024.2.2 11 | cffi==1.16.0 12 | cfgv==3.4.0 13 | charset-normalizer==3.3.2 14 | click==8.1.7 15 | cmake==3.29.0.1 16 | colossalai==0.3.6 17 | contexttimer==0.3.3 18 | cryptography==42.0.5 19 | datasets==2.18.0 20 | decorator==5.1.1 21 | Deprecated==1.2.14 22 | dill==0.3.8 23 | distlib==0.3.8 24 | einops==0.7.0 25 | fabric==3.2.2 26 | filelock==3.13.3 27 | flash-attn==2.2.1 28 | frozenlist==1.4.1 29 | fsspec==2024.2.0 30 | google==3.0.0 31 | google-auth==2.29.0 32 | google-auth-oauthlib==1.0.0 33 | grpcio==1.62.1 34 | huggingface-hub==0.22.2 35 | identify==2.5.35 36 | idna==3.6 37 | invoke==2.2.0 38 | Jinja2==3.1.3 39 | jsonschema==4.21.1 40 | jsonschema-specifications==2023.12.1 41 | lit==18.1.2 42 | Markdown==3.6 43 | markdown-it-py==3.0.0 44 | MarkupSafe==2.1.5 45 | mdurl==0.1.2 46 | mpmath==1.3.0 47 | msgpack==1.0.8 48 | multidict==6.0.5 49 | multiprocess==0.70.16 50 | networkx== 3.3 51 | ninja==1.11.1.1 52 | nodeenv==1.8.0 53 | numpy==1.26.4 54 | nvidia-cublas-cu11==11.10.3.66 55 | nvidia-cublas-cu12==12.1.3.1 56 | nvidia-cuda-cupti-cu11==11.7.101 57 | nvidia-cuda-cupti-cu12==12.1.105 58 | nvidia-cuda-nvrtc-cu11==11.7.99 59 | nvidia-cuda-nvrtc-cu12==12.1.105 60 | nvidia-cuda-runtime-cu11==11.7.99 61 | nvidia-cuda-runtime-cu12==12.1.105 62 | nvidia-cudnn-cu11==8.5.0.96 63 | nvidia-cudnn-cu12==8.9.2.26 64 | nvidia-cufft-cu11==10.9.0.58 65 | nvidia-cufft-cu12==11.0.2.54 66 | nvidia-curand-cu11==10.2.10.91 67 | nvidia-curand-cu12==10.3.2.106 68 | nvidia-cusolver-cu11==11.4.0.1 69 | nvidia-cusolver-cu12==11.4.5.107 70 | nvidia-cusparse-cu11==11.7.4.91 71 | nvidia-cusparse-cu12==12.1.0.106 72 | nvidia-nccl-cu11==2.14.3 73 | nvidia-nccl-cu12==2.19.3 74 | nvidia-nvjitlink-cu12==12.4.127 75 | nvidia-nvtx-cu11==11.7.91 76 | nvidia-nvtx-cu12==12.1.105 77 | oauthlib==3.2.2 78 | packaging==24.0 79 | pandas==2.2.1 80 | paramiko==3.4.0 81 | pip==23.3.1 82 | platformdirs==4.2.0 83 | pre-commit==3.7.0 84 | protobuf==5.26.1 85 | psutil==5.9.8 86 | pyarrow==15.0.2 87 | pyarrow-hotfix==0.6 88 | pyasn1==0.6.0 89 | pyasn1_modules==0.4.0 90 | pycparser==2.22 91 | pydantic==2.6.4 92 | pydantic_core==2.16.3 93 | Pygments==2.17.2 94 | PyNaCl==1.5.0 95 | python-dateutil==2.9.0.post0 96 | pytz==2024.1 97 | PyYAML==6.0.1 98 | ray==2.10.0 99 | referencing==0.34.0 100 | regex==2023.12.25 101 | requests==2.31.0 102 | requests-oauthlib==2.0.0 103 | rich==13.7.1 104 | rpds-py==0.18.0 105 | rsa==4.9 106 | safetensors==0.4.2 107 | sentencepiece==0.1.99 108 | setuptools==68.2.2 109 | six==1.16.0 110 | soupsieve==2.5 111 | sympy==1.12 112 | tensorboard==2.14.0 113 | tensorboard-data-server==0.7.2 114 | tokenizers==0.13.3 115 | torch==2.0.0 116 | tqdm==4.66.2 117 | transformers==4.33.3 118 | triton==2.0.0 119 | typing_extensions==4.11.0 120 | tzdata==2024.1 121 | urllib3==2.2.1 122 | virtualenv==20.25.1 123 | Werkzeug==3.0.2 124 | wheel==0.41.2 125 | wrapt==1.16.0 126 | xxhash==3.4.1 127 | yarl==1.9.4 128 | --------------------------------------------------------------------------------