├── .github ├── Pythia_saturation.png ├── TinyLlama_logo.png └── llama2-training.png ├── .gitignore ├── EVAL.md ├── LICENSE ├── PRETRAIN.md ├── README.md ├── README_zh-CN.md ├── chat_gradio ├── README.md ├── app.py └── requirements.txt ├── lit_gpt ├── __init__.py ├── adapter.py ├── adapter_v2.py ├── config.py ├── fused_cross_entropy.py ├── fused_rotary_embedding.py ├── lora.py ├── model.py ├── packed_dataset.py ├── rmsnorm.py ├── speed_monitor.py ├── tokenizer.py └── utils.py ├── pretrain ├── tinyllama.py └── tinyllama_code.py ├── requirements.txt ├── script.sh ├── scripts ├── convert_hf_checkpoint.py ├── convert_lit_checkpoint.py ├── prepare_redpajama.py ├── prepare_slimpajama.py └── prepare_starcoder.py ├── sft ├── finetune.py ├── script.sh ├── simple_inference.py └── simple_inference2.py └── speculative_decoding ├── README.md └── instruct_hf_assisted_decoding.py /.github/Pythia_saturation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jzhang38/TinyLlama/bf122247c486b6b897050e98cbb7bedae8eeba73/.github/Pythia_saturation.png -------------------------------------------------------------------------------- /.github/TinyLlama_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jzhang38/TinyLlama/bf122247c486b6b897050e98cbb7bedae8eeba73/.github/TinyLlama_logo.png -------------------------------------------------------------------------------- /.github/llama2-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jzhang38/TinyLlama/bf122247c486b6b897050e98cbb7bedae8eeba73/.github/llama2-training.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | .DS_Store 4 | *.egg-info 5 | build 6 | .venv 7 | .vscode 8 | 9 | # data 10 | data 11 | checkpoints 12 | out 13 | wandb 14 | 15 | tests/original_falcon_40b.py 16 | sft/output 17 | sft/wandb -------------------------------------------------------------------------------- /EVAL.md: -------------------------------------------------------------------------------- 1 | ## Evaluate TinyLlama 2 | 3 | ### GPT4All Benchmarks 4 | 5 | We evaluate TinyLlama's commonsense reasoning ability following the [GPT4All](https://gpt4all.io/index.html) evaluation suite. We include Pythia as our baseline. We report the acc_norm by default. 6 | 7 | Base models: 8 | 9 | | Model | Pretrain Tokens | HellaSwag | Obqa | WinoGrande | ARC_c | ARC_e | boolq | piqa | avg | 10 | |-------------------------------------------|-----------------|-----------|------|------------|-------|-------|-------|------|-----| 11 | | Pythia-1.0B | 300B | 47.16 | 31.40| 53.43 | 27.05 | 48.99 | 60.83 | 69.21 | 48.30 | 12 | | TinyLlama-1.1B-intermediate-step-50K-104b | 103B | 43.50 | 29.80| 53.28 | 24.32 | 44.91 | 59.66 | 67.30 | 46.11| 13 | | TinyLlama-1.1B-intermediate-step-240k-503b| 503B | 49.56 |31.40 |55.80 |26.54 |48.32 |56.91 |69.42 | 48.28 | 14 | | TinyLlama-1.1B-intermediate-step-480k-1007B | 1007B | 52.54 | 33.40 | 55.96 | 27.82 | 52.36 | 59.54 | 69.91 | 50.22 | 15 | | TinyLlama-1.1B-intermediate-step-715k-1.5T | 1.5T | 53.68 | 35.20 | 58.33 | 29.18 | 51.89 | 59.08 | 71.65 | 51.29 | 16 | | TinyLlama-1.1B-intermediate-step-955k-2T | 2T | 54.63 | 33.40 | 56.83 | 28.07 | 54.67 | 63.21 | 70.67 | 51.64 | 17 | | TinyLlama-1.1B-intermediate-step-1195k-2.5T | 2.5T | 58.96 | 34.40 | 58.72 | 31.91 | 56.78 | 63.21 | 73.07 | 53.86| 18 | | TinyLlama-1.1B-intermediate-step-1431k-3T | 3T | 59.20 | 36.00 | 59.12 | 30.12 | 55.25 | 57.83 | 73.29 | 52.99| 19 | 20 | 21 | Chat models: 22 | | Model | Pretrain Tokens | HellaSwag | Obqa | WinoGrande | ARC_c | ARC_e | boolq | piqa | avg | 23 | |-------------------------------------------|-----------------|-----------|------|------------|-------|-------|-------|------|-----| 24 | | [TinyLlama-1.1B-Chat-v0.1](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.1) | 503B | 53.81 |32.20 | 55.01 | 28.67 |49.62 | 58.04 | 69.64 | 49.57 | 25 | | [TinyLlama-1.1B-Chat-v0.2](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.2) | 503B | 53.63 |32.80 | 54.85 | 28.75 |49.16 | 55.72 | 69.48 | 49.20 | 26 | | [TinyLlama-1.1B-Chat-v0.3](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.3) | 1T | 56.81 |34.20 | 55.80 | 30.03 |53.20 | 59.57 | 69.91 | 51.36 | 27 | | [TinyLlama-1.1B-Chat-v0.4](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4) | 1.5T | 58.59 |35.40 | 58.80 | 30.80 |54.04 | 57.31 | 71.16 | 52.30 | 28 | 29 | 30 | We observed huge improvements once we finetuned the model. We attribute this phenomenon to: 1. the base model has not undergone lr cool-down and FT helps to cool down the lr. 2. the SFT stage better elicits the model's internal knowledge. 31 | 32 | You can obtain the above scores by running [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness): 33 | ```bash 34 | python main.py \ 35 | --model hf-causal \ 36 | --model_args pretrained=PY007/TinyLlama-1.1B-Chat-v0.1,dtype="float" \ 37 | --tasks hellaswag,openbookqa,winogrande,arc_easy,arc_challenge,boolq,piqa\ 38 | --device cuda:0 --batch_size 32 39 | ``` 40 | 41 | 42 | 43 | ### Instruct-Eval Benchmarks 44 | We evaluate TinyLlama's ability in problem-solving on the [Instruct-Eval](https://github.com/declare-lab/instruct-eval) evaluation suite. 45 | 46 | 47 | | Model | MMLU | BBH | HumanEval | DROP | 48 | | ------------------------------------------------- | ----- | ----- | --------- | ----- | 49 | | Pythia-1.0B | 25.70 | 28.19 | 1.83 | 4.25 | 50 | | TinyLlama-1.1B-intermediate-step-50K-104b | 26.45 | 28.82 | 5.49 | 11.42 | 51 | | TinyLlama-1.1B-intermediate-step-240k-503b | 26.16 | 28.83 | 4.88 | 12.43 | 52 | | TinyLlama-1.1B-intermediate-step-480K-1T | 24.65 | 29.21 | 6.1 | 13.03 | 53 | | TinyLlama-1.1B-intermediate-step-715k-1.5T | 24.85 | 28.2 | 7.93 | 14.43 | 54 | | TinyLlama-1.1B-intermediate-step-955k-2T | 25.97 | 29.07 | 6.71 | 13.14 | 55 | | TinyLlama-1.1B-intermediate-step-1195k-token-2.5T | 25.92 | 29.32 | 9.15 | 15.45 | 56 | 57 | You can obtain above scores by running [instruct-eval](https://github.com/declare-lab/instruct-eval): 58 | ```bash 59 | CUDA_VISIBLE_DEVICES=0 python main.py mmlu --model_name llama --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T 60 | CUDA_VISIBLE_DEVICES=1 python main.py bbh --model_name llama --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T 61 | CUDA_VISIBLE_DEVICES=2 python main.py drop --model_name llama --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T 62 | CUDA_VISIBLE_DEVICES=3 python main.py humaneval --model_name llama --n_sample 1 --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T 63 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2023] Lightning AI 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /PRETRAIN.md: -------------------------------------------------------------------------------- 1 | ## Pretrain TinyLlama 2 | 3 | ### Installation 4 | We expect you have CUDA 11.8 installed. 5 | #### Install Pytorch Nightly. 6 | ```bash 7 | pip install --index-url https://download.pytorch.org/whl/nightly/cu118 --pre 'torch>=2.1.0dev' 8 | ``` 9 | #### Build XFormers from Source 10 | Note: as of 2023/09/02, xformers does not provide pre-built binaries for torch 2.1. You have to build it from source. 11 | ```bash 12 | pip uninstall ninja -y && pip install ninja -U 13 | pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers 14 | ``` 15 | 16 | 17 | #### Install Flash-Attention 2 and other fused operators: 18 | ```bash 19 | git clone https://github.com/Dao-AILab/flash-attention 20 | cd flash-attention 21 | python setup.py install 22 | cd csrc/rotary && pip install . 23 | cd ../layer_norm && pip install . 24 | cd ../xentropy && pip install . 25 | cd ../.. && rm -rf flash-attention 26 | ``` 27 | #### Install Remaining Dependencies 28 | ``` 29 | pip install -r requirements.txt tokenizers sentencepiece 30 | ``` 31 | to install other dependencies. 32 | It may take >= 5 minutes to build xformers/flash-attention. Do not worry if the process seemly stagnant or the terminal print out many warnings. 33 | 34 | Then you are ready to go 🎉! 35 | 36 | ### Data Preparation 37 | 38 | #### Download Datasets 39 | Download the Slimpajama and Starcoderdata datasets to your chosen directory. 40 | ```bash 41 | cd /path/to/dataset 42 | git lfs install 43 | git clone https://huggingface.co/datasets/cerebras/SlimPajama-627B 44 | git clone https://huggingface.co/datasets/bigcode/starcoderdata 45 | ``` 46 | The SlimPajama dataset eats 893GB diskspace and the starcoderdata takes 290GB. 47 | 48 | #### Tokenize data 49 | Use the provided scripts to tokenize the datasets and divide them into chunks. 50 | ```bash 51 | python scripts/prepare_starcoder.py --source_path /path/to/starcoderdata/ --tokenizer_path data/llama --destination_path data/slim_star_combined --split train --percentage 1.0 52 | python scripts/prepare_slimpajama.py --source_path /path/to/SlimPajama --tokenizer_path data/llama --destination_path data/slim_star_combined --split validation --percentage 1.0 53 | python scripts/prepare_slimpajama.py --source_path /path/to/SlimPajama --tokenizer_path data/llama --destination_path data/slim_star_combined --split train --percentage 1.0 54 | ``` 55 | The processed data will take 1.8T storage. 56 | 57 | ### Pretraining 58 | If your setup comprises two nodes, each with 8 GPUs, you can initiate pretraining with the following commands: 59 | 60 | On node 1: 61 | ``` 62 | lightning run model \ 63 | --node-rank=0 \ 64 | --main-address=172.16.101.5 \ 65 | --accelerator=cuda \ 66 | --devices=8 \ 67 | --num-nodes=2 \ 68 | pretrain/tinyllama.py --devices 8 --train_data_dir data/slim_star --val_data_dir data/slim_star 69 | ``` 70 | On node 2: 71 | ``` 72 | lightning run model \ 73 | --node-rank=1 \ 74 | --main-address=172.16.101.5 \ 75 | --accelerator=cuda \ 76 | --devices=8 \ 77 | --num-nodes=2 \ 78 | pretrain/tinyllama.py --devices 8 --train_data_dir data/slim_star --val_data_dir data/slim_star 79 | ``` 80 | You can follow [these instructions](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html) if you have a slurm cluster. 81 | 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
self._chunk_size: 108 | part_len = self._chunk_size - self._idx 109 | self._arr[self._idx : self._idx + part_len] = arr[:part_len] 110 | self._write_chunk() 111 | arr = arr[part_len:] 112 | 113 | arr_len = arr.shape[0] 114 | self._arr[self._idx : self._idx + arr_len] = arr 115 | self._idx += arr_len 116 | 117 | def write_reminder(self): 118 | self._write_chunk() 119 | 120 | 121 | class PackedDatasetIterator: 122 | def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): 123 | self._seed = seed 124 | self._shuffle = shuffle 125 | self._rng = np.random.default_rng(seed) if shuffle else None 126 | self._block_idxs = None 127 | 128 | self._wrap = wrap 129 | 130 | # TODO: instead of filenames, we could have a single text stream 131 | # (or text file) with the sequence of all files to be 132 | # fetched/loaded. 133 | self._filenames = filenames 134 | self._file_idx = 0 135 | 136 | self._n_chunks = n_chunks 137 | 138 | self._dtype = None 139 | self._block_size = block_size 140 | self._n_blocks = None 141 | 142 | self._mmaps = [] 143 | self._buffers = [] 144 | 145 | self._block_idxs = [] 146 | self._curr_idx = 0 147 | 148 | self._load_n_chunks() 149 | 150 | def _read_header(self, path): 151 | with open(path, "rb") as f: 152 | magic = f.read(len(HDR_MAGIC)) 153 | assert magic == HDR_MAGIC, "File doesn't match expected format." 154 | version = struct.unpack("len(self._filenames[self._file_idx :]): 171 | # if not self._wrap: 172 | # raise StopIteration 173 | self._file_idx = 0 174 | 175 | for i in range(self._n_chunks): 176 | filename = self._filenames[self._file_idx + i] 177 | if self._dtype is None: 178 | self._dtype, self._chunk_size = self._read_header(filename) 179 | self._n_blocks = self._chunk_size // self._block_size 180 | # TODO: check header matches with previous files 181 | mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) 182 | self._mmaps.append(mmap) 183 | self._buffers.append(memoryview(mmap)) 184 | 185 | self._file_idx += self._n_chunks 186 | n_all_blocks = self._n_chunks * self._n_blocks 187 | 188 | self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks) 189 | 190 | self._curr_idx = 0 191 | 192 | def __del__(self): 193 | self._close_mmaps() 194 | del self._mmaps 195 | del self._buffers 196 | 197 | def __iter__(self): 198 | return self 199 | 200 | def __next__(self): 201 | if self._curr_idx >= len(self._block_idxs): 202 | self._load_n_chunks() 203 | # TODO: trigger fetching next next n_chunks if remote 204 | block_idx = self._block_idxs[self._curr_idx] 205 | chunk_id = block_idx // self._n_blocks 206 | buffer = self._buffers[chunk_id] 207 | elem_id = (block_idx % self._n_blocks) * self._block_size 208 | offset = np.dtype(self._dtype).itemsize * elem_id 209 | arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) 210 | self._curr_idx += 1 211 | return torch.from_numpy(arr.astype(np.int64)) 212 | 213 | 214 | class CombinedDataset(IterableDataset): 215 | def __init__(self, datasets, seed, weights=None): 216 | self._seed = seed 217 | self._datasets = datasets 218 | self._weights = weights 219 | n_datasets = len(datasets) 220 | if weights is None: 221 | self._weights = [1 / n_datasets] * n_datasets 222 | 223 | def __iter__(self): 224 | return CombinedDatasetIterator(self._datasets, self._seed, self._weights) 225 | 226 | 227 | class CombinedDatasetIterator: 228 | def __init__(self, datasets, seed, weights): 229 | self._datasets = [iter(el) for el in datasets] 230 | self._weights = weights 231 | self._rng = random.Random(seed) 232 | 233 | def __next__(self): 234 | (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) 235 | return next(dataset) 236 | -------------------------------------------------------------------------------- /lit_gpt/tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | class Tokenizer: 9 | def __init__(self, checkpoint_dir: Path) -> None: 10 | # some checkpoints have both files, `.model` takes precedence 11 | if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): 12 | from sentencepiece import SentencePieceProcessor 13 | 14 | self.processor = SentencePieceProcessor(model_file=str(vocabulary_path)) 15 | self.backend = "sentencepiece" 16 | self.bos_id = self.processor.bos_id() 17 | self.eos_id = self.processor.eos_id() 18 | elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file(): 19 | from tokenizers import Tokenizer as HFTokenizer 20 | 21 | self.processor = HFTokenizer.from_file(str(vocabulary_path)) 22 | self.backend = "huggingface" 23 | with open(checkpoint_dir / "tokenizer_config.json") as fp: 24 | config = json.load(fp) 25 | bos_token = config.get("bos_token") 26 | self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None 27 | self.eos_id = self.token_to_id(config["eos_token"]) 28 | else: 29 | raise NotImplementedError 30 | 31 | @property 32 | def vocab_size(self) -> int: 33 | if self.backend == "huggingface": 34 | return self.processor.get_vocab_size(with_added_tokens=False) 35 | if self.backend == "sentencepiece": 36 | return self.processor.vocab_size() 37 | raise RuntimeError 38 | 39 | def token_to_id(self, token: str) -> int: 40 | if self.backend == "huggingface": 41 | id_ = self.processor.token_to_id(token) 42 | elif self.backend == "sentencepiece": 43 | id_ = self.processor.piece_to_id(token) 44 | else: 45 | raise RuntimeError 46 | if id_ is None: 47 | raise ValueError(f"token {token!r} not found in the collection.") 48 | return id_ 49 | 50 | def encode( 51 | self, 52 | string: str, 53 | device: Optional[torch.device] = None, 54 | bos: bool = False, 55 | eos: bool = True, 56 | max_length: int = -1, 57 | ) -> torch.Tensor: 58 | if self.backend == "huggingface": 59 | tokens = self.processor.encode(string).ids 60 | elif self.backend == "sentencepiece": 61 | tokens = self.processor.encode(string) 62 | else: 63 | raise RuntimeError 64 | if bos: 65 | bos_id = self.bos_id 66 | if bos_id is None: 67 | raise NotImplementedError("This tokenizer does not defined a bos token") 68 | tokens = [bos_id] + tokens 69 | if eos: 70 | tokens = tokens + [self.eos_id] 71 | if max_length > 0: 72 | tokens = tokens[:max_length] 73 | return torch.tensor(tokens, dtype=torch.int, device=device) 74 | 75 | def decode(self, tensor: torch.Tensor) -> str: 76 | tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() 77 | return self.processor.decode(tokens) 78 | -------------------------------------------------------------------------------- /lit_gpt/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for training and inference.""" 2 | 3 | import pickle 4 | import sys 5 | import warnings 6 | from contextlib import contextmanager 7 | from functools import partial 8 | from io import BytesIO 9 | from pathlib import Path 10 | from types import MethodType 11 | from typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, Union 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.utils._device 16 | from lightning.fabric.loggers import CSVLogger 17 | from torch.serialization import normalize_storage_type 18 | 19 | 20 | def find_multiple(n: int, k: int) -> int: 21 | assert k > 0 22 | if n % k == 0: 23 | return n 24 | return n + k - (n % k) 25 | 26 | 27 | def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: 28 | return sum(p.numel() for p in module.parameters() if requires_grad is None or p.requires_grad == requires_grad) 29 | 30 | 31 | @contextmanager 32 | def quantization(mode: Optional[str] = None): 33 | if mode is None: 34 | yield 35 | return 36 | 37 | if mode == "bnb.int8": 38 | from quantize.bnb import InferenceLinear8bitLt 39 | 40 | quantized_linear_cls = InferenceLinear8bitLt 41 | elif mode == "bnb.fp4": 42 | from quantize.bnb import Linear4bit 43 | 44 | # Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses 45 | class QuantizedLinear(Linear4bit): 46 | def __init__(self, *args, **kwargs): 47 | super().__init__(*args, quant_type="fp4", compress_statistics=False, **kwargs) 48 | 49 | quantized_linear_cls = QuantizedLinear 50 | elif mode == "bnb.fp4-dq": 51 | from quantize.bnb import Linear4bit 52 | 53 | class QuantizedLinear(Linear4bit): 54 | def __init__(self, *args, **kwargs): 55 | super().__init__(*args, quant_type="fp4", compress_statistics=True, **kwargs) 56 | 57 | quantized_linear_cls = QuantizedLinear 58 | elif mode == "bnb.nf4": 59 | from quantize.bnb import Linear4bit 60 | 61 | class QuantizedLinear(Linear4bit): 62 | def __init__(self, *args, **kwargs): 63 | super().__init__(*args, quant_type="nf4", compress_statistics=False, **kwargs) 64 | 65 | quantized_linear_cls = QuantizedLinear 66 | elif mode == "bnb.nf4-dq": 67 | from quantize.bnb import Linear4bit 68 | 69 | class QuantizedLinear(Linear4bit): 70 | def __init__(self, *args, **kwargs): 71 | super().__init__(*args, quant_type="nf4", compress_statistics=True, **kwargs) 72 | 73 | quantized_linear_cls = QuantizedLinear 74 | elif mode == "gptq.int4": 75 | from quantize.gptq import ColBlockQuantizedLinear 76 | 77 | class QuantizedLinear(ColBlockQuantizedLinear): 78 | def __init__(self, *args, **kwargs): 79 | super().__init__(*args, bits=4, tile_cols=-1, **kwargs) 80 | 81 | quantized_linear_cls = QuantizedLinear 82 | else: 83 | raise ValueError(f"Unknown quantization mode: {mode}") 84 | 85 | torch_linear_cls = torch.nn.Linear 86 | torch.nn.Linear = quantized_linear_cls 87 | yield 88 | torch.nn.Linear = torch_linear_cls 89 | 90 | 91 | # this is taken from torchhacks https://github.com/lernapparat/torchhacks 92 | 93 | 94 | class NotYetLoadedTensor: 95 | def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): 96 | self.metatensor = metatensor 97 | self.archiveinfo = archiveinfo 98 | self.storageinfo = storageinfo 99 | self.rebuild_args = rebuild_args 100 | 101 | @classmethod 102 | def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None): 103 | ret = func(*args) 104 | if isinstance(ret, NotYetLoadedTensor): 105 | old_lt = ret._load_tensor 106 | 107 | def _load_tensor(): 108 | t = old_lt() 109 | return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) 110 | 111 | ret._load_tensor = _load_tensor 112 | return ret 113 | return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) 114 | 115 | @classmethod 116 | def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None): 117 | if isinstance(data, NotYetLoadedTensor): 118 | old_lt = data._load_tensor 119 | 120 | def _load_tensor(): 121 | t = old_lt() 122 | return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) 123 | 124 | data._load_tensor = _load_tensor 125 | return data 126 | return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) 127 | 128 | @classmethod 129 | def rebuild_tensor_v2( 130 | cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, *, archiveinfo=None 131 | ): 132 | rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata) 133 | metatensor = torch._utils._rebuild_tensor_v2( 134 | storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata 135 | ) 136 | storageinfo = storage.archiveinfo 137 | return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) 138 | 139 | def _load_tensor(self): 140 | name, storage_cls, fn, device, size = self.storageinfo 141 | dtype = self.metatensor.dtype 142 | 143 | uts = ( 144 | self.archiveinfo.zipfile_context.zf.get_storage_from_record( 145 | f"data/{fn}", size * torch._utils._element_size(dtype), torch.UntypedStorage 146 | ) 147 | ._typed_storage() 148 | ._untyped_storage 149 | ) 150 | with warnings.catch_warnings(): 151 | warnings.simplefilter("ignore") 152 | storage = torch.storage.TypedStorage(wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True) 153 | return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) 154 | 155 | @classmethod 156 | def __torch_function__(cls, func, types, args=(), kwargs=None): 157 | if kwargs is None: 158 | kwargs = {} 159 | loaded_args = [(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args] 160 | return func(*loaded_args, **kwargs) 161 | # gc.collect would be costly here, maybe do it optionally 162 | 163 | def __getattr__(self, name): 164 | # properties 165 | ## TODO: device, is_...?? 166 | ## TODO: mH, mT, H, T, data, imag, real 167 | ## name ??? 168 | if name in { 169 | "dtype", 170 | "grad", 171 | "grad_fn", 172 | "layout", 173 | "names", 174 | "ndim", 175 | "output_nr", 176 | "requires_grad", 177 | "retains_grad", 178 | "shape", 179 | "volatile", 180 | }: 181 | return getattr(self.metatensor, name) 182 | if name in {"size"}: 183 | return getattr(self.metatensor, name) 184 | # materializing with contiguous is needed for quantization 185 | if name in {"contiguous"}: 186 | return getattr(self._load_tensor(), name) 187 | 188 | raise AttributeError(f"{type(self)} does not have {name}") 189 | 190 | def __repr__(self): 191 | return f"NotYetLoadedTensor({repr(self.metatensor)})" 192 | 193 | 194 | class LazyLoadingUnpickler(pickle.Unpickler): 195 | def __init__(self, file, zipfile_context): 196 | super().__init__(file) 197 | self.zipfile_context = zipfile_context 198 | 199 | def find_class(self, module, name): 200 | res = super().find_class(module, name) 201 | if module == "torch._utils" and name == "_rebuild_tensor_v2": 202 | return partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) 203 | if module == "torch._tensor" and name == "_rebuild_from_type_v2": 204 | return partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) 205 | if module == "torch._utils" and name == "_rebuild_parameter": 206 | return partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) 207 | return res 208 | 209 | def persistent_load(self, pid): 210 | name, cls, fn, device, size = pid 211 | with warnings.catch_warnings(): 212 | warnings.simplefilter("ignore") 213 | s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") 214 | s.archiveinfo = pid 215 | return s 216 | 217 | 218 | class lazy_load: 219 | def __init__(self, fn): 220 | self.zf = torch._C.PyTorchFileReader(str(fn)) 221 | with BytesIO(self.zf.get_record("data.pkl")) as pkl: 222 | mup = LazyLoadingUnpickler(pkl, self) 223 | self.sd = mup.load() 224 | 225 | def __enter__(self): 226 | return self.sd 227 | 228 | def __exit__(self, exc_type, exc_val, exc_tb): 229 | del self.zf # I don't think there is a way to force closing... 230 | self.zf = None 231 | 232 | 233 | def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: 234 | files = { 235 | "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), 236 | "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), 237 | "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( 238 | checkpoint_dir / "tokenizer.model" 239 | ).is_file(), 240 | "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), 241 | } 242 | if checkpoint_dir.is_dir(): 243 | if all(files.values()): 244 | # we're good 245 | return 246 | problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" 247 | else: 248 | problem = " is not a checkpoint directory" 249 | 250 | # list locally available checkpoints 251 | available = list(Path("checkpoints").glob("*/*")) 252 | if available: 253 | options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available]) 254 | extra = f"\nYou have downloaded locally:{options}\n" 255 | else: 256 | extra = "" 257 | 258 | error_message = ( 259 | f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." 260 | "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n" 261 | f"{extra}\nSee all download options by running:\n python scripts/download.py" 262 | ) 263 | print(error_message, file=sys.stderr) 264 | raise SystemExit(1) 265 | 266 | 267 | class SavingProxyForStorage: 268 | def __init__(self, obj, saver, protocol_version=5): 269 | self.protocol_version = protocol_version 270 | self.saver = saver 271 | if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): 272 | raise TypeError(f"expected storage, not {type(obj)}") 273 | 274 | # this logic is taken from PyTorch 2.0+ torch/serialization.py 275 | if isinstance(obj, torch.storage.TypedStorage): 276 | # PT upstream wants to deprecate this eventually... 277 | storage = obj._untyped_storage 278 | storage_type_str = obj._pickle_storage_type() 279 | storage_type = getattr(torch, storage_type_str) 280 | storage_numel = obj._size() 281 | else: 282 | storage = obj 283 | storage_type = normalize_storage_type(type(obj)) 284 | storage_numel = storage.nbytes() 285 | 286 | storage_key = saver._write_storage_and_return_key(storage) 287 | location = torch.serialization.location_tag(storage) 288 | 289 | self.storage_info = ("storage", storage_type, storage_key, location, storage_numel) 290 | 291 | def __reduce_ex__(self, protocol_version): 292 | assert False, "this should be handled with out of band" 293 | 294 | 295 | class SavingProxyForTensor: 296 | def __init__(self, tensor, saver, protocol_version=5): 297 | self.protocol_version = protocol_version 298 | self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(protocol_version) 299 | assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" 300 | storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) 301 | self.reduce_args = (storage_proxy, *other_reduce_args) 302 | 303 | def __reduce_ex__(self, protocol_version): 304 | if protocol_version != self.protocol_version: 305 | raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}") 306 | return self.reduce_ret_fn, self.reduce_args 307 | 308 | 309 | class IncrementalPyTorchPickler(pickle.Pickler): 310 | def __init__(self, saver, *args, **kwargs): 311 | super().__init__(*args, **kwargs) 312 | self.storage_dtypes = {} 313 | self.saver = saver 314 | self.id_map = {} 315 | 316 | # this logic is taken from PyTorch 2.0+ torch/serialization.py 317 | def persistent_id(self, obj): 318 | # FIXME: the docs say that persistent_id should only return a string 319 | # but torch store returns tuples. This works only in the binary protocol 320 | # see 321 | # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects 322 | # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 323 | if isinstance(obj, SavingProxyForStorage): 324 | return obj.storage_info 325 | 326 | if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): 327 | if isinstance(obj, torch.storage.TypedStorage): 328 | # TODO: Once we decide to break serialization FC, this case 329 | # can be deleted 330 | storage = obj._untyped_storage 331 | storage_dtype = obj.dtype 332 | storage_type_str = obj._pickle_storage_type() 333 | storage_type = getattr(torch, storage_type_str) 334 | storage_numel = obj._size() 335 | 336 | else: 337 | storage = obj 338 | storage_dtype = torch.uint8 339 | storage_type = normalize_storage_type(type(obj)) 340 | storage_numel = storage.nbytes() 341 | 342 | # If storage is allocated, ensure that any other saved storages 343 | # pointing to the same data all have the same dtype. If storage is 344 | # not allocated, don't perform this check 345 | if storage.data_ptr() != 0: 346 | if storage.data_ptr() in self.storage_dtypes: 347 | if storage_dtype != self.storage_dtypes[storage.data_ptr()]: 348 | raise RuntimeError( 349 | "Cannot save multiple tensors or storages that view the same data as different types" 350 | ) 351 | else: 352 | self.storage_dtypes[storage.data_ptr()] = storage_dtype 353 | 354 | storage_key = self.id_map.get(storage._cdata) 355 | if storage_key is None: 356 | storage_key = self.saver._write_storage_and_return_key(storage) 357 | self.id_map[storage._cdata] = storage_key 358 | location = torch.serialization.location_tag(storage) 359 | 360 | return ("storage", storage_type, storage_key, location, storage_numel) 361 | 362 | return None 363 | 364 | 365 | class incremental_save: 366 | def __init__(self, name): 367 | self.name = name 368 | self.zipfile = torch._C.PyTorchFileWriter(str(name)) 369 | self.has_saved = False 370 | self.next_key = 0 371 | 372 | def __enter__(self): 373 | return self 374 | 375 | def store_early(self, tensor): 376 | if isinstance(tensor, torch.Tensor): 377 | return SavingProxyForTensor(tensor, self) 378 | raise TypeError(f"can only store tensors early, not {type(tensor)}") 379 | 380 | def save(self, obj): 381 | if self.has_saved: 382 | raise RuntimeError("have already saved") 383 | # Write the pickle data for `obj` 384 | data_buf = BytesIO() 385 | pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) 386 | pickler.dump(obj) 387 | data_value = data_buf.getvalue() 388 | self.zipfile.write_record("data.pkl", data_value, len(data_value)) 389 | self.has_saved = True 390 | 391 | def _write_storage_and_return_key(self, storage): 392 | if self.has_saved: 393 | raise RuntimeError("have already saved") 394 | key = self.next_key 395 | self.next_key += 1 396 | name = f"data/{key}" 397 | if storage.device.type != "cpu": 398 | storage = storage.cpu() 399 | num_bytes = storage.nbytes() 400 | self.zipfile.write_record(name, storage.data_ptr(), num_bytes) 401 | return key 402 | 403 | def __exit__(self, type, value, traceback): 404 | self.zipfile.write_end_of_file() 405 | 406 | 407 | T = TypeVar("T") 408 | 409 | 410 | def step_csv_logger(*args: Any, cls: Type[T] = CSVLogger, **kwargs: Any) -> T: 411 | logger = cls(*args, **kwargs) 412 | 413 | def merge_by(dicts, key): 414 | from collections import defaultdict 415 | 416 | out = defaultdict(dict) 417 | for d in dicts: 418 | if key in d: 419 | out[d[key]].update(d) 420 | return [v for _, v in sorted(out.items())] 421 | 422 | def save(self) -> None: 423 | """Overridden to merge CSV by the step number.""" 424 | import csv 425 | 426 | if not self.metrics: 427 | return 428 | metrics = merge_by(self.metrics, "step") 429 | keys = sorted({k for m in metrics for k in m}) 430 | with self._fs.open(self.metrics_file_path, "w", newline="") as f: 431 | writer = csv.DictWriter(f, fieldnames=keys) 432 | writer.writeheader() 433 | writer.writerows(metrics) 434 | 435 | logger.experiment.save = MethodType(save, logger.experiment) 436 | 437 | return logger 438 | 439 | 440 | def chunked_cross_entropy( 441 | logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128 442 | ) -> torch.Tensor: 443 | # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate 444 | # the memory usage in fine-tuning settings with low number of parameters. 445 | # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing 446 | # the memory spike's magnitude 447 | 448 | # lm_head was chunked (we are fine-tuning) 449 | if isinstance(logits, list): 450 | # don't want to chunk cross entropy 451 | if chunk_size == 0: 452 | logits = torch.cat(logits, dim=1) 453 | logits = logits.reshape(-1, logits.size(-1)) 454 | targets = targets.reshape(-1) 455 | return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) 456 | 457 | # chunk cross entropy 458 | logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] 459 | target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)] 460 | loss_chunks = [ 461 | torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") 462 | for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) 463 | ] 464 | return torch.cat(loss_chunks).mean() 465 | 466 | # no chunking at all 467 | logits = logits.reshape(-1, logits.size(-1)) 468 | targets = targets.reshape(-1) 469 | if chunk_size == 0: 470 | return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) 471 | 472 | # lm_head wasn't chunked, chunk cross entropy 473 | logit_chunks = logits.split(chunk_size) 474 | target_chunks = targets.split(chunk_size) 475 | loss_chunks = [ 476 | torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") 477 | for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) 478 | ] 479 | return torch.cat(loss_chunks).mean() 480 | 481 | 482 | def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: 483 | for checkpoint_name, attribute_name in mapping.items(): 484 | full_checkpoint_name = prefix + checkpoint_name 485 | if full_checkpoint_name in state_dict: 486 | full_attribute_name = prefix + attribute_name 487 | state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) 488 | return state_dict 489 | 490 | 491 | def get_default_supported_precision(training: bool, tpu: bool = False) -> str: 492 | """Return default precision that is supported by the hardware. 493 | 494 | Args: 495 | training: `-mixed` or `-true` version of the precision to use 496 | tpu: whether TPU device is used 497 | 498 | Returns: 499 | default precision that is suitable for the task and is supported by the hardware 500 | """ 501 | if tpu: 502 | return "32-true" 503 | if not torch.cuda.is_available() or torch.cuda.is_bf16_supported(): 504 | return "bf16-mixed" if training else "bf16-true" 505 | return "16-mixed" if training else "16-true" 506 | -------------------------------------------------------------------------------- /pretrain/tinyllama.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | import sys 4 | import time 5 | from pathlib import Path 6 | from typing import Optional, Tuple, Union 7 | import math 8 | import lightning as L 9 | import torch 10 | from lightning.fabric.strategies import FSDPStrategy, XLAStrategy 11 | from torch.utils.data import DataLoader 12 | from functools import partial 13 | # support running without installing as a package 14 | wd = Path(__file__).parent.parent.resolve() 15 | sys.path.append(str(wd)) 16 | # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually 17 | from lit_gpt.model import GPT, Block, Config, CausalSelfAttention 18 | from lit_gpt.packed_dataset import CombinedDataset, PackedDataset 19 | from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor 20 | from lit_gpt.speed_monitor import estimate_flops, measure_flops 21 | from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load 22 | from pytorch_lightning.loggers import WandbLogger 23 | from lit_gpt import FusedCrossEntropyLoss 24 | import random 25 | 26 | model_name = "tiny_LLaMA_1b" 27 | name = "tinyllama_1b" 28 | out_dir = Path("out") / name 29 | 30 | # Hyperparameters 31 | num_of_devices = 8 32 | global_batch_size = 512 33 | learning_rate = 4e-4 34 | micro_batch_size = 8 35 | max_step = 715256 * 2 36 | warmup_steps = 2000 37 | log_step_interval = 10 38 | eval_iters = 100 39 | save_step_interval = 5000 40 | eval_step_interval = 5000 41 | 42 | 43 | weight_decay = 1e-1 44 | beta1 = 0.9 45 | beta2 = 0.95 46 | grad_clip = 1.0 47 | decay_lr = True 48 | min_lr = 4e-5 49 | 50 | batch_size = global_batch_size // num_of_devices 51 | gradient_accumulation_steps = batch_size // micro_batch_size 52 | assert gradient_accumulation_steps > 0 53 | warmup_iters = warmup_steps * gradient_accumulation_steps 54 | 55 | 56 | 57 | 58 | max_iters = max_step * gradient_accumulation_steps 59 | lr_decay_iters = max_iters 60 | log_iter_interval = log_step_interval * gradient_accumulation_steps 61 | 62 | 63 | # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. 64 | train_data_config = [ 65 | ("train_slim", 0.693584), 66 | ("train_star", 0.306416), 67 | ] 68 | 69 | val_data_config = [ 70 | ("validation", 1.0), 71 | ] 72 | 73 | hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} 74 | logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) 75 | wandb_logger = WandbLogger() 76 | 77 | 78 | def setup( 79 | devices: int = 8, 80 | train_data_dir: Path = Path("data/redpajama_sample"), 81 | val_data_dir: Optional[Path] = None, 82 | precision: Optional[str] = None, 83 | tpu: bool = False, 84 | resume: Union[bool, Path] = False, 85 | ) -> None: 86 | precision = precision or get_default_supported_precision(training=True, tpu=tpu) 87 | 88 | if devices > 1: 89 | if tpu: 90 | # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. 91 | devices = "auto" 92 | strategy = XLAStrategy(sync_module_states=False) 93 | else: 94 | strategy = FSDPStrategy( 95 | auto_wrap_policy={Block}, 96 | activation_checkpointing_policy=None, 97 | state_dict_type="full", 98 | limit_all_gathers=True, 99 | cpu_offload=False, 100 | ) 101 | else: 102 | strategy = "auto" 103 | 104 | fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) 105 | fabric.print(hparams) 106 | #fabric.launch(main, train_data_dir, val_data_dir, resume) 107 | main(fabric, train_data_dir, val_data_dir, resume) 108 | 109 | 110 | def main(fabric, train_data_dir, val_data_dir, resume): 111 | monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) 112 | 113 | if fabric.global_rank == 0: 114 | out_dir.mkdir(parents=True, exist_ok=True) 115 | 116 | config = Config.from_name(model_name) 117 | 118 | train_dataloader, val_dataloader = create_dataloaders( 119 | batch_size=micro_batch_size, 120 | block_size=config.block_size, 121 | fabric=fabric, 122 | train_data_dir=train_data_dir, 123 | val_data_dir=val_data_dir, 124 | seed=3407, 125 | ) 126 | if val_dataloader is None: 127 | train_dataloader = fabric.setup_dataloaders(train_dataloader) 128 | else: 129 | train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) 130 | 131 | fabric.seed_everything(3407) # same seed for every process to init model (FSDP) 132 | 133 | fabric.print(f"Loading model with {config.__dict__}") 134 | t0 = time.perf_counter() 135 | with fabric.init_module(empty_init=False): 136 | model = GPT(config) 137 | model.apply(partial(model._init_weights ,n_layer=config.n_layer)) 138 | 139 | 140 | fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") 141 | fabric.print(f"Total parameters {num_parameters(model):,}") 142 | 143 | model = fabric.setup(model) 144 | optimizer = torch.optim.AdamW( 145 | model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False 146 | ) 147 | # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) 148 | optimizer = fabric.setup_optimizers(optimizer) 149 | 150 | state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} 151 | 152 | if resume is True: 153 | resume = sorted(out_dir.glob("*.pth"))[-1] 154 | if resume : 155 | fabric.print(f"Resuming training from {resume}") 156 | fabric.load(resume, state) 157 | 158 | train_time = time.perf_counter() 159 | train(fabric, state, train_dataloader, val_dataloader, monitor, resume) 160 | fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") 161 | if fabric.device.type == "cuda": 162 | fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") 163 | 164 | 165 | def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): 166 | model = state["model"] 167 | optimizer = state["optimizer"] 168 | 169 | if val_dataloader is not None: 170 | validate(fabric, model, val_dataloader) # sanity check 171 | 172 | with torch.device("meta"): 173 | meta_model = GPT(model.config) 174 | # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. 175 | # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, 176 | # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead 177 | estimated_flops = estimate_flops(meta_model) * micro_batch_size 178 | fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") 179 | x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) 180 | # measured_flos run in meta. Will trigger fusedRMSNorm error 181 | #measured_flops = measure_flops(meta_model, x) 182 | #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") 183 | del meta_model, x 184 | 185 | total_lengths = 0 186 | total_t0 = time.perf_counter() 187 | 188 | if fabric.device.type == "xla": 189 | import torch_xla.core.xla_model as xm 190 | 191 | xm.mark_step() 192 | 193 | 194 | initial_iter = state["iter_num"] 195 | curr_iter = 0 196 | 197 | loss_func = FusedCrossEntropyLoss() 198 | for train_data in train_dataloader: 199 | # resume loader state. This is not elegant but it works. Should rewrite it in the future. 200 | if resume: 201 | if curr_iter < initial_iter: 202 | curr_iter += 1 203 | continue 204 | else: 205 | resume = False 206 | curr_iter = -1 207 | fabric.barrier() 208 | fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) 209 | if state["iter_num"] >= max_iters: 210 | break 211 | 212 | # determine and set the learning rate for this iteration 213 | lr = get_lr(state["iter_num"]) if decay_lr else learning_rate 214 | for param_group in optimizer.param_groups: 215 | param_group["lr"] = lr 216 | 217 | iter_t0 = time.perf_counter() 218 | 219 | input_ids = train_data[:, 0 : model.config.block_size].contiguous() 220 | targets = train_data[:, 1 : model.config.block_size + 1].contiguous() 221 | is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 222 | with fabric.no_backward_sync(model, enabled=is_accumulating): 223 | logits = model(input_ids) 224 | loss = loss_func(logits, targets) 225 | # loss = chunked_cross_entropy(logits, targets, chunk_size=0) 226 | fabric.backward(loss / gradient_accumulation_steps) 227 | 228 | if not is_accumulating: 229 | fabric.clip_gradients(model, optimizer, max_norm=grad_clip) 230 | optimizer.step() 231 | optimizer.zero_grad() 232 | state["step_count"] += 1 233 | elif fabric.device.type == "xla": 234 | xm.mark_step() 235 | state["iter_num"] += 1 236 | # input_id: B L 237 | total_lengths += input_ids.size(1) 238 | t1 = time.perf_counter() 239 | fabric.print( 240 | f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" 241 | f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" 242 | f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " 243 | # print days as well 244 | f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " 245 | ) 246 | 247 | monitor.on_train_batch_end( 248 | state["iter_num"] * micro_batch_size, 249 | t1 - total_t0, 250 | # this assumes that device FLOPs are the same and that all devices have the same batch size 251 | fabric.world_size, 252 | state["step_count"], 253 | flops_per_batch=estimated_flops, 254 | lengths=total_lengths, 255 | train_loss = loss.item() 256 | ) 257 | 258 | 259 | 260 | 261 | if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: 262 | 263 | t0 = time.perf_counter() 264 | val_loss = validate(fabric, model, val_dataloader) 265 | t1 = time.perf_counter() - t0 266 | monitor.eval_end(t1) 267 | fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") 268 | fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 269 | fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 270 | fabric.barrier() 271 | if not is_accumulating and state["step_count"] % save_step_interval == 0: 272 | checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" 273 | fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") 274 | fabric.save(checkpoint_path, state) 275 | 276 | 277 | @torch.no_grad() 278 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: 279 | fabric.print("Validating ...") 280 | model.eval() 281 | 282 | losses = torch.zeros(eval_iters, device=fabric.device) 283 | for k, val_data in enumerate(val_dataloader): 284 | if k >= eval_iters: 285 | break 286 | input_ids = val_data[:, 0 : model.config.block_size].contiguous() 287 | targets = val_data[:, 1 : model.config.block_size + 1].contiguous() 288 | logits = model(input_ids) 289 | loss = chunked_cross_entropy(logits, targets, chunk_size=0) 290 | 291 | # loss_func = FusedCrossEntropyLoss() 292 | # loss = loss_func(logits, targets) 293 | losses[k] = loss.item() 294 | 295 | out = losses.mean() 296 | 297 | model.train() 298 | return out 299 | 300 | 301 | def create_dataloader( 302 | batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" 303 | ) -> DataLoader: 304 | datasets = [] 305 | data_config = train_data_config if split == "train" else val_data_config 306 | for prefix, _ in data_config: 307 | filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) 308 | random.seed(seed) 309 | random.shuffle(filenames) 310 | 311 | dataset = PackedDataset( 312 | filenames, 313 | # n_chunks control the buffer size. 314 | # Note that the buffer size also impacts the random shuffle 315 | # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) 316 | n_chunks=8, 317 | block_size=block_size, 318 | shuffle=shuffle, 319 | seed=seed+fabric.global_rank, 320 | num_processes=fabric.world_size, 321 | process_rank=fabric.global_rank, 322 | ) 323 | datasets.append(dataset) 324 | 325 | if not datasets: 326 | raise RuntimeError( 327 | f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." 328 | ) 329 | 330 | weights = [weight for _, weight in data_config] 331 | sum_weights = sum(weights) 332 | weights = [el / sum_weights for el in weights] 333 | 334 | combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) 335 | 336 | return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 337 | 338 | 339 | def create_dataloaders( 340 | batch_size: int, 341 | block_size: int, 342 | fabric, 343 | train_data_dir: Path = Path("data/redpajama_sample"), 344 | val_data_dir: Optional[Path] = None, 345 | seed: int = 12345, 346 | ) -> Tuple[DataLoader, DataLoader]: 347 | # Increase by one because we need the next word as well 348 | effective_block_size = block_size + 1 349 | train_dataloader = create_dataloader( 350 | batch_size=batch_size, 351 | block_size=effective_block_size, 352 | fabric=fabric, 353 | data_dir=train_data_dir, 354 | shuffle=True, 355 | seed=seed, 356 | split="train" 357 | ) 358 | val_dataloader = ( 359 | create_dataloader( 360 | batch_size=batch_size, 361 | block_size=effective_block_size, 362 | fabric=fabric, 363 | data_dir=val_data_dir, 364 | shuffle=False, 365 | seed=seed, 366 | split="validation" 367 | ) 368 | if val_data_dir 369 | else None 370 | ) 371 | return train_dataloader, val_dataloader 372 | 373 | 374 | # learning rate decay scheduler (cosine with warmup) 375 | def get_lr(it): 376 | # 1) linear warmup for warmup_iters steps 377 | if it < warmup_iters: 378 | return learning_rate * it / warmup_iters 379 | # 2) if it > lr_decay_iters, return min learning rate 380 | if it > lr_decay_iters: 381 | return min_lr 382 | # 3) in between, use cosine decay down to min learning rate 383 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 384 | assert 0 <= decay_ratio <= 1 385 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 386 | return min_lr + coeff * (learning_rate - min_lr) 387 | 388 | 389 | if __name__ == "__main__": 390 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 391 | # torch.backends.cuda.enable_flash_sdp(False) 392 | torch.set_float32_matmul_precision("high") 393 | 394 | from jsonargparse import CLI 395 | 396 | CLI(setup) 397 | -------------------------------------------------------------------------------- /pretrain/tinyllama_code.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | import sys 4 | import time 5 | from pathlib import Path 6 | from typing import Optional, Tuple, Union 7 | import math 8 | import lightning as L 9 | import torch 10 | from lightning.fabric.strategies import FSDPStrategy, XLAStrategy 11 | from torch.utils.data import DataLoader 12 | from functools import partial 13 | # support running without installing as a package 14 | wd = Path(__file__).parent.parent.resolve() 15 | sys.path.append(str(wd)) 16 | # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually 17 | from lit_gpt.model import GPT, Block, Config, CausalSelfAttention 18 | from lit_gpt.packed_dataset import CombinedDataset, PackedDataset 19 | from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor 20 | from lit_gpt.speed_monitor import estimate_flops, measure_flops 21 | from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load 22 | from pytorch_lightning.loggers import WandbLogger 23 | from lit_gpt import FusedCrossEntropyLoss 24 | import random 25 | 26 | 27 | model_name = "tiny_LLaMA_1b" 28 | name = "tiny_LLaMA_1b" 29 | out_dir = Path("out") / name 30 | checkpoint_path = "out/TinyLlama-1.1B-intermediate-step-240k-503b/lit_model.pth" 31 | # Hyperparameters 32 | num_of_devices = 6 33 | global_batch_size = 360 34 | learning_rate = 2e-4 35 | min_lr = 2e-5 36 | micro_batch_size = 6 37 | max_step = 10000 38 | warmup_steps = 0 39 | log_step_interval = 1 40 | eval_iters = 1000000 41 | save_step_interval = 2000 42 | eval_step_interval = 2000 43 | 44 | weight_decay = 1e-1 45 | beta1 = 0.9 46 | beta2 = 0.95 47 | grad_clip = 1.0 48 | decay_lr = True 49 | 50 | batch_size = global_batch_size // num_of_devices 51 | gradient_accumulation_steps = batch_size // micro_batch_size 52 | assert gradient_accumulation_steps > 0 53 | warmup_iters = warmup_steps * gradient_accumulation_steps 54 | 55 | 56 | 57 | 58 | max_iters = max_step * gradient_accumulation_steps 59 | lr_decay_iters = max_iters 60 | log_iter_interval = log_step_interval * gradient_accumulation_steps 61 | 62 | 63 | # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. 64 | train_data_config = [ 65 | ("train_starcoder", 1), 66 | ] 67 | 68 | val_data_config = [ 69 | ("validation", 1.0), 70 | ] 71 | 72 | hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} 73 | logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) 74 | wandb_logger = WandbLogger() 75 | 76 | 77 | def setup( 78 | devices: int = 8, 79 | train_data_dir: Path = Path("data/redpajama_sample"), 80 | val_data_dir: Optional[Path] = None, 81 | precision: Optional[str] = None, 82 | tpu: bool = False, 83 | resume: Union[bool, Path] = False, 84 | ) -> None: 85 | precision = precision or get_default_supported_precision(training=True, tpu=tpu) 86 | 87 | if devices > 1: 88 | if tpu: 89 | # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. 90 | devices = "auto" 91 | strategy = XLAStrategy(sync_module_states=False) 92 | else: 93 | strategy = FSDPStrategy( 94 | auto_wrap_policy={Block}, 95 | activation_checkpointing_policy=None, 96 | state_dict_type="full", 97 | limit_all_gathers=True, 98 | cpu_offload=False, 99 | ) 100 | else: 101 | strategy = "auto" 102 | 103 | fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) 104 | fabric.print(hparams) 105 | fabric.launch(main, train_data_dir, val_data_dir, resume) 106 | # main(fabric, train_data_dir, val_data_dir, resume) 107 | 108 | 109 | def main(fabric, train_data_dir, val_data_dir, resume): 110 | monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) 111 | 112 | if fabric.global_rank == 0: 113 | out_dir.mkdir(parents=True, exist_ok=True) 114 | 115 | config = Config.from_name(model_name) 116 | 117 | train_dataloader, val_dataloader = create_dataloaders( 118 | batch_size=micro_batch_size, 119 | block_size=config.block_size, 120 | fabric=fabric, 121 | train_data_dir=train_data_dir, 122 | val_data_dir=val_data_dir, 123 | seed=3407, 124 | ) 125 | if val_dataloader is None: 126 | train_dataloader = fabric.setup_dataloaders(train_dataloader) 127 | else: 128 | train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) 129 | 130 | fabric.seed_everything(3407) # same seed for every process to init model (FSDP) 131 | 132 | fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}") 133 | t0 = time.perf_counter() 134 | with fabric.init_module(empty_init=True): 135 | model = GPT(config) 136 | 137 | 138 | model = fabric.setup(model) 139 | fabric.load_raw(checkpoint_path, model, strict=True) 140 | fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") 141 | fabric.print(f"Total parameters {num_parameters(model):,}") 142 | 143 | 144 | optimizer = torch.optim.AdamW( 145 | model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False 146 | ) 147 | # import bitsandbytes as bnb 148 | # optimizer = bnb.optim.AdamW8bit( 149 | # model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2) 150 | # ) 151 | # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) 152 | optimizer = fabric.setup_optimizers(optimizer) 153 | 154 | state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} 155 | 156 | if resume is True: 157 | resume = sorted(out_dir.glob("*.pth"))[-1] 158 | if resume : 159 | fabric.print(f"Resuming training from {resume}") 160 | fabric.load(resume, state) 161 | 162 | train_time = time.perf_counter() 163 | train(fabric, state, train_dataloader, val_dataloader, monitor, resume) 164 | fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") 165 | if fabric.device.type == "cuda": 166 | fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") 167 | 168 | 169 | def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): 170 | model = state["model"] 171 | optimizer = state["optimizer"] 172 | 173 | if val_dataloader is not None: 174 | validate(fabric, model, val_dataloader) # sanity check 175 | 176 | with torch.device("meta"): 177 | meta_model = GPT(model.config) 178 | # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. 179 | # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, 180 | # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead 181 | estimated_flops = estimate_flops(meta_model) * micro_batch_size 182 | fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") 183 | x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) 184 | # measured_flos run in meta. Will trigger fusedRMSNorm error 185 | #measured_flops = measure_flops(meta_model, x) 186 | #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") 187 | del meta_model, x 188 | 189 | total_lengths = 0 190 | total_t0 = time.perf_counter() 191 | 192 | if fabric.device.type == "xla": 193 | import torch_xla.core.xla_model as xm 194 | 195 | xm.mark_step() 196 | 197 | 198 | initial_iter = state["iter_num"] 199 | curr_iter = 0 200 | 201 | loss_func = FusedCrossEntropyLoss() 202 | for train_data in train_dataloader: 203 | # resume loader state. This is not elegant but it works. Should rewrite it in the future. 204 | if resume: 205 | if curr_iter < initial_iter: 206 | curr_iter += 1 207 | continue 208 | else: 209 | resume = False 210 | curr_iter = -1 211 | fabric.barrier() 212 | fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) 213 | if state["iter_num"] >= max_iters: 214 | break 215 | 216 | # determine and set the learning rate for this iteration 217 | lr = get_lr(state["iter_num"]) if decay_lr else learning_rate 218 | for param_group in optimizer.param_groups: 219 | param_group["lr"] = lr 220 | 221 | iter_t0 = time.perf_counter() 222 | input_ids = train_data[:, 0 : model.config.block_size].contiguous() 223 | targets = train_data[:, 1 : model.config.block_size + 1].contiguous() 224 | 225 | is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 226 | with fabric.no_backward_sync(model, enabled=is_accumulating): 227 | logits = model(input_ids) 228 | loss = loss_func(logits, targets) 229 | # loss = chunked_cross_entropy(logits, targets, chunk_size=0) 230 | fabric.backward(loss / gradient_accumulation_steps) 231 | 232 | if not is_accumulating: 233 | fabric.clip_gradients(model, optimizer, max_norm=grad_clip) 234 | optimizer.step() 235 | optimizer.zero_grad() 236 | state["step_count"] += 1 237 | elif fabric.device.type == "xla": 238 | xm.mark_step() 239 | state["iter_num"] += 1 240 | # input_id: B L 241 | total_lengths += input_ids.size(1) 242 | t1 = time.perf_counter() 243 | fabric.print( 244 | f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" 245 | f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" 246 | f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " 247 | # print days as well 248 | f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " 249 | ) 250 | 251 | monitor.on_train_batch_end( 252 | state["iter_num"] * micro_batch_size, 253 | t1 - total_t0, 254 | # this assumes that device FLOPs are the same and that all devices have the same batch size 255 | fabric.world_size, 256 | state["step_count"], 257 | flops_per_batch=estimated_flops, 258 | lengths=total_lengths, 259 | train_loss = loss.item() 260 | ) 261 | 262 | 263 | 264 | 265 | if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: 266 | 267 | t0 = time.perf_counter() 268 | val_loss = validate(fabric, model, val_dataloader) 269 | t1 = time.perf_counter() - t0 270 | monitor.eval_end(t1) 271 | fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") 272 | fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 273 | fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 274 | fabric.barrier() 275 | if not is_accumulating and state["step_count"] % save_step_interval == 0: 276 | checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" 277 | fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") 278 | fabric.save(checkpoint_path, state) 279 | 280 | 281 | @torch.no_grad() 282 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: 283 | fabric.print("Validating ...") 284 | model.eval() 285 | 286 | losses = torch.zeros(eval_iters, device=fabric.device) 287 | for k, val_data in enumerate(val_dataloader): 288 | if k >= eval_iters: 289 | break 290 | input_ids = val_data[:, 0 : model.config.block_size].contiguous() 291 | targets = val_data[:, 1 : model.config.block_size + 1].contiguous() 292 | logits = model(input_ids) 293 | loss = chunked_cross_entropy(logits, targets, chunk_size=0) 294 | 295 | # loss_func = FusedCrossEntropyLoss() 296 | # loss = loss_func(logits, targets) 297 | losses[k] = loss.item() 298 | 299 | out = losses.mean() 300 | 301 | model.train() 302 | return out 303 | 304 | 305 | def create_dataloader( 306 | batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" 307 | ) -> DataLoader: 308 | datasets = [] 309 | data_config = train_data_config if split == "train" else val_data_config 310 | for prefix, _ in data_config: 311 | filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) 312 | random.seed(seed) 313 | random.shuffle(filenames) 314 | 315 | dataset = PackedDataset( 316 | filenames, 317 | # n_chunks control the buffer size. 318 | # Note that the buffer size also impacts the random shuffle 319 | # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) 320 | n_chunks=8, 321 | block_size=block_size, 322 | shuffle=shuffle, 323 | seed=seed+fabric.global_rank, 324 | num_processes=fabric.world_size, 325 | process_rank=fabric.global_rank, 326 | ) 327 | datasets.append(dataset) 328 | 329 | if not datasets: 330 | raise RuntimeError( 331 | f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." 332 | ) 333 | 334 | weights = [weight for _, weight in data_config] 335 | sum_weights = sum(weights) 336 | weights = [el / sum_weights for el in weights] 337 | 338 | combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) 339 | 340 | return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 341 | 342 | 343 | def create_dataloaders( 344 | batch_size: int, 345 | block_size: int, 346 | fabric, 347 | train_data_dir: Path = Path("data/redpajama_sample"), 348 | val_data_dir: Optional[Path] = None, 349 | seed: int = 12345, 350 | ) -> Tuple[DataLoader, DataLoader]: 351 | # Increase by one because we need the next word as well 352 | effective_block_size = block_size + 1 353 | train_dataloader = create_dataloader( 354 | batch_size=batch_size, 355 | block_size=effective_block_size, 356 | fabric=fabric, 357 | data_dir=train_data_dir, 358 | shuffle=True, 359 | seed=seed, 360 | split="train" 361 | ) 362 | val_dataloader = ( 363 | create_dataloader( 364 | batch_size=batch_size, 365 | block_size=effective_block_size, 366 | fabric=fabric, 367 | data_dir=val_data_dir, 368 | shuffle=False, 369 | seed=seed, 370 | split="validation" 371 | ) 372 | if val_data_dir 373 | else None 374 | ) 375 | return train_dataloader, val_dataloader 376 | 377 | 378 | # learning rate decay scheduler (cosine with warmup) 379 | def get_lr(it): 380 | # 1) linear warmup for warmup_iters steps 381 | if it < warmup_iters: 382 | return learning_rate * it / warmup_iters 383 | # 2) if it > lr_decay_iters, return min learning rate 384 | if it > lr_decay_iters: 385 | return min_lr 386 | # 3) in between, use cosine decay down to min learning rate 387 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 388 | assert 0 <= decay_ratio <= 1 389 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 390 | return min_lr + coeff * (learning_rate - min_lr) 391 | 392 | 393 | if __name__ == "__main__": 394 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 395 | # torch.backends.cuda.enable_flash_sdp(False) 396 | torch.set_float32_matmul_precision("high") 397 | 398 | from jsonargparse import CLI 399 | 400 | CLI(setup) 401 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.1.0dev 2 | lightning==2.1.2 3 | lightning[app] 4 | jsonargparse[signatures] # CLI 5 | pandas 6 | pyarrow 7 | tokenizers 8 | sentencepiece 9 | wandb 10 | zstd 11 | 12 | # for finetuning 13 | bitsandbytes==0.40.0 14 | transformers==4.31.0 15 | peft==0.4.0 16 | accelerate==0.21.0 17 | einops==0.6.1 18 | evaluate==0.4.0 19 | scikit-learn==1.2.2 20 | sentencepiece==0.1.99 21 | wandb==0.15.3 22 | # other optional dependencies are 23 | # sentencepiece # pythia, falcon, redpajama 24 | # tokenizers # llama-based models 25 | # bitsandbytes>=0.41.1 # quantize/bnb.py 26 | # scipy # TODO: remove when https://github.com/TimDettmers/bitsandbytes/pull/525 is released 27 | # datasets # quantize/gptq.py 28 | # zstandard # scripts/prepare_redpajama.py 29 | # git+https://github.com/EleutherAI/lm-evaluation-harness.git@master # eval 30 | -------------------------------------------------------------------------------- /script.sh: -------------------------------------------------------------------------------- 1 | python scripts/convert_hf_checkpoint.py --checkpoint_dir out/TinyLlama-1.1B-900B --model_name tiny_LLaMA_1b 2 | 3 | python test_weight.py --checkpoint_dir out/TinyLlama-1.1B-intermediate-900B 4 | 5 | 6 | python pretrain/tinyllama_code.py --devices 8 --train_data_dir data/code_specialist_python_java_javascript_c_go_8192 7 | 8 | 9 | 10 | python scripts/prepare_starcoder.py --source_path data/starcoderdata/ --tokenizer_path data/llama --destination_path data/code_specialist_python_java_javascript_c_go_8192 --split train --percentage 1.0 --filenames_subset ["python","cpp","go","java","javascript"] --chunk_size 4194816 11 | 12 | 13 | 14 | 15 | /data/TinyLlama/out/code_tiny_LLaMA_1b_python_java_go_cpp_javascript/iter-032000-ckpt.pth 16 | 17 | python scripts/convert_lit_checkpoint.py --out_dir /data/TinyLlama/out/tiny_LLaMA_1b/ --checkpoint_name iter-100000-ckpt.pth --model_name tiny_LLaMA_1b -------------------------------------------------------------------------------- /scripts/convert_hf_checkpoint.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import gc 3 | import json 4 | import sys 5 | from functools import partial 6 | from pathlib import Path 7 | from typing import Dict, List, Literal, Optional, Tuple, Union 8 | 9 | import torch 10 | 11 | # support running without installing as a package 12 | wd = Path(__file__).parent.parent.resolve() 13 | sys.path.append(str(wd)) 14 | 15 | from lit_gpt import Config 16 | from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load 17 | 18 | 19 | def copy_weights_gpt_neox( 20 | state_dict: Dict[str, torch.Tensor], 21 | hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 22 | saver: Optional[incremental_save] = None, 23 | dtype: Optional[torch.dtype] = None, 24 | ) -> None: 25 | weight_map = { 26 | "gpt_neox.embed_in.weight": "transformer.wte.weight", 27 | "gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", 28 | "gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", 29 | "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.attn.bias", 30 | "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", 31 | "gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias", 32 | "gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight", 33 | "gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None, 34 | "gpt_neox.layers.{}.attention.bias": None, 35 | "gpt_neox.layers.{}.attention.masked_bias": None, 36 | "gpt_neox.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", 37 | "gpt_neox.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", 38 | "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias": "transformer.h.{}.mlp.fc.bias", 39 | "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", 40 | "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias": "transformer.h.{}.mlp.proj.bias", 41 | "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", 42 | "gpt_neox.final_layer_norm.bias": "transformer.ln_f.bias", 43 | "gpt_neox.final_layer_norm.weight": "transformer.ln_f.weight", 44 | "embed_out.weight": "lm_head.weight", 45 | } 46 | 47 | for name, param in hf_weights.items(): 48 | if "gpt_neox.layers" in name: 49 | from_name, number = layer_template(name, 2) 50 | to_name = weight_map[from_name] 51 | if to_name is None: 52 | continue 53 | to_name = to_name.format(number) 54 | else: 55 | to_name = weight_map[name] 56 | param = load_param(param, name, dtype) 57 | if saver is not None: 58 | param = saver.store_early(param) 59 | state_dict[to_name] = param 60 | 61 | 62 | def copy_weights_falcon( 63 | size: Literal["7b", "40b"], 64 | state_dict: Dict[str, torch.Tensor], 65 | hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 66 | saver: Optional[incremental_save] = None, 67 | dtype: Optional[torch.dtype] = None, 68 | ) -> None: 69 | weight_map = { 70 | "transformer.word_embeddings.weight": "transformer.wte.weight", 71 | "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", 72 | "transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight", 73 | "transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", 74 | "transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", 75 | "transformer.ln_f.bias": "transformer.ln_f.bias", 76 | "transformer.ln_f.weight": "transformer.ln_f.weight", 77 | "lm_head.weight": "lm_head.weight", 78 | } 79 | # the original model definition is different for each size 80 | if size == "7b": 81 | weight_map.update( 82 | { 83 | "transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", 84 | "transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", 85 | } 86 | ) 87 | elif size == "40b": 88 | weight_map.update( 89 | { 90 | "transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias", 91 | "transformer.h.{}.ln_attn.weight": "transformer.h.{}.norm_1.weight", 92 | "transformer.h.{}.ln_mlp.bias": "transformer.h.{}.norm_2.bias", 93 | "transformer.h.{}.ln_mlp.weight": "transformer.h.{}.norm_2.weight", 94 | } 95 | ) 96 | else: 97 | raise NotImplementedError 98 | 99 | for name, param in hf_weights.items(): 100 | if "transformer.h" in name: 101 | from_name, number = layer_template(name, 2) 102 | to_name = weight_map[from_name].format(number) 103 | else: 104 | to_name = weight_map[name] 105 | param = load_param(param, name, dtype) 106 | if saver is not None: 107 | param = saver.store_early(param) 108 | state_dict[to_name] = param 109 | 110 | 111 | def copy_weights_hf_llama( 112 | config: Config, 113 | qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], 114 | state_dict: Dict[str, torch.Tensor], 115 | hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 116 | saver: Optional[incremental_save] = None, 117 | dtype: Optional[torch.dtype] = None, 118 | ) -> None: 119 | weight_map = { 120 | "model.embed_tokens.weight": "transformer.wte.weight", 121 | "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", 122 | "model.layers.{}.self_attn.q_proj.weight": None, 123 | "model.layers.{}.self_attn.k_proj.weight": None, 124 | "model.layers.{}.self_attn.v_proj.weight": None, 125 | "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", 126 | "model.layers.{}.self_attn.rotary_emb.inv_freq": None, 127 | "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", 128 | "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.swiglu.w1.weight", 129 | "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.swiglu.w2.weight", 130 | "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.swiglu.w3.weight", 131 | "model.norm.weight": "transformer.ln_f.weight", 132 | "lm_head.weight": "lm_head.weight", 133 | } 134 | 135 | for name, param in hf_weights.items(): 136 | if "model.layers" in name: 137 | from_name, number = layer_template(name, 2) 138 | qkv = qkv_weights.setdefault(number, [None, None, None]) 139 | if "q_proj" in name: 140 | qkv[0] = param 141 | elif "k_proj" in name: 142 | qkv[1] = param 143 | elif "v_proj" in name: 144 | qkv[2] = param 145 | to_name = weight_map[from_name] 146 | if to_name is None: 147 | continue 148 | to_name = to_name.format(number) 149 | else: 150 | to_name = weight_map[name] 151 | param = load_param(param, name, dtype) 152 | if saver is not None: 153 | param = saver.store_early(param) 154 | state_dict[to_name] = param 155 | 156 | for i, (q, k, v) in list(qkv_weights.items()): 157 | if q is None or k is None or v is None: 158 | # split across different .bin files 159 | continue 160 | q = load_param(q, f"layer {i} q", dtype) 161 | k = load_param(k, f"layer {i} k", dtype) 162 | v = load_param(v, f"layer {i} v", dtype) 163 | q_per_kv = config.n_head // config.n_query_groups 164 | qs = torch.split(q, config.head_size * q_per_kv) 165 | ks = torch.split(k, config.head_size) 166 | vs = torch.split(v, config.head_size) 167 | cycled = [t for group in zip(qs, ks, vs) for t in group] 168 | qkv = torch.cat(cycled) 169 | state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv 170 | del qkv_weights[i] 171 | 172 | 173 | def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: 174 | split = layer_name.split(".") 175 | number = int(split[idx]) 176 | split[idx] = "{}" 177 | from_name = ".".join(split) 178 | return from_name, number 179 | 180 | 181 | def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: 182 | if hasattr(param, "_load_tensor"): 183 | # support tensors loaded via `lazy_load()` 184 | print(f"Loading {name!r} into RAM") 185 | param = param._load_tensor() 186 | if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: 187 | print(f"Converting {name!r} from {param.dtype} to {dtype}") 188 | param = param.to(dtype) 189 | return param 190 | 191 | 192 | @torch.inference_mode() 193 | def convert_hf_checkpoint( 194 | *, 195 | checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), 196 | model_name: Optional[str] = None, 197 | dtype: Optional[str] = None, 198 | ) -> None: 199 | if model_name is None: 200 | model_name = checkpoint_dir.name 201 | if dtype is not None: 202 | dtype = getattr(torch, dtype) 203 | 204 | config = Config.from_name(model_name) 205 | print(f"Model config {config.__dict__}") 206 | with open(checkpoint_dir / "lit_config.json", "w") as json_config: 207 | json.dump(config.__dict__, json_config) 208 | 209 | if "falcon" in model_name: 210 | copy_fn = partial(copy_weights_falcon, "40b" if config.n_embd == 8192 else "7b") 211 | elif config._mlp_class == "LLaMAMLP": 212 | # holder to reconstitute the split q, k, v 213 | qkv_weights = {} 214 | copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) 215 | else: 216 | copy_fn = copy_weights_gpt_neox 217 | 218 | # initialize a new empty state dict to hold our new weights 219 | sd = {} 220 | 221 | # Load the json file containing weight mapping 222 | pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json" 223 | if pytorch_bin_map_json_path.is_file(): # not all checkpoints have this file 224 | with open(pytorch_bin_map_json_path) as json_map: 225 | bin_index = json.load(json_map) 226 | bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} 227 | else: 228 | bin_files = set(checkpoint_dir.glob("*.bin")) 229 | if not bin_files: 230 | raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files") 231 | 232 | with incremental_save(checkpoint_dir / "lit_model.pth") as saver: 233 | # for checkpoints that split the QKV across several files, we need to keep all the bin files 234 | # open, so we use `ExitStack` to close them all together at the end 235 | with contextlib.ExitStack() as stack: 236 | for bin_file in sorted(bin_files): 237 | print("Processing", bin_file) 238 | hf_weights = stack.enter_context(lazy_load(bin_file)) 239 | copy_fn(sd, hf_weights, saver=None, dtype=dtype) 240 | gc.collect() 241 | print("Saving converted checkpoint") 242 | saver.save(sd) 243 | 244 | 245 | if __name__ == "__main__": 246 | from jsonargparse import CLI 247 | 248 | CLI(convert_hf_checkpoint) 249 | -------------------------------------------------------------------------------- /scripts/convert_lit_checkpoint.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import gc 3 | import sys 4 | from functools import partial 5 | from pathlib import Path 6 | from typing import Dict, Literal, Optional, Tuple, Union 7 | from dataclasses import asdict 8 | import json 9 | import torch 10 | 11 | # support running without installing as a package 12 | wd = Path(__file__).parent.parent.resolve() 13 | sys.path.append(str(wd)) 14 | 15 | from lit_gpt import Config 16 | from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load 17 | # from scripts.convert_hf_checkpoint import layer_template, load_param 18 | 19 | 20 | def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: 21 | split = layer_name.split(".") 22 | number = int(split[idx]) 23 | split[idx] = "{}" 24 | from_name = ".".join(split) 25 | return from_name, number 26 | 27 | 28 | def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: 29 | if hasattr(param, "_load_tensor"): 30 | # support tensors loaded via `lazy_load()` 31 | print(f"Loading {name!r} into RAM") 32 | param = param._load_tensor() 33 | if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: 34 | print(f"Converting {name!r} from {param.dtype} to {dtype}") 35 | param = param.to(dtype) 36 | return param 37 | def copy_weights_falcon( 38 | size: Literal["7b", "40b"], 39 | state_dict: Dict[str, torch.Tensor], 40 | lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 41 | saver: Optional[incremental_save] = None, 42 | ): 43 | weight_map = { 44 | "transformer.wte.weight": "transformer.word_embeddings.weight", 45 | "transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight", 46 | "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", 47 | "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", 48 | "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", 49 | "transformer.ln_f.bias": "transformer.ln_f.bias", 50 | "transformer.ln_f.weight": "transformer.ln_f.weight", 51 | "lm_head.weight": "lm_head.weight", 52 | } 53 | # the original model definition is different for each size 54 | if size == "7b": 55 | weight_map.update( 56 | { 57 | "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", 58 | "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", 59 | } 60 | ) 61 | elif size == "40b": 62 | weight_map.update( 63 | { 64 | "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", 65 | "transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight", 66 | "transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias", 67 | "transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight", 68 | } 69 | ) 70 | else: 71 | raise NotImplementedError 72 | 73 | for name, param in lit_weights.items(): 74 | if "transformer.h" in name: 75 | from_name, number = layer_template(name, 2) 76 | to_name = weight_map[from_name].format(number) 77 | else: 78 | to_name = weight_map[name] 79 | param = load_param(param, name, None) 80 | if saver is not None: 81 | param = saver.store_early(param) 82 | state_dict[to_name] = param 83 | 84 | 85 | def copy_weights_gpt_neox( 86 | state_dict: Dict[str, torch.Tensor], 87 | lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 88 | saver: Optional[incremental_save] = None, 89 | ) -> None: 90 | weight_map = { 91 | "transformer.wte.weight": "gpt_neox.embed_in.weight", 92 | "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", 93 | "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", 94 | "transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", 95 | "transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", 96 | "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", 97 | "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", 98 | "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", 99 | "transformer.h.{}.norm_2.weight": "gpt_neox.layers.{}.post_attention_layernorm.weight", 100 | "transformer.h.{}.mlp.fc.bias": "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias", 101 | "transformer.h.{}.mlp.fc.weight": "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight", 102 | "transformer.h.{}.mlp.proj.bias": "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias", 103 | "transformer.h.{}.mlp.proj.weight": "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight", 104 | "transformer.ln_f.bias": "gpt_neox.final_layer_norm.bias", 105 | "transformer.ln_f.weight": "gpt_neox.final_layer_norm.weight", 106 | "lm_head.weight": "embed_out.weight", 107 | } 108 | 109 | for name, param in lit_weights.items(): 110 | if "transformer.h" in name: 111 | from_name, number = layer_template(name, 2) 112 | to_name = weight_map[from_name].format(number) 113 | else: 114 | to_name = weight_map[name] 115 | param = load_param(param, name, None) 116 | if saver is not None: 117 | param = saver.store_early(param) 118 | state_dict[to_name] = param 119 | 120 | 121 | def copy_weights_llama( 122 | config: Config, 123 | state_dict: Dict[str, torch.Tensor], 124 | lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 125 | saver: Optional[incremental_save] = None, 126 | ): 127 | weight_map = { 128 | "transformer.wte.weight": "model.embed_tokens.weight", 129 | "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", 130 | "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", 131 | "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", 132 | "transformer.h.{}.mlp.swiglu.w1.weight": "model.layers.{}.mlp.gate_proj.weight", 133 | "transformer.h.{}.mlp.swiglu.w2.weight": "model.layers.{}.mlp.up_proj.weight", 134 | "transformer.h.{}.mlp.swiglu.w3.weight": "model.layers.{}.mlp.down_proj.weight", 135 | "transformer.ln_f.weight": "model.norm.weight", 136 | "lm_head.weight": "lm_head.weight", 137 | } 138 | for name, param in lit_weights.items(): 139 | if name.endswith(".attn.attn.weight"): 140 | from_name, number = layer_template(name, 2) 141 | q = "model.layers.{}.self_attn.q_proj.weight".format(number) 142 | k = "model.layers.{}.self_attn.k_proj.weight".format(number) 143 | v = "model.layers.{}.self_attn.v_proj.weight".format(number) 144 | qkv = load_param(param, name,None) 145 | qp, kp, vp = tensor_split(qkv, config) 146 | for to_name, param in zip((q, k, v), (qp, kp, vp)): 147 | if saver is not None: 148 | param = saver.store_early(param) 149 | state_dict[to_name] = param 150 | elif "transformer.h" in name: 151 | from_name, number = layer_template(name, 2) 152 | to_name = weight_map[from_name] 153 | 154 | if to_name is None: 155 | continue 156 | to_name = to_name.format(number) 157 | param = load_param(param, name,None) 158 | if saver is not None: 159 | param = saver.store_early(param) 160 | state_dict[to_name] = param 161 | 162 | else: 163 | to_name = weight_map[name] 164 | param = load_param(param, name, None) 165 | if saver is not None: 166 | param = saver.store_early(param) 167 | state_dict[to_name] = param 168 | 169 | 170 | def tensor_split( 171 | param: Union[torch.Tensor, NotYetLoadedTensor], config: Config 172 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 173 | def kstart(start, blen, klen) -> int: 174 | """returns start index of keys in batch""" 175 | return start + (blen - (klen * 2)) 176 | 177 | def vstart(start, blen, klen) -> int: 178 | """returns start index of values in batch""" 179 | return start + blen - klen 180 | 181 | def vend(start, blen) -> int: 182 | """returns last index of values in batch""" 183 | return start + blen 184 | 185 | # num observations 186 | nobs = param.shape[0] 187 | # batch length 188 | blen = nobs // config.n_query_groups 189 | # key length in batch 190 | klen = config.head_size 191 | # value length in batch 192 | vlen = config.head_size 193 | # the starting index of each new batch 194 | starts = range(0, nobs, blen) 195 | # the indices to splice on 196 | splices = [(s, kstart(s, blen, klen), vstart(s, blen, vlen), vend(s, blen)) for s in starts] 197 | 198 | qc = () 199 | kc = () 200 | vc = () 201 | 202 | for splice in splices: 203 | qs, ks, vs, ve = splice 204 | qc += (param[qs:ks, :],) 205 | kc += (param[ks:vs, :],) 206 | vc += (param[vs:ve, :],) 207 | 208 | q = torch.cat(qc) 209 | k = torch.cat(kc) 210 | v = torch.cat(vc) 211 | 212 | return q, k, v 213 | 214 | 215 | def maybe_unwrap_state_dict(lit_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 216 | return lit_weights.get("model", lit_weights) 217 | 218 | 219 | def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: 220 | weight_names = {wk.split(".")[-1] for wk in lit_weights} 221 | # LoRA or QLoRA 222 | if any("lora" in wn for wn in weight_names): 223 | raise ValueError("Model weights must be merged using `lora.merge_lora_weights()` before conversion.") 224 | # adapter v2. adapter_bias will only be in adapter_v2 225 | elif "adapter_bias" in weight_names: 226 | raise NotImplementedError("Converting models finetuned with adapter_v2 not yet supported.") 227 | # adapter. gating_factor is in adapter and adapter_v2 228 | elif "gating_factor" in weight_names: 229 | raise NotImplementedError("Converting models finetuned with adapter not yet supported.") 230 | 231 | 232 | def get_tinyllama_init_hf_config() -> dict: 233 | return { 234 | "architectures": ["LlamaForCausalLM"], 235 | "bos_token_id": 1, 236 | "eos_token_id": 2, 237 | "hidden_act": "silu", 238 | "hidden_size": None, 239 | "initializer_range": 0.02, 240 | "intermediate_size": None, 241 | "max_position_embeddings": None, 242 | "model_type": "llama", 243 | "num_attention_heads": None, 244 | "num_hidden_layers": None, 245 | "num_key_value_heads": None, 246 | "pretraining_tp": 1, 247 | "rms_norm_eps": None, 248 | "rope_scaling": None, 249 | "tie_word_embeddings": False, 250 | "torch_dtype": "float32", 251 | "transformers_version": "4.31.0.dev0", 252 | "use_cache": True, 253 | "vocab_size": None, 254 | } 255 | 256 | 257 | def convert_config_lit_to_hf(lit_config_dict: dict) -> dict: 258 | lit_hf_mapping = { 259 | "block_size": "max_position_embeddings", 260 | "vocab_size": "vocab_size", 261 | "n_layer": "num_hidden_layers", 262 | "n_embd": "hidden_size", 263 | "n_head": "num_attention_heads", 264 | "n_query_groups": "num_key_value_heads", 265 | "intermediate_size": "intermediate_size", 266 | "norm_eps": "rms_norm_eps", 267 | 268 | } 269 | hf_config_dict = get_tinyllama_init_hf_config() 270 | 271 | for lit_key, hf_key in lit_hf_mapping.items(): 272 | hf_config_dict[hf_key] = lit_config_dict[lit_key] 273 | return hf_config_dict 274 | 275 | 276 | @torch.inference_mode() 277 | def convert_lit_checkpoint(*, 278 | checkpoint_name: str, 279 | out_dir: Path, 280 | model_name: str, 281 | model_only: bool = True) -> None: 282 | config = Config.from_name(model_name) 283 | 284 | if "falcon" in model_name: 285 | copy_fn = partial(copy_weights_falcon, "40b" if config.n_embd == 8192 else "7b") 286 | elif config._mlp_class == "LLaMAMLP": 287 | copy_fn = partial(copy_weights_llama, config) 288 | else: 289 | copy_fn = copy_weights_gpt_neox 290 | 291 | # initialize a new empty state dict to hold our new weights 292 | sd = {} 293 | 294 | # checkpoint_name cannot be hardcoded because there exists different outputs such as 295 | # ("lit_model_finetuned.pth", "lit_model_lora_finetuned.pth", "lit_model_adapter_finetuned.pth"") 296 | pth_file = out_dir / checkpoint_name 297 | bin_file = pth_file.with_suffix(".bin") 298 | 299 | with incremental_save(bin_file) as saver: 300 | with contextlib.ExitStack() as stack: 301 | lit_weights = stack.enter_context(lazy_load(pth_file)) 302 | lit_weights = maybe_unwrap_state_dict(lit_weights) 303 | check_conversion_supported(lit_weights) 304 | # Incremental save will trigger error 305 | copy_fn(sd, lit_weights, saver=None) 306 | gc.collect() 307 | saver.save(sd) 308 | 309 | # convert lit config file to hf-style 310 | if not model_only: 311 | print('Converting config file...') 312 | lit_config = asdict(config) 313 | hf_config = convert_config_lit_to_hf(lit_config) 314 | config_path = out_dir / "config.json" 315 | with open(config_path, "w") as f: 316 | json.dump(hf_config, f, indent=4) 317 | 318 | 319 | 320 | 321 | if __name__ == "__main__": 322 | from jsonargparse import CLI 323 | 324 | CLI(convert_lit_checkpoint, as_positional=False) 325 | -------------------------------------------------------------------------------- /scripts/prepare_redpajama.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import sys 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | # support running without installing as a package 11 | wd = Path(__file__).parent.parent.resolve() 12 | sys.path.append(str(wd)) 13 | 14 | import lit_gpt.packed_dataset as packed_dataset 15 | from lit_gpt import Config, Tokenizer 16 | 17 | filenames_sample = [ 18 | "arxiv_sample.jsonl", 19 | "book_sample.jsonl", 20 | "c4_sample.jsonl", 21 | "cc_2019-30_sample.jsonl", 22 | "cc_2020-05_sample.jsonl", 23 | "cc_2021-04_sample.jsonl", 24 | "cc_2022-05_sample.jsonl", 25 | "cc_2023-06_sample.jsonl", 26 | "github_sample.jsonl", 27 | "stackexchange_sample.jsonl", 28 | "wikipedia_sample.jsonl", 29 | ] 30 | 31 | filename_sets = { 32 | "arxiv": "arxiv/arxiv*", 33 | "book": "book/book*", 34 | "c4": "c4/c4-train*", 35 | "common_crawl": "common_crawl/*", 36 | "github": "github/filtered*", 37 | "stackexchange": "stackexchange/stackexchange*", 38 | "wikipedia": "wikipedia/wiki*", 39 | } 40 | 41 | 42 | def prepare_sample( 43 | source_path: Path, checkpoint_dir: Path, destination_path: Path, chunk_size: int, match: str = "" 44 | ) -> None: 45 | """Prepare the "Red Pajama" dataset using the original tokenizer.""" 46 | destination_path.mkdir(parents=True, exist_ok=True) 47 | 48 | tokenizer = Tokenizer(checkpoint_dir) 49 | 50 | for name in filenames_sample: 51 | if match and match not in name: 52 | continue 53 | 54 | filepath = source_path / name 55 | 56 | if not filepath.is_file(): 57 | raise RuntimeError( 58 | f"Input file not found at {filepath}. \nMake sure you download the data, e.g. wget -i" 59 | " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" 60 | " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" 61 | " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" 62 | ) 63 | 64 | prefix, _ = os.path.splitext(name) 65 | 66 | builder = packed_dataset.PackedDatasetBuilder( 67 | outdir=destination_path, 68 | prefix=prefix, 69 | chunk_size=chunk_size, 70 | sep_token=tokenizer.eos_id, 71 | dtype="auto", 72 | vocab_size=tokenizer.vocab_size, 73 | ) 74 | 75 | print(f"Processing {name}") 76 | 77 | with open(filepath, encoding="utf-8") as f: 78 | for row in tqdm(f): 79 | text = json.loads(row)["text"] 80 | text_ids = tokenizer.encode(text) 81 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 82 | 83 | builder.write_reminder() 84 | 85 | 86 | def prepare_full( 87 | source_path: Path, checkpoint_dir: Path, destination_path: Path, chunk_size: int, match: str = "" 88 | ) -> None: 89 | """Prepare the "Red Pajama" dataset using the original tokenizer.""" 90 | import zstandard as zstd 91 | 92 | destination_path.mkdir(parents=True, exist_ok=True) 93 | 94 | tokenizer = Tokenizer(checkpoint_dir) 95 | 96 | for set_name, pattern in filename_sets.items(): 97 | if match and match not in set_name: 98 | continue 99 | 100 | is_cc = set_name == "common_crawl" 101 | 102 | filenames = glob.glob(os.path.join(source_path, pattern), recursive=True) 103 | 104 | if not filenames: 105 | raise RuntimeError( 106 | f"No files matching {pattern} found at {source_path}. \nMake sure you download the data, e.g. wget -i" 107 | " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" 108 | " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" 109 | " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" 110 | ) 111 | 112 | builder = packed_dataset.PackedDatasetBuilder( 113 | outdir=destination_path, 114 | prefix=set_name, 115 | chunk_size=chunk_size, 116 | sep_token=tokenizer.eos_id, 117 | dtype="auto", 118 | vocab_size=tokenizer.vocab_size, 119 | ) 120 | 121 | for name in filenames: 122 | filepath = source_path / name 123 | 124 | print(f"Processing {name}") 125 | 126 | if is_cc: 127 | with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: 128 | for row in tqdm(f): 129 | text = json.loads(row)["text"] 130 | text_ids = tokenizer.encode(text) 131 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 132 | else: 133 | with open(filepath, encoding="utf-8") as f: 134 | for row in tqdm(f): 135 | text = json.loads(row)["text"] 136 | text_ids = tokenizer.encode(text) 137 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 138 | 139 | builder.write_reminder() 140 | 141 | 142 | def prepare( 143 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 144 | checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), 145 | destination_path: Path = Path("data/redpajama_sample"), 146 | sample: bool = True, 147 | match: str = "", 148 | ) -> None: 149 | """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained.""" 150 | with open(checkpoint_dir / "lit_config.json") as fp: 151 | config = Config(**json.load(fp)) 152 | 153 | prepare_fn = prepare_sample if sample else prepare_full 154 | prepare_fn( 155 | source_path=source_path, 156 | checkpoint_dir=checkpoint_dir, 157 | destination_path=destination_path, 158 | chunk_size=(config.block_size + 1) * 1024, # block size + 1 for causal, 1024 blocks 159 | match=match, 160 | ) 161 | 162 | 163 | if __name__ == "__main__": 164 | from jsonargparse import CLI 165 | 166 | CLI(prepare) -------------------------------------------------------------------------------- /scripts/prepare_slimpajama.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import os 4 | from pathlib import Path 5 | import sys 6 | from typing import List 7 | import numpy as np 8 | from tqdm import tqdm 9 | from multiprocessing import Process, cpu_count 10 | 11 | # support running without installing as a package 12 | wd = Path(__file__).parent.parent.resolve() 13 | sys.path.append(str(wd)) 14 | 15 | import lit_gpt.packed_dataset as packed_dataset 16 | from lit_gpt import Tokenizer 17 | 18 | # Filename for SlimPajama 19 | slimpajama_sets = { 20 | "train": "train/chunk*/*", 21 | "validation": "validation/chunk*/*", 22 | "test": "test/chunk*/*", 23 | } 24 | 25 | 26 | def prepare_full( 27 | source_path: Path, 28 | tokenizer_path: Path, 29 | destination_path: Path, 30 | chunk_size: int, 31 | split: str="train", 32 | filenames_subset: List[str] = None, 33 | process_id: int = 0 34 | ) -> None: 35 | import zstandard as zstd 36 | 37 | destination_path.mkdir(parents=True, exist_ok=True) 38 | 39 | tokenizer = Tokenizer(tokenizer_path) 40 | 41 | # Use the provided filenames_subset or default to all filenames 42 | filenames = filenames_subset 43 | 44 | if not filenames: 45 | raise RuntimeError( 46 | f"No files matching {slimpajama_sets[split]} found at {source_path}. \n" 47 | "Make sure you download the data..." 48 | ) 49 | 50 | builder = packed_dataset.PackedDatasetBuilder( 51 | outdir=destination_path, 52 | prefix=f"{split}_slimpajama_{process_id}", # Use process_id to differentiate builders 53 | chunk_size=chunk_size, 54 | sep_token=tokenizer.bos_id, 55 | dtype="auto", 56 | vocab_size=tokenizer.vocab_size, 57 | ) 58 | 59 | for filepath in filenames: 60 | print(f"Processing {filepath}") 61 | with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: 62 | for row in tqdm(f): 63 | text = json.loads(row)["text"] 64 | if json.loads(row)["meta"]["redpajama_set_name"] == "RedPajamaGithub": 65 | continue # we don't want to include the github data 66 | text_ids = tokenizer.encode(text) 67 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 68 | 69 | # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details 70 | # builder.write_reminder() 71 | 72 | 73 | def prepare( 74 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 75 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 76 | destination_path: Path = Path("data/red_pajama_sample"), 77 | chunk_size: int = 2049 * 1024, 78 | split: str="train", 79 | percentage: float = 1.0, 80 | ) -> None: 81 | import time 82 | 83 | filenames = glob.glob(os.path.join(source_path, slimpajama_sets[split]), recursive=True) 84 | filenames = filenames[:int(len(filenames) * percentage)] 85 | 86 | num_processes = cpu_count() 87 | chunked_filenames = np.array_split(filenames, num_processes) 88 | 89 | processes = [] 90 | start_time = time.time() 91 | 92 | for i, subset in enumerate(chunked_filenames): 93 | p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) 94 | processes.append(p) 95 | p.start() 96 | 97 | for p in processes: 98 | p.join() 99 | end_time = time.time() 100 | elapsed_time = end_time - start_time 101 | print(f"Time taken: {elapsed_time:.2f} seconds") 102 | 103 | 104 | if __name__ == "__main__": 105 | from jsonargparse import CLI 106 | CLI(prepare) -------------------------------------------------------------------------------- /scripts/prepare_starcoder.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import os 4 | from pathlib import Path 5 | import sys 6 | from typing import List 7 | import numpy as np 8 | from tqdm import tqdm 9 | from multiprocessing import Process, cpu_count 10 | 11 | # support running without installing as a package 12 | wd = Path(__file__).parent.parent.resolve() 13 | sys.path.append(str(wd)) 14 | 15 | import lit_gpt.packed_dataset as packed_dataset 16 | from lit_gpt import Tokenizer 17 | 18 | import pandas as pd 19 | 20 | 21 | def prepare_full( 22 | source_path: Path, 23 | tokenizer_path: Path, 24 | destination_path: Path, 25 | chunk_size: int, 26 | split: str="train", 27 | filenames_subset: List[str] = None, 28 | process_id: int = 0 29 | ) -> None: 30 | import zstandard as zstd 31 | 32 | destination_path.mkdir(parents=True, exist_ok=True) 33 | 34 | tokenizer = Tokenizer(tokenizer_path) 35 | 36 | # Use the provided filenames_subset or default to all filenames 37 | filenames = filenames_subset 38 | 39 | if not filenames: 40 | raise RuntimeError( 41 | f"No files matching found at {source_path}. \n" 42 | "Make sure you download the data..." 43 | ) 44 | 45 | builder = packed_dataset.PackedDatasetBuilder( 46 | outdir=destination_path, 47 | prefix=f"{split}_starcoder_{process_id}", # Use process_id to differentiate builders 48 | chunk_size=chunk_size, 49 | sep_token=tokenizer.bos_id, 50 | dtype="auto", 51 | vocab_size=tokenizer.vocab_size, 52 | ) 53 | 54 | for filepath in filenames: 55 | print(f"Processing {filepath}") 56 | try: 57 | contents = pd.read_parquet(filepath, engine='pyarrow')['content'] 58 | except: 59 | print(f"Error reading {filepath}!!") 60 | continue 61 | for text in contents: 62 | text_ids = tokenizer.encode(text) 63 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 64 | 65 | # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details 66 | # builder.write_reminder() 67 | 68 | 69 | def prepare( 70 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 71 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 72 | destination_path: Path = Path("data/red_pajama_sample"), 73 | chunk_size: int = 2049 * 1024, 74 | split: str="train", 75 | percentage: float = 1.0, 76 | filenames_subset: List[str] = None, 77 | ) -> None: 78 | import time 79 | assert split == "train" # starcoder only has train data 80 | filenames = glob.glob(os.path.join(source_path, "*/*.parquet"), recursive=True) 81 | # only retrain subsets that follow the prefix in filenames_subset 82 | if filenames_subset: 83 | filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])] 84 | filenames = filenames[:int(len(filenames) * percentage)] 85 | num_processes = 64 86 | chunked_filenames = np.array_split(filenames, num_processes) 87 | 88 | processes = [] 89 | start_time = time.time() 90 | 91 | for i, subset in enumerate(chunked_filenames): 92 | p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) 93 | processes.append(p) 94 | p.start() 95 | 96 | for p in processes: 97 | p.join() 98 | end_time = time.time() 99 | elapsed_time = end_time - start_time 100 | print(f"Time taken: {elapsed_time:.2f} seconds") 101 | 102 | 103 | if __name__ == "__main__": 104 | from jsonargparse import CLI 105 | CLI(prepare) 106 | -------------------------------------------------------------------------------- /sft/script.sh: -------------------------------------------------------------------------------- 1 | # We include a simple full-parameter finetuning & inference script here. Our V0.1 chat model is finetuned using this script. 2 | # The FT dataset we use is openassistant-guanaco. For finetuning with less than 4GB RAM, we refer you to the Qlora and bitsandbytes repo. 3 | # We did not undergone extensive hyperparameter tuning nor choosing more performant FT datasets. 4 | # We hope the community can explore on finetuning TinyLlama and come up with better chat models. I will include community-finetuned models in this repo. 5 | 6 | # V0.1 7 | CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --multi_gpu --num_processes 4 --main_process_port 1234 finetune.py \ 8 | --model_name_or_path PY007/TinyLlama-1.1B-intermediate-step-240k-503b \ 9 | --output_dir ./output/503B_FT_lr1e-5_ep5 \ 10 | --logging_steps 10 \ 11 | --save_strategy epoch \ 12 | --data_seed 42 \ 13 | --save_total_limit 6 \ 14 | --evaluation_strategy epoch \ 15 | --eval_dataset_size 512 \ 16 | --max_eval_samples 1000 \ 17 | --per_device_eval_batch_size 1 \ 18 | --max_new_tokens 32 \ 19 | --dataloader_num_workers 3 \ 20 | --group_by_length=False \ 21 | --logging_strategy steps \ 22 | --remove_unused_columns False \ 23 | --do_train \ 24 | --do_eval \ 25 | --warmup_ratio 0.05 \ 26 | --lr_scheduler_type constant \ 27 | --dataset oasst1 \ 28 | --source_max_len 16 \ 29 | --target_max_len 512 \ 30 | --per_device_train_batch_size 4 \ 31 | --max_steps 0 \ 32 | --num_train_epochs 5 \ 33 | --learning_rate 1e-5 \ 34 | --adam_beta2 0.999 \ 35 | --max_grad_norm 1.0 \ 36 | --weight_decay 0.0 \ 37 | --seed 0 \ 38 | --trust_remote_code \ 39 | --report_to wandb 40 | 41 | 42 | # V0.2 43 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --multi_gpu --num_processes 4 --main_process_port 1234 finetune.py \ 44 | --model_name_or_path PY007/TinyLlama-1.1B-intermediate-step-480k-1T \ 45 | --output_dir ./output/503B_FT_lr1e-5_ep5_top1_2023-08-25 \ 46 | --logging_steps 10 \ 47 | --save_strategy epoch \ 48 | --data_seed 42 \ 49 | --save_total_limit 6 \ 50 | --evaluation_strategy epoch \ 51 | --eval_dataset_size 512 \ 52 | --max_eval_samples 1000 \ 53 | --per_device_eval_batch_size 1 \ 54 | --max_new_tokens 32 \ 55 | --dataloader_num_workers 3 \ 56 | --group_by_length=False \ 57 | --logging_strategy steps \ 58 | --remove_unused_columns False \ 59 | --do_train \ 60 | --do_eval \ 61 | --warmup_ratio 0.05 \ 62 | --lr_scheduler_type constant \ 63 | --dataset OpenAssistant/oasst_top1_2023-08-25 \ 64 | --dataset_format oasst1 \ 65 | --source_max_len 16 \ 66 | --target_max_len 512 \ 67 | --per_device_train_batch_size 4 \ 68 | --max_steps 0 \ 69 | --num_train_epochs 5 \ 70 | --learning_rate 1e-5 \ 71 | --adam_beta2 0.999 \ 72 | --max_grad_norm 1.0 \ 73 | --weight_decay 0.0 \ 74 | --seed 0 \ 75 | --trust_remote_code \ 76 | --report_to wandb 77 | -------------------------------------------------------------------------------- /sft/simple_inference.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | import transformers 3 | import torch 4 | model = "PY007/TinyLlama-1.1B-Chat-v0.1" 5 | tokenizer = AutoTokenizer.from_pretrained(model) 6 | pipeline = transformers.pipeline( 7 | "text-generation", 8 | model=model, 9 | torch_dtype=torch.float16, 10 | device_map="auto", 11 | ) 12 | 13 | prompt = "Give me detailed info about Jeo Biden." 14 | formatted_prompt = ( 15 | f"### Human: {prompt} ### Assistant:" 16 | ) 17 | 18 | 19 | sequences = pipeline( 20 | formatted_prompt, 21 | do_sample=True, 22 | top_k=50, 23 | top_p = 0.9, 24 | num_return_sequences=1, 25 | repetition_penalty=1.1, 26 | max_new_tokens=1024, 27 | ) 28 | for seq in sequences: 29 | print(f"Result: {seq['generated_text']}") 30 | -------------------------------------------------------------------------------- /sft/simple_inference2.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from transformers import AutoTokenizer 4 | import transformers 5 | import torch 6 | model = "PY007/TinyLlama-1.1B-Chat-v0.2" 7 | tokenizer = AutoTokenizer.from_pretrained(model) 8 | pipeline = transformers.pipeline( 9 | "text-generation", 10 | model=model, 11 | torch_dtype=torch.float16, 12 | device_map="auto", 13 | ) 14 | 15 | prompt = "How to get in a good university?" 16 | formatted_prompt = ( 17 | f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" 18 | ) 19 | 20 | 21 | sequences = pipeline( 22 | formatted_prompt, 23 | do_sample=True, 24 | top_k=50, 25 | top_p = 0.9, 26 | num_return_sequences=1, 27 | repetition_penalty=1.1, 28 | max_new_tokens=1024, 29 | ) 30 | for seq in sequences: 31 | print(f"Result: {seq['generated_text']}") -------------------------------------------------------------------------------- /speculative_decoding/README.md: -------------------------------------------------------------------------------- 1 | ## Speculative Decoding 2 | 3 | ### HuggingFace "Assisted Generation" 4 | 5 | 6 | | Large Model | Native Decoding | Assisted Decoding | 7 | | ----------- | --------------- | ------------------ | 8 | | guanaco-7b | 69 seconds | 38 seconds | 9 | | guanaco-13b | 84 seconds | 45 seconds | 10 | | guanaco-33b | 109 seconds | 62 seconds | 11 | 12 | We use PY007/TinyLlama-1.1B-Chat-v0.1 as the assistant model and vary the large model from guanaco-7B to 33B. Experiments are done on a single A40 GPU with code inside instruct_hf_assisted_decoding.py. TinyLlama is loaded in fp16 and the large models are loaded in 8 bit to make guanaco-33b fit in memory and also to keep a consistent setup. The prompt used is "Give me detailed info about Jeo Biden.". max_new_tokens is set to 512. 13 | 14 | You can read this [article](https://huggingface.co/blog/assisted-generation) for more information about HuggingFace's Assisted Generation. 15 | 16 | Quote from HF: "due to INT8 quantization and the use of causal masking in assisted generation, the output of greedy decoding may differ in rare occasions." 17 | #### TODO 18 | - [ ] Thouroughly benchmark the average speedup on 52K Alpaca prompts. 19 | 20 | ### Llama.cpp Speculative Decoding 21 | We have continue-pretrained a code tinyllama from the 500B checkpoint with another 7B Python data [here](https://huggingface.co/PY007/TinyLlama-1.1B-python-v0.1). 22 | The code for continue-pretraining can be found in pretrain/tinyllama_code.py 23 | 24 | ``` 25 | ./speculative \ 26 | -m models/CodeLlama-7b-hf/ggml-model-f16.gguf \ 27 | -md models/TinyLlama-1.1B-500B-python/ggml-model-q4_0.gguf \ 28 | -p "# Quick-sort implementation in Python and sample usage:" \ 29 | -e -ngl 1 -t 4 -n 256 -s 20 --temp 0 --draft 8 30 | ``` 31 | This gives: 32 | 33 | ``` 34 | encoded 12 tokens in 0.247 seconds, speed: 48.638 t/s 35 | decoded 265 tokens in 7.909 seconds, speed: 33.507 t/s 36 | 37 | n_draft = 16 38 | n_predict = 265 39 | n_drafted = 317 40 | n_accept = 195 41 | accept = 61.514% 42 | 43 | draft: 44 | 45 | llama_print_timings: load time = 53.14 ms 46 | llama_print_timings: sample time = 652.62 ms / 1 runs ( 652.62 ms per token, 1.53 tokens per second) 47 | llama_print_timings: prompt eval time = 73.81 ms / 12 tokens ( 6.15 ms per token, 162.58 tokens per second) 48 | llama_print_timings: eval time = 2247.77 ms / 378 runs ( 5.95 ms per token, 168.17 tokens per second) 49 | llama_print_timings: total time = 8154.92 ms 50 | 51 | target: 52 | 53 | llama_print_timings: load time = 534.47 ms 54 | llama_print_timings: sample time = 208.12 ms / 265 runs ( 0.79 ms per token, 1273.32 tokens per second) 55 | llama_print_timings: prompt eval time = 4210.38 ms / 382 tokens ( 11.02 ms per token, 90.73 tokens per second) 56 | llama_print_timings: eval time = 682.80 ms / 16 runs ( 42.68 ms per token, 23.43 tokens per second) 57 | llama_print_timings: total time = 8214.11 ms 58 | ggml_metal_free: deallocating 59 | ggml_metal_free: deallocating 60 | ``` 61 | 62 | Even though the model is continue-pretrained exclusively on Python, it retains its ability in other languages, such as C: 63 | ``` 64 | ./speculative \ 65 | -m models/CodeLlama-7b-hf/ggml-model-f16.gguf \ 66 | -md models/TinyLlama-1.1B-500B-python/ggml-model-q4_0.gguf \ 67 | -p "// Quick-sort implementation in C (4 spaces indentation + detailed comments) and sample usage:\n\n#include" \ 68 | -e -ngl 1 -t 4 -n 256 -s 20 --temp 0 --draft 8 69 | ``` 70 | 71 | This gives: 72 | 73 | ``` 74 | encoded 25 tokens in 0.278 seconds, speed: 89.900 t/s 75 | decoded 258 tokens in 6.432 seconds, speed: 40.112 t/s 76 | 77 | n_draft = 28 78 | n_predict = 258 79 | n_drafted = 278 80 | n_accept = 200 81 | accept = 71.942% 82 | 83 | draft: 84 | 85 | llama_print_timings: load time = 932.54 ms 86 | llama_print_timings: sample time = 583.50 ms / 1 runs ( 583.50 ms per token, 1.71 tokens per second) 87 | llama_print_timings: prompt eval time = 81.50 ms / 25 tokens ( 3.26 ms per token, 306.73 tokens per second) 88 | llama_print_timings: eval time = 1834.67 ms / 329 runs ( 5.58 ms per token, 179.32 tokens per second) 89 | llama_print_timings: total time = 6710.30 ms 90 | 91 | target: 92 | 93 | llama_print_timings: load time = 18568.44 ms 94 | llama_print_timings: sample time = 208.78 ms / 258 runs ( 0.81 ms per token, 1235.75 tokens per second) 95 | llama_print_timings: prompt eval time = 3164.84 ms / 342 tokens ( 9.25 ms per token, 108.06 tokens per second) 96 | llama_print_timings: eval time = 775.43 ms / 18 runs ( 43.08 ms per token, 23.21 tokens per second) 97 | llama_print_timings: total time = 7650.67 ms 98 | ggml_metal_free: deallocating 99 | ggml_metal_free: deallocating 100 | ``` 101 | 102 | 103 | I have not tried 13B CodeLlama as the large model yet because my Mac memory is not enough :). -------------------------------------------------------------------------------- /speculative_decoding/instruct_hf_assisted_decoding.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import torch 3 | import time 4 | 5 | 6 | model_id = "huggyllama/llama-13b" 7 | peft_model_id = "timdettmers/guanaco-13b" 8 | assistant_checkpoint = "PY007/TinyLlama-1.1B-Chat-v0.1" 9 | 10 | 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | tokenizer = AutoTokenizer.from_pretrained(model_id) 13 | 14 | 15 | prompt = "Give me detailed info about Jeo Biden." 16 | formatted_prompt = f"### Human: {prompt}### Assistant:" 17 | inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) 18 | model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True) 19 | model.load_adapter(peft_model_id) 20 | print("Large model loaded") 21 | model.config.use_cache = True 22 | assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint).half().to(device) 23 | assistant_model.config.use_cache = True 24 | print("Small model loaded") 25 | 26 | 27 | print("###Native Decoding Starts...\n") 28 | start = time.time() 29 | outputs = model.generate(**inputs, assistant_model=None, max_new_tokens=512) 30 | end = time.time() 31 | print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) 32 | print("Time: ", end - start) 33 | 34 | print("###TinyLlama Assisted Decoding Starts...\n") 35 | start = time.time() 36 | outputs = model.generate(**inputs, assistant_model=assistant_model,max_new_tokens=512) 37 | end = time.time() 38 | print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) 39 | # print time in seconds 40 | print("Time: ", end - start) 41 | 42 | --------------------------------------------------------------------------------