├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── configs ├── demo.yaml ├── pico-decoder-large.yaml ├── pico-decoder-medium.yaml ├── pico-decoder-small.yaml └── pico-decoder-tiny.yaml ├── pyproject.toml ├── scripts └── train.py ├── setup.sh └── src ├── checkpointing ├── __init__.py ├── evaluation.py ├── learning_dynamics.py └── training.py ├── config ├── __init__.py ├── _constants.py ├── checkpointing_config.py ├── data_config.py ├── evaluation_config.py ├── model_config.py ├── monitoring_config.py └── training_config.py ├── evaluation ├── __init__.py └── tasks │ └── paloma.py ├── model ├── __init__.py └── pico_decoder.py └── training ├── trainer.py └── utils ├── __init__.py ├── data.py ├── initialization.py ├── io.py └── logging.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | poetry.lock 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 112 | .pdm.toml 113 | .pdm-python 114 | .pdm-build/ 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | # Data 167 | data/ 168 | 169 | # Checkpoint and Logging Directorries 170 | runs/ 171 | wandb/ 172 | configs/ 173 | 174 | .vscode/ 175 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.7.1 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | args: [ --fix, --extend-select, I ] 9 | # Run the formatter. 10 | - id: ruff-format -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🚀 **Pico Train** 2 | 3 | Pico Train is a lightweight framework for training language models—from tiny-scale (~1M parameters) to mid-scale (~1B parameters)—with built-in rich checkpointing that captures activations, gradients, and model states, enabling detailed learning dynamics research. 4 | 5 | Our **suite of pre-trained models** is already publicly available on our [Hugging Face organization](https://huggingface.co/pico-lm), and a dedicated companion library for advanced analysis—[**pico-analyze**](https://github.com/pico-lm/pico-analyze)—is fully released for deeper checkpoint studies. 6 | 7 | > For a **detailed run-through**, check out the **full tutorial** on our website at [picolm.io](https://picolm.io). 8 | 9 | --- 10 | 11 | ## **Key Features** 12 | 13 | 1. **Pico Decoder: LLAMA-style Transformer Architecture** 14 | - RMSNorm, RoPE, multi-head self-attention with KV-cache, and SwiGLU activations 15 | - Currently supports the **pico-decoder** model, with future expansions planned (pico-diffusion, pico-statespace, etc.) 16 | 17 | 2. **Comprehensive Checkpoints** 18 | - Saves model states, optimizer states, and training metadata 19 | - Enriched with **activation and gradient** snapshots for interpretability 20 | 21 | 3. **Focused Scale Range** 22 | - Optimized to train models from **1M to 1B parameters**, where learning dynamics research is most viable 23 | 24 | 4. **Clean, Pre-tokenized Data** 25 | - Uses a pre-tokenized, pre-shuffled version of [Dolma](https://allenai.org/dolma) that we make available on [Hugging Face](https://huggingface.co/datasets/pico-lm/pretokenized-dolma) 26 | - Facilitates training models using identical data for **consistency** and **comparability** 27 | 28 | 6. **Research Ready** 29 | - Minimal, well-documented code suitable for **forking and tailoring** 30 | - Logs essential metrics (e.g. perplexity) throughout training 31 | - Works seamlessly with [pico-analyze](https://github.com/pico-lm/pico-analyze) for advanced post-training interpretation 32 | 33 | --- 34 | 35 | ## **Training Philosophy** 36 | 37 | All models in the Pico suite (both pre-trained and user-trained): 38 | 39 | - Employ **identical architectures** and **optimizer settings** 40 | - **Share** the same data order and tokens 41 | - Automatically log **rich checkpoint data** (including activations, gradients) 42 | - Facilitate **direct cross-scale comparisons** 43 | 44 | This uniformity means you can isolate model size as the primary variable, giving you clearer insights into **how model capacity affects learning**. 45 | 46 | --- 47 | 48 | ## **Resources** 49 | 50 | - **Pre-trained Models** (1M–1B parameters), publicly hosted on [Hugging Face](https://huggingface.co/pico-lm) 51 | - **Pre-tokenized Datasets** for straightforward streaming-based training 52 | - **Extensive Checkpoints** logging activation and gradient snapshots 53 | - **Evaluation Metrics** (perplexity and more) tracked at each checkpoint 54 | 55 | --- 56 | 57 | ## **Core Components** 58 | 59 | - **Pico-Decoder Model** 60 | - LLAMA-style auto-regressive transformer 61 | - RMSNorm 62 | - RoPE (Rotary Positional Embeddings) 63 | - Multi-head attention with KV-cache 64 | - SwiGLU activation 65 | 66 | *Future plans include additional architectures like pico-diffusion and pico-statespace.* 67 | 68 | - **Training & Checkpointing** 69 | - Automatic storage of model and optimizer states 70 | - Periodic hooks for saving **learning dynamics** (activations, gradients) 71 | - Optional logging to Weights & Biases 72 | 73 | - **Config-Driven Setup** 74 | - Specify architecture, optimizer, dataset, and logging settings in YAML 75 | - Straightforward to extend or modify 76 | 77 | --- 78 | 79 | ## **Quick Start** 80 | 81 | 1. **Clone the Repository** 82 | 83 | ```bash 84 | git clone https://github.com/pico-lm/pico-train 85 | cd pico-train 86 | ``` 87 | 88 | 2. **Configure Environment** 89 | 90 | Create a `.env` file at the root with your Hugging Face and Weights & Biases tokens: 91 | ```bash 92 | export HF_TOKEN=your_huggingface_token 93 | export WANDB_API_KEY=your_wandb_key 94 | ``` 95 | 96 | 3. **Install Dependencies** 97 | 98 | ```bash 99 | source setup.sh 100 | ``` 101 | This script checks your environment, installs necessary tools, and sets up a Poetry virtual environment. 102 | 103 | 4. **Train Your Model Suite** 104 | 105 | - Edit (or create) a config file (e.g., `configs/demo.yaml`) to specify your architecture and training preferences. 106 | - Then run: 107 | ```bash 108 | poetry run train --config_path configs/demo.yaml 109 | ``` 110 | - This launches training, automatically checkpointing states and saving learning dynamics data. 111 | 112 | 5. **Explore Checkpoints** 113 | - By default, checkpoints are stored under `runs/YOUR_RUN_NAME/checkpoints/`. 114 | - Each checkpoint contains: 115 | - **Model state** (PyTorch + Hugging Face formats) 116 | - **Optimizer state** 117 | - **Gradients and activations** for interpretability 118 | - **Evaluation logs** (e.g. perplexity) and metrics 119 | 120 | --- 121 | 122 | ## **Repository Structure** 123 | 124 | - **`src/model/pico_decoder.py`** 125 | - Core LLAMA-style decoder implementation (attention, RMSNorm, RoPE, etc.) 126 | 127 | - **`src/training/trainer.py`** 128 | - Main training loop 129 | - Manages distributed and multi-node settings 130 | - Collects/logs metrics 131 | - Orchestrates checkpoint saving 132 | 133 | - **`src/checkpointing`** 134 | - Logic for saving model states, gradients, activations 135 | - Tools for uploading checkpoints to Hugging Face 136 | 137 | - **`src/config`** 138 | - Flexible Dataclass-based config system (model and training hyperparameters, checkpointing, logging) 139 | 140 | - **`configs/demo.yaml`** 141 | - Example config with default values for quick experimentation 142 | 143 | --- 144 | 145 | ## **Advanced Analysis with Pico Analyze** 146 | 147 | For deeper checkpoint analysis—comparing gradients, tracking representation shifts, measuring sparsity—use our companion repository [**pico-analyze**](https://github.com/pico-lm/pico-analyze). It automatically processes **pico-train** checkpoints and applies advanced metrics like **CKA**, **PWCCA**, **Gini**, **Hoyer**, and more to reveal **how** your models learn over time. 148 | 149 | --- 150 | 151 | ## **License** 152 | 153 | Pico is open-source under the [Apache License 2.0](LICENSE). 154 | 155 | --- 156 | 157 | ## **Citation** 158 | 159 | If you use **Pico** in your research, please cite: 160 | 161 | ```bibtex 162 | @software{pico2025, 163 | author = {Diehl Martinez, Richard}, 164 | title = {Pico: A Lightweight Framework for Studying Language Model Learning Dynamics}, 165 | year = {2025}, 166 | url = {https://github.com/pico-lm} 167 | } 168 | ``` 169 | 170 | **Happy Training!** For more information and tutorials, visit our website at [picolm.io](https://picolm.io). 171 | -------------------------------------------------------------------------------- /configs/demo.yaml: -------------------------------------------------------------------------------- 1 | # Demo config file 2 | # You can follow this template to create your own config file 3 | # Refer to the config files in the configs/ directory to see all the available options 4 | 5 | data: 6 | dataloader: 7 | batch_size: 32 8 | 9 | checkpointing: 10 | run_name: "pico-decoder-demo-1" 11 | save_every_n_steps: 50 12 | 13 | save_to_hf: true 14 | hf_checkpoint: 15 | repo_id: "pico-lm/demo" 16 | 17 | learning_dynamics: 18 | batch_size: 16 19 | 20 | model: 21 | d_model: 96 22 | activation_hidden_dim: 384 23 | 24 | evaluation: 25 | paloma: 26 | batch_size: 32 27 | 28 | monitoring: 29 | 30 | save_to_wandb: true 31 | wandb: 32 | project: "pico-demo" 33 | entity: "pico-lm" 34 | 35 | logging: 36 | log_every_n_steps: 10 37 | 38 | training: 39 | max_steps: 100 40 | 41 | optimization: 42 | lr: 0.001 43 | lr_warmup_steps: 30 44 | 45 | gradient_accumulation_steps: 2 46 | 47 | fabric: 48 | num_devices: 1 49 | -------------------------------------------------------------------------------- /configs/pico-decoder-large.yaml: -------------------------------------------------------------------------------- 1 | # Demo config file 2 | # You can follow this template to create your own config file 3 | # Refer to the config files in the configs/ directory to see all the available options 4 | 5 | checkpointing: 6 | run_name: "pico-decoder-large-1" 7 | save_to_hf: true 8 | hf_checkpoint: 9 | repo_id: "pico-lm/pico-decoder-large" 10 | 11 | learning_dynamics: 12 | batch_size: 128 13 | 14 | model: 15 | d_model: 1536 16 | activation_hidden_dim: 6144 17 | 18 | monitoring: 19 | save_to_wandb: true 20 | wandb: 21 | project: "pico-decoder" 22 | entity: "pico-lm" 23 | 24 | training: 25 | optimization: 26 | gradient_accumulation_steps: 8 27 | 28 | fabric: 29 | num_nodes: 4 30 | num_devices: 4 31 | 32 | evaluation: 33 | paloma: 34 | batch_size: 16 35 | -------------------------------------------------------------------------------- /configs/pico-decoder-medium.yaml: -------------------------------------------------------------------------------- 1 | # Demo config file 2 | # You can follow this template to create your own config file 3 | # Refer to the config files in the configs/ directory to see all the available options 4 | 5 | checkpointing: 6 | run_name: "pico-decoder-medium-1" 7 | save_to_hf: true 8 | hf_checkpoint: 9 | repo_id: "pico-lm/pico-decoder-medium" 10 | 11 | learning_dynamics: 12 | batch_size: 128 13 | 14 | model: 15 | d_model: 768 16 | activation_hidden_dim: 3072 17 | 18 | monitoring: 19 | save_to_wandb: true 20 | wandb: 21 | project: "pico-decoder" 22 | entity: "pico-lm" 23 | 24 | training: 25 | optimization: 26 | gradient_accumulation_steps: 8 27 | 28 | fabric: 29 | num_nodes: 4 30 | num_devices: 4 31 | 32 | evaluation: 33 | paloma: 34 | batch_size: 16 35 | -------------------------------------------------------------------------------- /configs/pico-decoder-small.yaml: -------------------------------------------------------------------------------- 1 | # Demo config file 2 | # You can follow this template to create your own config file 3 | # Refer to the config files in the configs/ directory to see all the available options 4 | 5 | checkpointing: 6 | run_name: "pico-decoder-small-1" 7 | save_to_hf: true 8 | hf_checkpoint: 9 | repo_id: "pico-lm/pico-decoder-small" 10 | 11 | learning_dynamics: 12 | batch_size: 128 13 | 14 | model: 15 | d_model: 384 16 | activation_hidden_dim: 1536 17 | 18 | monitoring: 19 | save_to_wandb: true 20 | wandb: 21 | project: "pico-decoder" 22 | entity: "pico-lm" 23 | 24 | training: 25 | optimization: 26 | gradient_accumulation_steps: 8 27 | 28 | fabric: 29 | num_nodes: 4 30 | num_devices: 4 31 | 32 | evaluation: 33 | paloma: 34 | batch_size: 16 35 | -------------------------------------------------------------------------------- /configs/pico-decoder-tiny.yaml: -------------------------------------------------------------------------------- 1 | # Demo config file 2 | # You can follow this template to create your own config file 3 | # Refer to the config files in the configs/ directory to see all the available options 4 | 5 | checkpointing: 6 | run_name: "pico-decoder-tiny-1" 7 | save_to_hf: true 8 | hf_checkpoint: 9 | repo_id: "pico-lm/pico-decoder-tiny" 10 | 11 | learning_dynamics: 12 | batch_size: 256 13 | 14 | model: 15 | d_model: 96 16 | activation_hidden_dim: 384 17 | 18 | monitoring: 19 | save_to_wandb: true 20 | wandb: 21 | project: "pico-decoder" 22 | entity: "pico-lm" 23 | 24 | training: 25 | optimization: 26 | gradient_accumulation_steps: 4 27 | 28 | fabric: 29 | num_nodes: 4 30 | num_devices: 4 31 | 32 | evaluation: 33 | paloma: 34 | batch_size: 32 35 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pico-train" 3 | version = "1.0.0" 4 | description = "A minimalistic framework for transparently training language models and storing comprehensive checkpoints for in-depth learning dynamics research" 5 | authors = ["Richard Diehl Martinez "] 6 | license = "Apache 2.0" 7 | readme = "README.md" 8 | packages = [{include = "src"}] 9 | 10 | [tool.poetry.scripts] 11 | train = "scripts.train:main" 12 | 13 | [tool.poetry.dependencies] 14 | python = "^3.10,<3.13" 15 | lightning = "^2.4.0" 16 | click = "^8.1.7" 17 | wandb = "^0.18.1" 18 | huggingface-hub = {extras = ["cli"], version = "^0.25.1"} 19 | datasets = "^3.0.1,<3.2.0" 20 | transformers = "^4.45.2" 21 | pre-commit = "^4.0.1" 22 | torch = "^2.5.1" 23 | evaluate = "^0.4.3" 24 | deepspeed = "^0.16.2" 25 | rich = "^13.9.4" 26 | 27 | [tool.poetry.group.dev.dependencies] 28 | ipykernel = "^6.29.5" 29 | jupyter = "^1.1.1" 30 | 31 | [build-system] 32 | requires = ["poetry-core"] 33 | build-backend = "poetry.core.masonry.api" 34 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | A minimal script to train the Pico language model. In practice, you should just use the 4 | `poetry run train` command to run the training pipeline. Doing so will invoke this script. 5 | Training logic is located in `src/training/trainer.py`. 6 | """ 7 | 8 | from pathlib import Path 9 | 10 | import click 11 | 12 | from src.training.trainer import Trainer 13 | 14 | 15 | @click.command() 16 | @click.option( 17 | "--config_path", 18 | "config_path", 19 | type=click.Path(exists=True, path_type=Path), 20 | help="Path to the training configuration file", 21 | ) 22 | def main(config_path: Path) -> None: 23 | """Train the Pico language model using the specified configuration.""" 24 | 25 | trainer = Trainer(config_path=str(config_path)) 26 | trainer.train() 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script sets up the project by installing dependencies, checking for a poetry environment, 3 | # and installing pre-commit hooks. 4 | 5 | # Add color and formatting variables at the top 6 | GREEN='\033[0;32m' 7 | BLUE='\033[0;34m' 8 | YELLOW='\033[1;33m' 9 | RED='\033[0;31m' 10 | NC='\033[0m' # No Color 11 | BOLD='\033[1m' 12 | 13 | # Initialize error tracking 14 | ERRORS_FOUND=0 15 | 16 | # Function for section headers 17 | print_section() { 18 | echo -e "\n${BOLD}${BLUE}=== $1 ===${NC}\n" 19 | } 20 | 21 | # Function for success messages 22 | print_success() { 23 | echo -e "${GREEN}✓ $1${NC}" 24 | } 25 | 26 | # Function for warnings 27 | print_warning() { 28 | echo -e "${YELLOW}⚠ $1${NC}" 29 | } 30 | 31 | # --- GIT LFS SETUP --- # 32 | print_section "Git LFS Setup" 33 | if ! command -v git-lfs &> /dev/null; then 34 | print_warning "git-lfs is not installed. Some model checkpointing functionality may not work correctly." 35 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 36 | 37 | # Check the operating system 38 | if [[ "$OSTYPE" == "darwin"* ]]; then 39 | # macOS 40 | echo -e "${YELLOW} You can install it using Homebrew:${NC}" 41 | echo " brew install git-lfs" 42 | elif [[ "$OSTYPE" == "linux-gnu"* ]]; then 43 | # Linux 44 | echo -e "${YELLOW} You can install it using your package manager:${NC}" 45 | if command -v apt-get &> /dev/null; then 46 | # Ubuntu/Debian 47 | echo " curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash" 48 | echo " sudo apt-get install git-lfs" 49 | elif command -v yum &> /dev/null; then 50 | # CentOS/RHEL 51 | echo " curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash" 52 | echo " sudo yum install git-lfs" 53 | else 54 | print_warning "Could not detect package manager. Please install git-lfs manually." 55 | fi 56 | else 57 | print_warning "Unsupported operating system. Please install git-lfs manually." 58 | fi 59 | else 60 | git-lfs install 61 | print_success "git-lfs installed and initialized" 62 | fi 63 | 64 | # --- CUDA VERSION CHECK --- # 65 | print_section "CUDA Version Check" 66 | if command -v nvidia-smi &> /dev/null; then 67 | CUDA_VERSION=$(nvidia-smi | sed -n 's/.*CUDA Version: \([0-9.]*\).*/\1/p') 68 | 69 | if [[ -z "$CUDA_VERSION" ]]; then 70 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 71 | print_warning "nvidia-smi failed to communicate with the NVIDIA driver." 72 | echo -e "${YELLOW} Ensure that the latest NVIDIA driver is installed and running.${NC}" 73 | else 74 | MAJOR_VERSION=${CUDA_VERSION%.*} 75 | MINOR_VERSION=${CUDA_VERSION#*.} 76 | 77 | if [ "$MAJOR_VERSION" -lt 12 ] || ([ "$MAJOR_VERSION" -eq 12 ] && [ "$MINOR_VERSION" -lt 1 ]); then 78 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 79 | print_warning "CUDA version ${MAJOR_VERSION}.${MINOR_VERSION} detected." 80 | echo -e "${YELLOW} Some multi-node communication GPU features may not work properly.${NC}" 81 | echo -e "${YELLOW} CUDA version 12.1 or newer is recommended.${NC}" 82 | else 83 | print_success "CUDA version ${MAJOR_VERSION}.${MINOR_VERSION} detected" 84 | fi 85 | fi 86 | else 87 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 88 | print_warning "nvidia-smi not found. Unable to check CUDA version." 89 | echo -e "${YELLOW} Ensure that NVIDIA drivers and CUDA version at 12.1 or newer are installed for GPU support.${NC}" 90 | fi 91 | 92 | 93 | # ---- ENVIRONMENT VARIABLES ---- # 94 | print_section "Environment Variables" 95 | if [ -f .env ]; then 96 | print_success "Loading environment variables from .env..." 97 | source .env 98 | if [[ -n "$HF_TOKEN" && -n "$WANDB_API_KEY" ]]; then 99 | print_success "Both HF_TOKEN and WANDB_API_KEY are set and loaded!" 100 | else 101 | print_warning "One or both of HF_TOKEN and WANDB_API_KEY are not set." 102 | fi 103 | else 104 | print_warning "No .env file found." 105 | echo -e "${YELLOW} You might need to create one with HF_TOKEN and WANDB_API_KEY${NC}" 106 | echo -e "${YELLOW} Example .env contents:${NC}" 107 | echo " export HF_TOKEN=your_huggingface_token" 108 | echo " export WANDB_API_KEY=your_wandb_key" 109 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 110 | fi 111 | 112 | # ---- POETRY SETUP ---- # 113 | print_section "Poetry Setup" 114 | 115 | # First check if Poetry is installed 116 | if ! command -v poetry &> /dev/null; then 117 | echo "Poetry not found. Installing..." 118 | 119 | # Run the installation command 120 | curl -sSL https://install.python-poetry.org | python3 - 121 | POETRY_INSTALL_STATUS=$? 122 | 123 | if [ $POETRY_INSTALL_STATUS -ne 0 ]; then 124 | print_warning "Poetry installation failed!" 125 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 126 | else 127 | export PATH="$HOME/.local/bin:$PATH" 128 | 129 | # Verify installation succeeded 130 | if ! command -v poetry &> /dev/null; then 131 | print_warning "Poetry was installed but cannot be found in PATH!" 132 | echo -e "${YELLOW} Try adding this to your shell profile:${NC}" 133 | echo " export PATH=\"\$HOME/.local/bin:\$PATH\"" 134 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 135 | else 136 | print_success "Poetry installed successfully" 137 | fi 138 | fi 139 | else 140 | print_success "Poetry already installed" 141 | fi 142 | 143 | # Then check for virtual environment 144 | if [ ! -d ".venv" ]; then 145 | echo "No virtual environment found. Creating one..." 146 | poetry config virtualenvs.in-project true 147 | 148 | # Create virtual environment and install dependencies 149 | poetry install --with dev 150 | POETRY_VENV_STATUS=$? 151 | 152 | if [ $POETRY_VENV_STATUS -ne 0 ]; then 153 | print_warning "Failed to create Poetry virtual environment!" 154 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 155 | else 156 | print_success "Poetry environment created successfully" 157 | fi 158 | else 159 | print_success "Poetry environment already exists" 160 | fi 161 | 162 | # ---- PRE-COMMIT SETUP ---- # 163 | print_section "Pre-commit Setup" 164 | 165 | # Install pre-commit hooks 166 | echo "Installing pre-commit hooks..." 167 | poetry run pre-commit install 168 | if [ $? -ne 0 ]; then 169 | print_warning "Failed to install pre-commit hooks!" 170 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 171 | else 172 | print_success "Pre-commit hooks installed" 173 | fi 174 | 175 | # Run pre-commit hooks on all files 176 | echo "Running pre-commit hooks on all files..." 177 | poetry run pre-commit run --all-files 178 | if [ $? -ne 0 ]; then 179 | print_warning "Pre-commit encountered issues with some files" 180 | ERRORS_FOUND=$((ERRORS_FOUND + 1)) 181 | else 182 | print_success "Pre-commit initial run complete" 183 | fi 184 | 185 | # --- Final Status Message --- # 186 | 187 | # Final status message 188 | print_section "Setup Status" 189 | if [ $ERRORS_FOUND -eq 0 ]; then 190 | print_success "Setup Complete! 🎉" 191 | print_success "To activate the virtual environment, run: poetry env activate" 192 | else 193 | print_warning "Setup completed with warnings and errors! Please check the messages above." 194 | echo -e "${YELLOW} ${ERRORS_FOUND} issue(s) were detected that may affect functionality.${NC}" 195 | if [ -d ".venv" ]; then 196 | echo -e "${YELLOW} You can still activate the environment with: poetry env activate${NC}" 197 | else 198 | echo -e "${RED} The virtual environment setup failed. Fix the issues before proceeding.${NC}" 199 | fi 200 | fi -------------------------------------------------------------------------------- /src/checkpointing/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pico Checkpointing Package 3 | 4 | We subdivide the checkpointing into training, evaluation, and learning_dynamics. Training 5 | checkpoints store the model, optimizer, and learning rate scheduler. Evaluation checkpoints store 6 | the evaluation results on the defined metrics. Learning dynamics checkpoints store activations and gradients used for 7 | learning dynamics analysis. 8 | """ 9 | 10 | from .evaluation import save_evaluation_results 11 | from .learning_dynamics import ( 12 | compute_learning_dynamics_states, 13 | save_learning_dynamics_states, 14 | ) 15 | from .training import load_checkpoint, save_checkpoint 16 | 17 | __all__ = [ 18 | "compute_learning_dynamics_states", 19 | "load_checkpoint", 20 | "save_checkpoint", 21 | "save_evaluation_results", 22 | "save_learning_dynamics_states", 23 | ] 24 | -------------------------------------------------------------------------------- /src/checkpointing/evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for checkpointing evaluation-related states (i.e. evaluation results, etc.) 3 | 4 | We save the evaluation results in a JSON file at the step-specific evaluation results directory. 5 | """ 6 | 7 | import json 8 | import os 9 | from typing import Any, Dict 10 | 11 | from huggingface_hub import upload_folder 12 | from lightning.fabric import Fabric 13 | from lightning.fabric.utilities.rank_zero import rank_zero_only 14 | 15 | from src.config import CheckpointingConfig 16 | from src.training.utils.io import use_backoff 17 | 18 | 19 | @rank_zero_only 20 | @use_backoff() 21 | def save_evaluation_results( 22 | checkpointing_config: CheckpointingConfig, 23 | checkpoint_step: int, 24 | fabric: Fabric, 25 | evaluation_results: Dict[str, Any], 26 | ) -> None: 27 | """Save evaluation results to disk and optionally to HuggingFace Hub. 28 | 29 | The evaluation results are saved in the following directory structure: 30 | {checkpointing_config.runs_dir}/ 31 | └── {checkpointing_config.run_name}/ 32 | └── {checkpointing_config.eval_results_dir}/ 33 | └── step_{checkpoint_step}.json 34 | 35 | NOTE: this function is only called on rank 0 to avoid conflicts; assumes that the evaluation 36 | results are gathered on rank 0. 37 | 38 | Args: 39 | checkpointing_config: Configuration object containing checkpoint settings 40 | checkpoint_step: Current training checkpoint step (i.e. number of learning steps taken) 41 | fabric: Lightning Fabric instance 42 | evaluation_results: Dictionary containing evaluation metrics 43 | """ 44 | 45 | run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name) 46 | eval_results_dir = os.path.join( 47 | run_dir, checkpointing_config.evaluation.eval_results_dir 48 | ) 49 | 50 | os.makedirs(eval_results_dir, exist_ok=True) 51 | 52 | curr_eval_results_path = os.path.join( 53 | eval_results_dir, f"step_{checkpoint_step}.json" 54 | ) 55 | 56 | # save out as json 57 | with open(curr_eval_results_path, "w") as f: 58 | json.dump(evaluation_results, f) 59 | 60 | if checkpointing_config.save_to_hf: 61 | upload_folder( 62 | folder_path=eval_results_dir, 63 | path_in_repo=checkpointing_config.evaluation.eval_results_dir, 64 | repo_id=checkpointing_config.hf_checkpoint.repo_id, 65 | commit_message=f"Saving Evaluation Results -- Step {checkpoint_step}", 66 | revision=checkpointing_config.run_name, 67 | token=os.getenv("HF_TOKEN"), 68 | ) 69 | -------------------------------------------------------------------------------- /src/checkpointing/learning_dynamics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for checkpointing learning dynamics-related states (i.e. activations, weights, grads, etc.) 3 | 4 | We save the learning dynamics states in a subdirectory of the checkpointing directory. 5 | """ 6 | 7 | import os 8 | import re 9 | from typing import Dict, Optional 10 | 11 | import deepspeed 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from datasets import Dataset 16 | from huggingface_hub import upload_folder 17 | from lightning.fabric import Fabric 18 | from lightning.fabric.strategies import DeepSpeedStrategy 19 | from lightning.fabric.utilities.rank_zero import rank_zero_only 20 | from torch.nn import functional as F 21 | from torch.utils.data import DataLoader 22 | from transformers import PreTrainedTokenizerBase 23 | 24 | from src.config import CheckpointingConfig 25 | from src.config.checkpointing_config import LearningDynamicsCheckpointingConfig 26 | from src.training.utils.initialization import initialize_model 27 | from src.training.utils.io import use_backoff 28 | 29 | 30 | # NOTE: DeepSpeed requires a dummy optimizer to be passed in to the setup function 31 | class DummyOptimizer(optim.Optimizer): 32 | def __init__(self, params): 33 | super().__init__(params, defaults={}) 34 | 35 | 36 | class CheckpointStateExtractor: 37 | """ 38 | Class to extract and save the states of a model at a given checkpoint step for learning 39 | dynamics research. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | learning_dynamics_config: LearningDynamicsCheckpointingConfig, 45 | fabric: Fabric, 46 | model: nn.Module, 47 | ): 48 | self.learning_dynamics_config = learning_dynamics_config 49 | self.fabric = fabric 50 | self.model = model 51 | 52 | def extract_states(self, dataloader, compute_gradients: bool = False): 53 | """Extracts model states (activations, weights, and optionally gradients). 54 | 55 | Given a dataloader, this function will perform a forward pass of the model on each batch, 56 | and save the activations and weights at each layer. If compute_gradients is True, it will 57 | also compute the gradients of the model parameters. 58 | 59 | Args: 60 | dataloader: The dataloader containing the dataset to extract states from. 61 | compute_gradients: Whether to compute the gradients of the model parameters. 62 | 63 | Returns: 64 | A dictionary containing the activations, weights, and optionally gradients of the model. 65 | """ 66 | checkpoint_activations = {} 67 | checkpoint_weights = {} 68 | 69 | # NOTE: to extract activations and weights, we need to setup forward hooks on the layers 70 | # of the model that we are interested in. This is a good intro to forward hooks if you 71 | # are not familiar: https://web.stanford.edu/~nanbhas/blog/forward-hooks-pytorch/ 72 | forward_hooks = self._setup_forward_hooks( 73 | checkpoint_activations, 74 | checkpoint_weights, 75 | ) 76 | 77 | ######################################################## 78 | # 79 | # Forward Pass: Extract activations and weights; and compute gradients 80 | # 81 | ######################################################## 82 | 83 | for sub_batch in dataloader: 84 | _input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device) 85 | 86 | if compute_gradients: 87 | if "labels" in sub_batch: 88 | input_ids = _input_ids 89 | labels = torch.tensor( 90 | sub_batch["labels"], device=self.fabric.device 91 | ) 92 | else: 93 | input_ids = _input_ids[:, :-1] 94 | labels = _input_ids[:, 1:] 95 | else: 96 | input_ids = _input_ids 97 | labels = None 98 | 99 | if labels is None: 100 | # we can throw away the outputs, we are only interested in the hidden states 101 | with torch.no_grad(): 102 | _ = self.model(input_ids) 103 | else: 104 | # NOTE: if we are computing gradients, calling backwards will compute the gradients 105 | # of the model parameters. 106 | outputs, _ = self.model(input_ids) 107 | outputs = outputs.transpose(1, 2) 108 | loss = F.cross_entropy(outputs, labels) 109 | self.fabric.backward(loss, model=self.model) 110 | 111 | # cleanup forward hooks 112 | # NOTE this is not strictly necessary, since self.model is a deepcopy of the original model 113 | # but it is good practice to remove the hooks after the forward pass is complete. 114 | for hook in forward_hooks: 115 | hook.remove() 116 | 117 | ######################################################## 118 | # 119 | # Extract gradients from the target tensors of the model 120 | # 121 | ######################################################## 122 | 123 | layer_suffixes = self.learning_dynamics_config.layer_suffixes 124 | checkpoint_gradients = {} 125 | if compute_gradients: 126 | for name, param in self.model.named_parameters(): 127 | # only do this for the weight matrix of the layer_suffixes 128 | if ( 129 | any(layer_suffix in name for layer_suffix in layer_suffixes) 130 | and "weight" in name 131 | ): 132 | if isinstance(self.fabric.strategy, DeepSpeedStrategy): 133 | _grad = deepspeed.utils.safe_get_full_grad(param) 134 | else: 135 | _grad = param.grad 136 | 137 | assert _grad is not None, f"Gradient is None for layer: {name}" 138 | name = re.sub(r"\.weight", "", name) 139 | checkpoint_gradients[name] = _grad.detach().cpu() 140 | 141 | # zero out the gradients 142 | self.model.zero_grad() 143 | 144 | return checkpoint_activations, checkpoint_weights, checkpoint_gradients 145 | 146 | ######################################################## 147 | # 148 | # Setup forward hooks to save activations and weights at each layer 149 | # 150 | ######################################################## 151 | 152 | def _setup_forward_hooks(self, checkpoint_activations, checkpoint_weights): 153 | """Setup forward hooks for the model to save activations and weights at each layer. 154 | 155 | This function will setup forward hooks on the layers of the model that we are interested in. 156 | The forward hooks will save the activations and weights at each layer whenever the forward pass 157 | is performed. 158 | 159 | Args: 160 | checkpoint_activations: A dictionary to store the activations at each layer. 161 | checkpoint_weights: A dictionary to store the weights at each layer. 162 | 163 | Returns: 164 | A list of forward hooks. We do this so that we can remove the hooks after the forward pass 165 | is complete. 166 | """ 167 | 168 | forward_hooks = [] 169 | layer_suffixes = self.learning_dynamics_config.layer_suffixes 170 | 171 | for name, module in self.model.named_modules(): 172 | if any(layer_suffix in name for layer_suffix in layer_suffixes): 173 | _forward_hook = module.register_forward_hook( 174 | self._get_forward_hook( 175 | name, checkpoint_activations, checkpoint_weights 176 | ) 177 | ) 178 | forward_hooks.append(_forward_hook) 179 | return forward_hooks 180 | 181 | def _get_forward_hook( 182 | self, module_name, checkpoint_activations, checkpoint_weights 183 | ): 184 | """Get a forward hook for a given module. 185 | 186 | This function is called by the _setup_forward_hooks function to setup a forward hook for a given 187 | module. This functions is a closure that captures the module_name, checkpoint_activations, and 188 | checkpoint_weights. 189 | 190 | Args: 191 | module_name: The name of the module to setup a forward hook for. 192 | checkpoint_activations: A dictionary to store the activations at each layer. 193 | checkpoint_weights: A dictionary to store the weights at each layer. 194 | 195 | Returns: 196 | A forward hook for the given module. 197 | """ 198 | 199 | def _forward_hook(module, _, module_out): 200 | sequence_idx = self.learning_dynamics_config.sequence_idx 201 | 202 | local_activations = module_out[:, sequence_idx, :].detach() 203 | 204 | # Gather activations from all processes using fabric 205 | gathered_activations = self.fabric.all_gather(local_activations) 206 | 207 | # Reshape from [num_processes, batch_size, hidden_dim] to [total_batch_size, hidden_dim] 208 | # NOTE: transposing allows us to interleave the activations from each process so that 209 | # they are in the correct order. (i.e. activation N is from data sample N) 210 | gathered_activations = gathered_activations.transpose(0, 1).reshape( 211 | -1, gathered_activations.shape[-1] 212 | ) 213 | 214 | # check if there is already a key for the module name 215 | if module_name not in checkpoint_activations: 216 | # if there is no key, then we create a new key and store the hidden states 217 | checkpoint_activations[module_name] = ( 218 | gathered_activations.detach().cpu() 219 | ) 220 | 221 | # extract the weight matrix just once 222 | weight_matrix = module.weight.detach().cpu() 223 | checkpoint_weights[module_name] = weight_matrix 224 | else: 225 | # if there is already a key, then we concatenate the new hidden states to the existing ones 226 | checkpoint_activations[module_name] = torch.cat( 227 | ( 228 | checkpoint_activations[module_name], 229 | gathered_activations.detach().cpu(), 230 | ) 231 | ) 232 | 233 | return _forward_hook 234 | 235 | 236 | def compute_learning_dynamics_states( 237 | checkpointing_config: CheckpointingConfig, 238 | fabric: Fabric, 239 | model: nn.Module, 240 | dataset: Dataset, 241 | compute_gradients: bool = False, 242 | ) -> Dict[str, torch.Tensor]: 243 | """Computes the learning dynamics metrics for a given checkpoint step. 244 | 245 | Uses the CheckpointStateExtractor to extract the activations, weights, and optionally gradients 246 | of the model at a given checkpoint step. 247 | 248 | Args: 249 | checkpointing_config: The configuration object for checkpointing. 250 | fabric: The Fabric instance for distributed training. 251 | model: The model to extract states from. 252 | dataset: The dataset to extract states from. 253 | compute_gradients: Whether to compute the gradients of the model parameters. 254 | 255 | Returns: 256 | A dictionary containing the activations, weights, and optionally gradients of the model. 257 | """ 258 | 259 | # NOTE: Synchronizing processes for fabric dataloader setup 260 | fabric.barrier() 261 | model.to("cpu") # Offloading model to CPU 262 | 263 | # Setting up Dataloader for learning dynamics 264 | def _collate_fn(batch): 265 | return {"input_ids": [entry["input_ids"] for entry in batch]} 266 | 267 | batch_size = checkpointing_config.learning_dynamics.batch_size 268 | sub_batch_size = batch_size // fabric.world_size 269 | 270 | # NOTE: Make sure to set drop_last to False, otherwise the last batch will be dropped 271 | # and we will not have a complete set of activations for the last sample. Also, 272 | # we need to set shuffle to False, otherwise the activations will be shuffled across 273 | # processes and we will not be able to interleave them correctly. 274 | extractor_dataloader = DataLoader( 275 | dataset, 276 | batch_size=sub_batch_size, 277 | shuffle=False, 278 | collate_fn=_collate_fn, 279 | drop_last=False, 280 | ) 281 | extractor_dataloader = fabric.setup_dataloaders( 282 | extractor_dataloader, use_distributed_sampler=True 283 | ) 284 | 285 | # Create a new model instance with same parameters but zero gradients 286 | _model = initialize_model(model.config) 287 | _model.load_state_dict(model.state_dict()) 288 | 289 | if isinstance(fabric.strategy, DeepSpeedStrategy): 290 | _model, _ = fabric.setup(_model, DummyOptimizer(_model.parameters())) 291 | else: 292 | _model = fabric.setup(_model) 293 | 294 | _model.zero_grad() 295 | 296 | # setup forward hooks for the model to save activations and weights at each layer 297 | state_extractor = CheckpointStateExtractor( 298 | checkpointing_config.learning_dynamics, fabric, _model 299 | ) 300 | 301 | checkpoint_activations, checkpoint_weights, checkpoint_gradients = ( 302 | state_extractor.extract_states( 303 | extractor_dataloader, compute_gradients=compute_gradients 304 | ) 305 | ) 306 | 307 | del _model 308 | torch.cuda.empty_cache() 309 | 310 | # NOTE: Synchronizing processes for model setup 311 | fabric.barrier() 312 | 313 | model.to(fabric.device) 314 | 315 | # NOTE: Trimming down the activations to match the dataset size; 316 | # This is because the DataSampler might add extra samples to the dataset to make it evenly divisible 317 | # by the number of processes. We need to remove these extra samples. 318 | for layer_name, layer_activations in checkpoint_activations.items(): 319 | if len(layer_activations) > len(dataset): 320 | checkpoint_activations[layer_name] = layer_activations[: len(dataset)] 321 | elif len(layer_activations) < len(dataset): 322 | raise ValueError( 323 | f"Number of activations ({len(layer_activations)}) in layer {layer_name} does not match number of samples in dataset ({len(dataset)})" 324 | ) 325 | 326 | return { 327 | "activations": checkpoint_activations, 328 | "weights": checkpoint_weights, 329 | "gradients": checkpoint_gradients, 330 | } 331 | 332 | 333 | @rank_zero_only 334 | @use_backoff() 335 | def save_learning_dynamics_states( 336 | checkpointing_config: CheckpointingConfig, 337 | checkpoint_step: int, 338 | prefix: str, 339 | fabric: Fabric, 340 | learning_dynamics_states: Dict[str, torch.Tensor], 341 | learning_dynamics_dataset: Optional[Dataset] = None, 342 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 343 | ) -> None: 344 | """Save the learning dynamics metrics to the checkpointing directory. 345 | 346 | By default only the learning dynamics states are saved. If the learning dynamics dataset 347 | is provided, it is also saved; if a tokenizer is provided, the dataset is also detokenized 348 | (i.e. a new column with the text is added to the dataset). 349 | 350 | The learning dynamics dataset is saved in the checkpointing directory as a HuggingFace 351 | dataset. 352 | 353 | Creates a versioned checkpoint directory with the following structure: 354 | 355 | {checkpointing_config.runs_dir}/ 356 | └── {checkpointing_config.run_name}/ 357 | └── {checkpointing_config.checkpoints_dir}/ 358 | ├── step_{checkpoint_step}/ 359 | │ └── {checkpointing_config.learning_dynamics_dir}/ # Learning Dynamics files 360 | │ ├── {prefix}_activations.pt 361 | │ ├── {prefix}_weights.pt 362 | │ └── {prefix}_gradients.pt 363 | │ └── {prefix}_data/ # if learning_dynamics_dataset is provided 364 | └── latest -> step_{checkpoint_step}/ 365 | 366 | NOTE: this function is only called on rank 0 367 | 368 | Args: 369 | checkpointing_config: The configuration object for checkpointing. 370 | checkpoint_step: The checkpoint step at which the learning dynamics states were computed. 371 | prefix: The prefix for the learning dynamics states. 372 | fabric: The Fabric instance for distributed training. 373 | learning_dynamics_states: The learning dynamics states to save. 374 | learning_dynamics_dataset: The dataset containing learning dynamics data, 375 | including input IDs that need to be decoded. (optional) 376 | tokenizer: The tokenizer used to decode input IDs into text. (optional) 377 | """ 378 | 379 | runs_dir = checkpointing_config.runs_dir 380 | run_name = checkpointing_config.run_name 381 | checkpoints_dir = checkpointing_config.checkpoints_dir 382 | learning_dynamics_dir = checkpointing_config.learning_dynamics_dir 383 | 384 | run_path = os.path.join(runs_dir, run_name) 385 | root_checkpoint_path = os.path.join(run_path, checkpoints_dir) 386 | checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}") 387 | learning_dynamics_path = os.path.join(checkpoint_path, learning_dynamics_dir) 388 | os.makedirs(learning_dynamics_path, exist_ok=True) 389 | 390 | # save the learning dynamics states 391 | for key, value in learning_dynamics_states.items(): 392 | if value is not None and len(value) > 0: 393 | torch.save( 394 | value, os.path.join(learning_dynamics_path, f"{prefix}_{key}.pt") 395 | ) 396 | 397 | if learning_dynamics_dataset is not None: 398 | if tokenizer is not None: 399 | # go through dataset and decode the input ids; and add back into dataset 400 | detokenized_dataset = {"input_ids": [], "text": []} 401 | 402 | for entry in learning_dynamics_dataset: 403 | input_ids = entry["input_ids"] 404 | decoded_text = tokenizer.decode(input_ids, skip_special_tokens=True) 405 | detokenized_dataset["input_ids"].append(input_ids) 406 | detokenized_dataset["text"].append(decoded_text) 407 | 408 | learning_dynamics_dataset = Dataset.from_dict(detokenized_dataset) 409 | 410 | learning_dynamics_dataset_path = os.path.join( 411 | learning_dynamics_path, f"{prefix}_data" 412 | ) 413 | learning_dynamics_dataset.save_to_disk(learning_dynamics_dataset_path) 414 | 415 | if checkpointing_config.save_to_hf: 416 | # Upload the HF model 417 | upload_folder( 418 | folder_path=learning_dynamics_path, 419 | path_in_repo=learning_dynamics_dir, 420 | repo_id=checkpointing_config.hf_checkpoint.repo_id, 421 | commit_message=f"Saving Learning Dynamics Data ({prefix}) -- Step {checkpoint_step}", 422 | revision=checkpointing_config.run_name, 423 | token=os.getenv("HF_TOKEN"), 424 | ) 425 | -------------------------------------------------------------------------------- /src/checkpointing/training.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for checkpointing training-related states (i.e. model, optimizer, lr_scheduler, etc.) 3 | 4 | We save both a HuggingFace model and a Fabric-specific checkpoint. The HuggingFace model is 5 | saved at the step-specific checkpoint directory, while the Fabric-specific checkpoint is saved 6 | in a subdirectory. This is done to facilitate easier versioning of the HuggingFace model files 7 | (which are what gets uploaded to the Hub). 8 | """ 9 | 10 | import os 11 | from dataclasses import asdict 12 | from typing import Any, Dict, Tuple, Union 13 | 14 | import yaml 15 | from huggingface_hub import upload_file, upload_folder 16 | from lightning.fabric import Fabric 17 | from lightning.fabric.strategies import DeepSpeedStrategy 18 | from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states 19 | from torch import nn 20 | from torch.optim import Optimizer 21 | from torch.optim.lr_scheduler import LRScheduler 22 | from transformers import PreTrainedTokenizerBase 23 | 24 | from src.config import CheckpointingConfig 25 | from src.training.utils.io import use_backoff 26 | 27 | 28 | @use_backoff() 29 | def load_checkpoint( 30 | checkpointing_config: CheckpointingConfig, 31 | checkpoint_step: Union[str, int], 32 | fabric: Fabric, 33 | model: nn.Module, 34 | optimizer: Optimizer, 35 | lr_scheduler: LRScheduler, 36 | ) -> Tuple[nn.Module, Optimizer, LRScheduler, int]: 37 | """Load model checkpoint and associated states from a given step. 38 | 39 | Args: 40 | checkpointing_config: Configuration object containing checkpoint settings 41 | checkpoint_step: The step at which to load the checkpoint 42 | fabric: Lightning Fabric instance for distributed training support 43 | model: The model instance to load weights into 44 | optimizer: The optimizer instance to load states into 45 | lr_scheduler: The learning rate scheduler to load states into 46 | 47 | Returns: 48 | Tuple containing the model, optimizer, lr_scheduler, and checkpoint step. 49 | Returns None if no checkpoint is found. 50 | """ 51 | 52 | if isinstance(checkpoint_step, int): 53 | checkpoint_step = f"step_{checkpoint_step}" 54 | 55 | checkpoint_path = os.path.join( 56 | checkpointing_config.runs_dir, 57 | checkpointing_config.run_name, 58 | checkpointing_config.checkpoints_dir, 59 | checkpoint_step, 60 | ) 61 | 62 | if not os.path.exists(checkpoint_path): 63 | return None 64 | 65 | # Load from specified fabric checkpoint subdirectory 66 | fabric_checkpoint_path = os.path.join( 67 | checkpoint_path, checkpointing_config.fabric_checkpoint_dir 68 | ) 69 | 70 | checkpoint_state = { 71 | "_model": model, 72 | "_optimizer": optimizer, 73 | "_lr_scheduler": lr_scheduler, 74 | } 75 | 76 | if not isinstance(fabric.strategy, DeepSpeedStrategy): 77 | fabric_load_file = os.path.join( 78 | fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename 79 | ) 80 | else: 81 | # Deepspeed checkpoints create sub-directory with distributed checkpoint file 82 | fabric_load_file = fabric_checkpoint_path 83 | 84 | extra_state = fabric.load(os.path.join(fabric_load_file), state=checkpoint_state) 85 | 86 | # NOTE: extra_state will contain any additional states that were saved in the checkpoint 87 | checkpoint_step = extra_state["_checkpoint_step"] 88 | 89 | if "_rng_states" in extra_state: 90 | _rng_states = extra_state["_rng_states"] 91 | _set_rng_states(_rng_states) 92 | 93 | return model, optimizer, lr_scheduler, checkpoint_step 94 | 95 | 96 | @use_backoff() 97 | def save_checkpoint( 98 | configs: Dict[str, Any], 99 | checkpoint_step: int, 100 | fabric: Fabric, 101 | model: nn.Module, 102 | optimizer: Optimizer, 103 | lr_scheduler: LRScheduler, 104 | tokenizer: PreTrainedTokenizerBase, 105 | upload_logs: bool = False, 106 | ) -> None: 107 | """Save training checkpoint and associated states to disk and optionally to HuggingFace Hub. 108 | 109 | We save the following files: 110 | - HuggingFace model files (config.json, pytorch_model.bin) 111 | - Tokenizer files (vocab.json, merges.txt) 112 | - Fabric-specific files - fabric state of the model, optimizer, and lr_scheduler. If using 113 | DeepSpeed, the checkpoint is saved in a subdirectory, otherwise it is saved in a single file. 114 | 115 | Note that the HuggingFace model files are saved at the step-specific checkpoint directory, while the 116 | Fabric-specific files are saved in a subdirectory. This is done to facilitate easier 117 | versioning of the HuggingFace model files (which are what gets uploaded to the Hub). 118 | 119 | NOTE: Why do we save a HF model at all? We do this because it makes it easier to load the model 120 | in a separate script for evaluation and to play nicely with the HuggingFace Hub. 121 | 122 | Creates a versioned checkpoint directory with the following structure: 123 | 124 | {checkpointing_config.runs_dir}/ 125 | └── {checkpointing_config.run_name}/ 126 | └── training_config.yaml # Training config 127 | └── {checkpointing_config.checkpoints_dir}/ 128 | ├── step_{checkpoint_step}/ 129 | │ ├── config.json # HuggingFace model config 130 | │ ├── model.safetensors # HuggingFace model weights 131 | │ ├── pico_{model_type}.py # HuggingFace custom model class 132 | │ ├── tokenizer.json # Tokenizer vocab 133 | │ ├── tokenizer_config.json # Tokenizer config 134 | │ └── {checkpointing_config.fabric_checkpoint_dir}/ # Fabric-specific files 135 | │ └── checkpoint/ # Distributed model checkpoint files (if using DeepSpeed) 136 | │ OR 137 | │ └── checkpoint.pt # Single checkpoint file (if using other strategies) 138 | └── latest -> step_{checkpoint_step}/ 139 | 140 | Args: 141 | configs: A dictionary containing the initialized configuration objects. 142 | checkpoint_step: The current training checkpoint step (i.e. number of learning steps taken) 143 | fabric: Lightning Fabric instance for distributed training support 144 | model: The model instance to save 145 | optimizer: The optimizer instance to save 146 | lr_scheduler: The learning rate scheduler to save 147 | tokenizer: The tokenizer to save 148 | upload_logs: Whether to upload training logs to HF Hub (default: False) 149 | 150 | """ 151 | 152 | checkpointing_config = configs["checkpointing"] 153 | 154 | # Get the directories from the training config 155 | runs_dir = checkpointing_config.runs_dir 156 | checkpoints_dir = checkpointing_config.checkpoints_dir 157 | fabric_checkpoint_dir = checkpointing_config.fabric_checkpoint_dir 158 | logs_dir = checkpointing_config.logs_dir 159 | 160 | run_path = os.path.join(runs_dir, checkpointing_config.run_name) 161 | root_checkpoint_path = os.path.join(run_path, checkpoints_dir) 162 | checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}") 163 | 164 | # Create directories 165 | os.makedirs(checkpoint_path, exist_ok=True) 166 | 167 | ######################################################## 168 | # 169 | # Save HuggingFace files 170 | # 171 | ######################################################## 172 | 173 | # NOTE: we convert the Pico model to a HuggingFace model before saving it. See `model.py` 174 | # for more details. 175 | if fabric.global_rank == 0: 176 | hf_model = model.convert_to_hf_model() 177 | hf_model.save_pretrained(checkpoint_path) 178 | tokenizer.save_pretrained(checkpoint_path) 179 | 180 | ######################################################## 181 | # 182 | # Save Fabric-specific files 183 | # 184 | ######################################################## 185 | 186 | # Create fabric-specific subdirectory 187 | fabric_checkpoint_path = os.path.join(checkpoint_path, fabric_checkpoint_dir) 188 | os.makedirs(fabric_checkpoint_path, exist_ok=True) 189 | 190 | # Save model states (use underscore to avoid conflicts with third-party libraries) 191 | checkpoint_state = { 192 | "_model": model, 193 | "_optimizer": optimizer, 194 | "_lr_scheduler": lr_scheduler, 195 | "_checkpoint_step": checkpoint_step, 196 | } 197 | 198 | if not isinstance(fabric.strategy, DeepSpeedStrategy): 199 | checkpoint_state["_rng_states"] = _collect_rng_states() 200 | fabric_save_file = os.path.join( 201 | fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename 202 | ) 203 | else: 204 | # Deepspeed checkpoints create sub-directory with distributed checkpoint file 205 | fabric_save_file = fabric_checkpoint_path 206 | 207 | fabric.save(fabric_save_file, checkpoint_state) 208 | 209 | if fabric.global_rank == 0: 210 | # Save config in fabric directory 211 | config_path = os.path.join(run_path, "training_config.yaml") 212 | if not os.path.exists(config_path): 213 | # Converting dataclasses to joined dicts and saving to file 214 | _training_config = {} 215 | for config_name, config in configs.items(): 216 | _training_config[config_name] = asdict(config) 217 | with open(config_path, "w") as f: 218 | yaml.dump(_training_config, f) 219 | 220 | # Update latest symlink 221 | latest_symlink_path = os.path.join(root_checkpoint_path, "latest") 222 | if os.path.lexists(latest_symlink_path): 223 | os.remove(latest_symlink_path) 224 | os.symlink( 225 | f"step_{checkpoint_step}", latest_symlink_path, target_is_directory=True 226 | ) 227 | 228 | ######################################################## 229 | # 230 | # Push to HuggingFace Hub (if configured) 231 | # 232 | ######################################################## 233 | 234 | if fabric.global_rank == 0: 235 | # Push only on rank zero thread 236 | 237 | if checkpointing_config.save_to_hf: 238 | repo_id = checkpointing_config.hf_checkpoint.repo_id 239 | 240 | # Upload the HF model 241 | hf_model.push_to_hub( 242 | repo_id=repo_id, 243 | commit_message=f"Saving HF Model -- Step {checkpoint_step}", 244 | revision=checkpointing_config.run_name, 245 | token=os.getenv("HF_TOKEN"), 246 | ) 247 | 248 | if checkpoint_step == 0: 249 | # Uploading Tokenizer during first step since it never changes 250 | tokenizer.push_to_hub( 251 | repo_id=repo_id, 252 | commit_message=f"Saving Tokenizer -- Step {checkpoint_step}", 253 | revision=checkpointing_config.run_name, 254 | token=os.getenv("HF_TOKEN"), 255 | ) 256 | 257 | # Upload training config, also only in first step 258 | upload_file( 259 | path_or_fileobj=config_path, 260 | path_in_repo="training_config.yaml", 261 | repo_id=repo_id, 262 | commit_message=f"Saving Training Config -- Step {checkpoint_step}", 263 | revision=checkpointing_config.run_name, 264 | token=os.getenv("HF_TOKEN"), 265 | ) 266 | 267 | # Upload the fabric checkpoint directory 268 | upload_folder( 269 | folder_path=fabric_checkpoint_path, 270 | path_in_repo=fabric_checkpoint_dir, 271 | repo_id=repo_id, 272 | commit_message=f"Saving Fabric Checkpoint -- Step {checkpoint_step}", 273 | revision=checkpointing_config.run_name, 274 | token=os.getenv("HF_TOKEN"), 275 | ) 276 | 277 | # Upload logs if requested 278 | if upload_logs: 279 | logs_path = os.path.join(run_path, logs_dir) 280 | upload_folder( 281 | folder_path=logs_path, 282 | path_in_repo=logs_dir, 283 | repo_id=repo_id, 284 | commit_message=f"Saving Logs -- Step {checkpoint_step}", 285 | revision=checkpointing_config.run_name, 286 | token=os.getenv("HF_TOKEN"), 287 | ) 288 | -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pico Config Package 3 | 4 | The modules of this package are where you can specify the hyperparameters for the Pico model, 5 | the dataset, the training process, evaluation, etc. 6 | 7 | As with anything else in Pico, we've designed for the configuration setup to be as flexible 8 | as possible. By default the configs are implemented as vanilla dataclasses -- this makes it easy to 9 | switch to different config management systems if you want, like hydra. 10 | 11 | Some things to NOTE: 12 | - All hyperparameters are initialized with default values, which can be overridden. 13 | - The default vocab size is set to the size of the OLMo tokenizer. 14 | """ 15 | 16 | # For convenience, we export the config classes here 17 | from .checkpointing_config import CheckpointingConfig 18 | from .data_config import DataConfig 19 | from .evaluation_config import EvaluationConfig 20 | from .model_config import ModelConfig 21 | from .monitoring_config import MonitoringConfig 22 | from .training_config import TrainingConfig 23 | 24 | __all__ = [ 25 | "CheckpointingConfig", 26 | "DataConfig", 27 | "EvaluationConfig", 28 | "ModelConfig", 29 | "MonitoringConfig", 30 | "TrainingConfig", 31 | ] 32 | -------------------------------------------------------------------------------- /src/config/_constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constants used throughout the codebase 3 | """ 4 | 5 | # Basic Training Constants used throughout the codebase 6 | VOCAB_SIZE = 50304 7 | MAX_SEQ_LEN = 2048 8 | BATCH_SIZE = 1024 9 | GRADIENT_ACCUMULATION_STEPS = 128 10 | 11 | # Directories used to store training runs, checkpoints, logs, and evaluation results 12 | RUNS_DIR = "runs" 13 | CHECKPOINTS_DIR = "checkpoints" 14 | LOGS_DIR = "logs" 15 | FABRIC_CHECKPOINT_DIR = "fabric_state" 16 | FABRIC_CHECKPOINT_FILENAME = "checkpoint.pt" 17 | LEARNING_DYNAMICS_DIR = "learning_dynamics" 18 | EVAL_RESULTS_DIR = "eval_results" 19 | -------------------------------------------------------------------------------- /src/config/checkpointing_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Checkpointing Config 3 | 4 | Specifies the hyperparameters for the checkpointing process; checkpointing is used to save 5 | the model and optimizer states, as well as the learning dynamics metrics. 6 | """ 7 | 8 | from dataclasses import dataclass, field 9 | from typing import List, Optional 10 | 11 | from ._constants import ( 12 | CHECKPOINTS_DIR, 13 | EVAL_RESULTS_DIR, 14 | FABRIC_CHECKPOINT_DIR, 15 | FABRIC_CHECKPOINT_FILENAME, 16 | LEARNING_DYNAMICS_DIR, 17 | LOGS_DIR, 18 | RUNS_DIR, 19 | ) 20 | 21 | 22 | @dataclass 23 | class TrainingCheckpointingConfig: 24 | # Automatically resume training from the most recent checkpoint 25 | auto_resume: bool = True 26 | 27 | 28 | @dataclass 29 | class EvaluationCheckpointingConfig: 30 | # Directory in which evaluation results are saved 31 | eval_results_dir: str = EVAL_RESULTS_DIR 32 | 33 | 34 | @dataclass 35 | class LearningDynamicsCheckpointingConfig: 36 | # Suffixes of the layers to compute learning dynamics for 37 | layer_suffixes: List[str] = field( 38 | default_factory=lambda: [ 39 | "attention.v_proj", 40 | "attention.o_proj", 41 | "swiglu.w_2", 42 | ] 43 | ) 44 | 45 | # Sequence index at which to extract hidden states; by default, we extract the hidden states 46 | # at the last token of the sequence (-1) 47 | sequence_idx: int = -1 48 | 49 | # size of the sub-batch used for extracting learning dynamics states 50 | batch_size: int = 8 51 | 52 | # Path to evaluation dataset - used across learning dynamics checkpointing for consistency 53 | # NOTE: set to None to disable extracting learning dynamics states for an eval_batch 54 | # NOTE: this dataset should be small, ideally just a batch of additional data 55 | eval_data: Optional[str] = "pico-lm/pretokenized-paloma-tinsy" 56 | 57 | 58 | @dataclass 59 | class HuggingFaceCheckpointingConfig: 60 | # Should be in the format of <(username or organization name)>/, e.g. pico-lm/demo 61 | repo_id: str = "" 62 | 63 | # HuggingFace Collection Slug (specifies a tag for the run) 64 | collection_slug: Optional[str] = None 65 | 66 | 67 | @dataclass 68 | class CheckpointingConfig: 69 | # Assign a name to the run 70 | run_name: Optional[str] = None 71 | 72 | # Defining checkpointing directories 73 | runs_dir: str = RUNS_DIR 74 | checkpoints_dir: str = CHECKPOINTS_DIR 75 | logs_dir: str = LOGS_DIR 76 | fabric_checkpoint_dir: str = FABRIC_CHECKPOINT_DIR 77 | fabric_checkpoint_filename: str = FABRIC_CHECKPOINT_FILENAME 78 | learning_dynamics_dir: str = LEARNING_DYNAMICS_DIR 79 | 80 | # How often to save checkpoints 81 | save_every_n_steps: int = 1000 82 | 83 | # Whether to save checkpoints to HuggingFace 84 | save_to_hf: Optional[bool] = False 85 | hf_checkpoint: HuggingFaceCheckpointingConfig = field( 86 | default_factory=HuggingFaceCheckpointingConfig 87 | ) 88 | 89 | training: TrainingCheckpointingConfig = field( 90 | default_factory=TrainingCheckpointingConfig 91 | ) 92 | evaluation: EvaluationCheckpointingConfig = field( 93 | default_factory=EvaluationCheckpointingConfig 94 | ) 95 | learning_dynamics: LearningDynamicsCheckpointingConfig = field( 96 | default_factory=LearningDynamicsCheckpointingConfig 97 | ) 98 | -------------------------------------------------------------------------------- /src/config/data_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Config 3 | 4 | Specifies the hyperparameters for the dataset, dataloader, and tokenizer. 5 | """ 6 | 7 | from dataclasses import dataclass, field 8 | 9 | from ._constants import BATCH_SIZE, VOCAB_SIZE 10 | 11 | 12 | @dataclass 13 | class DatasetConfig: 14 | # Defines the HuggingFace name of a dataset 15 | name: str = "pico-lm/pretokenized-dolma" 16 | 17 | 18 | @dataclass 19 | class DataLoaderConfig: 20 | # NOTE: You should only change these values jointly with the training config; so that the 21 | # sub-batch size is consistent with the gradient accumulation steps 22 | batch_size: int = BATCH_SIZE 23 | 24 | 25 | @dataclass 26 | class TokenizerConfig: 27 | # Specify a tokenizer to use 28 | name: str = "allenai/OLMo-7B-0724-hf" 29 | vocab_size: int = VOCAB_SIZE 30 | 31 | 32 | @dataclass 33 | class DataConfig: 34 | dataset: DatasetConfig = field(default_factory=DatasetConfig) 35 | dataloader: DataLoaderConfig = field(default_factory=DataLoaderConfig) 36 | tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig) 37 | -------------------------------------------------------------------------------- /src/config/evaluation_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation Config 3 | 4 | Specifies the hyperparameters for the evaluation process, i.e. what metrics to compute, etc. 5 | """ 6 | 7 | from dataclasses import dataclass, field 8 | from typing import List, Optional 9 | 10 | from src.config._constants import MAX_SEQ_LEN 11 | 12 | 13 | @dataclass 14 | class PalomaEvaluationConfig: 15 | dataset_name: str = "pico-lm/pretokenized-paloma-tinsy" 16 | dataset_split: str = "val" 17 | max_length: int = MAX_SEQ_LEN 18 | batch_size: int = 16 19 | 20 | 21 | @dataclass 22 | class EvaluationConfig: 23 | # Evaluation metrics to compute: by default, we compute the perplexity of the model on the paloma dataset 24 | metrics: Optional[List[str]] = field(default_factory=lambda: ["paloma"]) 25 | 26 | # NOTE: Add other evaluation configs here 27 | # Each evaluation metric should have its own config 28 | paloma: PalomaEvaluationConfig = field(default_factory=PalomaEvaluationConfig) 29 | -------------------------------------------------------------------------------- /src/config/model_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model Config 3 | 4 | Specifies the hyperparameters for the Pico model/model architecture. 5 | """ 6 | 7 | from dataclasses import dataclass 8 | from typing import Optional 9 | 10 | from ._constants import BATCH_SIZE, MAX_SEQ_LEN, VOCAB_SIZE 11 | 12 | 13 | @dataclass 14 | class ModelConfig: 15 | model_type: str = "pico_decoder" 16 | 17 | # Pico Decoder default hyperparameters 18 | 19 | d_model: int = 768 20 | n_layers: int = 12 21 | 22 | vocab_size: int = VOCAB_SIZE 23 | batch_size: int = BATCH_SIZE 24 | max_seq_len: int = MAX_SEQ_LEN 25 | 26 | attention_n_heads: int = 12 27 | attention_n_kv_heads: Optional[int] = 4 28 | 29 | activation_hidden_dim: int = 3072 30 | 31 | norm_eps: float = 1e-6 32 | 33 | position_emb_theta: float = 10000.0 34 | -------------------------------------------------------------------------------- /src/config/monitoring_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Monitoring Config 3 | 4 | Specifies the monitoring process, e.g. how to log metrics and keep track of training progress. 5 | """ 6 | 7 | from dataclasses import dataclass, field 8 | 9 | 10 | @dataclass 11 | class LoggingConfig: 12 | log_level: str = "INFO" 13 | log_every_n_steps: int = 100 14 | 15 | 16 | @dataclass 17 | class WandbConfig: 18 | # configure logging to Weights and Biases 19 | project: str = "" 20 | entity: str = "" 21 | 22 | 23 | @dataclass 24 | class MonitoringConfig: 25 | logging: LoggingConfig = field(default_factory=LoggingConfig) 26 | 27 | # Weights and Biases 28 | save_to_wandb: bool = False 29 | wandb: WandbConfig = field(default_factory=WandbConfig) 30 | -------------------------------------------------------------------------------- /src/config/training_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training Config 3 | 4 | Specifies the hyperparameters for the training process, i.e. the optimizer, learning rate, etc. 5 | """ 6 | 7 | from dataclasses import dataclass, field 8 | 9 | from ._constants import GRADIENT_ACCUMULATION_STEPS 10 | 11 | 12 | @dataclass 13 | class FabricConfig: 14 | # Configure nodes/devices for parallelised training 15 | num_nodes: int = 1 16 | num_devices: int = 1 17 | precision: str = "bf16-mixed" 18 | # Hardware accelerator to use, can be cpu/cuda/mps etc. 19 | accelerator: str = "cuda" 20 | 21 | 22 | @dataclass 23 | class OptimizationConfig: 24 | # Optimizer 25 | optimizer: str = "adamw" 26 | lr: float = 3e-4 27 | 28 | # Learning Rate Scheduler 29 | lr_scheduler: str = "linear_with_warmup" 30 | lr_warmup_steps: int = 2500 31 | 32 | # Define number of gradient accumulation steps 33 | gradient_accumulation_steps: int = GRADIENT_ACCUMULATION_STEPS 34 | 35 | 36 | @dataclass 37 | class TrainingConfig: 38 | fabric: FabricConfig = field(default_factory=FabricConfig) 39 | optimization: OptimizationConfig = field(default_factory=OptimizationConfig) 40 | max_steps: int = 200_000 41 | -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pico Evaluation Package 3 | 4 | This package implements the evaluation pipeline for the Pico language model. It provides 5 | functionality to evaluate model performance using various metrics and handles the complete 6 | evaluation workflow. 7 | 8 | We recommend that each evaluation metric should have its own config, and should be 9 | implemented as a module in the `evaluation/tasks` directory that exposes a `run_` function. 10 | 11 | NOTE: Out of the box we only support Paloma, but the structure is designed to be flexible and 12 | you are meant to add whatever metrics you want. One of the main reasons we store out 13 | the model in the HuggingFace format is so that its easy to use third-party evaluation 14 | libraries/frameworks. 15 | """ 16 | 17 | import os 18 | 19 | import torch 20 | from lightning.fabric import Fabric 21 | from torch import nn 22 | 23 | from src.config import CheckpointingConfig, EvaluationConfig 24 | 25 | from .tasks.paloma import run_paloma_evaluation 26 | 27 | 28 | def run_evaluation( 29 | evaluation_config: EvaluationConfig, 30 | checkpointing_config: CheckpointingConfig, 31 | fabric: Fabric, 32 | model: nn.Module, 33 | ) -> None: 34 | """Run model evaluation using specified metrics in `evaluation_config`. 35 | 36 | This function orchestrates the complete evaluation pipeline by: 37 | 1. Resolving the model checkpoint path (either specified or latest) to load the model from; 38 | during training, this is the path to the latest checkpoint in the run directory. 39 | 2. Iterating over each evaluation metric, and running the corresponding evaluation function. 40 | NOTE: we suggest you follow the pattern of the Paloma evaluation function, and implement 41 | your own evaluation function for each metric in the `evaluation/tasks` directory. 42 | 3. Aggregating results across all metrics in a dictionary, and returning it. 43 | 44 | Args: 45 | evaluation_config (EvaluationConfig): Configuration object containing: 46 | - metrics (List[str]): Metrics to evaluate; each metric should have its 47 | own config. Currently supported: ["paloma"]; 48 | - paloma (PalomaConfig): Configuration for Paloma evaluation 49 | - max_length (int): Maximum sequence length 50 | - limit_eval_examples (Optional[int]): Number of examples to evaluate 51 | checkpointing_config (CheckpointingConfig): Configuration object containing: 52 | fabric (Fabric): Lightning Fabric instance 53 | model (nn.Module): Original model instance 54 | 55 | Returns: 56 | Dict[str, float]: Dictionary mapping metric names to their values 57 | Example: {"paloma": 3.45} 58 | 59 | Raises: 60 | ValueError: If an unsupported evaluation metric is requested 61 | 62 | Example: 63 | results = run_evaluation( 64 | EvaluationConfig( 65 | run_name="experiment_1", 66 | metrics=["paloma"], 67 | paloma=PalomaConfig(max_length=2048, batch_size=16) 68 | ) 69 | ) 70 | 71 | """ 72 | 73 | fabric.barrier() 74 | 75 | model.to("cpu") # Offloading model to CPU 76 | 77 | evaluation_results = {} 78 | 79 | # NOTE: Evaluation is only run on first processes to enable third-party evaluation libraries 80 | # to determine how to handle distributed evaluation. 81 | if fabric.global_rank == 0: 82 | run_name = checkpointing_config.run_name 83 | model_path = f"{os.getcwd()}/{checkpointing_config.runs_dir}/{run_name}/{checkpointing_config.checkpoints_dir}/latest" 84 | os.makedirs(model_path, exist_ok=True) 85 | 86 | for metric in evaluation_config.metrics: 87 | # NOTE: add your own metrics here 88 | if metric == "paloma": 89 | evaluation_result = run_paloma_evaluation( 90 | model_path, evaluation_config.paloma 91 | ) 92 | else: 93 | raise ValueError(f"Metric {metric} not supported") 94 | 95 | evaluation_results[metric] = evaluation_result 96 | 97 | torch.cuda.empty_cache() 98 | 99 | fabric.barrier() 100 | 101 | model.to(fabric.device) 102 | 103 | return evaluation_results 104 | -------------------------------------------------------------------------------- /src/evaluation/tasks/paloma.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paloma is a comprehensive evaluation benchmark for large language models (LLMs) that focuses 3 | on measuring perplexity across diverse text domains. 4 | 5 | To evaluate on Paloma, we use the huggingface evaluation framework. 6 | 7 | For more details, see: https://huggingface.co/datasets/allenai/paloma 8 | """ 9 | 10 | import evaluate 11 | from datasets import load_dataset 12 | from datasets.utils.logging import disable_progress_bar, enable_progress_bar 13 | 14 | from src.config.evaluation_config import PalomaEvaluationConfig 15 | 16 | 17 | def run_paloma_evaluation( 18 | model_path: str, 19 | paloma_config: PalomaEvaluationConfig, 20 | ) -> None: 21 | """Run Perplexity evaluation on the Paloma evaluation dataset. 22 | 23 | We use the HuggingFace evaluate library to load in and compute the perplexity metric. 24 | 25 | Args: 26 | model_path (str): Path to the model checkpoint to be evaluated 27 | paloma_config (PalomaEvaluationConfig): Configuration for Paloma evaluation 28 | """ 29 | 30 | disable_progress_bar() 31 | 32 | # load custom evaluation space, see https://huggingface.co/spaces/pico-lm/perplexity 33 | perplexity = evaluate.load("pico-lm/perplexity") 34 | 35 | dataset = load_dataset( 36 | paloma_config.dataset_name, split=paloma_config.dataset_split 37 | )["text"] 38 | 39 | # compute perplexity score on Paloma dataset 40 | perplexity_result = perplexity.compute( 41 | model_id=model_path, 42 | predictions=dataset, 43 | add_start_token=False, 44 | max_length=paloma_config.max_length, 45 | batch_size=paloma_config.batch_size, 46 | trust_remote_code=True, 47 | ) 48 | 49 | mean_perplexity = perplexity_result["mean_perplexity"] 50 | 51 | enable_progress_bar() 52 | return mean_perplexity 53 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model Package 3 | 4 | This Package contains Pico models (currently only the Pico Decoder). We plan to implement other 5 | architectures in the future. 6 | 7 | If you have other models you'd like to implement, we recommend you add modules to this package. 8 | """ 9 | 10 | from .pico_decoder import PicoDecoder 11 | 12 | __all__ = ["PicoDecoder"] 13 | -------------------------------------------------------------------------------- /src/model/pico_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pico Decoder: A Lightweight Causal Transformer Language Model 3 | 4 | Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes. 5 | 6 | Everything is written with a modular design for easy modification and experimentation. 7 | 8 | Key features: 9 | - RMSNorm for layer normalization 10 | - Rotary Positional Embeddings (RoPE) 11 | - Multi-head attention with KV-cache support 12 | - SwiGLU activation function 13 | - Residual connections throughout 14 | 15 | - KV-cache for faster autoregressive generation 16 | 17 | References: 18 | - RoPE: https://arxiv.org/abs/2104.09864 19 | - SwiGLU: https://arxiv.org/abs/2002.05202 20 | - LLAMA: https://arxiv.org/abs/2302.13971 21 | 22 | Adapted from: 23 | - OLMO: https://github.com/allenai/OLMo 24 | - LLAMA: https://github.com/meta/llama 25 | """ 26 | 27 | from dataclasses import asdict 28 | from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union 29 | 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | from torch.nn.attention import SDPBackend, sdpa_kernel 34 | from transformers import PretrainedConfig, PreTrainedModel 35 | from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast 36 | 37 | try: 38 | if TYPE_CHECKING: 39 | # We need to do this to avoid importing these when creating the HF-compatible models 40 | from src.config import ModelConfig 41 | except ImportError: 42 | pass 43 | 44 | ######################################################## 45 | # 46 | # Layer Normalization 47 | # 48 | ######################################################## 49 | 50 | 51 | class RMSNorm(torch.nn.Module): 52 | """Root Mean Square Layer Normalization. 53 | 54 | A variant of Layer Normalization that uses RMS statistics instead of mean/variance, 55 | resulting in improved stability and performance. 56 | 57 | Args: 58 | config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters 59 | - config.norm_eps: Small constant for numerical stability 60 | - config.d_model: Model dimension for the weight parameter 61 | 62 | References: 63 | https://arxiv.org/abs/1910.07467 64 | """ 65 | 66 | def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]): 67 | super().__init__() 68 | self.eps = config.norm_eps 69 | self.weight = nn.Parameter(torch.ones(config.d_model)) 70 | 71 | def _norm(self, x: torch.Tensor) -> torch.Tensor: 72 | """ 73 | Normalizes the input tensor by its RMS value. 74 | """ 75 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 76 | 77 | def forward(self, x: torch.Tensor) -> torch.Tensor: 78 | """ 79 | Applies RMS normalization to the input tensor and scales it by the weight parameter. 80 | """ 81 | output = self._norm(x.float()).type_as(x) 82 | return output * self.weight 83 | 84 | 85 | ######################################################## 86 | # 87 | # Positional Embedding 88 | # 89 | ######################################################## 90 | 91 | 92 | class RoPE(nn.Module): 93 | """Rotary Positional Embeddings (RoPE). 94 | 95 | Implements position-dependent rotation of keys and queries in attention mechanism, 96 | allowing better modeling of relative positions in sequences. Uses complex number 97 | operations for efficient rotation. 98 | 99 | Args: 100 | config (Union[ModelConfig, PicoHFConfig]): Model configuration containing: 101 | - config.position_emb_theta: Base for frequency computation 102 | - config.d_model: Model dimension 103 | - config.attention_n_heads: Number of attention heads 104 | - config.max_seq_len: Maximum sequence length 105 | 106 | References: 107 | https://arxiv.org/abs/2104.09864 108 | """ 109 | 110 | _freqs_cis_tensor: torch.Tensor | None = None 111 | 112 | def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]): 113 | super().__init__() 114 | 115 | self.theta = config.position_emb_theta 116 | self.dim = config.d_model // config.attention_n_heads 117 | 118 | max_seq_len = config.max_seq_len 119 | 120 | # only gets set once, and then reused for all RoPE instances 121 | if RoPE._freqs_cis_tensor is None: 122 | RoPE._freqs_cis_tensor = self._setup_freqs_cis( 123 | max_seq_len, self.theta, self.dim 124 | ) 125 | 126 | # register _freqs_cis buffer 127 | # can be easily recomputed so persistent=False 128 | self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False) 129 | 130 | @classmethod 131 | def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor: 132 | """Setup Frequency Tensor for RoPE Embeddings 133 | 134 | Initializes the complex frequency tensor that is used to compute the RoPE embeddings. 135 | 136 | Note other implementations will use cos and sin directly, but using the complex 137 | number representation is (probably?) more efficient: 138 | 139 | e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula] 140 | """ 141 | _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 142 | positions = torch.arange(seq_len) 143 | freqs = torch.outer(positions, _freqs) 144 | return torch.polar(torch.ones_like(freqs), freqs) # complex64 145 | 146 | def get_freqs_cis( 147 | self, input_shape: torch.Size, start_pos: int, end_pos: int 148 | ) -> torch.Tensor: 149 | """Reshape Frequency Tensor for RoPE Embeddings 150 | 151 | Makes the frequency tensor broadcastable with the input tensor. 152 | """ 153 | _freqs_cis = self._freqs_cis[start_pos:end_pos] 154 | ndim = len(input_shape) 155 | assert 0 <= 1 < ndim 156 | assert _freqs_cis.shape == (input_shape[1], input_shape[-1]) 157 | 158 | # TODO: Check whether this is correct (might be able to remove this) 159 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)] 160 | return _freqs_cis.view(*shape) 161 | 162 | def forward( 163 | self, 164 | queries: torch.Tensor, 165 | keys: torch.Tensor, 166 | start_pos: int = 0, 167 | ) -> Tuple[torch.Tensor, torch.Tensor]: 168 | """Apply RoPE Embeddings to Queries and Keys 169 | 170 | Applies the rotary positional embeddings to the input tensors via complex num multiplication 171 | 172 | NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism. 173 | """ 174 | queries_ = torch.view_as_complex( 175 | queries.float().reshape(*queries.shape[:-1], -1, 2) 176 | ) 177 | keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2)) 178 | 179 | input_shape = ( 180 | queries_.shape 181 | ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2) 182 | freqs_start_pos = start_pos 183 | freqs_end_pos = freqs_start_pos + queries_.shape[1] 184 | 185 | freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos) 186 | 187 | queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3) 188 | keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3) 189 | return queries_rotated.type_as(queries), keys_rotated.type_as(keys) 190 | 191 | 192 | ######################################################## 193 | # 194 | # Attention 195 | # 196 | ######################################################## 197 | 198 | 199 | class Attention(nn.Module): 200 | """Multi-head Attention with Group Query Attention support. 201 | 202 | Implements scaled dot-product attention and supports: 203 | - Grouped Query Attention (GQA) 204 | - Key-Value caching for efficient inference 205 | - RoPE integration 206 | 207 | Args: 208 | config (Union[ModelConfig, PretrainedConfig]): Configuration containing: 209 | - config.attention_n_heads: Number of attention heads 210 | - config.attention_n_kv_heads: Number of key/value heads 211 | - config.d_model: Model dimension 212 | - config.batch_size: Maximum batch size 213 | - config.max_seq_len: Maximum sequence length 214 | 215 | Shape: 216 | - Input: (batch_size, seq_len, d_model) 217 | - Output: (batch_size, seq_len, d_model) 218 | """ 219 | 220 | def __init__( 221 | self, 222 | config: Union["ModelConfig", "PicoDecoderHFConfig"], 223 | ): 224 | super().__init__() 225 | 226 | self.n_heads = config.attention_n_heads 227 | self.n_kv_heads = config.attention_n_kv_heads 228 | 229 | self.batch_size = config.batch_size 230 | self.max_seq_len = config.max_seq_len 231 | 232 | d_model = config.d_model 233 | self.head_dim = d_model // self.n_heads 234 | 235 | self.n_rep = self.n_heads // self.n_kv_heads 236 | 237 | self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False) 238 | self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False) 239 | self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False) 240 | self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False) 241 | 242 | self.rope = RoPE(config) 243 | 244 | def forward( 245 | self, 246 | input: torch.Tensor, 247 | mask: Optional[torch.Tensor] = None, 248 | past_key_values: Optional[Tuple[torch.Tensor, ...]] = None, 249 | use_cache: bool = False, 250 | ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: 251 | """Forward pass for the attention mechanism. 252 | 253 | Computes queries, keys, and values for the attention mechanism. Applies rotary positional 254 | embeddings to the queries and keys, and then computes attention scores and outputs. 255 | 256 | For an introduction to the attention mechanism, see: 257 | https://arxiv.org/abs/1706.03762 258 | 259 | A few things to note: 260 | - The past_key_values is used to implement the KV cache, which is used to speed up 261 | generation by caching the KV pairs from previous forward passes. This is useful when doing 262 | tasks that require generating multiple tokens conditioned on previous tokens (e.g. language 263 | modeling, text generation, etc.). The way the KV cache is implemented is that each layer has 264 | its own KV cache - this KV cache is implemented as a tuple. 265 | """ 266 | bsz, seq_len, _ = input.shape 267 | _queries, _keys, _values = ( 268 | self.q_proj(input), 269 | self.k_proj(input), 270 | self.v_proj(input), 271 | ) 272 | 273 | # Reshaping for multi-head attention 274 | queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim) 275 | keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim) 276 | values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim) 277 | 278 | # The start position is used to apply the RoPE embeddings to only the new tokens 279 | # when using the kv_cache in the attention mechanism. 280 | # We want to start from the last position in the cache. 281 | start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0 282 | 283 | # apply rotary positional embeddings 284 | queries, keys = self.rope(queries, keys, start_pos) 285 | 286 | if past_key_values is not None: 287 | keys = torch.cat([past_key_values[0], keys], dim=1) 288 | values = torch.cat([past_key_values[1], values], dim=1) 289 | 290 | if use_cache: 291 | cached_keys = keys 292 | cached_values = values 293 | else: 294 | cached_keys = None 295 | cached_values = None 296 | 297 | queries = queries.transpose(1, 2) 298 | keys = keys.transpose(1, 2) 299 | values = values.transpose(1, 2) 300 | 301 | apply_gqa = self.n_rep > 1 302 | if apply_gqa and queries.device.type == "mps": 303 | # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values 304 | # outside of the kernel to get the same effect. 305 | # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html 306 | keys = keys.repeat_interleave(self.n_rep, dim=-3) 307 | values = values.repeat_interleave(self.n_rep, dim=-3) 308 | apply_gqa = False 309 | 310 | backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH] 311 | 312 | with sdpa_kernel(backends=backends): 313 | attn_output = F.scaled_dot_product_attention( 314 | queries.contiguous(), 315 | keys.contiguous(), 316 | values.contiguous(), 317 | attn_mask=mask.to(queries.dtype), 318 | enable_gqa=apply_gqa, 319 | ) 320 | 321 | attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) 322 | output = self.o_proj(attn_output) 323 | 324 | return output, (cached_keys, cached_values) 325 | 326 | 327 | ######################################################## 328 | # 329 | # SwiGLU (Combines MLP and Activation) 330 | # 331 | ######################################################## 332 | 333 | 334 | class SwiGLU(nn.Module): 335 | """SwiGLU Activation Function with Linear Projections. 336 | 337 | Implements the SwiGLU activation function combined with linear transformations, 338 | serving as the feed-forward network in transformer blocks. 339 | 340 | Args: 341 | config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing: 342 | - config.d_model: Model dimension 343 | - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model) 344 | 345 | References: 346 | https://arxiv.org/abs/2002.05202 347 | """ 348 | 349 | def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]): 350 | super().__init__() 351 | 352 | model_dim = config.d_model 353 | act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model 354 | 355 | self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False) 356 | self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False) 357 | self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False) 358 | 359 | def forward(self, x: torch.Tensor) -> torch.Tensor: 360 | return self.w_2(F.silu(self.w_0(x)) * self.w_1(x)) 361 | 362 | 363 | ######################################################## 364 | # 365 | # PicoDecoderBlock 366 | # 367 | ######################################################## 368 | 369 | 370 | class PicoDecoderBlock(nn.Module): 371 | """Single Transformer Block with Attention and Feed-forward layers. 372 | 373 | Implements a standard transformer block with: 374 | - Multi-head attention with normalization and residual connection 375 | - SwiGLU feed-forward network with normalization and residual connection 376 | 377 | Args: 378 | config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or 379 | a HuggingFace PicoDecoderHFConfig 380 | """ 381 | 382 | def __init__( 383 | self, 384 | config: Union["ModelConfig", "PicoDecoderHFConfig"], 385 | ): 386 | super().__init__() 387 | 388 | self.attention = Attention(config) 389 | self.swiglu = SwiGLU(config) 390 | self.attention_norm = RMSNorm(config) 391 | self.swiglu_norm = RMSNorm(config) 392 | 393 | def forward( 394 | self, 395 | input: torch.Tensor, 396 | mask: Optional[torch.Tensor] = None, 397 | past_key_values: Optional[Tuple[torch.Tensor]] = None, 398 | use_cache: bool = False, 399 | ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: 400 | attention_output, cached_key_values = self.attention( 401 | self.attention_norm(input), 402 | mask=mask, 403 | past_key_values=past_key_values, 404 | use_cache=use_cache, 405 | ) 406 | # NOTE: cached_key_values is None if use_cache is False 407 | 408 | h = input + attention_output 409 | out = h + self.swiglu(self.swiglu_norm(h)) 410 | return out, cached_key_values 411 | 412 | 413 | ######################################################## 414 | # 415 | # Pico Decoder (Causal Transformer Model) 416 | # 417 | ######################################################## 418 | 419 | 420 | class PicoDecoder(nn.Module): 421 | """ 422 | Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a 423 | single autoregressive model. 424 | 425 | For more information on the model, see the classes for the modules that make up the model. 426 | """ 427 | 428 | def __init__( 429 | self, 430 | model_config: Union["ModelConfig", "PicoDecoderHFConfig"], 431 | ): 432 | super().__init__() 433 | self.config = model_config 434 | 435 | self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model) 436 | self.layers = nn.ModuleList( 437 | [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)] 438 | ) 439 | self.output_norm = RMSNorm(self.config) 440 | self.de_embedding_proj = nn.Linear( 441 | self.config.d_model, self.config.vocab_size, bias=False 442 | ) 443 | 444 | def convert_to_hf_model(self) -> "PicoDecoderHF": 445 | """Convert the Lightning model to a HuggingFace model.""" 446 | # Create HF config without fabric-specific settings 447 | hf_config = PicoDecoderHFConfig.from_dataclass(self.config) 448 | 449 | # Create new HF model 450 | hf_model = PicoDecoderHF(hf_config) 451 | 452 | # Copy state dict, excluding fabric-specific keys 453 | hf_model.load_state_dict(self.state_dict(prefix="pico_decoder.")) 454 | 455 | return hf_model 456 | 457 | def forward( 458 | self, 459 | input_ids: torch.Tensor, 460 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 461 | use_cache: bool = False, 462 | ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]: 463 | """ 464 | This is the forward pass for the entire Pico model. It boils down to: 465 | - Embedding the input ids 466 | - Creating a causal mask 467 | - Processing through the pico layers 468 | - Projecting the output to logits 469 | 470 | NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up 471 | generation by caching the KV pairs from previous forward passes. This is useful when doing 472 | tasks that require generating multiple tokens conditioned on previous tokens (e.g. language 473 | modeling, text generation, etc.). The way the KV cache is implemented is that each layer has 474 | its own KV cache which is stored as a tuple. The whole model then stores a tuple of these 475 | KV caches (so a tuple of tuples). 476 | """ 477 | 478 | seq_len = input_ids.shape[-1] 479 | h = self.embedding_proj(input_ids) 480 | 481 | # Calculate start position from past cached KV pairs. Remember that each layer has its 482 | # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the 483 | # correct layer and then for either the keys or values. 484 | start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1] 485 | 486 | # Create causal mask for current sequence 487 | mask = None 488 | if seq_len > 1: 489 | mask = torch.full((seq_len, seq_len), float("-inf")) 490 | mask = torch.triu(mask, diagonal=1) 491 | 492 | # If using KV cache, extend mask to cover cached sequence length 493 | if past_key_values is not None: 494 | # Add zeros for cached tokens (we can attend to all of them) 495 | mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask]) 496 | 497 | mask = mask.to(h.device) 498 | 499 | # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer 500 | # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple. 501 | cached_key_values = () if use_cache else None 502 | 503 | # Process through transformer blocks 504 | for idx, layer in enumerate(self.layers): 505 | layer_past_key_values = ( 506 | past_key_values[idx] if past_key_values is not None else None 507 | ) 508 | 509 | h, layer_cached_key_values = layer( 510 | h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache 511 | ) 512 | 513 | if use_cache: 514 | cached_key_values += (layer_cached_key_values,) 515 | 516 | # Final norm and projection 517 | h = self.output_norm(h) 518 | logits = self.de_embedding_proj(h).float() 519 | 520 | return logits, cached_key_values 521 | 522 | 523 | ######################################################## 524 | # 525 | # HuggingFace Wrapper for the Pico Decoder model. 526 | # 527 | ######################################################## 528 | 529 | 530 | class PicoDecoderHFConfig(PretrainedConfig): 531 | """Config class for the Pico Decoder HuggingFace wrapper.""" 532 | 533 | model_type = "pico_decoder" 534 | 535 | @classmethod 536 | def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig": 537 | """ 538 | Initialize config from a dictionary. Note that no kwargs are passed to the constructor -- 539 | this is because with some kwargs special handling is required and can make this class 540 | brittle. 541 | """ 542 | pico_config = cls(**config_dict) 543 | 544 | return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) 545 | unused_kwargs = { 546 | key: value for key, value in kwargs.items() if not hasattr(pico_config, key) 547 | } 548 | 549 | if return_unused_kwargs: 550 | return pico_config, unused_kwargs 551 | return pico_config 552 | 553 | @classmethod 554 | def from_dataclass(cls, model_config: "ModelConfig"): 555 | """Initialise from our custom config dataclass.""" 556 | return cls.from_dict(asdict(model_config)) 557 | 558 | 559 | class PicoDecoderHF(PreTrainedModel): 560 | """ 561 | HuggingFace wrapper for the Pico model. 562 | 563 | Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple 564 | wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal 565 | Pico model as well as the model wrapped in this HuggingFace class. 566 | 567 | This also lets you do cool things like: 568 | 569 | `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")` 570 | """ 571 | 572 | config_class = PicoDecoderHFConfig 573 | _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"] 574 | 575 | def __init__(self, config: PicoDecoderHFConfig): 576 | super().__init__(config) 577 | self.pico_decoder = PicoDecoder(config) 578 | 579 | def forward( 580 | self, 581 | input_ids: torch.Tensor, 582 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 583 | use_cache: bool = False, 584 | **kwargs, 585 | ) -> Union[CausalLMOutput, CausalLMOutputWithPast]: 586 | """HuggingFace forward pass wrapper. 587 | 588 | Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the 589 | Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput. 590 | """ 591 | logits, past_key_values = self.pico_decoder( 592 | input_ids, past_key_values, use_cache 593 | ) 594 | if use_cache: 595 | return CausalLMOutputWithPast( 596 | logits=logits, 597 | past_key_values=past_key_values, 598 | ) 599 | else: 600 | return CausalLMOutput( 601 | logits=logits, 602 | ) 603 | 604 | 605 | # Register for auto classes 606 | PicoDecoderHFConfig.register_for_auto_class() 607 | PicoDecoderHF.register_for_auto_class("AutoModel") 608 | PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM") 609 | -------------------------------------------------------------------------------- /src/training/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pico Language Model Trainer 3 | 4 | This Trainer implements a minimalistic end-to-end training pipeline of the Pico language model with 5 | distributed training support via Lightning Fabric. It provides a modular and configurable training 6 | pipeline with the features: 7 | 8 | - Configuration Management: YAML-based configuration for all aspects of training 9 | - Distributed Training: Multi-GPU support via Lightning Fabric 10 | - Checkpointing: Regular model saving and training state recovery 11 | - Evaluation: Periodic model evaluation on validation datasets 12 | - Logging: Comprehensive metric tracking and experiment monitoring 13 | - Optimization: Support for gradient accumulation, clipping, and LR scheduling 14 | """ 15 | 16 | import logging 17 | import os 18 | import platform 19 | from typing import Any, Dict 20 | 21 | import lightning as L 22 | import psutil 23 | import torch 24 | import torch.nn.functional as F 25 | import yaml 26 | from datasets import Dataset, load_dataset 27 | from lightning.fabric.utilities.rank_zero import rank_zero_only 28 | 29 | from src.checkpointing import ( 30 | compute_learning_dynamics_states, 31 | load_checkpoint, 32 | save_checkpoint, 33 | save_evaluation_results, 34 | save_learning_dynamics_states, 35 | ) 36 | from src.evaluation import run_evaluation 37 | from src.training.utils import ( 38 | initialize_configuration, 39 | initialize_dataloader, 40 | initialize_dataset, 41 | initialize_fabric, 42 | initialize_hf_checkpointing, 43 | initialize_logging, 44 | initialize_lr_scheduler, 45 | initialize_model, 46 | initialize_optimizer, 47 | initialize_run_dir, 48 | initialize_tokenizer, 49 | initialize_wandb, 50 | ) 51 | from src.training.utils.logging import pretty_print_yaml_config 52 | 53 | 54 | class Trainer: 55 | def __init__(self, config_path: str): 56 | """ 57 | Initializes the Trainer class. This Trainer class implements a `train` method, which is the 58 | main entry point for training the Pico model. Before calling `train`, the Trainer class 59 | initializes the following: 60 | 61 | - Configuration loading and validation 62 | - Model, optimizer, and dataset setup 63 | - Logging and experiment tracking setup 64 | - Checkpoint management 65 | 66 | Args: 67 | config_path (str): Path to the YAML configuration file containing any overrides. 68 | """ 69 | 70 | ######################################################## 71 | # 72 | # Basic Initialization of Configs, Fabric, Model, Optimizer, etc. 73 | # 74 | ######################################################## 75 | 76 | # Setup Config 77 | self.configs = initialize_configuration(config_path) 78 | 79 | # Setup Run Directory (i.e. where we store checkpoints, logs, etc.) 80 | initialize_run_dir(checkpointing_config=self.configs["checkpointing"]) 81 | 82 | # Setup Logger 83 | if self.configs["monitoring"].save_to_wandb: 84 | wandb_logger = initialize_wandb( 85 | monitoring_config=self.configs["monitoring"], 86 | checkpointing_config=self.configs["checkpointing"], 87 | ) 88 | else: 89 | wandb_logger = None 90 | 91 | # Setup Fabric 92 | self.fabric = initialize_fabric( 93 | training_config=self.configs["training"], 94 | wandb_logger=wandb_logger, 95 | ) 96 | L.seed_everything(42, verbose=False) 97 | 98 | # Set up logging 99 | self.logger = initialize_logging( 100 | monitoring_config=self.configs["monitoring"], 101 | checkpointing_config=self.configs["checkpointing"], 102 | fabric=self.fabric, 103 | ) 104 | 105 | # Setup Model, Optimizer, and Dataloaders 106 | self.model = initialize_model(model_config=self.configs["model"]) 107 | self.optimizer = initialize_optimizer( 108 | training_config=self.configs["training"], model=self.model 109 | ) 110 | self.lr_scheduler = initialize_lr_scheduler( 111 | training_config=self.configs["training"], optimizer=self.optimizer 112 | ) 113 | 114 | # Wrap model and optimizer with Fabric 115 | self.model, self.optimizer = self.fabric.setup(self.model, self.optimizer) 116 | 117 | # Setup HuggingFace Checkpointing 118 | if self.configs["checkpointing"].save_to_hf: 119 | initialize_hf_checkpointing( 120 | checkpointing_config=self.configs["checkpointing"], fabric=self.fabric 121 | ) 122 | 123 | ######################################################## 124 | # 125 | # Boilerplate to deal with loading/resuming from checkpoints 126 | # 127 | ######################################################## 128 | 129 | self.should_load_checkpoint = self.configs["checkpointing"].training.auto_resume 130 | 131 | # Possibly load a checkpoint 132 | if self.should_load_checkpoint: 133 | resume_checkpoint = load_checkpoint( 134 | checkpointing_config=self.configs["checkpointing"], 135 | checkpoint_step="latest", 136 | fabric=self.fabric, 137 | model=self.model, 138 | optimizer=self.optimizer, 139 | lr_scheduler=self.lr_scheduler, 140 | ) 141 | 142 | if resume_checkpoint: 143 | ( 144 | self.model, 145 | self.optimizer, 146 | self.lr_scheduler, 147 | self.initial_batch_step, 148 | ) = resume_checkpoint 149 | else: 150 | self.initial_batch_step = 0 151 | else: 152 | self.initial_batch_step = 0 153 | 154 | ######################################################## 155 | # 156 | # Initialization of Dataset & DataLoader (possibly fast-forwarding to correct batch) 157 | # 158 | ######################################################## 159 | 160 | self.train_dataset, fast_forward_steps = initialize_dataset( 161 | data_config=self.configs["data"], 162 | fabric=self.fabric, 163 | initial_batch_step=self.initial_batch_step, 164 | return_fast_forward_steps=True, 165 | ) 166 | 167 | self.train_dataloader = initialize_dataloader( 168 | data_config=self.configs["data"], 169 | training_config=self.configs["training"], 170 | fabric=self.fabric, 171 | dataset=self.train_dataset, 172 | ) 173 | self.train_dataloader = self.fabric.setup_dataloaders( 174 | self.train_dataloader, use_distributed_sampler=False 175 | ) 176 | 177 | self.tokenizer = initialize_tokenizer(data_config=self.configs["data"]) 178 | 179 | # NOTE: We may need to fast-forward the iterator to the correct step so that we can 180 | # continue from the correct batch of data we would have seen had training not 181 | # previously stopped. 182 | train_iterator = iter(self.train_dataloader) 183 | if fast_forward_steps > 0: 184 | fast_forward_sub_steps = ( 185 | fast_forward_steps 186 | * self.configs["training"].optimization.gradient_accumulation_steps 187 | ) 188 | for _ in range(fast_forward_sub_steps): 189 | next(train_iterator) 190 | 191 | self.train_iterator = train_iterator 192 | 193 | # NOTE: Sychronizing processes after fast-forwarding iterator 194 | self.fabric.barrier() 195 | 196 | ######################################################## 197 | # 198 | # Helper flags used during training for checkpointing and evaluation 199 | # 200 | ######################################################## 201 | 202 | # Helper flag to determine if we should evaluate the model 203 | self.should_evaluate = ( 204 | self.configs["evaluation"].metrics is not None 205 | and len(self.configs["evaluation"].metrics) > 0 206 | ) 207 | 208 | self.should_compute_learning_dynamics = ( 209 | self.configs["checkpointing"].learning_dynamics.layer_suffixes is not None 210 | and len(self.configs["checkpointing"].learning_dynamics.layer_suffixes) > 0 211 | ) 212 | 213 | if self.should_compute_learning_dynamics: 214 | if self.configs["checkpointing"].learning_dynamics.eval_data is not None: 215 | self.learning_dynamics_eval_dataset = load_dataset( 216 | self.configs["checkpointing"].learning_dynamics.eval_data, 217 | split="val", 218 | ) 219 | else: 220 | self.learning_dynamics_eval_dataset = None 221 | 222 | def train(self) -> None: 223 | """Execute the main training pipeline. 224 | 225 | This method orchestrates the complete training process by: 226 | 1. Creating an initial checkpoint to save the starting state and evaluate the model as a 227 | baseline 228 | 2. Running the main training loop via `_training_loop` 229 | 3. Handling final checkpointing and evaluation 230 | 231 | The training progress is tracked through checkpoints and evaluations 232 | at intervals specified in the configuration. 233 | """ 234 | 235 | ######################################################## 236 | # 237 | # Initial Checkpointing and Evaluation 238 | # 239 | ######################################################## 240 | 241 | # Save Initial Checkpoint -- If the checkpoint already exists, this performs a no-op 242 | save_checkpoint( 243 | configs=self.configs, 244 | checkpoint_step=self.initial_batch_step, 245 | fabric=self.fabric, 246 | model=self.model, 247 | optimizer=self.optimizer, 248 | lr_scheduler=self.lr_scheduler, 249 | tokenizer=self.tokenizer, 250 | ) 251 | 252 | # Save Initial Evaluation Results 253 | if self.should_evaluate: 254 | if self.initial_batch_step == 0: 255 | evaluation_results = run_evaluation( 256 | evaluation_config=self.configs["evaluation"], 257 | checkpointing_config=self.configs["checkpointing"], 258 | fabric=self.fabric, 259 | model=self.model, 260 | ) 261 | self._log_evaluation_results( 262 | evaluation_results, self.initial_batch_step 263 | ) 264 | save_evaluation_results( 265 | checkpointing_config=self.configs["checkpointing"], 266 | fabric=self.fabric, 267 | evaluation_results=evaluation_results, 268 | checkpoint_step=self.initial_batch_step, 269 | ) 270 | else: 271 | # NOTE: If the run crashed while evaluating, we need to restart the evaluation 272 | eval_results_path = os.path.join( 273 | self.configs["checkpointing"].evaluation.eval_results_dir, 274 | f"step_{self.initial_batch_step}.json", 275 | ) 276 | if not os.path.exists(eval_results_path): 277 | evaluation_results = run_evaluation( 278 | evaluation_config=self.configs["evaluation"], 279 | checkpointing_config=self.configs["checkpointing"], 280 | fabric=self.fabric, 281 | model=self.model, 282 | ) 283 | self._log_evaluation_results( 284 | evaluation_results, self.initial_batch_step 285 | ) 286 | save_evaluation_results( 287 | checkpointing_config=self.configs["checkpointing"], 288 | fabric=self.fabric, 289 | evaluation_results=evaluation_results, 290 | checkpoint_step=self.initial_batch_step, 291 | ) 292 | 293 | ######################################################## 294 | # 295 | # Main Training Loop (see `_training_loop` for details) 296 | # 297 | ######################################################## 298 | 299 | if self.initial_batch_step < self.configs["training"].max_steps: 300 | self._log_training_configuration() 301 | final_step = self._training_loop() 302 | else: 303 | final_step = self.initial_batch_step 304 | 305 | ######################################################## 306 | # 307 | # Final Checkpointing and Evaluation 308 | # 309 | ######################################################## 310 | 311 | # Save Learning Dynamics States 312 | if self.should_compute_learning_dynamics: 313 | if self.learning_dynamics_eval_dataset is not None: 314 | self.log(f"Step {final_step} -- 📈 Saving Learning Dynamics") 315 | learning_dynamics_val_states = compute_learning_dynamics_states( 316 | checkpointing_config=self.configs["checkpointing"], 317 | fabric=self.fabric, 318 | model=self.model, 319 | dataset=self.learning_dynamics_eval_dataset, 320 | compute_gradients=True, 321 | ) 322 | save_learning_dynamics_states( 323 | checkpointing_config=self.configs["checkpointing"], 324 | fabric=self.fabric, 325 | learning_dynamics_states=learning_dynamics_val_states, 326 | checkpoint_step=final_step, 327 | prefix="val", 328 | ) 329 | 330 | # Handle checkpointing and final evaluation 331 | if final_step % self.configs["checkpointing"].save_every_n_steps != 0: 332 | self.log(f"Step {final_step} -- 💾 Saving Final Checkpoint") 333 | save_checkpoint( 334 | configs=self.configs, 335 | checkpoint_step=final_step, 336 | fabric=self.fabric, 337 | model=self.model, 338 | optimizer=self.optimizer, 339 | lr_scheduler=self.lr_scheduler, 340 | tokenizer=self.tokenizer, 341 | ) 342 | 343 | # Final evaluation 344 | if self.should_evaluate: 345 | evaluation_results = run_evaluation( 346 | evaluation_config=self.configs["evaluation"], 347 | checkpointing_config=self.configs["checkpointing"], 348 | fabric=self.fabric, 349 | model=self.model, 350 | ) 351 | self._log_evaluation_results(evaluation_results, final_step) 352 | save_evaluation_results( 353 | checkpointing_config=self.configs["checkpointing"], 354 | checkpoint_step=final_step, 355 | fabric=self.fabric, 356 | evaluation_results=evaluation_results, 357 | ) 358 | 359 | self.log(f"🎉 Training complete! Final step: {final_step}") 360 | 361 | if final_step < self.configs["training"].max_steps: 362 | self.log( 363 | f"\t Note: Training stopped before max steps ({self.configs['training'].max_steps})", 364 | level=logging.WARNING, 365 | ) 366 | 367 | # Cleanup distributed training 368 | self.fabric.barrier() 369 | if torch.cuda.is_available(): 370 | torch.cuda.empty_cache() 371 | if torch.distributed.is_initialized(): 372 | torch.distributed.destroy_process_group() 373 | 374 | del self.train_dataloader # NOTE: shutting down worker nodes 375 | 376 | self.fabric.barrier() 377 | 378 | def _training_loop(self) -> int: 379 | """Execute the main training loop. 380 | 381 | This method orchestrates the core training loop and includes the following features: 382 | - Gradient accumulation 383 | - Gradient clipping 384 | - Periodic model evaluation and checkpointing 385 | - Learning Dynamics Checkpointing 386 | - Learning rate scheduling 387 | - Logging of training metrics including loss and learning rate 388 | - Handling of infinite/NaN losses 389 | 390 | Returns: 391 | int: The final step count reached during training. 392 | NOTE: A complete training run should match the configured max_steps. 393 | """ 394 | # Setup training loop variables 395 | batch_step = self.initial_batch_step 396 | 397 | # NOTE: these are used to compute the average loss over a training interval. 398 | # This is more accurate than using the loss at the end of the interval. 399 | interval_loss = torch.tensor(0.0, device=self.fabric.device) 400 | interval_steps = torch.tensor(0, device=self.fabric.device) 401 | interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device) 402 | 403 | if self.should_compute_learning_dynamics: 404 | # NOTE: we basically re-construct the full batch here so that we can compute learning dynamics 405 | training_batch = {"input_ids": []} 406 | 407 | # NOTE: determine what sub-batch we should start from 408 | initial_sub_batch_step = ( 409 | batch_step 410 | * self.configs["training"].optimization.gradient_accumulation_steps 411 | ) 412 | 413 | ############################################################### 414 | # 415 | # Core loop starts here 416 | # NOTE: the ratio between sub_batch_step and batch_step 417 | # is the configured number of gradient_accumulation_steps 418 | # i.e. with 32 configured gradient accumulation steps, 419 | # there are 32 sub_batch_steps for each batch_step 420 | # 421 | ############################################################### 422 | 423 | for sub_batch_step, sub_batch in enumerate( 424 | self.train_iterator, start=initial_sub_batch_step 425 | ): 426 | # NOTE: We want to store the entire training batch whenever we are computing learning dynamics 427 | # and we are at a checkpointing step. 428 | should_store_training_batch = self.should_compute_learning_dynamics and ( 429 | batch_step % self.configs["checkpointing"].save_every_n_steps == 0 430 | ) 431 | 432 | ######################################################## 433 | # 434 | # Forward Pass 435 | # 436 | ######################################################## 437 | 438 | _input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device) 439 | input_ids = _input_ids[:, :-1] 440 | labels = _input_ids[:, 1:] 441 | 442 | if should_store_training_batch: 443 | gathered_input_ids = self.fabric.all_gather(_input_ids) 444 | 445 | # NOTE: On multi-GPU, we need to reshape the input_ids to be a 2D tensor; on 446 | # a single GPU, the input_ids are already a 2D tensor. 447 | if self.fabric.world_size > 1: 448 | gathered_input_ids = gathered_input_ids.reshape( 449 | -1, *gathered_input_ids.shape[2:] 450 | ) 451 | 452 | training_batch["input_ids"].extend(gathered_input_ids.tolist()) 453 | 454 | # Forward pass 455 | model_output, _ = self.model(input_ids) 456 | model_output = model_output.transpose(1, 2) 457 | 458 | ######################################################## 459 | # 460 | # Gradient accumulation 461 | # 462 | ######################################################## 463 | 464 | should_accumulate_gradients = (sub_batch_step + 1) % self.configs[ 465 | "training" 466 | ].optimization.gradient_accumulation_steps != 0 467 | 468 | with self.fabric.no_backward_sync( 469 | self.model, enabled=should_accumulate_gradients 470 | ): 471 | loss = F.cross_entropy(model_output, labels) 472 | self.fabric.backward( 473 | loss 474 | / self.configs["training"].optimization.gradient_accumulation_steps, 475 | model=self.model, 476 | ) 477 | 478 | if torch.isnan(loss) or torch.isinf(loss): 479 | interval_inf_or_nan_count += 1 480 | else: 481 | interval_loss += loss.item() 482 | interval_steps += 1 483 | 484 | # NOTE: if we are not accumulating gradients, we should skip the logging and optimization steps 485 | if should_accumulate_gradients: 486 | continue 487 | 488 | ######################################################## 489 | # 490 | # Logging 491 | # 492 | ######################################################## 493 | 494 | if batch_step % self.configs["monitoring"].logging.log_every_n_steps == 0: 495 | self._log_training_metrics( 496 | interval_loss=interval_loss, 497 | interval_steps=interval_steps, 498 | interval_inf_or_nan_count=interval_inf_or_nan_count, 499 | batch_step=batch_step, 500 | ) 501 | interval_loss = torch.tensor(0.0, device=self.fabric.device) 502 | interval_steps = torch.tensor(0, device=self.fabric.device) 503 | interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device) 504 | 505 | ######################################################## 506 | # 507 | # Learning Dynamics Checkpointing 508 | # 509 | ######################################################## 510 | 511 | if batch_step % self.configs["checkpointing"].save_every_n_steps == 0: 512 | if self.should_compute_learning_dynamics: 513 | self.log(f"Step {batch_step} -- 📈 Saving Learning Dynamics") 514 | 515 | # Training Batch Learning Dynamics 516 | training_batch_dataset = Dataset.from_dict(training_batch) 517 | 518 | learning_dynamics_train_states = compute_learning_dynamics_states( 519 | checkpointing_config=self.configs["checkpointing"], 520 | fabric=self.fabric, 521 | model=self.model, 522 | dataset=training_batch_dataset, 523 | compute_gradients=True, 524 | ) 525 | 526 | save_learning_dynamics_states( 527 | checkpointing_config=self.configs["checkpointing"], 528 | checkpoint_step=batch_step, 529 | prefix="train", 530 | fabric=self.fabric, 531 | learning_dynamics_states=learning_dynamics_train_states, 532 | learning_dynamics_dataset=training_batch_dataset, 533 | tokenizer=self.tokenizer, 534 | ) 535 | training_batch = { 536 | "input_ids": [] 537 | } # Resetting training_batch for next training batch 538 | 539 | # Validation Data Learning Dynamics 540 | if self.learning_dynamics_eval_dataset is not None: 541 | learning_dynamics_val_states = compute_learning_dynamics_states( 542 | checkpointing_config=self.configs["checkpointing"], 543 | fabric=self.fabric, 544 | model=self.model, 545 | dataset=self.learning_dynamics_eval_dataset, 546 | compute_gradients=True, 547 | ) 548 | save_learning_dynamics_states( 549 | checkpointing_config=self.configs["checkpointing"], 550 | checkpoint_step=batch_step, 551 | prefix="val", 552 | fabric=self.fabric, 553 | learning_dynamics_states=learning_dynamics_val_states, 554 | ) 555 | 556 | ######################################################## 557 | # 558 | # Optimization step 559 | # 560 | ######################################################## 561 | 562 | self.optimizer.step() 563 | self.optimizer.zero_grad() 564 | self.lr_scheduler.step() 565 | 566 | batch_step += 1 567 | 568 | ######################################################## 569 | # 570 | # Training Checkpointing and evaluation 571 | # 572 | ######################################################## 573 | 574 | if batch_step % self.configs["checkpointing"].save_every_n_steps == 0: 575 | self.log(f"Step {batch_step} -- 💾 Saving Checkpoint") 576 | save_checkpoint( 577 | configs=self.configs, 578 | checkpoint_step=batch_step, 579 | fabric=self.fabric, 580 | model=self.model, 581 | optimizer=self.optimizer, 582 | lr_scheduler=self.lr_scheduler, 583 | tokenizer=self.tokenizer, 584 | ) 585 | 586 | if self.should_evaluate: 587 | evaluation_results = run_evaluation( 588 | evaluation_config=self.configs["evaluation"], 589 | checkpointing_config=self.configs["checkpointing"], 590 | fabric=self.fabric, 591 | model=self.model, 592 | ) 593 | if evaluation_results is not None: 594 | self._log_evaluation_results(evaluation_results, batch_step) 595 | save_evaluation_results( 596 | checkpointing_config=self.configs["checkpointing"], 597 | fabric=self.fabric, 598 | evaluation_results=evaluation_results, 599 | checkpoint_step=batch_step, 600 | ) 601 | 602 | # Break if we've reached training steps 603 | if batch_step >= self.configs["training"].max_steps: 604 | break 605 | 606 | return batch_step 607 | 608 | ######################################################## 609 | # 610 | # Trainer Logging Functinalities 611 | # 612 | ######################################################## 613 | 614 | def _log_training_metrics( 615 | self, 616 | interval_loss: torch.Tensor, 617 | interval_steps: torch.Tensor, 618 | interval_inf_or_nan_count: torch.Tensor, 619 | batch_step: int, 620 | ): 621 | """ 622 | Gathers together the training metrics computed across all processes in distributed training 623 | and logs them in a tree-style format. 624 | """ 625 | gathered_interval_loss = self.fabric.all_reduce( 626 | interval_loss, reduce_op="sum" 627 | ).item() 628 | gathered_interval_inf_or_nan_count = self.fabric.all_reduce( 629 | interval_inf_or_nan_count, reduce_op="sum" 630 | ).item() 631 | gathered_interval_steps = self.fabric.all_reduce( 632 | interval_steps, reduce_op="sum" 633 | ).item() 634 | 635 | avg_loss = ( 636 | gathered_interval_loss / gathered_interval_steps 637 | if gathered_interval_steps > 0 638 | else float("inf") 639 | ) 640 | 641 | self.fabric.log("train/loss", avg_loss, step=batch_step) 642 | self.fabric.log( 643 | "trainer/inf_or_nan_count", 644 | gathered_interval_inf_or_nan_count, 645 | step=batch_step, 646 | ) 647 | self.fabric.log( 648 | "trainer/learning_rate", 649 | self.lr_scheduler.get_last_lr()[0], 650 | step=batch_step, 651 | ) 652 | 653 | # Log to console in tree format 654 | self.log(f"Step {batch_step} -- 🔄 Training Metrics") 655 | self.log(f"├── Loss: {avg_loss:.4f}") 656 | self.log(f"├── Learning Rate: {self.lr_scheduler.get_last_lr()[0]:.2e}") 657 | self.log(f"└── Inf/NaN count: {gathered_interval_inf_or_nan_count}") 658 | 659 | def _log_evaluation_results( 660 | self, evaluation_results: Dict[str, Any], batch_step: int 661 | ): 662 | """Log model evaluation metrics to experiment tracking system and console.""" 663 | self.log(f"Step {batch_step} -- 📊 Evaluation Results") 664 | for i, (metric, result) in enumerate(evaluation_results.items()): 665 | prefix = "└──" if i == len(evaluation_results) - 1 else "├──" 666 | self.log(f"{prefix} {metric}: {result}") 667 | self.fabric.log(f"eval/{metric}", result, step=batch_step) 668 | 669 | def _log_training_configuration(self): 670 | """ 671 | Log training configuration details as well as runtime information about the hardware, 672 | software, and batch settings. 673 | 674 | This function is called at the beginning of the training loop to provide a summary of the 675 | training configuration. 676 | """ 677 | 678 | total_params = sum(p.numel() for p in self.model.parameters()) 679 | trainable_params = sum( 680 | p.numel() for p in self.model.parameters() if p.requires_grad 681 | ) 682 | global_batch_size = self.configs["data"].dataloader.batch_size 683 | per_device_batch_size = self.train_dataloader.batch_size 684 | gradient_accumulation_steps = self.configs[ 685 | "training" 686 | ].optimization.gradient_accumulation_steps 687 | 688 | device_type = "" 689 | fabric_device = str(self.fabric.device) 690 | if torch.cuda.is_available() and "cuda" in fabric_device: 691 | device_type = torch.cuda.get_device_name(self.fabric.device) 692 | elif torch.backends.mps.is_available() and "mps" in fabric_device: 693 | device_type = "MPS (Apple Silicon)" 694 | else: 695 | device_type = "CPU" 696 | 697 | training_config_path = os.path.join( 698 | self.configs["checkpointing"].runs_dir, 699 | self.configs["checkpointing"].run_name, 700 | "training_config.yaml", 701 | ) 702 | if os.path.exists(training_config_path): 703 | self.log("=" * 50) 704 | self.log("✨ Training Configuration") 705 | self.log("=" * 50) 706 | training_config = yaml.safe_load(open(training_config_path, "r")) 707 | pretty_print_yaml_config(self.logger, training_config) 708 | 709 | self.log("=" * 50) 710 | self.log("⛭ Runtime Summary:") 711 | self.log("=" * 50) 712 | self.log(f"Starting from step: {self.initial_batch_step}") 713 | 714 | self.log("Model Setup:") 715 | self.log(f"└─ Total Parameters: {total_params:,}") 716 | self.log(f"└─ Trainable Parameters: {trainable_params:,}") 717 | 718 | self.log("Distributed Setup:") 719 | self.log(f"└─ Number of Devices: {self.fabric.world_size}") 720 | self.log(f"└─ Device Type: {device_type}") 721 | self.log( 722 | f"└─ Available Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB" 723 | if torch.cuda.is_available() 724 | else f"└─ Available Memory: {psutil.virtual_memory().total / 1e9:.2f} GB" 725 | ) 726 | 727 | self.log("Software Setup:") 728 | self.log(f"└─ Python Version: {platform.python_version()}") 729 | self.log(f"└─ PyTorch Version: {torch.__version__}") 730 | self.log( 731 | f"└─ CUDA Version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}" 732 | ) 733 | self.log(f"└─ Operating System: {platform.system()} {platform.release()}") 734 | 735 | self.log("Batch Size Configuration:") 736 | self.log(f"└─ Global Batch Size: {global_batch_size}") 737 | self.log(f"└─ Per Device Batch Size: {per_device_batch_size}") 738 | self.log(f"└─ Gradient Accumulation Steps: {gradient_accumulation_steps}") 739 | self.log("=" * 50) 740 | 741 | @rank_zero_only 742 | def log(self, msg: str, level: int = logging.INFO) -> None: 743 | """NOTE: Log messages only from rank zero process.""" 744 | self.logger.log(level, msg) 745 | -------------------------------------------------------------------------------- /src/training/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility package that contains functions for the training process, e.g. initialization, logging, etc. 3 | """ 4 | 5 | # For convenience, we export the initialization functions here 6 | from .initialization import ( 7 | initialize_configuration, 8 | initialize_dataloader, 9 | initialize_dataset, 10 | initialize_fabric, 11 | initialize_hf_checkpointing, 12 | initialize_logging, 13 | initialize_lr_scheduler, 14 | initialize_model, 15 | initialize_optimizer, 16 | initialize_run_dir, 17 | initialize_tokenizer, 18 | initialize_wandb, 19 | ) 20 | 21 | __all__ = [ 22 | "initialize_configuration", 23 | "initialize_dataloader", 24 | "initialize_dataset", 25 | "initialize_fabric", 26 | "initialize_hf_checkpointing", 27 | "initialize_logging", 28 | "initialize_lr_scheduler", 29 | "initialize_model", 30 | "initialize_optimizer", 31 | "initialize_run_dir", 32 | "initialize_tokenizer", 33 | "initialize_wandb", 34 | ] 35 | -------------------------------------------------------------------------------- /src/training/utils/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for data loading and processing. 3 | """ 4 | 5 | from torch.utils.data import IterableDataset 6 | 7 | 8 | class ShardedIterableDataset(IterableDataset): 9 | """ 10 | A super simple implementation of a sharded iterable dataset that enables DataParallelism 11 | across multiple workers. Ensures that each worker gets a unique shard of the dataset. 12 | 13 | NOTE: Also works fine if there is only one worker. 14 | """ 15 | 16 | def __init__(self, dataset, rank, world_size): 17 | self.dataset = dataset 18 | self.rank = rank 19 | self.world_size = world_size 20 | 21 | def __iter__(self): 22 | iterator = iter(self.dataset) 23 | # NOTE: Start by skipping to this worker's shard 24 | for _ in range(self.rank): 25 | next(iterator) 26 | 27 | # NOTE: Yield every world_size-th item 28 | while True: 29 | try: 30 | yield next(iterator) 31 | # Skip other workers' samples 32 | for _ in range(self.world_size - 1): 33 | next(iterator) 34 | except StopIteration: 35 | break 36 | -------------------------------------------------------------------------------- /src/training/utils/initialization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for initializing components of the training process. 3 | 4 | Here, we initialize all of the components that are part of the learning process. From logging, 5 | and checkpointing to the optimizer to the dataset and the dataloader, this file contains the 6 | logic for setting up the classes and functions that are used in the training loop. 7 | 8 | As always, this code is meant to be basic. We hard-code the obvious defaults, and leave the 9 | more experimental stuff to you. 10 | """ 11 | 12 | import logging 13 | import os 14 | import warnings 15 | from dataclasses import fields, is_dataclass 16 | from datetime import datetime 17 | from typing import Dict, Optional, Union 18 | 19 | import lightning as L 20 | import torch 21 | import wandb 22 | import yaml 23 | from datasets import Dataset, DownloadConfig, load_dataset 24 | from datasets import config as datasets_config 25 | from huggingface_hub import add_collection_item, create_branch, create_repo 26 | from lightning.fabric.loggers import Logger as FabricLogger 27 | from lightning.fabric.utilities.rank_zero import rank_zero_only 28 | from torch.utils.data import DataLoader 29 | from transformers import AutoTokenizer 30 | from wandb.integration.lightning.fabric import WandbLogger 31 | 32 | from src.config import ( 33 | CheckpointingConfig, 34 | DataConfig, 35 | EvaluationConfig, 36 | ModelConfig, 37 | MonitoringConfig, 38 | TrainingConfig, 39 | ) 40 | from src.model import PicoDecoder 41 | from src.training.utils.io import use_backoff 42 | 43 | warnings.filterwarnings( 44 | "ignore", 45 | message=".*This integration is tested and supported for lightning Fabric.*", 46 | ) 47 | warnings.filterwarnings( 48 | "ignore", 49 | message=".*Please report any issues to.*", 50 | ) 51 | 52 | ######################################################## 53 | # 54 | # Basic Initialization 55 | # 56 | ######################################################## 57 | 58 | 59 | def _apply_config_overrides(config, overrides: dict): 60 | """Recursively apply configuration overrides to a dataclass config object. 61 | 62 | Args: 63 | config: Base configuration object (must be a dataclass) 64 | overrides: Dictionary of override values matching config structure 65 | 66 | Returns: 67 | Modified config object with overrides to the config. 68 | """ 69 | for field in fields(config): 70 | field_value = getattr(config, field.name) 71 | if is_dataclass(field_value): 72 | _apply_config_overrides(field_value, overrides.get(field.name, {})) 73 | else: 74 | if field.name in overrides: 75 | setattr(config, field.name, overrides[field.name]) 76 | return config 77 | 78 | 79 | def initialize_configuration( 80 | config_path: Optional[str] = None, 81 | ) -> Dict[ 82 | str, 83 | Union[ 84 | DataConfig, 85 | ModelConfig, 86 | TrainingConfig, 87 | EvaluationConfig, 88 | MonitoringConfig, 89 | CheckpointingConfig, 90 | ], 91 | ]: 92 | """Initialize configuration objects with optional overrides from a YAML file. 93 | 94 | This function initializes all of the configuration objects, and then applies 95 | any overrides from the config_path file. If no config_path is provided, 96 | the function will use the default configuration objects. 97 | 98 | Args: 99 | config_path: Path to a YAML file containing configuration overrides. 100 | 101 | Returns: 102 | A dictionary containing the initialized configuration objects. 103 | """ 104 | data_config = DataConfig() 105 | model_config = ModelConfig() 106 | training_config = TrainingConfig() 107 | evaluation_config = EvaluationConfig() 108 | monitoring_config = MonitoringConfig() 109 | checkpointing_config = CheckpointingConfig() 110 | 111 | if config_path: 112 | overrides = yaml.safe_load(open(config_path, "r")) 113 | data_config = _apply_config_overrides(data_config, overrides.get("data", {})) 114 | model_config = _apply_config_overrides(model_config, overrides.get("model", {})) 115 | training_config = _apply_config_overrides( 116 | training_config, overrides.get("training", {}) 117 | ) 118 | evaluation_config = _apply_config_overrides( 119 | evaluation_config, overrides.get("evaluation", {}) 120 | ) 121 | monitoring_config = _apply_config_overrides( 122 | monitoring_config, overrides.get("monitoring", {}) 123 | ) 124 | checkpointing_config = _apply_config_overrides( 125 | checkpointing_config, overrides.get("checkpointing", {}) 126 | ) 127 | 128 | configs = { 129 | "data": data_config, 130 | "model": model_config, 131 | "training": training_config, 132 | "evaluation": evaluation_config, 133 | "monitoring": monitoring_config, 134 | "checkpointing": checkpointing_config, 135 | } 136 | 137 | return configs 138 | 139 | 140 | def initialize_run_dir(checkpointing_config: CheckpointingConfig) -> str: 141 | """Initialize a directory for the current training run. 142 | 143 | Creates a unique directory for storing training, evaluation, and logging artifacts. 144 | If no run name is specified in the config, generates a timestamp-based name. 145 | 146 | Args: 147 | checkpointing_config: Configuration object containing run settings. 148 | NOTE: Must have a 'run_name' attribute that can be None, in which case 149 | a timestamp-based name will be generated. 150 | 151 | Returns: 152 | str: The path to the run directory. 153 | """ 154 | run_name = checkpointing_config.run_name 155 | if run_name is None: 156 | run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 157 | checkpointing_config.run_name = run_name 158 | 159 | run_dir = os.path.join(checkpointing_config.runs_dir, run_name) 160 | 161 | os.makedirs(run_dir, exist_ok=True) 162 | return run_dir 163 | 164 | 165 | def initialize_fabric( 166 | training_config: TrainingConfig, wandb_logger: Optional[FabricLogger] = None 167 | ): 168 | """Initialize Lightning Fabric for distributed training. 169 | 170 | Sets up a Lightning Fabric instance with the specified configuration for 171 | handling distributed training, mixed precision, and logging. 172 | 173 | Args: 174 | training_config: Configuration object containing fabric settings 175 | (accelerator, precision, devices, etc.). 176 | wandb_logger: Optional weights and biases logger instance for experiment tracking 177 | 178 | Returns: 179 | L.Fabric: Initialized Lightning Fabric instance. 180 | 181 | Example: 182 | >>> fabric = initialize_fabric(training_config, wandb_logger) 183 | """ 184 | 185 | total_devices = ( 186 | training_config.fabric.num_devices * training_config.fabric.num_nodes 187 | ) 188 | 189 | if total_devices > 1: 190 | strategy = "deepspeed_stage_2" 191 | else: 192 | strategy = "auto" # Sets up SingleDevice Strategy by default 193 | 194 | # NOTE: The strategy is set to use either DeepSpeed (Zero Stage 2) on multi-GPU, 195 | # or SingleDevice Strategy on single-GPU set ups. If you'd like to use a different strategy, 196 | # you can change the strategy flag in the fabric initialization, but be aware that this might 197 | # cause issues with checkpointing, evaluation, etc. 198 | 199 | fabric = L.Fabric( 200 | accelerator=training_config.fabric.accelerator, 201 | precision=training_config.fabric.precision, 202 | devices=training_config.fabric.num_devices, 203 | num_nodes=training_config.fabric.num_nodes, 204 | loggers=[wandb_logger] if wandb_logger is not None else None, 205 | strategy=strategy, 206 | ) 207 | 208 | fabric.launch() 209 | 210 | return fabric 211 | 212 | 213 | ######################################################## 214 | # 215 | # Dataset and Tokenization Initialization 216 | # 217 | ######################################################## 218 | 219 | 220 | @use_backoff(max_retries=20) 221 | def initialize_dataset( 222 | data_config: DataConfig, 223 | fabric: L.Fabric, 224 | initial_batch_step: Optional[int] = 0, 225 | return_fast_forward_steps: bool = False, 226 | ): 227 | """Initialize dataset based on the given config. 228 | 229 | This function will return a dataset object, and optionally a fast_forward_steps value. 230 | 231 | The fast_forward_steps value is the number of steps that we need to fast-forward an iterator by, 232 | so that we can continue from a ertain batch of data we would have seen had training not previously 233 | stopped. Depending on how the dataset is loaded, the amount of steps to fast-forward may be 234 | different from the initial_batch_step value. 235 | 236 | NOTE: This functionality is primarily useful for streaming datasets (which for large 237 | datasets is most of the time). 238 | 239 | Args: 240 | data_config: Configuration object containing dataset settings. 241 | fabric: A Lightning Fabric instance. 242 | initial_batch_step: The initial batch step to fast-forward to. 243 | return_fast_forward_steps: Whether to return the fast-forward steps value. 244 | 245 | Returns: 246 | Dataset: Initialized dataset object. 247 | Optional[int]: Number of steps to fast-forward the iterator by, if return_fast_forward_steps is True. 248 | """ 249 | 250 | datasets_config.STREAMING_READ_MAX_RETRIES = 40 # default is 20 251 | datasets_config.STREAMING_READ_RETRY_INTERVAL = 10 # default is 5 252 | download_config = DownloadConfig( 253 | max_retries=20, # default is 1 and can lead to pre-mature HTTPS errors 254 | ) 255 | 256 | fast_forward_steps = 0 257 | 258 | if data_config.dataset.name == "pico-lm/pretokenized-dolma": 259 | # NOTE: We know that the dataset is sharded into 10,000 shards, so we can easily compute 260 | # the data file that we need to load in that contains the batch of data at 261 | # initial_batch_step. 262 | 263 | if initial_batch_step is not None: 264 | examples_per_shard = 20_480 265 | total_shards = 10_000 266 | batches_per_shard = examples_per_shard // data_config.dataloader.batch_size 267 | shard_idx = initial_batch_step // batches_per_shard 268 | 269 | data_files = [ 270 | f"data/train-{str(_shard_idx).zfill(5)}-of-{total_shards}.parquet" 271 | for _shard_idx in range(shard_idx, total_shards) 272 | ] 273 | 274 | fast_forward_steps = initial_batch_step % batches_per_shard 275 | else: 276 | data_files = None 277 | 278 | base_dataset = load_dataset( 279 | data_config.dataset.name, 280 | split="train", 281 | streaming=True, 282 | data_files=data_files, 283 | download_config=download_config, 284 | ) 285 | else: 286 | # NOTE: For other datasets, you might want to add some custom loading logic, especially 287 | # to help with loading or fast-forwarding to the correct batch. 288 | 289 | base_dataset = load_dataset( 290 | data_config.dataset.name, 291 | split="train", 292 | streaming=True, 293 | download_config=download_config, 294 | ) 295 | 296 | if data_config.dataset.name == "pico-lm/pretokenized-dolma": 297 | from .data import ShardedIterableDataset 298 | 299 | # NOTE: We wrap the dataset in a ShardedIterableDataset, which is a custom class that 300 | # allows us to shard an iterable dataset across multiple processes. This is useful for 301 | # distributed training, where we want data-parallelism. 302 | dataset = ShardedIterableDataset( 303 | base_dataset, fabric.global_rank, fabric.world_size 304 | ) 305 | else: 306 | dataset = base_dataset 307 | 308 | if return_fast_forward_steps: 309 | return dataset, fast_forward_steps 310 | else: 311 | return dataset 312 | 313 | 314 | def initialize_tokenizer(data_config: DataConfig): 315 | """Initialize the tokenizer for text processing. 316 | 317 | This function can be extended to include custom tokenization logic. 318 | 319 | Args: 320 | data_config: Configuration object containing tokenizer settings. 321 | 322 | Returns: 323 | AutoTokenizer: A HuggingFace tokenizer instance. 324 | """ 325 | 326 | return AutoTokenizer.from_pretrained(data_config.tokenizer.name) 327 | 328 | 329 | def initialize_dataloader( 330 | data_config: DataConfig, 331 | training_config: TrainingConfig, 332 | fabric: L.Fabric, 333 | dataset: Dataset, 334 | ): 335 | """Initialize the DataLoader for efficient batch processing. 336 | 337 | Creates a PyTorch DataLoader that handles batching and data loading for training. 338 | Configured specifically for streaming tokenized text datasets. 339 | 340 | You might also want to extend this function to add a sampler, or some sort of custom 341 | collate function. For the default dataset, we don't need any of this, because the data are 342 | pre-shuffled, and pre-tokenized. 343 | 344 | Args: 345 | data_config: Configuration object containing dataloader settings. 346 | training_config: Configuration object containing training settings. 347 | fabric: A Lightning Fabric instance. 348 | dataset: A HuggingFace Dataset object containing tokenized text data. 349 | Expected to have 'input_ids' field in its items. 350 | 351 | Returns: 352 | DataLoader: PyTorch DataLoader instance configured for the dataset. 353 | """ 354 | 355 | def _collate_fn(batch): 356 | return {"input_ids": [entry["input_ids"] for entry in batch]} 357 | 358 | sub_batch_size = data_config.dataloader.batch_size // ( 359 | fabric.world_size * training_config.optimization.gradient_accumulation_steps 360 | ) 361 | 362 | # NOTE: We use the sub-batch size for the dataloader, which is the full batch size 363 | # divided by the gradient accumulation steps. This ensures that the effective batch size 364 | # is correct. 365 | 366 | return DataLoader( 367 | dataset, 368 | batch_size=sub_batch_size, 369 | shuffle=False, # Keep sequential for streaming datasets 370 | pin_memory=True, # Speeds up transfer to GPU 371 | collate_fn=_collate_fn, 372 | ) 373 | 374 | 375 | ######################################################## 376 | # 377 | # Model Initialization 378 | # 379 | ######################################################## 380 | 381 | 382 | def initialize_model(model_config: ModelConfig): 383 | """Initialize the model for training. 384 | 385 | Loads in a given model implemented in the `src.model` package and returns it. 386 | 387 | NOTE: out of the box we currently only support the PicoDecoder model (a causal transformer 388 | language model). If you'd like to implement your own model, you can do so by adding a new 389 | model class in the `src.model` package, and then adding a new entry here. 390 | 391 | Args: 392 | model_config: Configuration object containing model settings. 393 | 394 | Returns: 395 | PyTorch model instance. 396 | 397 | """ 398 | if model_config.model_type == "pico_decoder": 399 | return PicoDecoder(model_config) 400 | else: 401 | raise ValueError(f"Invalid model type: {model_config.model_type}") 402 | 403 | 404 | ######################################################## 405 | # 406 | # Optimizer and Scheduler 407 | # 408 | ######################################################## 409 | 410 | 411 | def initialize_optimizer(training_config: TrainingConfig, model: torch.nn.Module): 412 | """Initialize the optimizer for model training. 413 | 414 | Creates an optimizer instance based on the configuration settings. 415 | 416 | Add whatever other optimizers you want here. 417 | 418 | Args: 419 | training_config: Configuration object containing optimizer settings. 420 | Must have: 421 | - optimization.optimizer (str): Name of the optimizer ("adamw") 422 | - optimization.lr (float): Learning rate for the optimizer 423 | model: PyTorch model whose parameters will be optimized. 424 | 425 | Returns: 426 | torch.optim.Optimizer: Configured optimizer instance. 427 | 428 | """ 429 | 430 | if training_config.optimization.optimizer == "adamw": 431 | optimizer = torch.optim.AdamW( 432 | model.parameters(), lr=training_config.optimization.lr 433 | ) 434 | else: 435 | raise ValueError(f"Invalid optimizer: {training_config.optimization.optimizer}") 436 | 437 | return optimizer 438 | 439 | 440 | def initialize_lr_scheduler( 441 | training_config: TrainingConfig, optimizer: torch.optim.Optimizer 442 | ): 443 | """Initialize a learning rate scheduler with warmup and decay. 444 | 445 | The default is a learning rate scheduler that implements a linear warmup followed by 446 | linear decay. The learning rate increases linearly from 0 to the initial lr 447 | during warmup, then decreases linearly to 0 during the remaining steps. 448 | 449 | Add other types of learning rate schedulers here. 450 | 451 | Args: 452 | training_config: Configuration object containing optimizer and scheduler settings. 453 | optimizer: PyTorch optimizer whose learning rate will be scheduled. 454 | 455 | Returns: 456 | torch.optim.lr_scheduler.LambdaLR: Learning rate scheduler instance. 457 | """ 458 | 459 | if training_config.optimization.lr_scheduler == "linear_with_warmup": 460 | # Credit where credit is due: 461 | # https://github.com/huggingface/transformers/blob/e71a01a104dd663c730e494eb0b6467bb51df357/src/transformers/optimization.py#L102 462 | def _lr_lambda(curr_step, num_warmup_steps, max_steps): 463 | if curr_step < num_warmup_steps: 464 | return float(curr_step) / float(max(1, num_warmup_steps)) 465 | else: 466 | return max( 467 | 0.0, 468 | float(max_steps - curr_step) 469 | / float(max(1, max_steps - num_warmup_steps)), 470 | ) 471 | 472 | lr_lambda = lambda step: _lr_lambda( # noqa: E731 473 | step, 474 | training_config.optimization.lr_warmup_steps, 475 | training_config.max_steps, 476 | ) 477 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 478 | optimizer, 479 | lr_lambda, 480 | ) 481 | else: 482 | raise ValueError( 483 | f"Invalid learning rate scheduler: {training_config.optimization.lr_scheduler}" 484 | ) 485 | 486 | return lr_scheduler 487 | 488 | 489 | ######################################################## 490 | # 491 | # Experiment Monitoring (Logging, Experiment Tracking, etc.) 492 | # 493 | ######################################################## 494 | 495 | 496 | def _initialize_log_file(checkpointing_config: CheckpointingConfig) -> str: 497 | """Create and initialize a timestamped log file in the run's log directory. 498 | 499 | Sets up a log file with a unique timestamp in the run's logging directory. 500 | Creates the necessary directory structure if it doesn't exist. 501 | 502 | Directory Structure: 503 | {checkpointing_config.runs_dir}/ 504 | └── {checkpointing_config.run_name}/ 505 | └── {checkpointing_config.logs_dir}/ 506 | └── log_YYYYMMDD_HHMMSS.txt 507 | 508 | Args: 509 | checkpointing_config: Configuration object containing checkpointing settings. 510 | 511 | Returns: 512 | str: Absolute path to the created log file. 513 | 514 | """ 515 | 516 | run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name) 517 | logs_dir = os.path.join(run_dir, checkpointing_config.logs_dir) 518 | os.makedirs(logs_dir, exist_ok=True) 519 | 520 | # datetime stamp 521 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 522 | log_file_name = f"log_{timestamp}.log" 523 | log_file_path = os.path.join(logs_dir, log_file_name) 524 | 525 | open(log_file_path, "w").close() # Create an empty log file 526 | 527 | return log_file_path 528 | 529 | 530 | @use_backoff() 531 | def initialize_wandb( 532 | monitoring_config: MonitoringConfig, checkpointing_config: CheckpointingConfig 533 | ): 534 | """Initialize Weights and Biases. 535 | 536 | This function initializes Weights and Biases based on the configuration settings. 537 | 538 | Args: 539 | monitoring_config: Configuration object containing monitoring settings. 540 | checkpointing_config: Configuration object containing checkpointing settings. 541 | 542 | Returns: 543 | Optional[WandbLogger]: An experiment tracker instance. 544 | """ 545 | 546 | assert ( 547 | monitoring_config.wandb.project is not None 548 | and monitoring_config.wandb.project != "" 549 | ), "Wandb project must be provided if wandb is to be used." 550 | assert ( 551 | monitoring_config.wandb.entity is not None 552 | and monitoring_config.wandb.entity != "" 553 | ), "Wandb entity must be provided if wandb is to be used." 554 | 555 | _run_id = None 556 | if checkpointing_config.training.auto_resume: 557 | # If we are loading a checkpoint, we can try to find the run id of the previous run 558 | previous_runs = wandb.Api().runs( 559 | path=f"{monitoring_config.wandb.entity}/{monitoring_config.wandb.project}", 560 | filters={"display_name": checkpointing_config.run_name}, 561 | ) 562 | try: 563 | if len(previous_runs) == 1: 564 | _run_id = previous_runs[0].id 565 | except ValueError: 566 | pass 567 | 568 | wandb_logger = WandbLogger( 569 | project=monitoring_config.wandb.project, 570 | entity=monitoring_config.wandb.entity, 571 | id=_run_id, 572 | name=checkpointing_config.run_name, 573 | ) 574 | 575 | return wandb_logger 576 | 577 | 578 | @rank_zero_only 579 | def initialize_logging( 580 | monitoring_config: MonitoringConfig, 581 | checkpointing_config: CheckpointingConfig, 582 | fabric: L.Fabric, 583 | ): 584 | """Initialize logging system with default logging, to file and console. 585 | 586 | The default logging system uses a file handler and a stream handler. 587 | 588 | NOTE: this function is only called on rank 0. 589 | 590 | Args: 591 | monitoring_config: Configuration object containing monitoring settings. 592 | checkpointing_config: Configuration object containing checkpointing settings. 593 | 594 | Returns: 595 | logger: Standard Python logger configured for file and console output 596 | """ 597 | 598 | # ---- Standard Local Logger ---- # 599 | logger = logging.getLogger("pico-train") 600 | logger.setLevel(logging.INFO) 601 | 602 | # Create file handler 603 | log_file_path = _initialize_log_file(checkpointing_config) 604 | file_handler = logging.FileHandler(log_file_path, encoding="utf-8") 605 | file_handler.setLevel(monitoring_config.logging.log_level) 606 | 607 | # Create formatter and add it to the handler 608 | formatter = logging.Formatter( 609 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s", 610 | datefmt="%Y-%m-%d %H:%M:%S", 611 | ) 612 | file_handler.setFormatter(formatter) 613 | 614 | # Add the handler to the logger 615 | logger.addHandler(file_handler) 616 | 617 | # Add a stream handler for console output 618 | stream_handler = logging.StreamHandler() 619 | stream_handler.setLevel(monitoring_config.logging.log_level) 620 | stream_handler.setFormatter(formatter) 621 | logger.addHandler(stream_handler) 622 | 623 | return logger 624 | 625 | 626 | ######################################################## 627 | # 628 | # HuggingFace/Remote Checkpointing 629 | # 630 | ######################################################## 631 | 632 | 633 | @rank_zero_only 634 | @use_backoff() 635 | def initialize_hf_checkpointing( 636 | checkpointing_config: CheckpointingConfig, fabric: L.Fabric 637 | ): 638 | """Initialize HuggingFace Checkpointing. 639 | 640 | Creates a HuggingFace repository if it doesn't exist, and creates a branch named after the run. 641 | 642 | NOTE: this function is only called on rank 0. 643 | 644 | Args: 645 | checkpointing_config: Configuration object containing checkpointing settings; must have 646 | a 'hf_checkpoint' attribute that specifies the HuggingFace repository id and 647 | collection slug (if applicable) to save the checkpoint to. 648 | 649 | Raises: 650 | RuntimeError: If unable to create HuggingFace repository after multiple attempts. 651 | """ 652 | 653 | huggingface_repo_id = checkpointing_config.hf_checkpoint.repo_id 654 | assert ( 655 | huggingface_repo_id is not None and huggingface_repo_id != "" 656 | ), "hf_checkpoint.repo_id must be provided." 657 | 658 | repo = create_repo(huggingface_repo_id, exist_ok=True) 659 | 660 | # can create a repo without a specified namespace (will default to username) 661 | # however the rest of the HF calls need the fully qualified name 662 | # this is returned by create repo, so we update the config for later calls 663 | checkpointing_config.hf_checkpoint.repo_id = repo.repo_id 664 | huggingface_repo_id = repo.repo_id 665 | 666 | if checkpointing_config.hf_checkpoint.collection_slug: 667 | add_collection_item( 668 | checkpointing_config.hf_checkpoint.collection_slug, 669 | huggingface_repo_id, 670 | repo.repo_type, 671 | exists_ok=True, 672 | ) 673 | 674 | create_branch( 675 | repo_id=huggingface_repo_id, 676 | branch=checkpointing_config.run_name, 677 | exist_ok=True, 678 | ) 679 | -------------------------------------------------------------------------------- /src/training/utils/io.py: -------------------------------------------------------------------------------- 1 | """Defines a retry wrapper for io operations.""" 2 | 3 | import time 4 | from functools import wraps 5 | 6 | 7 | def use_backoff(max_retries=2, initial_delay=1, backoff_factor=2): 8 | """ 9 | Universal retry wrapper with exponential backoff for any function, but primarily for loading 10 | and storing HuggingFace datasets and objects. 11 | 12 | Example usage: 13 | 14 | >>> @use_backoff(max_retries=10, delay=1, backoff_factor=2) 15 | >>> def important_io_operation(x): 16 | >>> return x + 1 17 | 18 | Args: 19 | fn: Function to execute 20 | max_retries: Maximum number of retry attempts (default: 3) 21 | delay: Initial delay between retries in seconds (default: 1) 22 | backoff_factor: Multiplier for delay between retries (default: 2) 23 | 24 | Returns: 25 | A wrapper function that will retry the function fn up to max_retries times with exponential backoff 26 | 27 | Raises: 28 | Exception: If all retries fail 29 | """ 30 | 31 | def _decorator(fn): 32 | @wraps(fn) 33 | def wrapper(*args, **kwargs): 34 | current_delay = initial_delay 35 | last_exception = None 36 | 37 | for attempt in range(max_retries): 38 | try: 39 | return fn(*args, **kwargs) 40 | except Exception as e: 41 | last_exception = e 42 | if attempt < max_retries - 1: # Don't sleep on the last attempt 43 | time.sleep(current_delay) 44 | current_delay *= backoff_factor 45 | 46 | raise Exception( 47 | f"IO Operation failed after {max_retries} attempts: {str(last_exception)}" 48 | ) 49 | 50 | return wrapper 51 | 52 | return _decorator 53 | -------------------------------------------------------------------------------- /src/training/utils/logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellaneous logging utilities. 3 | """ 4 | 5 | from io import StringIO 6 | 7 | import yaml 8 | from lightning.fabric.utilities.rank_zero import rank_zero_only 9 | from rich.console import Console 10 | from rich.panel import Panel 11 | 12 | 13 | @rank_zero_only 14 | def pretty_print_yaml_config(logger, config: dict) -> None: 15 | """ 16 | Pretty print config with rich formatting. Assumes that the config is already saved as a 17 | dictionary - this can be done by calling `asdict` on the dataclass or loading in the config 18 | from a yaml file. 19 | 20 | NOTE: this function is only called on rank 0. 21 | 22 | Args: 23 | logger: Logger object to log the formatted output to. 24 | config: Dictionary containing the config to pretty print. 25 | """ 26 | # Create string buffer 27 | output = StringIO() 28 | console = Console(file=output, force_terminal=False) 29 | 30 | # Convert to YAML string first 31 | yaml_str = yaml.dump( 32 | config, default_flow_style=False, sort_keys=False, Dumper=yaml.SafeDumper 33 | ) 34 | 35 | # Create formatted panel 36 | panel = Panel( 37 | yaml_str, 38 | border_style="blue", 39 | padding=(0, 1), # Reduced padding 40 | expand=False, # Don't expand to terminal width 41 | ) 42 | 43 | # Print to buffer 44 | console.print(panel) 45 | 46 | # Log the formatted output 47 | for line in output.getvalue().splitlines(): 48 | logger.info(line) 49 | --------------------------------------------------------------------------------