├── .env.example ├── .flake8 ├── .gitattributes ├── .gitignore ├── .gitmodules ├── .isort.cfg ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── poetry.lock ├── pyproject.toml └── smol_trainer ├── __init__.py ├── config ├── __init__.py ├── model.py └── train.py ├── data ├── __init__.py ├── data_config.py ├── data_configurations │ └── test.yaml └── data_maker.py ├── inference.py ├── model ├── __init__.py ├── gpt.py └── utils.py ├── runner.py ├── trainer ├── __init__.py ├── base.py ├── checkpointer.py ├── data_loader.py └── initializer.py └── utils.py /.env.example: -------------------------------------------------------------------------------- 1 | # For logging to WandB 2 | WANDB_API_KEY=your_wandb_key 3 | HF_TOKEN=your_hf_token -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501, W503, E203, F541, W293, W291, E266, E402 3 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.csv linguist-generated=true 2 | *.json linguist-generated=true 3 | *.jsonl linguist-generated=true 4 | *.ipynb linguist-generated=true 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated Data 2 | results/ 3 | wandb/ 4 | 5 | # Python Cruft 6 | __pycache__/ 7 | 8 | # Hidden conf files 9 | .env 10 | .vscode 11 | .DS_Store 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emrgnt-cmplxty/SmolTrainer/b4659c8bb357f1e80495c5721fc01a674b88925c/.gitmodules -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = black 3 | multi_line_output = 3 4 | include_trailing_comma = true 5 | force_grid_wrap = 0 6 | use_parentheses = true 7 | ensure_newline_before_comments = true 8 | line_length = 79 9 | sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: mypy-poetry 5 | name: mypy (via poetry) 6 | entry: poetry run mypy . 7 | language: system 8 | pass_filenames: false 9 | types: [python] 10 | 11 | - id: black-poetry 12 | name: black (via poetry) 13 | entry: poetry run black 14 | language: system 15 | pass_filenames: true 16 | types: [python] 17 | 18 | - id: isort-poetry 19 | name: isort (via poetry) 20 | entry: poetry run isort 21 | language: system 22 | pass_filenames: true 23 | types: [python] 24 | 25 | - id: flake8-poetry 26 | name: flake8 (via poetry) 27 | entry: poetry run flake8 28 | language: system 29 | pass_filenames: true 30 | types: [python] 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Emergent AGI Inc. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SmolTrainer 2 | 3 | Welcome to SmolTrainer, an evolution of the popular nanoGPT repository, tailored to train additional models with as little overhead as possible. 4 | --- 5 | 6 | ## Install and Setup 7 | 8 | ```bash 9 | # Clone the repository 10 | git clone git@github.com:emrgnt-cmplxty/SmolTrainer.git && cd SmolTrainer 11 | 12 | # Install poetry and the project 13 | pip3 install poetry && poetry install 14 | 15 | # Optional development tooling 16 | # pre-commit install 17 | ``` 18 | 19 | Great, now let's proceed onward to train the full UberSmol models. 20 | 21 | ```bash 22 | # Perform training run with pythia-410m 23 | # place dataset in smol_trainer/data/open-platypus/[train,val].bin 24 | export DATASET=open-platypus 25 | export MODEL_NAME=pythia-410m 26 | poetry run python smol_trainer/runner.py \ 27 | --model-name=$MODEL_NAME \ 28 | --dataset=$DATASET \ 29 | --batch-size=8 \ 30 | --block-size=1024 \ 31 | --eval-iters=250 \ 32 | --compile=True \ 33 | --device=cuda \ 34 | --eval-interval=100 \ 35 | --wandb-log \ 36 | --run-name=run_$DATASET 37 | `````` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry-core", "setuptools", "wheel"] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [tool.poetry] 6 | name = "smol-trainer" 7 | version = "0.1.0" 8 | description = "A smol interface for training, built on top of nanoGPT." 9 | authors = ["Owen Colegrove "] 10 | license = "Apache-2.0" 11 | readme = "README.md" 12 | 13 | [tool.poetry.dependencies] 14 | python = ">=3.10,<3.11" 15 | torch = "2.1.0" 16 | numpy = "^1.25.2" 17 | requests = "^2.31.0" 18 | wandb = "^0.15.9" 19 | python-dotenv = "^1.0.0" 20 | datasets = "^2.14.5" 21 | transformers = "^4.34.0" 22 | triton = "^2.1.0" 23 | fire = "^0.5.0" 24 | 25 | [tool.poetry.group.dev.dependencies] 26 | black = "^23.3.0" 27 | flake8 = "6.1.0" 28 | isort = "5.12.0" 29 | pre-commit = "^3.3.3" 30 | mypy = "^1.5.1" 31 | sourcery = "^1.6.0" 32 | types-requests = "^2.31.0.2" 33 | types-attrs = "^19.1.0" 34 | yapf = "0.40.1" 35 | 36 | [tool.black] 37 | line-length = 79 38 | 39 | [tool.mypy] 40 | ignore_missing_imports = true 41 | exclude = 'smol_trainer/inference.py|playground/' 42 | 43 | [tool.flake8] 44 | ignore = ["E501", "W503"] 45 | exclude = 'smol_trainer/nano_gpt/' 46 | -------------------------------------------------------------------------------- /smol_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | 3 | load_dotenv() 4 | -------------------------------------------------------------------------------- /smol_trainer/config/__init__.py: -------------------------------------------------------------------------------- 1 | from smol_trainer.config.model import ModelConfig 2 | from smol_trainer.config.train import LearningConfig, TrainConfig 3 | 4 | __all__ = ["LearningConfig", "ModelConfig", "TrainConfig"] 5 | -------------------------------------------------------------------------------- /smol_trainer/config/model.py: -------------------------------------------------------------------------------- 1 | # Derived from https://github.com/Lightning-AI/lit-gpt/tree/main. Apache-2 License. 2 | 3 | import json 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import Any, Literal, Optional, Type, Union 7 | 8 | import torch 9 | from typing_extensions import Self 10 | 11 | 12 | def find_multiple(n: int, k: int) -> int: 13 | assert k > 0 14 | if n % k == 0: 15 | return n 16 | return n + k - (n % k) 17 | 18 | 19 | @dataclass 20 | class ModelConfig: 21 | org: str = "Emergent-AGI" 22 | name: str = "SmolModel" 23 | # GPT Params 24 | block_size: int = 4096 25 | vocab_size: int = 32_000 26 | n_layer: int = 16 27 | n_head: int = 32 28 | n_embd: int = 768 29 | bias: bool = True 30 | 31 | # Unused 32 | padding_multiple: int = 512 33 | padded_vocab_size: Optional[int] = None 34 | rotary_percentage: float = 0.25 35 | parallel_residual: bool = True 36 | lm_head_bias: bool = False 37 | # to use multi-head attention (MHA), set this to `n_head` (default) 38 | # to use multi-query attention (MQA), set this to 1 39 | # to use grouped-query attention (GQA), set this to a value in between 40 | # Example with `n_head=4` 41 | # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ 42 | # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ 43 | # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ 44 | # │ │ │ │ │ │ │ 45 | # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ 46 | # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ 47 | # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ 48 | # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ 49 | # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ 50 | # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ 51 | # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ 52 | # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ 53 | # MHA GQA MQA 54 | # n_query_groups=4 n_query_groups=2 n_query_groups=1 55 | # 56 | # credit https://arxiv.org/pdf/2305.13245.pdf 57 | n_query_groups: Optional[int] = None 58 | shared_attention_norm: bool = False 59 | _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" 60 | norm_eps: float = 1e-5 61 | _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP" 62 | gelu_approximate: str = "none" 63 | intermediate_size: Optional[int] = None 64 | rope_condense_ratio: int = 1 65 | rope_base: int = 10000 66 | 67 | def __post_init__(self): 68 | assert self.n_embd % self.n_head == 0 69 | self.head_size = self.n_embd // self.n_head 70 | 71 | # vocab size should be a power of 2 to be optimal on hardware. compute the closest value 72 | if self.padded_vocab_size is None: 73 | self.padded_vocab_size = find_multiple( 74 | self.vocab_size, self.padding_multiple 75 | ) 76 | else: 77 | # vocab size shouldn't be larger than padded vocab size 78 | self.vocab_size = min(self.vocab_size, self.padded_vocab_size) 79 | 80 | # compute the number of query groups 81 | if self.n_query_groups is not None: 82 | assert self.n_head % self.n_query_groups == 0 83 | else: 84 | self.n_query_groups = self.n_head 85 | 86 | # compute the intermediate size for MLP if not set 87 | if self.intermediate_size is None: 88 | if self._mlp_class == "LLaMAMLP": 89 | raise ValueError( 90 | "The config needs to set the `intermediate_size`" 91 | ) 92 | self.intermediate_size = 4 * self.n_embd 93 | 94 | self.rope_n_elem = int(self.rotary_percentage * self.head_size) 95 | 96 | @classmethod 97 | def from_name(cls, name: str, **kwargs: Any) -> Self: 98 | conf_dict = name_to_config[name].copy() 99 | if "condense_ratio" in kwargs: # legacy name 100 | kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio") 101 | conf_dict.update(kwargs) 102 | return cls(**conf_dict) 103 | 104 | @classmethod 105 | def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self: 106 | with open(path, encoding="utf-8") as fp: 107 | json_kwargs = json.load(fp) 108 | if "condense_ratio" in json_kwargs: # legacy name 109 | json_kwargs["rope_condense_ratio"] = json_kwargs.pop( 110 | "condense_ratio" 111 | ) 112 | if "condense_ratio" in kwargs: # legacy name 113 | kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio") 114 | json_kwargs.update(kwargs) 115 | return cls(**json_kwargs) 116 | 117 | @property 118 | def mlp_class(self) -> Type: 119 | from smol_trainer import model 120 | 121 | # `self._mlp_class` cannot be the type to keep the config json serializable 122 | return getattr(model, self._mlp_class) 123 | 124 | @property 125 | def norm_class(self) -> Type: 126 | # `self._norm_class` cannot be the type to keep the config json serializable 127 | if self._norm_class == "RMSNorm": 128 | from smol_trainer.model import RMSNorm 129 | 130 | return RMSNorm 131 | return getattr(torch.nn, self._norm_class) 132 | 133 | 134 | #################### 135 | # EleutherAI Pythia 136 | #################### 137 | configs = [ 138 | # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json 139 | dict( 140 | org="EleutherAI", 141 | name="pythia-70m", 142 | block_size=2048, 143 | n_layer=6, 144 | n_embd=512, 145 | n_head=8, 146 | padding_multiple=128, 147 | ), 148 | # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json 149 | dict( 150 | org="EleutherAI", 151 | name="pythia-160m", 152 | block_size=2048, 153 | n_layer=12, 154 | n_embd=768, 155 | n_head=12, 156 | padding_multiple=128, 157 | ), 158 | # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json 159 | dict( 160 | org="EleutherAI", 161 | name="pythia-410m", 162 | block_size=2048, 163 | n_layer=24, 164 | n_embd=1024, 165 | n_head=16, 166 | padding_multiple=128, 167 | ), 168 | # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json 169 | dict( 170 | org="EleutherAI", 171 | name="pythia-1b", 172 | block_size=2048, 173 | n_embd=2048, 174 | n_head=8, 175 | padding_multiple=128, 176 | ), 177 | # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json 178 | dict( 179 | org="EleutherAI", 180 | name="pythia-1.4b", 181 | block_size=2048, 182 | n_layer=24, 183 | n_embd=2048, 184 | n_head=16, 185 | padding_multiple=128, 186 | ), 187 | # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json 188 | dict( 189 | org="EleutherAI", 190 | name="pythia-2.8b", 191 | block_size=2048, 192 | n_layer=32, 193 | n_embd=2560, 194 | padding_multiple=128, 195 | ), 196 | # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json 197 | dict( 198 | org="EleutherAI", 199 | name="pythia-6.9b", 200 | block_size=2048, 201 | n_layer=32, 202 | padding_multiple=256, 203 | ), 204 | # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json 205 | dict( 206 | org="EleutherAI", 207 | name="pythia-12b", 208 | block_size=2048, 209 | n_layer=36, 210 | n_embd=5120, 211 | n_head=40, 212 | ), 213 | ] 214 | configs.append( 215 | dict( 216 | org="TEST", 217 | name="test-local", 218 | block_size=1024, 219 | n_layer=4, 220 | n_embd=128, 221 | n_head=4, 222 | ) 223 | ) 224 | 225 | name_to_config = {config["name"]: config for config in configs} 226 | -------------------------------------------------------------------------------- /smol_trainer/config/train.py: -------------------------------------------------------------------------------- 1 | """Configurations for training.""" 2 | import logging 3 | import time 4 | from dataclasses import dataclass 5 | from enum import Enum 6 | 7 | 8 | @dataclass 9 | class LearningConfig: 10 | """Learning rate and optimizer configuration.""" 11 | 12 | # Learning rate arguments 13 | initial_lr: float 14 | decay_lr: float 15 | lr: float 16 | min_lr: float 17 | 18 | # Optimizer arguments 19 | grad_clip: float 20 | weight_decay: float 21 | beta1: float 22 | beta2: float 23 | do_flash_v2: bool 24 | 25 | # Iteration variables 26 | warmup_iters: int 27 | lr_decay_iters: int 28 | gradient_accumulation_steps: int 29 | 30 | 31 | @dataclass 32 | class TrainConfig: 33 | """Training configuration.""" 34 | 35 | # Logging support 36 | logger: logging.Logger 37 | lr_config: LearningConfig 38 | master_process: bool 39 | log_interval: int 40 | wandb_log: bool 41 | 42 | # Arhitecture 43 | model_name: str 44 | 45 | # Training Params 46 | eval_interval: int 47 | batch_size: int 48 | block_size: int 49 | max_iters: int 50 | eval_iters: int 51 | max_checkpoints: int 52 | 53 | # Run variables 54 | out_dir: str 55 | checkpoint_dir: str 56 | run_name: str 57 | ddp: bool 58 | device: str 59 | device_type: str 60 | always_save_checkpoint: bool 61 | 62 | # Run information 63 | iter_num: int = 0 64 | total_tokens_processed: int = 0 65 | best_val_loss: float = 1e9 66 | running_mfu: float = -1.0 67 | initial_time: float = time.time() 68 | total_time: float = 0.0 69 | training_loss: float = -1 70 | -------------------------------------------------------------------------------- /smol_trainer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | 3 | load_dotenv() 4 | -------------------------------------------------------------------------------- /smol_trainer/data/data_config.py: -------------------------------------------------------------------------------- 1 | # datasets_config.py 2 | 3 | datasets_config = { 4 | "textbook_quality_programming": { 5 | "path": "vikp/textbook_quality_programming", 6 | "mappings": { 7 | "markdown": "", 8 | }, 9 | "tokenizer": "simple_tokenize", 10 | }, 11 | "open-platypus": { 12 | "path": "garage-bAInd/Open-Platypus", 13 | "mappings": { 14 | "instruction": "### Instruction:", 15 | "output": "### Output:", 16 | }, 17 | "tokenizer": "simple_tokenize", 18 | }, 19 | "sciphi-python-textbook": { 20 | "path": "emrgnt-cmplxty/sciphi-python-textbook", 21 | "mappings": { 22 | "formatted_prompt": "### Instruction:", 23 | "completion": "### Response:", 24 | }, 25 | "tokenizer": "simple_tokenize", 26 | }, 27 | "sciphi-textbooks-are-all-you-need": { 28 | "path": "emrgnt-cmplxty/sciphi-textbooks-are-all-you-need", 29 | "mappings": { 30 | "formatted_prompt": "### Instruction:", 31 | "completion": "### Response:", 32 | }, 33 | "tokenizer": "simple_tokenize", 34 | }, 35 | "open-phi-textbooks": { 36 | "path": "open-phi/textbooks", 37 | "mappings": {"markdown": ""}, 38 | "tokenizer": "simple_tokenize", 39 | }, 40 | "programming-books-llama": { 41 | "path": "open-phi/programming_books_llama", 42 | "mappings": {"markdown": ""}, 43 | "tokenizer": "simple_tokenize", 44 | }, 45 | "open-orca": { 46 | "path": "Open-Orca/OpenOrca", 47 | "mappings": { 48 | "system_prompt": "### System:", 49 | "question": "### Question:", 50 | "response": "### Response:", 51 | }, 52 | "tokenizer": "simple_tokenize", 53 | }, 54 | "tiny-stories": { 55 | "path": "roneneldan/TinyStories", 56 | "mappings": {"text": ""}, 57 | "tokenizer": "simple_tokenize", 58 | }, 59 | "tiny-codes": { 60 | "path": "nampdn-ai/tiny-codes", 61 | "mappings": { 62 | "prompt": "### Instruction:", 63 | "response": "### Response:", 64 | }, 65 | "tokenizer": "simple_tokenize", 66 | }, 67 | "tiny-orca": { 68 | "path": "nampdn-ai/tiny-orca-textbooks", 69 | "mappings": { 70 | "prompt": "### System:", 71 | "question": "### Question:", 72 | "textbook": "### Context:", 73 | "response": "### Response:", 74 | }, 75 | "tokenizer": "simple_tokenize", 76 | }, 77 | "tiny-textbooks": { 78 | "path": "nampdn-ai/tiny-textbooks", 79 | "mappings": { 80 | "text": "### Instruction:\nWrite a lesson based on the following content:", 81 | "textbook": "### Response:", 82 | }, 83 | "tokenizer": "simple_tokenize", 84 | }, 85 | "meta-math": { 86 | "path": "meta-math/MetaMathQA", 87 | "mappings": { 88 | "query": "### Question:", 89 | "response": "### Answer:", 90 | }, 91 | "tokenizer": "simple_tokenize", 92 | }, 93 | "evol-instruct": { 94 | "path": "nickrosh/Evol-Instruct-Code-80k-v1", 95 | "mappings": { 96 | "instruction": "### Instruction:", 97 | "output": "### Output:", 98 | }, 99 | "tokenizer": "simple_tokenize", 100 | }, 101 | } 102 | -------------------------------------------------------------------------------- /smol_trainer/data/data_configurations/test.yaml: -------------------------------------------------------------------------------- 1 | sciphi-python-textbook: 10 2 | sciphi-textbooks-are-all-you-need: 5 3 | open-phi-textbooks: 3 4 | -------------------------------------------------------------------------------- /smol_trainer/data/data_maker.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import logging 3 | import multiprocessing 4 | import os 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | import fire as fire 8 | import numpy as np 9 | import yaml 10 | from datasets import load_dataset 11 | from tqdm import tqdm 12 | from transformers import AutoTokenizer 13 | 14 | from smol_trainer.data.data_config import datasets_config 15 | from smol_trainer.utils import get_root_py_fpath 16 | 17 | # Basic configuration for logging 18 | logging.basicConfig( 19 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 20 | ) 21 | 22 | 23 | class TokenizationManager: 24 | """Class for managing tokenization of datasets.""" 25 | 26 | def __init__(self, encoding, num_proc=None): 27 | self.tokenizer = AutoTokenizer.from_pretrained(encoding) 28 | self.num_proc = num_proc or multiprocessing.cpu_count() 29 | 30 | def tokenize_dataset(self, dataset, formatters, raw_tokenize_function): 31 | tokenize_function = lambda x: raw_tokenize_function( 32 | x, formatters, self.tokenizer 33 | ) 34 | tokenized_dataset = dataset.map( 35 | tokenize_function, num_proc=self.num_proc 36 | ) 37 | return [ 38 | token 39 | for sublist in tokenized_dataset["input_ids"] 40 | for token in sublist[0] 41 | ] 42 | 43 | def split_and_save(self, flattened_tokens, val_frac, dataset_name): 44 | train_ids = flattened_tokens[ 45 | : int(len(flattened_tokens) * (1 - val_frac)) 46 | ] 47 | val_ids = flattened_tokens[ 48 | int(len(flattened_tokens) * (1 - val_frac)) : 49 | ] 50 | 51 | train_ids = np.array(train_ids, dtype=np.uint16) 52 | val_ids = np.array(val_ids, dtype=np.uint16) 53 | 54 | if not os.path.exists( 55 | os.path.join(os.path.dirname(__file__), dataset_name) 56 | ): 57 | os.mkdir(os.path.join(os.path.dirname(__file__), dataset_name)) 58 | 59 | train_ids.tofile( 60 | os.path.join( 61 | os.path.dirname(__file__), 62 | dataset_name, 63 | "train.bin", 64 | ) 65 | ) 66 | val_ids.tofile( 67 | os.path.join( 68 | os.path.dirname(__file__), 69 | dataset_name, 70 | "val.bin", 71 | ) 72 | ) 73 | 74 | 75 | class DatasetLoader: 76 | def __init__(self, split, token): 77 | self.datasets = {} 78 | self.split = split 79 | self.token = token 80 | 81 | def add_dataset(self, name, path, mappings, tokenizer): 82 | self.datasets[name] = { 83 | "path": path, 84 | "mappings": mappings, 85 | "tokenizer": tokenizer, 86 | } 87 | 88 | def load_all_datasets(self): 89 | loaded_data = {} 90 | for name, data in self.datasets.items(): 91 | loaded_data[name] = ( 92 | load_dataset(data["path"], split=self.split, token=self.token), 93 | data["mappings"], 94 | data["tokenizer"], 95 | ) 96 | return loaded_data 97 | 98 | 99 | def simple_tokenizer(batch, formatters, tokenizer, add_bos_eos=True): 100 | combined_text = [] 101 | for column, prefix in formatters.items(): 102 | prefix = formatters[column] 103 | combined_text.append(f"{prefix}:\n{batch[column]}") 104 | if add_bos_eos: 105 | combined_text = ( 106 | [tokenizer.bos_token] + combined_text + [tokenizer.eos_token] 107 | ) 108 | return tokenizer.encode("\n\n".join(combined_text), return_tensors="np") 109 | 110 | 111 | def read_binary_dataset(dataset_name, split): 112 | file_path = os.path.join( 113 | os.path.dirname(__file__), dataset_name, f"{split}.bin" 114 | ) 115 | return np.fromfile(file_path, dtype=np.uint16) 116 | 117 | 118 | def load_weights_from_yaml(yaml_path): 119 | with open(yaml_path, "r") as file: 120 | dataset_weights = yaml.safe_load(file) 121 | return dataset_weights 122 | 123 | 124 | def tokenizer( 125 | hf_token=None, encoding="mistralai/Mistral-7B-v0.1", val_frac=0.01 126 | ): 127 | hf_token = hf_token or os.environ.get("HF_TOKEN") 128 | logging.info(f"Loading datasets with token: {hf_token}") 129 | 130 | if not hf_token: 131 | raise ValueError( 132 | "No HuggingFace token provided and HF_TOKEN environment variable is not set." 133 | ) 134 | 135 | loader = DatasetLoader(split="train", token=hf_token) 136 | 137 | for dataset_name, config in datasets_config.items(): 138 | loader.add_dataset( 139 | dataset_name, 140 | config["path"], 141 | config["mappings"], 142 | globals()[config["tokenizer"]], 143 | ) 144 | 145 | datasets = loader.load_all_datasets() 146 | manager = TokenizationManager(encoding) 147 | 148 | for dataset_name, payload in datasets.items(): 149 | logging.info(f"Tokenizing dataset: {dataset_name}") 150 | 151 | dataset, formatters, raw_tokenize_function = payload 152 | if "train" in dataset: 153 | dataset = dataset["train"] 154 | 155 | flattened_tokens = manager.tokenize_dataset( 156 | dataset, formatters, raw_tokenize_function 157 | ) 158 | manager.split_and_save(flattened_tokens, val_frac, dataset_name) 159 | 160 | 161 | def process_dataset(dataset_name, fraction, all_data, chunk_size): 162 | # TODO - Move parallelism downstream to here. 163 | logging.info( 164 | f"Loading and remixing dataset: {dataset_name} with fraction: {fraction}" 165 | ) 166 | 167 | local_data = {"train": [], "val": []} 168 | 169 | for dataset_type in all_data.keys(): 170 | data = read_binary_dataset(dataset_name, dataset_type) 171 | subset_length = int(len(data) * fraction) 172 | local_data[dataset_type].extend(data[:subset_length]) 173 | 174 | for dataset_type, data in local_data.items(): 175 | # Chunking 176 | chunks = [ 177 | data[i : i + chunk_size] 178 | for i in tqdm( 179 | range(0, len(data), chunk_size), 180 | desc=f"Chunking {dataset_type}", 181 | ) 182 | ] 183 | 184 | # Shuffling merged chunks 185 | np.random.shuffle(chunks) 186 | 187 | # Unchunking 188 | local_data[dataset_type] = [ 189 | token for chunk in chunks for token in chunk 190 | ] 191 | 192 | return local_data 193 | 194 | 195 | def remixer(config_name, chunk_size=2_048, num_proc=None): 196 | config_path = os.path.join( 197 | get_root_py_fpath(), 198 | "data", 199 | "data_configurations", 200 | f"{config_name}.yaml" if ".yaml" not in config_name else config_name, 201 | ) 202 | 203 | logging.info(f"Loading dataset weights from: {config_path}") 204 | dataset_weights = load_weights_from_yaml(config_path) 205 | 206 | # Identify the highest weight 207 | max_weight = max(dataset_weights.values()) 208 | 209 | # Normalize weights 210 | fractions = {k: v / max_weight for k, v in dataset_weights.items()} 211 | 212 | remixed_name = config_name.replace(".yaml", "") 213 | logging.info( 214 | f"Saving datasets now to remixed dataset name = {remixed_name}" 215 | ) 216 | 217 | all_data = {"train": [], "val": []} 218 | 219 | num_proc = num_proc or multiprocessing.cpu_count() 220 | 221 | # Parallelizing dataset processing 222 | with ThreadPoolExecutor( 223 | max_workers=num_proc 224 | ) as executor: # Adjust max_workers as necessary 225 | futures = [] 226 | for dataset_name, fraction in fractions.items(): 227 | future = executor.submit( 228 | process_dataset, dataset_name, fraction, all_data, chunk_size 229 | ) 230 | futures.append(future) 231 | 232 | for future in concurrent.futures.as_completed(futures): 233 | # merge processed data back to all_data 234 | processed_data = future.result() 235 | for dataset_type in all_data.keys(): 236 | all_data[dataset_type].extend(processed_data[dataset_type]) 237 | 238 | # Construct path for saving 239 | save_path = os.path.join(get_root_py_fpath(), "data", remixed_name) 240 | if not os.path.exists(save_path): 241 | os.mkdir(save_path) 242 | 243 | logging.info(f"Saving remixed datasets to {save_path}") 244 | 245 | # Save the remixed datasets 246 | for dataset_type, data in all_data.items(): 247 | np.array(data, dtype=np.uint16).tofile( 248 | os.path.join(save_path, f"{dataset_type}.bin") 249 | ) 250 | 251 | 252 | if __name__ == "__main__": 253 | fire.Fire({"tokenizer": tokenizer, "remixer": remixer}) 254 | -------------------------------------------------------------------------------- /smol_trainer/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample from a trained model 3 | """ 4 | import os 5 | import pickle 6 | from contextlib import nullcontext 7 | 8 | import tiktoken 9 | import torch 10 | 11 | from smol_trainer.config import Model 12 | from smol_trainer.model import GPT, GPTConfig, MoEGPT 13 | 14 | # ----------------------------------------------------------------------------- 15 | init_from = "resume" # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl') 16 | out_dir = "results" # ignored if init_from is not 'resume' 17 | start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt" 18 | num_samples = 10 # number of samples to draw 19 | max_new_tokens = 500 # number of tokens generated in each sample 20 | temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions 21 | top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability 22 | seed = 1337 23 | device = ( 24 | "cuda" if torch.cuda.is_available() else "cpu" 25 | ) # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. 26 | dtype = ( 27 | "bfloat16" 28 | if torch.cuda.is_available() and torch.cuda.is_bf16_supported() 29 | else "float16" 30 | ) # 'float32' or 'bfloat16' or 'float16' 31 | compile = False # use PyTorch 2.0 to compile the model to be faster 32 | model_prefix = "checkpoint__mode_moe__n_layer_4__n_head_4__n_embd_128__n_experts_128__top_k_experts_16" 33 | iter_num = "2000" 34 | meta_path = "x" 35 | exec( 36 | open("nano_gpt/configurator.py").read() 37 | ) # overrides from command line or config file 38 | # ----------------------------------------------------------------------------- 39 | 40 | torch.manual_seed(seed) 41 | torch.cuda.manual_seed(seed) 42 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 43 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 44 | device_type = ( 45 | "cuda" if "cuda" in device else "cpu" 46 | ) # for later use in torch.autocast 47 | ptdtype = { 48 | "float32": torch.float32, 49 | "bfloat16": torch.bfloat16, 50 | "float16": torch.float16, 51 | }[dtype] 52 | ctx = ( 53 | nullcontext() 54 | if device_type == "cpu" 55 | else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 56 | ) 57 | 58 | # model 59 | if init_from == "resume": 60 | # init from a model saved in a specific directory 61 | ckpt_path = os.path.join( 62 | out_dir, model_prefix, f"{model_prefix}__iter_num_{iter_num}.pt" 63 | ) 64 | checkpoint = torch.load(ckpt_path, map_location=device) 65 | gptconf = GPTConfig( 66 | block_size=checkpoint["block_size"], 67 | n_layer=checkpoint["n_layer"], 68 | n_head=checkpoint["n_head"], 69 | n_embd=checkpoint["n_embd"], 70 | dropout=checkpoint["dropout"], 71 | bias=checkpoint["bias"], 72 | do_flash_v2=checkpoint["do_flash_v2"], 73 | ) 74 | model = GPT(gptconf) 75 | state_dict = checkpoint["model"] 76 | unwanted_prefix = "_orig_mod." 77 | for k, v in list(state_dict.items()): 78 | if k.startswith(unwanted_prefix): 79 | state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) 80 | model.load_state_dict(state_dict) 81 | elif init_from.startswith("gpt2"): 82 | # init from a given GPT-2 model 83 | model = GPT.from_pretrained(init_from, dict(dropout=0.0)) 84 | 85 | model.eval() 86 | model.to(device) 87 | if compile: 88 | model = torch.compile(model) # requires PyTorch 2.0 (optional) 89 | 90 | # look for the meta pickle in case it is available in the dataset folder 91 | print("checkpoint = ", checkpoint.keys()) 92 | # ok let's assume gpt-2 encodings by default 93 | print("No meta.pkl found, assuming GPT-2 encodings...") 94 | enc = tiktoken.get_encoding("gpt2") 95 | encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) 96 | decode = lambda l: enc.decode(l) 97 | 98 | # encode the beginning of the prompt 99 | if start.startswith("FILE:"): 100 | with open(start[5:], "r", encoding="utf-8") as f: 101 | start = f.read() 102 | start_ids = encode(start) 103 | x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] 104 | 105 | # run generation 106 | with torch.no_grad(): 107 | with ctx: 108 | for k in range(num_samples): 109 | y = model.generate( 110 | x, max_new_tokens, temperature=temperature, top_k=top_k 111 | ) 112 | print(decode(y[0].tolist())) 113 | print("---------------") 114 | -------------------------------------------------------------------------------- /smol_trainer/model/__init__.py: -------------------------------------------------------------------------------- 1 | from smol_trainer.model.gpt import GPT, GptNeoxMLP 2 | 3 | __all__ = ["GPT", "GptNeoxMLP"] 4 | -------------------------------------------------------------------------------- /smol_trainer/model/gpt.py: -------------------------------------------------------------------------------- 1 | """Full definition of a GPT NeoX Language Model, all of it in this single file. 2 | 3 | Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and 4 | https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. 5 | 6 | Derived from https://github.com/Lightning-AI/lit-gpt/tree/main. Apache-2 License. 7 | """ 8 | import logging 9 | import math 10 | from typing import Any, Optional, Tuple 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.nn import functional as F 15 | from typing_extensions import Self 16 | 17 | from smol_trainer.config import ModelConfig 18 | from smol_trainer.model.utils import RequirementCache 19 | 20 | 21 | class RMSNorm(torch.nn.Module): 22 | """Root Mean Square Layer Normalization. 23 | 24 | Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: 25 | https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. 26 | """ 27 | 28 | def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: 29 | super().__init__() 30 | self.weight = torch.nn.Parameter(torch.ones(size)) 31 | self.eps = eps 32 | self.dim = dim 33 | 34 | def forward(self, x: torch.Tensor) -> torch.Tensor: 35 | # NOTE: the original RMSNorm paper implementation is not equivalent 36 | norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) 37 | x_normed = x * torch.rsqrt(norm_x + self.eps) 38 | return self.weight * x_normed 39 | 40 | def reset_parameters(self): 41 | torch.nn.init.ones_(self.weight) 42 | 43 | 44 | FlashAttention2Available = bool(RequirementCache("flash-attn>=2.0.0.post1")) 45 | 46 | 47 | class GPT(nn.Module): 48 | def __init__(self, config: ModelConfig) -> None: 49 | super().__init__() 50 | assert config.padded_vocab_size is not None 51 | self.config = config 52 | 53 | self.lm_head = nn.Linear( 54 | config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias 55 | ) 56 | self.transformer = nn.ModuleDict( 57 | dict( 58 | wte=nn.Embedding(config.padded_vocab_size, config.n_embd), 59 | h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), 60 | ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), 61 | ) 62 | ) 63 | self.max_seq_length = self.config.block_size 64 | self.mask_cache: Optional[torch.Tensor] = None 65 | logging.info(f"FlashAttention2Available = {FlashAttention2Available}") 66 | 67 | @property 68 | def max_seq_length(self) -> int: 69 | return self._max_seq_length 70 | 71 | @max_seq_length.setter 72 | def max_seq_length(self, value: int) -> None: 73 | """ 74 | When doing inference, the sequences used might be shorter than the model's context length. 75 | This allows setting a smaller number to avoid allocating unused memory 76 | """ 77 | if value > self.config.block_size: 78 | raise ValueError( 79 | f"Cannot attend to {value}, block size is only {self.config.block_size}" 80 | ) 81 | self._max_seq_length = value 82 | if not hasattr(self, "cos"): 83 | # first call 84 | cos, sin = self.rope_cache() 85 | self.register_buffer("cos", cos, persistent=False) 86 | self.register_buffer("sin", sin, persistent=False) 87 | elif value != self.cos.size(0): 88 | # override 89 | self.cos, self.sin = self.rope_cache(device=self.cos.device) 90 | # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know 91 | # if the kv cache is expected 92 | 93 | def reset_parameters(self) -> None: 94 | # Trigger resetting the rope-cache 95 | self.max_seq_length = self.config.block_size 96 | 97 | def _init_weights(self, module: nn.Module) -> None: 98 | """Meant to be used with `gpt.apply(gpt._init_weights)`.""" 99 | if isinstance(module, nn.Linear): 100 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 101 | if module.bias is not None: 102 | torch.nn.init.zeros_(module.bias) 103 | elif isinstance(module, nn.Embedding): 104 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 105 | 106 | def forward( 107 | self, 108 | idx: torch.Tensor, 109 | targets: Optional[torch.Tensor] = None, 110 | input_pos: Optional[torch.Tensor] = None, 111 | ) -> torch.Tensor: 112 | T = idx.size(1) 113 | if self.max_seq_length < T: 114 | raise ValueError( 115 | f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}." 116 | ) 117 | 118 | if input_pos is not None: # use the kv cache 119 | cos = self.cos.index_select(0, input_pos) 120 | sin = self.sin.index_select(0, input_pos) 121 | if self.mask_cache is None: 122 | raise TypeError("You need to call `gpt.set_kv_cache()`") 123 | mask = self.mask_cache.index_select(2, input_pos) 124 | else: 125 | cos = self.cos[:T] 126 | sin = self.sin[:T] 127 | mask = None 128 | 129 | x = self.transformer.wte( 130 | idx 131 | ) # token embeddings of shape (b, t, n_embd) 132 | for block in self.transformer.h: 133 | x = block(x, cos, sin, mask, input_pos) 134 | x = self.transformer.ln_f(x) 135 | if targets is not None: 136 | # if we are given some desired targets also calculate the loss 137 | logits = self.lm_head(x) 138 | loss = F.cross_entropy( 139 | logits.view(-1, logits.size(-1)), 140 | targets.view(-1), 141 | ignore_index=-1, 142 | ) 143 | else: 144 | # inference-time mini-optimization: only forward the lm_head on the very last position 145 | logits = self.lm_head( 146 | x[:, [-1], :] 147 | ) # note: using list [-1] to preserve the time dim 148 | loss = None 149 | 150 | return logits, loss 151 | 152 | @classmethod 153 | def from_name(cls, name: str, **kwargs: Any) -> Self: 154 | return cls(ModelConfig.from_name(name, **kwargs)) 155 | 156 | def rope_cache( 157 | self, device: Optional[torch.device] = None 158 | ) -> Tuple[torch.Tensor, torch.Tensor]: 159 | return build_rope_cache( 160 | seq_len=self.max_seq_length, 161 | n_elem=self.config.rope_n_elem, 162 | dtype=torch.get_default_dtype(), 163 | device=device, 164 | condense_ratio=self.config.rope_condense_ratio, 165 | base=self.config.rope_base, 166 | ) 167 | 168 | def set_kv_cache( 169 | self, 170 | batch_size: int, 171 | rope_cache_length: Optional[int] = None, 172 | device: Optional[torch.device] = None, 173 | dtype: Optional[torch.dtype] = None, 174 | ) -> None: 175 | if rope_cache_length is None: 176 | rope_cache_length = self.cos.size(-1) 177 | max_seq_length = self.max_seq_length 178 | 179 | # initialize the kv cache for all blocks 180 | for block in self.transformer.h: 181 | block.attn.kv_cache = block.attn.build_kv_cache( 182 | batch_size, max_seq_length, rope_cache_length, device, dtype 183 | ) 184 | 185 | if ( 186 | self.mask_cache is None 187 | or self.mask_cache.size(3) != max_seq_length 188 | ): 189 | # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask 190 | # for the kv-cache support (only during inference), we only create it in that situation 191 | # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 192 | ones = torch.ones( 193 | (max_seq_length, max_seq_length), 194 | device=device, 195 | dtype=torch.bool, 196 | ) 197 | self.mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0) 198 | 199 | def clear_kv_cache(self) -> None: 200 | self.mask_cache = None 201 | for block in self.transformer.h: 202 | block.attn.kv_cache = None 203 | 204 | def get_num_params(self) -> int: 205 | """ 206 | Return the number of parameters in the model. 207 | For non-embedding count (default), the position embeddings get subtracted. 208 | The token embeddings would too, except due to the parameter sharing these 209 | params are actually used as weights in the final layer, so we include them. 210 | """ 211 | return sum(p.numel() for p in self.parameters()) 212 | 213 | def estimate_mfu(self, fwdbwd_per_iter: int, dt: float) -> float: 214 | """estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS""" 215 | # first estimate the number of flops we do per iteration. 216 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 217 | N = self.get_num_params() 218 | cfg = self.config 219 | L, H, Q, T = ( 220 | cfg.n_layer, 221 | cfg.n_head, 222 | cfg.n_embd // cfg.n_head, 223 | cfg.block_size, 224 | ) 225 | flops_per_token = 6 * N + 12 * L * H * Q * T 226 | flops_per_fwdbwd = flops_per_token * T 227 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 228 | # express our flops throughput as ratio of A100 bfloat16 peak flops 229 | flops_achieved = flops_per_iter * (1.0 / dt) # per second 230 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 231 | mfu = flops_achieved / flops_promised 232 | return mfu 233 | 234 | 235 | class Block(nn.Module): 236 | def __init__(self, config: ModelConfig) -> None: 237 | super().__init__() 238 | self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) 239 | self.attn = CausalSelfAttention(config) 240 | self.norm_2 = ( 241 | None 242 | if config.shared_attention_norm 243 | else config.norm_class(config.n_embd, eps=config.norm_eps) 244 | ) 245 | self.mlp = config.mlp_class(config) 246 | 247 | self.config = config 248 | 249 | def forward( 250 | self, 251 | x: torch.Tensor, 252 | cos: torch.Tensor, 253 | sin: torch.Tensor, 254 | mask: Optional[torch.Tensor] = None, 255 | input_pos: Optional[torch.Tensor] = None, 256 | ) -> torch.Tensor: 257 | n_1 = self.norm_1(x) 258 | h = self.attn(n_1, cos, sin, mask, input_pos) 259 | if self.config.parallel_residual: 260 | n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) 261 | x = x + h + self.mlp(n_2) 262 | else: 263 | if self.config.shared_attention_norm: 264 | raise NotImplementedError( 265 | "No checkpoint amongst the ones we support uses this configuration" 266 | " (non-parallel residual and shared attention norm)." 267 | ) 268 | x = x + h 269 | x = x + self.mlp(self.norm_2(x)) 270 | return x 271 | 272 | 273 | class CausalSelfAttention(nn.Module): 274 | def __init__(self, config: ModelConfig) -> None: 275 | super().__init__() 276 | shape = (config.n_head + 2 * config.n_query_groups) * config.head_size 277 | # key, query, value projections for all heads, but in a batch 278 | self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) 279 | # output projection 280 | self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 281 | # disabled by default 282 | self.kv_cache: Optional[KVCache] = None 283 | 284 | self.config = config 285 | 286 | def forward( 287 | self, 288 | x: torch.Tensor, 289 | cos: torch.Tensor, 290 | sin: torch.Tensor, 291 | mask: Optional[torch.Tensor] = None, 292 | input_pos: Optional[torch.Tensor] = None, 293 | ) -> torch.Tensor: 294 | ( 295 | B, 296 | T, 297 | C, 298 | ) = ( 299 | x.size() 300 | ) # batch size, sequence length, embedding dimensionality (n_embd) 301 | 302 | qkv = self.attn(x) 303 | 304 | # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) 305 | q_per_kv = self.config.n_head // self.config.n_query_groups 306 | total_qkv = ( 307 | q_per_kv + 2 308 | ) # each group has 1+ queries, 1 key, and 1 value 309 | qkv = qkv.view( 310 | B, T, self.config.n_query_groups, total_qkv, self.config.head_size 311 | ) 312 | qkv = qkv.permute( 313 | 0, 2, 3, 1, 4 314 | ) # (B, n_query_groups, total_qkv, T, hs) 315 | 316 | # split batched computation into three 317 | q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) 318 | 319 | # repeat k and v if necessary 320 | if ( 321 | self.config.n_query_groups != 1 322 | ): # doing this would require a full kv cache with MQA (inefficient!) 323 | # for MHA this is a no-op 324 | k = k.expand( 325 | B, 326 | self.config.n_query_groups, 327 | q_per_kv, 328 | T, 329 | self.config.head_size, 330 | ) 331 | v = v.expand( 332 | B, 333 | self.config.n_query_groups, 334 | q_per_kv, 335 | T, 336 | self.config.head_size, 337 | ) 338 | 339 | q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) 340 | k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) 341 | v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) 342 | 343 | q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) 344 | k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) 345 | q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) 346 | k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) 347 | 348 | if input_pos is not None: 349 | if not isinstance(self.kv_cache, KVCache): 350 | raise TypeError("You need to call `gpt.set_kv_cache()`") 351 | k, v = self.kv_cache(input_pos, k, v) 352 | 353 | y = self.scaled_dot_product_attention(q, k, v, mask) 354 | 355 | y = y.reshape(B, T, C) # re-assemble all head outputs side by side 356 | 357 | # output projection 358 | return self.proj(y) 359 | 360 | def scaled_dot_product_attention( 361 | self, 362 | q: torch.Tensor, 363 | k: torch.Tensor, 364 | v: torch.Tensor, 365 | mask: Optional[torch.Tensor] = None, 366 | ): 367 | scale = 1.0 / math.sqrt(self.config.head_size) 368 | if ( 369 | FlashAttention2Available 370 | and mask is None 371 | and q.device.type == "cuda" 372 | and q.dtype in (torch.float16, torch.bfloat16) 373 | ): 374 | from flash_attn import flash_attn_func 375 | 376 | # flash-attn requires (B, T, nh, hs) 377 | q = q.transpose(1, 2) 378 | k = k.transpose(1, 2) 379 | v = v.transpose(1, 2) 380 | return flash_attn_func( 381 | q, k, v, dropout_p=0.0, softmax_scale=scale, causal=True 382 | ) 383 | y = torch.nn.functional.scaled_dot_product_attention( 384 | q, 385 | k, 386 | v, 387 | attn_mask=mask, 388 | dropout_p=0.0, 389 | scale=scale, 390 | is_causal=mask is None, 391 | ) 392 | return y.transpose(1, 2) 393 | 394 | def build_kv_cache( 395 | self, 396 | batch_size: int, 397 | max_seq_length: int, 398 | rope_cache_length: Optional[int] = None, 399 | device: Optional[torch.device] = None, 400 | dtype: Optional[torch.dtype] = None, 401 | ) -> "KVCache": 402 | heads = 1 if self.config.n_query_groups == 1 else self.config.n_head 403 | v_shape = (batch_size, heads, max_seq_length, self.config.head_size) 404 | if rope_cache_length is None: 405 | if self.config.rotary_percentage != 1.0: 406 | raise TypeError( 407 | "Please pass the `rope_cache_length=gpt.cos.size(-1)` value" 408 | ) 409 | k_shape = v_shape 410 | else: 411 | k_shape = ( 412 | batch_size, 413 | heads, 414 | max_seq_length, 415 | rope_cache_length 416 | + self.config.head_size 417 | - self.config.rope_n_elem, 418 | ) 419 | return KVCache(k_shape, v_shape, device=device, dtype=dtype) 420 | 421 | 422 | class GptNeoxMLP(nn.Module): 423 | def __init__(self, config: ModelConfig) -> None: 424 | super().__init__() 425 | self.fc = nn.Linear( 426 | config.n_embd, config.intermediate_size, bias=config.bias 427 | ) 428 | self.proj = nn.Linear( 429 | config.intermediate_size, config.n_embd, bias=config.bias 430 | ) 431 | 432 | self.config = config 433 | 434 | def forward(self, x: torch.Tensor) -> torch.Tensor: 435 | x = self.fc(x) 436 | x = torch.nn.functional.gelu( 437 | x, approximate=self.config.gelu_approximate 438 | ) 439 | return self.proj(x) 440 | 441 | 442 | class LLaMAMLP(nn.Module): 443 | def __init__(self, config: ModelConfig) -> None: 444 | super().__init__() 445 | self.fc_1 = nn.Linear( 446 | config.n_embd, config.intermediate_size, bias=config.bias 447 | ) 448 | self.fc_2 = nn.Linear( 449 | config.n_embd, config.intermediate_size, bias=config.bias 450 | ) 451 | self.proj = nn.Linear( 452 | config.intermediate_size, config.n_embd, bias=config.bias 453 | ) 454 | 455 | def forward(self, x: torch.Tensor) -> torch.Tensor: 456 | x_fc_1 = self.fc_1(x) 457 | x_fc_2 = self.fc_2(x) 458 | x = torch.nn.functional.silu(x_fc_1) * x_fc_2 459 | return self.proj(x) 460 | 461 | 462 | def build_rope_cache( 463 | seq_len: int, 464 | n_elem: int, 465 | dtype: torch.dtype, 466 | device: Optional[torch.device] = None, 467 | base: int = 10000, 468 | condense_ratio: int = 1, 469 | ) -> Tuple[torch.Tensor, torch.Tensor]: 470 | """Enhanced Transformer with Rotary Position Embedding. 471 | 472 | Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ 473 | transformers/rope/__init__.py. MIT License: 474 | https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. 475 | """ 476 | # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ 477 | theta = 1.0 / ( 478 | base ** (torch.arange(0, n_elem, 2, device=device) / n_elem) 479 | ) 480 | 481 | # Create position indexes `[0, 1, ..., seq_len - 1]` 482 | seq_idx = torch.arange(seq_len, device=device) / condense_ratio 483 | 484 | # Calculate the product of position index and $\theta_i$ 485 | idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) 486 | 487 | cos, sin = torch.cos(idx_theta), torch.sin(idx_theta) 488 | 489 | # this is to mimic the behaviour of complex32, else we will get different results 490 | if dtype in (torch.float16, torch.bfloat16, torch.int8): 491 | return cos.half(), sin.half() 492 | return cos, sin 493 | 494 | 495 | def apply_rope( 496 | x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor 497 | ) -> torch.Tensor: 498 | head_size = x.size(-1) 499 | x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) 500 | x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) 501 | rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) 502 | roped = (x * cos) + (rotated * sin) 503 | return roped.type_as(x) 504 | 505 | 506 | class KVCache(nn.Module): 507 | def __init__( 508 | self, 509 | k_shape: Tuple[int, int, int, int], 510 | v_shape: Tuple[int, int, int, int], 511 | device: Optional[torch.device] = None, 512 | dtype: Optional[torch.dtype] = None, 513 | ) -> None: 514 | super().__init__() 515 | self.register_buffer( 516 | "k", 517 | torch.zeros(k_shape, device=device, dtype=dtype), 518 | persistent=False, 519 | ) 520 | self.register_buffer( 521 | "v", 522 | torch.zeros(v_shape, device=device, dtype=dtype), 523 | persistent=False, 524 | ) 525 | 526 | def forward( 527 | self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor 528 | ) -> Tuple[torch.Tensor, torch.Tensor]: 529 | # move the buffer to the activation dtype for when AMP is used 530 | self.k = self.k.to(k.dtype) 531 | self.v = self.v.to(v.dtype) 532 | # update the cache 533 | k = self.k.index_copy_(2, input_pos, k) 534 | v = self.v.index_copy_(2, input_pos, v) 535 | return k, v 536 | -------------------------------------------------------------------------------- /smol_trainer/model/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright The PyTorch Lightning team. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # http://www.apache.org/licenses/LICENSE-2.0 4 | 5 | import importlib 6 | from functools import lru_cache 7 | from importlib.util import find_spec 8 | from typing import Optional 9 | 10 | import pkg_resources # type: ignore 11 | 12 | 13 | @lru_cache() 14 | def package_available(package_name: str) -> bool: 15 | """Check if a package is available in your environment. 16 | 17 | >>> package_available('os') 18 | True 19 | >>> package_available('bla') 20 | False 21 | 22 | """ 23 | try: 24 | return find_spec(package_name) is not None 25 | except ModuleNotFoundError: 26 | return False 27 | 28 | 29 | @lru_cache() 30 | def module_available(module_path: str) -> bool: 31 | """Check if a module path is available in your environment. 32 | 33 | >>> module_available('os') 34 | True 35 | >>> module_available('os.bla') 36 | False 37 | >>> module_available('bla.bla') 38 | False 39 | 40 | """ 41 | module_names = module_path.split(".") 42 | if not package_available(module_names[0]): 43 | return False 44 | try: 45 | importlib.import_module(module_path) 46 | except ImportError: 47 | return False 48 | return True 49 | 50 | 51 | class RequirementCache: 52 | """Boolean-like class to check for requirement and module availability. 53 | 54 | Args: 55 | requirement: The requirement to check, version specifiers are allowed. 56 | module: The optional module to try to import if the requirement check fails. 57 | 58 | >>> RequirementCache("torch>=0.1") 59 | Requirement 'torch>=0.1' met 60 | >>> bool(RequirementCache("torch>=0.1")) 61 | True 62 | >>> bool(RequirementCache("torch>100.0")) 63 | False 64 | >>> RequirementCache("torch") 65 | Requirement 'torch' met 66 | >>> bool(RequirementCache("torch")) 67 | True 68 | >>> bool(RequirementCache("unknown_package")) 69 | False 70 | >>> bool(RequirementCache(module="torch.utils")) 71 | True 72 | >>> bool(RequirementCache(module="unknown_package")) 73 | False 74 | >>> bool(RequirementCache(module="unknown.module.path")) 75 | False 76 | 77 | """ 78 | 79 | def __init__( 80 | self, requirement: Optional[str] = None, module: Optional[str] = None 81 | ) -> None: 82 | if not (requirement or module): 83 | raise ValueError("At least one arguments need to be set.") 84 | self.requirement = requirement 85 | self.module = module 86 | 87 | def _check_requirement(self) -> None: 88 | assert self.requirement # noqa: S101; needed for typing 89 | try: 90 | # first try the pkg_resources requirement 91 | pkg_resources.require(self.requirement) 92 | self.available = True 93 | self.message = f"Requirement {self.requirement!r} met" 94 | except Exception as ex: 95 | self.available = False 96 | self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`" 97 | req_include_version = any(c in self.requirement for c in "=<>") 98 | if not req_include_version or self.module is not None: 99 | module = ( 100 | self.requirement if self.module is None else self.module 101 | ) 102 | # sometimes `pkg_resources.require()` fails but the module is importable 103 | self.available = module_available(module) 104 | if self.available: 105 | self.message = f"Module {module!r} available" 106 | 107 | def _check_module(self) -> None: 108 | assert self.module # noqa: S101; needed for typing 109 | self.available = module_available(self.module) 110 | if self.available: 111 | self.message = f"Module {self.module!r} available" 112 | else: 113 | self.message = f"Module not found: {self.module!r}. HINT: Try running `pip install -U {self.module}`" 114 | 115 | def _check_available(self) -> None: 116 | if hasattr(self, "available"): 117 | return 118 | if self.requirement: 119 | self._check_requirement() 120 | if getattr(self, "available", True) and self.module: 121 | self._check_module() 122 | 123 | def __bool__(self) -> bool: 124 | """Format as bool.""" 125 | self._check_available() 126 | return self.available 127 | 128 | def __str__(self) -> str: 129 | """Format as string.""" 130 | self._check_available() 131 | return self.message 132 | 133 | def __repr__(self) -> str: 134 | """Format as string.""" 135 | return self.__str__() 136 | -------------------------------------------------------------------------------- /smol_trainer/runner.py: -------------------------------------------------------------------------------- 1 | """ 2 | This training script can be run both on a single gpu in debug mode, 3 | and also in a larger training run with distributed data parallel (ddp). 4 | 5 | To run on a single GPU, example: 6 | $ python train.py --batch_size=32 --compile=False 7 | 8 | To run with DDP on 4 gpus on 1 node, example: 9 | $ torchrun --standalone --nproc_per_node=4 train.py 10 | 11 | To run with DDP on 4 gpus across 2 nodes, example: 12 | - Run on the first (master) node with example IP 123.456.123.456: 13 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 14 | - Run on the worker node: 15 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 16 | (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1) 17 | """ 18 | 19 | import argparse 20 | import logging 21 | import os 22 | 23 | # from contextlib import AbstractContextManager 24 | from contextlib import nullcontext 25 | from typing import Any, Union 26 | 27 | import torch 28 | from torch.distributed import destroy_process_group, init_process_group 29 | from torch.nn import Module 30 | from torch.nn.parallel import DistributedDataParallel as DDP 31 | 32 | import wandb 33 | from smol_trainer.config import LearningConfig, TrainConfig 34 | from smol_trainer.trainer import ( 35 | get_checkpoint_prefix, 36 | get_project_identifier, 37 | initialize_model_from_checkpoint, 38 | initialize_model_from_scratch, 39 | initialize_optimizer, 40 | load_data, 41 | train_model, 42 | ) 43 | from smol_trainer.utils import get_configured_logger, parse_args 44 | 45 | 46 | def load_config_and_overwrite_args( 47 | logger: logging.Logger, args: argparse.Namespace 48 | ) -> None: 49 | """Load config from file and overwrite args.""" 50 | if args.config_file: 51 | local_vars_before = locals().copy() 52 | config_load = open(args.config_file).read() 53 | logger.info(f"Reading config from {args.config_file}:\n{config_load}") 54 | 55 | local_namespace: dict = {} 56 | exec(config_load, globals(), local_namespace) 57 | 58 | new_vars = set(local_namespace.keys()) - set(local_vars_before.keys()) 59 | 60 | for var in new_vars: 61 | setattr(args, var, local_namespace[var]) 62 | 63 | 64 | def setup_run_args(logger: logging.Logger, args: argparse.Namespace) -> None: 65 | """Setup the arguments for the run.""" 66 | # sourcery skip: extract-method 67 | 68 | # various inits, derived attributes, I/O setup 69 | args.ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? 70 | if args.ddp: 71 | init_process_group(backend=args.backend) 72 | args.ddp_rank = int(os.environ["RANK"]) 73 | args.ddp_local_rank = int(os.environ["LOCAL_RANK"]) 74 | args.ddp_world_size = int(os.environ["WORLD_SIZE"]) 75 | args.device = f"cuda:{args.ddp_local_rank}" 76 | torch.cuda.set_device(args.device) 77 | args.master_process = ( 78 | args.ddp_rank == 0 79 | ) # this process will do logging, checkpointing etc. 80 | args.seed_offset = args.ddp_rank # each process gets a different seed 81 | # world_size number of processes will be training simultaneously, so we can scale 82 | # down the desired gradient accumulation iterations per process proportionally 83 | assert args.gradient_accumulation_steps % args.ddp_world_size == 0 84 | args.gradient_accumulation_steps //= args.ddp_world_size 85 | else: 86 | # if not ddp, we are running on a single gpu, and one process 87 | args.master_process = True 88 | args.seed_offset = 0 89 | args.ddp_world_size = 1 90 | args.tokens_per_iter = ( 91 | args.gradient_accumulation_steps 92 | * args.ddp_world_size 93 | * args.batch_size 94 | * args.block_size 95 | ) 96 | args.device_type = ( 97 | "cuda" if "cuda" in args.device else "cpu" 98 | ) # for later use in torch.autocast 99 | 100 | # Initialize here so we can override if init_from='resume' (i.e. from a checkpoint) 101 | args.best_val_loss = 1e9 102 | 103 | prefix = get_checkpoint_prefix(vars(args)) 104 | args.tensorboard_path = os.path.join( 105 | args.out_dir, f"{prefix}__tensorboard" 106 | ) 107 | args.checkpoint_dir = os.path.join( 108 | args.out_dir, 109 | prefix, 110 | ) 111 | 112 | logger.info(f"tokens per iteration will be: {args.tokens_per_iter:,}") 113 | 114 | 115 | def setup_amp_context(args: argparse.Namespace) -> Any: 116 | """Sets up the autocast context.""" 117 | ptdtype = { 118 | "float32": torch.float32, 119 | "bfloat16": torch.bfloat16, 120 | "float16": torch.float16, 121 | }[args.dtype] 122 | return ( 123 | nullcontext() 124 | if args.device_type == "cpu" 125 | else torch.amp.autocast(device_type=args.device_type, dtype=ptdtype) 126 | ) 127 | 128 | 129 | def setup_ddp(args: argparse.Namespace, model: Module) -> Union[Module, Any]: 130 | # TODO - What is the appropriate type for DDP? 131 | """Sets up the DDP model.""" 132 | if args.ddp: 133 | model = DDP(model, device_ids=[args.ddp_local_rank]) 134 | return model 135 | 136 | 137 | def setup_training_environment(args: argparse.Namespace) -> Any: 138 | """Setup the training environment""" 139 | 140 | load_config_and_overwrite_args(logger, args) 141 | setup_run_args(logger, args) 142 | torch.manual_seed(1337 + args.seed_offset) 143 | torch.backends.cuda.matmul.allow_tf32 = True 144 | torch.backends.cudnn.allow_tf32 = True 145 | 146 | if ( 147 | ( 148 | os.path.exists(args.checkpoint_dir) 149 | or os.path.exists(args.tensorboard_path) 150 | ) 151 | and args.init_from == "scratch" 152 | and args.master_process 153 | ): 154 | raise ValueError( 155 | f"Checkpoint directory {args.checkpoint_dir} or {args.tensorboard_path} already exists, please move before re-running." 156 | ) 157 | return setup_amp_context(args) 158 | 159 | 160 | def initialize_run_performance_logging( 161 | args: argparse.Namespace, 162 | ) -> None: 163 | """Initialize logging with WandB and TensorBoard.""" 164 | if args.wandb_log and args.master_process: 165 | config_dict = vars(args) 166 | wandb.init( 167 | project=get_project_identifier(config_dict), 168 | name=args.run_name, 169 | config=config_dict, 170 | ) 171 | 172 | 173 | if __name__ == "__main__": 174 | args = parse_args() 175 | config = vars(args) 176 | 177 | logger = get_configured_logger(__name__, args.log_level) 178 | logger.info(f"Running with passed in args:\n{args}") 179 | 180 | amp_context = setup_training_environment(args) 181 | 182 | logger.info(f"Running over dataset = {args.dataset}") 183 | train_data, val_data, meta_vocab_size = load_data(logger, args.dataset) 184 | 185 | checkpoint = None 186 | if args.init_from == "scratch": 187 | model: Module = initialize_model_from_scratch(args, logger) 188 | elif args.init_from == "resume": 189 | model, checkpoint = initialize_model_from_checkpoint(args, logger) 190 | else: 191 | raise ValueError( 192 | "Invalid initialization mode. Must be 'scratch' or 'resume'." 193 | ) 194 | 195 | model.to(args.device) 196 | 197 | optimizer = initialize_optimizer(args, model, checkpoint) 198 | 199 | # Initialize a GradScaler. If enabled=False scaler is a no-op 200 | # We must comment out this code which appears in nanoGPT 201 | # This is to avoid explosions of gradients when using Torch 2.0.x 202 | # The authors in nanoGPT claim this is a fault of Torch 203 | # TODO - Investigate this further 204 | # scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == "float16")) 205 | scaler = torch.cuda.amp.GradScaler( 206 | enabled=(args.dtype == "float16"), growth_interval=0 207 | ) 208 | # Compile the model 209 | if args.compile: 210 | logger.info("Compiling the model... (takes a ~minute)") 211 | model = torch.compile(model) # requires PyTorch 2.0 212 | 213 | # Wrap the model into DDP container 214 | model = setup_ddp(args, model) 215 | 216 | if args.master_process: 217 | os.makedirs(args.out_dir, exist_ok=True) 218 | 219 | raw_model = ( 220 | model.module if args.ddp else model 221 | ) # unwrap DDP container if needed 222 | logger.info(f"Running with the following model:\m{model}") 223 | 224 | lr_config = LearningConfig( 225 | # Learning rate settings 226 | lr=args.initial_lr, 227 | initial_lr=args.initial_lr, 228 | decay_lr=args.decay_lr, 229 | min_lr=args.min_lr, 230 | # Optimizer settings 231 | grad_clip=args.grad_clip, 232 | weight_decay=args.weight_decay, 233 | beta1=args.beta1, 234 | beta2=args.beta2, 235 | do_flash_v2=args.do_flash_v2, 236 | # Iteration variables 237 | lr_decay_iters=args.lr_decay_iters, 238 | warmup_iters=args.warmup_iters, 239 | gradient_accumulation_steps=args.gradient_accumulation_steps, 240 | ) 241 | 242 | initialize_run_performance_logging(args) 243 | 244 | # Initialize the training config 245 | train_config = TrainConfig( 246 | # Logging support 247 | logger=logger, 248 | lr_config=lr_config, 249 | master_process=args.master_process, 250 | log_interval=args.log_interval, 251 | wandb_log=args.wandb_log, 252 | # Architecture 253 | model_name=args.model_name, 254 | # Training params 255 | eval_interval=args.eval_interval, 256 | batch_size=args.batch_size, 257 | block_size=args.block_size, 258 | max_iters=args.max_iters, 259 | eval_iters=args.eval_iters, 260 | max_checkpoints=args.max_checkpoints, 261 | # Run information 262 | out_dir=args.out_dir, 263 | checkpoint_dir=args.checkpoint_dir, 264 | run_name=args.run_name, 265 | ddp=args.ddp, 266 | device=args.device, 267 | device_type=args.device_type, 268 | always_save_checkpoint=args.always_save_checkpoint, 269 | iter_num=checkpoint["iter_num"] if checkpoint else 0, 270 | total_tokens_processed=checkpoint["total_tokens_processed"] 271 | if checkpoint 272 | else 0, 273 | ) 274 | 275 | train_model( 276 | model, 277 | optimizer, 278 | scaler, 279 | train_data, 280 | val_data, 281 | train_config, 282 | amp_context, 283 | raw_model, 284 | ) 285 | 286 | if args.ddp: 287 | destroy_process_group() 288 | -------------------------------------------------------------------------------- /smol_trainer/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from smol_trainer.trainer.base import train_model 2 | from smol_trainer.trainer.checkpointer import ( 3 | get_checkpoint_prefix, 4 | get_project_identifier, 5 | ) 6 | from smol_trainer.trainer.data_loader import get_batch, load_data 7 | from smol_trainer.trainer.initializer import ( 8 | initialize_model_from_checkpoint, 9 | initialize_model_from_scratch, 10 | initialize_optimizer, 11 | ) 12 | 13 | __all__ = [ 14 | "TrainConfig", 15 | "train_model", 16 | "load_data", 17 | "get_batch", 18 | "get_project_identifier", 19 | "get_checkpoint_prefix", 20 | "initialize_model_from_scratch", 21 | "initialize_model_from_checkpoint", 22 | "initialize_optimizer", 23 | ] 24 | -------------------------------------------------------------------------------- /smol_trainer/trainer/base.py: -------------------------------------------------------------------------------- 1 | """Base training loop for SmolTrainer.""" 2 | 3 | import math 4 | import threading 5 | import time 6 | 7 | # from contextlib import AbstractContextManager 8 | from typing import Any 9 | 10 | import numpy as np 11 | import torch 12 | from torch.cuda.amp import GradScaler 13 | from torch.nn import Module 14 | from torch.optim import Optimizer 15 | 16 | import wandb 17 | from smol_trainer.config.train import LearningConfig, TrainConfig 18 | from smol_trainer.trainer.checkpointer import ( 19 | manage_checkpoints, 20 | save_checkpoint, 21 | ) 22 | from smol_trainer.trainer.data_loader import get_batch 23 | from smol_trainer.utils import custom_asdict 24 | 25 | # ========================== Learning Rate Logic ========================== 26 | 27 | 28 | def linear_warmup_lr(lr_config: LearningConfig, it): 29 | """Calculate learning rate during the warmup phase.""" 30 | return lr_config.initial_lr * it / lr_config.warmup_iters 31 | 32 | 33 | def cosine_decay_lr(lr_config: LearningConfig, it): 34 | """Calculate learning rate using cosine decay.""" 35 | decay_ratio = (it - lr_config.warmup_iters) / ( 36 | lr_config.lr_decay_iters - lr_config.warmup_iters 37 | ) 38 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) 39 | return lr_config.min_lr + coeff * (lr_config.initial_lr - lr_config.min_lr) 40 | 41 | 42 | def get_lr(lr_config: LearningConfig, it: int): 43 | """Get the learning rate for the given iteration.""" 44 | if it < lr_config.warmup_iters: 45 | return linear_warmup_lr(lr_config, it) 46 | elif it > lr_config.lr_decay_iters: 47 | return lr_config.min_lr 48 | else: 49 | return cosine_decay_lr(lr_config, it) 50 | 51 | 52 | # ========================== Logging Logic ========================== 53 | 54 | 55 | def log_metrics( 56 | config: TrainConfig, 57 | lossf: float, 58 | dt: float, 59 | ): 60 | """Log metrics during training.""" 61 | config.logger.info( 62 | f"iter {config.iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {config.running_mfu*100:.2f}%, total_tokens_processed {config.total_tokens_processed}" 63 | ) 64 | 65 | 66 | # ========================== Evaluation Logic ========================== 67 | 68 | 69 | @torch.no_grad() 70 | def estimate_loss( 71 | config: TrainConfig, 72 | amp_context: Any, 73 | model: Module, 74 | train_data: np.memmap, 75 | val_data: np.memmap, 76 | ) -> dict: 77 | """Estimate the loss on the training and validation sets.""" 78 | 79 | out = {} 80 | model.eval() 81 | for split in ["train", "val"]: 82 | losses = torch.zeros(config.eval_iters) 83 | for k in range(config.eval_iters): 84 | X, Y = get_batch( 85 | config, train_data if split == "train" else val_data 86 | ) # fetch the very first batch 87 | 88 | with amp_context: 89 | _, loss = model(X, Y) 90 | losses[k] = loss.item() 91 | out[split] = losses.mean() 92 | model.train() 93 | return out 94 | 95 | 96 | def perform_evaluation( 97 | config: TrainConfig, 98 | optimizer: Optimizer, 99 | model: Module, 100 | raw_model: Module, 101 | train_data: np.memmap, 102 | val_data: np.memmap, 103 | amp_context: Any, 104 | ) -> None: 105 | """Evaluate the model and save checkpoints if necessary.""" 106 | losses = estimate_loss(config, amp_context, model, train_data, val_data) 107 | 108 | config.total_time = time.time() - config.initial_time 109 | config.logger.info( 110 | f"Eval @ iter = {config.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, total time {config.total_time:.2f}" 111 | ) 112 | 113 | # Logging with WandB 114 | if config.wandb_log: 115 | wandb.log( 116 | { 117 | "iter": config.iter_num, 118 | "train/loss": losses["train"], 119 | "val/loss": losses["val"], 120 | "lr": config.lr_config.lr, 121 | "mfu": config.running_mfu * 100, # convert to percentage 122 | "tokens_processed": config.total_tokens_processed, 123 | } 124 | ) 125 | 126 | # Save checkpoint and manage old checkpoints 127 | if losses["val"] < config.best_val_loss or config.always_save_checkpoint: 128 | config.best_val_loss, training_loss = losses["val"], losses["train"] 129 | config.training_loss = training_loss 130 | if config.iter_num > 0: 131 | output_config = custom_asdict(config) 132 | output_config.pop("logger") 133 | save_checkpoint( 134 | output_config, 135 | raw_model, 136 | optimizer, 137 | ) 138 | thread = threading.Thread( 139 | target=manage_checkpoints, args=(output_config,) 140 | ) 141 | thread.start() 142 | 143 | 144 | def train_model( 145 | model: Module, 146 | optimizer: Optimizer, 147 | scaler: GradScaler, 148 | train_data: np.memmap, 149 | val_data: np.memmap, 150 | # TODO - Track down the type for amp_context 151 | # Why does amp_context: AbstractContextManager fail? 152 | config: TrainConfig, 153 | amp_context: Any, 154 | raw_model: Module, 155 | ) -> None: 156 | """Train the model.""" 157 | # TODO - Break this function up into smaller functions 158 | 159 | # initial setup 160 | lr_config = config.lr_config 161 | 162 | # Fetch the very first batch 163 | X, Y = get_batch(config, train_data) 164 | t0 = time.time() 165 | 166 | for local_iter_num, iter_num in enumerate( 167 | range(config.iter_num, config.max_iters + 1) 168 | ): 169 | config.iter_num = iter_num 170 | 171 | # determine and set the learning rate for this iteration 172 | lr = ( 173 | get_lr(lr_config, iter_num) 174 | if lr_config.decay_lr 175 | else lr_config.initial_lr 176 | ) 177 | for param_group in optimizer.param_groups: 178 | param_group["lr"] = lr 179 | lr_config.lr = lr 180 | 181 | # evaluate the loss on train/val sets and write checkpoints 182 | if iter_num % config.eval_interval == 0 and config.master_process: 183 | perform_evaluation( 184 | config, 185 | optimizer, 186 | model, 187 | raw_model, 188 | train_data, 189 | val_data, 190 | amp_context, 191 | ) 192 | 193 | # forward backward update, with optional gradient accumulation to simulate larger batch size 194 | # and using the GradScaler if data type is float16 195 | for micro_step in range(lr_config.gradient_accumulation_steps): 196 | config.total_tokens_processed += X.numel() 197 | 198 | if config.ddp: 199 | # in DDP training we only need to sync gradients at the last micro step. 200 | # the official way to do this is with model.no_sync() context manager, but 201 | # I really dislike that this bloats the code and forces us to repeat code 202 | # looking at the source of that context manager, it just toggles this variable 203 | model.require_backward_grad_sync = ( 204 | micro_step == lr_config.gradient_accumulation_steps - 1 205 | ) 206 | with amp_context: 207 | _, loss = model(X, Y) 208 | loss = ( 209 | loss / lr_config.gradient_accumulation_steps 210 | ) # scale the loss to account for gradient accumulation 211 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 212 | X, Y = get_batch(config, train_data) # fetch the very first batch 213 | # backward pass, with gradient scaling if training in fp16 214 | scaler.scale(loss).backward() 215 | # clip the gradient 216 | if lr_config.grad_clip != 0.0: 217 | scaler.unscale_(optimizer) 218 | torch.nn.utils.clip_grad_norm_( 219 | model.parameters(), lr_config.grad_clip 220 | ) 221 | # step the optimizer and scaler if training in fp16 222 | scaler.step(optimizer) 223 | scaler.update() 224 | # flush the gradients as soon as we can, no need for this memory anymore 225 | optimizer.zero_grad(set_to_none=True) 226 | 227 | if iter_num % config.log_interval == 0 and config.master_process: 228 | # Calculate elapsed time 229 | dt = time.time() - t0 230 | t0 = time.time() 231 | 232 | # Get loss and update metrics 233 | lossf = loss.item() * lr_config.gradient_accumulation_steps 234 | if local_iter_num >= 5: 235 | mfu = raw_model.estimate_mfu( 236 | config.batch_size * lr_config.gradient_accumulation_steps, 237 | dt, 238 | ) 239 | config.running_mfu = ( 240 | mfu 241 | if config.running_mfu == -1.0 242 | else 0.9 * config.running_mfu + 0.1 * mfu 243 | ) 244 | 245 | log_metrics(config, lossf, dt) 246 | return 247 | -------------------------------------------------------------------------------- /smol_trainer/trainer/checkpointer.py: -------------------------------------------------------------------------------- 1 | """A module for checkpointing and saving models during training.""" 2 | import datetime 3 | import glob 4 | import os 5 | 6 | import torch 7 | from torch.nn import Module 8 | from torch.optim import Optimizer 9 | 10 | 11 | def get_project_identifier(output_config: dict) -> str: 12 | """Returns the name of the checkpoint file""" 13 | 14 | return f"run_name_{output_config['run_name']}__model_{output_config['model_name']}" 15 | 16 | 17 | def get_checkpoint_prefix(output_config: dict) -> str: 18 | """Returns the name of the checkpoint file""" 19 | return f"checkpoint_{get_project_identifier(output_config)}" 20 | 21 | 22 | def manage_checkpoints(output_config: dict) -> None: 23 | """Manage the checkpoints: save, delete old ones""" 24 | 25 | # List all checkpoints 26 | prefix = get_checkpoint_prefix(output_config) 27 | file_name = f"{prefix}__iter_num_*.pt" 28 | checkpoints = sorted( 29 | glob.glob(os.path.join(output_config["out_dir"], prefix, file_name)), 30 | key=lambda x: int(x.split("__iter_num_")[-1].split(".pt")[0]), 31 | ) 32 | 33 | # Remove older checkpoints 34 | for ckpt in checkpoints[: -output_config["max_checkpoints"]]: 35 | os.remove(ckpt) 36 | 37 | 38 | def save_checkpoint( 39 | output_config: dict, 40 | model: Module, 41 | optimizer: Optimizer, 42 | ) -> None: 43 | """Saves the checkpoint to the designated output location""" 44 | 45 | checkpoint = { 46 | **output_config, 47 | "model": model.state_dict(), 48 | "num_params": model.get_num_params(), 49 | "optimizer": optimizer.state_dict(), 50 | "timestamp": str(datetime.datetime.now()), 51 | "pytorch_version": torch.__version__, 52 | } 53 | 54 | prefix = get_checkpoint_prefix(output_config) 55 | os.makedirs(output_config["checkpoint_dir"], exist_ok=True) 56 | 57 | temp_checkpoint_path = os.path.join( 58 | output_config["checkpoint_dir"], 59 | f"{prefix}__iter_num_{output_config['iter_num']}.temp", 60 | ) 61 | checkpoint_path = temp_checkpoint_path.replace(".temp", ".pt") 62 | 63 | # Save to a temporary file to avoid data corruption 64 | torch.save(checkpoint, temp_checkpoint_path) 65 | os.rename(temp_checkpoint_path, checkpoint_path) 66 | -------------------------------------------------------------------------------- /smol_trainer/trainer/data_loader.py: -------------------------------------------------------------------------------- 1 | """Data loading utilities for training and validation.""" 2 | 3 | import logging 4 | import os 5 | import pickle 6 | from typing import Optional, Tuple 7 | 8 | import numpy as np 9 | import torch 10 | from utils import get_root_py_fpath 11 | 12 | from smol_trainer.config import TrainConfig 13 | 14 | 15 | def load_data( 16 | logger: logging.Logger, 17 | dataset: str, 18 | ) -> Tuple[np.memmap, np.memmap, Optional[int]]: 19 | """Load training and validation data.""" 20 | data_dir = os.path.join(get_root_py_fpath(), "data", dataset) 21 | train_data = np.memmap( 22 | os.path.join(get_root_py_fpath(), data_dir, "train.bin"), 23 | dtype=np.uint16, 24 | mode="r", 25 | ) 26 | val_data = np.memmap( 27 | os.path.join(get_root_py_fpath(), data_dir, "val.bin"), 28 | dtype=np.uint16, 29 | mode="r", 30 | ) 31 | # attempt to derive vocab_size from the dataset 32 | meta_path = os.path.join(data_dir, "meta.pkl") 33 | meta_vocab_size = None 34 | if os.path.exists(meta_path): 35 | with open(meta_path, "rb") as f: 36 | meta = pickle.load(f) 37 | meta_vocab_size = meta["vocab_size"] 38 | logger.info( 39 | f"Found vocab_size = {meta_vocab_size} (inside {meta_path})" 40 | ) 41 | 42 | return (train_data, val_data, meta_vocab_size) 43 | 44 | 45 | def get_batch( 46 | config: TrainConfig, data: np.memmap 47 | ) -> Tuple[torch.Tensor, torch.Tensor]: 48 | """Get a batch of data from either the training or validation set.""" 49 | ix = torch.randint(len(data) - config.block_size, (config.batch_size,)) 50 | x = torch.stack( 51 | [ 52 | torch.from_numpy( 53 | (data[i : i + config.block_size]).astype(np.int64) 54 | ) 55 | for i in ix 56 | ] 57 | ) 58 | y = torch.stack( 59 | [ 60 | torch.from_numpy( 61 | (data[i + 1 : i + 1 + config.block_size]).astype(np.int64) 62 | ) 63 | for i in ix 64 | ] 65 | ) 66 | if config.device_type == "cuda": 67 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 68 | x, y = x.pin_memory().to( 69 | config.device, non_blocking=True 70 | ), y.pin_memory().to(config.device, non_blocking=True) 71 | else: 72 | x, y = x.to(config.device), y.to(config.device) 73 | return x, y 74 | -------------------------------------------------------------------------------- /smol_trainer/trainer/initializer.py: -------------------------------------------------------------------------------- 1 | """Contains the logic for initializing a model.""" 2 | 3 | import argparse 4 | import logging 5 | import os 6 | from typing import Any, Optional, Tuple 7 | 8 | import torch 9 | from torch.nn import Module 10 | 11 | from smol_trainer.model import GPT 12 | from smol_trainer.trainer.checkpointer import get_checkpoint_prefix 13 | 14 | 15 | def configure_optimizers(model: Module, weight_decay, learning_rate, betas): 16 | # start with all of the candidate parameters 17 | param_dict = {pn: p for pn, p in model.named_parameters()} 18 | # filter out those that do not require grad 19 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 20 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 21 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 22 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 23 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 24 | optim_groups = [ 25 | {"params": decay_params, "weight_decay": weight_decay}, 26 | {"params": nodecay_params, "weight_decay": 0.0}, 27 | ] 28 | num_decay_params = sum(p.numel() for p in decay_params) 29 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 30 | print( 31 | f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters" 32 | ) 33 | print( 34 | f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters" 35 | ) 36 | # --- DISABLING use_fused due to issue in nightly --- 37 | # Create AdamW optimizer and use the fused version if it is available 38 | # fused_available = ( 39 | # "fused" in inspect.signature(torch.optim.AdamW).parameters 40 | # ) 41 | # use_fused = fused_available and device_type == "cuda" 42 | # extra_args = dict(fused=True) if use_fused else dict() 43 | # optimizer = torch.optim.AdamW( 44 | # optim_groups, fused=False, lr=learning_rate, betas=betas, **extra_args 45 | # ) 46 | # print(f"using fused AdamW: {use_fused}") 47 | # optimizer = torch.optim.AdamW( 48 | # optim_groups, fused=False, lr=learning_rate, betas=betas, **extra_args 49 | # ) 50 | 51 | return torch.optim.AdamW( 52 | optim_groups, 53 | fused=False, 54 | lr=learning_rate, 55 | betas=betas, 56 | ) 57 | 58 | 59 | def initialize_optimizer( 60 | args: argparse.Namespace, model: Module, checkpoint: Any = None 61 | ): 62 | """Initialize optimizer and load its state if resuming from a checkpoint.""" 63 | 64 | optimizer = configure_optimizers( 65 | model, 66 | args.weight_decay, 67 | args.initial_lr, 68 | (args.beta1, args.beta2), 69 | ) 70 | if checkpoint and args.init_from == "resume": 71 | optimizer.load_state_dict(checkpoint["optimizer"]) 72 | return optimizer 73 | 74 | 75 | def initialize_model_from_scratch( 76 | args: argparse.Namespace, 77 | logger: logging.Logger, 78 | ) -> Module: 79 | """Initialize a new model from scratch.""" 80 | 81 | logger.info(f"Running model {args.model_name}") 82 | 83 | if args.iter_num != 0: 84 | raise ValueError("iter_num must be 0 to initialize from scratch") 85 | 86 | return GPT.from_name(args.model_name) 87 | 88 | 89 | def initialize_model_from_checkpoint( 90 | args: argparse.Namespace, logger: logging.Logger 91 | ) -> Tuple[Module, Any]: # TODO - Find a correct type for the checkpoint. 92 | """Resume training from a checkpoint.""" 93 | 94 | logger.info(f"Resuming training from {args.out_dir}") 95 | 96 | if args.ckpt_path_override: 97 | checkpoint = torch.load( 98 | args.ckpt_path_override, map_location=args.device 99 | ) 100 | else: 101 | try: 102 | checkpoint_prefix = get_checkpoint_prefix( 103 | {"run_name": args.run_name, "model_name": args.model_name} 104 | ) 105 | model_path = os.path.join( 106 | args.out_dir, 107 | checkpoint_prefix, 108 | f"{checkpoint_prefix}__iter_num_{args.iter_num}.pt", 109 | ) 110 | if not os.path.exists(model_path): 111 | raise ValueError( 112 | f"Checkpoint path {model_path} does not exist" 113 | ) 114 | 115 | checkpoint = torch.load(model_path, map_location=args.device) 116 | args.iter_num = checkpoint["iter_num"] 117 | assert ( 118 | args.iter_num == checkpoint["iter_num"] 119 | ), "Iteration numbers do not match!" 120 | 121 | except Exception as e: 122 | logger.error( 123 | "Encountered an error {e} while attempting to load model checkpoint" 124 | ) 125 | raise e 126 | 127 | model = GPT.from_name(args.model_name, block_size=args.block_size) 128 | state_dict = checkpoint["model"] 129 | 130 | unwanted_prefix = "_orig_mod." 131 | for k, v in list(state_dict.items()): 132 | if k.startswith(unwanted_prefix): 133 | state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) 134 | model.load_state_dict(state_dict) 135 | 136 | # Update args to reflect latest numbers from checkpoint 137 | args.best_val_loss = checkpoint["best_val_loss"] 138 | 139 | return model, checkpoint 140 | -------------------------------------------------------------------------------- /smol_trainer/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for the smol_trainer package.""" 2 | import argparse 3 | import logging 4 | import os 5 | from dataclasses import fields, is_dataclass 6 | 7 | 8 | def get_root_py_fpath() -> str: 9 | """Get the path to the root of the python code directory.""" 10 | 11 | return os.path.dirname(os.path.realpath(__file__)) 12 | 13 | 14 | def get_root_fpath() -> str: 15 | """Get the path to the root of the smol_trainer directory.""" 16 | 17 | return os.path.join(get_root_py_fpath(), "..") 18 | 19 | 20 | def get_configured_logger(name: str, log_level: str) -> logging.Logger: 21 | """Get a configured logger.""" 22 | 23 | log_level = getattr(logging, log_level.upper(), "INFO") 24 | logging.basicConfig( 25 | level=log_level, format="%(asctime)s - %(levelname)s - %(message)s" 26 | ) 27 | return logging.getLogger(name) 28 | 29 | 30 | # TODO - Break this into multiple functions and/or configs 31 | # Define the argument parser 32 | def parse_args(): 33 | parser = argparse.ArgumentParser( 34 | description="Training script for the GPT model" 35 | ) 36 | 37 | # I/O arguments 38 | parser.add_argument( 39 | "--config-file", default="", type=str, help="Configuration file" 40 | ) 41 | parser.add_argument( 42 | "--out-dir", default="results", type=str, help="Output directory" 43 | ) 44 | parser.add_argument( 45 | "--eval-interval", default=100, type=int, help="Evaluation interval" 46 | ) 47 | parser.add_argument( 48 | "--log-interval", default=1, type=int, help="Log interval" 49 | ) 50 | parser.add_argument( 51 | "--log-level", default="INFO", type=str, help="Log level" 52 | ) 53 | parser.add_argument( 54 | "--eval-iters", default=200, type=int, help="Evaluation iterations" 55 | ) 56 | parser.add_argument( 57 | "--always-save-checkpoint", 58 | default=True, 59 | action="store_true", 60 | help="Always save checkpoint after each evaluation", 61 | ) 62 | 63 | # Run parameters 64 | parser.add_argument( 65 | "--init-from", 66 | default="scratch", 67 | type=str, 68 | choices=["scratch", "resume"], 69 | help="Initialization mode: scratch, resume or gpt2*", 70 | ) 71 | parser.add_argument( 72 | "--ckpt-path-override", 73 | default=None, 74 | type=str, 75 | help="Path to the model", 76 | ) 77 | parser.add_argument( 78 | "--iter-num", 79 | default=0, 80 | type=int, 81 | help="Iteration number, used when resuming training", 82 | ) 83 | parser.add_argument( 84 | "--run-name", default="run_0", type=str, help="Specify the run name." 85 | ) 86 | 87 | # WandB logging 88 | parser.add_argument( 89 | "--wandb-log", 90 | default=False, 91 | action="store_true", 92 | help="Enable W&B logging", 93 | ) 94 | 95 | # Data arguments 96 | parser.add_argument( 97 | "--dataset", default="openwebtext", type=str, help="Dataset name" 98 | ) 99 | parser.add_argument( 100 | "--gradient-accumulation-steps", 101 | default=5 * 8, 102 | type=int, 103 | help="Steps for gradient accumulation", 104 | ) 105 | parser.add_argument( 106 | "--batch-size", default=12, type=int, help="Batch size" 107 | ) 108 | parser.add_argument( 109 | "--block-size", default=1024, type=int, help="Block size" 110 | ) 111 | 112 | # Model arguments 113 | parser.add_argument( 114 | "--model-name", 115 | default="pythia-70m", 116 | type=str, 117 | help="The name of the model to run with.", 118 | ) 119 | 120 | # Optimizer arguments 121 | parser.add_argument( 122 | "--initial-lr", 123 | default=6e-4, 124 | type=float, 125 | help="Learning rate", 126 | ) 127 | parser.add_argument( 128 | "--max-iters", 129 | default=600000, 130 | type=int, 131 | help="Maximum number of training iterations", 132 | ) 133 | parser.add_argument( 134 | "--weight-decay", 135 | default=1e-1, 136 | type=float, 137 | help="Weight decay for optimizer", 138 | ) 139 | parser.add_argument( 140 | "--beta1", default=0.9, type=float, help="Beta1 for optimizer" 141 | ) 142 | parser.add_argument( 143 | "--beta2", default=0.95, type=float, help="Beta2 for optimizer" 144 | ) 145 | parser.add_argument( 146 | "--grad-clip", default=1.0, type=float, help="Gradient clipping value" 147 | ) 148 | parser.add_argument( 149 | "--do-flash-v2", 150 | default=False, 151 | action="store_true", 152 | help="Use flash v2 calculation (Requires A100 or better).", 153 | ) 154 | 155 | # Learning rate decay settings 156 | parser.add_argument( 157 | "--decay-lr", 158 | default=True, 159 | action="store_true", 160 | help="Enable learning rate decay", 161 | ) 162 | parser.add_argument( 163 | "--warmup-iters", 164 | default=2000, 165 | type=int, 166 | help="Number of warmup iterations", 167 | ) 168 | parser.add_argument( 169 | "--lr-decay-iters", 170 | default=600000, 171 | type=int, 172 | help="Learning rate decay iterations", 173 | ) 174 | parser.add_argument( 175 | "--min-lr", default=6e-5, type=float, help="Minimum learning rate" 176 | ) 177 | parser.add_argument( 178 | "--max-checkpoints", 179 | default=5, 180 | type=int, 181 | help="Maximum checkpoints to hold.", 182 | ) 183 | 184 | # DDP settings 185 | parser.add_argument( 186 | "--backend", 187 | default="nccl", 188 | type=str, 189 | choices=["nccl", "gloo"], 190 | help="Backend for DDP", 191 | ) 192 | 193 | # System arguments 194 | def str2bool(v): 195 | if isinstance(v, bool): 196 | return v 197 | if v.lower() in ("yes", "true", "t", "y", "1"): 198 | return True 199 | elif v.lower() in ("no", "false", "f", "n", "0"): 200 | return False 201 | else: 202 | raise argparse.ArgumentTypeError("Boolean value expected.") 203 | 204 | # Other arguments 205 | parser.add_argument( 206 | "--device", default="cuda", type=str, help="Device to use for training" 207 | ) 208 | parser.add_argument( 209 | "--dtype", 210 | default="float16", 211 | type=str, 212 | choices=["float32", "bfloat16", "float16"], 213 | help="Data type for training", 214 | ) 215 | parser.add_argument( 216 | "--compile", 217 | type=str2bool, 218 | default=True, 219 | help="Use PyTorch 2.0 to compile the model to be faster.", 220 | ) 221 | 222 | return parser.parse_args() 223 | 224 | 225 | def custom_asdict(obj) -> dict: 226 | import _thread 227 | 228 | if is_dataclass(obj): 229 | result = {} 230 | for f in fields(obj): 231 | value = getattr(obj, f.name) 232 | # Check if value is a thread lock or any other non-pickleable type 233 | # Modify this check as per your needs 234 | if isinstance(value, _thread.RLock): 235 | result[f.name] = "Skipped due to non-pickleable type" 236 | else: 237 | result[f.name] = custom_asdict(value) 238 | return result 239 | elif isinstance(obj, (list, tuple)): 240 | return [custom_asdict(x) for x in obj] 241 | elif isinstance(obj, dict): 242 | return {k: custom_asdict(v) for k, v in obj.items()} 243 | else: 244 | return obj 245 | --------------------------------------------------------------------------------