├── .gitignore ├── LICENSE ├── README.md ├── assets ├── .DS_Store ├── cos_vs_cooldown.png ├── savings.png └── scaling_curves.png ├── flops.ipynb ├── requirements.txt └── src ├── config ├── __init__.py └── base.py ├── data ├── arxiv.py ├── openwebtext2.py ├── redpajama.py ├── shakespeare.py ├── slimpajama.py ├── utils.py └── wikitext.py ├── distributed ├── __init__.py ├── backend.py ├── ddp.py └── single.py ├── logger ├── logger.py └── rotational_logger.yaml ├── main.py ├── models ├── base.py ├── llama.py └── utils.py └── optim ├── base.py ├── utils.py └── weight_averaging.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Datasets, assets, logging 2 | datasets/ 3 | wandb/ 4 | exps/ 5 | scripts/ 6 | cluster/ 7 | assets/*.csv 8 | assets/*.json 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | 140 | # vscode 141 | .vscode/ 142 | 143 | # cluster stuff 144 | user.yaml 145 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 EPFL Machine Learning and Optimization Laboratory 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Codebase: Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations 2 | This is the codebase accompanying the paper [*Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations*](https://arxiv.org/abs/2405.18392). The code is largely based on our framework [llm-baselines](https://github.com/epfml/llm-baselines) to do research on training LLMs as an extension of [NanoGPT](https://github.com/karpathy/nanogpt). 3 | 4 | **Abstract:** 5 | > Scale has become a main ingredient in obtaining strong machine learning models. As a result, understanding a model's scaling properties is key to effectively designing both the right training setup as well as future generations of architectures. In this work, we argue that scale and training research has been needlessly complex due to reliance on the cosine schedule, which prevents training across different lengths for the same model size. We investigate the training behavior of a direct alternative - constant learning rate and cooldowns - and find that it scales predictably and reliably similar to cosine. Additionally, we show that stochastic weight averaging yields improved performance along the training trajectory, without additional training costs, across different scales. Importantly, with these findings we demonstrate that scaling experiments can be performed with significantly reduced compute and GPU hours by utilizing fewer but reusable training runs. 6 | 7 |

8 | Cosine vs. Cooldown Schedules 9 |

10 | 11 | **Figure:** Whereas the cosine learning rate follows a slow annealing, the alternative schedule of constant LR + cooldown is characterized by a fast drop towards the end of training. This cooldown phase initiates a sharp decrease in loss to match cosine; the training perplexity follows the same behavior. 12 | 13 |

14 | Loss Curve Envelopes 15 | Loss Curve Envelopes 16 |

17 | 18 | **Figure:** The cooldown schedule allows to perform scaling law experiments for a fraction of the compute. Instead of having to train from scratch (cosine), we launch one long run and perform cooldowns from intermediate checkpoints after training. 19 | 20 | 21 | ## Quickstart 22 | 23 | Create a conda environment and install dependencies (we recommend Python 3.10): 24 | 25 | ```bash 26 | conda create -n env python=3.10 27 | conda activate env 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | Run a simple training on the SlimPajama 6B dataset: 32 | ```bash 33 | python ./src/main.py 34 | ``` 35 | 36 | The above command trains a 213.34M parameters model with the Llama-style architecture. We recommend to use the `--compile` flag that speeds up training noticeably (up to 20% in our setup). 37 | 38 | ## LR Schedules and Weight Averaging 39 | In order to use the cooldown schedule: 40 | ```bash 41 | python ./src/main.py --compile --scheduler wsd --wsd-fract-decay 0.2 42 | ``` 43 | The argument `wsd-fract-decay` controls the fraction of the cooldown phase, and the functional form of the cooldown is handled with the argument `decay-type`. 44 | 45 | If you want to use stochastic weight averaging: 46 | ```bash 47 | python ./src/main.py --compile --scheduler wsd --wsd-fract-decay 0.2 --weight-average 48 | ``` 49 | With this, the averaging is done automatically in slots of 500 steps; the model averages are all stored (beware of the disk space). The frequency is handled via the arguments `--wa-interval` (average every k steps) and `--wa-horizon` (the length of the horizon/window). 50 | 51 | Moreover, the argument `wa-sweep-horizon` helps to automatically sweep the horizon to find the best performance, but may slow down training. 52 | 53 | ## FLOPS helpers 54 | The [`flops.ipynb`](flops.ipynb) provides a few helpers and functionalities for FLOPS computations of transformer configurations. 55 | 56 | # Contact & Reference 57 | Please do not hesitate to reach out to us if you have questions! 58 | 59 | In order to cite this work: 60 | ``` 61 | @article{hagele2024scaling, 62 | author = {Alexander H\"agele and Elie Bakouch and Atli Kosson and Loubna Ben Allal and Leandro Von Werra and Martin Jaggi}, 63 | title = {{Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations}}, 64 | year = {2024}, 65 | journal = {Advances in Neural Information Processing Systems}, 66 | url = {http://arxiv.org/abs/2405.18392} 67 | } 68 | ``` 69 | 70 | 71 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/schedules-and-scaling/17e78d600a7054bc6609a5ca6d050d141d4ad79e/assets/.DS_Store -------------------------------------------------------------------------------- /assets/cos_vs_cooldown.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/schedules-and-scaling/17e78d600a7054bc6609a5ca6d050d141d4ad79e/assets/cos_vs_cooldown.png -------------------------------------------------------------------------------- /assets/savings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/schedules-and-scaling/17e78d600a7054bc6609a5ca6d050d141d4ad79e/assets/savings.png -------------------------------------------------------------------------------- /assets/scaling_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/schedules-and-scaling/17e78d600a7054bc6609a5ca6d050d141d4ad79e/assets/scaling_curves.png -------------------------------------------------------------------------------- /flops.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "def embedding(seq_len, vocab_size, d_model):\n", 10 | " return 2 * seq_len * vocab_size * d_model\n", 11 | "\n", 12 | "\n", 13 | "def attention(seq_len, d_model, key_size, num_heads):\n", 14 | " projections = 2 * 3 * seq_len * d_model * (key_size * num_heads)\n", 15 | " logits = 2 * seq_len * seq_len * (key_size * num_heads)\n", 16 | " softmax = 3 * num_heads * seq_len * seq_len\n", 17 | " softmax_query_reduction = 2 * seq_len * seq_len * (key_size * num_heads)\n", 18 | " final_layer = 2 * seq_len * (key_size * num_heads) * d_model\n", 19 | " return projections + logits + softmax + softmax_query_reduction + final_layer\n", 20 | "\n", 21 | "\n", 22 | "def dense(seq_len, d_model, ffw_size, swiglu=False):\n", 23 | " if not swiglu:\n", 24 | " return 2 * seq_len * (2 * d_model * ffw_size)\n", 25 | " else:\n", 26 | " return 2 * seq_len * (3 * d_model * ffw_size)\n", 27 | "\n", 28 | "\n", 29 | "def moe(dense, n_experts, top_k, seq_len, d_model, ffw_size, swiglu=False):\n", 30 | " dense_flops = top_k * dense(seq_len, d_model, ffw_size, swiglu)\n", 31 | " gate_flops = 3 * seq_len * n_experts\n", 32 | " return dense_flops + gate_flops\n", 33 | "\n", 34 | "\n", 35 | "def final_logits(seq_len, d_model, vocab_size):\n", 36 | " return 2 * seq_len * d_model * vocab_size\n", 37 | "\n", 38 | "\n", 39 | "def get_flops(\n", 40 | " n_layers,\n", 41 | " seq_len,\n", 42 | " vocab_size,\n", 43 | " d_model,\n", 44 | " key_size,\n", 45 | " num_heads,\n", 46 | " ffw_size,\n", 47 | " swiglu=False,\n", 48 | " **kwargs,\n", 49 | "):\n", 50 | " return (\n", 51 | " embedding(seq_len, vocab_size, d_model)\n", 52 | " + n_layers\n", 53 | " * (\n", 54 | " attention(seq_len, d_model, key_size, num_heads)\n", 55 | " + dense(seq_len, d_model, ffw_size, swiglu=swiglu)\n", 56 | " )\n", 57 | " + final_logits(seq_len, d_model, vocab_size)\n", 58 | " )\n", 59 | "\n", 60 | "\n", 61 | "def flops_moe(\n", 62 | " seq_len,\n", 63 | " vocab_size,\n", 64 | " n_layers,\n", 65 | " d_model,\n", 66 | " key_size,\n", 67 | " num_heads,\n", 68 | " ffw_size,\n", 69 | " n_experts,\n", 70 | " top_k,\n", 71 | " swiglu=False,\n", 72 | "):\n", 73 | " return (\n", 74 | " embedding(seq_len, vocab_size, d_model)\n", 75 | " + n_layers\n", 76 | " * (\n", 77 | " attention(seq_len, d_model, key_size, num_heads)\n", 78 | " + moe(dense, n_experts, top_k, seq_len, d_model, ffw_size, swiglu=swiglu)\n", 79 | " )\n", 80 | " + final_logits(seq_len, d_model, vocab_size)\n", 81 | " )\n", 82 | "\n", 83 | "\n", 84 | "def parameter_count(\n", 85 | " vocab_size,\n", 86 | " n_layers,\n", 87 | " d_model,\n", 88 | " key_size,\n", 89 | " num_heads,\n", 90 | " num_kv_heads,\n", 91 | " ffw_size,\n", 92 | " n_experts=1,\n", 93 | " swiglu_or_geglu=False,\n", 94 | " **kwargs,\n", 95 | "):\n", 96 | " mul_factor_ffn = 3 if swiglu_or_geglu else 2\n", 97 | " attn = 2 * d_model * num_heads * key_size + 2 * d_model * num_kv_heads * key_size\n", 98 | " return vocab_size * d_model + n_layers * (\n", 99 | " attn + mul_factor_ffn * n_experts * d_model * ffw_size\n", 100 | " )" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 2, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "multiple_of = 256\n", 110 | "\n", 111 | "tiny = {\n", 112 | " \"d_model\": 384,\n", 113 | " \"key_size\": 64,\n", 114 | " \"num_heads\": 6,\n", 115 | " \"num_kv_heads\": 6,\n", 116 | " \"ffw_size\": int(8 / 3 * 384),\n", 117 | " \"n_layers\": 8,\n", 118 | " \"vocab_size\": 50257,\n", 119 | " \"swiglu\": True,\n", 120 | " \"seq_len\": 512,\n", 121 | "}\n", 122 | "tiny[\"ffw_size\"] = multiple_of * (\n", 123 | " (tiny[\"ffw_size\"] + multiple_of - 1) // multiple_of\n", 124 | ")\n", 125 | "mini = {\n", 126 | " \"d_model\": 512,\n", 127 | " \"key_size\": 64,\n", 128 | " \"num_heads\": 8,\n", 129 | " \"num_kv_heads\": 8,\n", 130 | " \"ffw_size\": int(8 / 3 * 512),\n", 131 | " \"n_layers\": 10,\n", 132 | " \"vocab_size\": 50257,\n", 133 | " \"swiglu\": True,\n", 134 | " \"seq_len\": 512,\n", 135 | "}\n", 136 | "mini[\"ffw_size\"] = multiple_of * (\n", 137 | " (mini[\"ffw_size\"] + multiple_of - 1) // multiple_of\n", 138 | ")\n", 139 | "\n", 140 | "small = {\n", 141 | " \"d_model\": 768,\n", 142 | " \"key_size\": 64,\n", 143 | " \"num_heads\": 12,\n", 144 | " \"num_kv_heads\": 12,\n", 145 | " \"ffw_size\": int(8 / 3 * 768),\n", 146 | " \"n_layers\": 12,\n", 147 | " \"vocab_size\": 50257,\n", 148 | " \"swiglu\": True,\n", 149 | " \"seq_len\": 512,\n", 150 | "}\n", 151 | "small[\"ffw_size\"] = multiple_of * (\n", 152 | " (small[\"ffw_size\"] + multiple_of - 1) // multiple_of\n", 153 | ")\n", 154 | "\n", 155 | "\n", 156 | "_210M = {\n", 157 | " \"d_model\": 768,\n", 158 | " \"key_size\": 64,\n", 159 | " \"num_heads\": 12,\n", 160 | " \"num_kv_heads\": 12,\n", 161 | " \"ffw_size\": int(8 / 3 * 768),\n", 162 | " \"n_layers\": 24,\n", 163 | " \"vocab_size\": 50257,\n", 164 | " \"swiglu\": True,\n", 165 | " \"seq_len\": 512,\n", 166 | "}\n", 167 | "_210M[\"ffw_size\"] = multiple_of * (\n", 168 | " (_210M[\"ffw_size\"] + multiple_of - 1) // multiple_of\n", 169 | ")\n" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 3, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "all_flops = []\n", 179 | "all_params = []" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 4, 185 | "metadata": {}, 186 | "outputs": [ 187 | { 188 | "name": "stdout", 189 | "output_type": "stream", 190 | "text": [ 191 | "33.454464\n", 192 | "171834605568\n", 193 | "iters [3.0, 7.0, 10.0]\n", 194 | "tokens [0.3, 0.7, 1.0]\n", 195 | "ratio [9.2, 21.4, 30.6]\n", 196 | "flops [0.1031007633408, 0.2405684477952, 0.343669211136]\n", 197 | "flop savings 0.6\n" 198 | ] 199 | } 200 | ], 201 | "source": [ 202 | "model = tiny\n", 203 | "\n", 204 | "n_layers = model[\"n_layers\"]\n", 205 | "d_model = model[\"d_model\"]\n", 206 | "key_size = model[\"key_size\"]\n", 207 | "num_heads = model[\"num_heads\"]\n", 208 | "num_kv_heads = model[\"num_kv_heads\"]\n", 209 | "ffw_size = model[\"ffw_size\"]\n", 210 | "vocab_size = model[\"vocab_size\"]\n", 211 | "swiglu = model[\"swiglu\"]\n", 212 | "n_experts = 8\n", 213 | "top_k = 2\n", 214 | "seq_len = model[\"seq_len\"]\n", 215 | "\n", 216 | "\n", 217 | "flops = 3 * get_flops(\n", 218 | " n_layers,\n", 219 | " seq_len,\n", 220 | " vocab_size,\n", 221 | " d_model,\n", 222 | " key_size,\n", 223 | " num_heads=num_heads,\n", 224 | " ffw_size=ffw_size,\n", 225 | " swiglu=swiglu,\n", 226 | ")\n", 227 | "params = parameter_count(\n", 228 | " vocab_size=vocab_size,\n", 229 | " n_layers=n_layers,\n", 230 | " d_model=d_model,\n", 231 | " key_size=key_size,\n", 232 | " num_heads=num_heads,\n", 233 | " num_kv_heads=num_kv_heads,\n", 234 | " ffw_size=ffw_size,\n", 235 | " swiglu_or_geglu=swiglu,\n", 236 | ")\n", 237 | "# lr 0.002 for cos, 0.001 for wsd\n", 238 | "print(params / 1e6)\n", 239 | "print(flops)\n", 240 | "iters = [2400 / 0.8, 5600 / 0.8, 8000 / 0.8]#, 9600 / 0.8]\n", 241 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n", 242 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n", 243 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n", 244 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n", 245 | "print(\"flops\", flops_all)\n", 246 | "all_flops.append(flops_all)\n", 247 | "all_params.append(params)\n", 248 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 5, 254 | "metadata": {}, 255 | "outputs": [ 256 | { 257 | "name": "stdout", 258 | "output_type": "stream", 259 | "text": [ 260 | "52.99456\n", 261 | "254882611200\n", 262 | "iters [3.8, 7.5, 11.2]\n", 263 | "tokens [0.4, 0.8, 1.2]\n", 264 | "ratio [7.2, 14.5, 21.7]\n", 265 | "flops [0.1911619584, 0.3823239168, 0.5734858752]\n", 266 | "flop savings 0.6\n" 267 | ] 268 | } 269 | ], 270 | "source": [ 271 | "tiny2 = {\n", 272 | " \"d_model\": 512,\n", 273 | " \"key_size\": 64,\n", 274 | " \"num_heads\": 8,\n", 275 | " \"num_kv_heads\": 8,\n", 276 | " \"ffw_size\": int(8 / 3 * 512),\n", 277 | " \"n_layers\": 8,\n", 278 | " \"vocab_size\": 50257,\n", 279 | " \"swiglu\": True,\n", 280 | " \"seq_len\": 512,\n", 281 | "}\n", 282 | "tiny2[\"ffw_size\"] = multiple_of * (\n", 283 | " (tiny2[\"ffw_size\"] + multiple_of - 1) // multiple_of\n", 284 | ")\n", 285 | "\n", 286 | "model = tiny2\n", 287 | "n_layers = model[\"n_layers\"]\n", 288 | "d_model = model[\"d_model\"]\n", 289 | "key_size = model[\"key_size\"]\n", 290 | "num_heads = model[\"num_heads\"]\n", 291 | "num_kv_heads = model[\"num_kv_heads\"]\n", 292 | "ffw_size = model[\"ffw_size\"]\n", 293 | "vocab_size = model[\"vocab_size\"]\n", 294 | "swiglu = model[\"swiglu\"]\n", 295 | "n_experts = 8\n", 296 | "top_k = 2\n", 297 | "seq_len = model[\"seq_len\"]\n", 298 | "\n", 299 | "\n", 300 | "flops = 3 * get_flops(\n", 301 | " n_layers,\n", 302 | " seq_len,\n", 303 | " vocab_size,\n", 304 | " d_model,\n", 305 | " key_size,\n", 306 | " num_heads=num_heads,\n", 307 | " ffw_size=ffw_size,\n", 308 | " swiglu=swiglu,\n", 309 | ")\n", 310 | "params = parameter_count(\n", 311 | " vocab_size=vocab_size,\n", 312 | " n_layers=n_layers,\n", 313 | " d_model=d_model,\n", 314 | " key_size=key_size,\n", 315 | " num_heads=num_heads,\n", 316 | " num_kv_heads=num_kv_heads,\n", 317 | " ffw_size=ffw_size,\n", 318 | " swiglu_or_geglu=swiglu,\n", 319 | ")\n", 320 | "# lr 0.002 for cos, 0.001 for wsd\n", 321 | "print(params / 1e6)\n", 322 | "print(flops)\n", 323 | "iters = [3000 / 0.8, 6000 / 0.8, 9000 / 0.8]# 12000 / 0.8]\n", 324 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n", 325 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n", 326 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n", 327 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n", 328 | "print(\"flops\", flops_all)\n", 329 | "all_flops.append(flops_all)\n", 330 | "all_params.append(params)\n", 331 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 6, 337 | "metadata": {}, 338 | "outputs": [ 339 | { 340 | "name": "stdout", 341 | "output_type": "stream", 342 | "text": [ 343 | "59.810304\n", 344 | "279079550976\n", 345 | "iters [7.5, 12.5, 17.5]\n", 346 | "tokens [0.8, 1.3, 1.8]\n", 347 | "ratio [12.8, 21.4, 30.0]\n", 348 | "flops [0.418619326464, 0.69769887744, 0.976778428416]\n", 349 | "flop savings 0.5733333333333334\n" 350 | ] 351 | } 352 | ], 353 | "source": [ 354 | "model = mini\n", 355 | "\n", 356 | "n_layers = model[\"n_layers\"]\n", 357 | "d_model = model[\"d_model\"]\n", 358 | "key_size = model[\"key_size\"]\n", 359 | "num_heads = model[\"num_heads\"]\n", 360 | "num_kv_heads = model[\"num_kv_heads\"]\n", 361 | "ffw_size = model[\"ffw_size\"]\n", 362 | "vocab_size = model[\"vocab_size\"]\n", 363 | "swiglu = model[\"swiglu\"]\n", 364 | "n_experts = 8\n", 365 | "top_k = 2\n", 366 | "seq_len = model[\"seq_len\"]\n", 367 | "\n", 368 | "\n", 369 | "flops = 3 * get_flops(\n", 370 | " n_layers,\n", 371 | " seq_len,\n", 372 | " vocab_size,\n", 373 | " d_model,\n", 374 | " key_size,\n", 375 | " num_heads=num_heads,\n", 376 | " ffw_size=ffw_size,\n", 377 | " swiglu=swiglu,\n", 378 | ")\n", 379 | "params = parameter_count(\n", 380 | " vocab_size=vocab_size,\n", 381 | " n_layers=n_layers,\n", 382 | " d_model=d_model,\n", 383 | " key_size=key_size,\n", 384 | " num_heads=num_heads,\n", 385 | " num_kv_heads=num_kv_heads,\n", 386 | " ffw_size=ffw_size,\n", 387 | " swiglu_or_geglu=swiglu,\n", 388 | ")\n", 389 | "\n", 390 | "print(params / 1e6)\n", 391 | "print(flops)\n", 392 | "iters = [6000 / 0.8, 10000 / 0.8, 14000 / 0.8]# 18000 / 0.8]\n", 393 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n", 394 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n", 395 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n", 396 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n", 397 | "print(\"flops\", flops_all)\n", 398 | "all_flops.append(flops_all)\n", 399 | "all_params.append(params)\n", 400 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))\n", 401 | "# lr 0.002 for cos, 0.001 for wsd\n" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 7, 407 | "metadata": {}, 408 | "outputs": [ 409 | { 410 | "name": "stdout", 411 | "output_type": "stream", 412 | "text": [ 413 | "93.11296\n", 414 | "409294602240\n", 415 | "iters [10.0, 17.5, 25.0]\n", 416 | "tokens [1.0, 1.8, 2.6]\n", 417 | "ratio [11.0, 19.2, 27.5]\n", 418 | "flops [0.81858920448, 1.43253110784, 2.0464730112]\n", 419 | "flop savings 0.5809523809523809\n" 420 | ] 421 | } 422 | ], 423 | "source": [ 424 | "mini2 = {\n", 425 | " \"d_model\": 640,\n", 426 | " \"key_size\": 64,\n", 427 | " \"num_heads\": 10,\n", 428 | " \"num_kv_heads\": 10,\n", 429 | " \"ffw_size\": int(8 / 3 * 640),\n", 430 | " \"n_layers\": 12,\n", 431 | " \"vocab_size\": 50257,\n", 432 | " \"swiglu\": True,\n", 433 | " \"seq_len\": 512,\n", 434 | "}\n", 435 | "mini2[\"ffw_size\"] = multiple_of * (\n", 436 | " (mini2[\"ffw_size\"] + multiple_of - 1) // multiple_of\n", 437 | ")\n", 438 | "\n", 439 | "\n", 440 | "model = mini2\n", 441 | "\n", 442 | "n_layers = model[\"n_layers\"]\n", 443 | "d_model = model[\"d_model\"]\n", 444 | "key_size = model[\"key_size\"]\n", 445 | "num_heads = model[\"num_heads\"]\n", 446 | "num_kv_heads = model[\"num_kv_heads\"]\n", 447 | "ffw_size = model[\"ffw_size\"]\n", 448 | "vocab_size = model[\"vocab_size\"]\n", 449 | "swiglu = model[\"swiglu\"]\n", 450 | "n_experts = 8\n", 451 | "top_k = 2\n", 452 | "seq_len = model[\"seq_len\"]\n", 453 | "\n", 454 | "\n", 455 | "flops = 3 * get_flops(\n", 456 | " n_layers,\n", 457 | " seq_len,\n", 458 | " vocab_size,\n", 459 | " d_model,\n", 460 | " key_size,\n", 461 | " num_heads=num_heads,\n", 462 | " ffw_size=ffw_size,\n", 463 | " swiglu=swiglu,\n", 464 | ")\n", 465 | "params = parameter_count(\n", 466 | " vocab_size=vocab_size,\n", 467 | " n_layers=n_layers,\n", 468 | " d_model=d_model,\n", 469 | " key_size=key_size,\n", 470 | " num_heads=num_heads,\n", 471 | " num_kv_heads=num_kv_heads,\n", 472 | " ffw_size=ffw_size,\n", 473 | " swiglu_or_geglu=swiglu,\n", 474 | ")\n", 475 | "\n", 476 | "print(params / 1e6)\n", 477 | "print(flops)\n", 478 | "iters = [8000 / 0.8, 14000 / 0.8, 20000 / 0.8]# 26000 / 0.8]\n", 479 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n", 480 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n", 481 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n", 482 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n", 483 | "print(\"flops\", flops_all)\n", 484 | "all_flops.append(flops_all)\n", 485 | "all_params.append(params)\n", 486 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))\n", 487 | "# lr 0.002 for cos, 0.001 for wsd" 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": 8, 493 | "metadata": {}, 494 | "outputs": [ 495 | { 496 | "name": "stdout", 497 | "output_type": "stream", 498 | "text": [ 499 | "123.532032\n", 500 | "527392309248\n", 501 | "iters [15.0, 25.0, 35.0]\n", 502 | "tokens [1.5, 2.6, 3.6]\n", 503 | "ratio [12.4, 20.7, 29.0]\n", 504 | "flops [1.582176927744, 2.63696154624, 3.691746164736]\n", 505 | "flop savings 0.5733333333333334\n" 506 | ] 507 | } 508 | ], 509 | "source": [ 510 | "model = small\n", 511 | "\n", 512 | "n_layers = model[\"n_layers\"]\n", 513 | "d_model = model[\"d_model\"]\n", 514 | "key_size = model[\"key_size\"]\n", 515 | "num_heads = model[\"num_heads\"]\n", 516 | "num_kv_heads = model[\"num_kv_heads\"]\n", 517 | "ffw_size = model[\"ffw_size\"]\n", 518 | "vocab_size = model[\"vocab_size\"]\n", 519 | "swiglu = model[\"swiglu\"]\n", 520 | "n_experts = 8\n", 521 | "top_k = 2\n", 522 | "seq_len = model[\"seq_len\"]\n", 523 | "\n", 524 | "\n", 525 | "flops = 3 * get_flops(\n", 526 | " n_layers,\n", 527 | " seq_len,\n", 528 | " vocab_size,\n", 529 | " d_model,\n", 530 | " key_size,\n", 531 | " num_heads=num_heads,\n", 532 | " ffw_size=ffw_size,\n", 533 | " swiglu=swiglu,\n", 534 | ")\n", 535 | "params = parameter_count(\n", 536 | " vocab_size=vocab_size,\n", 537 | " n_layers=n_layers,\n", 538 | " d_model=d_model,\n", 539 | " key_size=key_size,\n", 540 | " num_heads=num_heads,\n", 541 | " num_kv_heads=num_kv_heads,\n", 542 | " ffw_size=ffw_size,\n", 543 | " swiglu_or_geglu=swiglu,\n", 544 | ")\n", 545 | "print(params / 1e6)\n", 546 | "print(flops)\n", 547 | "iters = [12000 / 0.8, 20000 / 0.8, 28000 / 0.8]# 36000 / 0.8]\n", 548 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n", 549 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n", 550 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n", 551 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n", 552 | "print(\"flops\", flops_all)\n", 553 | "all_flops.append(flops_all)\n", 554 | "all_params.append(params)\n", 555 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))\n", 556 | "# lr 0.001" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": 9, 562 | "metadata": {}, 563 | "outputs": [ 564 | { 565 | "name": "stdout", 566 | "output_type": "stream", 567 | "text": [ 568 | "151.843584\n", 569 | "624142319616\n", 570 | "iters [25.0, 37.5, 50.0]\n", 571 | "tokens [2.6, 3.8, 5.1]\n", 572 | "ratio [16.9, 25.3, 33.7]\n", 573 | "flops [3.12071159808, 4.68106739712, 6.24142319616]\n", 574 | "flop savings 0.5555555555555556\n" 575 | ] 576 | } 577 | ], 578 | "source": [ 579 | "_151M = {\n", 580 | " \"d_model\": 768,\n", 581 | " \"key_size\": 64,\n", 582 | " \"num_heads\": 12,\n", 583 | " \"num_kv_heads\": 12,\n", 584 | " \"ffw_size\": int(8 / 3 * 768),\n", 585 | " \"n_layers\": 16,\n", 586 | " \"vocab_size\": 50257,\n", 587 | " \"swiglu\": True,\n", 588 | " \"seq_len\": 512,\n", 589 | "}\n", 590 | "\n", 591 | "_151M[\"ffw_size\"] = multiple_of * (\n", 592 | " (_151M[\"ffw_size\"] + multiple_of - 1) // multiple_of\n", 593 | ")\n", 594 | "\n", 595 | "model = _151M\n", 596 | "\n", 597 | "n_layers = model[\"n_layers\"]\n", 598 | "d_model = model[\"d_model\"]\n", 599 | "key_size = model[\"key_size\"]\n", 600 | "num_heads = model[\"num_heads\"]\n", 601 | "num_kv_heads = model[\"num_kv_heads\"]\n", 602 | "ffw_size = model[\"ffw_size\"]\n", 603 | "vocab_size = model[\"vocab_size\"]\n", 604 | "swiglu = model[\"swiglu\"]\n", 605 | "n_experts = 8\n", 606 | "top_k = 2\n", 607 | "seq_len = model[\"seq_len\"]\n", 608 | "\n", 609 | "\n", 610 | "flops = 3 * get_flops(\n", 611 | " n_layers,\n", 612 | " seq_len,\n", 613 | " vocab_size,\n", 614 | " d_model,\n", 615 | " key_size,\n", 616 | " num_heads=num_heads,\n", 617 | " ffw_size=ffw_size,\n", 618 | " swiglu=swiglu,\n", 619 | ")\n", 620 | "params = parameter_count(\n", 621 | " vocab_size=vocab_size,\n", 622 | " n_layers=n_layers,\n", 623 | " d_model=d_model,\n", 624 | " key_size=key_size,\n", 625 | " num_heads=num_heads,\n", 626 | " num_kv_heads=num_kv_heads,\n", 627 | " ffw_size=ffw_size,\n", 628 | " swiglu_or_geglu=swiglu,\n", 629 | ")\n", 630 | "print(params / 1e6)\n", 631 | "print(flops)\n", 632 | "iters = [20000 / 0.8, 30000 / 0.8, 40000 / 0.8]\n", 633 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n", 634 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n", 635 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n", 636 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n", 637 | "print(\"flops\", flops_all)\n", 638 | "all_flops.append(flops_all)\n", 639 | "all_params.append(params)\n", 640 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))\n", 641 | "# double batch size?" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 10, 647 | "metadata": {}, 648 | "outputs": [ 649 | { 650 | "name": "stdout", 651 | "output_type": "stream", 652 | "text": [ 653 | "166.1408\n", 654 | "682936762368\n", 655 | "iters [25.0, 37.5, 50.0]\n", 656 | "tokens [2.6, 3.8, 5.1]\n", 657 | "ratio [15.4, 23.1, 30.8]\n", 658 | "flops [3.41468381184, 5.12202571776, 6.82936762368]\n", 659 | "flop savings 0.5555555555555556\n" 660 | ] 661 | } 662 | ], 663 | "source": [ 664 | "_166M = {\n", 665 | " \"d_model\": 896,\n", 666 | " \"key_size\": 64,\n", 667 | " \"num_heads\": 14,\n", 668 | " \"num_kv_heads\": 14,\n", 669 | " \"ffw_size\": int(8 / 3 * 896),\n", 670 | " \"n_layers\": 12,\n", 671 | " \"vocab_size\": 50257,\n", 672 | " \"swiglu\": True,\n", 673 | " \"seq_len\": 512,\n", 674 | "}\n", 675 | "\n", 676 | "_166M[\"ffw_size\"] = multiple_of * (\n", 677 | " (_166M[\"ffw_size\"] + multiple_of - 1) // multiple_of\n", 678 | ")\n", 679 | "\n", 680 | "model = _166M\n", 681 | "\n", 682 | "n_layers = model[\"n_layers\"]\n", 683 | "d_model = model[\"d_model\"]\n", 684 | "key_size = model[\"key_size\"]\n", 685 | "num_heads = model[\"num_heads\"]\n", 686 | "num_kv_heads = model[\"num_kv_heads\"]\n", 687 | "ffw_size = model[\"ffw_size\"]\n", 688 | "vocab_size = model[\"vocab_size\"]\n", 689 | "swiglu = model[\"swiglu\"]\n", 690 | "n_experts = 8\n", 691 | "top_k = 2\n", 692 | "seq_len = model[\"seq_len\"]\n", 693 | "\n", 694 | "\n", 695 | "flops = 3 * get_flops(\n", 696 | " n_layers,\n", 697 | " seq_len,\n", 698 | " vocab_size,\n", 699 | " d_model,\n", 700 | " key_size,\n", 701 | " num_heads=num_heads,\n", 702 | " ffw_size=ffw_size,\n", 703 | " swiglu=swiglu,\n", 704 | ")\n", 705 | "params = parameter_count(\n", 706 | " vocab_size=vocab_size,\n", 707 | " n_layers=n_layers,\n", 708 | " d_model=d_model,\n", 709 | " key_size=key_size,\n", 710 | " num_heads=num_heads,\n", 711 | " num_kv_heads=num_kv_heads,\n", 712 | " ffw_size=ffw_size,\n", 713 | " swiglu_or_geglu=swiglu,\n", 714 | ")\n", 715 | "print(params / 1e6)\n", 716 | "print(flops)\n", 717 | "iters = [20000 / 0.8, 30000 / 0.8, 40000 / 0.8]\n", 718 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n", 719 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n", 720 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n", 721 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n", 722 | "print(\"flops\", flops_all)\n", 723 | "all_flops.append(flops_all)\n", 724 | "all_params.append(params)\n", 725 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))\n", 726 | "# double batch size?" 727 | ] 728 | }, 729 | { 730 | "cell_type": "code", 731 | "execution_count": 11, 732 | "metadata": {}, 733 | "outputs": [ 734 | { 735 | "name": "stdout", 736 | "output_type": "stream", 737 | "text": [ 738 | "208.466688\n", 739 | "817642340352\n", 740 | "iters [37.5, 50.0, 62.5]\n", 741 | "tokens [3.8, 5.1, 6.4]\n", 742 | "ratio [18.4, 24.6, 30.7]\n", 743 | "flops [6.13231755264, 8.17642340352, 10.2205292544]\n", 744 | "flop savings 0.5333333333333333\n" 745 | ] 746 | } 747 | ], 748 | "source": [ 749 | "model = _210M\n", 750 | "\n", 751 | "n_layers = model[\"n_layers\"]\n", 752 | "d_model = model[\"d_model\"]\n", 753 | "key_size = model[\"key_size\"]\n", 754 | "num_heads = model[\"num_heads\"]\n", 755 | "num_kv_heads = model[\"num_kv_heads\"]\n", 756 | "ffw_size = model[\"ffw_size\"]\n", 757 | "vocab_size = model[\"vocab_size\"]\n", 758 | "swiglu = model[\"swiglu\"]\n", 759 | "n_experts = 8\n", 760 | "top_k = 2\n", 761 | "seq_len = model[\"seq_len\"]\n", 762 | "\n", 763 | "\n", 764 | "flops = 3 * get_flops(\n", 765 | " n_layers,\n", 766 | " seq_len,\n", 767 | " vocab_size,\n", 768 | " d_model,\n", 769 | " key_size,\n", 770 | " num_heads=num_heads,\n", 771 | " ffw_size=ffw_size,\n", 772 | " swiglu=swiglu,\n", 773 | ")\n", 774 | "params = parameter_count(\n", 775 | " vocab_size=vocab_size,\n", 776 | " n_layers=n_layers,\n", 777 | " d_model=d_model,\n", 778 | " key_size=key_size,\n", 779 | " num_heads=num_heads,\n", 780 | " num_kv_heads=num_kv_heads,\n", 781 | " ffw_size=ffw_size,\n", 782 | " swiglu_or_geglu=swiglu,\n", 783 | ")\n", 784 | "print(params / 1e6)\n", 785 | "print(flops)\n", 786 | "# iters = [22222, 44444, 66666]\n", 787 | "iters = [30000 / 0.8, 40000 / 0.8, 50000 / 0.8]\n", 788 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n", 789 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n", 790 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n", 791 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n", 792 | "print(\"flops\", flops_all)\n", 793 | "all_flops.append(flops_all)\n", 794 | "all_params.append(params)\n", 795 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))" 796 | ] 797 | }, 798 | { 799 | "cell_type": "code", 800 | "execution_count": 12, 801 | "metadata": {}, 802 | "outputs": [ 803 | { 804 | "name": "stdout", 805 | "output_type": "stream", 806 | "text": [ 807 | "359.744512\n", 808 | "1341445373952\n", 809 | "iters [25.0, 37.5, 50.0]\n", 810 | "tokens [5.1, 7.7, 10.2]\n", 811 | "ratio [14.2, 21.3, 28.5]\n", 812 | "flops [6.70722686976, 10.06084030464, 13.41445373952]\n", 813 | "flop savings 0.5555555555555556\n" 814 | ] 815 | } 816 | ], 817 | "source": [ 818 | "_350M = {\n", 819 | " \"d_model\": 1024,\n", 820 | " \"key_size\": 64,\n", 821 | " \"num_heads\": 16,\n", 822 | " \"num_kv_heads\": 16,\n", 823 | " \"ffw_size\": int(8 / 3 * 1024),\n", 824 | " \"n_layers\": 24,\n", 825 | " \"vocab_size\": 50257,\n", 826 | " \"swiglu\": True,\n", 827 | " \"seq_len\": 512,\n", 828 | "}\n", 829 | "\n", 830 | "_350M[\"ffw_size\"] = multiple_of * (\n", 831 | " (_350M[\"ffw_size\"] + multiple_of - 1) // multiple_of\n", 832 | ")\n", 833 | "\n", 834 | "model = _350M\n", 835 | "\n", 836 | "n_layers = model[\"n_layers\"]\n", 837 | "d_model = model[\"d_model\"]\n", 838 | "key_size = model[\"key_size\"]\n", 839 | "num_heads = model[\"num_heads\"]\n", 840 | "num_kv_heads = model[\"num_kv_heads\"]\n", 841 | "ffw_size = model[\"ffw_size\"]\n", 842 | "vocab_size = model[\"vocab_size\"]\n", 843 | "swiglu = model[\"swiglu\"]\n", 844 | "n_experts = 8\n", 845 | "top_k = 2\n", 846 | "seq_len = model[\"seq_len\"]\n", 847 | "\n", 848 | "\n", 849 | "flops = 3 * get_flops(\n", 850 | " n_layers,\n", 851 | " seq_len,\n", 852 | " vocab_size,\n", 853 | " d_model,\n", 854 | " key_size,\n", 855 | " num_heads=num_heads,\n", 856 | " ffw_size=ffw_size,\n", 857 | " swiglu=swiglu,\n", 858 | ")\n", 859 | "params = parameter_count(\n", 860 | " vocab_size=vocab_size,\n", 861 | " n_layers=n_layers,\n", 862 | " d_model=d_model,\n", 863 | " key_size=key_size,\n", 864 | " num_heads=num_heads,\n", 865 | " num_kv_heads=num_kv_heads,\n", 866 | " ffw_size=ffw_size,\n", 867 | " swiglu_or_geglu=swiglu,\n", 868 | ")\n", 869 | "print(params / 1e6)\n", 870 | "print(flops)\n", 871 | "iters = [20000 / 0.8, 30000 / 0.8, 40000 / 0.8]\n", 872 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n", 873 | "print(\"tokens\", [float(f\"{400 * 512 * i / 1e9:.1f}\") for i in iters])\n", 874 | "print(\"ratio\", [float(f\"{400 * 512 * i / params:.1f}\") for i in iters])\n", 875 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n", 876 | "print(\"flops\", flops_all)\n", 877 | "all_flops.append(flops_all)\n", 878 | "all_params.append(params)\n", 879 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))\n", 880 | "# double batch size?" 881 | ] 882 | } 883 | ], 884 | "metadata": { 885 | "kernelspec": { 886 | "display_name": "llm-baselines", 887 | "language": "python", 888 | "name": "python3" 889 | }, 890 | "language_info": { 891 | "codemirror_mode": { 892 | "name": "ipython", 893 | "version": 3 894 | }, 895 | "file_extension": ".py", 896 | "mimetype": "text/x-python", 897 | "name": "python", 898 | "nbconvert_exporter": "python", 899 | "pygments_lexer": "ipython3", 900 | "version": "3.10.14" 901 | } 902 | }, 903 | "nbformat": 4, 904 | "nbformat_minor": 2 905 | } 906 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tiktoken 2 | --find-links https://download.pytorch.org/whl/torch_stable.html 3 | torch==2.2.0+cu118 4 | torchaudio==2.2.0+cu118 5 | torchvision==0.17.0+cu118 6 | tqdm==4.65.0 7 | schedulefree 8 | transformers 9 | wandb 10 | datasets 11 | zstandard -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | 3 | CONFIG_FORMAT_TO_MODULE_MAP = { 4 | "base": base, 5 | } 6 | 7 | 8 | def parse_args_with_format(format, base_parser, args, namespace): 9 | return CONFIG_FORMAT_TO_MODULE_MAP[format].parse_args(base_parser, args, namespace) 10 | 11 | 12 | def registered_formats(): 13 | return CONFIG_FORMAT_TO_MODULE_MAP.keys() 14 | -------------------------------------------------------------------------------- /src/config/base.py: -------------------------------------------------------------------------------- 1 | import distributed 2 | 3 | 4 | def parse_args(base_parser, args, namespace): 5 | parser = base_parser 6 | # General training params 7 | parser.add_argument("--experiment-name", default=None, type=str) 8 | parser.add_argument("--seed", default=0, type=int) 9 | parser.add_argument("--data-seed", default=1337, type=int) 10 | parser.add_argument("--eval-interval", default=200, type=int) 11 | parser.add_argument("--full-eval-at", nargs="+", type=int) 12 | parser.add_argument("--eval-batches", default=32, type=int) 13 | parser.add_argument("--device", default="cuda:0", type=str) 14 | parser.add_argument( 15 | "--distributed-backend", 16 | default=None, 17 | type=str, 18 | required=False, 19 | choices=distributed.registered_backends(), 20 | ) 21 | parser.add_argument("--log-interval", default=50, type=int) 22 | 23 | # Checkpointing 24 | parser.add_argument("--results-base-folder", default="./exps", type=str) 25 | parser.add_argument("--permanent-ckpt-interval", default=0, type=int) 26 | parser.add_argument("--latest-ckpt-interval", default=0, type=int) 27 | parser.add_argument("--resume-from", default=None, type=str) 28 | parser.add_argument("--resume-from-swa", default=None, type=str) 29 | 30 | parser.add_argument("--auto-resume", default=True) 31 | 32 | # logging params (WandB) 33 | parser.add_argument("--wandb", action="store_true") # whether to use wandb or not 34 | parser.add_argument("--wandb-project", default="my-project", type=str) 35 | parser.add_argument( 36 | "--wandb-run-prefix", default="none", type=str 37 | ) # is added before the autogenerated experiment name 38 | parser.add_argument( 39 | "--eval-seq-prefix", default="none", type=str 40 | ) # prefix used to generate sequences 41 | parser.add_argument("--log-dynamics", action="store_true") 42 | parser.add_argument( 43 | "--dynamics-logger-cfg", default="./src/logger/rotational_logger.yaml", type=str 44 | ) 45 | 46 | # Schedule 47 | parser.add_argument( 48 | "--scheduler", 49 | default="cos", 50 | choices=["linear", "cos", "wsd", "none", "cos_inf"], 51 | ) 52 | parser.add_argument("--cos-inf-steps", default=0, type=int) 53 | # parser.add_argument("--cos-final-lr", default=1e-6, type=float) 54 | parser.add_argument("--iterations", default=15000, type=int) 55 | parser.add_argument("--warmup-steps", default=300, type=int) 56 | parser.add_argument("--lr", default=1e-3, type=float) 57 | # wsd 58 | parser.add_argument("--wsd-final-lr-scale", default=0.0, type=float) 59 | parser.add_argument("--wsd-fract-decay", default=0.1, type=float) 60 | # parser.add_argument("--wsd-exponential-decay", action="store_true") 61 | parser.add_argument( 62 | "--decay-type", 63 | default="linear", 64 | choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"], 65 | ) 66 | # Optimization 67 | parser.add_argument("--opt", default="adamw", choices=["adamw", "sgd", "SFAdamW"]) 68 | parser.add_argument("--batch-size", default=50, type=int) 69 | parser.add_argument("--acc-steps", default=4, type=int) 70 | parser.add_argument("--weight-decay", default=1e-1, type=float) 71 | parser.add_argument("--beta1", default=0.9, type=float) 72 | parser.add_argument("--beta2", default=0.95, type=float) 73 | parser.add_argument( 74 | "--grad-clip", default=1.0, type=float 75 | ) # default value is 1.0 in NanoGPT 76 | 77 | # Weight Averaging 78 | parser.add_argument("--weight-average", action="store_true") 79 | parser.add_argument( 80 | "--wa-interval", 81 | default=5, 82 | type=int, 83 | help="How often to take the average (every k steps). Must divide wa-horizon.", 84 | ) 85 | parser.add_argument( 86 | "--wa-horizon", 87 | default=500, 88 | type=int, 89 | help="How frequently we save uniform model averages. Should divide " 90 | + "latest-ckpt-interval, otherwise some points may not be saved " 91 | + "correctly.", 92 | ) 93 | parser.add_argument( 94 | "--wa-dtype", 95 | default="float32", 96 | type=str, 97 | choices=["float32", "float64"], 98 | ) 99 | 100 | parser.add_argument("--wa-use-temp-dir", action="store_true") 101 | parser.add_argument("--wa-sweep-horizon", action="store_true") 102 | parser.add_argument("--max-num-wa-sweeps", default=5, type=int) 103 | 104 | parser.add_argument("--exponential-moving-average", action="store_true") 105 | parser.add_argument( 106 | "--ema-interval", 107 | default=10, 108 | type=int, 109 | help="How often to take the EMA average (every k steps).", 110 | ) 111 | parser.add_argument( 112 | "--ema-decay", 113 | default=0.95, 114 | type=float, 115 | help="EMA decay parameter (between 0.9 and 1).", 116 | ) 117 | parser.add_argument( 118 | "--ema-after-warmup", 119 | action="store_true", 120 | help="Start EMA after warmup steps.", 121 | ) 122 | 123 | # Dataset params 124 | parser.add_argument("--datasets-dir", type=str, default="./datasets/") 125 | parser.add_argument( 126 | "--dataset", 127 | default="slimpajama", 128 | choices=[ 129 | "wikitext", 130 | "shakespeare-char", 131 | "arxiv", 132 | "arxiv2000", 133 | "arxiv+wiki", 134 | "openwebtext2", 135 | "redpajama", 136 | "slimpajama", 137 | "slimpajama_chunk1", 138 | "redpajamav2", 139 | ], 140 | ) 141 | parser.add_argument( 142 | "--tokenizer", default="gpt2", type=str, choices=["gpt2", "mistral"] 143 | ) 144 | parser.add_argument("--vocab-size", default=50304, type=int) 145 | parser.add_argument( 146 | "--data-in-ram", action="store_true" 147 | ) # force the data to RAM, mostly useless except for openwebtext2 148 | 149 | # Model params 150 | parser.add_argument( 151 | "--model", 152 | default="llama", 153 | choices=[ 154 | "base", 155 | "llama", 156 | ], 157 | ) 158 | parser.add_argument("--parallel-block", action="store_true") 159 | parser.add_argument( 160 | "--use-pretrained", default="none", type=str 161 | ) # 'none', 'gpt-2' or a path to the pretraind model 162 | parser.add_argument("--from-dense", action="store_true") 163 | parser.add_argument("--init-std", default=0.02, type=float) 164 | parser.add_argument("--dropout", default=0.0, type=float) 165 | parser.add_argument("--n-head", default=12, type=int) 166 | parser.add_argument("--n-layer", default=24, type=int) # depths in att + ff blocks 167 | parser.add_argument("--sequence-length", default=512, type=int) 168 | parser.add_argument( 169 | "--n-embd", default=768, type=int # embedding size / hidden size ... 170 | ) 171 | parser.add_argument( 172 | "--multiple-of", # make SwiGLU hidden layer size multiple of large power of 2 173 | default=256, 174 | type=int, 175 | ) 176 | parser.add_argument("--rmsnorm-eps", default=1e-5, type=float) 177 | parser.add_argument( 178 | "--dtype", 179 | default="bfloat16", 180 | type=str, 181 | choices=["float32", "float16", "bfloat16"], 182 | ) 183 | parser.add_argument("--bias", default=False, type=bool) 184 | parser.add_argument("--compile", action="store_true") 185 | parser.add_argument("--mlp-dim-exp-factor", default=1.0, type=float) 186 | return parser.parse_args(args, namespace) 187 | -------------------------------------------------------------------------------- /src/data/arxiv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | import logging 4 | from pathlib import Path 5 | from typing import Optional 6 | from multiprocessing import Pool 7 | from tempfile import NamedTemporaryFile 8 | from subprocess import Popen, TimeoutExpired, PIPE 9 | from typing import Tuple, List 10 | 11 | import numpy as np 12 | import requests 13 | from tqdm.auto import tqdm 14 | import tiktoken 15 | 16 | 17 | def convert_to_markdown(args: Tuple[Path, Path]): 18 | texfile, mdroot = args 19 | mdfile = mdroot / f"{texfile.name}.md" 20 | with Popen( 21 | ["pandoc", "--wrap=none", "--from", "latex", texfile, "--output", mdfile], 22 | stderr=PIPE, 23 | ) as proc: 24 | try: 25 | proc.communicate(timeout=1) 26 | except TimeoutExpired: 27 | proc.kill() 28 | 29 | 30 | def fetch_arxiv(root: Path, year: int): 31 | # download latex 32 | url = f"https://www.cs.cornell.edu/projects/kddcup/download/hep-th-{year}.tar.gz" 33 | texroot = root / "tex" 34 | print("Downloading Arxiv year", year) 35 | req = requests.get(url, timeout=60) 36 | with NamedTemporaryFile(suffix=".tar.gz") as f: 37 | f.write(req.content) 38 | logging.debug("Tar saved in tempfile %s" % f.name) 39 | with tarfile.open(f.name) as tar: 40 | logging.debug("Extracting tarfile") 41 | tar.extractall(texroot) 42 | 43 | # convert to markdown 44 | mdroot = root / "md" / str(year) 45 | mdroot.mkdir(parents=True) 46 | files = list((texroot / str(year)).iterdir()) 47 | with Pool(os.cpu_count()) as p: 48 | args = [(texfile, mdroot) for texfile in files] 49 | for _ in tqdm( 50 | p.imap_unordered(convert_to_markdown, args), 51 | desc="Converting to markdown", 52 | total=len(files), 53 | ): 54 | pass 55 | 56 | 57 | def tokenize_arxiv(root: Path, year: int): 58 | tokenizer = tiktoken.get_encoding("gpt2") 59 | tokens = [] 60 | tokens_val = [] 61 | tokens_test = [] 62 | mds = root / "md" / str(year) 63 | 64 | # tokenize 65 | desc = f"Tokenizing {year}" 66 | for i, mdpath in enumerate(tqdm(list(mds.iterdir()), desc=desc)): 67 | with open(mdpath, encoding="utf8") as f: 68 | text = "".join(f.readlines()) 69 | if i % 10 <= 6: # train split 70 | tokens += tokenizer.encode(text) 71 | elif i % 10 <= 8: # val split 72 | tokens_val += tokenizer.encode(text) 73 | else: # test split 74 | tokens_test += tokenizer.encode(text) 75 | 76 | # save to dir 77 | tpath = root / str(year) 78 | tpath.mkdir(parents=True) 79 | for x, name in zip([tokens, tokens_val, tokens_test], ["train", "val", "test"]): 80 | mem = np.memmap(tpath / f"{name}.npy", dtype=np.uint16, mode="w+", shape=len(x)) 81 | for i, v in enumerate(x): 82 | mem[i] = v 83 | 84 | 85 | def load_arxiv(cachedir: Path, years: Optional[List[int]] = None): 86 | all_years = list(range(1993, 2004)) # 1992 seems to give some problem 87 | if years is None: 88 | years = all_years 89 | assert set(years) <= set(all_years) 90 | root = cachedir / "arxiv" 91 | root.mkdir(exist_ok=True, parents=True) 92 | 93 | # download all years requested that are not present 94 | for year in years: 95 | if not (root / "md" / str(year)).exists(): 96 | fetch_arxiv(root, year) 97 | 98 | # tokenize all years not previously tokenized 99 | for year in years: 100 | if not (root / str(year)).exists(): 101 | tokenize_arxiv(root, year) 102 | 103 | # load meta 104 | ret = {} 105 | for split in ["train", "val"]: 106 | paths = [root / str(year) / f"{split}.npy" for year in years] 107 | x = [np.memmap(path, dtype=np.uint16, mode="r") for path in paths] 108 | ret[split] = np.concatenate(x) 109 | return ret 110 | 111 | 112 | def get_arxiv_2000(datasets_base_dir): 113 | return load_arxiv(Path(datasets_base_dir), [2000]) 114 | 115 | 116 | def get_arxiv_full(datasets_base_dir): 117 | return load_arxiv(Path(datasets_base_dir)) 118 | -------------------------------------------------------------------------------- /src/data/openwebtext2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | import tiktoken 5 | from datasets import load_dataset 6 | 7 | 8 | tknzr = tiktoken.get_encoding("gpt2") 9 | 10 | 11 | def get_openwebtext2_data(datasets_base_dir, num_proc=40): 12 | """https://openwebtext2.readthedocs.io/en/latest/""" 13 | OWT2_DATA_PATH = os.path.join(datasets_base_dir, "openwebtext2/") 14 | if not os.path.exists(os.path.join(OWT2_DATA_PATH, "train.bin")): 15 | os.makedirs(OWT2_DATA_PATH, exist_ok=True) 16 | dataset = load_dataset("the_pile_openwebtext2") 17 | 18 | split_dataset = dataset["train"].train_test_split( 19 | test_size=0.0005, seed=2357, shuffle=True 20 | ) 21 | split_dataset["val"] = split_dataset.pop("test") 22 | 23 | def process(example): 24 | ids = tknzr.encode_ordinary( 25 | example["text"] 26 | ) # encode_ordinary ignores any special tokens 27 | ids.append( 28 | tknzr.eot_token 29 | ) # add the end of text token, e.g. 50256 for gpt2 bpe 30 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 31 | out = {"ids": ids, "len": len(ids)} 32 | return out 33 | 34 | # tokenize the dataset 35 | tokenized = split_dataset.map( 36 | process, 37 | remove_columns=["text"], 38 | desc="tokenizing the splits", 39 | num_proc=num_proc, 40 | ) 41 | 42 | # concatenate all the ids in each dataset into one large file we can use for training 43 | for split, dset in tokenized.items(): 44 | arr_len = np.sum(dset["len"]) 45 | filename = os.path.join(OWT2_DATA_PATH, f"{split}.bin") 46 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 47 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) 48 | total_batches = 1024 49 | 50 | idx = 0 51 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): 52 | # Batch together samples for faster write 53 | batch = dset.shard( 54 | num_shards=total_batches, index=batch_idx, contiguous=True 55 | ).with_format("numpy") 56 | arr_batch = np.concatenate(batch["ids"]) 57 | # Write into mmap 58 | arr[idx : idx + len(arr_batch)] = arr_batch 59 | idx += len(arr_batch) 60 | arr.flush() 61 | 62 | return { 63 | "train": os.path.join(OWT2_DATA_PATH, "train.bin"), 64 | "val": os.path.join(OWT2_DATA_PATH, "val.bin"), 65 | } 66 | -------------------------------------------------------------------------------- /src/data/redpajama.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | import tiktoken 5 | from datasets import load_dataset 6 | 7 | 8 | tknzr = tiktoken.get_encoding("gpt2") 9 | 10 | 11 | def get_redpajama_data(datasets_dir, num_proc=40): 12 | RPJ_DATA_PATH = os.path.join(datasets_dir, "redpajama1Tsample/") 13 | if not os.path.exists(os.path.join(RPJ_DATA_PATH, "train.bin")): 14 | os.makedirs(RPJ_DATA_PATH, exist_ok=True) 15 | dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample") 16 | 17 | split_dataset = dataset["train"].train_test_split( 18 | test_size=0.0005, seed=2357, shuffle=True 19 | ) 20 | split_dataset["val"] = split_dataset.pop("test") 21 | 22 | def process(example): 23 | ids = tknzr.encode_ordinary( 24 | example["text"] 25 | ) # encode_ordinary ignores any special tokens 26 | ids.append( 27 | tknzr.eot_token 28 | ) # add the end of text token, e.g. 50256 for gpt2 bpe 29 | out = {"ids": ids, "len": len(ids)} 30 | return out 31 | 32 | # tokenize the dataset 33 | tokenized = split_dataset.map( 34 | process, 35 | remove_columns=["text"], 36 | desc="tokenizing the splits", 37 | num_proc=num_proc, 38 | ) 39 | 40 | # concatenate all the ids in each dataset into one large file we can use for training 41 | for split, dset in tokenized.items(): 42 | arr_len = np.sum(dset["len"]) 43 | filename = os.path.join(RPJ_DATA_PATH, f"{split}.bin") 44 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 45 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) 46 | total_batches = min(1024, len(dset)) 47 | 48 | idx = 0 49 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): 50 | # Batch together samples for faster write 51 | batch = dset.shard( 52 | num_shards=total_batches, index=batch_idx, contiguous=True 53 | ).with_format("numpy") 54 | arr_batch = np.concatenate(batch["ids"]) 55 | # Write into mmap 56 | arr[idx : idx + len(arr_batch)] = arr_batch 57 | idx += len(arr_batch) 58 | arr.flush() 59 | 60 | train_data = np.memmap( 61 | os.path.join(RPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" 62 | ) 63 | val_data = np.memmap( 64 | os.path.join(RPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" 65 | ) 66 | 67 | return {"train": train_data, "val": val_data} 68 | 69 | 70 | def get_redpajamav2_data(datasets_dir, num_proc=40): 71 | """https://openwebtext2.readthedocs.io/en/latest/""" 72 | RPJ_V2_DATA_PATH = os.path.join(datasets_dir, "redpajamaV2sample/") 73 | if not os.path.exists(os.path.join(RPJ_V2_DATA_PATH, "train.bin")): 74 | os.makedirs(RPJ_V2_DATA_PATH, exist_ok=True) 75 | dataset = load_dataset("togethercomputer/RedPajama-Data-V2", name="sample") 76 | 77 | split_dataset = dataset["train"].train_test_split( 78 | test_size=0.0005, seed=2357, shuffle=True 79 | ) 80 | split_dataset["val"] = split_dataset.pop("test") 81 | 82 | def process(example): 83 | ids = tknzr.encode_ordinary( 84 | example["raw_content"] 85 | ) # encode_ordinary ignores any special tokens 86 | ids.append( 87 | tknzr.eot_token 88 | ) # add the end of text token, e.g. 50256 for gpt2 bpe 89 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 90 | out = {"ids": ids, "len": len(ids)} 91 | return out 92 | 93 | # tokenize the dataset 94 | tokenized = split_dataset.map( 95 | process, 96 | remove_columns=["raw_content"], 97 | desc="tokenizing the splits", 98 | num_proc=num_proc, 99 | ) 100 | 101 | # concatenate all the ids in each dataset into one large file we can use for training 102 | for split, dset in tokenized.items(): 103 | arr_len = np.sum(dset["len"]) 104 | filename = os.path.join(RPJ_V2_DATA_PATH, f"{split}.bin") 105 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 106 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) 107 | total_batches = min(1024, len(dset)) 108 | 109 | idx = 0 110 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): 111 | # Batch together samples for faster write 112 | batch = dset.shard( 113 | num_shards=total_batches, index=batch_idx, contiguous=True 114 | ).with_format("numpy") 115 | arr_batch = np.concatenate(batch["ids"]) 116 | # Write into mmap 117 | arr[idx : idx + len(arr_batch)] = arr_batch 118 | idx += len(arr_batch) 119 | arr.flush() 120 | 121 | return { 122 | "train": os.path.join(RPJ_V2_DATA_PATH, "train.bin"), 123 | "val": os.path.join(RPJ_V2_DATA_PATH, "val.bin"), 124 | } 125 | -------------------------------------------------------------------------------- /src/data/shakespeare.py: -------------------------------------------------------------------------------- 1 | import os 2 | from string import ascii_letters, digits, punctuation 3 | 4 | import numpy as np 5 | import requests 6 | 7 | 8 | _char_decode = dict( 9 | enumerate(sorted(set(ascii_letters + digits + punctuation + " \n"))) 10 | ) 11 | _char_encode = {char: i for i, char in _char_decode.items()} 12 | 13 | 14 | def char_tknzr(txt: str): 15 | return [_char_encode[char] for char in txt if char in _char_encode] 16 | 17 | 18 | def get_shakespeare_data(datasets_dir): 19 | """Inspired from https://github.com/karpathy/nanoGPT/""" 20 | DATA_PATH = os.path.join(datasets_dir, "shakespeare") 21 | raw_path = os.path.join(DATA_PATH, "raw.txt") 22 | train_path = os.path.join(DATA_PATH, f"train.npy") 23 | test_path = os.path.join(DATA_PATH, f"test.npy") 24 | 25 | # if path is not even there, download all data 26 | if not os.path.exists(DATA_PATH): 27 | print("Downloading raw Shakespeare texts") 28 | url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" 29 | os.makedirs(DATA_PATH, exist_ok=True) 30 | text = requests.get(url, timeout=60).text 31 | with open(raw_path, "w+", encoding="utf8") as f: 32 | f.write(text) 33 | 34 | # attempt to find cached version for current tokenizer 35 | if not os.path.exists(train_path) or not os.path.exists(test_path): 36 | print("Tokenizing Shakespeare texts") 37 | # load text 38 | with open(raw_path, encoding="utf8") as f: 39 | text = "".join(f.readlines()) 40 | i = int(0.8 * len(text)) 41 | # encode text 42 | x = np.array(char_tknzr(text[:i]), dtype=np.uint16) 43 | x_test = np.array(char_tknzr(text[i:]), dtype=np.uint16) 44 | # map memory 45 | mem = np.memmap(train_path, dtype=np.uint16, mode="w+", shape=x.shape) 46 | mem[:] = x 47 | mem = np.memmap(test_path, dtype=np.uint16, mode="w+", shape=x_test.shape) 48 | mem[:] = x_test 49 | 50 | return { 51 | "train": train_path, 52 | "val": test_path, 53 | } 54 | -------------------------------------------------------------------------------- /src/data/slimpajama.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import tiktoken 4 | from datasets import load_dataset 5 | import os 6 | 7 | 8 | tknzr = tiktoken.get_encoding("gpt2") 9 | 10 | 11 | def get_slimpajama_data(datasets_dir, num_proc=40): 12 | SPJ_DATA_PATH = os.path.join(datasets_dir, "slimpajama6B/") 13 | if not os.path.exists(os.path.join(SPJ_DATA_PATH, "train.bin")): 14 | os.makedirs(SPJ_DATA_PATH, exist_ok=True) 15 | dataset = load_dataset("DKYoon/SlimPajama-6B") 16 | 17 | split_dataset = dataset["train"].train_test_split( 18 | test_size=0.0005, seed=2357, shuffle=True 19 | ) 20 | split_dataset["val"] = split_dataset.pop("test") 21 | 22 | def process(example): 23 | ids = tknzr.encode_ordinary( 24 | example["text"] 25 | ) # encode_ordinary ignores any special tokens 26 | ids.append( 27 | tknzr.eot_token 28 | ) # add the end of text token, e.g. 50256 for gpt2 bpe 29 | out = {"ids": ids, "len": len(ids)} 30 | return out 31 | 32 | # tokenize the dataset 33 | tokenized = split_dataset.map( 34 | process, 35 | remove_columns=["text"], 36 | desc="tokenizing the splits", 37 | num_proc=num_proc, 38 | ) 39 | 40 | # concatenate all the ids in each dataset into one large file we can use for training 41 | for split, dset in tokenized.items(): 42 | arr_len = np.sum(dset["len"]) 43 | filename = os.path.join(SPJ_DATA_PATH, f"{split}.bin") 44 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 45 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) 46 | total_batches = min(1024, len(dset)) 47 | 48 | idx = 0 49 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): 50 | # Batch together samples for faster write 51 | batch = dset.shard( 52 | num_shards=total_batches, index=batch_idx, contiguous=True 53 | ).with_format("numpy") 54 | arr_batch = np.concatenate(batch["ids"]) 55 | # Write into mmap 56 | arr[idx : idx + len(arr_batch)] = arr_batch 57 | idx += len(arr_batch) 58 | arr.flush() 59 | 60 | return { 61 | "train": os.path.join(SPJ_DATA_PATH, "train.bin"), 62 | "val": os.path.join(SPJ_DATA_PATH, "val.bin"), 63 | } 64 | 65 | 66 | def get_slimpajama_chunk1(datasets_dir, num_proc=40): 67 | SPJ_DATA_PATH = os.path.join(datasets_dir, "slimpajama6B/") 68 | SPJ_CHUNK_1_DATA_PATH = os.path.join(SPJ_DATA_PATH, "chunk1") 69 | if not os.path.exists(os.path.join(SPJ_CHUNK_1_DATA_PATH, "train.bin")): 70 | os.makedirs(SPJ_DATA_PATH, exist_ok=True) 71 | dataset = load_dataset("cerebras/SlimPajama-627B", split="train/chunk1") 72 | 73 | split_dataset = dataset["train"].train_test_split( 74 | test_size=0.0005, seed=2357, shuffle=True 75 | ) 76 | split_dataset["val"] = split_dataset.pop("test") 77 | 78 | def process(example): 79 | ids = tknzr.encode_ordinary( 80 | example["text"] 81 | ) # encode_ordinary ignores any special tokens 82 | ids.append( 83 | tknzr.eot_token 84 | ) # add the end of text token, e.g. 50256 for gpt2 bpe 85 | out = {"ids": ids, "len": len(ids)} 86 | return out 87 | 88 | # tokenize the dataset 89 | tokenized = split_dataset.map( 90 | process, 91 | remove_columns=["text"], 92 | desc="tokenizing the splits", 93 | num_proc=num_proc, 94 | ) 95 | 96 | # concatenate all the ids in each dataset into one large file we can use for training 97 | for split, dset in tokenized.items(): 98 | arr_len = np.sum(dset["len"]) 99 | filename = os.path.join(SPJ_DATA_PATH, f"{split}.bin") 100 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 101 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) 102 | total_batches = min(1024, len(dset)) 103 | 104 | idx = 0 105 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): 106 | # Batch together samples for faster write 107 | batch = dset.shard( 108 | num_shards=total_batches, index=batch_idx, contiguous=True 109 | ).with_format("numpy") 110 | arr_batch = np.concatenate(batch["ids"]) 111 | # Write into mmap 112 | arr[idx : idx + len(arr_batch)] = arr_batch 113 | idx += len(arr_batch) 114 | arr.flush() 115 | 116 | return { 117 | "train": os.path.join(SPJ_DATA_PATH, "train.bin"), 118 | "val": os.path.join(SPJ_DATA_PATH, "val.bin"), 119 | } 120 | -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | from typing import Dict 4 | import torch 5 | import torch.distributed as dist 6 | 7 | from .shakespeare import get_shakespeare_data 8 | from .wikitext import get_wikitext_data 9 | from .arxiv import get_arxiv_2000, get_arxiv_full 10 | from .openwebtext2 import get_openwebtext2_data 11 | from .redpajama import get_redpajama_data, get_redpajamav2_data 12 | from .slimpajama import get_slimpajama_data 13 | 14 | 15 | def get_dataset(args) -> Dict[str, np.ndarray]: 16 | """Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is 17 | contained in its own python file. The expected format at the moment is a dictionary of np.memmap 18 | containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. 19 | """ 20 | if args.dataset == "wikitext": 21 | return get_wikitext_data(args.datasets_dir) 22 | if args.dataset == "shakespeare-char": 23 | return get_shakespeare_data(args.datasets_dir) 24 | if args.dataset == "arxiv2000": 25 | return get_arxiv_2000(args.datasets_dir) 26 | if args.dataset == "arxiv": 27 | return get_arxiv_full(args.datasets_dir) 28 | if args.dataset == "arxiv+wiki": 29 | arxiv_data = get_arxiv_full(args.datasets_dir) 30 | wiki_data = get_wikitext_data(args.datasets_dir) 31 | train_data = np.concatenate((arxiv_data["train"], wiki_data["train"])) 32 | val_data = np.concatenate((arxiv_data["val"], wiki_data["val"])) 33 | return {"train": train_data, "val": val_data} 34 | if args.dataset == "openwebtext2": 35 | return get_openwebtext2_data(args.datasets_dir) 36 | if args.dataset == "redpajama": 37 | return get_redpajama_data(args.datasets_dir) 38 | if args.dataset == "redpajamav2": 39 | return get_redpajamav2_data(args.datasets_dir) 40 | if args.dataset == "slimpajama": 41 | return get_slimpajama_data(args.datasets_dir) 42 | else: 43 | raise NotImplementedError(f"Unknow dataset key '{args.dataset}'") 44 | 45 | 46 | class DataReader: 47 | def __init__( 48 | self, 49 | data_src, 50 | batch_size, 51 | sequence_length, 52 | seed=1337, 53 | with_replacement=False, 54 | auto_shard=True, 55 | keep_in_ram=False, 56 | ): 57 | if isinstance(data_src, (str, Path)): 58 | self.data_path = Path(data_src) 59 | self.keep_in_ram = keep_in_ram 60 | if keep_in_ram: 61 | self.data = np.array( 62 | np.memmap(self.data_path, dtype=np.uint16, mode="r") 63 | ) 64 | else: 65 | self.data = None 66 | elif isinstance(data_src, (np.ndarray, np.memmap)): 67 | self.data_path = None 68 | self.data = data_src 69 | self.keep_in_ram = True 70 | 71 | self.batch_size = batch_size 72 | self.sequence_length = sequence_length 73 | self.seed = seed 74 | self.with_replacement = with_replacement 75 | 76 | self.num_tokens = len(self._get_data()) 77 | 78 | if auto_shard and dist.is_initialized(): 79 | self.world_size = dist.get_world_size() 80 | self.rank = dist.get_rank() 81 | print( 82 | f"Distributed DataReader Initialized for Worker {self.rank}/{self.world_size}" 83 | ) 84 | else: 85 | self.world_size = 1 86 | self.rank = 0 87 | 88 | # Sampling without replacement 89 | self.last_epoch = None 90 | self.order = None 91 | self.epoch_offset = None 92 | self.step = 0 93 | self.num_batches_of_seqlen = 0 94 | if not with_replacement: 95 | self._shuffle_epoch(0) 96 | 97 | def __len__(self): 98 | # Length in valid start indices for a sequence 99 | # Extra -1 to have a valid next token for the final token of the last idx 100 | return self.num_tokens - self.sequence_length - 1 101 | 102 | def _get_data(self): 103 | if self.data is not None: 104 | return self.data 105 | else: 106 | # Construct the memmap each time to avoid a memory leak per NanoGPT 107 | # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 108 | return np.memmap(self.data_path, dtype=np.uint16, mode="r") 109 | 110 | def __getitem__(self, idx): 111 | # Return the underlying datapoint, no random sampling, no worker sharding 112 | assert 0 <= idx < len(self) 113 | data = self._get_data() 114 | x = torch.from_numpy(data[idx : idx + self.sequence_length].astype(np.int64)) 115 | y = torch.from_numpy( 116 | data[idx + 1 : idx + self.sequence_length + 1].astype(torch.int64) 117 | ) 118 | return x, y 119 | 120 | def set_step(self, step): 121 | self.step = step 122 | 123 | def sample_batch(self): 124 | data = self._get_data() 125 | 126 | if self.with_replacement: 127 | idxs = self._sample_with_replacement(self.step) 128 | else: 129 | idxs = self._sample_without_replacement(self.step) 130 | self.step += 1 131 | 132 | xy = np.stack([data[i : i + self.sequence_length + 1] for i in idxs]).astype( 133 | np.int64 134 | ) 135 | x = torch.from_numpy(xy[:, :-1]).contiguous() 136 | y = torch.from_numpy(xy[:, 1:]).contiguous() 137 | return x, y 138 | 139 | def _sample_with_replacement(self, idx): 140 | # Return an array of token indices of length self.batch_size 141 | # Sampled with replacement, can get repeats at any time 142 | seed = self.seed + idx * self.world_size + self.rank 143 | rng = np.random.default_rng(seed) 144 | return rng.integers(len(self), self.batch_size) 145 | 146 | def _shuffle_epoch(self, epoch): 147 | seed = self.seed + epoch 148 | rng = np.random.default_rng(seed) 149 | # Drop one sequence to allow different offsets per epoch: 150 | self.order = rng.permutation((len(self)) // self.sequence_length - 1) 151 | # Shift all sequences in this epoch by this amount: 152 | self.epoch_offset = rng.integers(self.sequence_length) 153 | self.last_epoch = epoch 154 | self.num_batches_of_seqlen = ( 155 | len(self.order) // self.batch_size 156 | ) # Drops remainder batch 157 | 158 | def _sample_without_replacement(self, step): 159 | # Return an array of token indices of length self.batch_size 160 | # Sampled without replacement, cycle all sequences before potential repeats 161 | # Sequences are randomly offset in every epoch as well 162 | batch_idx = self.world_size * step + self.rank 163 | epoch_length = self.num_batches_of_seqlen 164 | 165 | epoch = batch_idx // epoch_length 166 | if epoch != self.last_epoch: 167 | self._shuffle_epoch(epoch) 168 | epoch_idx = batch_idx % epoch_length 169 | 170 | start = epoch_idx * self.batch_size 171 | end = start + self.batch_size 172 | return self.order[start:end] * self.sequence_length + self.epoch_offset 173 | 174 | def num_batches(self): 175 | if self.with_replacement: 176 | return self.num_tokens // self.batch_size 177 | return self.num_batches_of_seqlen 178 | -------------------------------------------------------------------------------- /src/data/wikitext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import urllib 4 | import numpy as np 5 | import tiktoken 6 | 7 | 8 | def get_wikitext_data(datasets_base_dir): 9 | """Inspired from https://github.com/tysam-code/hlb-gpt""" 10 | WIKITEXT_DATA_PATH = os.path.join(datasets_base_dir, "wikitext/") 11 | if not os.path.exists(WIKITEXT_DATA_PATH): 12 | os.makedirs(WIKITEXT_DATA_PATH, exist_ok=True) 13 | print("downloading data and tokenizing (1-2 min)") 14 | raw_data_source = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip" 15 | urllib.request.urlretrieve( 16 | raw_data_source, os.path.join(WIKITEXT_DATA_PATH, "data.zip") 17 | ) 18 | 19 | with zipfile.ZipFile( 20 | os.path.join(WIKITEXT_DATA_PATH, "data.zip"), "r" 21 | ) as zip_ref: 22 | zip_ref.extractall(WIKITEXT_DATA_PATH) 23 | 24 | with open( 25 | os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.train.raw"), "r" 26 | ) as data_file: 27 | raw_train_data = data_file.read() 28 | 29 | with open( 30 | os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.valid.raw"), "r" 31 | ) as data_file: 32 | raw_eval_data = data_file.read() 33 | 34 | tokenizer = tiktoken.get_encoding("gpt2") 35 | raw_tokenized_train = tokenizer.encode_ordinary(raw_train_data) 36 | raw_tokenized_eval = tokenizer.encode_ordinary(raw_eval_data) 37 | 38 | train_tokenized = np.array(raw_tokenized_train, dtype=np.uint16) 39 | eval_tokenized = np.array(raw_tokenized_eval, dtype=np.uint16) 40 | 41 | train_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, "train.bin")) 42 | eval_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, "val.bin")) 43 | print("completed the tokenization process!") 44 | 45 | return { 46 | "train": os.path.join(WIKITEXT_DATA_PATH, "train.bin"), 47 | "val": os.path.join(WIKITEXT_DATA_PATH, "val.bin"), 48 | } 49 | -------------------------------------------------------------------------------- /src/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ddp 2 | from . import single 3 | 4 | BACKEND_TYPE_TO_MODULE_MAP = { 5 | "nccl": ddp.DataParallelDistributedBackend, 6 | None: single.SinlgeNodeBackend, 7 | } 8 | 9 | 10 | def make_backend_from_args(args): 11 | return BACKEND_TYPE_TO_MODULE_MAP[args.distributed_backend](args) 12 | 13 | 14 | def registered_backends(): 15 | return BACKEND_TYPE_TO_MODULE_MAP.keys() 16 | -------------------------------------------------------------------------------- /src/distributed/backend.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class DistributedBackend(object): 5 | 6 | def __init__(self, args): 7 | pass 8 | 9 | def transform_model(self, model): 10 | raise NotImplementedError 11 | 12 | def get_context_for_microstep_forward( 13 | self, model, microstep_idx, gradient_accumulation_steps 14 | ): 15 | raise NotImplementedError 16 | 17 | def is_master_process(self) -> bool: 18 | raise NotImplementedError 19 | 20 | def get_adjusted_args_for_process(self, args): 21 | raise NotImplementedError 22 | 23 | def get_raw_model(self, model): 24 | raise NotImplementedError 25 | 26 | def translate_model_parameter_name_for_node(self, parameter_name) -> List[str]: 27 | raise NotImplementedError 28 | 29 | def get_world_size(self): 30 | raise NotImplementedError 31 | 32 | def finalize(self): 33 | pass 34 | -------------------------------------------------------------------------------- /src/distributed/ddp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from contextlib import contextmanager 4 | 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | from torch.distributed import init_process_group, destroy_process_group, get_world_size 7 | 8 | from .backend import DistributedBackend 9 | 10 | 11 | class DataParallelDistributedBackend(DistributedBackend): 12 | 13 | def __init__(self, args): 14 | self.rank = int(os.environ.get("RANK", -1)) 15 | assert self.rank != -1, "DDP backend can not be used without rank" 16 | assert "cuda" in args.device, "DDP backend can not be used on non-CUDA devices" 17 | init_process_group(backend=args.distributed_backend) 18 | self.local_rank = int(os.environ["LOCAL_RANK"]) 19 | 20 | def get_adjusted_args_for_process(self, args): 21 | effective_batch_size = args.batch_size * args.acc_steps 22 | world_size = self.get_world_size() 23 | if effective_batch_size % world_size != 0: 24 | raise ValueError( 25 | f"Effective batch size " 26 | "{effective_batch_size} is not divisible " 27 | "by the world size {world_size}." 28 | ) 29 | acc_steps_div = math.gcd(args.acc_steps, world_size) 30 | args.acc_steps = args.acc_steps // acc_steps_div 31 | args.batch_size = args.batch_size // (world_size // acc_steps_div) 32 | args.device = f"cuda:{self.local_rank}" 33 | args.seed = args.seed + self.local_rank 34 | args.data_seed = args.data_seed 35 | return args 36 | 37 | def transform_model(self, model): 38 | return DDP(model, device_ids=[self.local_rank]) 39 | 40 | @contextmanager 41 | def get_context_for_microstep_forward( 42 | self, model, microstep_idx, gradient_accumulation_steps 43 | ): 44 | model.require_backward_grad_sync = ( 45 | microstep_idx == gradient_accumulation_steps - 1 46 | ) 47 | yield 48 | 49 | def is_master_process(self) -> bool: 50 | return self.rank == 0 51 | 52 | def get_raw_model(self, model): 53 | return model.module 54 | 55 | def translate_model_parameter_name_for_node(self, parameter_name): 56 | return [f"module.{parameter_name}"] 57 | 58 | def get_world_size(self): 59 | return get_world_size() 60 | 61 | def finalize(self): 62 | destroy_process_group() 63 | -------------------------------------------------------------------------------- /src/distributed/single.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | 3 | from .backend import DistributedBackend 4 | 5 | 6 | class SinlgeNodeBackend(DistributedBackend): 7 | 8 | def __init__(self, args): 9 | super().__init__(args) 10 | self.rank = 0 11 | 12 | def transform_model(self, model): 13 | return model 14 | 15 | def get_context_for_microstep_forward(self, *args, **kwargs): 16 | return nullcontext() 17 | 18 | def get_adjusted_args_for_process(self, args): 19 | return args 20 | 21 | def is_master_process(self) -> bool: 22 | return True 23 | 24 | def get_raw_model(self, model): 25 | return model 26 | 27 | def get_world_size(self): 28 | return 1 29 | 30 | def translate_model_parameter_name_for_node(self, parameter_name): 31 | return [parameter_name] 32 | -------------------------------------------------------------------------------- /src/logger/logger.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import copy 3 | from functools import wraps 4 | from pathlib import Path 5 | import pickle 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import wandb 10 | 11 | 12 | class DynamicsLogger: 13 | def __init__(self, model, optimizer, cfg, output_folder, wandb=False): 14 | self.model = model 15 | self.optimizer = optimizer 16 | self.iteration = 0 17 | self.output_folder = output_folder 18 | self.wandb = wandb 19 | 20 | self.interval = None # Set in step 21 | 22 | if isinstance(cfg["interval"], int): 23 | self.interval = cfg["interval"] 24 | elif isinstance(cfg["interval"], (tuple, list)): 25 | # Tuples of iter, interval with iter0 < iter1 < ... 26 | # [(iter0, interval0), (iter1, interval1), (iter2, interval2)] 27 | # At iterX change interval to intervalX 28 | self.interval = cfg["interval"][0][1] 29 | 30 | # TODO: Add default cfg here once tested 31 | self.cfg = copy.deepcopy(cfg) 32 | if self.cfg["disk_stats"] == "all": 33 | self.cfg["disk_stats"] = self.cfg["stats"] 34 | if self.cfg["wandb_stats"] == "all": 35 | self.cfg["wandb_stats"] = self.cfg["stats"] 36 | 37 | self.stats = defaultdict(lambda: defaultdict(list)) 38 | self.reducers = defaultdict(lambda: "rms") 39 | self.reducers.update( 40 | { 41 | # Make sure any signed stats are not reduced with RMS! 42 | # 'layer_cos_gradient_angle': 'mean', 43 | # 'neuron/cos_gradient_angle': 'mean', 44 | # 'vector/values': 'mean', 45 | # 'vector/update_values': 'mean', 46 | } 47 | ) 48 | 49 | # Wrap the step method of the optimizer 50 | self.optimizer.original_step = self.optimizer.step 51 | self.optimizer.step = self_preserving_overwrite(self.step).__get__( 52 | self.optimizer 53 | ) 54 | 55 | self.wandb_setup_complete = False 56 | 57 | def step(self, *args, **kwargs): 58 | # Dictionaries keyed by parameter name 59 | # NOTE: Some may be direct references to model / optimizer (do not change in-place) 60 | pre_params = dict() 61 | pre_grads = dict() 62 | pre_states = dict() 63 | post_params = dict() 64 | post_states = dict() 65 | 66 | if isinstance(self.cfg["interval"], int): 67 | self.interval = self.cfg["interval"] 68 | elif isinstance(self.cfg["interval"], (tuple, list)): 69 | # Tuples of iter, interval with iter0 < iter1 < ... 70 | # [(iter0, interval0), (iter1, interval1), (iter2, interval2)] 71 | # At iterX change interval to intervalX 72 | idx = 0 73 | while ( 74 | idx < len(self.cfg["interval"]) 75 | and self.cfg["interval"][idx][0] <= self.iteration 76 | ): 77 | idx += 1 78 | self.interval = self.cfg["interval"][idx - 1][1] 79 | 80 | if "eps" in self.optimizer.defaults: 81 | eps = self.optimizer.defaults["eps"] 82 | else: 83 | eps = 1e-8 84 | 85 | if self.iteration % self.interval == 0: 86 | for name, param in self.model.named_parameters(): 87 | pre_params[name] = param.clone().detach() 88 | if param.grad is not None: 89 | pre_grads[name] = param.grad 90 | else: 91 | pre_grads[name] = None 92 | 93 | pre_states[name] = copy.deepcopy(self.optimizer.state[param]) 94 | 95 | self.optimizer.original_step(*args, **kwargs) # Assuming no change to grads 96 | 97 | for name, param in self.model.named_parameters(): 98 | post_params[name] = param.detach() 99 | post_states[name] = self.optimizer.state[param] 100 | 101 | self.log_statistics( 102 | pre_params, post_params, pre_grads, pre_states, post_states, eps 103 | ) 104 | else: 105 | # Normal optimizer step, no logging 106 | self.optimizer.original_step(*args, **kwargs) 107 | 108 | self.iteration += 1 109 | 110 | @torch.no_grad() 111 | def log_statistics( 112 | self, pre_params, post_params, pre_grads, pre_states, post_states, eps 113 | ): 114 | requested_stats = set(self.cfg["stats"]) 115 | 116 | if {"layer_norm", "neuron_norm"} & requested_stats: 117 | for name, param in pre_params.items(): 118 | if param.dim() < 2: 119 | # Only higher dimensional weights (linear, conv etc) 120 | continue 121 | 122 | # Compute neuron norms (assume shape K x C x ...) 123 | neuron_norm = torch.linalg.vector_norm(param.flatten(1), dim=1) 124 | 125 | if "layer_norm" in requested_stats: 126 | # This makes more sense with layernorm 127 | # for BN rms_neuron_norm is what we predict (closely related) 128 | layer_norm = torch.linalg.vector_norm(neuron_norm) 129 | self.stats["layer_norm"][name].append(layer_norm) 130 | if "neuron_norm" in requested_stats: 131 | self.stats["neuron_norm"][name].append(neuron_norm) 132 | 133 | if {"layer_grad_norm", "neuron_grad_norm"} & requested_stats: 134 | for name, grad in pre_grads.items(): 135 | if grad.dim() < 2: 136 | # Only higher dimensional weights (linear, conv etc) 137 | continue 138 | 139 | # Compute neuron norms (assume shape K x C x ...) 140 | neuron_grad_norm = torch.linalg.vector_norm(grad.flatten(1), dim=1) 141 | 142 | if "layer_grad_norm" in requested_stats: 143 | grad_norm = torch.linalg.vector_norm(neuron_grad_norm) 144 | self.stats["layer_grad_norm"][name].append(grad_norm) 145 | if "neuron_grad_norm" in requested_stats: 146 | self.stats["neuron_grad_norm"][name].append(neuron_grad_norm) 147 | 148 | if {"layer_update_norm", "neuron_update_norm"} & requested_stats: 149 | for name, pre_param in pre_params.items(): 150 | if pre_param.dim() < 2: 151 | # Only higher dimensional weights (linear, conv etc) 152 | continue 153 | 154 | post_param = post_params[name] 155 | diff = post_param - pre_param 156 | 157 | # Compute neuron norms (assume shape K x C x ...) 158 | neuron_update_norm = torch.linalg.vector_norm(diff.flatten(1), dim=1) 159 | 160 | if "layer_update_norm" in requested_stats: 161 | layer_norm = torch.linalg.vector_norm(neuron_update_norm) 162 | self.stats["layer_update_norm"][name].append(layer_norm) 163 | if "neuron_update_norm" in requested_stats: 164 | self.stats["neuron_update_norm"][name].append(neuron_update_norm) 165 | 166 | if {"layer_relative_update", "neuron_relative_update"} & requested_stats: 167 | for name, pre_param in pre_params.items(): 168 | if pre_param.dim() < 2: 169 | # Only higher dimensional weights (linear, conv etc) 170 | continue 171 | 172 | post_param = post_params[name] 173 | diff = post_param - pre_param 174 | 175 | if "layer_relative_update" in requested_stats: 176 | layer_diff_norm = torch.linalg.vector_norm(diff) 177 | layer_norm = torch.linalg.vector_norm(pre_param) 178 | layer_relative_update = (layer_diff_norm + eps) / (layer_norm + eps) 179 | self.stats["layer_relative_update"][name].append( 180 | layer_relative_update 181 | ) 182 | 183 | if "neuron_relative_update" in requested_stats: 184 | neuron_diff_norm = torch.linalg.vector_norm(diff.flatten(1), dim=1) 185 | neuron_norm = torch.linalg.vector_norm(pre_param.flatten(1), dim=1) 186 | neuron_relative_update = (neuron_diff_norm + eps) / ( 187 | neuron_norm + eps 188 | ) 189 | self.stats["neuron_relative_update"][name].append( 190 | neuron_relative_update 191 | ) 192 | 193 | if {"layer_angular_update", "neuron_angular_update"} & requested_stats: 194 | for name, pre_param in pre_params.items(): 195 | if pre_param.dim() < 2: 196 | # Only higher dimensional weights (linear, conv etc) 197 | continue 198 | 199 | post_param = post_params[name] 200 | pre_param = ( 201 | pre_param.double() 202 | ) # There is a lot of noise for small angles 203 | 204 | if "layer_angular_update" in requested_stats: 205 | cos = F.cosine_similarity( 206 | pre_param.flatten(), post_param.flatten(), dim=0 207 | ) 208 | angles = torch.acos(torch.clamp(cos, min=-1, max=1)) 209 | self.stats["layer_angular_update"][name].append(angles) 210 | 211 | if "neuron_angular_update" in requested_stats: 212 | cos = F.cosine_similarity( 213 | pre_param.flatten(1), post_param.flatten(1), dim=1 214 | ) 215 | angles = torch.acos(torch.clamp(cos, min=-1, max=1)) 216 | self.stats["neuron_angular_update"][name].append(angles) 217 | 218 | if {"layer_grad_alignment", "neuron_grad_alignment"} & requested_stats: 219 | for name, pre_param in pre_params.items(): 220 | if pre_param.dim() < 2: 221 | # Only higher dimensional weights (linear, conv etc) 222 | continue 223 | 224 | pre_grad = pre_grads[name] 225 | mean_layer_alignment = self.layer_cosine_sim(pre_grad, pre_param) 226 | self.stats["layer_grad_alignment"][name].append(mean_layer_alignment) 227 | 228 | mean_neuron_alignment = self.neuron_cosine_sim(pre_grad, pre_param) 229 | self.stats["neuron_grad_alignment"][name].append(mean_neuron_alignment) 230 | 231 | if { 232 | "layer_grad_velocity_alignment", 233 | "neuron_grad_velocity_alignment", 234 | } & requested_stats: 235 | for name, pre_param in pre_params.items(): 236 | if pre_param.dim() < 2: 237 | # Only higher dimensional weights (linear, conv etc) 238 | continue 239 | 240 | # Adam and similar only 241 | if "exp_avg" not in pre_states[name]: 242 | self.stats["layer_grad_velocity_alignment"][name].append( 243 | torch.tensor(0).to(pre_param.device) 244 | ) 245 | self.stats["neuron_grad_velocity_alignment"][name].append( 246 | torch.zeros(pre_param.shape[0]).to(pre_param.device) 247 | ) 248 | continue 249 | 250 | pre_grad = pre_grads[name] 251 | pre_state = pre_states[name] 252 | pre_m = pre_state["exp_avg"] 253 | 254 | mean_layer_alignment = self.layer_cosine_sim(pre_grad, pre_m) 255 | self.stats["layer_grad_velocity_alignment"][name].append( 256 | mean_layer_alignment 257 | ) 258 | 259 | mean_neuron_alignment = self.neuron_cosine_sim(pre_grad, pre_m) 260 | self.stats["neuron_grad_velocity_alignment"][name].append( 261 | mean_neuron_alignment 262 | ) 263 | 264 | # TODO: Log averages for scalar vectors 265 | if "scalar_rms" in requested_stats: 266 | for name, pre_param in pre_params.items(): 267 | if pre_param.dim() >= 2: 268 | # Only scalars and scalar vectors 269 | continue 270 | 271 | self.stats["scalar_rms"][name].append((param**2).mean().sqrt()) 272 | 273 | if "scalar_update_rms" in requested_stats: 274 | for name, pre_param in pre_params.items(): 275 | if pre_param.dim() >= 2: 276 | # Only scalars and scalar vectors 277 | continue 278 | 279 | post_param = post_params[name] 280 | diff = post_param - pre_param 281 | 282 | self.stats["scalar_update_rms"][name].append((diff**2).mean().sqrt()) 283 | 284 | if "scalar_grad_rms" in requested_stats: 285 | for name, grad in pre_grads.items(): 286 | if grad.dim() >= 2: 287 | # Only scalars and scalar vectors 288 | continue 289 | 290 | self.stats["scalar_grad_rms"][name].append((grad**2).mean().sqrt()) 291 | 292 | # Could add similar per elements histograms for the following: 293 | # scalar_value 294 | # scalar_update_value 295 | # scalar_grad_value 296 | 297 | # More obscure metrics below, only used in select experiments 298 | if { 299 | "layer_mean_second_grad_moment", 300 | "neuron_mean_second_grad_moment", 301 | } & requested_stats: 302 | for name, pre_param in pre_params.items(): 303 | if pre_param.dim() < 2: 304 | # Only higher dimensional weights (linear, conv etc) 305 | continue 306 | 307 | post_state = post_states[name] 308 | post_v = post_state["exp_avg_sq"] 309 | 310 | if "layer_mean_second_grad_moment" in requested_stats: 311 | mean_v = torch.mean(post_v) 312 | self.stats["layer_mean_second_grad_moment"][name].append(mean_v) 313 | 314 | if "neuron_mean_second_grad_moment" in requested_stats: 315 | mean_v = torch.mean(post_v.flatten(1), dim=1) 316 | self.stats["neuron_mean_second_grad_moment"][name].append(mean_v) 317 | 318 | if { 319 | "layer_second_grad_moment_std_mean_ratio", 320 | "neuron_second_grad_moment_std_mean_ratio", 321 | } & requested_stats: 322 | for name, pre_param in pre_params.items(): 323 | if pre_param.dim() < 2: 324 | # Only higher dimensional weights (linear, conv etc) 325 | continue 326 | 327 | post_state = post_states[name] 328 | post_v = post_state["exp_avg_sq"] 329 | 330 | v_neuron_mean = post_v.flatten(1).mean(dim=1) 331 | v_neuron_std = post_v.flatten(1).std(dim=1) 332 | neuron_std_mean_ratio = torch.div(v_neuron_std, v_neuron_mean) 333 | self.stats["neuron_second_grad_moment_std_mean_ratio"][name].append( 334 | neuron_std_mean_ratio 335 | ) 336 | 337 | v_layer_mean = post_v.flatten(0).mean(dim=0) 338 | v_layer_std = post_v.flatten(0).std(dim=0) 339 | layer_std_mean_ratio = torch.div(v_layer_std, v_layer_mean) 340 | self.stats["layer_second_grad_moment_std_mean_ratio"][name].append( 341 | layer_std_mean_ratio 342 | ) 343 | 344 | if {"layer_scaled_grad_norm", "neuron_scaled_grad_norm"} & requested_stats: 345 | for name, pre_param in pre_params.items(): 346 | if pre_param.dim() < 2: 347 | # Only higher dimensional weights (linear, conv etc) 348 | continue 349 | 350 | pre_grad = pre_grads[name] 351 | post_state = post_states[name] 352 | post_v = post_state["exp_avg_sq"] 353 | 354 | scaled_grad = torch.div(pre_grad, (post_v.sqrt() + 1e-8)) 355 | 356 | layer_scaled_grad_norm = self.layer_norm(scaled_grad) 357 | self.stats["layer_scaled_grad_norm"][name].append( 358 | layer_scaled_grad_norm 359 | ) 360 | 361 | neuron_scaled_grad_norm = self.neuron_norm(scaled_grad) 362 | self.stats["neuron_scaled_grad_norm"][name].append( 363 | neuron_scaled_grad_norm 364 | ) 365 | 366 | if { 367 | "layer_scaled_grad_wd_projection", 368 | "neuron_scaled_grad_wd_projection", 369 | } & requested_stats: 370 | for name, pre_param in pre_params.items(): 371 | if pre_param.dim() < 2: 372 | # Only higher dimensional weights (linear, conv etc) 373 | continue 374 | 375 | pre_grad = pre_grads[name] 376 | post_state = post_states[name] 377 | post_v = post_state["exp_avg_sq"] 378 | 379 | scaled_grad = torch.div(pre_grad, (post_v.sqrt() + 1e-8)) 380 | layer_scaled_grad_wd_projection = self.layer_gradient_wd_project( 381 | scaled_grad, pre_param 382 | ) 383 | self.stats["layer_scaled_grad_wd_projection"][name].append( 384 | layer_scaled_grad_wd_projection 385 | ) 386 | 387 | neuron_scaled_grad_wd_projection = self.neuron_gradient_wd_project( 388 | scaled_grad, pre_param 389 | ) 390 | self.stats["neuron_scaled_grad_wd_projection"][name].append( 391 | neuron_scaled_grad_wd_projection 392 | ) 393 | 394 | T_disk = self.cfg["disk_save_interval"] or 0 395 | T_wandb = self.cfg["wandb_interval"] or 0 396 | 397 | # Maybe log to disk 398 | if T_disk and (self.iteration + self.interval) % (T_disk * self.interval) == 0: 399 | self.log_to_disk() 400 | 401 | # Maybe log to wandb 402 | if ( 403 | self.wandb 404 | and T_wandb 405 | and (self.iteration + self.interval) % (T_wandb * self.interval) == 0 406 | ): 407 | self.log_to_wandb() 408 | 409 | def layer_gradient_wd_project(self, g_t, w_t): 410 | norm = self.layer_norm(w_t) 411 | dot_prod = torch.sum(w_t.flatten() * g_t.flatten(), dim=0) 412 | projection = torch.div(dot_prod, norm * norm) 413 | return projection 414 | 415 | def neuron_gradient_wd_project(self, g_t, w_t): 416 | norm = self.neuron_norm(w_t) 417 | dot_prod = torch.sum(w_t.flatten(1) * g_t.flatten(1), dim=1) 418 | projection = torch.div(dot_prod, norm * norm) 419 | return projection 420 | 421 | def layer_cosine_sim(self, v1, v2): 422 | return F.cosine_similarity(v1.flatten(), v2.flatten(), dim=0) 423 | 424 | def neuron_cosine_sim(self, v1, v2): 425 | return F.cosine_similarity(v1.flatten(1), v2.flatten(1), dim=1) 426 | 427 | def layer_norm(self, v1): 428 | return torch.linalg.vector_norm(v1.flatten(), dim=0) 429 | 430 | def neuron_norm(self, v1): 431 | return torch.linalg.vector_norm(v1.flatten(1), dim=1) 432 | 433 | def log_to_disk(self, free_buffers=True): 434 | out_dict = dict() 435 | T_disk = self.cfg["disk_save_interval"] 436 | for stat_name in self.cfg["disk_stats"]: 437 | out_dict[stat_name] = dict() 438 | reducer = self.reducers[stat_name] 439 | 440 | for param_name, values in self.stats[stat_name].items(): 441 | values = torch.stack(values[-T_disk:]) 442 | if self.cfg["disk_max_channels"] > 0 and values.dim() > 1: 443 | values = values[:, : self.cfg["disk_max_channels"]] 444 | if self.cfg["disk_downsample"] > 1: 445 | assert T_disk % self.cfg["disk_downsample"] == 0 446 | values = values.reshape( 447 | ( 448 | T_disk // self.cfg["disk_downsample"], 449 | self.cfg["disk_downsample"], 450 | -1, 451 | ) 452 | ) 453 | if reducer == "mean": 454 | values = values.mean(dim=1) 455 | elif reducer == "rms": 456 | values = (values**2).mean(dim=1).sqrt() 457 | elif reducer == "first": 458 | values = values[:, 0] 459 | else: 460 | raise ValueError(f"Unknown {reducer=}") 461 | 462 | values = values.cpu() 463 | out_dict[stat_name][param_name] = values 464 | 465 | out_path = Path(self.output_folder) / "dynamics.pkl" 466 | with open(out_path, "ab") as fp: 467 | # Multiple dumps in a single file 468 | # https://stackoverflow.com/a/12762056 469 | pickle.dump(out_dict, fp) 470 | 471 | if free_buffers: 472 | self.free_buffers("disk") 473 | 474 | def log_to_wandb(self, free_buffers=True): 475 | # Assume stats are logged as a list of tensors for each stat 476 | # Reducer can be individual samples (i.e. the first) or mean 477 | 478 | out_dict = dict() 479 | T_wandb = self.cfg["wandb_interval"] 480 | for stat_name in self.cfg["wandb_stats"]: 481 | out_dict[stat_name] = dict() 482 | reducer = self.reducers[stat_name] 483 | 484 | for param_name, values in self.stats[stat_name].items(): 485 | values = torch.stack(values[-T_wandb:]) 486 | 487 | if reducer == "mean": 488 | values = values.mean(dim=0) 489 | elif reducer == "global_mean": 490 | values = values.mean(dim=0).mean() 491 | elif reducer == "rms": 492 | values = (values**2).mean(dim=0).sqrt() 493 | elif reducer == "global_rms": 494 | values = (values**2).mean(dim=0).sqrt().mean() 495 | elif reducer == "first": 496 | values = values[0] 497 | else: 498 | raise ValueError(f"Unknown {reducer=}") 499 | 500 | values = values.cpu().numpy() 501 | 502 | if values.size > 1: 503 | values = wandb.Histogram(values) 504 | 505 | out_dict[f"{stat_name}/{param_name}"] = values 506 | 507 | if not self.wandb_setup_complete: 508 | # For whatever reason using globs at init doesn't work 509 | wandb.define_metric("iter") 510 | for stat in out_dict: 511 | wandb.define_metric(stat, step_metric="iter") 512 | self.wandb_setup_complete = True 513 | 514 | out_dict["iter"] = self.iteration - (T_wandb - 1) * self.interval 515 | wandb.log( 516 | data=out_dict, 517 | # step=self.iteration-(T_wandb-1)*self.interval 518 | ) 519 | 520 | if free_buffers: 521 | self.free_buffers("wandb") 522 | 523 | def free_buffers(self, set_name="all"): 524 | # Delete old stat values that are no longer needed i.e. those that 525 | # have been logged by both wandb and to disk where appropriate 526 | 527 | if set_name == "all": 528 | self.stats.clear() 529 | return 530 | if set_name == "disk": 531 | main = "disk_stats" 532 | other = "wandb_stats" 533 | elif set_name == "wandb": 534 | main = "wandb_stats" 535 | other = "disk_stats" 536 | else: 537 | raise ValueError(f"Unknown {set_name=}") 538 | 539 | private_stats = set(self.cfg[main]) - set(self.cfg[other]) 540 | for stat in private_stats: 541 | del self.stats[stat] 542 | 543 | T_disk = self.cfg["disk_save_interval"] or 0 544 | T_wandb = self.cfg["wandb_interval"] or 0 545 | buffer_size = max(T_disk, T_wandb) 546 | shared_stats = set(self.cfg[main]) & set(self.cfg[other]) 547 | for stat_name in shared_stats: 548 | for param_name in self.stats[stat_name]: 549 | new_buffer = self.stats[stat_name][param_name][-buffer_size:] 550 | self.stats[stat_name][param_name] = new_buffer 551 | 552 | @staticmethod 553 | def load_stats(path): 554 | path = Path(path) 555 | 556 | log_fragments = [] 557 | with open(path, "rb") as f: 558 | while True: 559 | try: 560 | log_fragments.append(pickle.load(f)) 561 | except EOFError: 562 | break 563 | 564 | out_dict = dict() 565 | for stat_name in log_fragments[0]: 566 | stat_dict = {} 567 | for param_name in log_fragments[0][stat_name]: 568 | chunks = [] 569 | for log_fragment in log_fragments: 570 | chunks.append(log_fragment[stat_name][param_name]) 571 | stat_dict[param_name] = torch.concatenate(chunks) 572 | out_dict[stat_name] = stat_dict 573 | return out_dict 574 | 575 | 576 | def move_to_cpu(data, clone=False): 577 | def recurse(data): 578 | if isinstance(data, dict): 579 | return {k: recurse(v) for k, v in data.items()} 580 | if isinstance(data, list): 581 | return [recurse(v) for v in data] 582 | if isinstance(data, tuple): 583 | return tuple(recurse(v) for v in data) 584 | 585 | if isinstance(data, torch.Tensor): 586 | data = data.detach() 587 | if clone: 588 | data = data.clone() 589 | return data.to(device="cpu") 590 | else: 591 | # Others int, float, str, None etc 592 | if clone: 593 | return copy.deepcopy(data) # Copy just in case 594 | else: 595 | return data 596 | 597 | return recurse(data) 598 | 599 | 600 | # Bind this to the original object to preserve the self property 601 | # E.g. obj.method = self_preserving_overwrite(some_func).__get__(obj) 602 | def self_preserving_overwrite(method): 603 | @wraps(method) 604 | def _impl(inner_self, *args, **kwargs): 605 | return method(*args, **kwargs) 606 | 607 | return _impl 608 | -------------------------------------------------------------------------------- /src/logger/rotational_logger.yaml: -------------------------------------------------------------------------------- 1 | interval: 5 2 | stats: 3 | [ 4 | "layer_norm", 5 | "neuron_norm", 6 | "layer_grad_norm", 7 | "neuron_grad_norm", 8 | "layer_update_norm", 9 | "neuron_update_norm", 10 | "layer_angular_update", 11 | "neuron_angular_update", 12 | "layer_relative_update", 13 | "neuron_relative_update", 14 | "scalar_rms", 15 | "scalar_update_rms", 16 | "scalar_grad_rms", 17 | ] 18 | 19 | disk_save_interval: null # Save f, null to disable 20 | disk_stats: all 21 | disk_max_channels: 16 # Max channels per layer saved 22 | disk_downsample: 10 # 23 | 24 | wandb_interval: 5 25 | wandb_stats: all 26 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | import random 5 | import os 6 | import schedulefree 7 | 8 | import numpy as np 9 | import torch 10 | import wandb 11 | 12 | import config 13 | from data.utils import DataReader, get_dataset 14 | import distributed 15 | from models.utils import get_model 16 | from optim.base import train 17 | from optim.utils import cos_inf_schedule, wsd_schedule 18 | 19 | 20 | def main(args): 21 | distributed_backend = distributed.make_backend_from_args(args) 22 | args = distributed_backend.get_adjusted_args_for_process(args) 23 | args.world_size = distributed_backend.get_world_size() 24 | 25 | if args.full_eval_at is None: 26 | args.full_eval_at = [] 27 | 28 | # NOTE args.seed is offset per worker in get_adjusted_args_for_process 29 | torch.backends.cuda.matmul.allow_tf32 = True 30 | torch.backends.cudnn.allow_tf32 = True 31 | torch.manual_seed(args.seed) 32 | random.seed(args.seed) 33 | np.random.seed(args.seed) 34 | if "cuda" in args.device: 35 | torch.cuda.set_device(torch.device(args.device)) 36 | # torch.use_deterministic_algorithms(True) # CUBLAS_WORKSPACE_CONFIG=:4096:8 37 | 38 | exp_name = get_exp_name(args, distributed_backend) 39 | exp_dir = Path(args.results_base_folder) / exp_name 40 | if distributed_backend.is_master_process() and args.wandb: 41 | wandb.init( 42 | project=args.wandb_project, 43 | name=exp_name, 44 | config=vars(args), 45 | ) 46 | wandb.define_metric("iter") 47 | wandb.define_metric("train/*", step_metric="iter") 48 | wandb.define_metric("val/*", step_metric="iter") 49 | wandb.define_metric("lr", step_metric="iter") 50 | 51 | print(f"Starting Experiment: {exp_name}") 52 | print(f"Experiment Directory: {exp_dir}") 53 | print(f"Config:\n{vars(args)}\n") 54 | 55 | print(f"Loading dataset: '{args.dataset}'") 56 | datareaders = get_data_readers(args) 57 | 58 | model = get_model(args).to(args.device) 59 | # TODO: take care of initializing the model if args.use_pretrained != 'none' 60 | print(f"\nModel:\n{model}") 61 | 62 | model = distributed_backend.transform_model(model) 63 | group_specs = distributed_backend.get_raw_model(model).get_parameter_group_specs() 64 | param_name_mapping = {p_name: p for p_name, p in model.named_parameters()} 65 | optimized_params_cnt = 0 66 | for g in group_specs: 67 | params = [] 68 | for p_name in g["params"]: 69 | translated_p_names = ( 70 | distributed_backend.translate_model_parameter_name_for_node(p_name) 71 | ) 72 | params += [param_name_mapping[p_name] for p_name in translated_p_names] 73 | g["params"] = params 74 | optimized_params_cnt += sum([p.numel() for p in g["params"]]) 75 | params_cnt = distributed_backend.get_raw_model(model).get_num_params() 76 | print("number of parameters: %.2fM" % (params_cnt / 1e6,)) 77 | print("number of optimized parameters: %.2fM" % (optimized_params_cnt / 1e6,)) 78 | if args.wandb and distributed_backend.is_master_process(): 79 | wandb.log( 80 | {"parameters": params_cnt, "optimized_parameters": optimized_params_cnt} 81 | ) 82 | 83 | if args.opt == "adamw": 84 | opt = torch.optim.AdamW( 85 | group_specs, 86 | lr=args.lr, 87 | betas=(args.beta1, args.beta2), 88 | weight_decay=args.weight_decay, 89 | ) 90 | elif args.opt == "SFAdamW": 91 | opt = schedulefree.AdamWScheduleFree( 92 | group_specs, 93 | lr=args.lr, 94 | betas=(args.beta1, args.beta2), 95 | weight_decay=args.weight_decay, 96 | warmup_steps=args.warmup_steps, 97 | ) 98 | 99 | else: 100 | opt = torch.optim.SGD( 101 | group_specs, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay 102 | ) 103 | print(f"\nOptimizer:\n{opt}") 104 | 105 | if args.scheduler != "none": 106 | assert args.warmup_steps < args.iterations, "Warmup steps must be < iterations." 107 | if args.scheduler in ["cos", "linear"]: 108 | # initial lr is args.lr / div_factor 109 | # final lr is initial_lr/final_div_factor = args.lr / div_factor / final_div_factor 110 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 111 | optimizer=opt, 112 | max_lr=[group.get("lr", args.lr) for group in group_specs], 113 | total_steps=args.iterations, 114 | pct_start=args.warmup_steps / args.iterations, 115 | anneal_strategy=args.scheduler, 116 | cycle_momentum=False, 117 | div_factor=1e2, 118 | final_div_factor=0.1, 119 | ) 120 | elif args.scheduler == "cos_inf": 121 | lambda_schedule = cos_inf_schedule( 122 | n_iterations=args.iterations, 123 | n_warmup=args.warmup_steps, 124 | n_inf=args.cos_inf_steps, 125 | div_factor=1e2, 126 | final_div_factor=0.1, 127 | ) 128 | scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) 129 | elif args.scheduler == "wsd": 130 | lambda_schedule = wsd_schedule( 131 | n_iterations=args.iterations, 132 | n_warmup=args.warmup_steps, 133 | fract_decay=args.wsd_fract_decay, 134 | init_div_factor=1e2, 135 | final_lr_factor=args.wsd_final_lr_scale, # should be 0 here 136 | decay_type=args.decay_type, 137 | ) 138 | scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) 139 | else: 140 | raise NotImplementedError(f"Unknown scheduler type: {args.scheduler}.") 141 | else: 142 | scheduler = None 143 | 144 | if (exp_dir / "ckpts" / "latest" / "main.pt").exists(): 145 | if not args.auto_resume: 146 | raise ValueError( 147 | f"The experiment dir {exp_dir} already exists. " 148 | + "To resume training, set auto_resume=True. " 149 | + "Otherwise, specify a different experiment name. " 150 | ) 151 | else: 152 | # Auto resume overwrites resume_from 153 | args.resume_from = str(exp_dir / "ckpts" / "latest") 154 | 155 | elif distributed_backend.is_master_process(): 156 | exp_dir.mkdir(parents=True, exist_ok=True) 157 | 158 | stats = train( 159 | model=model, 160 | opt=opt, 161 | datareaders=datareaders, 162 | scheduler=scheduler, 163 | exp_dir=exp_dir, 164 | distributed_backend=distributed_backend, 165 | cfg=args, 166 | ) 167 | 168 | stats["args"] = vars(args) 169 | if distributed_backend.is_master_process(): 170 | with open(exp_dir / "summary.json", "w") as fs: 171 | json.dump(stats, fs) 172 | distributed_backend.finalize() 173 | 174 | 175 | def get_args(): 176 | parser = argparse.ArgumentParser(allow_abbrev=False) 177 | parser.add_argument( 178 | "--config_format", default="base", choices=config.registered_formats() 179 | ) 180 | 181 | args, rem_args = parser.parse_known_args() 182 | 183 | return config.parse_args_with_format( 184 | format=args.config_format, base_parser=parser, args=rem_args, namespace=args 185 | ) 186 | 187 | 188 | def get_exp_name(args, distributed_backend): 189 | """Returns the name of the experiment, used for saving models and wandb.""" 190 | if args.experiment_name is not None: 191 | return args.experiment_name 192 | 193 | rank = distributed_backend.rank 194 | 195 | exp_name = ( 196 | f"{args.dataset}_{args.model}_nlayers{args.n_layer}" 197 | f"_nhead{args.n_head}_lr{args.lr}" 198 | f"_sched_{args.scheduler}_warmup{args.warmup_steps}" 199 | f"_decay_{args.decay_type}_{args.wsd_fract_decay}" 200 | f"_iter{args.iterations}" 201 | f"_bs{args.batch_size}x{args.acc_steps}_ws{args.world_size}" 202 | ) 203 | # for mup 204 | if args.model == "mup_noam": 205 | exp_name = ( 206 | f"{args.dataset}_{args.model}" 207 | f"_opt{args.opt}" 208 | f"_nlayers{args.n_layer}" 209 | # f"_nhead{args.n_head}" 210 | f"_lr{args.lr}" 211 | f"_sched_{args.scheduler}" 212 | f"_decay_{args.decay_type}" 213 | # f"_warmup{args.warmup_steps}" 214 | f"_iter{args.iterations}" 215 | f"_init{args.init_std}_sce{args.scale_emb}" 216 | f"_scd{args.scale_depth}" 217 | # f"_bs{args.batch_size}x{args.acc_steps}_ws{args.world_size}" 218 | ) 219 | if args.wandb_run_prefix != "none": 220 | exp_name = args.wandb_run_prefix + "_" + exp_name 221 | exp_name += f"_seed{args.seed - rank}" 222 | exp_name += f"_data_seed{args.data_seed}" 223 | 224 | if args.weight_average: 225 | exp_name += f"_WA" 226 | if args.opt == "SFAdamW": 227 | exp_name += f"_beta1_{args.beta1}" 228 | exp_name += f"_beta2_{args.beta2}" 229 | return exp_name 230 | 231 | 232 | def get_data_readers(args, verbose=True): 233 | data_srcs = get_dataset(args) 234 | train_reader = DataReader( 235 | data_src=data_srcs["train"], 236 | batch_size=args.batch_size, 237 | sequence_length=args.sequence_length, 238 | seed=args.data_seed, 239 | with_replacement=False, 240 | auto_shard=True, 241 | keep_in_ram=args.data_in_ram, 242 | ) 243 | val_reader = DataReader( 244 | data_src=data_srcs["val"], 245 | batch_size=args.batch_size, 246 | sequence_length=args.sequence_length, 247 | seed=args.data_seed, 248 | with_replacement=False, 249 | auto_shard=False, # NOTE Identical Per Rank 250 | keep_in_ram=args.data_in_ram, 251 | ) 252 | 253 | if verbose: 254 | print(f"Num training tokens: {train_reader.num_tokens}") 255 | print(f"Num validation tokens: {val_reader.num_tokens}") 256 | 257 | return { 258 | "train": train_reader, 259 | "val": val_reader, 260 | } 261 | 262 | 263 | if __name__ == "__main__": 264 | args = get_args() 265 | main(args) 266 | -------------------------------------------------------------------------------- /src/models/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full definition of a GPT Language Model, all of it in this single file. 3 | References: 4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 5 | https://github.com/openai/gpt-2/blob/master/src/model.py 6 | 2) huggingface/transformers PyTorch implementation: 7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 8 | """ 9 | 10 | import math 11 | 12 | import tiktoken 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | class LayerNorm(nn.Module): 19 | """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" 20 | 21 | def __init__(self, ndim, bias): 22 | super().__init__() 23 | self.weight = nn.Parameter(torch.ones(ndim)) 24 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 25 | 26 | def forward(self, input): 27 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 28 | 29 | 30 | class CausalSelfAttention(nn.Module): 31 | def __init__(self, config): 32 | super().__init__() 33 | assert config.n_embd % config.n_head == 0 34 | # key, query, value projections for all heads, but in a batch 35 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 36 | # output projection 37 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 38 | # regularization 39 | self.attn_dropout = nn.Dropout(config.dropout) 40 | self.resid_dropout = nn.Dropout(config.dropout) 41 | self.n_head = config.n_head 42 | self.n_embd = config.n_embd 43 | self.dropout = config.dropout 44 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 45 | self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") 46 | if not self.flash: 47 | print( 48 | "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0" 49 | ) 50 | # causal mask to ensure that attention is only applied to the left in the input sequence 51 | self.register_buffer( 52 | "bias", 53 | torch.tril( 54 | torch.ones(config.sequence_length, config.sequence_length) 55 | ).view(1, 1, config.sequence_length, config.sequence_length), 56 | ) 57 | 58 | def forward(self, x): 59 | # batch size, sequence length, embedding dimensionality (n_embd) 60 | ( 61 | B, 62 | T, 63 | C, 64 | ) = x.size() 65 | 66 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 67 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 68 | # (B, T, nh, hs) 69 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 70 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 71 | 72 | # (B, nh, T, hs) 73 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 74 | 75 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 76 | if self.flash: 77 | # efficient attention using Flash Attention CUDA kernels 78 | y = torch.nn.functional.scaled_dot_product_attention( 79 | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True 80 | ) 81 | else: 82 | # manual implementation of attention 83 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 84 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) 85 | att = F.softmax(att, dim=-1) 86 | att = self.attn_dropout(att) 87 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 88 | y = ( 89 | y.transpose(1, 2).contiguous().view(B, T, C) 90 | ) # re-assemble all head outputs side by side 91 | 92 | # output projection 93 | y = self.resid_dropout(self.c_proj(y)) 94 | return y 95 | 96 | 97 | class MLP(nn.Module): 98 | def __init__(self, config, exp_factor=1.0): 99 | super().__init__() 100 | self.dim_exp_factor = exp_factor * 4 101 | 102 | self.c_fc = nn.Linear( 103 | config.n_embd, int(self.dim_exp_factor * config.n_embd), bias=config.bias 104 | ) 105 | self.c_proj = nn.Linear( 106 | int(self.dim_exp_factor * config.n_embd), config.n_embd, bias=config.bias 107 | ) 108 | self.dropout = nn.Dropout(config.dropout) 109 | self.activation = nn.GELU() 110 | 111 | def forward(self, x): 112 | x = self.c_fc(x) 113 | x = self.activation(x) 114 | x = self.c_proj(x) 115 | x = self.dropout(x) 116 | return x, {} 117 | 118 | 119 | class Block(nn.Module): 120 | def __init__(self, config): 121 | super().__init__() 122 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 123 | self.attn = CausalSelfAttention(config) 124 | self.parallel = config.parallel_block 125 | if not self.parallel: 126 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 127 | self.mlp = MLP(config) 128 | 129 | def forward(self, x, *args, **kwargs): 130 | if self.parallel: 131 | # from GPT-J 6B https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L299 132 | x_ln = self.ln_1(x, *args, **kwargs) 133 | x_attn = self.attn(x_ln) 134 | x_ffn = self.mlp(x_ln) 135 | x = x + x_attn + x_ffn 136 | else: 137 | x = x + self.attn(self.ln_1(x, *args, **kwargs)) 138 | x_ = self.mlp(self.ln_2(x, *args, **kwargs)) 139 | x = x + x_ 140 | return x 141 | 142 | 143 | class GPTBase(nn.Module): 144 | def __init__(self, config): 145 | super().__init__() 146 | assert config.vocab_size is not None 147 | assert config.sequence_length is not None 148 | self.config = config 149 | self.tokenizer = tiktoken.get_encoding("gpt2") 150 | 151 | self.transformer = nn.ModuleDict( 152 | dict( 153 | wte=nn.Embedding(config.vocab_size, config.n_embd), 154 | wpe=nn.Embedding(config.sequence_length, config.n_embd), 155 | drop=nn.Dropout(config.dropout), 156 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 157 | ln_f=LayerNorm(config.n_embd, bias=config.bias), 158 | ) 159 | ) 160 | 161 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 162 | # with weight tying when using torch.compile() some warnings get generated: 163 | # "UserWarning: functional_call was passed multiple values for tied weights. 164 | # This behavior is deprecated and will be an error in future versions" 165 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 166 | self.transformer.wte.weight = ( 167 | self.lm_head.weight 168 | ) # https://paperswithcode.com/method/weight-tying 169 | 170 | # init all weights 171 | self.apply(self._init_weights) 172 | # apply special scaled init to the residual projections, per GPT-2 paper 173 | for pn, p in self.named_parameters(): 174 | if pn.endswith("c_proj.weight"): 175 | torch.nn.init.normal_( 176 | p, 177 | mean=0.0, 178 | std=self.config.init_std / math.sqrt(2 * config.n_layer), 179 | ) 180 | 181 | def get_num_params(self, non_embedding=True): 182 | """ 183 | Return the number of parameters in the model. 184 | For non-embedding count (default), the position embeddings get subtracted. 185 | The token embeddings would too, except due to the parameter sharing these 186 | params are actually used as weights in the final layer, so we include them. 187 | """ 188 | n_params = sum(p.numel() for p in self.parameters()) 189 | if non_embedding: 190 | n_params -= self.transformer.wpe.weight.numel() 191 | return n_params 192 | 193 | def _init_weights(self, module): 194 | if isinstance(module, nn.Linear): 195 | torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) 196 | if module.bias is not None: 197 | torch.nn.init.zeros_(module.bias) 198 | elif isinstance(module, nn.Embedding): 199 | torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) 200 | 201 | def forward(self, idx, targets=None, get_logits=False): 202 | device = idx.device 203 | b, t = idx.size() 204 | assert ( 205 | t <= self.config.sequence_length 206 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" 207 | # shape (1, t) 208 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) 209 | 210 | # forward the GPT model itself 211 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 212 | pos_emb = self.transformer.wpe( 213 | pos 214 | ) # position embeddings of shape (1, t, n_embd) 215 | x = self.transformer.drop(tok_emb + pos_emb) 216 | 217 | # router logits is a list for each layer's routing, each of shape (b * seq_len, n_experts) 218 | router_logits = [] 219 | # experts is a list for each layer's selected experts, shape (b * seq_len, topk) 220 | experts = [] 221 | 222 | # forward pass through all the transformer blocks 223 | for block in self.transformer.h: 224 | x, logits_and_experts = block(x) 225 | x = self.transformer.ln_f(x) 226 | 227 | if targets is not None: 228 | # if we are given some desired targets also calculate the loss 229 | logits = self.lm_head(x) 230 | loss = F.cross_entropy( 231 | logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 232 | ) 233 | 234 | else: 235 | # inference-time mini-optimization: only forward the lm_head on the very last position 236 | logits = self.lm_head( 237 | x[:, [-1], :] 238 | ) # note: using list [-1] to preserve the time dim 239 | loss = None 240 | logits = logits if get_logits else None 241 | return { 242 | "logits": logits, 243 | "loss": loss, 244 | } 245 | 246 | def crop_sequence_length(self, sequence_length): 247 | # model surgery to decrease the block size if necessary 248 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 249 | # but want to use a smaller block size for some smaller, simpler model 250 | assert sequence_length <= self.config.sequence_length 251 | self.config.sequence_length = sequence_length 252 | self.transformer.wpe.weight = nn.Parameter( 253 | self.transformer.wpe.weight[:sequence_length] 254 | ) 255 | for block in self.transformer.h: 256 | block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length] 257 | 258 | def from_pretrained( 259 | self, 260 | model_path, 261 | ): 262 | paths = model_path.split(",") 263 | if len(paths) == 1: 264 | # TODO: with distributed? 265 | loaded_state = torch.load( 266 | str(model_path + "/ckpt.pt"), 267 | map_location=torch.device(self.config.device), 268 | ) 269 | state_to_load = loaded_state["model"] 270 | 271 | # load the sparse model 272 | state_to_load = { 273 | ".".join(k.split(".")[1:]): v # drop _orig_mod from keys 274 | for k, v in state_to_load.items() 275 | } 276 | 277 | def get_parameter_group_specs(self): 278 | """ 279 | This long function is unfortunately doing something very simple and is being very defensive: 280 | We are separating out all parameters of the model into two buckets: those that will experience 281 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 282 | We are then returning the PyTorch optimizer object. 283 | """ 284 | 285 | # separate out all parameters to those that will and won't experience regularizing weight decay 286 | decay = set() 287 | no_decay = set() 288 | whitelist_weight_modules = (torch.nn.Linear,) 289 | # need to do import here to avoid circular import (since llama imports from base here) 290 | from .utils import BLACKLIST_WEIGHT_MODULES 291 | 292 | for mn, m in self.named_modules(): 293 | for pn, p in m.named_parameters(): 294 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name 295 | # random note: because named_modules and named_parameters are recursive 296 | # we will see the same tensors p many many times. but doing it this way 297 | # allows us to know which parent module any tensor p belongs to... 298 | if pn.endswith("bias"): 299 | # all biases will not be decayed 300 | no_decay.add(fpn) 301 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): 302 | # weights of whitelist modules will be weight decayed 303 | decay.add(fpn) 304 | elif pn.endswith("weight") and isinstance(m, BLACKLIST_WEIGHT_MODULES): 305 | # weights of blacklist modules will NOT be weight decayed 306 | no_decay.add(fpn) 307 | 308 | # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they 309 | # will appear in the no_decay and decay sets respectively after the above. 310 | # In addition, because named_parameters() doesn't return duplicates, it 311 | # will only return the first occurence, key'd by 'transformer.wte.weight', below. 312 | # so let's manually remove 'lm_head.weight' from decay set. This will include 313 | # this tensor into optimization via transformer.wte.weight only, and not decayed. 314 | decay.remove("lm_head.weight") 315 | 316 | # validate that we considered every parameter 317 | param_dict = {pn: p for pn, p in self.named_parameters()} 318 | inter_params = decay & no_decay 319 | union_params = decay | no_decay 320 | assert ( 321 | len(inter_params) == 0 322 | ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) 323 | assert ( 324 | len(param_dict.keys() - union_params) == 0 325 | ), "parameters %s were not separated into either decay/no_decay set!" % ( 326 | str(param_dict.keys() - union_params), 327 | ) 328 | 329 | # create the pytorch optimizer object 330 | return [ 331 | {"params": sorted(list(decay))}, 332 | {"params": sorted(list(no_decay)), "weight_decay": 0.0}, 333 | ] 334 | 335 | @torch.no_grad() 336 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 337 | """ 338 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 339 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 340 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 341 | """ 342 | for _ in range(max_new_tokens): 343 | # if the sequence context is growing too long we must crop it at sequence_length 344 | idx_cond = ( 345 | idx 346 | if idx.size(1) <= self.config.sequence_length 347 | else idx[:, -self.config.sequence_length :] 348 | ) 349 | # forward the model to get the logits for the index in the sequence 350 | logits = self(idx_cond, get_logits=True)["logits"] 351 | # pluck the logits at the final step and scale by desired temperature 352 | logits = logits[:, -1, :] / temperature 353 | # optionally crop the logits to only the top k options 354 | if top_k is not None: 355 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 356 | logits[logits < v[:, [-1]]] = -float("Inf") 357 | # apply softmax to convert logits to (normalized) probabilities 358 | probs = F.softmax(logits, dim=-1) 359 | # sample from the distribution 360 | idx_next = torch.multinomial(probs, num_samples=1) 361 | # append sampled index to the running sequence and continue 362 | idx = torch.cat((idx, idx_next), dim=1) 363 | 364 | return idx 365 | 366 | @torch.no_grad() 367 | def generate_from_string(self, in_str, max_new_tokens, temperature=1.0, top_k=None): 368 | idx = ( 369 | torch.tensor( 370 | self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"}) 371 | ) 372 | .view(1, -1) 373 | .to(self.lm_head.weight.device) 374 | ) 375 | out_idx = ( 376 | self.generate(idx, max_new_tokens, temperature, top_k) 377 | .view(-1) 378 | .to("cpu") 379 | .numpy() 380 | ) 381 | return self.tokenizer.decode(out_idx) 382 | -------------------------------------------------------------------------------- /src/models/llama.py: -------------------------------------------------------------------------------- 1 | """ 2 | Llama style Language Model that is 3 | compilable (avoids torch complex) 4 | """ 5 | 6 | import math 7 | 8 | import tiktoken 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import functional as F 12 | from models.base import CausalSelfAttention, GPTBase 13 | 14 | 15 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: 16 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 17 | t = torch.arange(end, device=freqs.device) # type: ignore 18 | freqs = torch.outer(t, freqs).float() # type: ignore 19 | cos_freqs = torch.cos(freqs) 20 | sin_freqs = torch.sin(freqs) 21 | # Stack the cos and sin parts in the last dimension to simulate complex numbers 22 | return torch.stack((cos_freqs, sin_freqs), dim=-1) 23 | 24 | 25 | def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: 26 | """ 27 | freqs_cis: complex - (seq_len, head_dim / 2) 28 | x: complex - (bsz, seq_len, head_dim / 2) 29 | """ 30 | ndim = x.ndim 31 | assert 1 < ndim 32 | assert freqs_cis.shape[:-1] == (x.shape[1], x.shape[-2]) 33 | # New shape for broadcasting 34 | shape = [ 35 | 1 if i != 1 and i != ndim - 2 else d for i, d in enumerate(x.shape[:-1]) 36 | ] + [2] 37 | return freqs_cis.view(*shape) 38 | 39 | 40 | def apply_rotary_emb(q, k, freqs_cis): 41 | # q, k: (B, T, nh, hs) 42 | # freq_cis: (T, hs) 43 | # return: (B, T, nh, hs), (B, T, nh, hs) 44 | q = q.float().reshape(*q.shape[:-1], -1, 2) 45 | k = k.float().reshape(*k.shape[:-1], -1, 2) 46 | 47 | freqs_cis = _reshape_for_broadcast(freqs_cis, q) 48 | 49 | # Perform manual "complex" multiplication 50 | q_cos = q[..., 0] * freqs_cis[..., 0] - q[..., 1] * freqs_cis[..., 1] 51 | q_sin = q[..., 0] * freqs_cis[..., 1] + q[..., 1] * freqs_cis[..., 0] 52 | k_cos = k[..., 0] * freqs_cis[..., 0] - k[..., 1] * freqs_cis[..., 1] 53 | k_sin = k[..., 0] * freqs_cis[..., 1] + k[..., 1] * freqs_cis[..., 0] 54 | 55 | # Combine the results back into the interleaved format expected by q and k 56 | q_out = torch.stack((q_cos, q_sin), dim=-1).reshape(q.shape).flatten(3) 57 | k_out = torch.stack((k_cos, k_sin), dim=-1).reshape(k.shape).flatten(3) 58 | 59 | return q_out, k_out 60 | 61 | 62 | class RMSNorm(nn.Module): 63 | def __init__(self, dim: int, eps: float = 1e-6): 64 | super().__init__() 65 | self.eps = eps 66 | self.weight = nn.Parameter(torch.ones(dim)) 67 | 68 | def _norm(self, x): 69 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 70 | 71 | def forward(self, x): 72 | output = self._norm(x.float()).type_as(x) 73 | return output * self.weight 74 | 75 | 76 | class LlamaMLP(nn.Module): 77 | def __init__(self, config): 78 | super().__init__() 79 | 80 | hidden_dim = config.n_embd * 4 81 | hidden_dim = int(2 * hidden_dim / 3) 82 | hidden_dim = config.multiple_of * ( 83 | (hidden_dim + config.multiple_of - 1) // config.multiple_of 84 | ) 85 | 86 | self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False) 87 | self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=False) 88 | self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) 89 | 90 | def forward(self, x): 91 | return self.c_proj(nn.functional.silu(self.w1(x)) * self.w2(x)) 92 | 93 | 94 | class LlamaAttention(CausalSelfAttention): 95 | 96 | def forward(self, x, freqs_cis): 97 | # batch size, sequence length, embedding dimensionality (n_embd) 98 | ( 99 | B, 100 | T, 101 | C, 102 | ) = x.size() 103 | 104 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 105 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 106 | # (B, T, nh, hs) 107 | k = k.view(B, T, self.n_head, C // self.n_head) 108 | q = q.view(B, T, self.n_head, C // self.n_head) 109 | q, k = apply_rotary_emb(q, k, freqs_cis) 110 | # (B, nh, T, hs) 111 | q, k = q.transpose(1, 2), k.transpose(1, 2) 112 | 113 | # (B, nh, T, hs) 114 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 115 | 116 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 117 | if self.flash: 118 | # efficient attention using Flash Attention CUDA kernels 119 | y = torch.nn.functional.scaled_dot_product_attention( 120 | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True 121 | ) 122 | else: 123 | # manual implementation of attention 124 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 125 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) 126 | att = F.softmax(att, dim=-1) 127 | att = self.attn_dropout(att) 128 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 129 | y = ( 130 | y.transpose(1, 2).contiguous().view(B, T, C) 131 | ) # re-assemble all head outputs side by side 132 | 133 | # output projection 134 | y = self.resid_dropout(self.c_proj(y)) 135 | return y 136 | 137 | 138 | class LlamaBlock(nn.Module): 139 | def __init__(self, config): 140 | super().__init__() 141 | self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) 142 | self.attn = LlamaAttention(config) 143 | self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) 144 | self.mlp = LlamaMLP(config) 145 | 146 | def forward(self, x, freqs_cis): 147 | x = x + self.attn(self.ln_1(x), freqs_cis) 148 | x_ = self.mlp(self.ln_2(x)) 149 | x = x + x_ 150 | return x 151 | 152 | 153 | class Llama(GPTBase): 154 | def __init__(self, config): 155 | super().__init__(config) 156 | assert config.vocab_size is not None 157 | assert config.sequence_length is not None 158 | self.config = config 159 | self.tokenizer = tiktoken.get_encoding("gpt2") 160 | 161 | # create the token and position embeddings 162 | self.head_dim = config.n_embd // config.n_head 163 | self.freqs_cis = precompute_freqs_cis(self.head_dim, config.sequence_length) 164 | 165 | self.transformer = nn.ModuleDict( 166 | dict( 167 | wte=nn.Embedding(config.vocab_size, config.n_embd), 168 | drop=nn.Dropout(config.dropout), 169 | h=nn.ModuleList([LlamaBlock(config) for _ in range(config.n_layer)]), 170 | ln_f=RMSNorm(config.n_embd, eps=config.rmsnorm_eps), 171 | ) 172 | ) 173 | 174 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 175 | # with weight tying when using torch.compile() some warnings get generated: 176 | # "UserWarning: functional_call was passed multiple values for tied weights. 177 | # This behavior is deprecated and will be an error in future versions" 178 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 179 | self.transformer.wte.weight = ( 180 | self.lm_head.weight 181 | ) # https://paperswithcode.com/method/weight-tying 182 | 183 | # init all weights 184 | self.apply(self._init_weights) 185 | # apply special scaled init to the residual projections, per GPT-2 paper 186 | for pn, p in self.named_parameters(): 187 | if pn.endswith("c_proj.weight"): 188 | torch.nn.init.normal_( 189 | p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) 190 | ) 191 | 192 | def get_num_params(self, non_embedding=True): 193 | """ 194 | Return the number of parameters in the model. 195 | For non-embedding count (default) 196 | The token embeddings would too, except due to the parameter sharing these 197 | params are actually used as weights in the final layer, so we include them. 198 | """ 199 | n_params = sum(p.numel() for p in self.parameters()) 200 | return n_params 201 | 202 | def forward(self, idx, targets=None, get_logits=False): 203 | device = idx.device 204 | b, t = idx.size() 205 | assert ( 206 | t <= self.config.sequence_length 207 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" 208 | # shape (1, t) 209 | pos = torch.arange(0, t, dtype=torch.long, device=device) 210 | 211 | # forward the GPT model itself 212 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 213 | 214 | x = self.transformer.drop(tok_emb) 215 | freqs_cis = self.freqs_cis.to(x.device)[pos] 216 | 217 | for block in self.transformer.h: 218 | x = block(x, freqs_cis=freqs_cis) 219 | x = self.transformer.ln_f(x) 220 | 221 | if targets is not None: 222 | # if we are given some desired targets also calculate the loss 223 | logits = self.lm_head(x) 224 | loss = F.cross_entropy( 225 | logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 226 | ) 227 | else: 228 | # inference-time mini-optimization: only forward the lm_head on the very last position 229 | logits = self.lm_head( 230 | x[:, [-1], :] 231 | ) # note: using list [-1] to preserve the time dim 232 | loss = None 233 | 234 | logits = logits if get_logits else None 235 | 236 | return { 237 | "logits": logits, 238 | "loss": loss, 239 | } 240 | -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | from .llama import Llama, RMSNorm 2 | from .base import GPTBase, LayerNorm 3 | import torch 4 | 5 | BLACKLIST_WEIGHT_MODULES = ( 6 | torch.nn.LayerNorm, 7 | LayerNorm, 8 | RMSNorm, 9 | torch.nn.Embedding, 10 | ) 11 | 12 | 13 | def get_model(args): 14 | """Return the right model""" 15 | if args.model == "base": 16 | model = GPTBase(args) 17 | if args.use_pretrained != "none": 18 | model.from_pretrained(args.use_pretrained) 19 | return model 20 | elif args.model == "llama": 21 | model = Llama(args) 22 | if args.use_pretrained != "none": 23 | raise NotImplementedError( 24 | f"Loading of pretrained models not yet implemented for model '{args.model}'." 25 | ) 26 | return model 27 | else: 28 | raise KeyError(f"Unknown model '{args.model}'.") 29 | -------------------------------------------------------------------------------- /src/optim/base.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | import copy 3 | from pathlib import Path 4 | import time 5 | import yaml 6 | 7 | import torch 8 | import wandb 9 | 10 | from logger.logger import DynamicsLogger 11 | from optim.weight_averaging import ( 12 | WeightAverager, 13 | eval_ema, 14 | eval_wa, 15 | ExponentialWeightAverager, 16 | ) 17 | from .utils import ( 18 | eval, 19 | get_batch, 20 | load_checkpoint, 21 | load_worker_state, 22 | save_checkpoint, 23 | save_worker_state, 24 | ) 25 | 26 | 27 | def train( 28 | model, 29 | opt, 30 | datareaders, 31 | scheduler, 32 | exp_dir, 33 | distributed_backend, 34 | cfg, 35 | ): 36 | not_compiled_model = model 37 | if cfg.compile: 38 | print(f"Compiling model ...") 39 | model = torch.compile(model) 40 | 41 | if "cuda" in cfg.device: 42 | type_ctx = torch.amp.autocast( 43 | device_type="cuda", 44 | dtype={ 45 | "float32": torch.float32, 46 | "float16": torch.float16, 47 | "bfloat16": torch.bfloat16, 48 | }[cfg.dtype], 49 | ) 50 | else: 51 | type_ctx = nullcontext() 52 | 53 | if cfg.resume_from: 54 | # This is a full resume including the model weights, optimizer, state 55 | # dataloader state, random seed, etc. Not indended for fine tuning or 56 | # other scenarios where some of these should change. 57 | print(f"\nResuming Training From {cfg.resume_from}") 58 | ckpt_dir = Path(cfg.resume_from) 59 | curr_iter = load_checkpoint( 60 | model, 61 | opt, 62 | scheduler, 63 | ckpt_dir / "main.pt", 64 | cfg.device, 65 | ) 66 | load_worker_state(ckpt_dir) 67 | else: 68 | curr_iter = 0 69 | 70 | if cfg.weight_average: 71 | # This does generally not support resuming training, but will work if 72 | # cfg.wa_interval perfectly divides the iteration number of the chkpt. 73 | # Otherwise, the first avg will not be correctly computed, with a bias 74 | # towards the first sample and missing values for earlier iterations. 75 | weight_averager = WeightAverager( 76 | not_compiled_model, 77 | horizon=cfg.wa_horizon, 78 | interval=cfg.wa_interval, 79 | save_dir=None if cfg.wa_use_temp_dir else exp_dir / "avgs", 80 | dtype={ 81 | "float32": torch.float32, 82 | "float64": torch.float64, 83 | }[cfg.wa_dtype], 84 | count=curr_iter, 85 | ) 86 | 87 | if cfg.exponential_moving_average: 88 | ema = ExponentialWeightAverager( 89 | not_compiled_model, 90 | interval=cfg.ema_interval, 91 | decay=cfg.ema_decay, 92 | warmup=cfg.warmup_steps if cfg.ema_after_warmup else 0, 93 | dtype={ 94 | "float32": torch.float32, 95 | "float64": torch.float64, 96 | }[cfg.wa_dtype], 97 | ) 98 | 99 | if distributed_backend.is_master_process() and cfg.log_dynamics: 100 | with open(cfg.dynamics_logger_cfg, "r") as f: 101 | dlcfg = yaml.safe_load(f) 102 | 103 | # Hooks into optimizer 104 | dlogger = DynamicsLogger( 105 | model, opt, dlcfg, cfg.results_base_folder, wandb=cfg.wandb 106 | ) 107 | dlogger.iteration = curr_iter 108 | 109 | substep = curr_iter * cfg.acc_steps 110 | train_reader, val_reader = datareaders["train"], datareaders["val"] 111 | train_reader.set_step(substep) 112 | stats = {"train_loss": [], "val_loss": [], "val_pp": [], "val_acc": []} 113 | model.train() 114 | 115 | while curr_iter <= cfg.iterations: 116 | # Save permanent checkpoint 117 | if cfg.permanent_ckpt_interval > 0: 118 | if curr_iter % cfg.permanent_ckpt_interval == 0: 119 | ckpt_dir = exp_dir / "ckpts" / str(curr_iter) 120 | if distributed_backend.is_master_process(): 121 | save_checkpoint(model, opt, scheduler, curr_iter, ckpt_dir) 122 | save_worker_state(ckpt_dir) 123 | 124 | # Save temporary checkpoint for resuming training 125 | if cfg.latest_ckpt_interval > 0: 126 | if curr_iter % cfg.latest_ckpt_interval == 0 or curr_iter == cfg.iterations: 127 | ckpt_dir = exp_dir / "ckpts" / "latest" 128 | if distributed_backend.is_master_process(): 129 | save_checkpoint(model, opt, scheduler, curr_iter, ckpt_dir) 130 | save_worker_state(ckpt_dir) 131 | 132 | ws = distributed_backend.get_world_size() 133 | tokens = ws * substep * cfg.sequence_length * cfg.batch_size 134 | epoch = tokens / train_reader.num_tokens 135 | if ( 136 | curr_iter % cfg.eval_interval == 0 137 | or curr_iter == cfg.iterations 138 | or (curr_iter in cfg.full_eval_at) 139 | ): 140 | eval_and_log( 141 | curr_iter, 142 | epoch, 143 | model, 144 | val_reader, 145 | type_ctx, 146 | distributed_backend, 147 | cfg, 148 | opt, 149 | full_eval=(curr_iter in cfg.full_eval_at), 150 | ) 151 | 152 | if curr_iter > cfg.wa_interval and cfg.weight_average: 153 | eval_wa( 154 | curr_iter, 155 | not_compiled_model, 156 | weight_averager, 157 | val_reader, 158 | type_ctx, 159 | distributed_backend, 160 | cfg, 161 | full_eval=(curr_iter in cfg.full_eval_at), 162 | ) 163 | if cfg.exponential_moving_average: 164 | eval_ema( 165 | curr_iter, 166 | not_compiled_model, 167 | ema, 168 | val_reader, 169 | type_ctx, 170 | distributed_backend, 171 | cfg, 172 | full_eval=(curr_iter in cfg.full_eval_at), 173 | ) 174 | 175 | if curr_iter == cfg.iterations: 176 | # Save checkpoints and evaluate at final iteration, but no need to train further 177 | break 178 | 179 | # Train model 180 | t_start = time.perf_counter_ns() 181 | for microstep_idx in range(cfg.acc_steps): # gradient accumulation 182 | x, y = get_batch(train_reader, device=cfg.device) 183 | with type_ctx: 184 | with distributed_backend.get_context_for_microstep_forward( 185 | model=model, 186 | microstep_idx=microstep_idx, 187 | gradient_accumulation_steps=cfg.acc_steps, 188 | ): 189 | outputs = model(x, targets=y) 190 | 191 | loss = outputs["loss"] / cfg.acc_steps 192 | loss.backward() 193 | substep += 1 194 | 195 | if cfg.grad_clip != 0.0: 196 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) 197 | if cfg.opt == "SFAdamW": 198 | opt.train() 199 | opt.step() 200 | scheduler.step() 201 | opt.zero_grad(set_to_none=True) 202 | if cfg.weight_average: 203 | weight_averager.step(not_compiled_model, distributed_backend.is_master_process()) 204 | if cfg.exponential_moving_average: 205 | ema.step(not_compiled_model, distributed_backend.is_master_process()) 206 | dt = (time.perf_counter_ns() - t_start) / 1e9 207 | 208 | curr_iter += 1 209 | 210 | if ( 211 | cfg.log_interval 212 | and curr_iter % cfg.log_interval == 0 213 | and distributed_backend.is_master_process() # Only log on master rank 214 | ): 215 | train_loss = loss.detach().cpu().item() * cfg.acc_steps 216 | 217 | current_lrs = [param_group["lr"] for param_group in opt.param_groups] 218 | 219 | print( 220 | f"Train: Iter={curr_iter} ({epoch:0.3f} epochs) " 221 | f"train_loss={train_loss:.3f} iter_dt={dt:.2e}s " 222 | f"lr={current_lrs[0]:.2e}" 223 | ) 224 | 225 | if cfg.wandb: 226 | wandb.log( 227 | { 228 | "iter": curr_iter, 229 | "train/loss": train_loss, 230 | "train/perplexity": 2.71828**train_loss, 231 | "lr": current_lrs[0], 232 | "iter_dt": dt, 233 | } 234 | ) 235 | 236 | return stats 237 | 238 | 239 | def eval_and_log( 240 | curr_iter, 241 | epoch, 242 | model, 243 | val_reader, 244 | type_ctx, 245 | distributed_backend, 246 | cfg, 247 | opt, 248 | full_eval=False, 249 | ): 250 | if not distributed_backend.is_master_process(): 251 | # Only evaluate and log on master rank 252 | return 253 | 254 | model.eval() 255 | if cfg.opt == "SFAdamW": 256 | opt.eval() 257 | 258 | if curr_iter == cfg.iterations or full_eval: 259 | max_num_batches = val_reader.num_batches() 260 | else: 261 | max_num_batches = cfg.eval_batches 262 | 263 | # to make sure we start from the beginning of the validation set, 264 | # i.e. repeat the same batches 265 | val_reader.set_step(0) 266 | val_acc, val_loss, val_perplexity = eval( 267 | model, 268 | val_reader, 269 | cfg.device, 270 | max_num_batches=max_num_batches, 271 | ctx=type_ctx, 272 | cfg=cfg, 273 | ) 274 | 275 | print( 276 | f">Eval: Iter={curr_iter} ({epoch:0.3f} epochs) " 277 | f"val_loss={val_loss:.3f} " 278 | f"val_pp={val_perplexity:.3f} " 279 | f"val_acc={val_acc:3f}" 280 | ) 281 | 282 | if cfg.wandb: 283 | if curr_iter == cfg.iterations or full_eval: 284 | logs = { 285 | "iter": curr_iter, 286 | "final-val/loss": val_loss, 287 | "final-val/perplexity": val_perplexity, 288 | "final-val/acc": val_acc, 289 | } 290 | else: 291 | logs = { 292 | "iter": curr_iter, 293 | "val/loss": val_loss, 294 | "val/perplexity": val_perplexity, 295 | "val/acc": val_acc, 296 | } 297 | 298 | wandb.log(logs) 299 | if cfg.eval_seq_prefix != "none" and ( 300 | curr_iter % (cfg.eval_interval * 5) == 0 or curr_iter == cfg.iterations 301 | ): 302 | text_table = wandb.Table(columns=["itr", "val-pp", "text"]) 303 | 304 | out_str = distributed_backend.get_raw_model(model).generate_from_string( 305 | cfg.eval_seq_prefix, 306 | max_new_tokens=40, 307 | temperature=0.9, 308 | top_k=None, 309 | ) 310 | text_table.add_data(curr_iter, val_perplexity, out_str) 311 | # why a copy? see github.com/wandb/wandb/issues/2981 312 | wandb.log({f"generated-text-{wandb.run.name}": copy.copy(text_table)}) 313 | model.train() 314 | -------------------------------------------------------------------------------- /src/optim/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import random 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from contextlib import nullcontext 7 | import torch.distributed as dist 8 | import math 9 | import wandb 10 | 11 | 12 | def get_batch(datareader, device="cpu"): 13 | x, y = datareader.sample_batch() 14 | if "cuda" in torch.device(device).type: 15 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 16 | x = x.pin_memory().to(device, non_blocking=True) 17 | y = y.pin_memory().to(device, non_blocking=True) 18 | else: 19 | x = x.to(device) 20 | y = y.to(device) 21 | return x, y 22 | 23 | 24 | def cos_inf_schedule(n_iterations, n_warmup, div_factor, final_div_factor, n_inf): 25 | """Cosine annealing with warmup and _constant_ final_lr after cycle ended. 26 | Args: 27 | n_iterations: total number of iterations 28 | n_warmup: number of warmup iterations 29 | div_factor: initial division factor for warmup 30 | final_div_factor: final division factor for final lr 31 | n_inf: number of iterations for the final lr (constant lr after cycle ended) 32 | Returns: 33 | schedule: a function that takes the current iteration and 34 | returns the multiplicative factor for the learning rate 35 | """ 36 | max_lr = 1.0 37 | base_lr = max_lr / div_factor 38 | final_lr = base_lr / final_div_factor 39 | 40 | n_anneal_steps = n_iterations - n_inf 41 | 42 | def schedule(step): 43 | if step < n_warmup: 44 | return (step / n_warmup) + (1 - step / n_warmup) / div_factor 45 | elif step < n_anneal_steps: 46 | t = (step - n_warmup) / (n_anneal_steps - n_warmup) 47 | lr = final_lr + 0.5 * (max_lr - final_lr) * (1 + np.cos(np.pi * t)) 48 | return lr 49 | else: 50 | return final_lr 51 | 52 | return schedule 53 | 54 | 55 | def wsd_schedule( 56 | n_iterations, 57 | final_lr_factor=0.0, 58 | n_warmup=1000, 59 | init_div_factor=100, 60 | fract_decay=0.1, 61 | decay_type="linear", 62 | ): 63 | """Warmup, hold, and decay schedule. 64 | Args: 65 | n_iterations: total number of iterations 66 | final_lr_factor: factor by which to reduce max_lr at the end 67 | warmup_fract: fraction of iterations used for warmup 68 | init_div_factor: initial division factor for warmup 69 | fract_decay: fraction of iterations used for decay 70 | Returns: 71 | schedule: a function that takes the current iteration and 72 | returns the multiplicative factor for the learning rate 73 | """ 74 | n_anneal_steps = int(fract_decay * n_iterations) 75 | n_hold = n_iterations - n_anneal_steps 76 | 77 | def schedule(step): 78 | if step < n_warmup: 79 | return (step / n_warmup) + (1 - step / n_warmup) / init_div_factor 80 | elif step < n_hold: 81 | return 1.0 82 | elif step < n_iterations: 83 | if decay_type == "linear": 84 | return final_lr_factor + (1 - final_lr_factor) * ( 85 | 1 - (step - n_hold) / n_anneal_steps 86 | ) 87 | elif decay_type == "exp": 88 | return final_lr_factor ** ((step - n_hold) / n_anneal_steps) 89 | elif decay_type == "cosine": 90 | return ( 91 | final_lr_factor 92 | + (1 - final_lr_factor) 93 | * (1 + math.cos(math.pi * (step - n_hold) / n_anneal_steps)) 94 | * 0.5 95 | ) 96 | elif decay_type == "miror_cosine": 97 | cosine_value = ( 98 | final_lr_factor 99 | + (1 - final_lr_factor) 100 | * (1 + math.cos(math.pi * (step - n_hold) / n_anneal_steps)) 101 | * 0.5 102 | ) 103 | linear_value = final_lr_factor + (1 - final_lr_factor) * ( 104 | 1 - (step - n_hold) / n_anneal_steps 105 | ) 106 | return linear_value * 2 - cosine_value 107 | elif decay_type == "square": 108 | return final_lr_factor + (1 - final_lr_factor) * ( 109 | 1 - ((step - n_hold) / n_anneal_steps) ** 2 110 | ) 111 | 112 | elif decay_type == "sqrt": 113 | return final_lr_factor + (1 - final_lr_factor) * ( 114 | 1 - math.sqrt((step - n_hold) / n_anneal_steps) 115 | ) 116 | 117 | else: 118 | raise ValueError( 119 | f"decay type {decay_type} is not in ['cosine','miror_cosine','linear','exp']" 120 | ) 121 | 122 | else: 123 | return final_lr_factor 124 | 125 | return schedule 126 | 127 | 128 | @torch.no_grad() 129 | def eval( 130 | model, 131 | reader, 132 | device="cpu", 133 | max_num_batches=24, 134 | ctx=nullcontext(), 135 | cfg=None, 136 | ): 137 | assert model.training == False 138 | 139 | loss_list_val, acc_list = [], [] 140 | 141 | for idx in range(max_num_batches): 142 | x, y = get_batch(reader, device=device) 143 | with ctx: 144 | outputs = model(x, targets=y, get_logits=True) 145 | val_loss = outputs["loss"] 146 | 147 | loss_list_val.append(val_loss) 148 | acc_list.append((outputs["logits"].argmax(-1) == y).float().mean()) 149 | 150 | val_acc = torch.stack(acc_list).mean().item() 151 | val_loss = torch.stack(loss_list_val).mean().item() 152 | val_perplexity = 2.71828**val_loss 153 | 154 | return val_acc, val_loss, val_perplexity 155 | 156 | 157 | @torch.no_grad() 158 | def eval_sweep_dropk( 159 | model, 160 | data_tensor, 161 | sequence_length, 162 | batch_size, 163 | n_heads, 164 | device="cpu", 165 | max_num_batches=24, 166 | ctx=nullcontext(), 167 | ): 168 | assert model.training == False 169 | 170 | x_axis, y_axis_pp, y_axis_acc, y_axis_loss = ( 171 | torch.linspace(0.0, 0.95, 15), 172 | [], 173 | [], 174 | [], 175 | ) 176 | loss_list_val, acc_list = [], [] 177 | 178 | for frac in x_axis: 179 | drop_k = int(sequence_length * frac * n_heads) 180 | for _ in range(max_num_batches): 181 | x, y = get_batch(data_tensor, sequence_length, batch_size, device=device) 182 | with ctx: 183 | outputs = model( 184 | x, targets=y, alpha_th=None, drop_k=drop_k, get_logits=True 185 | ) 186 | loss_list_val.append(outputs["ce_loss"]) 187 | acc_list.append((outputs["logits"].argmax(-1) == y).float().mean()) 188 | 189 | y_axis_acc.append(torch.stack(acc_list).mean().item()) 190 | y_axis_loss.append(np.mean(loss_list_val)) 191 | y_axis_pp.append(2.71828 ** y_axis_loss[-1]) 192 | 193 | return x_axis, y_axis_acc, y_axis_pp, y_axis_loss 194 | 195 | 196 | @torch.no_grad() 197 | def eval_sweep_alphath( 198 | model, 199 | data_tensor, 200 | sequence_length, 201 | batch_size, 202 | device="cpu", 203 | max_num_batches=24, 204 | ctx=nullcontext(), 205 | ): 206 | assert model.training == False 207 | 208 | alpha_ths, y_axis_pp, y_axis_acc, y_axis_loss = ( 209 | [0, 1e-4, 1e-3, 1e-2, 1e-1, 2e-1, 3e-1, 4e-1, 5e-1], 210 | [], 211 | [], 212 | [], 213 | ) 214 | loss_list_val, acc_list, x_axis = [], [], [] 215 | 216 | for alpha_th in alpha_ths: 217 | frac_heads_pruned_list = [] 218 | for _ in range(max_num_batches): 219 | x, y = get_batch(data_tensor, sequence_length, batch_size, device=device) 220 | with ctx: 221 | outputs = model( 222 | x, targets=y, alpha_th=alpha_th, drop_k=None, get_logits=True 223 | ) 224 | nph, nh = ( 225 | outputs["num_head_pruned_per_layer"], 226 | outputs["num_heads_per_layer"], 227 | ) 228 | frac_heads_pruned = np.sum(nph) / np.sum( 229 | nh 230 | ) # fractions of heads removed given alpha_th 231 | frac_heads_pruned_list.append(frac_heads_pruned) 232 | loss_list_val.append(outputs["ce_loss"]) 233 | acc_list.append((outputs["logits"].argmax(-1) == y).float().mean()) 234 | 235 | x_axis.append(np.mean(frac_heads_pruned_list)) 236 | y_axis_acc.append(torch.stack(acc_list).mean().item()) 237 | y_axis_loss.append(np.mean(loss_list_val)) 238 | y_axis_pp.append(2.71828 ** y_axis_loss[-1]) 239 | 240 | return x_axis, y_axis_acc, y_axis_pp, y_axis_loss 241 | 242 | 243 | def save_checkpoint(model, opt, scheduler, itr, ckpt_dir: Path): 244 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 245 | model = model.module 246 | 247 | checkpoint = { 248 | "model": model.state_dict(), 249 | "optimizer": opt.state_dict(), 250 | "scheduler": scheduler.state_dict(), 251 | "itr": itr, 252 | } 253 | ckpt_dir.mkdir(exist_ok=True, parents=True) 254 | torch.save(checkpoint, ckpt_dir / "main.pt") 255 | 256 | 257 | def load_checkpoint(model, opt, scheduler, ckpt_path, device): 258 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 259 | model = model.module 260 | 261 | ckpt = torch.load(ckpt_path, map_location=device) 262 | model.load_state_dict(ckpt["model"]) 263 | opt.load_state_dict(ckpt["optimizer"]) 264 | scheduler.load_state_dict(ckpt["scheduler"]) 265 | itr = ckpt["itr"] 266 | return itr 267 | 268 | 269 | def save_worker_state(ckpt_dir: Path): 270 | # Dataloader, rng states 271 | worker_state = { 272 | "rng_torch_cpu": torch.random.get_rng_state(), 273 | "rng_torch_gpu": torch.cuda.get_rng_state(), 274 | "rng_np": np.random.get_state(), 275 | "rng_python": random.getstate(), 276 | } 277 | rank = 0 if not dist.is_initialized() else dist.get_rank() 278 | ckpt_dir.mkdir(exist_ok=True, parents=True) 279 | torch.save(worker_state, ckpt_dir / f"worker_{rank}.pt") 280 | 281 | 282 | def load_worker_state(ckpt_dir: Path): 283 | rank = 0 if not dist.is_initialized() else dist.get_rank() 284 | worker_state = torch.load(ckpt_dir / f"worker_{rank}.pt") 285 | torch.random.set_rng_state(worker_state["rng_torch_cpu"]) 286 | torch.cuda.set_rng_state(worker_state["rng_torch_gpu"]) 287 | np.random.set_state(worker_state["rng_np"]) 288 | random.setstate(worker_state["rng_python"]) 289 | -------------------------------------------------------------------------------- /src/optim/weight_averaging.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from pathlib import Path 3 | import tempfile 4 | 5 | import torch 6 | import wandb 7 | 8 | from .utils import eval 9 | 10 | 11 | class WeightAverager: 12 | def __init__( 13 | self, 14 | model, 15 | horizon=100, 16 | interval=1, 17 | save_dir=None, 18 | device=None, 19 | dtype=torch.float32, 20 | count=0, 21 | ): 22 | super().__init__() 23 | self.device = device # Where to keep avg model 24 | self.dtype = dtype # Precision for accumulation (>= float32) 25 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 26 | model = model.module 27 | self.module = deepcopy(model).to(dtype=self.dtype, device=device) 28 | 29 | assert horizon % interval == 0, "Interval should divide period" 30 | self.interval = interval 31 | self.horizon = horizon 32 | self.period = horizon // interval 33 | if save_dir is None: 34 | # Keep in tempdir 35 | self._tempdir = tempfile.TemporaryDirectory() 36 | self.save_dir = Path(self._tempdir.name) 37 | else: 38 | self.save_dir = Path(save_dir) 39 | self.save_dir.mkdir(parents=True, exist_ok=True) 40 | self.count = count 41 | # check if there are any checkpoints saved in the directory and set 42 | # num_saved to number of checkpoints with name <= count 43 | self.num_saved = len( 44 | [f for f in self.save_dir.iterdir() if f.is_file() and int(f.stem) <= count] 45 | ) 46 | 47 | @torch.no_grad() 48 | def step(self, model, is_master_rank=True): 49 | # Update module with current state 50 | if self.count % self.interval == 0: 51 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 52 | model = model.module 53 | for key, avg in self.module.state_dict().items(): 54 | curr = model.state_dict()[key].to(device=self.device, dtype=avg.dtype) 55 | rate = 1 / ((self.count % self.horizon) // self.interval + 1) 56 | avg.copy_(torch.lerp(avg, curr, rate)) 57 | 58 | self.count += 1 59 | 60 | if self.count % self.horizon == 0 and is_master_rank: 61 | torch.save( 62 | self.module.to().state_dict(), 63 | self.save_dir / f"{self.count}.pt", 64 | ) 65 | self.num_saved += 1 66 | 67 | def get_latest_like(self, model): 68 | # Return model for latest completed period 69 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 70 | model = model.module 71 | new_model = deepcopy(model) 72 | 73 | # Assumes that we saved at a specific iteration, will fail otherwise 74 | count = self.count - self.count % self.horizon 75 | latest_path = self.save_dir / f"{count}.pt" 76 | map_and_load_state_dict(new_model, torch.load(latest_path)) 77 | 78 | return new_model 79 | 80 | def sweep_horizon_like(self, model, max_num=None): 81 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 82 | model = model.module 83 | new_model = deepcopy(model) 84 | avg_state = deepcopy(self.module.state_dict()) 85 | if max_num is None: 86 | max_num = self.num_saved 87 | # Assumes all points exist 88 | for n in range(min(self.num_saved, max_num)): 89 | # Load state from the corresponding checkpoint 90 | count = self.count - self.count % self.horizon - n * self.horizon 91 | state = torch.load(self.save_dir / f"{count}.pt") 92 | 93 | # Update average state 94 | for key, avg in avg_state.items(): 95 | new = state[key].to(dtype=avg.dtype, device=avg.device) 96 | rate = 1 / (n + 1) 97 | avg.copy_(torch.lerp(avg, new, rate)) 98 | 99 | # Set new_model state and yield it 100 | map_and_load_state_dict(new_model, avg_state) 101 | yield ((n + 1) * self.horizon, new_model) 102 | 103 | 104 | def map_and_load_state_dict(model, state_dict): 105 | for key, m_val in model.state_dict().items(): 106 | for alias in (f'_orig_mod.{key}', f'_orig_mod.module.{key}'): # handle compiled / nested model 107 | if key not in state_dict and alias in state_dict: 108 | key = alias 109 | break 110 | s_val = state_dict[key] 111 | m_val.copy_(s_val.to(device=m_val.device, dtype=m_val.dtype)) 112 | 113 | 114 | def eval_wa( 115 | curr_iter, 116 | model, 117 | weight_averager, 118 | val_reader, 119 | type_ctx, 120 | distributed_backend, 121 | cfg, 122 | full_eval=False, 123 | ): 124 | if not distributed_backend.is_master_process(): 125 | # Only evaluate and log on master rank 126 | return 127 | 128 | if weight_averager.num_saved == 0: 129 | return 130 | if not cfg.wa_sweep_horizon: 131 | val_reader.set_step(0) 132 | val_acc, val_loss, val_perplexity = eval( 133 | weight_averager.get_latest_like(model).eval(), 134 | val_reader, 135 | cfg.device, 136 | max_num_batches=( 137 | val_reader.num_batches() 138 | if curr_iter == cfg.iterations or full_eval 139 | else cfg.eval_batches 140 | ), 141 | ctx=type_ctx, 142 | cfg=cfg, 143 | ) 144 | 145 | if cfg.wandb: 146 | if curr_iter == cfg.iterations or full_eval: 147 | logs = { 148 | "iter": curr_iter, 149 | "final-val/loss_wa": val_loss, 150 | "final-val/perplexity_wa": val_perplexity, 151 | "final-val/acc_wa": val_acc, 152 | } 153 | else: 154 | logs = { 155 | "iter": curr_iter, 156 | "val/loss_wa": val_loss, 157 | "val/perplexity_wa": val_perplexity, 158 | "val/acc_wa": val_acc, 159 | } 160 | wandb.log(logs) 161 | print( 162 | f">WA Eval: Iter={curr_iter} " 163 | f"val_loss={val_loss:.3f} " 164 | f"val_pp={val_perplexity:.3f} " 165 | f"val_acc={val_acc:3f}" 166 | ) 167 | else: 168 | losses = [] 169 | for horizon, avg_model in weight_averager.sweep_horizon_like( 170 | model, cfg.max_num_wa_sweeps 171 | ): 172 | avg_model.eval() 173 | val_reader.set_step(0) 174 | _, val_loss, _ = eval( 175 | avg_model, 176 | val_reader, 177 | cfg.device, 178 | max_num_batches=( 179 | val_reader.num_batches() 180 | if curr_iter == cfg.iterations or full_eval 181 | else cfg.eval_batches 182 | ), 183 | ctx=type_ctx, 184 | cfg=cfg, 185 | ) 186 | 187 | losses.append((val_loss, horizon)) 188 | if len(losses) == 0: # in case of none saved yet 189 | return 190 | best_loss, best_horizon = sorted(losses)[0] 191 | 192 | print(f"WA Eval: {[(h, f'{l:0.3e}') for (l,h) in losses]}") 193 | 194 | if cfg.wandb: 195 | if curr_iter == cfg.iterations or full_eval: 196 | logs = { 197 | "iter": curr_iter, 198 | "final-val/loss_wa": losses[0][0], 199 | "final-val/perplexity_wa": 2.71828 ** losses[0][0], 200 | "final-val/best_loss_wa": best_loss, 201 | "final-val/best_perplexity_wa": 2.71828**best_loss, 202 | } 203 | else: 204 | logs = { 205 | "iter": curr_iter, 206 | "val/loss_wa": losses[0][0], 207 | "val/perplexity_wa": 2.71828 ** losses[0][0], 208 | "val/best_loss_wa": best_loss, 209 | "val/best_perplexity_wa": 2.71828**best_loss, 210 | "wa_best_horizon": best_horizon, 211 | } 212 | wandb.log(logs) 213 | 214 | 215 | class ExponentialWeightAverager: 216 | def __init__( 217 | self, 218 | model, 219 | interval=1, 220 | decay=0.95, 221 | device=None, 222 | warmup=0, 223 | dtype=torch.float32, 224 | ): 225 | super().__init__() 226 | self.device = device # Where to keep avg model 227 | self.dtype = dtype # Precision for accumulation (>= float32) 228 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 229 | model = model.module 230 | self.module = deepcopy(model).to(dtype=self.dtype, device=device) 231 | 232 | self.interval = interval 233 | self.decay = decay 234 | self.num_saved = 0 235 | self.warmup = warmup 236 | self.count = 0 237 | 238 | @torch.no_grad() 239 | def step(self, model, is_master_rank=True): 240 | # Update module with current state 241 | 242 | if self.count < self.warmup: 243 | self.count += 1 244 | return 245 | 246 | if self.count == self.warmup: 247 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 248 | model = model.module 249 | for key, avg in self.module.state_dict().items(): 250 | curr = model.state_dict()[key].to(device=self.device, dtype=avg.dtype) 251 | avg.copy_(curr) 252 | 253 | elif self.count % self.interval == 0: 254 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 255 | model = model.module 256 | for key, avg in self.module.state_dict().items(): 257 | curr = model.state_dict()[key].to(device=self.device, dtype=avg.dtype) 258 | avg.copy_(torch.lerp(avg, curr, 1 - self.decay)) 259 | self.num_saved += 1 260 | 261 | self.count += 1 262 | 263 | # if self.count % self.horizon == 0 and is_master_rank: 264 | # torch.save( 265 | # self.module.to(dtype=torch.bfloat16).state_dict(), 266 | # self.save_dir / f"{self.count}.pt", 267 | # ) 268 | # self.num_saved += 1 269 | 270 | def get_latest_like(self, model): 271 | # Return model for latest completed period 272 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 273 | model = model.module 274 | new_model = deepcopy(model) 275 | 276 | map_and_load_state_dict( 277 | new_model, self.module.to(dtype=torch.bfloat16).state_dict() 278 | ) 279 | 280 | return new_model 281 | 282 | 283 | def eval_ema( 284 | curr_iter, 285 | model, 286 | ema, 287 | val_reader, 288 | type_ctx, 289 | distributed_backend, 290 | cfg, 291 | full_eval=False, 292 | ): 293 | if not distributed_backend.is_master_process(): 294 | # Only evaluate and log on master rank 295 | return 296 | 297 | val_reader.set_step(0) 298 | val_acc, val_loss, val_perplexity = eval( 299 | ema.get_latest_like(model).eval(), 300 | val_reader, 301 | cfg.device, 302 | max_num_batches=( 303 | val_reader.num_batches() 304 | if curr_iter == cfg.iterations or full_eval 305 | else cfg.eval_batches 306 | ), 307 | ctx=type_ctx, 308 | cfg=cfg, 309 | ) 310 | 311 | if cfg.wandb: 312 | if curr_iter == cfg.iterations or full_eval: 313 | logs = { 314 | "iter": curr_iter, 315 | "final-val/loss_ema": val_loss, 316 | "final-val/perplexity_ema": val_perplexity, 317 | "final-val/acc_ema": val_acc, 318 | } 319 | else: 320 | logs = { 321 | "iter": curr_iter, 322 | "val/loss_ema": val_loss, 323 | "val/perplexity_ema": val_perplexity, 324 | "val/acc_ema": val_acc, 325 | } 326 | wandb.log(logs) 327 | print( 328 | f">EMA Eval: Iter={curr_iter} " 329 | f"val_loss={val_loss:.3f} " 330 | f"val_pp={val_perplexity:.3f} " 331 | f"val_acc={val_acc:3f}" 332 | ) 333 | --------------------------------------------------------------------------------