├── .github ├── Pythia_saturation.png ├── TinyLlama_logo.png └── llama2-training.png ├── .gitignore ├── EVAL.md ├── LICENSE ├── PRETRAIN.md ├── README.md ├── README_zh-CN.md ├── chat_gradio ├── README.md ├── app.py └── requirements.txt ├── lit_gpt ├── __init__.py ├── adapter.py ├── adapter_v2.py ├── config.py ├── fused_cross_entropy.py ├── fused_rotary_embedding.py ├── lora.py ├── model.py ├── packed_dataset.py ├── rmsnorm.py ├── speed_monitor.py ├── tokenizer.py └── utils.py ├── pretrain ├── tinyllama.py └── tinyllama_code.py ├── requirements.txt ├── script.sh ├── scripts ├── convert_hf_checkpoint.py ├── convert_lit_checkpoint.py ├── prepare_redpajama.py ├── prepare_slimpajama.py └── prepare_starcoder.py ├── sft ├── finetune.py ├── script.sh ├── simple_inference.py └── simple_inference2.py └── speculative_decoding ├── README.md └── instruct_hf_assisted_decoding.py /.github/Pythia_saturation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jzhang38/TinyLlama/bf122247c486b6b897050e98cbb7bedae8eeba73/.github/Pythia_saturation.png -------------------------------------------------------------------------------- /.github/TinyLlama_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jzhang38/TinyLlama/bf122247c486b6b897050e98cbb7bedae8eeba73/.github/TinyLlama_logo.png -------------------------------------------------------------------------------- /.github/llama2-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jzhang38/TinyLlama/bf122247c486b6b897050e98cbb7bedae8eeba73/.github/llama2-training.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | .DS_Store 4 | *.egg-info 5 | build 6 | .venv 7 | .vscode 8 | 9 | # data 10 | data 11 | checkpoints 12 | out 13 | wandb 14 | 15 | tests/original_falcon_40b.py 16 | sft/output 17 | sft/wandb -------------------------------------------------------------------------------- /EVAL.md: -------------------------------------------------------------------------------- 1 | ## Evaluate TinyLlama 2 | 3 | ### GPT4All Benchmarks 4 | 5 | We evaluate TinyLlama's commonsense reasoning ability following the [GPT4All](https://gpt4all.io/index.html) evaluation suite. We include Pythia as our baseline. We report the acc_norm by default. 6 | 7 | Base models: 8 | 9 | | Model | Pretrain Tokens | HellaSwag | Obqa | WinoGrande | ARC_c | ARC_e | boolq | piqa | avg | 10 | |-------------------------------------------|-----------------|-----------|------|------------|-------|-------|-------|------|-----| 11 | | Pythia-1.0B | 300B | 47.16 | 31.40| 53.43 | 27.05 | 48.99 | 60.83 | 69.21 | 48.30 | 12 | | TinyLlama-1.1B-intermediate-step-50K-104b | 103B | 43.50 | 29.80| 53.28 | 24.32 | 44.91 | 59.66 | 67.30 | 46.11| 13 | | TinyLlama-1.1B-intermediate-step-240k-503b| 503B | 49.56 |31.40 |55.80 |26.54 |48.32 |56.91 |69.42 | 48.28 | 14 | | TinyLlama-1.1B-intermediate-step-480k-1007B | 1007B | 52.54 | 33.40 | 55.96 | 27.82 | 52.36 | 59.54 | 69.91 | 50.22 | 15 | | TinyLlama-1.1B-intermediate-step-715k-1.5T | 1.5T | 53.68 | 35.20 | 58.33 | 29.18 | 51.89 | 59.08 | 71.65 | 51.29 | 16 | | TinyLlama-1.1B-intermediate-step-955k-2T | 2T | 54.63 | 33.40 | 56.83 | 28.07 | 54.67 | 63.21 | 70.67 | 51.64 | 17 | | TinyLlama-1.1B-intermediate-step-1195k-2.5T | 2.5T | 58.96 | 34.40 | 58.72 | 31.91 | 56.78 | 63.21 | 73.07 | 53.86| 18 | | TinyLlama-1.1B-intermediate-step-1431k-3T | 3T | 59.20 | 36.00 | 59.12 | 30.12 | 55.25 | 57.83 | 73.29 | 52.99| 19 | 20 | 21 | Chat models: 22 | | Model | Pretrain Tokens | HellaSwag | Obqa | WinoGrande | ARC_c | ARC_e | boolq | piqa | avg | 23 | |-------------------------------------------|-----------------|-----------|------|------------|-------|-------|-------|------|-----| 24 | | [TinyLlama-1.1B-Chat-v0.1](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.1) | 503B | 53.81 |32.20 | 55.01 | 28.67 |49.62 | 58.04 | 69.64 | 49.57 | 25 | | [TinyLlama-1.1B-Chat-v0.2](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.2) | 503B | 53.63 |32.80 | 54.85 | 28.75 |49.16 | 55.72 | 69.48 | 49.20 | 26 | | [TinyLlama-1.1B-Chat-v0.3](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.3) | 1T | 56.81 |34.20 | 55.80 | 30.03 |53.20 | 59.57 | 69.91 | 51.36 | 27 | | [TinyLlama-1.1B-Chat-v0.4](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4) | 1.5T | 58.59 |35.40 | 58.80 | 30.80 |54.04 | 57.31 | 71.16 | 52.30 | 28 | 29 | 30 | We observed huge improvements once we finetuned the model. We attribute this phenomenon to: 1. the base model has not undergone lr cool-down and FT helps to cool down the lr. 2. the SFT stage better elicits the model's internal knowledge. 31 | 32 | You can obtain the above scores by running [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness): 33 | ```bash 34 | python main.py \ 35 | --model hf-causal \ 36 | --model_args pretrained=PY007/TinyLlama-1.1B-Chat-v0.1,dtype="float" \ 37 | --tasks hellaswag,openbookqa,winogrande,arc_easy,arc_challenge,boolq,piqa\ 38 | --device cuda:0 --batch_size 32 39 | ``` 40 | 41 | 42 | 43 | ### Instruct-Eval Benchmarks 44 | We evaluate TinyLlama's ability in problem-solving on the [Instruct-Eval](https://github.com/declare-lab/instruct-eval) evaluation suite. 45 | 46 | 47 | | Model | MMLU | BBH | HumanEval | DROP | 48 | | ------------------------------------------------- | ----- | ----- | --------- | ----- | 49 | | Pythia-1.0B | 25.70 | 28.19 | 1.83 | 4.25 | 50 | | TinyLlama-1.1B-intermediate-step-50K-104b | 26.45 | 28.82 | 5.49 | 11.42 | 51 | | TinyLlama-1.1B-intermediate-step-240k-503b | 26.16 | 28.83 | 4.88 | 12.43 | 52 | | TinyLlama-1.1B-intermediate-step-480K-1T | 24.65 | 29.21 | 6.1 | 13.03 | 53 | | TinyLlama-1.1B-intermediate-step-715k-1.5T | 24.85 | 28.2 | 7.93 | 14.43 | 54 | | TinyLlama-1.1B-intermediate-step-955k-2T | 25.97 | 29.07 | 6.71 | 13.14 | 55 | | TinyLlama-1.1B-intermediate-step-1195k-token-2.5T | 25.92 | 29.32 | 9.15 | 15.45 | 56 | 57 | You can obtain above scores by running [instruct-eval](https://github.com/declare-lab/instruct-eval): 58 | ```bash 59 | CUDA_VISIBLE_DEVICES=0 python main.py mmlu --model_name llama --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T 60 | CUDA_VISIBLE_DEVICES=1 python main.py bbh --model_name llama --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T 61 | CUDA_VISIBLE_DEVICES=2 python main.py drop --model_name llama --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T 62 | CUDA_VISIBLE_DEVICES=3 python main.py humaneval --model_name llama --n_sample 1 --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T 63 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2023] Lightning AI 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /PRETRAIN.md: -------------------------------------------------------------------------------- 1 | ## Pretrain TinyLlama 2 | 3 | ### Installation 4 | We expect you have CUDA 11.8 installed. 5 | #### Install Pytorch Nightly. 6 | ```bash 7 | pip install --index-url https://download.pytorch.org/whl/nightly/cu118 --pre 'torch>=2.1.0dev' 8 | ``` 9 | #### Build XFormers from Source 10 | Note: as of 2023/09/02, xformers does not provide pre-built binaries for torch 2.1. You have to build it from source. 11 | ```bash 12 | pip uninstall ninja -y && pip install ninja -U 13 | pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers 14 | ``` 15 | 16 | 17 | #### Install Flash-Attention 2 and other fused operators: 18 | ```bash 19 | git clone https://github.com/Dao-AILab/flash-attention 20 | cd flash-attention 21 | python setup.py install 22 | cd csrc/rotary && pip install . 23 | cd ../layer_norm && pip install . 24 | cd ../xentropy && pip install . 25 | cd ../.. && rm -rf flash-attention 26 | ``` 27 | #### Install Remaining Dependencies 28 | ``` 29 | pip install -r requirements.txt tokenizers sentencepiece 30 | ``` 31 | to install other dependencies. 32 | It may take >= 5 minutes to build xformers/flash-attention. Do not worry if the process seemly stagnant or the terminal print out many warnings. 33 | 34 | Then you are ready to go 🎉! 35 | 36 | ### Data Preparation 37 | 38 | #### Download Datasets 39 | Download the Slimpajama and Starcoderdata datasets to your chosen directory. 40 | ```bash 41 | cd /path/to/dataset 42 | git lfs install 43 | git clone https://huggingface.co/datasets/cerebras/SlimPajama-627B 44 | git clone https://huggingface.co/datasets/bigcode/starcoderdata 45 | ``` 46 | The SlimPajama dataset eats 893GB diskspace and the starcoderdata takes 290GB. 47 | 48 | #### Tokenize data 49 | Use the provided scripts to tokenize the datasets and divide them into chunks. 50 | ```bash 51 | python scripts/prepare_starcoder.py --source_path /path/to/starcoderdata/ --tokenizer_path data/llama --destination_path data/slim_star_combined --split train --percentage 1.0 52 | python scripts/prepare_slimpajama.py --source_path /path/to/SlimPajama --tokenizer_path data/llama --destination_path data/slim_star_combined --split validation --percentage 1.0 53 | python scripts/prepare_slimpajama.py --source_path /path/to/SlimPajama --tokenizer_path data/llama --destination_path data/slim_star_combined --split train --percentage 1.0 54 | ``` 55 | The processed data will take 1.8T storage. 56 | 57 | ### Pretraining 58 | If your setup comprises two nodes, each with 8 GPUs, you can initiate pretraining with the following commands: 59 | 60 | On node 1: 61 | ``` 62 | lightning run model \ 63 | --node-rank=0 \ 64 | --main-address=172.16.101.5 \ 65 | --accelerator=cuda \ 66 | --devices=8 \ 67 | --num-nodes=2 \ 68 | pretrain/tinyllama.py --devices 8 --train_data_dir data/slim_star --val_data_dir data/slim_star 69 | ``` 70 | On node 2: 71 | ``` 72 | lightning run model \ 73 | --node-rank=1 \ 74 | --main-address=172.16.101.5 \ 75 | --accelerator=cuda \ 76 | --devices=8 \ 77 | --num-nodes=2 \ 78 | pretrain/tinyllama.py --devices 8 --train_data_dir data/slim_star --val_data_dir data/slim_star 79 | ``` 80 | You can follow [these instructions](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html) if you have a slurm cluster. 81 | 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # TinyLlama-1.1B 4 | English | [中文](README_zh-CN.md) 5 | 6 | [Chat Demo](https://huggingface.co/spaces/TinyLlama/tinyllama-chat) | [Discord](https://discord.gg/74Wcx4j5Nb) 7 |
8 | 9 | The TinyLlama project aims to **pretrain** a **1.1B Llama model on 3 trillion tokens**. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01. 10 | 11 |
12 | 13 |
14 | 15 | We adopted exactly the same architecture and tokenizer as Llama 2. This means TinyLlama can be plugged and played in many open-source projects built upon Llama. Besides, TinyLlama is compact with only 1.1B parameters. This compactness allows it to cater to a multitude of applications demanding a restricted computation and memory footprint. 16 | 17 | #### News 18 | - 2023-12-18: Add two notes [1](https://whimsical-aphid-86d.notion.site/Release-of-TinyLlama-1-5T-Checkpoints-Postponed-01b266998c1c47f78f5ae1520196d194?pvs=4), [2](https://whimsical-aphid-86d.notion.site/Latest-Updates-from-TinyLlama-Team-7d30c01fff794da28ccc952f327c8d4f?pvs=4) explaining the changes of training curves, project schedules, and bug fixes. 19 | - 2023-10-03: Add examples in speculative decoding with llama.cpp. Do check out the [speculative_decoding/README.md](speculative_decoding/README.md). 20 | - 2023-10-02: 1. 1T-token checkpoint just dropped. 2. We document **all** intermediate checkpoints [here](https://huggingface.co/TinyLlama/tinyLlama-intermediate-checkpoints/tree/step-480k-token-1007B). 21 | - 2023-09-28: Add a discord server. 22 | - 2023-09-18: 1. We added a [chat demo](https://huggingface.co/spaces/PY007/TinyLlama-Chat) so that you can play with TinyLlama-Chat-V0.1 right away. 23 | - 2023-09-16: 1. We released the intermediate checkpoint trained on 503B tokens. 2. We released a chat model finetuned on OpenAssisant and simple [finetuning](sft) scripts is added. 3. More eval benchmarks are added and documented in [EVAL.md](EVAL.md). 24 | 25 | #### Evaluation 26 | You can find the evaluation results of TinyLlama in [EVAL.md](EVAL.md). 27 | 28 | #### Releases Schedule 29 | We will be rolling out intermediate checkpoints following the below schedule. 30 | 31 | Base models: 32 | 33 | | Date | HF Checkpoint | Tokens | Step | Commonsense Avg | 34 | |------------|-------------------------------------------------|--------|------| --------------- | 35 | | 2023-09-01 | Pythia-1.0B | 300B | 143k | 48.30 | 36 | | 2023-09-04 | [TinyLlama-1.1B-intermediate-step-50k-105b](https://huggingface.co/PY007/TinyLlama-1.1B-step-50K-105b) | 105B | 50k | 46.11| 37 | | 2023-09-16 | [TinyLlama-1.1B-intermediate-step-240k-503b](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-240k-503b) | 503B | 240K | 48.28 | 38 | | 2023-10-01 | [TinyLlama-1.1B-intermediate-step-480k-1T](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-480k-1T) | 1T | 480k | 50.22 | 39 | | 2023-11-04 | [TinyLlama-1.1B-intermediate-step-715k-1.5T](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T) | 1.5T |715k |51.28 | 40 | | 2023-11-20 | [TinyLlama-1.1B-intermediate-step-955k-2T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T) | 2T |955k |51.64 | 41 | | 2023-12-11 | [TinyLlama-1.1B-intermediate-step-1195k-2.5T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T) | 2.5T | 1195k |53.86 | 42 | | 2023-12-28 | [TinyLlama-1.1B-intermediate-step-1431k-3T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T) | 3T | 1431k | 52.99 | 43 | 44 | We are crafting a note offering possible explaination on why there is a significant improvement from 2T to 2.5T checkpoint (It is related to [bos_id issue](https://github.com/jzhang38/TinyLlama/issues/83)) 45 | 46 | Chat models: 47 | 48 | | Date | HF Checkpoint | Tokens | Step | Commonsense Avg | 49 | |------------|-------------------------------------------------|--------|------| --------------- | 50 | | 2023-09-16 | [TinyLlama-1.1B-Chat-V0.1](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.1) | 503B | 240K | 49.57 | 51 | | 2023-10-1 | [TinyLlama-1.1B-Chat-V0.3](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.3) | 1T | 480K | 51.36 | 52 | | 2023-11-04 | [TinyLlama-1.1B-Chat-V0.4](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4) | 1.5T | 715K | 52.30 | 53 | 54 | Note that the learning rate of the base model has not cooled down yet so we recommend you to also use the finetuned chat model. 55 | 56 | Meanwhile, you can track the live cross entropy loss [here](https://wandb.ai/lance777/lightning_logs/reports/metric-train_loss-23-09-04-23-38-15---Vmlldzo1MzA4MzIw?accessToken=5eu2sndit2mo6eqls8h38sklcgfwt660ek1f2czlgtqjv2c6tida47qm1oty8ik9). 57 | 58 | ## Potential Usecase 59 | Tiny but strong language models are useful for many applications. Here are some potential usecases: 60 | - Assisting speculative decoding of larger models. (See this [tutorial](https://twitter.com/karpathy/status/1697318534555336961) by Andrej Karpathy) 61 | - Deployment on edge devices with restricted memory and computational capacities, for functionalities like real-time machine translation without an internet connection (the 4bit-quantized TinyLlama-1.1B's weight only takes up 637 MB). 62 | - Enabling real-time dialogue generation in video games. 63 | 64 | Moreover, our code can be a **reference for enthusiasts keen on pretraining language models under 5 billion parameters** without diving too early into [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). 65 | 66 | ## Training Details 67 | Below are some details of our training setup: 68 | 69 | | Setting | Description | 70 | |---------------------------------|----------------------------------------------------------------| 71 | | Parameters | 1.1B | 72 | | Attention Variant | Grouped Query Attention | 73 | | Model Size | Layers: 22, Heads: 32, Query Groups: 4, Embedding Size: 2048, Intermediate Size (Swiglu): 5632| 74 | | Sequence Length | 2048 | 75 | | Batch Size | 2 million tokens (2048 * 1024) | 76 | | Learning Rate | 4e-4 | 77 | | Learning Rate Schedule | Cosine with 2000 warmup steps. See [Issue 27](https://github.com/jzhang38/TinyLlama/issues/27) for a minor bug | 78 | | Training Data | [Slimpajama](https://huggingface.co/datasets/cerebras/slimpajama-627b) & [Starcoderdata](https://huggingface.co/datasets/bigcode/starcoderdata) | 79 | | Data Preprocessing | Excluded GitHub subset of Slimpajama; Sampled all code from Starcoderdata | 80 | | Combined Dataset Size | Around 950B tokens | 81 | | Total Tokens During Training | 3 trillion (slightly more than 3 epochs/1430k steps) | 82 | | Natural Language to Code Ratio | 7:3 | 83 | | Hardware | 16 A100-40G GPUs | 84 | 85 | 86 | 87 | 88 | 89 | 90 | ## Blazingly Fast 91 | Our codebase supports the following features: 92 | - multi-gpu and multi-node distributed training with FSDP. 93 | - flash attention 2. 94 | - fused layernorm. 95 | - fused swiglu. 96 | - fused cross entropy loss . 97 | - fused rotary positional embedding. 98 | 99 | Credit: flash attention 2, fused layernorm, fused cross entropy loss, and fused 100 | rotary positional embedding are from the [FlashAttention repo](https://github.com/Dao-AILab/flash-attention/). Fused swiglu is from [xformers](https://github.com/facebookresearch/xformers). 101 | 102 | Thanks to those optimizations, we achieve a throughput of **24k** tokens per second per A100-40G GPU, which translates to **56% model flops utilization** without activation checkpointing (We expect the MFU to be even higher on A100-80G). It means you can train a chinchilla-optimal TinyLlama (1.1B param, 22B tokens) in **32 hours with 8 A100**. Those optimizations also greatly reduce the memory footprint, allowing us to stuff our 1.1B model into 40GB GPU RAM and train with a per-gpu batch size of 16k tokens. **You can also pretrain TinyLlama on 3090/4090 GPUs with a smaller per-gpu batch size**. 103 | Below is a comparison of the training speed of our codebase with that of Pythia and MPT. 104 | 105 | 106 | | Model | A100 GPU hours taken on 300B tokens| 107 | |-----------------------------------|------------------------------------| 108 | |TinyLlama-1.1B | 3456 | 109 | |[Pythia-1.0B](https://huggingface.co/EleutherAI/pythia-1b) | 4830 | 110 | |[MPT-1.3B](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) | 7920 | 111 | 112 | The Pythia number comes from their [paper](https://arxiv.org/abs/2304.01373). The MPT number comes from [here](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b), in which they say MPT-1.3B " was trained on 440 A100-40GBs for about half a day" on 200B tokens. 113 | 114 | The fact that TinyLlama is a relatively small model with grouped query attention means it is also fast during inference. Below are some throughputs that we measure: 115 | 116 | | Framework | Device | Settings | Throughput (tokens/sec) | 117 | |-----------|--------------|-----|-----------| 118 | |[Llama.cpp](https://github.com/ggerganov/llama.cpp) | Mac M2 16GB RAM | batch_size=1; 4-bit inference| 71.8 | 119 | |[vLLM](https://github.com/vllm-project/vllm) | A40 GPU | batch_size=100, n=10 | 7094.5 | 120 | 121 | 122 | ## Pretrain 123 | Please refer to [PRETRAIN.md](PRETRAIN.md) for instructions on how to pretrain TinyLlama. 124 | 125 | ## Finetune 126 | We include a simple full-parameter finetuning & inference script in [sft](sft). Our V0.1 chat model is finetuned using this script. The FT dataset we use is [openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco). 127 | For finetuning with less than 4GB RAM, we refer you to the [Qlora](https://github.com/artidoro/qlora) and [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) repos. 128 | We did not undergo extensive hyperparameter tuning nor choose more performant FT datasets. We hope the community can explore on finetuning TinyLlama and come up with better chat models. I will include community-finetuned models in this repo. 129 | 130 | ## TODO 131 | This project is still under active development. We are a really small team. Community feedback and contributions are highly appreciated. Here are some things we plan to work on: 132 | - [ ] Add scripts for pretraining on other datasets. 133 | - [ ] Sequence length extrapolation. 134 | - [ ] Test out speculative decoding for Llama-2-7B. 135 | - [ ] Test the throughput on RTX 3090/4090. 136 | - [ ] Add fine-tuning scripts. 137 | - [ ] Properly evaluate the model on downstream tasks. 138 | - [ ] A demo running on mobile phones. 139 | - [ ] Explore retrieval-augmentation. 140 | 141 | 142 | 143 | ## Acknowledgements 144 | This repository is built upon [lit-gpt](https://github.com/Lightning-AI/lit-gpt) and [flash-attention](https://github.com/Dao-AILab/flash-attention). Be sure to explore this fantastic open-source project if it's new to you! 145 | ``` 146 | @online{lit-gpt, 147 | author = {Lightning AI}, 148 | title = {Lit-GPT}, 149 | url = {https://github.com/Lightning-AI/lit-gpt}, 150 | year = {2023}, 151 | } 152 | @article{dao2023flashattention2, 153 | title ={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, 154 | author ={Dao, Tri}, 155 | year ={2023} 156 | } 157 | ``` 158 | 159 | ## Citation 160 | This project is currently contributed by [Peiyuan Zhang](https://veiled-texture-20c.notion.site/Peiyuan-Zhang-ab24b48621c9491db767a76df860873a?pvs=4) *, [Guangtao Zeng](https://github.com/ChaosCodes) *, [Tianduo Wang](https://github.com/TianduoWang) and [Wei Lu](https://istd.sutd.edu.sg/people/faculty/lu-wei/) from the StatNLP Research Group of Singapore University of Technology and Design. 161 | 162 | If you find our work valuable, please cite: 163 | 164 | ``` 165 | @misc{zhang2024tinyllama, 166 | title={TinyLlama: An Open-Source Small Language Model}, 167 | author={Peiyuan Zhang and Guangtao Zeng and Tianduo Wang and Wei Lu}, 168 | year={2024}, 169 | eprint={2401.02385}, 170 | archivePrefix={arXiv}, 171 | primaryClass={cs.CL} 172 | } 173 | ``` 174 | 175 | ## Frequently Asked Questions 176 | 177 | #### 1. Why would pretraining a 1.1B model for so long make sense? Doesn't it contradict the Chinchilla Scaling Law? 178 | 179 | The training loss curve of Llama 2 180 | 181 | Above is the training loss curve taken from the Llama 2 paper. Here I quote from that paper: "We observe that after pretraining on 2T Tokens, the models still did not show any sign of saturation". That is why we believe pretraining a 1.1B model for 3T tokens is a reasonable thing to do. Even if the loss curve does not go down eventually, we can still study the phenomenon of saturation and learn something from it. 182 | 183 | #### 2. What does "saturation" mean? 184 | Figure 10 of the Pythia paper 185 | 186 | The figure from the Pythia paper displays the LAMBADA accuracy plotted against the total training tokens (300B). The term "saturation" pertains specifically to the 70M and 160M models. Notably, even the 410M model does not saturate with 300B tokens, as it continues to show an increasing trend, similar to the trend of larger models. 187 | 188 | 189 | ## Star History 190 | 191 | [![Star History Chart](https://api.star-history.com/svg?repos=jzhang38/TinyLlama&type=Date)](https://star-history.com/#jzhang38/TinyLlama&Date) 192 | 193 | -------------------------------------------------------------------------------- /README_zh-CN.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # TinyLlama-1.1B 4 | [English](README.md) | 中文 5 | 6 | [Chat Demo](https://huggingface.co/spaces/TinyLlama/tinyllama-chat) 7 |
8 | 9 | TinyLlama项目旨在在3万亿tokens上进行预训练,构建一个拥有11亿参数的Llama模型。经过精心优化,我们"仅"需16块A100-40G的GPU,便可在90天内完成这个任务🚀🚀。训练已于2023-09-01开始。 10 | 11 | 12 |
13 | 14 |
15 | 我们采用了与Llama 2完全相同的架构和分词器。这意味着TinyLlama可以在许多基于Llama的开源项目中即插即用。此外,TinyLlama只有1.1B的参数,体积小巧,适用于需要限制计算和内存占用的多种应用。 16 | 17 | #### 新闻 18 | 19 | * 2023-12-18: 20 | * 添加两个文档 [1](https://whimsical-aphid-86d.notion.site/Release-of-TinyLlama-1-5T-Checkpoints-Postponed-01b266998c1c47f78f5ae1520196d194?pvs=4), [2](https://whimsical-aphid-86d.notion.site/Latest-Updates-from-TinyLlama-Team-7d30c01fff794da28ccc952f327c8d4f?pvs=4) 说明训练曲线、项目时间表和错误修复的变化。 21 | * 2023-10-03: 22 | * 在speculative decoding中添加llama.cpp的代码示例。具体请查看 [speculative_decoding/README.md](speculative_decoding/README.md)。 23 | * 2023-10-02: 1. 1T-token检查点刚发布。2. 我们在[huggingface](https://huggingface.co/TinyLlama/tinyLlama-intermediate-checkpoints/tree/step-480k-token-1007B)上记录了**所有**中间检查点。 24 | * 2023-09-28: 启用[Discord](https://discord.gg/74Wcx4j5Nb)服务器。 25 | * 2023-09-18: 26 | * 发布了一个 [chat demo](https://huggingface.co/spaces/TinyLlama/tinyllama-chat),欢迎点击链接来尝试我们的模型。 27 | * 2023-09-16: 28 | * 发布了目前已经训练了 5.03 亿个 token 的 [checkpoints 模型](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-240k-503b)。 29 | * 基于 5.03 亿 token 的 [checkpoints 模型](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-240k-503b) 在 OpenAssistant 数据集上微调并开源了聊天模型 [TinyLlama-Chat-V0.1](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.1) ,并添加了我们的 [微调脚本](sft) 。 30 | * 添加了更多的评测数据集,您可以通过 [EVAL.md](EVAL.md) 文件来查看我们各模型的结果。 31 | 32 | 33 | 34 | 35 | #### 发布时间表 36 | 37 | 我们会根据以下计划逐步发布中间checkpoint。我们也列了一些基线模型进行比较。 38 | 39 | 基座模型: 40 | 41 | | Date | 模型权重 | Tokens | Step | Commonsense Avg | 42 | | ---------- | ------------------------------------------------------------ | ------ | ---- | --------------- | 43 | | 2023-09-01 | Pythia-1.0B | 300B | 143k | 48.30 | 44 | | 2023-09-04 | [TinyLlama-1.1B-intermediate-step-50k-105b](https://huggingface.co/PY007/TinyLlama-1.1B-step-50K-105b) ([ModelScope](https://www.modelscope.cn/models/chaoscodes/TinyLlama-1.1B-step-50K-105b/files)) | 105B | 50k | 46.11 | 45 | | 2023-09-16 | [TinyLlama-1.1B-intermediate-step-240k-503b](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-240k-503b) ([ModelScope](https://www.modelscope.cn/models/chaoscodes/TinyLlama-1.1B-intermediate-step-240k-503b/files)) | 503B | 240K | 48.28 | 46 | | 2023-10-01 | [TinyLlama-1.1B-intermediate-step-480k-1T](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-480k-1T) | 1T | 480k | 50.22 | 47 | | 2023-11-04 | [TinyLlama-1.1B-intermediate-step-715k-1.5T](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T) | 1.5T |715k |51.28 | 48 | | 2023-11-20 | [TinyLlama-1.1B-intermediate-step-955k-2T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T) | 2T |955k |51.64 | 49 | | 2023-12-11 | [TinyLlama-1.1B-intermediate-step-1195k-2.5T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T) | 2.5T | 1195k |53.86 | 50 | | 2023-12-28 | [TinyLlama-1.1B-intermediate-step-1431k-3T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T) | 3T | 1431k | 52.99 | 51 | 52 | 对话模型: 53 | 54 | | Date | 模型权重 | Tokens | Step | Commonsense Avg | 55 | |------------|-------------------------------------------------|--------|------| --------------- | 56 | | 2023-09-16 | [TinyLlama-1.1B-Chat-V0.1](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.1) ([ModelScope](https://www.modelscope.cn/models/chaoscodes/TinyLlama-1.1B-Chat-v0.1/files)) | 503B | 240K | 49.57 | 57 | | 2023-10-1 | [TinyLlama-1.1B-Chat-V0.3](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.3) | 1T | 480K | 51.36 | 58 | | 2023-11-04 | [TinyLlama-1.1B-Chat-V0.4](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4) | 1.5T | 715K | 52.30 | 59 | 60 | 需要注意的是,由于我们的现在模型还处于训练初期,学习率并没有完全稳定下来,为了更好的体验我们的模型,您可以下载我们 [聊天模型](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) 或者通过 [chat demo](https://huggingface.co/spaces/TinyLlama/tinyllama-chat) 来尝试我们的模型。 61 | 62 | 63 | 你们也可以在[这里](https://api.wandb.ai/links/lance777/pgvhrsny)实时跟踪TinyLlama的训练损失。 64 | 65 | ## 潜在场景 66 | 小型但强大的语言模型对许多应用都很有用。以下是一些潜在的场景: 67 | - 帮助对大型模型进行speculative decoding。 68 | - 在边缘装置上运行,比如离线的实时机器翻译 (TinyLlama的4比特量化版本的模型权重只需要550MB的内存)。 69 | - 在游戏中实现实时对话生成(因为还得给游戏本身留显存所以模型要小)。 70 | 71 | 此外,我们的代码可以给初学者做一个**入门预训练的简洁参考**。如果你要训练50亿以下参数的语言模型, 你其实不需要Megatron-LM。 72 | 73 | ## 训练细节 74 | 以下是我们训练设置的一些细节: 75 | 76 | | Setting | Description | 77 | |---------------------------------|----------------------------------------------------------------| 78 | | Parameters | 1.1B | 79 | | Attention Variant | Grouped Query Attention | 80 | | Model Size | Layers: 22, Heads: 32, Query Groups: 4, Embedding Size: 2048, Intermediate Size (Swiglu): 5632| 81 | | Sequence Length | 2048 | 82 | | Batch Size | 2 million tokens (2048 * 1024) | 83 | | Learning Rate | 4e-4 | 84 | | Learning Rate Schedule | Cosine with 2000 warmup steps | 85 | | Training Data | [Slimpajama](https://huggingface.co/datasets/cerebras/slimpajama-627b) & [Starcoderdata](https://huggingface.co/datasets/bigcode/starcoderdata) | 86 | | Data Preprocessing | Excluded GitHub subset of Slimpajama; Sampled all code from Starcoderdata | 87 | | Combined Dataset Size | Around 950B tokens | 88 | | Total Tokens During Training | 3 trillion (slightly more than 3 epochs/143k steps) | 89 | | Natural Language to Code Ratio | 7:3 | 90 | | Hardware | 16 A100-40G GPUs | 91 | 92 | 93 | 94 | 95 | 96 | 97 | ## 速度极快 98 | 我们的代码库支持以下特性: 99 | - 使用FSDP进行多GPU和多节点分布式训练 100 | - flash attention 2 101 | - 融合层归一化 (fused layernorm) 102 | - 融合swiglu (fused swiglu) 103 | - 融合交叉熵损失 (fused cross entropy loss) 104 | - 融合旋转位置嵌入 (fused rotary positional embedding) 105 | 106 | 致谢:flash attention 2、融合层归一化、融合交叉熵损失和融合旋转位置嵌入来自于[FlashAttention](https://github.com/Dao-AILab/flash-attention/)仓库;融合swiglu来自于[xformers](https://github.com/facebookresearch/xformers)。 107 | 108 | 有了这些优化, 我们可以达到**24k tokens/秒/A100**的训练速度,也就是56%的MFU(在A100-80G上的MFU会更高)。这个速度可以让你可以在**8个A100上用32小时训练一个chinchilla-optimial的模型**(11亿参数,220亿token)。这些优化也大大减少了显存占用, 我们可以把11亿参数的模型塞入40GB的GPU里面还能同时维持16k tokens的per-gpu batch size。只需要把batch size改小一点, 你就可以在**RTX 3090/4090**上面训练TinyLlama。 109 | 下面是我们的代码库与Pythia和MPT的训练速度的比较。 110 | 111 | 112 | | Model | A100 GPU hours taken on 300B tokens| 113 | |-----------------------------------|------------------------------------| 114 | |TinyLlama-1.1B | 3456 | 115 | |[Pythia-1.0B](https://huggingface.co/EleutherAI/pythia-1b) | 4830 | 116 | |[MPT-1.3B](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) | 7920 | 117 | 118 | Pythia的数字来自他们的论文。MPT的数字来自[这里](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b),作者说MPT-1.3B"was trained on 440 A100-40GBs for about half a day" on 200B tokens。 119 | 120 | TinyLlama是一个相对较小的模型, 同时我们用了GQA, 这意味着它在推理期间也很快。以下是我们测量的一些推理速度: 121 | 122 | | Framework | Device | Settings | Throughput (tokens/sec) | 123 | |-----------|--------------|-----|-----------| 124 | |[Llama.cpp](https://github.com/ggerganov/llama.cpp) | Mac M2 16GB RAM | batch_size=1; 4-bit inference| 71.8 | 125 | |[vLLM](https://github.com/vllm-project/vllm) | A40 GPU | batch_size=100, n=10 | 7094.5 | 126 | 127 | 128 | ## 开始预训练 129 | 请参考[PRETRAIN.md](PRETRAIN.md)。 130 | 131 | 132 | 133 | ## 微调 134 | 135 | * 我们在 [sft](sft) 中添加了我们进行微调和推理的代码。并且基于这个代码我们在[openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) 数据集上进行了微调,得到了我们的第一版[聊天模型](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.1)。 136 | * 如果您希望在 RAM 小于 4GB 的 GPU 上对用我们的模型进行微调,可以参考并使用 [Qlora](https://github.com/artidoro/qlora) 和 [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) 项目。 137 | * 目前微调的时候我们并没有广泛对超参进行搜索,也没有选择潜在更优的 instruction 数据集。我们希望促进 NLP 社区对于我们的TinyLlama模型的开放研究,并开源更好的微调聊天模型。我们也会把这些模型放在这个项目中。 138 | 139 | 140 | 141 | ## TODO 142 | 该项目仍在积极开发中。我们团队很小,非常欢迎社区的反馈和贡献。以下是我们计划进行的一些工作: 143 | - [ ] Add scripts for pretraining on other datasets. 144 | - [ ] Sequence length extrapolation. 145 | - [ ] Test out speculative decoding for Llama-2-7B. 146 | - [ ] Test the throughput on RTX 3090/4090. 147 | - [ ] Add fine-tuning scripts. 148 | - [ ] Properly evaluate the model on downstream tasks. 149 | - [ ] A demo running on mobile phones. 150 | - [ ] Explore retrieval-augmentation. 151 | 152 | ## Star History 153 | 154 | [![Star History Chart](https://api.star-history.com/svg?repos=jzhang38/TinyLlama&type=Date)](https://star-history.com/#jzhang38/TinyLlama&Date) 155 | 156 | 157 | ## Acknowledgements 158 | 这个仓库基于出色的开源项目[lit-gpt](https://github.com/Lightning-AI/lit-gpt)和[flash-attention](https://github.com/Dao-AILab/flash-attention)构建. 159 | ``` 160 | @online{lit-gpt, 161 | author = {Lightning AI}, 162 | title = {Lit-GPT}, 163 | url = {https://github.com/Lightning-AI/lit-gpt}, 164 | year = {2023}, 165 | } 166 | @article{dao2023flashattention2, 167 | title ={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, 168 | author ={Dao, Tri}, 169 | year ={2023} 170 | } 171 | ``` 172 | 173 | ## Citation 174 | 此项目目前由[Peiyuan Zhang](https://github.com/jzhang38),[Guangtao Zeng](https://github.com/ChaosCodes),[Tianduo Wang](https://github.com/TianduoWang)和[Wei Lu](https://istd.sutd.edu.sg/people/faculty/lu-wei/)贡献。 175 | 176 | 如果您觉得我们的工作有价值, 可以引用: 177 | 178 | ``` 179 | @misc{zhang2024tinyllama, 180 | title={TinyLlama: An Open-Source Small Language Model}, 181 | author={Peiyuan Zhang and Guangtao Zeng and Tianduo Wang and Wei Lu}, 182 | year={2024}, 183 | eprint={2401.02385}, 184 | archivePrefix={arXiv}, 185 | primaryClass={cs.CL} 186 | } 187 | ``` 188 | 189 | -------------------------------------------------------------------------------- /chat_gradio/README.md: -------------------------------------------------------------------------------- 1 | ## Tinyllama Chatbot Implementation with Gradio 2 | 3 | We offer an easy way to interact with Tinyllama. This guide explains how to set up a local Gradio demo for a chatbot using TinyLlama. 4 | (A demo is also available on the Hugging Face Space [TinyLlama/tinyllama_chatbot](https://huggingface.co/spaces/TinyLlama/tinyllama-chat)) or Colab [colab](https://colab.research.google.com/drive/1qAuL5wTIa-USaNBu8DH35KQtICTnuLsy?usp=sharing). 5 | 6 | ### Requirements 7 | * Python>=3.8 8 | * PyTorch>=2.0 9 | * Transformers>=4.34.0 10 | * Gradio>=4.13.0 11 | 12 | ### Installation 13 | `pip install -r requirements.txt` 14 | 15 | ### Usage 16 | 17 | `python TinyLlama/chat_gradio/app.py` 18 | 19 | * After running it, open the local URL displayed in your terminal in your web browser. (For server setup, use SSH local port forwarding with the command: `ssh -L [local port]:localhost:[remote port] [username]@[server address]`.) 20 | * Interact with the chatbot by typing questions or commands. 21 | 22 | 23 | **Note:** The chatbot's performance may vary based on your system's hardware. Ensure your system meets the above requirements for optimal experience. 24 | -------------------------------------------------------------------------------- /chat_gradio/app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import torch 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer 5 | from threading import Thread 6 | 7 | # Loading the tokenizer and model from Hugging Face's model hub. 8 | tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") 9 | model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") 10 | 11 | # using CUDA for an optimal experience 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | model = model.to(device) 14 | 15 | 16 | # Defining a custom stopping criteria class for the model's text generation. 17 | class StopOnTokens(StoppingCriteria): 18 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 19 | stop_ids = [2] # IDs of tokens where the generation should stop. 20 | for stop_id in stop_ids: 21 | if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token. 22 | return True 23 | return False 24 | 25 | 26 | # Function to generate model predictions. 27 | def predict(message, history): 28 | history_transformer_format = history + [[message, ""]] 29 | stop = StopOnTokens() 30 | 31 | # Formatting the input for the model. 32 | messages = "".join(["".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]]) 33 | for item in history_transformer_format]) 34 | model_inputs = tokenizer([messages], return_tensors="pt").to(device) 35 | streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) 36 | generate_kwargs = dict( 37 | model_inputs, 38 | streamer=streamer, 39 | max_new_tokens=1024, 40 | do_sample=True, 41 | top_p=0.95, 42 | top_k=50, 43 | temperature=0.7, 44 | num_beams=1, 45 | stopping_criteria=StoppingCriteriaList([stop]) 46 | ) 47 | t = Thread(target=model.generate, kwargs=generate_kwargs) 48 | t.start() # Starting the generation in a separate thread. 49 | partial_message = "" 50 | for new_token in streamer: 51 | partial_message += new_token 52 | if '' in partial_message: # Breaking the loop if the stop token is generated. 53 | break 54 | yield partial_message 55 | 56 | 57 | # Setting up the Gradio chat interface. 58 | gr.ChatInterface(predict, 59 | title="Tinyllama_chatBot", 60 | description="Ask Tiny llama any questions", 61 | examples=['How to cook a fish?', 'Who is the president of US now?'] 62 | ).launch() # Launching the web interface. 63 | -------------------------------------------------------------------------------- /chat_gradio/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0 2 | transformers>=4.35.0 3 | gradio>=4.13.0 4 | -------------------------------------------------------------------------------- /lit_gpt/__init__.py: -------------------------------------------------------------------------------- 1 | from lit_gpt.model import GPT 2 | from lit_gpt.config import Config 3 | from lit_gpt.tokenizer import Tokenizer 4 | from lit_gpt.fused_cross_entropy import FusedCrossEntropyLoss 5 | from lightning_utilities.core.imports import RequirementCache 6 | 7 | if not bool(RequirementCache("torch>=2.1.0dev")): 8 | raise ImportError( 9 | "Lit-GPT requires torch nightly (future torch 2.1). Please follow the installation instructions in the" 10 | " repository README.md" 11 | ) 12 | _LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.1.0.dev0") 13 | if not bool(_LIGHTNING_AVAILABLE): 14 | raise ImportError( 15 | "Lit-GPT requires Lightning nightly (future lightning 2.1). Please run:\n" 16 | f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}" 17 | ) 18 | 19 | 20 | __all__ = ["GPT", "Config", "Tokenizer"] 21 | -------------------------------------------------------------------------------- /lit_gpt/adapter.py: -------------------------------------------------------------------------------- 1 | """Implementation of the paper: 2 | 3 | LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention 4 | https://arxiv.org/abs/2303.16199 5 | 6 | Port for Lit-GPT 7 | """ 8 | from dataclasses import dataclass 9 | from typing import Any, Dict, List, Optional, Tuple, Union 10 | 11 | import torch 12 | import torch.nn as nn 13 | from typing_extensions import Self 14 | 15 | from lit_gpt.config import Config as BaseConfig 16 | from lit_gpt.model import GPT as BaseModel 17 | from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention 18 | from lit_gpt.model import KVCache, RoPECache, apply_rope 19 | 20 | 21 | @dataclass 22 | class Config(BaseConfig): 23 | adapter_prompt_length: int = 10 24 | adapter_start_layer: int = 2 25 | 26 | 27 | class GPT(BaseModel): 28 | """The implementation is identical to `lit_gpt.model.GPT` with the exception that 29 | the `Block` saves the layer index and passes it down to the attention layer.""" 30 | 31 | def __init__(self, config: Config) -> None: 32 | nn.Module.__init__(self) 33 | assert config.padded_vocab_size is not None 34 | self.config = config 35 | 36 | self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) 37 | self.transformer = nn.ModuleDict( 38 | dict( 39 | wte=nn.Embedding(config.padded_vocab_size, config.n_embd), 40 | h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), 41 | ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), 42 | ) 43 | ) 44 | 45 | self.rope_cache: Optional[RoPECache] = None 46 | self.mask_cache: Optional[torch.Tensor] = None 47 | self.kv_caches: List[KVCache] = [] 48 | self.adapter_kv_caches: List[KVCache] = [] 49 | 50 | def reset_cache(self) -> None: 51 | super().reset_cache() 52 | self.adapter_kv_caches.clear() 53 | 54 | def forward( 55 | self, 56 | idx: torch.Tensor, 57 | max_seq_length: Optional[int] = None, 58 | input_pos: Optional[torch.Tensor] = None, 59 | lm_head_chunk_size: int = 0, 60 | ) -> Union[torch.Tensor, List[torch.Tensor]]: 61 | B, T = idx.size() 62 | use_kv_cache = input_pos is not None 63 | 64 | block_size = self.config.block_size 65 | if max_seq_length is None: 66 | max_seq_length = block_size 67 | if use_kv_cache: # not relevant otherwise 68 | assert ( 69 | max_seq_length >= T 70 | ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" 71 | assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" 72 | assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}" 73 | 74 | if self.rope_cache is None: 75 | self.rope_cache = self.build_rope_cache(idx) 76 | # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask 77 | # for the kv-cache support (only during inference), we only create it in that situation 78 | # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 79 | if use_kv_cache and self.mask_cache is None: 80 | self.mask_cache = self.build_mask_cache(idx) 81 | 82 | cos, sin = self.rope_cache 83 | if use_kv_cache: 84 | cos = cos.index_select(0, input_pos) 85 | sin = sin.index_select(0, input_pos) 86 | mask = self.mask_cache.index_select(2, input_pos) 87 | mask = mask[:, :, :, :max_seq_length] 88 | else: 89 | cos = cos[:T] 90 | sin = sin[:T] 91 | mask = None 92 | 93 | # forward the model itself 94 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 95 | 96 | if not use_kv_cache: 97 | for block in self.transformer.h: 98 | x, *_ = block(x, (cos, sin), max_seq_length) 99 | else: 100 | self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1)) 101 | self.adapter_kv_caches = self.adapter_kv_caches or [None for _ in range(self.config.n_layer)] 102 | for i, block in enumerate(self.transformer.h): 103 | x, self.kv_caches[i], self.adapter_kv_caches[i] = block( 104 | x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i], self.adapter_kv_caches[i] 105 | ) 106 | 107 | x = self.transformer.ln_f(x) 108 | 109 | if lm_head_chunk_size > 0: 110 | # chunk the lm head logits to reduce the peak memory used by autograd 111 | return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] 112 | return self.lm_head(x) # (b, t, vocab_size) 113 | 114 | @classmethod 115 | def from_name(cls, name: str, **kwargs: Any) -> Self: 116 | return cls(Config.from_name(name, **kwargs)) 117 | 118 | def _init_weights(self, module: nn.Module) -> None: 119 | """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" 120 | super()._init_weights(module) 121 | if isinstance(module, CausalSelfAttention): 122 | module.reset_parameters() 123 | 124 | 125 | class Block(nn.Module): 126 | """The implementation is identical to `lit_gpt.model.Block` with the exception that 127 | we replace the attention layer where adaption is implemented.""" 128 | 129 | def __init__(self, config: Config, block_idx: int) -> None: 130 | super().__init__() 131 | self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) 132 | self.attn = CausalSelfAttention(config, block_idx) 133 | if not config.shared_attention_norm: 134 | self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) 135 | self.mlp = config.mlp_class(config) 136 | 137 | self.config = config 138 | 139 | def forward( 140 | self, 141 | x: torch.Tensor, 142 | rope: RoPECache, 143 | max_seq_length: int, 144 | mask: Optional[torch.Tensor] = None, 145 | input_pos: Optional[torch.Tensor] = None, 146 | kv_cache: Optional[KVCache] = None, 147 | adapter_kv_cache: Optional[KVCache] = None, 148 | ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: 149 | n_1 = self.norm_1(x) 150 | h, new_kv_cache, new_adapter_kv_cache = self.attn( 151 | n_1, rope, max_seq_length, mask, input_pos, kv_cache, adapter_kv_cache 152 | ) 153 | if self.config.parallel_residual: 154 | n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) 155 | x = x + h + self.mlp(n_2) 156 | else: 157 | if self.config.shared_attention_norm: 158 | raise NotImplementedError( 159 | "No checkpoint amongst the ones we support uses this configuration" 160 | " (non-parallel residual and shared attention norm)." 161 | ) 162 | x = x + h 163 | x = x + self.mlp(self.norm_2(x)) 164 | return x, new_kv_cache, new_adapter_kv_cache 165 | 166 | 167 | class CausalSelfAttention(BaseCausalSelfAttention): 168 | """A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention 169 | over the adaption prompt.""" 170 | 171 | def __init__(self, config: Config, block_idx: int) -> None: 172 | super().__init__(config) 173 | if block_idx >= config.adapter_start_layer: 174 | # adapter embedding layer 175 | self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) 176 | # gate for adaption 177 | self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) 178 | self.reset_parameters() 179 | self.block_idx = block_idx 180 | 181 | def forward( 182 | self, 183 | x: torch.Tensor, 184 | rope: RoPECache, 185 | max_seq_length: int, 186 | mask: Optional[torch.Tensor] = None, 187 | input_pos: Optional[torch.Tensor] = None, 188 | kv_cache: Optional[KVCache] = None, 189 | adapter_kv_cache: Optional[KVCache] = None, 190 | ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: 191 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 192 | 193 | qkv = self.attn(x) 194 | 195 | # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) 196 | q_per_kv = self.config.n_head // self.config.n_query_groups 197 | total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value 198 | qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) 199 | qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) 200 | 201 | # split batched computation into three 202 | q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) 203 | 204 | # repeat k and v if necessary 205 | if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!) 206 | # for MHA this is a no-op 207 | k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 208 | v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 209 | 210 | q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) 211 | k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) 212 | v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) 213 | 214 | n_elem = int(self.config.rotary_percentage * self.config.head_size) 215 | 216 | cos, sin = rope 217 | q_roped = apply_rope(q[..., :n_elem], cos, sin) 218 | k_roped = apply_rope(k[..., :n_elem], cos, sin) 219 | q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) 220 | k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) 221 | 222 | if kv_cache is not None: 223 | cache_k, cache_v = kv_cache 224 | cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) 225 | # check if reached token limit 226 | if input_pos[-1] >= max_seq_length: 227 | input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) 228 | # shift 1 position to the left 229 | cache_k = torch.roll(cache_k, -1, dims=2) 230 | cache_v = torch.roll(cache_v, -1, dims=2) 231 | k = cache_k.index_copy_(2, input_pos, k) 232 | v = cache_v.index_copy_(2, input_pos, v) 233 | kv_cache = k, v 234 | 235 | y = self.scaled_dot_product_attention(q, k, v, mask=mask) 236 | 237 | if self.block_idx >= self.config.adapter_start_layer: 238 | aT = self.config.adapter_prompt_length 239 | if adapter_kv_cache is not None: 240 | ak, av = adapter_kv_cache 241 | else: 242 | prefix = self.adapter_wte.weight.reshape(1, aT, C) 243 | aqkv = self.attn(prefix) 244 | aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) 245 | aqkv = aqkv.permute(0, 2, 3, 1, 4) 246 | _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) 247 | if self.config.n_query_groups != 1: 248 | # for MHA this is a no-op 249 | ak = ak.repeat_interleave(q_per_kv, dim=2) 250 | av = av.repeat_interleave(q_per_kv, dim=2) 251 | ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) 252 | av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) 253 | adapter_kv_cache = (ak, av) 254 | 255 | amask = torch.ones(T, aT, dtype=torch.bool, device=x.device) 256 | ay = self.scaled_dot_product_attention(q, ak, av, amask) 257 | y = y + self.gating_factor * ay 258 | 259 | y = y.reshape(B, T, C) # re-assemble all head outputs side by side 260 | 261 | # output projection 262 | y = self.proj(y) 263 | 264 | return y, kv_cache, adapter_kv_cache 265 | 266 | def reset_parameters(self) -> None: 267 | torch.nn.init.zeros_(self.gating_factor) 268 | 269 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 270 | """For compatibility with older checkpoints.""" 271 | if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: 272 | state_dict[key] = state_dict[key].permute(0, 2, 1, 3) 273 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 274 | 275 | 276 | def mark_only_adapter_as_trainable(model: GPT) -> None: 277 | """Sets `requires_grad=False` for all non-adapter weights.""" 278 | for name, param in model.named_parameters(): 279 | param.requires_grad = adapter_filter(name, param) 280 | 281 | 282 | def adapter_filter(key: str, value: Any) -> bool: 283 | return "adapter_wte" in key or "gating_factor" in key 284 | -------------------------------------------------------------------------------- /lit_gpt/adapter_v2.py: -------------------------------------------------------------------------------- 1 | """Implementation of the paper: 2 | 3 | LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model 4 | https://arxiv.org/abs/2304.15010 5 | 6 | Port for Lit-GPT 7 | """ 8 | from dataclasses import dataclass 9 | from typing import Any, Dict, List, Optional, Tuple, Type 10 | 11 | import torch 12 | import torch.nn as nn 13 | from typing_extensions import Self 14 | 15 | import lit_gpt 16 | from lit_gpt.adapter import GPT as BaseModel 17 | from lit_gpt.adapter import Block as BaseBlock 18 | from lit_gpt.adapter import Config as BaseConfig 19 | from lit_gpt.adapter import KVCache, RoPECache 20 | from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention 21 | from lit_gpt.model import apply_rope 22 | from lit_gpt.utils import map_old_state_dict_weights 23 | 24 | 25 | @dataclass 26 | class Config(BaseConfig): 27 | @property 28 | def mlp_class(self) -> Type: 29 | return getattr(lit_gpt.adapter_v2, self._mlp_class) 30 | 31 | 32 | def adapter_filter(key: str, value: Any) -> bool: 33 | adapter_substrings = ( 34 | # regular adapter v1 parameters 35 | "adapter_wte", 36 | "gating_factor", 37 | # adapter v2: new bias and scale used in Linear 38 | "adapter_scale", 39 | "adapter_bias", 40 | # adapter v2: Norm parameters are now trainable 41 | "norm_1", 42 | "norm_2", 43 | "ln_f", 44 | ) 45 | return any(s in key for s in adapter_substrings) 46 | 47 | 48 | class AdapterV2Linear(torch.nn.Module): 49 | def __init__(self, in_features: int, out_features: int, **kwargs) -> None: 50 | super().__init__() 51 | self.linear = torch.nn.Linear(in_features, out_features, **kwargs) 52 | self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False) 53 | self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False) 54 | self.reset_parameters() 55 | 56 | def forward(self, x: torch.Tensor) -> torch.Tensor: 57 | return self.adapter_scale * (self.linear(x) + self.adapter_bias) 58 | 59 | def reset_parameters(self) -> None: 60 | nn.init.zeros_(self.adapter_bias) 61 | nn.init.ones_(self.adapter_scale) 62 | 63 | 64 | class GPT(BaseModel): 65 | def __init__(self, config: Config) -> None: 66 | # Skip the parent class __init__ altogether and replace it to avoid useless allocations 67 | nn.Module.__init__(self) 68 | assert config.padded_vocab_size is not None 69 | self.config = config 70 | 71 | self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=False) 72 | self.transformer = nn.ModuleDict( 73 | dict( 74 | wte=nn.Embedding(config.padded_vocab_size, config.n_embd), 75 | h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), 76 | ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), 77 | ) 78 | ) 79 | 80 | self.rope_cache: Optional[RoPECache] = None 81 | self.mask_cache: Optional[torch.Tensor] = None 82 | self.kv_caches: List[KVCache] = [] 83 | self.adapter_kv_caches: List[KVCache] = [] 84 | 85 | @classmethod 86 | def from_name(cls, name: str, **kwargs: Any) -> Self: 87 | return cls(Config.from_name(name, **kwargs)) 88 | 89 | def _init_weights(self, module: nn.Module) -> None: 90 | """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" 91 | super()._init_weights(module) 92 | if isinstance(module, CausalSelfAttention): 93 | module.reset_parameters() 94 | if isinstance(module, AdapterV2Linear): 95 | module.reset_parameters() 96 | 97 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 98 | """For compatibility with base checkpoints.""" 99 | mapping = {"lm_head.weight": "lm_head.linear.weight"} 100 | state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) 101 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 102 | 103 | 104 | class Block(BaseBlock): 105 | """The implementation is identical to `lit_gpt.model.Block` with the exception that 106 | we replace the attention layer where adaption is implemented.""" 107 | 108 | def __init__(self, config: Config, block_idx: int) -> None: 109 | # Skip the parent class __init__ altogether and replace it to avoid useless allocations 110 | nn.Module.__init__(self) 111 | self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) 112 | self.attn = CausalSelfAttention(config, block_idx) 113 | if not config.shared_attention_norm: 114 | self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) 115 | self.mlp = config.mlp_class(config) 116 | 117 | self.config = config 118 | 119 | 120 | class CausalSelfAttention(BaseCausalSelfAttention): 121 | def __init__(self, config: Config, block_idx: int) -> None: 122 | """Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for 123 | parameter-efficient fine-tuning. 124 | 125 | *Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for 126 | query, key and value for each head) we can do this in a single pass with a single weight matrix. 127 | """ 128 | # Skip the parent class __init__ altogether and replace it to avoid useless allocations 129 | nn.Module.__init__(self) 130 | shape = (config.n_head + 2 * config.n_query_groups) * config.head_size 131 | # key, query, value projections for all heads, but in a batch 132 | self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias) 133 | # output projection 134 | self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias) 135 | if block_idx >= config.adapter_start_layer: 136 | # adapter embedding layer 137 | self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) 138 | # gate for adaption 139 | self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) 140 | self.reset_parameters() 141 | self.block_idx = block_idx 142 | 143 | self.config = config 144 | 145 | def forward( 146 | self, 147 | x: torch.Tensor, 148 | rope: RoPECache, 149 | max_seq_length: int, 150 | mask: Optional[torch.Tensor] = None, 151 | input_pos: Optional[torch.Tensor] = None, 152 | kv_cache: Optional[KVCache] = None, 153 | adapter_kv_cache: Optional[KVCache] = None, 154 | ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: 155 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 156 | 157 | qkv = self.attn(x) 158 | 159 | # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) 160 | q_per_kv = self.config.n_head // self.config.n_query_groups 161 | total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value 162 | qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) 163 | qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) 164 | 165 | # split batched computation into three 166 | q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) 167 | 168 | # repeat k and v if necessary 169 | if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!) 170 | # for MHA this is a no-op 171 | k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 172 | v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 173 | 174 | q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) 175 | k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) 176 | v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) 177 | 178 | n_elem = int(self.config.rotary_percentage * self.config.head_size) 179 | 180 | cos, sin = rope 181 | q_roped = apply_rope(q[..., :n_elem], cos, sin) 182 | k_roped = apply_rope(k[..., :n_elem], cos, sin) 183 | q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) 184 | k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) 185 | 186 | if kv_cache is not None: 187 | cache_k, cache_v = kv_cache 188 | cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) 189 | # check if reached token limit 190 | if input_pos[-1] >= max_seq_length: 191 | input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) 192 | # shift 1 position to the left 193 | cache_k = torch.roll(cache_k, -1, dims=2) 194 | cache_v = torch.roll(cache_v, -1, dims=2) 195 | k = cache_k.index_copy_(2, input_pos, k) 196 | v = cache_v.index_copy_(2, input_pos, v) 197 | kv_cache = k, v 198 | 199 | y = self.scaled_dot_product_attention(q, k, v, mask=mask) 200 | 201 | if self.block_idx >= self.config.adapter_start_layer: 202 | aT = self.config.adapter_prompt_length 203 | if adapter_kv_cache is not None: 204 | ak, av = adapter_kv_cache 205 | else: 206 | prefix = self.adapter_wte.weight.reshape(1, aT, C) 207 | aqkv = self.attn(prefix) 208 | aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) 209 | aqkv = aqkv.permute(0, 2, 3, 1, 4) 210 | _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) 211 | if self.config.n_query_groups != 1: 212 | # for MHA this is a no-op 213 | ak = ak.repeat_interleave(q_per_kv, dim=2) 214 | av = av.repeat_interleave(q_per_kv, dim=2) 215 | ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) 216 | av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) 217 | adapter_kv_cache = (ak, av) 218 | 219 | amask = torch.ones(T, aT, dtype=torch.bool, device=x.device) 220 | ay = self.scaled_dot_product_attention(q, ak, av, amask) 221 | y = y + self.gating_factor * ay 222 | 223 | y = y.reshape(B, T, C) # re-assemble all head outputs side by side 224 | 225 | # output projection 226 | y = self.proj(y) 227 | 228 | return y, kv_cache, adapter_kv_cache 229 | 230 | def reset_parameters(self) -> None: 231 | torch.nn.init.zeros_(self.gating_factor) 232 | 233 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 234 | """For compatibility with base checkpoints.""" 235 | mapping = { 236 | "attn.weight": "attn.linear.weight", 237 | "attn.bias": "attn.linear.bias", 238 | "proj.weight": "proj.linear.weight", 239 | "proj.bias": "proj.linear.bias", 240 | } 241 | state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) 242 | # For compatibility with older checkpoints 243 | if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: 244 | state_dict[key] = state_dict[key].permute(0, 2, 1, 3) 245 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 246 | 247 | 248 | class GptNeoxMLP(lit_gpt.model.GptNeoxMLP): 249 | def __init__(self, config: Config) -> None: 250 | nn.Module.__init__(self) 251 | self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) 252 | self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) 253 | 254 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 255 | """For compatibility with base checkpoints.""" 256 | mapping = { 257 | "fc.weight": "fc.linear.weight", 258 | "fc.bias": "fc.linear.bias", 259 | "proj.weight": "proj.linear.weight", 260 | "proj.bias": "proj.linear.bias", 261 | } 262 | state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) 263 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 264 | 265 | 266 | class LLaMAMLP(lit_gpt.model.LLaMAMLP): 267 | def __init__(self, config: Config) -> None: 268 | nn.Module.__init__(self) 269 | self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) 270 | self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) 271 | self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) 272 | 273 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 274 | """For compatibility with base checkpoints.""" 275 | mapping = { 276 | "fc_1.weight": "fc_1.linear.weight", 277 | "fc_1.bias": "fc_1.linear.bias", 278 | "fc_2.weight": "fc_2.linear.weight", 279 | "fc_2.bias": "fc_2.linear.bias", 280 | "proj.weight": "proj.linear.weight", 281 | "proj.bias": "proj.linear.bias", 282 | } 283 | state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) 284 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 285 | 286 | 287 | def mark_only_adapter_v2_as_trainable(model: GPT) -> None: 288 | """Sets requires_grad=False for all non-adapter weights""" 289 | for name, param in model.named_parameters(): 290 | param.requires_grad = adapter_filter(name, param) 291 | -------------------------------------------------------------------------------- /lit_gpt/fused_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import xentropy_cuda_lib 6 | 7 | # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for 8 | # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent 9 | # version of PyTorch. The following 2 lines are for backward compatibility with 10 | # older PyTorch. 11 | if "all_gather_into_tensor" not in dir(torch.distributed): 12 | torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base 13 | 14 | 15 | class SoftmaxCrossEntropyLossFn(torch.autograd.Function): 16 | @staticmethod 17 | def forward( 18 | ctx, 19 | logits, 20 | labels, 21 | smoothing=0.0, 22 | ignored_index=-100, 23 | inplace_backward=False, 24 | process_group=None, 25 | ): 26 | """ 27 | logits: (batch, vocab_size) 28 | labels: (batch,) 29 | If process_group is not None, we're doing Tensor Parallel: each process is responsible for 30 | one part of the vocab. The loss needs to be aggregated across processes. 31 | """ 32 | batch, vocab_size = logits.shape 33 | assert labels.shape == (batch,) 34 | world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) 35 | ctx.total_classes = world_size * vocab_size 36 | 37 | if world_size == 1: 38 | losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) 39 | losses.masked_fill_(labels == ignored_index, 0) 40 | labels_local = labels 41 | else: 42 | rank = torch.distributed.get_rank(process_group) 43 | vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size 44 | 45 | # Create a mask of valid vocab ids (1 means it needs to be masked). 46 | labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) 47 | ignored_mask = labels == ignored_index 48 | labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) 49 | 50 | # For tensor parallel cross entropy with smoothing, we want to pass in the total number 51 | # of classes so that smoothing can be applied correctly. If total_classes=-1, use the 52 | # last dimension of the input tensor. 53 | losses, lse_local = xentropy_cuda_lib.forward( 54 | logits, labels_local, smoothing, world_size * vocab_size 55 | ) 56 | assert lse_local.shape == (batch,) 57 | assert losses.shape == (batch,) 58 | losses.masked_fill_(ignored_mask, 0) 59 | # For labels == ignored_index, the loss is always 0. 60 | # If there's no smoothing, if labels are in the vocab of this partition, losses contains 61 | # lse_local - predicted logit, and 0 otherwise. 62 | # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains 63 | # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) 64 | # For labels not in the vocab of this partition, losses contains 65 | # 0.1 * (lse_local - sum logit / total_classes). 66 | 67 | lse_allgather = torch.empty( 68 | world_size, batch, dtype=lse_local.dtype, device=lse_local.device 69 | ) 70 | torch.distributed.all_gather_into_tensor( 71 | lse_allgather, lse_local.contiguous(), group=process_group 72 | ) 73 | handle_losses = torch.distributed.all_reduce( 74 | losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True 75 | ) 76 | lse = torch.logsumexp(lse_allgather, dim=0) 77 | # If there's no smoothing, the total losses are lse_local - predicted_logit, 78 | # we just have to subtract the lse_local and add the lse (global). 79 | # If there's smoothing=0.1, the total losses are 80 | # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) 81 | # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). 82 | rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor") 83 | lse_local = lse_allgather[ 84 | rank_per_sample, torch.arange(batch, device=lse_allgather.device) 85 | ] 86 | 87 | handle_losses.wait() 88 | if smoothing == 0.0: 89 | losses += lse - lse_local 90 | else: 91 | losses += (1 - smoothing) * (lse - lse_local) + smoothing * ( 92 | lse - lse_allgather.sum(dim=0) 93 | ) 94 | losses.masked_fill_(ignored_mask, 0) 95 | 96 | ctx.save_for_backward(logits, lse, labels_local) 97 | ctx.smoothing = smoothing 98 | ctx.ignored_index = ignored_index 99 | ctx.inplace_backward = inplace_backward 100 | return losses 101 | 102 | @staticmethod 103 | def backward(ctx, grad_loss): 104 | logits, lse, labels = ctx.saved_tensors 105 | grad_loss = grad_loss.contiguous() 106 | grad_loss.masked_fill_(labels == ctx.ignored_index, 0) 107 | grad_logits = xentropy_cuda_lib.backward( 108 | grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes 109 | ) 110 | return grad_logits, None, None, None, None, None, None 111 | 112 | 113 | class FusedCrossEntropyLoss(nn.Module): 114 | def __init__( 115 | self, 116 | ignore_index=-100, 117 | reduction="mean", 118 | label_smoothing=0.0, 119 | inplace_backward=True, 120 | process_group=None, 121 | ): 122 | super().__init__() 123 | if reduction not in ["mean", "none"]: 124 | raise NotImplementedError("Only support reduction = 'mean' or 'none'") 125 | self.ignore_index = ignore_index 126 | self.reduction = reduction 127 | self.label_smoothing = label_smoothing 128 | self.inplace_backward = inplace_backward 129 | self.process_group = process_group 130 | 131 | def forward(self, input, target): 132 | assert input.is_cuda and target.is_cuda 133 | # SoftmaxCrossEntropyLoss implicitly casts to float 134 | if len(input.shape) == 3: 135 | input = input.view(-1, input.size(-1)) 136 | target = target.view(-1) 137 | loss = SoftmaxCrossEntropyLossFn.apply( 138 | input, 139 | target, 140 | self.label_smoothing, 141 | self.ignore_index, 142 | self.inplace_backward, 143 | self.process_group, 144 | ) 145 | if self.reduction == "mean": 146 | return loss.sum() / (target != self.ignore_index).sum() 147 | else: 148 | return loss -------------------------------------------------------------------------------- /lit_gpt/fused_rotary_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import math 4 | from typing import Optional, Tuple 5 | 6 | import rotary_emb 7 | import torch 8 | from einops import rearrange, repeat 9 | 10 | class ApplyRotaryEmb(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, x, cos, sin, interleaved=False, inplace=False): 13 | """ 14 | x: (batch_size, seqlen, nheads, headdim) 15 | cos, sin: (seqlen, rotary_dim / 2) 16 | interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead 17 | of 1st half and 2nd half (GPT-NeoX style). 18 | rotary_dim must be <= headdim 19 | Apply rotary embedding to the first rotary_dim of x. 20 | """ 21 | batch, seqlen, nheads, headdim = x.shape 22 | rotary_seqlen, rotary_dim = cos.shape 23 | rotary_dim *= 2 24 | assert rotary_dim <= headdim 25 | assert seqlen <= rotary_seqlen 26 | assert sin.shape == (rotary_seqlen, rotary_dim // 2) 27 | x_ro = x[..., :rotary_dim] 28 | x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) 29 | out = torch.empty_like(x) if not inplace else x 30 | out_ro = out[..., :rotary_dim] 31 | if inplace: 32 | o1, o2 = x1, x2 33 | else: 34 | o1, o2 = ( 35 | out_ro.chunk(2, dim=-1) 36 | if not interleaved 37 | else (out_ro[..., ::2], out_ro[..., 1::2]) 38 | ) 39 | rotary_emb.apply_rotary( 40 | x1, 41 | x2, 42 | rearrange(cos[:seqlen], "s d -> s 1 d"), 43 | rearrange(sin[:seqlen], "s d -> s 1 d"), 44 | o1, 45 | o2, 46 | False, 47 | ) 48 | if not inplace and rotary_dim < headdim: 49 | out[..., rotary_dim:].copy_(x[..., rotary_dim:]) 50 | ctx.save_for_backward(cos, sin) 51 | ctx.interleaved = interleaved 52 | ctx.inplace = inplace 53 | return out if not inplace else x 54 | 55 | @staticmethod 56 | def backward(ctx, do): 57 | cos, sin = ctx.saved_tensors 58 | _, seqlen, _, headdim = do.shape 59 | rotary_dim = cos.shape[-1] 60 | rotary_dim *= 2 61 | inplace = ctx.inplace 62 | do_ro = do[..., :rotary_dim] 63 | do1, do2 = ( 64 | do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2]) 65 | ) 66 | dx = torch.empty_like(do) if not inplace else do 67 | if inplace: 68 | dx1, dx2 = do1, do2 69 | else: 70 | dx_ro = dx[..., :rotary_dim] 71 | dx1, dx2 = ( 72 | dx_ro.chunk(2, dim=-1) 73 | if not ctx.interleaved 74 | else (dx_ro[..., ::2], dx_ro[..., 1::2]) 75 | ) 76 | rotary_emb.apply_rotary( 77 | do1, 78 | do2, 79 | rearrange(cos[:seqlen], "s d -> s 1 d"), 80 | rearrange(sin[:seqlen], "s d -> s 1 d"), 81 | dx1, 82 | dx2, 83 | True, 84 | ) 85 | if not inplace and rotary_dim < headdim: 86 | dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) 87 | return dx, None, None, None, None 88 | 89 | 90 | apply_rotary_emb_func = ApplyRotaryEmb.apply 91 | 92 | -------------------------------------------------------------------------------- /lit_gpt/model.py: -------------------------------------------------------------------------------- 1 | """Full definition of a GPT NeoX Language Model, all of it in this single file. 2 | 3 | Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and 4 | https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. 5 | """ 6 | import math 7 | from typing import Any, List, Optional, Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | from lightning_utilities.core.imports import RequirementCache 12 | from typing_extensions import Self 13 | from flash_attn import flash_attn_func 14 | from lit_gpt.config import Config 15 | from xformers.ops import SwiGLU 16 | from .fused_rotary_embedding import apply_rotary_emb_func 17 | RoPECache = Tuple[torch.Tensor, torch.Tensor] 18 | KVCache = Tuple[torch.Tensor, torch.Tensor] 19 | FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1") 20 | 21 | 22 | class GPT(nn.Module): 23 | def __init__(self, config: Config) -> None: 24 | super().__init__() 25 | assert config.padded_vocab_size is not None 26 | self.config = config 27 | 28 | self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) 29 | self.transformer = nn.ModuleDict( 30 | dict( 31 | wte=nn.Embedding(config.padded_vocab_size, config.n_embd), 32 | h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), 33 | ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), 34 | ) 35 | ) 36 | self.rope_cache: Optional[RoPECache] = None 37 | self.mask_cache: Optional[torch.Tensor] = None 38 | self.kv_caches: List[KVCache] = [] 39 | 40 | def _init_weights(self, module: nn.Module, n_layer) -> None: 41 | """Meant to be used with `gpt.apply(gpt._init_weights)`.""" 42 | # GPT-NeoX https://arxiv.org/pdf/2204.06745.pdf 43 | if isinstance(module, nn.Embedding): 44 | torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) 45 | # RWKV: set it to 1e-4 46 | # torch.nn.init.uniform_(module.weight, -1e-4, 1e-4) 47 | elif isinstance(module, nn.Linear): 48 | torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) 49 | if module.bias is not None: 50 | torch.nn.init.zeros_(module.bias) 51 | # GPT-NeoX 52 | for name, p in module.named_parameters(): 53 | if (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU) or (name=="proj.weight" and isinstance(module, CausalSelfAttention))): #if use xformer swiglu, fc2 layer will be renamed to w3 54 | nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer) 55 | 56 | 57 | def reset_cache(self) -> None: 58 | self.kv_caches.clear() 59 | if self.mask_cache is not None and self.mask_cache.device.type == "xla": 60 | # https://github.com/Lightning-AI/lit-gpt/pull/83#issuecomment-1558150179 61 | self.rope_cache = None 62 | self.mask_cache = None 63 | 64 | def forward( 65 | self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None 66 | ) -> torch.Tensor: 67 | B, T = idx.size() 68 | use_kv_cache = input_pos is not None 69 | 70 | block_size = self.config.block_size 71 | if max_seq_length is None: 72 | max_seq_length = block_size 73 | if use_kv_cache: # not relevant otherwise 74 | assert ( 75 | max_seq_length >= T 76 | ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" 77 | assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" 78 | assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}" 79 | 80 | if self.rope_cache is None: 81 | self.rope_cache = self.build_rope_cache(idx) 82 | # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask 83 | # for the kv-cache support (only during inference), we only create it in that situation 84 | # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 85 | if use_kv_cache and self.mask_cache is None: 86 | self.mask_cache = self.build_mask_cache(idx) 87 | 88 | cos, sin = self.rope_cache 89 | if use_kv_cache: 90 | 91 | cos = cos.index_select(0, input_pos) 92 | sin = sin.index_select(0, input_pos) 93 | mask = self.mask_cache.index_select(2, input_pos) 94 | mask = mask[:, :, :, :max_seq_length] 95 | else: 96 | cos = cos[:T] 97 | sin = sin[:T] 98 | mask = None 99 | 100 | # forward the model itself 101 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 102 | 103 | if not use_kv_cache: 104 | for block in self.transformer.h: 105 | x, *_ = block(x, (cos, sin), max_seq_length) 106 | else: 107 | self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2) 108 | for i, block in enumerate(self.transformer.h): 109 | x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i]) 110 | 111 | x = self.transformer.ln_f(x) 112 | 113 | return self.lm_head(x) # (b, t, vocab_size) 114 | 115 | @classmethod 116 | def from_name(cls, name: str, **kwargs: Any) -> Self: 117 | return cls(Config.from_name(name, **kwargs)) 118 | 119 | def build_rope_cache(self, idx: torch.Tensor) -> RoPECache: 120 | return build_rope_cache( 121 | seq_len=self.config.block_size, 122 | n_elem=int(self.config.rotary_percentage * self.config.head_size), 123 | dtype=torch.bfloat16, 124 | device=idx.device, 125 | condense_ratio=self.config.condense_ratio, 126 | ) 127 | 128 | def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor: 129 | ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool) 130 | return torch.tril(ones).unsqueeze(0).unsqueeze(0) 131 | 132 | def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]: 133 | B = idx.size(0) 134 | heads = 1 if self.config.n_query_groups == 1 else self.config.n_query_groups 135 | 136 | k_cache_shape = ( 137 | B, 138 | max_seq_length, 139 | heads, 140 | rope_cache_length + self.config.head_size - int(self.config.rotary_percentage * self.config.head_size), 141 | ) 142 | v_cache_shape = (B, max_seq_length, heads, self.config.head_size) 143 | device = idx.device 144 | return [ 145 | (torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device)) 146 | for _ in range(self.config.n_layer) 147 | ] 148 | 149 | 150 | class Block(nn.Module): 151 | def __init__(self, config: Config) -> None: 152 | super().__init__() 153 | self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) 154 | self.attn = CausalSelfAttention(config) 155 | if not config.shared_attention_norm: 156 | self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) 157 | self.mlp = config.mlp_class(config) 158 | self.config = config 159 | def forward( 160 | self, 161 | x: torch.Tensor, 162 | rope: RoPECache, 163 | max_seq_length: int, 164 | mask: Optional[torch.Tensor] = None, 165 | input_pos: Optional[torch.Tensor] = None, 166 | kv_cache: Optional[KVCache] = None, 167 | ) -> Tuple[torch.Tensor, Optional[KVCache]]: 168 | 169 | n_1 = self.norm_1(x) 170 | h, new_kv_cache = self.attn(n_1, rope, max_seq_length, mask, input_pos, kv_cache) 171 | if self.config.parallel_residual: 172 | n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) 173 | x = x + h + self.mlp(n_2) 174 | else: 175 | if self.config.shared_attention_norm: 176 | raise NotImplementedError( 177 | "No checkpoint amongst the ones we support uses this configuration" 178 | " (non-parallel residual and shared attention norm)." 179 | ) 180 | 181 | x = x + h 182 | x = x + self.mlp(self.norm_2(x)) 183 | return x, new_kv_cache 184 | 185 | 186 | class CausalSelfAttention(nn.Module): 187 | def __init__(self, config: Config) -> None: 188 | super().__init__() 189 | shape = (config.n_head + 2 * config.n_query_groups) * config.head_size 190 | # key, query, value projections for all heads, but in a batch 191 | self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) 192 | # output projection 193 | self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 194 | 195 | self.config = config 196 | 197 | def forward( 198 | self, 199 | x: torch.Tensor, 200 | rope: RoPECache, 201 | max_seq_length: int, 202 | mask: Optional[torch.Tensor] = None, 203 | input_pos: Optional[torch.Tensor] = None, 204 | kv_cache: Optional[KVCache] = None, 205 | ) -> Tuple[torch.Tensor, Optional[KVCache]]: 206 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 207 | 208 | qkv = self.attn(x) 209 | 210 | # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) 211 | q_per_kv = self.config.n_head // self.config.n_query_groups 212 | total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value 213 | qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) # (B, T, n_query_groups, total_qkv, hs) 214 | # qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) 215 | 216 | # split batched computation into three 217 | q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2) 218 | 219 | # repeat k and v if necessary 220 | # Peiyuan: we do not need to do this as flash attention 2 already support GQA 221 | # if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!) 222 | # # for MHA this is a no-op 223 | # k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 224 | # v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 225 | 226 | q = q.reshape(B, T, -1, self.config.head_size) # (B, T, nh_q, hs) 227 | k = k.reshape(B, T, -1, self.config.head_size) 228 | v = v.reshape(B, T, -1, self.config.head_size) 229 | 230 | cos, sin = rope 231 | 232 | # apply rope in fp32 significanly stabalize training 233 | # fused rope expect (batch_size, seqlen, nheads, headdim) 234 | q = apply_rotary_emb_func(q, cos, sin, False, True) 235 | k = apply_rotary_emb_func(k, cos, sin, False, True) 236 | 237 | # n_elem = int(self.config.rotary_percentage * self.config.head_size) 238 | 239 | # q_roped = apply_rope(q[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2)) 240 | # k_roped = apply_rope(k[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2)) 241 | # print( (q_roped - q).sum()) 242 | # q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) 243 | # k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) 244 | 245 | if kv_cache is not None: 246 | cache_k, cache_v = kv_cache 247 | cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) 248 | # check if reached token limit 249 | if input_pos[-1] >= max_seq_length: 250 | input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) 251 | # shift 1 position to the left 252 | cache_k = torch.roll(cache_k, -1, dims=1) 253 | cache_v = torch.roll(cache_v, -1, dims=1) 254 | 255 | k = cache_k.index_copy_(1, input_pos, k) 256 | v = cache_v.index_copy_(1, input_pos, v) 257 | kv_cache = k, v 258 | 259 | y = self.scaled_dot_product_attention(q, k, v, mask=mask) 260 | 261 | y = y.reshape(B, T, C) # re-assemble all head outputs side by side 262 | 263 | # output projection 264 | y = self.proj(y) 265 | 266 | return y, kv_cache 267 | 268 | def scaled_dot_product_attention( 269 | self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None 270 | ): 271 | scale = 1.0 / math.sqrt(self.config.head_size) 272 | 273 | if ( 274 | FlashAttention2Available 275 | and mask is None 276 | and q.device.type == "cuda" 277 | and q.dtype in (torch.float16, torch.bfloat16) 278 | ): 279 | from flash_attn import flash_attn_func 280 | 281 | return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=True) 282 | q = q.transpose(1, 2) 283 | k = k.transpose(1, 2) 284 | v = v.transpose(1, 2) 285 | if q.size() != k.size(): 286 | k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1) 287 | v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1) 288 | y = torch.nn.functional.scaled_dot_product_attention( 289 | q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None 290 | ) 291 | return y.transpose(1, 2) 292 | 293 | 294 | class GptNeoxMLP(nn.Module): 295 | def __init__(self, config: Config) -> None: 296 | super().__init__() 297 | self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 298 | self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) 299 | 300 | def forward(self, x: torch.Tensor) -> torch.Tensor: 301 | x = self.fc(x) 302 | x = torch.nn.functional.gelu(x) 303 | return self.proj(x) 304 | 305 | 306 | class LLaMAMLP(nn.Module): 307 | def __init__(self, config: Config) -> None: 308 | super().__init__() 309 | # self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 310 | # self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 311 | # self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) 312 | self.swiglu = SwiGLU(config.n_embd,config.intermediate_size, bias=False, _pack_weights=False) 313 | def forward(self, x: torch.Tensor) -> torch.Tensor: 314 | # x_fc_1 = self.fc_1(x) 315 | # x_fc_2 = self.fc_2(x) 316 | # x = torch.nn.functional.silu(x_fc_1) * x_fc_2 317 | # return self.proj(x) 318 | return self.swiglu(x) 319 | 320 | 321 | def build_rope_cache( 322 | seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1 323 | ) -> RoPECache: 324 | """Enhanced Transformer with Rotary Position Embedding. 325 | 326 | Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ 327 | transformers/rope/__init__.py. MIT License: 328 | https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. 329 | """ 330 | # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ 331 | theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem)) 332 | 333 | # Create position indexes `[0, 1, ..., seq_len - 1]` 334 | seq_idx = torch.arange(seq_len, device=device) / condense_ratio 335 | 336 | # Calculate the product of position index and $\theta_i$ 337 | idx_theta = torch.outer(seq_idx, theta) 338 | 339 | cos, sin = torch.cos(idx_theta), torch.sin(idx_theta) 340 | 341 | # added by peiyuan to ensure same data type with q, k, to use fused rotary embedding 342 | if dtype == torch.bfloat16: 343 | return cos.bfloat16(), sin.bfloat16() 344 | # this is to mimic the behaviour of complex32, else we will get different results 345 | if dtype in (torch.float16, torch.bfloat16, torch.int8): 346 | return cos.half(), sin.half() 347 | return cos, sin 348 | 349 | 350 | def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: 351 | head_size = x.size(-1) 352 | x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) 353 | x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) 354 | rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) 355 | roped = (x * cos) + (rotated * sin) 356 | return roped.type_as(x) 357 | -------------------------------------------------------------------------------- /lit_gpt/packed_dataset.py: -------------------------------------------------------------------------------- 1 | # Very loosely inspired by indexed_dataset in Fairseq, Megatron 2 | # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py 3 | 4 | 5 | import os 6 | import random 7 | import struct 8 | 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import IterableDataset, get_worker_info 12 | 13 | dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16} 14 | 15 | 16 | def code(dtype): 17 | for k in dtypes: 18 | if dtypes[k] == dtype: 19 | return k 20 | raise ValueError(dtype) 21 | 22 | 23 | HDR_MAGIC = b"LITPKDS" 24 | HDR_SIZE = 24 # bytes 25 | 26 | 27 | class PackedDataset(IterableDataset): 28 | def __init__( 29 | self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0 30 | ): 31 | self._filenames = filenames 32 | self._n_chunks = n_chunks 33 | self._block_size = block_size 34 | self._seed = seed 35 | self._shuffle = shuffle 36 | self._wrap = wrap 37 | self._num_processes = num_processes 38 | self._process_rank = process_rank 39 | 40 | def __iter__(self): 41 | worker_info = get_worker_info() 42 | num_workers = worker_info.num_workers if worker_info is not None else 1 43 | worker_id = worker_info.id if worker_info is not None else 0 44 | num_shards = num_workers * self._num_processes 45 | shard_id = self._process_rank * num_workers + worker_id 46 | 47 | max_num_files = len(self._filenames) // num_shards * num_shards 48 | filenames = self._filenames[shard_id:max_num_files:num_shards] 49 | 50 | return PackedDatasetIterator( 51 | filenames=filenames, 52 | n_chunks=self._n_chunks, 53 | block_size=self._block_size, 54 | seed=self._seed, 55 | shuffle=self._shuffle, 56 | wrap=self._wrap, 57 | ) 58 | 59 | 60 | class PackedDatasetBuilder(object): 61 | def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None): 62 | if dtype == "auto": 63 | if vocab_size is None: 64 | raise ValueError("vocab_size cannot be None when dtype='auto'") 65 | if vocab_size is not None and vocab_size < 65500: 66 | self._dtype = np.uint16 67 | else: 68 | self._dtype = np.int32 69 | else: 70 | self._dtype = dtype 71 | self._counter = 0 72 | self._chunk_size = chunk_size 73 | self._outdir = outdir 74 | self._prefix = prefix 75 | self._sep_token = sep_token 76 | self._arr = np.zeros(self._chunk_size, dtype=self._dtype) 77 | self._arr.fill(self._sep_token) 78 | self._idx = 0 79 | self._version = 1 80 | self._filenames = [] 81 | 82 | def _write_chunk(self): 83 | filename = f"{self._prefix}_{self._counter:010d}.bin" 84 | filename = os.path.join(self._outdir, filename) 85 | 86 | with open(filename, "wb") as f: 87 | f.write(HDR_MAGIC) 88 | f.write(struct.pack(" self._chunk_size: 108 | part_len = self._chunk_size - self._idx 109 | self._arr[self._idx : self._idx + part_len] = arr[:part_len] 110 | self._write_chunk() 111 | arr = arr[part_len:] 112 | 113 | arr_len = arr.shape[0] 114 | self._arr[self._idx : self._idx + arr_len] = arr 115 | self._idx += arr_len 116 | 117 | def write_reminder(self): 118 | self._write_chunk() 119 | 120 | 121 | class PackedDatasetIterator: 122 | def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): 123 | self._seed = seed 124 | self._shuffle = shuffle 125 | self._rng = np.random.default_rng(seed) if shuffle else None 126 | self._block_idxs = None 127 | 128 | self._wrap = wrap 129 | 130 | # TODO: instead of filenames, we could have a single text stream 131 | # (or text file) with the sequence of all files to be 132 | # fetched/loaded. 133 | self._filenames = filenames 134 | self._file_idx = 0 135 | 136 | self._n_chunks = n_chunks 137 | 138 | self._dtype = None 139 | self._block_size = block_size 140 | self._n_blocks = None 141 | 142 | self._mmaps = [] 143 | self._buffers = [] 144 | 145 | self._block_idxs = [] 146 | self._curr_idx = 0 147 | 148 | self._load_n_chunks() 149 | 150 | def _read_header(self, path): 151 | with open(path, "rb") as f: 152 | magic = f.read(len(HDR_MAGIC)) 153 | assert magic == HDR_MAGIC, "File doesn't match expected format." 154 | version = struct.unpack(" len(self._filenames[self._file_idx :]): 171 | # if not self._wrap: 172 | # raise StopIteration 173 | self._file_idx = 0 174 | 175 | for i in range(self._n_chunks): 176 | filename = self._filenames[self._file_idx + i] 177 | if self._dtype is None: 178 | self._dtype, self._chunk_size = self._read_header(filename) 179 | self._n_blocks = self._chunk_size // self._block_size 180 | # TODO: check header matches with previous files 181 | mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) 182 | self._mmaps.append(mmap) 183 | self._buffers.append(memoryview(mmap)) 184 | 185 | self._file_idx += self._n_chunks 186 | n_all_blocks = self._n_chunks * self._n_blocks 187 | 188 | self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks) 189 | 190 | self._curr_idx = 0 191 | 192 | def __del__(self): 193 | self._close_mmaps() 194 | del self._mmaps 195 | del self._buffers 196 | 197 | def __iter__(self): 198 | return self 199 | 200 | def __next__(self): 201 | if self._curr_idx >= len(self._block_idxs): 202 | self._load_n_chunks() 203 | # TODO: trigger fetching next next n_chunks if remote 204 | block_idx = self._block_idxs[self._curr_idx] 205 | chunk_id = block_idx // self._n_blocks 206 | buffer = self._buffers[chunk_id] 207 | elem_id = (block_idx % self._n_blocks) * self._block_size 208 | offset = np.dtype(self._dtype).itemsize * elem_id 209 | arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) 210 | self._curr_idx += 1 211 | return torch.from_numpy(arr.astype(np.int64)) 212 | 213 | 214 | class CombinedDataset(IterableDataset): 215 | def __init__(self, datasets, seed, weights=None): 216 | self._seed = seed 217 | self._datasets = datasets 218 | self._weights = weights 219 | n_datasets = len(datasets) 220 | if weights is None: 221 | self._weights = [1 / n_datasets] * n_datasets 222 | 223 | def __iter__(self): 224 | return CombinedDatasetIterator(self._datasets, self._seed, self._weights) 225 | 226 | 227 | class CombinedDatasetIterator: 228 | def __init__(self, datasets, seed, weights): 229 | self._datasets = [iter(el) for el in datasets] 230 | self._weights = weights 231 | self._rng = random.Random(seed) 232 | 233 | def __next__(self): 234 | (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) 235 | return next(dataset) 236 | -------------------------------------------------------------------------------- /lit_gpt/tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | class Tokenizer: 9 | def __init__(self, checkpoint_dir: Path) -> None: 10 | # some checkpoints have both files, `.model` takes precedence 11 | if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): 12 | from sentencepiece import SentencePieceProcessor 13 | 14 | self.processor = SentencePieceProcessor(model_file=str(vocabulary_path)) 15 | self.backend = "sentencepiece" 16 | self.bos_id = self.processor.bos_id() 17 | self.eos_id = self.processor.eos_id() 18 | elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file(): 19 | from tokenizers import Tokenizer as HFTokenizer 20 | 21 | self.processor = HFTokenizer.from_file(str(vocabulary_path)) 22 | self.backend = "huggingface" 23 | with open(checkpoint_dir / "tokenizer_config.json") as fp: 24 | config = json.load(fp) 25 | bos_token = config.get("bos_token") 26 | self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None 27 | self.eos_id = self.token_to_id(config["eos_token"]) 28 | else: 29 | raise NotImplementedError 30 | 31 | @property 32 | def vocab_size(self) -> int: 33 | if self.backend == "huggingface": 34 | return self.processor.get_vocab_size(with_added_tokens=False) 35 | if self.backend == "sentencepiece": 36 | return self.processor.vocab_size() 37 | raise RuntimeError 38 | 39 | def token_to_id(self, token: str) -> int: 40 | if self.backend == "huggingface": 41 | id_ = self.processor.token_to_id(token) 42 | elif self.backend == "sentencepiece": 43 | id_ = self.processor.piece_to_id(token) 44 | else: 45 | raise RuntimeError 46 | if id_ is None: 47 | raise ValueError(f"token {token!r} not found in the collection.") 48 | return id_ 49 | 50 | def encode( 51 | self, 52 | string: str, 53 | device: Optional[torch.device] = None, 54 | bos: bool = False, 55 | eos: bool = True, 56 | max_length: int = -1, 57 | ) -> torch.Tensor: 58 | if self.backend == "huggingface": 59 | tokens = self.processor.encode(string).ids 60 | elif self.backend == "sentencepiece": 61 | tokens = self.processor.encode(string) 62 | else: 63 | raise RuntimeError 64 | if bos: 65 | bos_id = self.bos_id 66 | if bos_id is None: 67 | raise NotImplementedError("This tokenizer does not defined a bos token") 68 | tokens = [bos_id] + tokens 69 | if eos: 70 | tokens = tokens + [self.eos_id] 71 | if max_length > 0: 72 | tokens = tokens[:max_length] 73 | return torch.tensor(tokens, dtype=torch.int, device=device) 74 | 75 | def decode(self, tensor: torch.Tensor) -> str: 76 | tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() 77 | return self.processor.decode(tokens) 78 | -------------------------------------------------------------------------------- /lit_gpt/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for training and inference.""" 2 | 3 | import pickle 4 | import sys 5 | import warnings 6 | from contextlib import contextmanager 7 | from functools import partial 8 | from io import BytesIO 9 | from pathlib import Path 10 | from types import MethodType 11 | from typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, Union 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.utils._device 16 | from lightning.fabric.loggers import CSVLogger 17 | from torch.serialization import normalize_storage_type 18 | 19 | 20 | def find_multiple(n: int, k: int) -> int: 21 | assert k > 0 22 | if n % k == 0: 23 | return n 24 | return n + k - (n % k) 25 | 26 | 27 | def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: 28 | return sum(p.numel() for p in module.parameters() if requires_grad is None or p.requires_grad == requires_grad) 29 | 30 | 31 | @contextmanager 32 | def quantization(mode: Optional[str] = None): 33 | if mode is None: 34 | yield 35 | return 36 | 37 | if mode == "bnb.int8": 38 | from quantize.bnb import InferenceLinear8bitLt 39 | 40 | quantized_linear_cls = InferenceLinear8bitLt 41 | elif mode == "bnb.fp4": 42 | from quantize.bnb import Linear4bit 43 | 44 | # Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses 45 | class QuantizedLinear(Linear4bit): 46 | def __init__(self, *args, **kwargs): 47 | super().__init__(*args, quant_type="fp4", compress_statistics=False, **kwargs) 48 | 49 | quantized_linear_cls = QuantizedLinear 50 | elif mode == "bnb.fp4-dq": 51 | from quantize.bnb import Linear4bit 52 | 53 | class QuantizedLinear(Linear4bit): 54 | def __init__(self, *args, **kwargs): 55 | super().__init__(*args, quant_type="fp4", compress_statistics=True, **kwargs) 56 | 57 | quantized_linear_cls = QuantizedLinear 58 | elif mode == "bnb.nf4": 59 | from quantize.bnb import Linear4bit 60 | 61 | class QuantizedLinear(Linear4bit): 62 | def __init__(self, *args, **kwargs): 63 | super().__init__(*args, quant_type="nf4", compress_statistics=False, **kwargs) 64 | 65 | quantized_linear_cls = QuantizedLinear 66 | elif mode == "bnb.nf4-dq": 67 | from quantize.bnb import Linear4bit 68 | 69 | class QuantizedLinear(Linear4bit): 70 | def __init__(self, *args, **kwargs): 71 | super().__init__(*args, quant_type="nf4", compress_statistics=True, **kwargs) 72 | 73 | quantized_linear_cls = QuantizedLinear 74 | elif mode == "gptq.int4": 75 | from quantize.gptq import ColBlockQuantizedLinear 76 | 77 | class QuantizedLinear(ColBlockQuantizedLinear): 78 | def __init__(self, *args, **kwargs): 79 | super().__init__(*args, bits=4, tile_cols=-1, **kwargs) 80 | 81 | quantized_linear_cls = QuantizedLinear 82 | else: 83 | raise ValueError(f"Unknown quantization mode: {mode}") 84 | 85 | torch_linear_cls = torch.nn.Linear 86 | torch.nn.Linear = quantized_linear_cls 87 | yield 88 | torch.nn.Linear = torch_linear_cls 89 | 90 | 91 | # this is taken from torchhacks https://github.com/lernapparat/torchhacks 92 | 93 | 94 | class NotYetLoadedTensor: 95 | def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): 96 | self.metatensor = metatensor 97 | self.archiveinfo = archiveinfo 98 | self.storageinfo = storageinfo 99 | self.rebuild_args = rebuild_args 100 | 101 | @classmethod 102 | def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None): 103 | ret = func(*args) 104 | if isinstance(ret, NotYetLoadedTensor): 105 | old_lt = ret._load_tensor 106 | 107 | def _load_tensor(): 108 | t = old_lt() 109 | return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) 110 | 111 | ret._load_tensor = _load_tensor 112 | return ret 113 | return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) 114 | 115 | @classmethod 116 | def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None): 117 | if isinstance(data, NotYetLoadedTensor): 118 | old_lt = data._load_tensor 119 | 120 | def _load_tensor(): 121 | t = old_lt() 122 | return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) 123 | 124 | data._load_tensor = _load_tensor 125 | return data 126 | return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) 127 | 128 | @classmethod 129 | def rebuild_tensor_v2( 130 | cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, *, archiveinfo=None 131 | ): 132 | rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata) 133 | metatensor = torch._utils._rebuild_tensor_v2( 134 | storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata 135 | ) 136 | storageinfo = storage.archiveinfo 137 | return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) 138 | 139 | def _load_tensor(self): 140 | name, storage_cls, fn, device, size = self.storageinfo 141 | dtype = self.metatensor.dtype 142 | 143 | uts = ( 144 | self.archiveinfo.zipfile_context.zf.get_storage_from_record( 145 | f"data/{fn}", size * torch._utils._element_size(dtype), torch.UntypedStorage 146 | ) 147 | ._typed_storage() 148 | ._untyped_storage 149 | ) 150 | with warnings.catch_warnings(): 151 | warnings.simplefilter("ignore") 152 | storage = torch.storage.TypedStorage(wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True) 153 | return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) 154 | 155 | @classmethod 156 | def __torch_function__(cls, func, types, args=(), kwargs=None): 157 | if kwargs is None: 158 | kwargs = {} 159 | loaded_args = [(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args] 160 | return func(*loaded_args, **kwargs) 161 | # gc.collect would be costly here, maybe do it optionally 162 | 163 | def __getattr__(self, name): 164 | # properties 165 | ## TODO: device, is_...?? 166 | ## TODO: mH, mT, H, T, data, imag, real 167 | ## name ??? 168 | if name in { 169 | "dtype", 170 | "grad", 171 | "grad_fn", 172 | "layout", 173 | "names", 174 | "ndim", 175 | "output_nr", 176 | "requires_grad", 177 | "retains_grad", 178 | "shape", 179 | "volatile", 180 | }: 181 | return getattr(self.metatensor, name) 182 | if name in {"size"}: 183 | return getattr(self.metatensor, name) 184 | # materializing with contiguous is needed for quantization 185 | if name in {"contiguous"}: 186 | return getattr(self._load_tensor(), name) 187 | 188 | raise AttributeError(f"{type(self)} does not have {name}") 189 | 190 | def __repr__(self): 191 | return f"NotYetLoadedTensor({repr(self.metatensor)})" 192 | 193 | 194 | class LazyLoadingUnpickler(pickle.Unpickler): 195 | def __init__(self, file, zipfile_context): 196 | super().__init__(file) 197 | self.zipfile_context = zipfile_context 198 | 199 | def find_class(self, module, name): 200 | res = super().find_class(module, name) 201 | if module == "torch._utils" and name == "_rebuild_tensor_v2": 202 | return partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) 203 | if module == "torch._tensor" and name == "_rebuild_from_type_v2": 204 | return partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) 205 | if module == "torch._utils" and name == "_rebuild_parameter": 206 | return partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) 207 | return res 208 | 209 | def persistent_load(self, pid): 210 | name, cls, fn, device, size = pid 211 | with warnings.catch_warnings(): 212 | warnings.simplefilter("ignore") 213 | s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") 214 | s.archiveinfo = pid 215 | return s 216 | 217 | 218 | class lazy_load: 219 | def __init__(self, fn): 220 | self.zf = torch._C.PyTorchFileReader(str(fn)) 221 | with BytesIO(self.zf.get_record("data.pkl")) as pkl: 222 | mup = LazyLoadingUnpickler(pkl, self) 223 | self.sd = mup.load() 224 | 225 | def __enter__(self): 226 | return self.sd 227 | 228 | def __exit__(self, exc_type, exc_val, exc_tb): 229 | del self.zf # I don't think there is a way to force closing... 230 | self.zf = None 231 | 232 | 233 | def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: 234 | files = { 235 | "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), 236 | "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), 237 | "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( 238 | checkpoint_dir / "tokenizer.model" 239 | ).is_file(), 240 | "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), 241 | } 242 | if checkpoint_dir.is_dir(): 243 | if all(files.values()): 244 | # we're good 245 | return 246 | problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" 247 | else: 248 | problem = " is not a checkpoint directory" 249 | 250 | # list locally available checkpoints 251 | available = list(Path("checkpoints").glob("*/*")) 252 | if available: 253 | options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available]) 254 | extra = f"\nYou have downloaded locally:{options}\n" 255 | else: 256 | extra = "" 257 | 258 | error_message = ( 259 | f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." 260 | "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n" 261 | f"{extra}\nSee all download options by running:\n python scripts/download.py" 262 | ) 263 | print(error_message, file=sys.stderr) 264 | raise SystemExit(1) 265 | 266 | 267 | class SavingProxyForStorage: 268 | def __init__(self, obj, saver, protocol_version=5): 269 | self.protocol_version = protocol_version 270 | self.saver = saver 271 | if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): 272 | raise TypeError(f"expected storage, not {type(obj)}") 273 | 274 | # this logic is taken from PyTorch 2.0+ torch/serialization.py 275 | if isinstance(obj, torch.storage.TypedStorage): 276 | # PT upstream wants to deprecate this eventually... 277 | storage = obj._untyped_storage 278 | storage_type_str = obj._pickle_storage_type() 279 | storage_type = getattr(torch, storage_type_str) 280 | storage_numel = obj._size() 281 | else: 282 | storage = obj 283 | storage_type = normalize_storage_type(type(obj)) 284 | storage_numel = storage.nbytes() 285 | 286 | storage_key = saver._write_storage_and_return_key(storage) 287 | location = torch.serialization.location_tag(storage) 288 | 289 | self.storage_info = ("storage", storage_type, storage_key, location, storage_numel) 290 | 291 | def __reduce_ex__(self, protocol_version): 292 | assert False, "this should be handled with out of band" 293 | 294 | 295 | class SavingProxyForTensor: 296 | def __init__(self, tensor, saver, protocol_version=5): 297 | self.protocol_version = protocol_version 298 | self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(protocol_version) 299 | assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" 300 | storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) 301 | self.reduce_args = (storage_proxy, *other_reduce_args) 302 | 303 | def __reduce_ex__(self, protocol_version): 304 | if protocol_version != self.protocol_version: 305 | raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}") 306 | return self.reduce_ret_fn, self.reduce_args 307 | 308 | 309 | class IncrementalPyTorchPickler(pickle.Pickler): 310 | def __init__(self, saver, *args, **kwargs): 311 | super().__init__(*args, **kwargs) 312 | self.storage_dtypes = {} 313 | self.saver = saver 314 | self.id_map = {} 315 | 316 | # this logic is taken from PyTorch 2.0+ torch/serialization.py 317 | def persistent_id(self, obj): 318 | # FIXME: the docs say that persistent_id should only return a string 319 | # but torch store returns tuples. This works only in the binary protocol 320 | # see 321 | # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects 322 | # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 323 | if isinstance(obj, SavingProxyForStorage): 324 | return obj.storage_info 325 | 326 | if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): 327 | if isinstance(obj, torch.storage.TypedStorage): 328 | # TODO: Once we decide to break serialization FC, this case 329 | # can be deleted 330 | storage = obj._untyped_storage 331 | storage_dtype = obj.dtype 332 | storage_type_str = obj._pickle_storage_type() 333 | storage_type = getattr(torch, storage_type_str) 334 | storage_numel = obj._size() 335 | 336 | else: 337 | storage = obj 338 | storage_dtype = torch.uint8 339 | storage_type = normalize_storage_type(type(obj)) 340 | storage_numel = storage.nbytes() 341 | 342 | # If storage is allocated, ensure that any other saved storages 343 | # pointing to the same data all have the same dtype. If storage is 344 | # not allocated, don't perform this check 345 | if storage.data_ptr() != 0: 346 | if storage.data_ptr() in self.storage_dtypes: 347 | if storage_dtype != self.storage_dtypes[storage.data_ptr()]: 348 | raise RuntimeError( 349 | "Cannot save multiple tensors or storages that view the same data as different types" 350 | ) 351 | else: 352 | self.storage_dtypes[storage.data_ptr()] = storage_dtype 353 | 354 | storage_key = self.id_map.get(storage._cdata) 355 | if storage_key is None: 356 | storage_key = self.saver._write_storage_and_return_key(storage) 357 | self.id_map[storage._cdata] = storage_key 358 | location = torch.serialization.location_tag(storage) 359 | 360 | return ("storage", storage_type, storage_key, location, storage_numel) 361 | 362 | return None 363 | 364 | 365 | class incremental_save: 366 | def __init__(self, name): 367 | self.name = name 368 | self.zipfile = torch._C.PyTorchFileWriter(str(name)) 369 | self.has_saved = False 370 | self.next_key = 0 371 | 372 | def __enter__(self): 373 | return self 374 | 375 | def store_early(self, tensor): 376 | if isinstance(tensor, torch.Tensor): 377 | return SavingProxyForTensor(tensor, self) 378 | raise TypeError(f"can only store tensors early, not {type(tensor)}") 379 | 380 | def save(self, obj): 381 | if self.has_saved: 382 | raise RuntimeError("have already saved") 383 | # Write the pickle data for `obj` 384 | data_buf = BytesIO() 385 | pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) 386 | pickler.dump(obj) 387 | data_value = data_buf.getvalue() 388 | self.zipfile.write_record("data.pkl", data_value, len(data_value)) 389 | self.has_saved = True 390 | 391 | def _write_storage_and_return_key(self, storage): 392 | if self.has_saved: 393 | raise RuntimeError("have already saved") 394 | key = self.next_key 395 | self.next_key += 1 396 | name = f"data/{key}" 397 | if storage.device.type != "cpu": 398 | storage = storage.cpu() 399 | num_bytes = storage.nbytes() 400 | self.zipfile.write_record(name, storage.data_ptr(), num_bytes) 401 | return key 402 | 403 | def __exit__(self, type, value, traceback): 404 | self.zipfile.write_end_of_file() 405 | 406 | 407 | T = TypeVar("T") 408 | 409 | 410 | def step_csv_logger(*args: Any, cls: Type[T] = CSVLogger, **kwargs: Any) -> T: 411 | logger = cls(*args, **kwargs) 412 | 413 | def merge_by(dicts, key): 414 | from collections import defaultdict 415 | 416 | out = defaultdict(dict) 417 | for d in dicts: 418 | if key in d: 419 | out[d[key]].update(d) 420 | return [v for _, v in sorted(out.items())] 421 | 422 | def save(self) -> None: 423 | """Overridden to merge CSV by the step number.""" 424 | import csv 425 | 426 | if not self.metrics: 427 | return 428 | metrics = merge_by(self.metrics, "step") 429 | keys = sorted({k for m in metrics for k in m}) 430 | with self._fs.open(self.metrics_file_path, "w", newline="") as f: 431 | writer = csv.DictWriter(f, fieldnames=keys) 432 | writer.writeheader() 433 | writer.writerows(metrics) 434 | 435 | logger.experiment.save = MethodType(save, logger.experiment) 436 | 437 | return logger 438 | 439 | 440 | def chunked_cross_entropy( 441 | logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128 442 | ) -> torch.Tensor: 443 | # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate 444 | # the memory usage in fine-tuning settings with low number of parameters. 445 | # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing 446 | # the memory spike's magnitude 447 | 448 | # lm_head was chunked (we are fine-tuning) 449 | if isinstance(logits, list): 450 | # don't want to chunk cross entropy 451 | if chunk_size == 0: 452 | logits = torch.cat(logits, dim=1) 453 | logits = logits.reshape(-1, logits.size(-1)) 454 | targets = targets.reshape(-1) 455 | return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) 456 | 457 | # chunk cross entropy 458 | logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] 459 | target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)] 460 | loss_chunks = [ 461 | torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") 462 | for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) 463 | ] 464 | return torch.cat(loss_chunks).mean() 465 | 466 | # no chunking at all 467 | logits = logits.reshape(-1, logits.size(-1)) 468 | targets = targets.reshape(-1) 469 | if chunk_size == 0: 470 | return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) 471 | 472 | # lm_head wasn't chunked, chunk cross entropy 473 | logit_chunks = logits.split(chunk_size) 474 | target_chunks = targets.split(chunk_size) 475 | loss_chunks = [ 476 | torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") 477 | for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) 478 | ] 479 | return torch.cat(loss_chunks).mean() 480 | 481 | 482 | def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: 483 | for checkpoint_name, attribute_name in mapping.items(): 484 | full_checkpoint_name = prefix + checkpoint_name 485 | if full_checkpoint_name in state_dict: 486 | full_attribute_name = prefix + attribute_name 487 | state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) 488 | return state_dict 489 | 490 | 491 | def get_default_supported_precision(training: bool, tpu: bool = False) -> str: 492 | """Return default precision that is supported by the hardware. 493 | 494 | Args: 495 | training: `-mixed` or `-true` version of the precision to use 496 | tpu: whether TPU device is used 497 | 498 | Returns: 499 | default precision that is suitable for the task and is supported by the hardware 500 | """ 501 | if tpu: 502 | return "32-true" 503 | if not torch.cuda.is_available() or torch.cuda.is_bf16_supported(): 504 | return "bf16-mixed" if training else "bf16-true" 505 | return "16-mixed" if training else "16-true" 506 | -------------------------------------------------------------------------------- /pretrain/tinyllama.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | import sys 4 | import time 5 | from pathlib import Path 6 | from typing import Optional, Tuple, Union 7 | import math 8 | import lightning as L 9 | import torch 10 | from lightning.fabric.strategies import FSDPStrategy, XLAStrategy 11 | from torch.utils.data import DataLoader 12 | from functools import partial 13 | # support running without installing as a package 14 | wd = Path(__file__).parent.parent.resolve() 15 | sys.path.append(str(wd)) 16 | # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually 17 | from lit_gpt.model import GPT, Block, Config, CausalSelfAttention 18 | from lit_gpt.packed_dataset import CombinedDataset, PackedDataset 19 | from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor 20 | from lit_gpt.speed_monitor import estimate_flops, measure_flops 21 | from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load 22 | from pytorch_lightning.loggers import WandbLogger 23 | from lit_gpt import FusedCrossEntropyLoss 24 | import random 25 | 26 | model_name = "tiny_LLaMA_1b" 27 | name = "tinyllama_1b" 28 | out_dir = Path("out") / name 29 | 30 | # Hyperparameters 31 | num_of_devices = 8 32 | global_batch_size = 512 33 | learning_rate = 4e-4 34 | micro_batch_size = 8 35 | max_step = 715256 * 2 36 | warmup_steps = 2000 37 | log_step_interval = 10 38 | eval_iters = 100 39 | save_step_interval = 5000 40 | eval_step_interval = 5000 41 | 42 | 43 | weight_decay = 1e-1 44 | beta1 = 0.9 45 | beta2 = 0.95 46 | grad_clip = 1.0 47 | decay_lr = True 48 | min_lr = 4e-5 49 | 50 | batch_size = global_batch_size // num_of_devices 51 | gradient_accumulation_steps = batch_size // micro_batch_size 52 | assert gradient_accumulation_steps > 0 53 | warmup_iters = warmup_steps * gradient_accumulation_steps 54 | 55 | 56 | 57 | 58 | max_iters = max_step * gradient_accumulation_steps 59 | lr_decay_iters = max_iters 60 | log_iter_interval = log_step_interval * gradient_accumulation_steps 61 | 62 | 63 | # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. 64 | train_data_config = [ 65 | ("train_slim", 0.693584), 66 | ("train_star", 0.306416), 67 | ] 68 | 69 | val_data_config = [ 70 | ("validation", 1.0), 71 | ] 72 | 73 | hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} 74 | logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) 75 | wandb_logger = WandbLogger() 76 | 77 | 78 | def setup( 79 | devices: int = 8, 80 | train_data_dir: Path = Path("data/redpajama_sample"), 81 | val_data_dir: Optional[Path] = None, 82 | precision: Optional[str] = None, 83 | tpu: bool = False, 84 | resume: Union[bool, Path] = False, 85 | ) -> None: 86 | precision = precision or get_default_supported_precision(training=True, tpu=tpu) 87 | 88 | if devices > 1: 89 | if tpu: 90 | # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. 91 | devices = "auto" 92 | strategy = XLAStrategy(sync_module_states=False) 93 | else: 94 | strategy = FSDPStrategy( 95 | auto_wrap_policy={Block}, 96 | activation_checkpointing_policy=None, 97 | state_dict_type="full", 98 | limit_all_gathers=True, 99 | cpu_offload=False, 100 | ) 101 | else: 102 | strategy = "auto" 103 | 104 | fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) 105 | fabric.print(hparams) 106 | #fabric.launch(main, train_data_dir, val_data_dir, resume) 107 | main(fabric, train_data_dir, val_data_dir, resume) 108 | 109 | 110 | def main(fabric, train_data_dir, val_data_dir, resume): 111 | monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) 112 | 113 | if fabric.global_rank == 0: 114 | out_dir.mkdir(parents=True, exist_ok=True) 115 | 116 | config = Config.from_name(model_name) 117 | 118 | train_dataloader, val_dataloader = create_dataloaders( 119 | batch_size=micro_batch_size, 120 | block_size=config.block_size, 121 | fabric=fabric, 122 | train_data_dir=train_data_dir, 123 | val_data_dir=val_data_dir, 124 | seed=3407, 125 | ) 126 | if val_dataloader is None: 127 | train_dataloader = fabric.setup_dataloaders(train_dataloader) 128 | else: 129 | train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) 130 | 131 | fabric.seed_everything(3407) # same seed for every process to init model (FSDP) 132 | 133 | fabric.print(f"Loading model with {config.__dict__}") 134 | t0 = time.perf_counter() 135 | with fabric.init_module(empty_init=False): 136 | model = GPT(config) 137 | model.apply(partial(model._init_weights ,n_layer=config.n_layer)) 138 | 139 | 140 | fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") 141 | fabric.print(f"Total parameters {num_parameters(model):,}") 142 | 143 | model = fabric.setup(model) 144 | optimizer = torch.optim.AdamW( 145 | model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False 146 | ) 147 | # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) 148 | optimizer = fabric.setup_optimizers(optimizer) 149 | 150 | state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} 151 | 152 | if resume is True: 153 | resume = sorted(out_dir.glob("*.pth"))[-1] 154 | if resume : 155 | fabric.print(f"Resuming training from {resume}") 156 | fabric.load(resume, state) 157 | 158 | train_time = time.perf_counter() 159 | train(fabric, state, train_dataloader, val_dataloader, monitor, resume) 160 | fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") 161 | if fabric.device.type == "cuda": 162 | fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") 163 | 164 | 165 | def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): 166 | model = state["model"] 167 | optimizer = state["optimizer"] 168 | 169 | if val_dataloader is not None: 170 | validate(fabric, model, val_dataloader) # sanity check 171 | 172 | with torch.device("meta"): 173 | meta_model = GPT(model.config) 174 | # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. 175 | # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, 176 | # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead 177 | estimated_flops = estimate_flops(meta_model) * micro_batch_size 178 | fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") 179 | x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) 180 | # measured_flos run in meta. Will trigger fusedRMSNorm error 181 | #measured_flops = measure_flops(meta_model, x) 182 | #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") 183 | del meta_model, x 184 | 185 | total_lengths = 0 186 | total_t0 = time.perf_counter() 187 | 188 | if fabric.device.type == "xla": 189 | import torch_xla.core.xla_model as xm 190 | 191 | xm.mark_step() 192 | 193 | 194 | initial_iter = state["iter_num"] 195 | curr_iter = 0 196 | 197 | loss_func = FusedCrossEntropyLoss() 198 | for train_data in train_dataloader: 199 | # resume loader state. This is not elegant but it works. Should rewrite it in the future. 200 | if resume: 201 | if curr_iter < initial_iter: 202 | curr_iter += 1 203 | continue 204 | else: 205 | resume = False 206 | curr_iter = -1 207 | fabric.barrier() 208 | fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) 209 | if state["iter_num"] >= max_iters: 210 | break 211 | 212 | # determine and set the learning rate for this iteration 213 | lr = get_lr(state["iter_num"]) if decay_lr else learning_rate 214 | for param_group in optimizer.param_groups: 215 | param_group["lr"] = lr 216 | 217 | iter_t0 = time.perf_counter() 218 | 219 | input_ids = train_data[:, 0 : model.config.block_size].contiguous() 220 | targets = train_data[:, 1 : model.config.block_size + 1].contiguous() 221 | is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 222 | with fabric.no_backward_sync(model, enabled=is_accumulating): 223 | logits = model(input_ids) 224 | loss = loss_func(logits, targets) 225 | # loss = chunked_cross_entropy(logits, targets, chunk_size=0) 226 | fabric.backward(loss / gradient_accumulation_steps) 227 | 228 | if not is_accumulating: 229 | fabric.clip_gradients(model, optimizer, max_norm=grad_clip) 230 | optimizer.step() 231 | optimizer.zero_grad() 232 | state["step_count"] += 1 233 | elif fabric.device.type == "xla": 234 | xm.mark_step() 235 | state["iter_num"] += 1 236 | # input_id: B L 237 | total_lengths += input_ids.size(1) 238 | t1 = time.perf_counter() 239 | fabric.print( 240 | f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" 241 | f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" 242 | f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " 243 | # print days as well 244 | f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " 245 | ) 246 | 247 | monitor.on_train_batch_end( 248 | state["iter_num"] * micro_batch_size, 249 | t1 - total_t0, 250 | # this assumes that device FLOPs are the same and that all devices have the same batch size 251 | fabric.world_size, 252 | state["step_count"], 253 | flops_per_batch=estimated_flops, 254 | lengths=total_lengths, 255 | train_loss = loss.item() 256 | ) 257 | 258 | 259 | 260 | 261 | if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: 262 | 263 | t0 = time.perf_counter() 264 | val_loss = validate(fabric, model, val_dataloader) 265 | t1 = time.perf_counter() - t0 266 | monitor.eval_end(t1) 267 | fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") 268 | fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 269 | fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 270 | fabric.barrier() 271 | if not is_accumulating and state["step_count"] % save_step_interval == 0: 272 | checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" 273 | fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") 274 | fabric.save(checkpoint_path, state) 275 | 276 | 277 | @torch.no_grad() 278 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: 279 | fabric.print("Validating ...") 280 | model.eval() 281 | 282 | losses = torch.zeros(eval_iters, device=fabric.device) 283 | for k, val_data in enumerate(val_dataloader): 284 | if k >= eval_iters: 285 | break 286 | input_ids = val_data[:, 0 : model.config.block_size].contiguous() 287 | targets = val_data[:, 1 : model.config.block_size + 1].contiguous() 288 | logits = model(input_ids) 289 | loss = chunked_cross_entropy(logits, targets, chunk_size=0) 290 | 291 | # loss_func = FusedCrossEntropyLoss() 292 | # loss = loss_func(logits, targets) 293 | losses[k] = loss.item() 294 | 295 | out = losses.mean() 296 | 297 | model.train() 298 | return out 299 | 300 | 301 | def create_dataloader( 302 | batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" 303 | ) -> DataLoader: 304 | datasets = [] 305 | data_config = train_data_config if split == "train" else val_data_config 306 | for prefix, _ in data_config: 307 | filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) 308 | random.seed(seed) 309 | random.shuffle(filenames) 310 | 311 | dataset = PackedDataset( 312 | filenames, 313 | # n_chunks control the buffer size. 314 | # Note that the buffer size also impacts the random shuffle 315 | # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) 316 | n_chunks=8, 317 | block_size=block_size, 318 | shuffle=shuffle, 319 | seed=seed+fabric.global_rank, 320 | num_processes=fabric.world_size, 321 | process_rank=fabric.global_rank, 322 | ) 323 | datasets.append(dataset) 324 | 325 | if not datasets: 326 | raise RuntimeError( 327 | f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." 328 | ) 329 | 330 | weights = [weight for _, weight in data_config] 331 | sum_weights = sum(weights) 332 | weights = [el / sum_weights for el in weights] 333 | 334 | combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) 335 | 336 | return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 337 | 338 | 339 | def create_dataloaders( 340 | batch_size: int, 341 | block_size: int, 342 | fabric, 343 | train_data_dir: Path = Path("data/redpajama_sample"), 344 | val_data_dir: Optional[Path] = None, 345 | seed: int = 12345, 346 | ) -> Tuple[DataLoader, DataLoader]: 347 | # Increase by one because we need the next word as well 348 | effective_block_size = block_size + 1 349 | train_dataloader = create_dataloader( 350 | batch_size=batch_size, 351 | block_size=effective_block_size, 352 | fabric=fabric, 353 | data_dir=train_data_dir, 354 | shuffle=True, 355 | seed=seed, 356 | split="train" 357 | ) 358 | val_dataloader = ( 359 | create_dataloader( 360 | batch_size=batch_size, 361 | block_size=effective_block_size, 362 | fabric=fabric, 363 | data_dir=val_data_dir, 364 | shuffle=False, 365 | seed=seed, 366 | split="validation" 367 | ) 368 | if val_data_dir 369 | else None 370 | ) 371 | return train_dataloader, val_dataloader 372 | 373 | 374 | # learning rate decay scheduler (cosine with warmup) 375 | def get_lr(it): 376 | # 1) linear warmup for warmup_iters steps 377 | if it < warmup_iters: 378 | return learning_rate * it / warmup_iters 379 | # 2) if it > lr_decay_iters, return min learning rate 380 | if it > lr_decay_iters: 381 | return min_lr 382 | # 3) in between, use cosine decay down to min learning rate 383 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 384 | assert 0 <= decay_ratio <= 1 385 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 386 | return min_lr + coeff * (learning_rate - min_lr) 387 | 388 | 389 | if __name__ == "__main__": 390 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 391 | # torch.backends.cuda.enable_flash_sdp(False) 392 | torch.set_float32_matmul_precision("high") 393 | 394 | from jsonargparse import CLI 395 | 396 | CLI(setup) 397 | -------------------------------------------------------------------------------- /pretrain/tinyllama_code.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | import sys 4 | import time 5 | from pathlib import Path 6 | from typing import Optional, Tuple, Union 7 | import math 8 | import lightning as L 9 | import torch 10 | from lightning.fabric.strategies import FSDPStrategy, XLAStrategy 11 | from torch.utils.data import DataLoader 12 | from functools import partial 13 | # support running without installing as a package 14 | wd = Path(__file__).parent.parent.resolve() 15 | sys.path.append(str(wd)) 16 | # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually 17 | from lit_gpt.model import GPT, Block, Config, CausalSelfAttention 18 | from lit_gpt.packed_dataset import CombinedDataset, PackedDataset 19 | from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor 20 | from lit_gpt.speed_monitor import estimate_flops, measure_flops 21 | from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load 22 | from pytorch_lightning.loggers import WandbLogger 23 | from lit_gpt import FusedCrossEntropyLoss 24 | import random 25 | 26 | 27 | model_name = "tiny_LLaMA_1b" 28 | name = "tiny_LLaMA_1b" 29 | out_dir = Path("out") / name 30 | checkpoint_path = "out/TinyLlama-1.1B-intermediate-step-240k-503b/lit_model.pth" 31 | # Hyperparameters 32 | num_of_devices = 6 33 | global_batch_size = 360 34 | learning_rate = 2e-4 35 | min_lr = 2e-5 36 | micro_batch_size = 6 37 | max_step = 10000 38 | warmup_steps = 0 39 | log_step_interval = 1 40 | eval_iters = 1000000 41 | save_step_interval = 2000 42 | eval_step_interval = 2000 43 | 44 | weight_decay = 1e-1 45 | beta1 = 0.9 46 | beta2 = 0.95 47 | grad_clip = 1.0 48 | decay_lr = True 49 | 50 | batch_size = global_batch_size // num_of_devices 51 | gradient_accumulation_steps = batch_size // micro_batch_size 52 | assert gradient_accumulation_steps > 0 53 | warmup_iters = warmup_steps * gradient_accumulation_steps 54 | 55 | 56 | 57 | 58 | max_iters = max_step * gradient_accumulation_steps 59 | lr_decay_iters = max_iters 60 | log_iter_interval = log_step_interval * gradient_accumulation_steps 61 | 62 | 63 | # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. 64 | train_data_config = [ 65 | ("train_starcoder", 1), 66 | ] 67 | 68 | val_data_config = [ 69 | ("validation", 1.0), 70 | ] 71 | 72 | hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} 73 | logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) 74 | wandb_logger = WandbLogger() 75 | 76 | 77 | def setup( 78 | devices: int = 8, 79 | train_data_dir: Path = Path("data/redpajama_sample"), 80 | val_data_dir: Optional[Path] = None, 81 | precision: Optional[str] = None, 82 | tpu: bool = False, 83 | resume: Union[bool, Path] = False, 84 | ) -> None: 85 | precision = precision or get_default_supported_precision(training=True, tpu=tpu) 86 | 87 | if devices > 1: 88 | if tpu: 89 | # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. 90 | devices = "auto" 91 | strategy = XLAStrategy(sync_module_states=False) 92 | else: 93 | strategy = FSDPStrategy( 94 | auto_wrap_policy={Block}, 95 | activation_checkpointing_policy=None, 96 | state_dict_type="full", 97 | limit_all_gathers=True, 98 | cpu_offload=False, 99 | ) 100 | else: 101 | strategy = "auto" 102 | 103 | fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) 104 | fabric.print(hparams) 105 | fabric.launch(main, train_data_dir, val_data_dir, resume) 106 | # main(fabric, train_data_dir, val_data_dir, resume) 107 | 108 | 109 | def main(fabric, train_data_dir, val_data_dir, resume): 110 | monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) 111 | 112 | if fabric.global_rank == 0: 113 | out_dir.mkdir(parents=True, exist_ok=True) 114 | 115 | config = Config.from_name(model_name) 116 | 117 | train_dataloader, val_dataloader = create_dataloaders( 118 | batch_size=micro_batch_size, 119 | block_size=config.block_size, 120 | fabric=fabric, 121 | train_data_dir=train_data_dir, 122 | val_data_dir=val_data_dir, 123 | seed=3407, 124 | ) 125 | if val_dataloader is None: 126 | train_dataloader = fabric.setup_dataloaders(train_dataloader) 127 | else: 128 | train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) 129 | 130 | fabric.seed_everything(3407) # same seed for every process to init model (FSDP) 131 | 132 | fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}") 133 | t0 = time.perf_counter() 134 | with fabric.init_module(empty_init=True): 135 | model = GPT(config) 136 | 137 | 138 | model = fabric.setup(model) 139 | fabric.load_raw(checkpoint_path, model, strict=True) 140 | fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") 141 | fabric.print(f"Total parameters {num_parameters(model):,}") 142 | 143 | 144 | optimizer = torch.optim.AdamW( 145 | model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False 146 | ) 147 | # import bitsandbytes as bnb 148 | # optimizer = bnb.optim.AdamW8bit( 149 | # model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2) 150 | # ) 151 | # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) 152 | optimizer = fabric.setup_optimizers(optimizer) 153 | 154 | state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} 155 | 156 | if resume is True: 157 | resume = sorted(out_dir.glob("*.pth"))[-1] 158 | if resume : 159 | fabric.print(f"Resuming training from {resume}") 160 | fabric.load(resume, state) 161 | 162 | train_time = time.perf_counter() 163 | train(fabric, state, train_dataloader, val_dataloader, monitor, resume) 164 | fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") 165 | if fabric.device.type == "cuda": 166 | fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") 167 | 168 | 169 | def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): 170 | model = state["model"] 171 | optimizer = state["optimizer"] 172 | 173 | if val_dataloader is not None: 174 | validate(fabric, model, val_dataloader) # sanity check 175 | 176 | with torch.device("meta"): 177 | meta_model = GPT(model.config) 178 | # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. 179 | # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, 180 | # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead 181 | estimated_flops = estimate_flops(meta_model) * micro_batch_size 182 | fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") 183 | x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) 184 | # measured_flos run in meta. Will trigger fusedRMSNorm error 185 | #measured_flops = measure_flops(meta_model, x) 186 | #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") 187 | del meta_model, x 188 | 189 | total_lengths = 0 190 | total_t0 = time.perf_counter() 191 | 192 | if fabric.device.type == "xla": 193 | import torch_xla.core.xla_model as xm 194 | 195 | xm.mark_step() 196 | 197 | 198 | initial_iter = state["iter_num"] 199 | curr_iter = 0 200 | 201 | loss_func = FusedCrossEntropyLoss() 202 | for train_data in train_dataloader: 203 | # resume loader state. This is not elegant but it works. Should rewrite it in the future. 204 | if resume: 205 | if curr_iter < initial_iter: 206 | curr_iter += 1 207 | continue 208 | else: 209 | resume = False 210 | curr_iter = -1 211 | fabric.barrier() 212 | fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) 213 | if state["iter_num"] >= max_iters: 214 | break 215 | 216 | # determine and set the learning rate for this iteration 217 | lr = get_lr(state["iter_num"]) if decay_lr else learning_rate 218 | for param_group in optimizer.param_groups: 219 | param_group["lr"] = lr 220 | 221 | iter_t0 = time.perf_counter() 222 | input_ids = train_data[:, 0 : model.config.block_size].contiguous() 223 | targets = train_data[:, 1 : model.config.block_size + 1].contiguous() 224 | 225 | is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 226 | with fabric.no_backward_sync(model, enabled=is_accumulating): 227 | logits = model(input_ids) 228 | loss = loss_func(logits, targets) 229 | # loss = chunked_cross_entropy(logits, targets, chunk_size=0) 230 | fabric.backward(loss / gradient_accumulation_steps) 231 | 232 | if not is_accumulating: 233 | fabric.clip_gradients(model, optimizer, max_norm=grad_clip) 234 | optimizer.step() 235 | optimizer.zero_grad() 236 | state["step_count"] += 1 237 | elif fabric.device.type == "xla": 238 | xm.mark_step() 239 | state["iter_num"] += 1 240 | # input_id: B L 241 | total_lengths += input_ids.size(1) 242 | t1 = time.perf_counter() 243 | fabric.print( 244 | f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" 245 | f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" 246 | f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " 247 | # print days as well 248 | f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " 249 | ) 250 | 251 | monitor.on_train_batch_end( 252 | state["iter_num"] * micro_batch_size, 253 | t1 - total_t0, 254 | # this assumes that device FLOPs are the same and that all devices have the same batch size 255 | fabric.world_size, 256 | state["step_count"], 257 | flops_per_batch=estimated_flops, 258 | lengths=total_lengths, 259 | train_loss = loss.item() 260 | ) 261 | 262 | 263 | 264 | 265 | if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: 266 | 267 | t0 = time.perf_counter() 268 | val_loss = validate(fabric, model, val_dataloader) 269 | t1 = time.perf_counter() - t0 270 | monitor.eval_end(t1) 271 | fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") 272 | fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 273 | fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 274 | fabric.barrier() 275 | if not is_accumulating and state["step_count"] % save_step_interval == 0: 276 | checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" 277 | fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") 278 | fabric.save(checkpoint_path, state) 279 | 280 | 281 | @torch.no_grad() 282 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: 283 | fabric.print("Validating ...") 284 | model.eval() 285 | 286 | losses = torch.zeros(eval_iters, device=fabric.device) 287 | for k, val_data in enumerate(val_dataloader): 288 | if k >= eval_iters: 289 | break 290 | input_ids = val_data[:, 0 : model.config.block_size].contiguous() 291 | targets = val_data[:, 1 : model.config.block_size + 1].contiguous() 292 | logits = model(input_ids) 293 | loss = chunked_cross_entropy(logits, targets, chunk_size=0) 294 | 295 | # loss_func = FusedCrossEntropyLoss() 296 | # loss = loss_func(logits, targets) 297 | losses[k] = loss.item() 298 | 299 | out = losses.mean() 300 | 301 | model.train() 302 | return out 303 | 304 | 305 | def create_dataloader( 306 | batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" 307 | ) -> DataLoader: 308 | datasets = [] 309 | data_config = train_data_config if split == "train" else val_data_config 310 | for prefix, _ in data_config: 311 | filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) 312 | random.seed(seed) 313 | random.shuffle(filenames) 314 | 315 | dataset = PackedDataset( 316 | filenames, 317 | # n_chunks control the buffer size. 318 | # Note that the buffer size also impacts the random shuffle 319 | # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) 320 | n_chunks=8, 321 | block_size=block_size, 322 | shuffle=shuffle, 323 | seed=seed+fabric.global_rank, 324 | num_processes=fabric.world_size, 325 | process_rank=fabric.global_rank, 326 | ) 327 | datasets.append(dataset) 328 | 329 | if not datasets: 330 | raise RuntimeError( 331 | f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." 332 | ) 333 | 334 | weights = [weight for _, weight in data_config] 335 | sum_weights = sum(weights) 336 | weights = [el / sum_weights for el in weights] 337 | 338 | combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) 339 | 340 | return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 341 | 342 | 343 | def create_dataloaders( 344 | batch_size: int, 345 | block_size: int, 346 | fabric, 347 | train_data_dir: Path = Path("data/redpajama_sample"), 348 | val_data_dir: Optional[Path] = None, 349 | seed: int = 12345, 350 | ) -> Tuple[DataLoader, DataLoader]: 351 | # Increase by one because we need the next word as well 352 | effective_block_size = block_size + 1 353 | train_dataloader = create_dataloader( 354 | batch_size=batch_size, 355 | block_size=effective_block_size, 356 | fabric=fabric, 357 | data_dir=train_data_dir, 358 | shuffle=True, 359 | seed=seed, 360 | split="train" 361 | ) 362 | val_dataloader = ( 363 | create_dataloader( 364 | batch_size=batch_size, 365 | block_size=effective_block_size, 366 | fabric=fabric, 367 | data_dir=val_data_dir, 368 | shuffle=False, 369 | seed=seed, 370 | split="validation" 371 | ) 372 | if val_data_dir 373 | else None 374 | ) 375 | return train_dataloader, val_dataloader 376 | 377 | 378 | # learning rate decay scheduler (cosine with warmup) 379 | def get_lr(it): 380 | # 1) linear warmup for warmup_iters steps 381 | if it < warmup_iters: 382 | return learning_rate * it / warmup_iters 383 | # 2) if it > lr_decay_iters, return min learning rate 384 | if it > lr_decay_iters: 385 | return min_lr 386 | # 3) in between, use cosine decay down to min learning rate 387 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 388 | assert 0 <= decay_ratio <= 1 389 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 390 | return min_lr + coeff * (learning_rate - min_lr) 391 | 392 | 393 | if __name__ == "__main__": 394 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 395 | # torch.backends.cuda.enable_flash_sdp(False) 396 | torch.set_float32_matmul_precision("high") 397 | 398 | from jsonargparse import CLI 399 | 400 | CLI(setup) 401 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.1.0dev 2 | lightning==2.1.2 3 | lightning[app] 4 | jsonargparse[signatures] # CLI 5 | pandas 6 | pyarrow 7 | tokenizers 8 | sentencepiece 9 | wandb 10 | zstd 11 | 12 | # for finetuning 13 | bitsandbytes==0.40.0 14 | transformers==4.31.0 15 | peft==0.4.0 16 | accelerate==0.21.0 17 | einops==0.6.1 18 | evaluate==0.4.0 19 | scikit-learn==1.2.2 20 | sentencepiece==0.1.99 21 | wandb==0.15.3 22 | # other optional dependencies are 23 | # sentencepiece # pythia, falcon, redpajama 24 | # tokenizers # llama-based models 25 | # bitsandbytes>=0.41.1 # quantize/bnb.py 26 | # scipy # TODO: remove when https://github.com/TimDettmers/bitsandbytes/pull/525 is released 27 | # datasets # quantize/gptq.py 28 | # zstandard # scripts/prepare_redpajama.py 29 | # git+https://github.com/EleutherAI/lm-evaluation-harness.git@master # eval 30 | -------------------------------------------------------------------------------- /script.sh: -------------------------------------------------------------------------------- 1 | python scripts/convert_hf_checkpoint.py --checkpoint_dir out/TinyLlama-1.1B-900B --model_name tiny_LLaMA_1b 2 | 3 | python test_weight.py --checkpoint_dir out/TinyLlama-1.1B-intermediate-900B 4 | 5 | 6 | python pretrain/tinyllama_code.py --devices 8 --train_data_dir data/code_specialist_python_java_javascript_c_go_8192 7 | 8 | 9 | 10 | python scripts/prepare_starcoder.py --source_path data/starcoderdata/ --tokenizer_path data/llama --destination_path data/code_specialist_python_java_javascript_c_go_8192 --split train --percentage 1.0 --filenames_subset ["python","cpp","go","java","javascript"] --chunk_size 4194816 11 | 12 | 13 | 14 | 15 | /data/TinyLlama/out/code_tiny_LLaMA_1b_python_java_go_cpp_javascript/iter-032000-ckpt.pth 16 | 17 | python scripts/convert_lit_checkpoint.py --out_dir /data/TinyLlama/out/tiny_LLaMA_1b/ --checkpoint_name iter-100000-ckpt.pth --model_name tiny_LLaMA_1b -------------------------------------------------------------------------------- /scripts/convert_hf_checkpoint.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import gc 3 | import json 4 | import sys 5 | from functools import partial 6 | from pathlib import Path 7 | from typing import Dict, List, Literal, Optional, Tuple, Union 8 | 9 | import torch 10 | 11 | # support running without installing as a package 12 | wd = Path(__file__).parent.parent.resolve() 13 | sys.path.append(str(wd)) 14 | 15 | from lit_gpt import Config 16 | from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load 17 | 18 | 19 | def copy_weights_gpt_neox( 20 | state_dict: Dict[str, torch.Tensor], 21 | hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 22 | saver: Optional[incremental_save] = None, 23 | dtype: Optional[torch.dtype] = None, 24 | ) -> None: 25 | weight_map = { 26 | "gpt_neox.embed_in.weight": "transformer.wte.weight", 27 | "gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", 28 | "gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", 29 | "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.attn.bias", 30 | "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", 31 | "gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias", 32 | "gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight", 33 | "gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None, 34 | "gpt_neox.layers.{}.attention.bias": None, 35 | "gpt_neox.layers.{}.attention.masked_bias": None, 36 | "gpt_neox.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", 37 | "gpt_neox.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", 38 | "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias": "transformer.h.{}.mlp.fc.bias", 39 | "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", 40 | "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias": "transformer.h.{}.mlp.proj.bias", 41 | "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", 42 | "gpt_neox.final_layer_norm.bias": "transformer.ln_f.bias", 43 | "gpt_neox.final_layer_norm.weight": "transformer.ln_f.weight", 44 | "embed_out.weight": "lm_head.weight", 45 | } 46 | 47 | for name, param in hf_weights.items(): 48 | if "gpt_neox.layers" in name: 49 | from_name, number = layer_template(name, 2) 50 | to_name = weight_map[from_name] 51 | if to_name is None: 52 | continue 53 | to_name = to_name.format(number) 54 | else: 55 | to_name = weight_map[name] 56 | param = load_param(param, name, dtype) 57 | if saver is not None: 58 | param = saver.store_early(param) 59 | state_dict[to_name] = param 60 | 61 | 62 | def copy_weights_falcon( 63 | size: Literal["7b", "40b"], 64 | state_dict: Dict[str, torch.Tensor], 65 | hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 66 | saver: Optional[incremental_save] = None, 67 | dtype: Optional[torch.dtype] = None, 68 | ) -> None: 69 | weight_map = { 70 | "transformer.word_embeddings.weight": "transformer.wte.weight", 71 | "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", 72 | "transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight", 73 | "transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", 74 | "transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", 75 | "transformer.ln_f.bias": "transformer.ln_f.bias", 76 | "transformer.ln_f.weight": "transformer.ln_f.weight", 77 | "lm_head.weight": "lm_head.weight", 78 | } 79 | # the original model definition is different for each size 80 | if size == "7b": 81 | weight_map.update( 82 | { 83 | "transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", 84 | "transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", 85 | } 86 | ) 87 | elif size == "40b": 88 | weight_map.update( 89 | { 90 | "transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias", 91 | "transformer.h.{}.ln_attn.weight": "transformer.h.{}.norm_1.weight", 92 | "transformer.h.{}.ln_mlp.bias": "transformer.h.{}.norm_2.bias", 93 | "transformer.h.{}.ln_mlp.weight": "transformer.h.{}.norm_2.weight", 94 | } 95 | ) 96 | else: 97 | raise NotImplementedError 98 | 99 | for name, param in hf_weights.items(): 100 | if "transformer.h" in name: 101 | from_name, number = layer_template(name, 2) 102 | to_name = weight_map[from_name].format(number) 103 | else: 104 | to_name = weight_map[name] 105 | param = load_param(param, name, dtype) 106 | if saver is not None: 107 | param = saver.store_early(param) 108 | state_dict[to_name] = param 109 | 110 | 111 | def copy_weights_hf_llama( 112 | config: Config, 113 | qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], 114 | state_dict: Dict[str, torch.Tensor], 115 | hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 116 | saver: Optional[incremental_save] = None, 117 | dtype: Optional[torch.dtype] = None, 118 | ) -> None: 119 | weight_map = { 120 | "model.embed_tokens.weight": "transformer.wte.weight", 121 | "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", 122 | "model.layers.{}.self_attn.q_proj.weight": None, 123 | "model.layers.{}.self_attn.k_proj.weight": None, 124 | "model.layers.{}.self_attn.v_proj.weight": None, 125 | "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", 126 | "model.layers.{}.self_attn.rotary_emb.inv_freq": None, 127 | "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", 128 | "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.swiglu.w1.weight", 129 | "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.swiglu.w2.weight", 130 | "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.swiglu.w3.weight", 131 | "model.norm.weight": "transformer.ln_f.weight", 132 | "lm_head.weight": "lm_head.weight", 133 | } 134 | 135 | for name, param in hf_weights.items(): 136 | if "model.layers" in name: 137 | from_name, number = layer_template(name, 2) 138 | qkv = qkv_weights.setdefault(number, [None, None, None]) 139 | if "q_proj" in name: 140 | qkv[0] = param 141 | elif "k_proj" in name: 142 | qkv[1] = param 143 | elif "v_proj" in name: 144 | qkv[2] = param 145 | to_name = weight_map[from_name] 146 | if to_name is None: 147 | continue 148 | to_name = to_name.format(number) 149 | else: 150 | to_name = weight_map[name] 151 | param = load_param(param, name, dtype) 152 | if saver is not None: 153 | param = saver.store_early(param) 154 | state_dict[to_name] = param 155 | 156 | for i, (q, k, v) in list(qkv_weights.items()): 157 | if q is None or k is None or v is None: 158 | # split across different .bin files 159 | continue 160 | q = load_param(q, f"layer {i} q", dtype) 161 | k = load_param(k, f"layer {i} k", dtype) 162 | v = load_param(v, f"layer {i} v", dtype) 163 | q_per_kv = config.n_head // config.n_query_groups 164 | qs = torch.split(q, config.head_size * q_per_kv) 165 | ks = torch.split(k, config.head_size) 166 | vs = torch.split(v, config.head_size) 167 | cycled = [t for group in zip(qs, ks, vs) for t in group] 168 | qkv = torch.cat(cycled) 169 | state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv 170 | del qkv_weights[i] 171 | 172 | 173 | def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: 174 | split = layer_name.split(".") 175 | number = int(split[idx]) 176 | split[idx] = "{}" 177 | from_name = ".".join(split) 178 | return from_name, number 179 | 180 | 181 | def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: 182 | if hasattr(param, "_load_tensor"): 183 | # support tensors loaded via `lazy_load()` 184 | print(f"Loading {name!r} into RAM") 185 | param = param._load_tensor() 186 | if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: 187 | print(f"Converting {name!r} from {param.dtype} to {dtype}") 188 | param = param.to(dtype) 189 | return param 190 | 191 | 192 | @torch.inference_mode() 193 | def convert_hf_checkpoint( 194 | *, 195 | checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), 196 | model_name: Optional[str] = None, 197 | dtype: Optional[str] = None, 198 | ) -> None: 199 | if model_name is None: 200 | model_name = checkpoint_dir.name 201 | if dtype is not None: 202 | dtype = getattr(torch, dtype) 203 | 204 | config = Config.from_name(model_name) 205 | print(f"Model config {config.__dict__}") 206 | with open(checkpoint_dir / "lit_config.json", "w") as json_config: 207 | json.dump(config.__dict__, json_config) 208 | 209 | if "falcon" in model_name: 210 | copy_fn = partial(copy_weights_falcon, "40b" if config.n_embd == 8192 else "7b") 211 | elif config._mlp_class == "LLaMAMLP": 212 | # holder to reconstitute the split q, k, v 213 | qkv_weights = {} 214 | copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) 215 | else: 216 | copy_fn = copy_weights_gpt_neox 217 | 218 | # initialize a new empty state dict to hold our new weights 219 | sd = {} 220 | 221 | # Load the json file containing weight mapping 222 | pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json" 223 | if pytorch_bin_map_json_path.is_file(): # not all checkpoints have this file 224 | with open(pytorch_bin_map_json_path) as json_map: 225 | bin_index = json.load(json_map) 226 | bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} 227 | else: 228 | bin_files = set(checkpoint_dir.glob("*.bin")) 229 | if not bin_files: 230 | raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files") 231 | 232 | with incremental_save(checkpoint_dir / "lit_model.pth") as saver: 233 | # for checkpoints that split the QKV across several files, we need to keep all the bin files 234 | # open, so we use `ExitStack` to close them all together at the end 235 | with contextlib.ExitStack() as stack: 236 | for bin_file in sorted(bin_files): 237 | print("Processing", bin_file) 238 | hf_weights = stack.enter_context(lazy_load(bin_file)) 239 | copy_fn(sd, hf_weights, saver=None, dtype=dtype) 240 | gc.collect() 241 | print("Saving converted checkpoint") 242 | saver.save(sd) 243 | 244 | 245 | if __name__ == "__main__": 246 | from jsonargparse import CLI 247 | 248 | CLI(convert_hf_checkpoint) 249 | -------------------------------------------------------------------------------- /scripts/convert_lit_checkpoint.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import gc 3 | import sys 4 | from functools import partial 5 | from pathlib import Path 6 | from typing import Dict, Literal, Optional, Tuple, Union 7 | from dataclasses import asdict 8 | import json 9 | import torch 10 | 11 | # support running without installing as a package 12 | wd = Path(__file__).parent.parent.resolve() 13 | sys.path.append(str(wd)) 14 | 15 | from lit_gpt import Config 16 | from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load 17 | # from scripts.convert_hf_checkpoint import layer_template, load_param 18 | 19 | 20 | def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: 21 | split = layer_name.split(".") 22 | number = int(split[idx]) 23 | split[idx] = "{}" 24 | from_name = ".".join(split) 25 | return from_name, number 26 | 27 | 28 | def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: 29 | if hasattr(param, "_load_tensor"): 30 | # support tensors loaded via `lazy_load()` 31 | print(f"Loading {name!r} into RAM") 32 | param = param._load_tensor() 33 | if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: 34 | print(f"Converting {name!r} from {param.dtype} to {dtype}") 35 | param = param.to(dtype) 36 | return param 37 | def copy_weights_falcon( 38 | size: Literal["7b", "40b"], 39 | state_dict: Dict[str, torch.Tensor], 40 | lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 41 | saver: Optional[incremental_save] = None, 42 | ): 43 | weight_map = { 44 | "transformer.wte.weight": "transformer.word_embeddings.weight", 45 | "transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight", 46 | "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", 47 | "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", 48 | "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", 49 | "transformer.ln_f.bias": "transformer.ln_f.bias", 50 | "transformer.ln_f.weight": "transformer.ln_f.weight", 51 | "lm_head.weight": "lm_head.weight", 52 | } 53 | # the original model definition is different for each size 54 | if size == "7b": 55 | weight_map.update( 56 | { 57 | "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", 58 | "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", 59 | } 60 | ) 61 | elif size == "40b": 62 | weight_map.update( 63 | { 64 | "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", 65 | "transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight", 66 | "transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias", 67 | "transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight", 68 | } 69 | ) 70 | else: 71 | raise NotImplementedError 72 | 73 | for name, param in lit_weights.items(): 74 | if "transformer.h" in name: 75 | from_name, number = layer_template(name, 2) 76 | to_name = weight_map[from_name].format(number) 77 | else: 78 | to_name = weight_map[name] 79 | param = load_param(param, name, None) 80 | if saver is not None: 81 | param = saver.store_early(param) 82 | state_dict[to_name] = param 83 | 84 | 85 | def copy_weights_gpt_neox( 86 | state_dict: Dict[str, torch.Tensor], 87 | lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 88 | saver: Optional[incremental_save] = None, 89 | ) -> None: 90 | weight_map = { 91 | "transformer.wte.weight": "gpt_neox.embed_in.weight", 92 | "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", 93 | "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", 94 | "transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", 95 | "transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", 96 | "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", 97 | "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", 98 | "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", 99 | "transformer.h.{}.norm_2.weight": "gpt_neox.layers.{}.post_attention_layernorm.weight", 100 | "transformer.h.{}.mlp.fc.bias": "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias", 101 | "transformer.h.{}.mlp.fc.weight": "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight", 102 | "transformer.h.{}.mlp.proj.bias": "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias", 103 | "transformer.h.{}.mlp.proj.weight": "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight", 104 | "transformer.ln_f.bias": "gpt_neox.final_layer_norm.bias", 105 | "transformer.ln_f.weight": "gpt_neox.final_layer_norm.weight", 106 | "lm_head.weight": "embed_out.weight", 107 | } 108 | 109 | for name, param in lit_weights.items(): 110 | if "transformer.h" in name: 111 | from_name, number = layer_template(name, 2) 112 | to_name = weight_map[from_name].format(number) 113 | else: 114 | to_name = weight_map[name] 115 | param = load_param(param, name, None) 116 | if saver is not None: 117 | param = saver.store_early(param) 118 | state_dict[to_name] = param 119 | 120 | 121 | def copy_weights_llama( 122 | config: Config, 123 | state_dict: Dict[str, torch.Tensor], 124 | lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 125 | saver: Optional[incremental_save] = None, 126 | ): 127 | weight_map = { 128 | "transformer.wte.weight": "model.embed_tokens.weight", 129 | "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", 130 | "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", 131 | "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", 132 | "transformer.h.{}.mlp.swiglu.w1.weight": "model.layers.{}.mlp.gate_proj.weight", 133 | "transformer.h.{}.mlp.swiglu.w2.weight": "model.layers.{}.mlp.up_proj.weight", 134 | "transformer.h.{}.mlp.swiglu.w3.weight": "model.layers.{}.mlp.down_proj.weight", 135 | "transformer.ln_f.weight": "model.norm.weight", 136 | "lm_head.weight": "lm_head.weight", 137 | } 138 | for name, param in lit_weights.items(): 139 | if name.endswith(".attn.attn.weight"): 140 | from_name, number = layer_template(name, 2) 141 | q = "model.layers.{}.self_attn.q_proj.weight".format(number) 142 | k = "model.layers.{}.self_attn.k_proj.weight".format(number) 143 | v = "model.layers.{}.self_attn.v_proj.weight".format(number) 144 | qkv = load_param(param, name,None) 145 | qp, kp, vp = tensor_split(qkv, config) 146 | for to_name, param in zip((q, k, v), (qp, kp, vp)): 147 | if saver is not None: 148 | param = saver.store_early(param) 149 | state_dict[to_name] = param 150 | elif "transformer.h" in name: 151 | from_name, number = layer_template(name, 2) 152 | to_name = weight_map[from_name] 153 | 154 | if to_name is None: 155 | continue 156 | to_name = to_name.format(number) 157 | param = load_param(param, name,None) 158 | if saver is not None: 159 | param = saver.store_early(param) 160 | state_dict[to_name] = param 161 | 162 | else: 163 | to_name = weight_map[name] 164 | param = load_param(param, name, None) 165 | if saver is not None: 166 | param = saver.store_early(param) 167 | state_dict[to_name] = param 168 | 169 | 170 | def tensor_split( 171 | param: Union[torch.Tensor, NotYetLoadedTensor], config: Config 172 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 173 | def kstart(start, blen, klen) -> int: 174 | """returns start index of keys in batch""" 175 | return start + (blen - (klen * 2)) 176 | 177 | def vstart(start, blen, klen) -> int: 178 | """returns start index of values in batch""" 179 | return start + blen - klen 180 | 181 | def vend(start, blen) -> int: 182 | """returns last index of values in batch""" 183 | return start + blen 184 | 185 | # num observations 186 | nobs = param.shape[0] 187 | # batch length 188 | blen = nobs // config.n_query_groups 189 | # key length in batch 190 | klen = config.head_size 191 | # value length in batch 192 | vlen = config.head_size 193 | # the starting index of each new batch 194 | starts = range(0, nobs, blen) 195 | # the indices to splice on 196 | splices = [(s, kstart(s, blen, klen), vstart(s, blen, vlen), vend(s, blen)) for s in starts] 197 | 198 | qc = () 199 | kc = () 200 | vc = () 201 | 202 | for splice in splices: 203 | qs, ks, vs, ve = splice 204 | qc += (param[qs:ks, :],) 205 | kc += (param[ks:vs, :],) 206 | vc += (param[vs:ve, :],) 207 | 208 | q = torch.cat(qc) 209 | k = torch.cat(kc) 210 | v = torch.cat(vc) 211 | 212 | return q, k, v 213 | 214 | 215 | def maybe_unwrap_state_dict(lit_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 216 | return lit_weights.get("model", lit_weights) 217 | 218 | 219 | def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: 220 | weight_names = {wk.split(".")[-1] for wk in lit_weights} 221 | # LoRA or QLoRA 222 | if any("lora" in wn for wn in weight_names): 223 | raise ValueError("Model weights must be merged using `lora.merge_lora_weights()` before conversion.") 224 | # adapter v2. adapter_bias will only be in adapter_v2 225 | elif "adapter_bias" in weight_names: 226 | raise NotImplementedError("Converting models finetuned with adapter_v2 not yet supported.") 227 | # adapter. gating_factor is in adapter and adapter_v2 228 | elif "gating_factor" in weight_names: 229 | raise NotImplementedError("Converting models finetuned with adapter not yet supported.") 230 | 231 | 232 | def get_tinyllama_init_hf_config() -> dict: 233 | return { 234 | "architectures": ["LlamaForCausalLM"], 235 | "bos_token_id": 1, 236 | "eos_token_id": 2, 237 | "hidden_act": "silu", 238 | "hidden_size": None, 239 | "initializer_range": 0.02, 240 | "intermediate_size": None, 241 | "max_position_embeddings": None, 242 | "model_type": "llama", 243 | "num_attention_heads": None, 244 | "num_hidden_layers": None, 245 | "num_key_value_heads": None, 246 | "pretraining_tp": 1, 247 | "rms_norm_eps": None, 248 | "rope_scaling": None, 249 | "tie_word_embeddings": False, 250 | "torch_dtype": "float32", 251 | "transformers_version": "4.31.0.dev0", 252 | "use_cache": True, 253 | "vocab_size": None, 254 | } 255 | 256 | 257 | def convert_config_lit_to_hf(lit_config_dict: dict) -> dict: 258 | lit_hf_mapping = { 259 | "block_size": "max_position_embeddings", 260 | "vocab_size": "vocab_size", 261 | "n_layer": "num_hidden_layers", 262 | "n_embd": "hidden_size", 263 | "n_head": "num_attention_heads", 264 | "n_query_groups": "num_key_value_heads", 265 | "intermediate_size": "intermediate_size", 266 | "norm_eps": "rms_norm_eps", 267 | 268 | } 269 | hf_config_dict = get_tinyllama_init_hf_config() 270 | 271 | for lit_key, hf_key in lit_hf_mapping.items(): 272 | hf_config_dict[hf_key] = lit_config_dict[lit_key] 273 | return hf_config_dict 274 | 275 | 276 | @torch.inference_mode() 277 | def convert_lit_checkpoint(*, 278 | checkpoint_name: str, 279 | out_dir: Path, 280 | model_name: str, 281 | model_only: bool = True) -> None: 282 | config = Config.from_name(model_name) 283 | 284 | if "falcon" in model_name: 285 | copy_fn = partial(copy_weights_falcon, "40b" if config.n_embd == 8192 else "7b") 286 | elif config._mlp_class == "LLaMAMLP": 287 | copy_fn = partial(copy_weights_llama, config) 288 | else: 289 | copy_fn = copy_weights_gpt_neox 290 | 291 | # initialize a new empty state dict to hold our new weights 292 | sd = {} 293 | 294 | # checkpoint_name cannot be hardcoded because there exists different outputs such as 295 | # ("lit_model_finetuned.pth", "lit_model_lora_finetuned.pth", "lit_model_adapter_finetuned.pth"") 296 | pth_file = out_dir / checkpoint_name 297 | bin_file = pth_file.with_suffix(".bin") 298 | 299 | with incremental_save(bin_file) as saver: 300 | with contextlib.ExitStack() as stack: 301 | lit_weights = stack.enter_context(lazy_load(pth_file)) 302 | lit_weights = maybe_unwrap_state_dict(lit_weights) 303 | check_conversion_supported(lit_weights) 304 | # Incremental save will trigger error 305 | copy_fn(sd, lit_weights, saver=None) 306 | gc.collect() 307 | saver.save(sd) 308 | 309 | # convert lit config file to hf-style 310 | if not model_only: 311 | print('Converting config file...') 312 | lit_config = asdict(config) 313 | hf_config = convert_config_lit_to_hf(lit_config) 314 | config_path = out_dir / "config.json" 315 | with open(config_path, "w") as f: 316 | json.dump(hf_config, f, indent=4) 317 | 318 | 319 | 320 | 321 | if __name__ == "__main__": 322 | from jsonargparse import CLI 323 | 324 | CLI(convert_lit_checkpoint, as_positional=False) 325 | -------------------------------------------------------------------------------- /scripts/prepare_redpajama.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import sys 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | # support running without installing as a package 11 | wd = Path(__file__).parent.parent.resolve() 12 | sys.path.append(str(wd)) 13 | 14 | import lit_gpt.packed_dataset as packed_dataset 15 | from lit_gpt import Config, Tokenizer 16 | 17 | filenames_sample = [ 18 | "arxiv_sample.jsonl", 19 | "book_sample.jsonl", 20 | "c4_sample.jsonl", 21 | "cc_2019-30_sample.jsonl", 22 | "cc_2020-05_sample.jsonl", 23 | "cc_2021-04_sample.jsonl", 24 | "cc_2022-05_sample.jsonl", 25 | "cc_2023-06_sample.jsonl", 26 | "github_sample.jsonl", 27 | "stackexchange_sample.jsonl", 28 | "wikipedia_sample.jsonl", 29 | ] 30 | 31 | filename_sets = { 32 | "arxiv": "arxiv/arxiv*", 33 | "book": "book/book*", 34 | "c4": "c4/c4-train*", 35 | "common_crawl": "common_crawl/*", 36 | "github": "github/filtered*", 37 | "stackexchange": "stackexchange/stackexchange*", 38 | "wikipedia": "wikipedia/wiki*", 39 | } 40 | 41 | 42 | def prepare_sample( 43 | source_path: Path, checkpoint_dir: Path, destination_path: Path, chunk_size: int, match: str = "" 44 | ) -> None: 45 | """Prepare the "Red Pajama" dataset using the original tokenizer.""" 46 | destination_path.mkdir(parents=True, exist_ok=True) 47 | 48 | tokenizer = Tokenizer(checkpoint_dir) 49 | 50 | for name in filenames_sample: 51 | if match and match not in name: 52 | continue 53 | 54 | filepath = source_path / name 55 | 56 | if not filepath.is_file(): 57 | raise RuntimeError( 58 | f"Input file not found at {filepath}. \nMake sure you download the data, e.g. wget -i" 59 | " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" 60 | " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" 61 | " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" 62 | ) 63 | 64 | prefix, _ = os.path.splitext(name) 65 | 66 | builder = packed_dataset.PackedDatasetBuilder( 67 | outdir=destination_path, 68 | prefix=prefix, 69 | chunk_size=chunk_size, 70 | sep_token=tokenizer.eos_id, 71 | dtype="auto", 72 | vocab_size=tokenizer.vocab_size, 73 | ) 74 | 75 | print(f"Processing {name}") 76 | 77 | with open(filepath, encoding="utf-8") as f: 78 | for row in tqdm(f): 79 | text = json.loads(row)["text"] 80 | text_ids = tokenizer.encode(text) 81 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 82 | 83 | builder.write_reminder() 84 | 85 | 86 | def prepare_full( 87 | source_path: Path, checkpoint_dir: Path, destination_path: Path, chunk_size: int, match: str = "" 88 | ) -> None: 89 | """Prepare the "Red Pajama" dataset using the original tokenizer.""" 90 | import zstandard as zstd 91 | 92 | destination_path.mkdir(parents=True, exist_ok=True) 93 | 94 | tokenizer = Tokenizer(checkpoint_dir) 95 | 96 | for set_name, pattern in filename_sets.items(): 97 | if match and match not in set_name: 98 | continue 99 | 100 | is_cc = set_name == "common_crawl" 101 | 102 | filenames = glob.glob(os.path.join(source_path, pattern), recursive=True) 103 | 104 | if not filenames: 105 | raise RuntimeError( 106 | f"No files matching {pattern} found at {source_path}. \nMake sure you download the data, e.g. wget -i" 107 | " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" 108 | " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" 109 | " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" 110 | ) 111 | 112 | builder = packed_dataset.PackedDatasetBuilder( 113 | outdir=destination_path, 114 | prefix=set_name, 115 | chunk_size=chunk_size, 116 | sep_token=tokenizer.eos_id, 117 | dtype="auto", 118 | vocab_size=tokenizer.vocab_size, 119 | ) 120 | 121 | for name in filenames: 122 | filepath = source_path / name 123 | 124 | print(f"Processing {name}") 125 | 126 | if is_cc: 127 | with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: 128 | for row in tqdm(f): 129 | text = json.loads(row)["text"] 130 | text_ids = tokenizer.encode(text) 131 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 132 | else: 133 | with open(filepath, encoding="utf-8") as f: 134 | for row in tqdm(f): 135 | text = json.loads(row)["text"] 136 | text_ids = tokenizer.encode(text) 137 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 138 | 139 | builder.write_reminder() 140 | 141 | 142 | def prepare( 143 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 144 | checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), 145 | destination_path: Path = Path("data/redpajama_sample"), 146 | sample: bool = True, 147 | match: str = "", 148 | ) -> None: 149 | """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained.""" 150 | with open(checkpoint_dir / "lit_config.json") as fp: 151 | config = Config(**json.load(fp)) 152 | 153 | prepare_fn = prepare_sample if sample else prepare_full 154 | prepare_fn( 155 | source_path=source_path, 156 | checkpoint_dir=checkpoint_dir, 157 | destination_path=destination_path, 158 | chunk_size=(config.block_size + 1) * 1024, # block size + 1 for causal, 1024 blocks 159 | match=match, 160 | ) 161 | 162 | 163 | if __name__ == "__main__": 164 | from jsonargparse import CLI 165 | 166 | CLI(prepare) -------------------------------------------------------------------------------- /scripts/prepare_slimpajama.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import os 4 | from pathlib import Path 5 | import sys 6 | from typing import List 7 | import numpy as np 8 | from tqdm import tqdm 9 | from multiprocessing import Process, cpu_count 10 | 11 | # support running without installing as a package 12 | wd = Path(__file__).parent.parent.resolve() 13 | sys.path.append(str(wd)) 14 | 15 | import lit_gpt.packed_dataset as packed_dataset 16 | from lit_gpt import Tokenizer 17 | 18 | # Filename for SlimPajama 19 | slimpajama_sets = { 20 | "train": "train/chunk*/*", 21 | "validation": "validation/chunk*/*", 22 | "test": "test/chunk*/*", 23 | } 24 | 25 | 26 | def prepare_full( 27 | source_path: Path, 28 | tokenizer_path: Path, 29 | destination_path: Path, 30 | chunk_size: int, 31 | split: str="train", 32 | filenames_subset: List[str] = None, 33 | process_id: int = 0 34 | ) -> None: 35 | import zstandard as zstd 36 | 37 | destination_path.mkdir(parents=True, exist_ok=True) 38 | 39 | tokenizer = Tokenizer(tokenizer_path) 40 | 41 | # Use the provided filenames_subset or default to all filenames 42 | filenames = filenames_subset 43 | 44 | if not filenames: 45 | raise RuntimeError( 46 | f"No files matching {slimpajama_sets[split]} found at {source_path}. \n" 47 | "Make sure you download the data..." 48 | ) 49 | 50 | builder = packed_dataset.PackedDatasetBuilder( 51 | outdir=destination_path, 52 | prefix=f"{split}_slimpajama_{process_id}", # Use process_id to differentiate builders 53 | chunk_size=chunk_size, 54 | sep_token=tokenizer.bos_id, 55 | dtype="auto", 56 | vocab_size=tokenizer.vocab_size, 57 | ) 58 | 59 | for filepath in filenames: 60 | print(f"Processing {filepath}") 61 | with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: 62 | for row in tqdm(f): 63 | text = json.loads(row)["text"] 64 | if json.loads(row)["meta"]["redpajama_set_name"] == "RedPajamaGithub": 65 | continue # we don't want to include the github data 66 | text_ids = tokenizer.encode(text) 67 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 68 | 69 | # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details 70 | # builder.write_reminder() 71 | 72 | 73 | def prepare( 74 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 75 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 76 | destination_path: Path = Path("data/red_pajama_sample"), 77 | chunk_size: int = 2049 * 1024, 78 | split: str="train", 79 | percentage: float = 1.0, 80 | ) -> None: 81 | import time 82 | 83 | filenames = glob.glob(os.path.join(source_path, slimpajama_sets[split]), recursive=True) 84 | filenames = filenames[:int(len(filenames) * percentage)] 85 | 86 | num_processes = cpu_count() 87 | chunked_filenames = np.array_split(filenames, num_processes) 88 | 89 | processes = [] 90 | start_time = time.time() 91 | 92 | for i, subset in enumerate(chunked_filenames): 93 | p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) 94 | processes.append(p) 95 | p.start() 96 | 97 | for p in processes: 98 | p.join() 99 | end_time = time.time() 100 | elapsed_time = end_time - start_time 101 | print(f"Time taken: {elapsed_time:.2f} seconds") 102 | 103 | 104 | if __name__ == "__main__": 105 | from jsonargparse import CLI 106 | CLI(prepare) -------------------------------------------------------------------------------- /scripts/prepare_starcoder.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import os 4 | from pathlib import Path 5 | import sys 6 | from typing import List 7 | import numpy as np 8 | from tqdm import tqdm 9 | from multiprocessing import Process, cpu_count 10 | 11 | # support running without installing as a package 12 | wd = Path(__file__).parent.parent.resolve() 13 | sys.path.append(str(wd)) 14 | 15 | import lit_gpt.packed_dataset as packed_dataset 16 | from lit_gpt import Tokenizer 17 | 18 | import pandas as pd 19 | 20 | 21 | def prepare_full( 22 | source_path: Path, 23 | tokenizer_path: Path, 24 | destination_path: Path, 25 | chunk_size: int, 26 | split: str="train", 27 | filenames_subset: List[str] = None, 28 | process_id: int = 0 29 | ) -> None: 30 | import zstandard as zstd 31 | 32 | destination_path.mkdir(parents=True, exist_ok=True) 33 | 34 | tokenizer = Tokenizer(tokenizer_path) 35 | 36 | # Use the provided filenames_subset or default to all filenames 37 | filenames = filenames_subset 38 | 39 | if not filenames: 40 | raise RuntimeError( 41 | f"No files matching found at {source_path}. \n" 42 | "Make sure you download the data..." 43 | ) 44 | 45 | builder = packed_dataset.PackedDatasetBuilder( 46 | outdir=destination_path, 47 | prefix=f"{split}_starcoder_{process_id}", # Use process_id to differentiate builders 48 | chunk_size=chunk_size, 49 | sep_token=tokenizer.bos_id, 50 | dtype="auto", 51 | vocab_size=tokenizer.vocab_size, 52 | ) 53 | 54 | for filepath in filenames: 55 | print(f"Processing {filepath}") 56 | try: 57 | contents = pd.read_parquet(filepath, engine='pyarrow')['content'] 58 | except: 59 | print(f"Error reading {filepath}!!") 60 | continue 61 | for text in contents: 62 | text_ids = tokenizer.encode(text) 63 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 64 | 65 | # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details 66 | # builder.write_reminder() 67 | 68 | 69 | def prepare( 70 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 71 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 72 | destination_path: Path = Path("data/red_pajama_sample"), 73 | chunk_size: int = 2049 * 1024, 74 | split: str="train", 75 | percentage: float = 1.0, 76 | filenames_subset: List[str] = None, 77 | ) -> None: 78 | import time 79 | assert split == "train" # starcoder only has train data 80 | filenames = glob.glob(os.path.join(source_path, "*/*.parquet"), recursive=True) 81 | # only retrain subsets that follow the prefix in filenames_subset 82 | if filenames_subset: 83 | filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])] 84 | filenames = filenames[:int(len(filenames) * percentage)] 85 | num_processes = 64 86 | chunked_filenames = np.array_split(filenames, num_processes) 87 | 88 | processes = [] 89 | start_time = time.time() 90 | 91 | for i, subset in enumerate(chunked_filenames): 92 | p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) 93 | processes.append(p) 94 | p.start() 95 | 96 | for p in processes: 97 | p.join() 98 | end_time = time.time() 99 | elapsed_time = end_time - start_time 100 | print(f"Time taken: {elapsed_time:.2f} seconds") 101 | 102 | 103 | if __name__ == "__main__": 104 | from jsonargparse import CLI 105 | CLI(prepare) 106 | -------------------------------------------------------------------------------- /sft/script.sh: -------------------------------------------------------------------------------- 1 | # We include a simple full-parameter finetuning & inference script here. Our V0.1 chat model is finetuned using this script. 2 | # The FT dataset we use is openassistant-guanaco. For finetuning with less than 4GB RAM, we refer you to the Qlora and bitsandbytes repo. 3 | # We did not undergone extensive hyperparameter tuning nor choosing more performant FT datasets. 4 | # We hope the community can explore on finetuning TinyLlama and come up with better chat models. I will include community-finetuned models in this repo. 5 | 6 | # V0.1 7 | CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --multi_gpu --num_processes 4 --main_process_port 1234 finetune.py \ 8 | --model_name_or_path PY007/TinyLlama-1.1B-intermediate-step-240k-503b \ 9 | --output_dir ./output/503B_FT_lr1e-5_ep5 \ 10 | --logging_steps 10 \ 11 | --save_strategy epoch \ 12 | --data_seed 42 \ 13 | --save_total_limit 6 \ 14 | --evaluation_strategy epoch \ 15 | --eval_dataset_size 512 \ 16 | --max_eval_samples 1000 \ 17 | --per_device_eval_batch_size 1 \ 18 | --max_new_tokens 32 \ 19 | --dataloader_num_workers 3 \ 20 | --group_by_length=False \ 21 | --logging_strategy steps \ 22 | --remove_unused_columns False \ 23 | --do_train \ 24 | --do_eval \ 25 | --warmup_ratio 0.05 \ 26 | --lr_scheduler_type constant \ 27 | --dataset oasst1 \ 28 | --source_max_len 16 \ 29 | --target_max_len 512 \ 30 | --per_device_train_batch_size 4 \ 31 | --max_steps 0 \ 32 | --num_train_epochs 5 \ 33 | --learning_rate 1e-5 \ 34 | --adam_beta2 0.999 \ 35 | --max_grad_norm 1.0 \ 36 | --weight_decay 0.0 \ 37 | --seed 0 \ 38 | --trust_remote_code \ 39 | --report_to wandb 40 | 41 | 42 | # V0.2 43 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --multi_gpu --num_processes 4 --main_process_port 1234 finetune.py \ 44 | --model_name_or_path PY007/TinyLlama-1.1B-intermediate-step-480k-1T \ 45 | --output_dir ./output/503B_FT_lr1e-5_ep5_top1_2023-08-25 \ 46 | --logging_steps 10 \ 47 | --save_strategy epoch \ 48 | --data_seed 42 \ 49 | --save_total_limit 6 \ 50 | --evaluation_strategy epoch \ 51 | --eval_dataset_size 512 \ 52 | --max_eval_samples 1000 \ 53 | --per_device_eval_batch_size 1 \ 54 | --max_new_tokens 32 \ 55 | --dataloader_num_workers 3 \ 56 | --group_by_length=False \ 57 | --logging_strategy steps \ 58 | --remove_unused_columns False \ 59 | --do_train \ 60 | --do_eval \ 61 | --warmup_ratio 0.05 \ 62 | --lr_scheduler_type constant \ 63 | --dataset OpenAssistant/oasst_top1_2023-08-25 \ 64 | --dataset_format oasst1 \ 65 | --source_max_len 16 \ 66 | --target_max_len 512 \ 67 | --per_device_train_batch_size 4 \ 68 | --max_steps 0 \ 69 | --num_train_epochs 5 \ 70 | --learning_rate 1e-5 \ 71 | --adam_beta2 0.999 \ 72 | --max_grad_norm 1.0 \ 73 | --weight_decay 0.0 \ 74 | --seed 0 \ 75 | --trust_remote_code \ 76 | --report_to wandb 77 | -------------------------------------------------------------------------------- /sft/simple_inference.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | import transformers 3 | import torch 4 | model = "PY007/TinyLlama-1.1B-Chat-v0.1" 5 | tokenizer = AutoTokenizer.from_pretrained(model) 6 | pipeline = transformers.pipeline( 7 | "text-generation", 8 | model=model, 9 | torch_dtype=torch.float16, 10 | device_map="auto", 11 | ) 12 | 13 | prompt = "Give me detailed info about Jeo Biden." 14 | formatted_prompt = ( 15 | f"### Human: {prompt} ### Assistant:" 16 | ) 17 | 18 | 19 | sequences = pipeline( 20 | formatted_prompt, 21 | do_sample=True, 22 | top_k=50, 23 | top_p = 0.9, 24 | num_return_sequences=1, 25 | repetition_penalty=1.1, 26 | max_new_tokens=1024, 27 | ) 28 | for seq in sequences: 29 | print(f"Result: {seq['generated_text']}") 30 | -------------------------------------------------------------------------------- /sft/simple_inference2.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from transformers import AutoTokenizer 4 | import transformers 5 | import torch 6 | model = "PY007/TinyLlama-1.1B-Chat-v0.2" 7 | tokenizer = AutoTokenizer.from_pretrained(model) 8 | pipeline = transformers.pipeline( 9 | "text-generation", 10 | model=model, 11 | torch_dtype=torch.float16, 12 | device_map="auto", 13 | ) 14 | 15 | prompt = "How to get in a good university?" 16 | formatted_prompt = ( 17 | f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" 18 | ) 19 | 20 | 21 | sequences = pipeline( 22 | formatted_prompt, 23 | do_sample=True, 24 | top_k=50, 25 | top_p = 0.9, 26 | num_return_sequences=1, 27 | repetition_penalty=1.1, 28 | max_new_tokens=1024, 29 | ) 30 | for seq in sequences: 31 | print(f"Result: {seq['generated_text']}") -------------------------------------------------------------------------------- /speculative_decoding/README.md: -------------------------------------------------------------------------------- 1 | ## Speculative Decoding 2 | 3 | ### HuggingFace "Assisted Generation" 4 | 5 | 6 | | Large Model | Native Decoding | Assisted Decoding | 7 | | ----------- | --------------- | ------------------ | 8 | | guanaco-7b | 69 seconds | 38 seconds | 9 | | guanaco-13b | 84 seconds | 45 seconds | 10 | | guanaco-33b | 109 seconds | 62 seconds | 11 | 12 | We use PY007/TinyLlama-1.1B-Chat-v0.1 as the assistant model and vary the large model from guanaco-7B to 33B. Experiments are done on a single A40 GPU with code inside instruct_hf_assisted_decoding.py. TinyLlama is loaded in fp16 and the large models are loaded in 8 bit to make guanaco-33b fit in memory and also to keep a consistent setup. The prompt used is "Give me detailed info about Jeo Biden.". max_new_tokens is set to 512. 13 | 14 | You can read this [article](https://huggingface.co/blog/assisted-generation) for more information about HuggingFace's Assisted Generation. 15 | 16 | Quote from HF: "due to INT8 quantization and the use of causal masking in assisted generation, the output of greedy decoding may differ in rare occasions." 17 | #### TODO 18 | - [ ] Thouroughly benchmark the average speedup on 52K Alpaca prompts. 19 | 20 | ### Llama.cpp Speculative Decoding 21 | We have continue-pretrained a code tinyllama from the 500B checkpoint with another 7B Python data [here](https://huggingface.co/PY007/TinyLlama-1.1B-python-v0.1). 22 | The code for continue-pretraining can be found in pretrain/tinyllama_code.py 23 | 24 | ``` 25 | ./speculative \ 26 | -m models/CodeLlama-7b-hf/ggml-model-f16.gguf \ 27 | -md models/TinyLlama-1.1B-500B-python/ggml-model-q4_0.gguf \ 28 | -p "# Quick-sort implementation in Python and sample usage:" \ 29 | -e -ngl 1 -t 4 -n 256 -s 20 --temp 0 --draft 8 30 | ``` 31 | This gives: 32 | 33 | ``` 34 | encoded 12 tokens in 0.247 seconds, speed: 48.638 t/s 35 | decoded 265 tokens in 7.909 seconds, speed: 33.507 t/s 36 | 37 | n_draft = 16 38 | n_predict = 265 39 | n_drafted = 317 40 | n_accept = 195 41 | accept = 61.514% 42 | 43 | draft: 44 | 45 | llama_print_timings: load time = 53.14 ms 46 | llama_print_timings: sample time = 652.62 ms / 1 runs ( 652.62 ms per token, 1.53 tokens per second) 47 | llama_print_timings: prompt eval time = 73.81 ms / 12 tokens ( 6.15 ms per token, 162.58 tokens per second) 48 | llama_print_timings: eval time = 2247.77 ms / 378 runs ( 5.95 ms per token, 168.17 tokens per second) 49 | llama_print_timings: total time = 8154.92 ms 50 | 51 | target: 52 | 53 | llama_print_timings: load time = 534.47 ms 54 | llama_print_timings: sample time = 208.12 ms / 265 runs ( 0.79 ms per token, 1273.32 tokens per second) 55 | llama_print_timings: prompt eval time = 4210.38 ms / 382 tokens ( 11.02 ms per token, 90.73 tokens per second) 56 | llama_print_timings: eval time = 682.80 ms / 16 runs ( 42.68 ms per token, 23.43 tokens per second) 57 | llama_print_timings: total time = 8214.11 ms 58 | ggml_metal_free: deallocating 59 | ggml_metal_free: deallocating 60 | ``` 61 | 62 | Even though the model is continue-pretrained exclusively on Python, it retains its ability in other languages, such as C: 63 | ``` 64 | ./speculative \ 65 | -m models/CodeLlama-7b-hf/ggml-model-f16.gguf \ 66 | -md models/TinyLlama-1.1B-500B-python/ggml-model-q4_0.gguf \ 67 | -p "// Quick-sort implementation in C (4 spaces indentation + detailed comments) and sample usage:\n\n#include" \ 68 | -e -ngl 1 -t 4 -n 256 -s 20 --temp 0 --draft 8 69 | ``` 70 | 71 | This gives: 72 | 73 | ``` 74 | encoded 25 tokens in 0.278 seconds, speed: 89.900 t/s 75 | decoded 258 tokens in 6.432 seconds, speed: 40.112 t/s 76 | 77 | n_draft = 28 78 | n_predict = 258 79 | n_drafted = 278 80 | n_accept = 200 81 | accept = 71.942% 82 | 83 | draft: 84 | 85 | llama_print_timings: load time = 932.54 ms 86 | llama_print_timings: sample time = 583.50 ms / 1 runs ( 583.50 ms per token, 1.71 tokens per second) 87 | llama_print_timings: prompt eval time = 81.50 ms / 25 tokens ( 3.26 ms per token, 306.73 tokens per second) 88 | llama_print_timings: eval time = 1834.67 ms / 329 runs ( 5.58 ms per token, 179.32 tokens per second) 89 | llama_print_timings: total time = 6710.30 ms 90 | 91 | target: 92 | 93 | llama_print_timings: load time = 18568.44 ms 94 | llama_print_timings: sample time = 208.78 ms / 258 runs ( 0.81 ms per token, 1235.75 tokens per second) 95 | llama_print_timings: prompt eval time = 3164.84 ms / 342 tokens ( 9.25 ms per token, 108.06 tokens per second) 96 | llama_print_timings: eval time = 775.43 ms / 18 runs ( 43.08 ms per token, 23.21 tokens per second) 97 | llama_print_timings: total time = 7650.67 ms 98 | ggml_metal_free: deallocating 99 | ggml_metal_free: deallocating 100 | ``` 101 | 102 | 103 | I have not tried 13B CodeLlama as the large model yet because my Mac memory is not enough :). -------------------------------------------------------------------------------- /speculative_decoding/instruct_hf_assisted_decoding.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import torch 3 | import time 4 | 5 | 6 | model_id = "huggyllama/llama-13b" 7 | peft_model_id = "timdettmers/guanaco-13b" 8 | assistant_checkpoint = "PY007/TinyLlama-1.1B-Chat-v0.1" 9 | 10 | 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | tokenizer = AutoTokenizer.from_pretrained(model_id) 13 | 14 | 15 | prompt = "Give me detailed info about Jeo Biden." 16 | formatted_prompt = f"### Human: {prompt}### Assistant:" 17 | inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) 18 | model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True) 19 | model.load_adapter(peft_model_id) 20 | print("Large model loaded") 21 | model.config.use_cache = True 22 | assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint).half().to(device) 23 | assistant_model.config.use_cache = True 24 | print("Small model loaded") 25 | 26 | 27 | print("###Native Decoding Starts...\n") 28 | start = time.time() 29 | outputs = model.generate(**inputs, assistant_model=None, max_new_tokens=512) 30 | end = time.time() 31 | print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) 32 | print("Time: ", end - start) 33 | 34 | print("###TinyLlama Assisted Decoding Starts...\n") 35 | start = time.time() 36 | outputs = model.generate(**inputs, assistant_model=assistant_model,max_new_tokens=512) 37 | end = time.time() 38 | print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) 39 | # print time in seconds 40 | print("Time: ", end - start) 41 | 42 | --------------------------------------------------------------------------------