├── .gitignore
├── LICENSE
├── README.md
├── assets
├── .DS_Store
├── cos_vs_cooldown.png
├── savings.png
└── scaling_curves.png
├── flops.ipynb
├── requirements.txt
└── src
├── config
├── __init__.py
└── base.py
├── data
├── arxiv.py
├── openwebtext2.py
├── redpajama.py
├── shakespeare.py
├── slimpajama.py
├── utils.py
└── wikitext.py
├── distributed
├── __init__.py
├── backend.py
├── ddp.py
└── single.py
├── logger
├── logger.py
└── rotational_logger.yaml
├── main.py
├── models
├── base.py
├── llama.py
└── utils.py
└── optim
├── base.py
├── utils.py
└── weight_averaging.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Datasets, assets, logging
2 | datasets/
3 | wandb/
4 | exps/
5 | scripts/
6 | cluster/
7 | assets/*.csv
8 | assets/*.json
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | pip-wheel-metadata/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104 | __pypackages__/
105 |
106 | # Celery stuff
107 | celerybeat-schedule
108 | celerybeat.pid
109 |
110 | # SageMath parsed files
111 | *.sage.py
112 |
113 | # Environments
114 | .env
115 | .venv
116 | env/
117 | venv/
118 | ENV/
119 | env.bak/
120 | venv.bak/
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # mkdocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 | .dmypy.json
135 | dmypy.json
136 |
137 | # Pyre type checker
138 | .pyre/
139 |
140 | # vscode
141 | .vscode/
142 |
143 | # cluster stuff
144 | user.yaml
145 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 EPFL Machine Learning and Optimization Laboratory
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Codebase: Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations
2 | This is the codebase accompanying the paper [*Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations*](https://arxiv.org/abs/2405.18392). The code is largely based on our framework [llm-baselines](https://github.com/epfml/llm-baselines) to do research on training LLMs as an extension of [NanoGPT](https://github.com/karpathy/nanogpt).
3 |
4 | **Abstract:**
5 | > Scale has become a main ingredient in obtaining strong machine learning models. As a result, understanding a model's scaling properties is key to effectively designing both the right training setup as well as future generations of architectures. In this work, we argue that scale and training research has been needlessly complex due to reliance on the cosine schedule, which prevents training across different lengths for the same model size. We investigate the training behavior of a direct alternative - constant learning rate and cooldowns - and find that it scales predictably and reliably similar to cosine. Additionally, we show that stochastic weight averaging yields improved performance along the training trajectory, without additional training costs, across different scales. Importantly, with these findings we demonstrate that scaling experiments can be performed with significantly reduced compute and GPU hours by utilizing fewer but reusable training runs.
6 |
7 |
8 |
9 |
10 |
11 | **Figure:** Whereas the cosine learning rate follows a slow annealing, the alternative schedule of constant LR + cooldown is characterized by a fast drop towards the end of training. This cooldown phase initiates a sharp decrease in loss to match cosine; the training perplexity follows the same behavior.
12 |
13 |
14 |
15 |
16 |
17 |
18 | **Figure:** The cooldown schedule allows to perform scaling law experiments for a fraction of the compute. Instead of having to train from scratch (cosine), we launch one long run and perform cooldowns from intermediate checkpoints after training.
19 |
20 |
21 | ## Quickstart
22 |
23 | Create a conda environment and install dependencies (we recommend Python 3.10):
24 |
25 | ```bash
26 | conda create -n env python=3.10
27 | conda activate env
28 | pip install -r requirements.txt
29 | ```
30 |
31 | Run a simple training on the SlimPajama 6B dataset:
32 | ```bash
33 | python ./src/main.py
34 | ```
35 |
36 | The above command trains a 213.34M parameters model with the Llama-style architecture. We recommend to use the `--compile` flag that speeds up training noticeably (up to 20% in our setup).
37 |
38 | ## LR Schedules and Weight Averaging
39 | In order to use the cooldown schedule:
40 | ```bash
41 | python ./src/main.py --compile --scheduler wsd --wsd-fract-decay 0.2
42 | ```
43 | The argument `wsd-fract-decay` controls the fraction of the cooldown phase, and the functional form of the cooldown is handled with the argument `decay-type`.
44 |
45 | If you want to use stochastic weight averaging:
46 | ```bash
47 | python ./src/main.py --compile --scheduler wsd --wsd-fract-decay 0.2 --weight-average
48 | ```
49 | With this, the averaging is done automatically in slots of 500 steps; the model averages are all stored (beware of the disk space). The frequency is handled via the arguments `--wa-interval` (average every k steps) and `--wa-horizon` (the length of the horizon/window).
50 |
51 | Moreover, the argument `wa-sweep-horizon` helps to automatically sweep the horizon to find the best performance, but may slow down training.
52 |
53 | ## FLOPS helpers
54 | The [`flops.ipynb`](flops.ipynb) provides a few helpers and functionalities for FLOPS computations of transformer configurations.
55 |
56 | # Contact & Reference
57 | Please do not hesitate to reach out to us if you have questions!
58 |
59 | In order to cite this work:
60 | ```
61 | @article{hagele2024scaling,
62 | author = {Alexander H\"agele and Elie Bakouch and Atli Kosson and Loubna Ben Allal and Leandro Von Werra and Martin Jaggi},
63 | title = {{Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations}},
64 | year = {2024},
65 | journal = {Advances in Neural Information Processing Systems},
66 | url = {http://arxiv.org/abs/2405.18392}
67 | }
68 | ```
69 |
70 |
71 |
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfml/schedules-and-scaling/17e78d600a7054bc6609a5ca6d050d141d4ad79e/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/cos_vs_cooldown.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfml/schedules-and-scaling/17e78d600a7054bc6609a5ca6d050d141d4ad79e/assets/cos_vs_cooldown.png
--------------------------------------------------------------------------------
/assets/savings.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfml/schedules-and-scaling/17e78d600a7054bc6609a5ca6d050d141d4ad79e/assets/savings.png
--------------------------------------------------------------------------------
/assets/scaling_curves.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfml/schedules-and-scaling/17e78d600a7054bc6609a5ca6d050d141d4ad79e/assets/scaling_curves.png
--------------------------------------------------------------------------------
/flops.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "def embedding(seq_len, vocab_size, d_model):\n",
10 | " return 2 * seq_len * vocab_size * d_model\n",
11 | "\n",
12 | "\n",
13 | "def attention(seq_len, d_model, key_size, num_heads):\n",
14 | " projections = 2 * 3 * seq_len * d_model * (key_size * num_heads)\n",
15 | " logits = 2 * seq_len * seq_len * (key_size * num_heads)\n",
16 | " softmax = 3 * num_heads * seq_len * seq_len\n",
17 | " softmax_query_reduction = 2 * seq_len * seq_len * (key_size * num_heads)\n",
18 | " final_layer = 2 * seq_len * (key_size * num_heads) * d_model\n",
19 | " return projections + logits + softmax + softmax_query_reduction + final_layer\n",
20 | "\n",
21 | "\n",
22 | "def dense(seq_len, d_model, ffw_size, swiglu=False):\n",
23 | " if not swiglu:\n",
24 | " return 2 * seq_len * (2 * d_model * ffw_size)\n",
25 | " else:\n",
26 | " return 2 * seq_len * (3 * d_model * ffw_size)\n",
27 | "\n",
28 | "\n",
29 | "def moe(dense, n_experts, top_k, seq_len, d_model, ffw_size, swiglu=False):\n",
30 | " dense_flops = top_k * dense(seq_len, d_model, ffw_size, swiglu)\n",
31 | " gate_flops = 3 * seq_len * n_experts\n",
32 | " return dense_flops + gate_flops\n",
33 | "\n",
34 | "\n",
35 | "def final_logits(seq_len, d_model, vocab_size):\n",
36 | " return 2 * seq_len * d_model * vocab_size\n",
37 | "\n",
38 | "\n",
39 | "def get_flops(\n",
40 | " n_layers,\n",
41 | " seq_len,\n",
42 | " vocab_size,\n",
43 | " d_model,\n",
44 | " key_size,\n",
45 | " num_heads,\n",
46 | " ffw_size,\n",
47 | " swiglu=False,\n",
48 | " **kwargs,\n",
49 | "):\n",
50 | " return (\n",
51 | " embedding(seq_len, vocab_size, d_model)\n",
52 | " + n_layers\n",
53 | " * (\n",
54 | " attention(seq_len, d_model, key_size, num_heads)\n",
55 | " + dense(seq_len, d_model, ffw_size, swiglu=swiglu)\n",
56 | " )\n",
57 | " + final_logits(seq_len, d_model, vocab_size)\n",
58 | " )\n",
59 | "\n",
60 | "\n",
61 | "def flops_moe(\n",
62 | " seq_len,\n",
63 | " vocab_size,\n",
64 | " n_layers,\n",
65 | " d_model,\n",
66 | " key_size,\n",
67 | " num_heads,\n",
68 | " ffw_size,\n",
69 | " n_experts,\n",
70 | " top_k,\n",
71 | " swiglu=False,\n",
72 | "):\n",
73 | " return (\n",
74 | " embedding(seq_len, vocab_size, d_model)\n",
75 | " + n_layers\n",
76 | " * (\n",
77 | " attention(seq_len, d_model, key_size, num_heads)\n",
78 | " + moe(dense, n_experts, top_k, seq_len, d_model, ffw_size, swiglu=swiglu)\n",
79 | " )\n",
80 | " + final_logits(seq_len, d_model, vocab_size)\n",
81 | " )\n",
82 | "\n",
83 | "\n",
84 | "def parameter_count(\n",
85 | " vocab_size,\n",
86 | " n_layers,\n",
87 | " d_model,\n",
88 | " key_size,\n",
89 | " num_heads,\n",
90 | " num_kv_heads,\n",
91 | " ffw_size,\n",
92 | " n_experts=1,\n",
93 | " swiglu_or_geglu=False,\n",
94 | " **kwargs,\n",
95 | "):\n",
96 | " mul_factor_ffn = 3 if swiglu_or_geglu else 2\n",
97 | " attn = 2 * d_model * num_heads * key_size + 2 * d_model * num_kv_heads * key_size\n",
98 | " return vocab_size * d_model + n_layers * (\n",
99 | " attn + mul_factor_ffn * n_experts * d_model * ffw_size\n",
100 | " )"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 2,
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "multiple_of = 256\n",
110 | "\n",
111 | "tiny = {\n",
112 | " \"d_model\": 384,\n",
113 | " \"key_size\": 64,\n",
114 | " \"num_heads\": 6,\n",
115 | " \"num_kv_heads\": 6,\n",
116 | " \"ffw_size\": int(8 / 3 * 384),\n",
117 | " \"n_layers\": 8,\n",
118 | " \"vocab_size\": 50257,\n",
119 | " \"swiglu\": True,\n",
120 | " \"seq_len\": 512,\n",
121 | "}\n",
122 | "tiny[\"ffw_size\"] = multiple_of * (\n",
123 | " (tiny[\"ffw_size\"] + multiple_of - 1) // multiple_of\n",
124 | ")\n",
125 | "mini = {\n",
126 | " \"d_model\": 512,\n",
127 | " \"key_size\": 64,\n",
128 | " \"num_heads\": 8,\n",
129 | " \"num_kv_heads\": 8,\n",
130 | " \"ffw_size\": int(8 / 3 * 512),\n",
131 | " \"n_layers\": 10,\n",
132 | " \"vocab_size\": 50257,\n",
133 | " \"swiglu\": True,\n",
134 | " \"seq_len\": 512,\n",
135 | "}\n",
136 | "mini[\"ffw_size\"] = multiple_of * (\n",
137 | " (mini[\"ffw_size\"] + multiple_of - 1) // multiple_of\n",
138 | ")\n",
139 | "\n",
140 | "small = {\n",
141 | " \"d_model\": 768,\n",
142 | " \"key_size\": 64,\n",
143 | " \"num_heads\": 12,\n",
144 | " \"num_kv_heads\": 12,\n",
145 | " \"ffw_size\": int(8 / 3 * 768),\n",
146 | " \"n_layers\": 12,\n",
147 | " \"vocab_size\": 50257,\n",
148 | " \"swiglu\": True,\n",
149 | " \"seq_len\": 512,\n",
150 | "}\n",
151 | "small[\"ffw_size\"] = multiple_of * (\n",
152 | " (small[\"ffw_size\"] + multiple_of - 1) // multiple_of\n",
153 | ")\n",
154 | "\n",
155 | "\n",
156 | "_210M = {\n",
157 | " \"d_model\": 768,\n",
158 | " \"key_size\": 64,\n",
159 | " \"num_heads\": 12,\n",
160 | " \"num_kv_heads\": 12,\n",
161 | " \"ffw_size\": int(8 / 3 * 768),\n",
162 | " \"n_layers\": 24,\n",
163 | " \"vocab_size\": 50257,\n",
164 | " \"swiglu\": True,\n",
165 | " \"seq_len\": 512,\n",
166 | "}\n",
167 | "_210M[\"ffw_size\"] = multiple_of * (\n",
168 | " (_210M[\"ffw_size\"] + multiple_of - 1) // multiple_of\n",
169 | ")\n"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": 3,
175 | "metadata": {},
176 | "outputs": [],
177 | "source": [
178 | "all_flops = []\n",
179 | "all_params = []"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": 4,
185 | "metadata": {},
186 | "outputs": [
187 | {
188 | "name": "stdout",
189 | "output_type": "stream",
190 | "text": [
191 | "33.454464\n",
192 | "171834605568\n",
193 | "iters [3.0, 7.0, 10.0]\n",
194 | "tokens [0.3, 0.7, 1.0]\n",
195 | "ratio [9.2, 21.4, 30.6]\n",
196 | "flops [0.1031007633408, 0.2405684477952, 0.343669211136]\n",
197 | "flop savings 0.6\n"
198 | ]
199 | }
200 | ],
201 | "source": [
202 | "model = tiny\n",
203 | "\n",
204 | "n_layers = model[\"n_layers\"]\n",
205 | "d_model = model[\"d_model\"]\n",
206 | "key_size = model[\"key_size\"]\n",
207 | "num_heads = model[\"num_heads\"]\n",
208 | "num_kv_heads = model[\"num_kv_heads\"]\n",
209 | "ffw_size = model[\"ffw_size\"]\n",
210 | "vocab_size = model[\"vocab_size\"]\n",
211 | "swiglu = model[\"swiglu\"]\n",
212 | "n_experts = 8\n",
213 | "top_k = 2\n",
214 | "seq_len = model[\"seq_len\"]\n",
215 | "\n",
216 | "\n",
217 | "flops = 3 * get_flops(\n",
218 | " n_layers,\n",
219 | " seq_len,\n",
220 | " vocab_size,\n",
221 | " d_model,\n",
222 | " key_size,\n",
223 | " num_heads=num_heads,\n",
224 | " ffw_size=ffw_size,\n",
225 | " swiglu=swiglu,\n",
226 | ")\n",
227 | "params = parameter_count(\n",
228 | " vocab_size=vocab_size,\n",
229 | " n_layers=n_layers,\n",
230 | " d_model=d_model,\n",
231 | " key_size=key_size,\n",
232 | " num_heads=num_heads,\n",
233 | " num_kv_heads=num_kv_heads,\n",
234 | " ffw_size=ffw_size,\n",
235 | " swiglu_or_geglu=swiglu,\n",
236 | ")\n",
237 | "# lr 0.002 for cos, 0.001 for wsd\n",
238 | "print(params / 1e6)\n",
239 | "print(flops)\n",
240 | "iters = [2400 / 0.8, 5600 / 0.8, 8000 / 0.8]#, 9600 / 0.8]\n",
241 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n",
242 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n",
243 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n",
244 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n",
245 | "print(\"flops\", flops_all)\n",
246 | "all_flops.append(flops_all)\n",
247 | "all_params.append(params)\n",
248 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 5,
254 | "metadata": {},
255 | "outputs": [
256 | {
257 | "name": "stdout",
258 | "output_type": "stream",
259 | "text": [
260 | "52.99456\n",
261 | "254882611200\n",
262 | "iters [3.8, 7.5, 11.2]\n",
263 | "tokens [0.4, 0.8, 1.2]\n",
264 | "ratio [7.2, 14.5, 21.7]\n",
265 | "flops [0.1911619584, 0.3823239168, 0.5734858752]\n",
266 | "flop savings 0.6\n"
267 | ]
268 | }
269 | ],
270 | "source": [
271 | "tiny2 = {\n",
272 | " \"d_model\": 512,\n",
273 | " \"key_size\": 64,\n",
274 | " \"num_heads\": 8,\n",
275 | " \"num_kv_heads\": 8,\n",
276 | " \"ffw_size\": int(8 / 3 * 512),\n",
277 | " \"n_layers\": 8,\n",
278 | " \"vocab_size\": 50257,\n",
279 | " \"swiglu\": True,\n",
280 | " \"seq_len\": 512,\n",
281 | "}\n",
282 | "tiny2[\"ffw_size\"] = multiple_of * (\n",
283 | " (tiny2[\"ffw_size\"] + multiple_of - 1) // multiple_of\n",
284 | ")\n",
285 | "\n",
286 | "model = tiny2\n",
287 | "n_layers = model[\"n_layers\"]\n",
288 | "d_model = model[\"d_model\"]\n",
289 | "key_size = model[\"key_size\"]\n",
290 | "num_heads = model[\"num_heads\"]\n",
291 | "num_kv_heads = model[\"num_kv_heads\"]\n",
292 | "ffw_size = model[\"ffw_size\"]\n",
293 | "vocab_size = model[\"vocab_size\"]\n",
294 | "swiglu = model[\"swiglu\"]\n",
295 | "n_experts = 8\n",
296 | "top_k = 2\n",
297 | "seq_len = model[\"seq_len\"]\n",
298 | "\n",
299 | "\n",
300 | "flops = 3 * get_flops(\n",
301 | " n_layers,\n",
302 | " seq_len,\n",
303 | " vocab_size,\n",
304 | " d_model,\n",
305 | " key_size,\n",
306 | " num_heads=num_heads,\n",
307 | " ffw_size=ffw_size,\n",
308 | " swiglu=swiglu,\n",
309 | ")\n",
310 | "params = parameter_count(\n",
311 | " vocab_size=vocab_size,\n",
312 | " n_layers=n_layers,\n",
313 | " d_model=d_model,\n",
314 | " key_size=key_size,\n",
315 | " num_heads=num_heads,\n",
316 | " num_kv_heads=num_kv_heads,\n",
317 | " ffw_size=ffw_size,\n",
318 | " swiglu_or_geglu=swiglu,\n",
319 | ")\n",
320 | "# lr 0.002 for cos, 0.001 for wsd\n",
321 | "print(params / 1e6)\n",
322 | "print(flops)\n",
323 | "iters = [3000 / 0.8, 6000 / 0.8, 9000 / 0.8]# 12000 / 0.8]\n",
324 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n",
325 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n",
326 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n",
327 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n",
328 | "print(\"flops\", flops_all)\n",
329 | "all_flops.append(flops_all)\n",
330 | "all_params.append(params)\n",
331 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 6,
337 | "metadata": {},
338 | "outputs": [
339 | {
340 | "name": "stdout",
341 | "output_type": "stream",
342 | "text": [
343 | "59.810304\n",
344 | "279079550976\n",
345 | "iters [7.5, 12.5, 17.5]\n",
346 | "tokens [0.8, 1.3, 1.8]\n",
347 | "ratio [12.8, 21.4, 30.0]\n",
348 | "flops [0.418619326464, 0.69769887744, 0.976778428416]\n",
349 | "flop savings 0.5733333333333334\n"
350 | ]
351 | }
352 | ],
353 | "source": [
354 | "model = mini\n",
355 | "\n",
356 | "n_layers = model[\"n_layers\"]\n",
357 | "d_model = model[\"d_model\"]\n",
358 | "key_size = model[\"key_size\"]\n",
359 | "num_heads = model[\"num_heads\"]\n",
360 | "num_kv_heads = model[\"num_kv_heads\"]\n",
361 | "ffw_size = model[\"ffw_size\"]\n",
362 | "vocab_size = model[\"vocab_size\"]\n",
363 | "swiglu = model[\"swiglu\"]\n",
364 | "n_experts = 8\n",
365 | "top_k = 2\n",
366 | "seq_len = model[\"seq_len\"]\n",
367 | "\n",
368 | "\n",
369 | "flops = 3 * get_flops(\n",
370 | " n_layers,\n",
371 | " seq_len,\n",
372 | " vocab_size,\n",
373 | " d_model,\n",
374 | " key_size,\n",
375 | " num_heads=num_heads,\n",
376 | " ffw_size=ffw_size,\n",
377 | " swiglu=swiglu,\n",
378 | ")\n",
379 | "params = parameter_count(\n",
380 | " vocab_size=vocab_size,\n",
381 | " n_layers=n_layers,\n",
382 | " d_model=d_model,\n",
383 | " key_size=key_size,\n",
384 | " num_heads=num_heads,\n",
385 | " num_kv_heads=num_kv_heads,\n",
386 | " ffw_size=ffw_size,\n",
387 | " swiglu_or_geglu=swiglu,\n",
388 | ")\n",
389 | "\n",
390 | "print(params / 1e6)\n",
391 | "print(flops)\n",
392 | "iters = [6000 / 0.8, 10000 / 0.8, 14000 / 0.8]# 18000 / 0.8]\n",
393 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n",
394 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n",
395 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n",
396 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n",
397 | "print(\"flops\", flops_all)\n",
398 | "all_flops.append(flops_all)\n",
399 | "all_params.append(params)\n",
400 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))\n",
401 | "# lr 0.002 for cos, 0.001 for wsd\n"
402 | ]
403 | },
404 | {
405 | "cell_type": "code",
406 | "execution_count": 7,
407 | "metadata": {},
408 | "outputs": [
409 | {
410 | "name": "stdout",
411 | "output_type": "stream",
412 | "text": [
413 | "93.11296\n",
414 | "409294602240\n",
415 | "iters [10.0, 17.5, 25.0]\n",
416 | "tokens [1.0, 1.8, 2.6]\n",
417 | "ratio [11.0, 19.2, 27.5]\n",
418 | "flops [0.81858920448, 1.43253110784, 2.0464730112]\n",
419 | "flop savings 0.5809523809523809\n"
420 | ]
421 | }
422 | ],
423 | "source": [
424 | "mini2 = {\n",
425 | " \"d_model\": 640,\n",
426 | " \"key_size\": 64,\n",
427 | " \"num_heads\": 10,\n",
428 | " \"num_kv_heads\": 10,\n",
429 | " \"ffw_size\": int(8 / 3 * 640),\n",
430 | " \"n_layers\": 12,\n",
431 | " \"vocab_size\": 50257,\n",
432 | " \"swiglu\": True,\n",
433 | " \"seq_len\": 512,\n",
434 | "}\n",
435 | "mini2[\"ffw_size\"] = multiple_of * (\n",
436 | " (mini2[\"ffw_size\"] + multiple_of - 1) // multiple_of\n",
437 | ")\n",
438 | "\n",
439 | "\n",
440 | "model = mini2\n",
441 | "\n",
442 | "n_layers = model[\"n_layers\"]\n",
443 | "d_model = model[\"d_model\"]\n",
444 | "key_size = model[\"key_size\"]\n",
445 | "num_heads = model[\"num_heads\"]\n",
446 | "num_kv_heads = model[\"num_kv_heads\"]\n",
447 | "ffw_size = model[\"ffw_size\"]\n",
448 | "vocab_size = model[\"vocab_size\"]\n",
449 | "swiglu = model[\"swiglu\"]\n",
450 | "n_experts = 8\n",
451 | "top_k = 2\n",
452 | "seq_len = model[\"seq_len\"]\n",
453 | "\n",
454 | "\n",
455 | "flops = 3 * get_flops(\n",
456 | " n_layers,\n",
457 | " seq_len,\n",
458 | " vocab_size,\n",
459 | " d_model,\n",
460 | " key_size,\n",
461 | " num_heads=num_heads,\n",
462 | " ffw_size=ffw_size,\n",
463 | " swiglu=swiglu,\n",
464 | ")\n",
465 | "params = parameter_count(\n",
466 | " vocab_size=vocab_size,\n",
467 | " n_layers=n_layers,\n",
468 | " d_model=d_model,\n",
469 | " key_size=key_size,\n",
470 | " num_heads=num_heads,\n",
471 | " num_kv_heads=num_kv_heads,\n",
472 | " ffw_size=ffw_size,\n",
473 | " swiglu_or_geglu=swiglu,\n",
474 | ")\n",
475 | "\n",
476 | "print(params / 1e6)\n",
477 | "print(flops)\n",
478 | "iters = [8000 / 0.8, 14000 / 0.8, 20000 / 0.8]# 26000 / 0.8]\n",
479 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n",
480 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n",
481 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n",
482 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n",
483 | "print(\"flops\", flops_all)\n",
484 | "all_flops.append(flops_all)\n",
485 | "all_params.append(params)\n",
486 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))\n",
487 | "# lr 0.002 for cos, 0.001 for wsd"
488 | ]
489 | },
490 | {
491 | "cell_type": "code",
492 | "execution_count": 8,
493 | "metadata": {},
494 | "outputs": [
495 | {
496 | "name": "stdout",
497 | "output_type": "stream",
498 | "text": [
499 | "123.532032\n",
500 | "527392309248\n",
501 | "iters [15.0, 25.0, 35.0]\n",
502 | "tokens [1.5, 2.6, 3.6]\n",
503 | "ratio [12.4, 20.7, 29.0]\n",
504 | "flops [1.582176927744, 2.63696154624, 3.691746164736]\n",
505 | "flop savings 0.5733333333333334\n"
506 | ]
507 | }
508 | ],
509 | "source": [
510 | "model = small\n",
511 | "\n",
512 | "n_layers = model[\"n_layers\"]\n",
513 | "d_model = model[\"d_model\"]\n",
514 | "key_size = model[\"key_size\"]\n",
515 | "num_heads = model[\"num_heads\"]\n",
516 | "num_kv_heads = model[\"num_kv_heads\"]\n",
517 | "ffw_size = model[\"ffw_size\"]\n",
518 | "vocab_size = model[\"vocab_size\"]\n",
519 | "swiglu = model[\"swiglu\"]\n",
520 | "n_experts = 8\n",
521 | "top_k = 2\n",
522 | "seq_len = model[\"seq_len\"]\n",
523 | "\n",
524 | "\n",
525 | "flops = 3 * get_flops(\n",
526 | " n_layers,\n",
527 | " seq_len,\n",
528 | " vocab_size,\n",
529 | " d_model,\n",
530 | " key_size,\n",
531 | " num_heads=num_heads,\n",
532 | " ffw_size=ffw_size,\n",
533 | " swiglu=swiglu,\n",
534 | ")\n",
535 | "params = parameter_count(\n",
536 | " vocab_size=vocab_size,\n",
537 | " n_layers=n_layers,\n",
538 | " d_model=d_model,\n",
539 | " key_size=key_size,\n",
540 | " num_heads=num_heads,\n",
541 | " num_kv_heads=num_kv_heads,\n",
542 | " ffw_size=ffw_size,\n",
543 | " swiglu_or_geglu=swiglu,\n",
544 | ")\n",
545 | "print(params / 1e6)\n",
546 | "print(flops)\n",
547 | "iters = [12000 / 0.8, 20000 / 0.8, 28000 / 0.8]# 36000 / 0.8]\n",
548 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n",
549 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n",
550 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n",
551 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n",
552 | "print(\"flops\", flops_all)\n",
553 | "all_flops.append(flops_all)\n",
554 | "all_params.append(params)\n",
555 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))\n",
556 | "# lr 0.001"
557 | ]
558 | },
559 | {
560 | "cell_type": "code",
561 | "execution_count": 9,
562 | "metadata": {},
563 | "outputs": [
564 | {
565 | "name": "stdout",
566 | "output_type": "stream",
567 | "text": [
568 | "151.843584\n",
569 | "624142319616\n",
570 | "iters [25.0, 37.5, 50.0]\n",
571 | "tokens [2.6, 3.8, 5.1]\n",
572 | "ratio [16.9, 25.3, 33.7]\n",
573 | "flops [3.12071159808, 4.68106739712, 6.24142319616]\n",
574 | "flop savings 0.5555555555555556\n"
575 | ]
576 | }
577 | ],
578 | "source": [
579 | "_151M = {\n",
580 | " \"d_model\": 768,\n",
581 | " \"key_size\": 64,\n",
582 | " \"num_heads\": 12,\n",
583 | " \"num_kv_heads\": 12,\n",
584 | " \"ffw_size\": int(8 / 3 * 768),\n",
585 | " \"n_layers\": 16,\n",
586 | " \"vocab_size\": 50257,\n",
587 | " \"swiglu\": True,\n",
588 | " \"seq_len\": 512,\n",
589 | "}\n",
590 | "\n",
591 | "_151M[\"ffw_size\"] = multiple_of * (\n",
592 | " (_151M[\"ffw_size\"] + multiple_of - 1) // multiple_of\n",
593 | ")\n",
594 | "\n",
595 | "model = _151M\n",
596 | "\n",
597 | "n_layers = model[\"n_layers\"]\n",
598 | "d_model = model[\"d_model\"]\n",
599 | "key_size = model[\"key_size\"]\n",
600 | "num_heads = model[\"num_heads\"]\n",
601 | "num_kv_heads = model[\"num_kv_heads\"]\n",
602 | "ffw_size = model[\"ffw_size\"]\n",
603 | "vocab_size = model[\"vocab_size\"]\n",
604 | "swiglu = model[\"swiglu\"]\n",
605 | "n_experts = 8\n",
606 | "top_k = 2\n",
607 | "seq_len = model[\"seq_len\"]\n",
608 | "\n",
609 | "\n",
610 | "flops = 3 * get_flops(\n",
611 | " n_layers,\n",
612 | " seq_len,\n",
613 | " vocab_size,\n",
614 | " d_model,\n",
615 | " key_size,\n",
616 | " num_heads=num_heads,\n",
617 | " ffw_size=ffw_size,\n",
618 | " swiglu=swiglu,\n",
619 | ")\n",
620 | "params = parameter_count(\n",
621 | " vocab_size=vocab_size,\n",
622 | " n_layers=n_layers,\n",
623 | " d_model=d_model,\n",
624 | " key_size=key_size,\n",
625 | " num_heads=num_heads,\n",
626 | " num_kv_heads=num_kv_heads,\n",
627 | " ffw_size=ffw_size,\n",
628 | " swiglu_or_geglu=swiglu,\n",
629 | ")\n",
630 | "print(params / 1e6)\n",
631 | "print(flops)\n",
632 | "iters = [20000 / 0.8, 30000 / 0.8, 40000 / 0.8]\n",
633 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n",
634 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n",
635 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n",
636 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n",
637 | "print(\"flops\", flops_all)\n",
638 | "all_flops.append(flops_all)\n",
639 | "all_params.append(params)\n",
640 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))\n",
641 | "# double batch size?"
642 | ]
643 | },
644 | {
645 | "cell_type": "code",
646 | "execution_count": 10,
647 | "metadata": {},
648 | "outputs": [
649 | {
650 | "name": "stdout",
651 | "output_type": "stream",
652 | "text": [
653 | "166.1408\n",
654 | "682936762368\n",
655 | "iters [25.0, 37.5, 50.0]\n",
656 | "tokens [2.6, 3.8, 5.1]\n",
657 | "ratio [15.4, 23.1, 30.8]\n",
658 | "flops [3.41468381184, 5.12202571776, 6.82936762368]\n",
659 | "flop savings 0.5555555555555556\n"
660 | ]
661 | }
662 | ],
663 | "source": [
664 | "_166M = {\n",
665 | " \"d_model\": 896,\n",
666 | " \"key_size\": 64,\n",
667 | " \"num_heads\": 14,\n",
668 | " \"num_kv_heads\": 14,\n",
669 | " \"ffw_size\": int(8 / 3 * 896),\n",
670 | " \"n_layers\": 12,\n",
671 | " \"vocab_size\": 50257,\n",
672 | " \"swiglu\": True,\n",
673 | " \"seq_len\": 512,\n",
674 | "}\n",
675 | "\n",
676 | "_166M[\"ffw_size\"] = multiple_of * (\n",
677 | " (_166M[\"ffw_size\"] + multiple_of - 1) // multiple_of\n",
678 | ")\n",
679 | "\n",
680 | "model = _166M\n",
681 | "\n",
682 | "n_layers = model[\"n_layers\"]\n",
683 | "d_model = model[\"d_model\"]\n",
684 | "key_size = model[\"key_size\"]\n",
685 | "num_heads = model[\"num_heads\"]\n",
686 | "num_kv_heads = model[\"num_kv_heads\"]\n",
687 | "ffw_size = model[\"ffw_size\"]\n",
688 | "vocab_size = model[\"vocab_size\"]\n",
689 | "swiglu = model[\"swiglu\"]\n",
690 | "n_experts = 8\n",
691 | "top_k = 2\n",
692 | "seq_len = model[\"seq_len\"]\n",
693 | "\n",
694 | "\n",
695 | "flops = 3 * get_flops(\n",
696 | " n_layers,\n",
697 | " seq_len,\n",
698 | " vocab_size,\n",
699 | " d_model,\n",
700 | " key_size,\n",
701 | " num_heads=num_heads,\n",
702 | " ffw_size=ffw_size,\n",
703 | " swiglu=swiglu,\n",
704 | ")\n",
705 | "params = parameter_count(\n",
706 | " vocab_size=vocab_size,\n",
707 | " n_layers=n_layers,\n",
708 | " d_model=d_model,\n",
709 | " key_size=key_size,\n",
710 | " num_heads=num_heads,\n",
711 | " num_kv_heads=num_kv_heads,\n",
712 | " ffw_size=ffw_size,\n",
713 | " swiglu_or_geglu=swiglu,\n",
714 | ")\n",
715 | "print(params / 1e6)\n",
716 | "print(flops)\n",
717 | "iters = [20000 / 0.8, 30000 / 0.8, 40000 / 0.8]\n",
718 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n",
719 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n",
720 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n",
721 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n",
722 | "print(\"flops\", flops_all)\n",
723 | "all_flops.append(flops_all)\n",
724 | "all_params.append(params)\n",
725 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))\n",
726 | "# double batch size?"
727 | ]
728 | },
729 | {
730 | "cell_type": "code",
731 | "execution_count": 11,
732 | "metadata": {},
733 | "outputs": [
734 | {
735 | "name": "stdout",
736 | "output_type": "stream",
737 | "text": [
738 | "208.466688\n",
739 | "817642340352\n",
740 | "iters [37.5, 50.0, 62.5]\n",
741 | "tokens [3.8, 5.1, 6.4]\n",
742 | "ratio [18.4, 24.6, 30.7]\n",
743 | "flops [6.13231755264, 8.17642340352, 10.2205292544]\n",
744 | "flop savings 0.5333333333333333\n"
745 | ]
746 | }
747 | ],
748 | "source": [
749 | "model = _210M\n",
750 | "\n",
751 | "n_layers = model[\"n_layers\"]\n",
752 | "d_model = model[\"d_model\"]\n",
753 | "key_size = model[\"key_size\"]\n",
754 | "num_heads = model[\"num_heads\"]\n",
755 | "num_kv_heads = model[\"num_kv_heads\"]\n",
756 | "ffw_size = model[\"ffw_size\"]\n",
757 | "vocab_size = model[\"vocab_size\"]\n",
758 | "swiglu = model[\"swiglu\"]\n",
759 | "n_experts = 8\n",
760 | "top_k = 2\n",
761 | "seq_len = model[\"seq_len\"]\n",
762 | "\n",
763 | "\n",
764 | "flops = 3 * get_flops(\n",
765 | " n_layers,\n",
766 | " seq_len,\n",
767 | " vocab_size,\n",
768 | " d_model,\n",
769 | " key_size,\n",
770 | " num_heads=num_heads,\n",
771 | " ffw_size=ffw_size,\n",
772 | " swiglu=swiglu,\n",
773 | ")\n",
774 | "params = parameter_count(\n",
775 | " vocab_size=vocab_size,\n",
776 | " n_layers=n_layers,\n",
777 | " d_model=d_model,\n",
778 | " key_size=key_size,\n",
779 | " num_heads=num_heads,\n",
780 | " num_kv_heads=num_kv_heads,\n",
781 | " ffw_size=ffw_size,\n",
782 | " swiglu_or_geglu=swiglu,\n",
783 | ")\n",
784 | "print(params / 1e6)\n",
785 | "print(flops)\n",
786 | "# iters = [22222, 44444, 66666]\n",
787 | "iters = [30000 / 0.8, 40000 / 0.8, 50000 / 0.8]\n",
788 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n",
789 | "print(\"tokens\", [float(f\"{200 * 512 * i / 1e9:.1f}\") for i in iters])\n",
790 | "print(\"ratio\", [float(f\"{200 * 512 * i / params:.1f}\") for i in iters])\n",
791 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n",
792 | "print(\"flops\", flops_all)\n",
793 | "all_flops.append(flops_all)\n",
794 | "all_params.append(params)\n",
795 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))"
796 | ]
797 | },
798 | {
799 | "cell_type": "code",
800 | "execution_count": 12,
801 | "metadata": {},
802 | "outputs": [
803 | {
804 | "name": "stdout",
805 | "output_type": "stream",
806 | "text": [
807 | "359.744512\n",
808 | "1341445373952\n",
809 | "iters [25.0, 37.5, 50.0]\n",
810 | "tokens [5.1, 7.7, 10.2]\n",
811 | "ratio [14.2, 21.3, 28.5]\n",
812 | "flops [6.70722686976, 10.06084030464, 13.41445373952]\n",
813 | "flop savings 0.5555555555555556\n"
814 | ]
815 | }
816 | ],
817 | "source": [
818 | "_350M = {\n",
819 | " \"d_model\": 1024,\n",
820 | " \"key_size\": 64,\n",
821 | " \"num_heads\": 16,\n",
822 | " \"num_kv_heads\": 16,\n",
823 | " \"ffw_size\": int(8 / 3 * 1024),\n",
824 | " \"n_layers\": 24,\n",
825 | " \"vocab_size\": 50257,\n",
826 | " \"swiglu\": True,\n",
827 | " \"seq_len\": 512,\n",
828 | "}\n",
829 | "\n",
830 | "_350M[\"ffw_size\"] = multiple_of * (\n",
831 | " (_350M[\"ffw_size\"] + multiple_of - 1) // multiple_of\n",
832 | ")\n",
833 | "\n",
834 | "model = _350M\n",
835 | "\n",
836 | "n_layers = model[\"n_layers\"]\n",
837 | "d_model = model[\"d_model\"]\n",
838 | "key_size = model[\"key_size\"]\n",
839 | "num_heads = model[\"num_heads\"]\n",
840 | "num_kv_heads = model[\"num_kv_heads\"]\n",
841 | "ffw_size = model[\"ffw_size\"]\n",
842 | "vocab_size = model[\"vocab_size\"]\n",
843 | "swiglu = model[\"swiglu\"]\n",
844 | "n_experts = 8\n",
845 | "top_k = 2\n",
846 | "seq_len = model[\"seq_len\"]\n",
847 | "\n",
848 | "\n",
849 | "flops = 3 * get_flops(\n",
850 | " n_layers,\n",
851 | " seq_len,\n",
852 | " vocab_size,\n",
853 | " d_model,\n",
854 | " key_size,\n",
855 | " num_heads=num_heads,\n",
856 | " ffw_size=ffw_size,\n",
857 | " swiglu=swiglu,\n",
858 | ")\n",
859 | "params = parameter_count(\n",
860 | " vocab_size=vocab_size,\n",
861 | " n_layers=n_layers,\n",
862 | " d_model=d_model,\n",
863 | " key_size=key_size,\n",
864 | " num_heads=num_heads,\n",
865 | " num_kv_heads=num_kv_heads,\n",
866 | " ffw_size=ffw_size,\n",
867 | " swiglu_or_geglu=swiglu,\n",
868 | ")\n",
869 | "print(params / 1e6)\n",
870 | "print(flops)\n",
871 | "iters = [20000 / 0.8, 30000 / 0.8, 40000 / 0.8]\n",
872 | "print(\"iters\", [float(f\"{i / 1e3:.1f}\") for i in iters])\n",
873 | "print(\"tokens\", [float(f\"{400 * 512 * i / 1e9:.1f}\") for i in iters])\n",
874 | "print(\"ratio\", [float(f\"{400 * 512 * i / params:.1f}\") for i in iters])\n",
875 | "flops_all = [flops * 200 * i / 1e18 for i in iters]\n",
876 | "print(\"flops\", flops_all)\n",
877 | "all_flops.append(flops_all)\n",
878 | "all_params.append(params)\n",
879 | "print(\"flop savings\", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))\n",
880 | "# double batch size?"
881 | ]
882 | }
883 | ],
884 | "metadata": {
885 | "kernelspec": {
886 | "display_name": "llm-baselines",
887 | "language": "python",
888 | "name": "python3"
889 | },
890 | "language_info": {
891 | "codemirror_mode": {
892 | "name": "ipython",
893 | "version": 3
894 | },
895 | "file_extension": ".py",
896 | "mimetype": "text/x-python",
897 | "name": "python",
898 | "nbconvert_exporter": "python",
899 | "pygments_lexer": "ipython3",
900 | "version": "3.10.14"
901 | }
902 | },
903 | "nbformat": 4,
904 | "nbformat_minor": 2
905 | }
906 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tiktoken
2 | --find-links https://download.pytorch.org/whl/torch_stable.html
3 | torch==2.2.0+cu118
4 | torchaudio==2.2.0+cu118
5 | torchvision==0.17.0+cu118
6 | tqdm==4.65.0
7 | schedulefree
8 | transformers
9 | wandb
10 | datasets
11 | zstandard
--------------------------------------------------------------------------------
/src/config/__init__.py:
--------------------------------------------------------------------------------
1 | from . import base
2 |
3 | CONFIG_FORMAT_TO_MODULE_MAP = {
4 | "base": base,
5 | }
6 |
7 |
8 | def parse_args_with_format(format, base_parser, args, namespace):
9 | return CONFIG_FORMAT_TO_MODULE_MAP[format].parse_args(base_parser, args, namespace)
10 |
11 |
12 | def registered_formats():
13 | return CONFIG_FORMAT_TO_MODULE_MAP.keys()
14 |
--------------------------------------------------------------------------------
/src/config/base.py:
--------------------------------------------------------------------------------
1 | import distributed
2 |
3 |
4 | def parse_args(base_parser, args, namespace):
5 | parser = base_parser
6 | # General training params
7 | parser.add_argument("--experiment-name", default=None, type=str)
8 | parser.add_argument("--seed", default=0, type=int)
9 | parser.add_argument("--data-seed", default=1337, type=int)
10 | parser.add_argument("--eval-interval", default=200, type=int)
11 | parser.add_argument("--full-eval-at", nargs="+", type=int)
12 | parser.add_argument("--eval-batches", default=32, type=int)
13 | parser.add_argument("--device", default="cuda:0", type=str)
14 | parser.add_argument(
15 | "--distributed-backend",
16 | default=None,
17 | type=str,
18 | required=False,
19 | choices=distributed.registered_backends(),
20 | )
21 | parser.add_argument("--log-interval", default=50, type=int)
22 |
23 | # Checkpointing
24 | parser.add_argument("--results-base-folder", default="./exps", type=str)
25 | parser.add_argument("--permanent-ckpt-interval", default=0, type=int)
26 | parser.add_argument("--latest-ckpt-interval", default=0, type=int)
27 | parser.add_argument("--resume-from", default=None, type=str)
28 | parser.add_argument("--resume-from-swa", default=None, type=str)
29 |
30 | parser.add_argument("--auto-resume", default=True)
31 |
32 | # logging params (WandB)
33 | parser.add_argument("--wandb", action="store_true") # whether to use wandb or not
34 | parser.add_argument("--wandb-project", default="my-project", type=str)
35 | parser.add_argument(
36 | "--wandb-run-prefix", default="none", type=str
37 | ) # is added before the autogenerated experiment name
38 | parser.add_argument(
39 | "--eval-seq-prefix", default="none", type=str
40 | ) # prefix used to generate sequences
41 | parser.add_argument("--log-dynamics", action="store_true")
42 | parser.add_argument(
43 | "--dynamics-logger-cfg", default="./src/logger/rotational_logger.yaml", type=str
44 | )
45 |
46 | # Schedule
47 | parser.add_argument(
48 | "--scheduler",
49 | default="cos",
50 | choices=["linear", "cos", "wsd", "none", "cos_inf"],
51 | )
52 | parser.add_argument("--cos-inf-steps", default=0, type=int)
53 | # parser.add_argument("--cos-final-lr", default=1e-6, type=float)
54 | parser.add_argument("--iterations", default=15000, type=int)
55 | parser.add_argument("--warmup-steps", default=300, type=int)
56 | parser.add_argument("--lr", default=1e-3, type=float)
57 | # wsd
58 | parser.add_argument("--wsd-final-lr-scale", default=0.0, type=float)
59 | parser.add_argument("--wsd-fract-decay", default=0.1, type=float)
60 | # parser.add_argument("--wsd-exponential-decay", action="store_true")
61 | parser.add_argument(
62 | "--decay-type",
63 | default="linear",
64 | choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"],
65 | )
66 | # Optimization
67 | parser.add_argument("--opt", default="adamw", choices=["adamw", "sgd", "SFAdamW"])
68 | parser.add_argument("--batch-size", default=50, type=int)
69 | parser.add_argument("--acc-steps", default=4, type=int)
70 | parser.add_argument("--weight-decay", default=1e-1, type=float)
71 | parser.add_argument("--beta1", default=0.9, type=float)
72 | parser.add_argument("--beta2", default=0.95, type=float)
73 | parser.add_argument(
74 | "--grad-clip", default=1.0, type=float
75 | ) # default value is 1.0 in NanoGPT
76 |
77 | # Weight Averaging
78 | parser.add_argument("--weight-average", action="store_true")
79 | parser.add_argument(
80 | "--wa-interval",
81 | default=5,
82 | type=int,
83 | help="How often to take the average (every k steps). Must divide wa-horizon.",
84 | )
85 | parser.add_argument(
86 | "--wa-horizon",
87 | default=500,
88 | type=int,
89 | help="How frequently we save uniform model averages. Should divide "
90 | + "latest-ckpt-interval, otherwise some points may not be saved "
91 | + "correctly.",
92 | )
93 | parser.add_argument(
94 | "--wa-dtype",
95 | default="float32",
96 | type=str,
97 | choices=["float32", "float64"],
98 | )
99 |
100 | parser.add_argument("--wa-use-temp-dir", action="store_true")
101 | parser.add_argument("--wa-sweep-horizon", action="store_true")
102 | parser.add_argument("--max-num-wa-sweeps", default=5, type=int)
103 |
104 | parser.add_argument("--exponential-moving-average", action="store_true")
105 | parser.add_argument(
106 | "--ema-interval",
107 | default=10,
108 | type=int,
109 | help="How often to take the EMA average (every k steps).",
110 | )
111 | parser.add_argument(
112 | "--ema-decay",
113 | default=0.95,
114 | type=float,
115 | help="EMA decay parameter (between 0.9 and 1).",
116 | )
117 | parser.add_argument(
118 | "--ema-after-warmup",
119 | action="store_true",
120 | help="Start EMA after warmup steps.",
121 | )
122 |
123 | # Dataset params
124 | parser.add_argument("--datasets-dir", type=str, default="./datasets/")
125 | parser.add_argument(
126 | "--dataset",
127 | default="slimpajama",
128 | choices=[
129 | "wikitext",
130 | "shakespeare-char",
131 | "arxiv",
132 | "arxiv2000",
133 | "arxiv+wiki",
134 | "openwebtext2",
135 | "redpajama",
136 | "slimpajama",
137 | "slimpajama_chunk1",
138 | "redpajamav2",
139 | ],
140 | )
141 | parser.add_argument(
142 | "--tokenizer", default="gpt2", type=str, choices=["gpt2", "mistral"]
143 | )
144 | parser.add_argument("--vocab-size", default=50304, type=int)
145 | parser.add_argument(
146 | "--data-in-ram", action="store_true"
147 | ) # force the data to RAM, mostly useless except for openwebtext2
148 |
149 | # Model params
150 | parser.add_argument(
151 | "--model",
152 | default="llama",
153 | choices=[
154 | "base",
155 | "llama",
156 | ],
157 | )
158 | parser.add_argument("--parallel-block", action="store_true")
159 | parser.add_argument(
160 | "--use-pretrained", default="none", type=str
161 | ) # 'none', 'gpt-2' or a path to the pretraind model
162 | parser.add_argument("--from-dense", action="store_true")
163 | parser.add_argument("--init-std", default=0.02, type=float)
164 | parser.add_argument("--dropout", default=0.0, type=float)
165 | parser.add_argument("--n-head", default=12, type=int)
166 | parser.add_argument("--n-layer", default=24, type=int) # depths in att + ff blocks
167 | parser.add_argument("--sequence-length", default=512, type=int)
168 | parser.add_argument(
169 | "--n-embd", default=768, type=int # embedding size / hidden size ...
170 | )
171 | parser.add_argument(
172 | "--multiple-of", # make SwiGLU hidden layer size multiple of large power of 2
173 | default=256,
174 | type=int,
175 | )
176 | parser.add_argument("--rmsnorm-eps", default=1e-5, type=float)
177 | parser.add_argument(
178 | "--dtype",
179 | default="bfloat16",
180 | type=str,
181 | choices=["float32", "float16", "bfloat16"],
182 | )
183 | parser.add_argument("--bias", default=False, type=bool)
184 | parser.add_argument("--compile", action="store_true")
185 | parser.add_argument("--mlp-dim-exp-factor", default=1.0, type=float)
186 | return parser.parse_args(args, namespace)
187 |
--------------------------------------------------------------------------------
/src/data/arxiv.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tarfile
3 | import logging
4 | from pathlib import Path
5 | from typing import Optional
6 | from multiprocessing import Pool
7 | from tempfile import NamedTemporaryFile
8 | from subprocess import Popen, TimeoutExpired, PIPE
9 | from typing import Tuple, List
10 |
11 | import numpy as np
12 | import requests
13 | from tqdm.auto import tqdm
14 | import tiktoken
15 |
16 |
17 | def convert_to_markdown(args: Tuple[Path, Path]):
18 | texfile, mdroot = args
19 | mdfile = mdroot / f"{texfile.name}.md"
20 | with Popen(
21 | ["pandoc", "--wrap=none", "--from", "latex", texfile, "--output", mdfile],
22 | stderr=PIPE,
23 | ) as proc:
24 | try:
25 | proc.communicate(timeout=1)
26 | except TimeoutExpired:
27 | proc.kill()
28 |
29 |
30 | def fetch_arxiv(root: Path, year: int):
31 | # download latex
32 | url = f"https://www.cs.cornell.edu/projects/kddcup/download/hep-th-{year}.tar.gz"
33 | texroot = root / "tex"
34 | print("Downloading Arxiv year", year)
35 | req = requests.get(url, timeout=60)
36 | with NamedTemporaryFile(suffix=".tar.gz") as f:
37 | f.write(req.content)
38 | logging.debug("Tar saved in tempfile %s" % f.name)
39 | with tarfile.open(f.name) as tar:
40 | logging.debug("Extracting tarfile")
41 | tar.extractall(texroot)
42 |
43 | # convert to markdown
44 | mdroot = root / "md" / str(year)
45 | mdroot.mkdir(parents=True)
46 | files = list((texroot / str(year)).iterdir())
47 | with Pool(os.cpu_count()) as p:
48 | args = [(texfile, mdroot) for texfile in files]
49 | for _ in tqdm(
50 | p.imap_unordered(convert_to_markdown, args),
51 | desc="Converting to markdown",
52 | total=len(files),
53 | ):
54 | pass
55 |
56 |
57 | def tokenize_arxiv(root: Path, year: int):
58 | tokenizer = tiktoken.get_encoding("gpt2")
59 | tokens = []
60 | tokens_val = []
61 | tokens_test = []
62 | mds = root / "md" / str(year)
63 |
64 | # tokenize
65 | desc = f"Tokenizing {year}"
66 | for i, mdpath in enumerate(tqdm(list(mds.iterdir()), desc=desc)):
67 | with open(mdpath, encoding="utf8") as f:
68 | text = "".join(f.readlines())
69 | if i % 10 <= 6: # train split
70 | tokens += tokenizer.encode(text)
71 | elif i % 10 <= 8: # val split
72 | tokens_val += tokenizer.encode(text)
73 | else: # test split
74 | tokens_test += tokenizer.encode(text)
75 |
76 | # save to dir
77 | tpath = root / str(year)
78 | tpath.mkdir(parents=True)
79 | for x, name in zip([tokens, tokens_val, tokens_test], ["train", "val", "test"]):
80 | mem = np.memmap(tpath / f"{name}.npy", dtype=np.uint16, mode="w+", shape=len(x))
81 | for i, v in enumerate(x):
82 | mem[i] = v
83 |
84 |
85 | def load_arxiv(cachedir: Path, years: Optional[List[int]] = None):
86 | all_years = list(range(1993, 2004)) # 1992 seems to give some problem
87 | if years is None:
88 | years = all_years
89 | assert set(years) <= set(all_years)
90 | root = cachedir / "arxiv"
91 | root.mkdir(exist_ok=True, parents=True)
92 |
93 | # download all years requested that are not present
94 | for year in years:
95 | if not (root / "md" / str(year)).exists():
96 | fetch_arxiv(root, year)
97 |
98 | # tokenize all years not previously tokenized
99 | for year in years:
100 | if not (root / str(year)).exists():
101 | tokenize_arxiv(root, year)
102 |
103 | # load meta
104 | ret = {}
105 | for split in ["train", "val"]:
106 | paths = [root / str(year) / f"{split}.npy" for year in years]
107 | x = [np.memmap(path, dtype=np.uint16, mode="r") for path in paths]
108 | ret[split] = np.concatenate(x)
109 | return ret
110 |
111 |
112 | def get_arxiv_2000(datasets_base_dir):
113 | return load_arxiv(Path(datasets_base_dir), [2000])
114 |
115 |
116 | def get_arxiv_full(datasets_base_dir):
117 | return load_arxiv(Path(datasets_base_dir))
118 |
--------------------------------------------------------------------------------
/src/data/openwebtext2.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import numpy as np
4 | import tiktoken
5 | from datasets import load_dataset
6 |
7 |
8 | tknzr = tiktoken.get_encoding("gpt2")
9 |
10 |
11 | def get_openwebtext2_data(datasets_base_dir, num_proc=40):
12 | """https://openwebtext2.readthedocs.io/en/latest/"""
13 | OWT2_DATA_PATH = os.path.join(datasets_base_dir, "openwebtext2/")
14 | if not os.path.exists(os.path.join(OWT2_DATA_PATH, "train.bin")):
15 | os.makedirs(OWT2_DATA_PATH, exist_ok=True)
16 | dataset = load_dataset("the_pile_openwebtext2")
17 |
18 | split_dataset = dataset["train"].train_test_split(
19 | test_size=0.0005, seed=2357, shuffle=True
20 | )
21 | split_dataset["val"] = split_dataset.pop("test")
22 |
23 | def process(example):
24 | ids = tknzr.encode_ordinary(
25 | example["text"]
26 | ) # encode_ordinary ignores any special tokens
27 | ids.append(
28 | tknzr.eot_token
29 | ) # add the end of text token, e.g. 50256 for gpt2 bpe
30 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
31 | out = {"ids": ids, "len": len(ids)}
32 | return out
33 |
34 | # tokenize the dataset
35 | tokenized = split_dataset.map(
36 | process,
37 | remove_columns=["text"],
38 | desc="tokenizing the splits",
39 | num_proc=num_proc,
40 | )
41 |
42 | # concatenate all the ids in each dataset into one large file we can use for training
43 | for split, dset in tokenized.items():
44 | arr_len = np.sum(dset["len"])
45 | filename = os.path.join(OWT2_DATA_PATH, f"{split}.bin")
46 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
47 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,))
48 | total_batches = 1024
49 |
50 | idx = 0
51 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"):
52 | # Batch together samples for faster write
53 | batch = dset.shard(
54 | num_shards=total_batches, index=batch_idx, contiguous=True
55 | ).with_format("numpy")
56 | arr_batch = np.concatenate(batch["ids"])
57 | # Write into mmap
58 | arr[idx : idx + len(arr_batch)] = arr_batch
59 | idx += len(arr_batch)
60 | arr.flush()
61 |
62 | return {
63 | "train": os.path.join(OWT2_DATA_PATH, "train.bin"),
64 | "val": os.path.join(OWT2_DATA_PATH, "val.bin"),
65 | }
66 |
--------------------------------------------------------------------------------
/src/data/redpajama.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import numpy as np
4 | import tiktoken
5 | from datasets import load_dataset
6 |
7 |
8 | tknzr = tiktoken.get_encoding("gpt2")
9 |
10 |
11 | def get_redpajama_data(datasets_dir, num_proc=40):
12 | RPJ_DATA_PATH = os.path.join(datasets_dir, "redpajama1Tsample/")
13 | if not os.path.exists(os.path.join(RPJ_DATA_PATH, "train.bin")):
14 | os.makedirs(RPJ_DATA_PATH, exist_ok=True)
15 | dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample")
16 |
17 | split_dataset = dataset["train"].train_test_split(
18 | test_size=0.0005, seed=2357, shuffle=True
19 | )
20 | split_dataset["val"] = split_dataset.pop("test")
21 |
22 | def process(example):
23 | ids = tknzr.encode_ordinary(
24 | example["text"]
25 | ) # encode_ordinary ignores any special tokens
26 | ids.append(
27 | tknzr.eot_token
28 | ) # add the end of text token, e.g. 50256 for gpt2 bpe
29 | out = {"ids": ids, "len": len(ids)}
30 | return out
31 |
32 | # tokenize the dataset
33 | tokenized = split_dataset.map(
34 | process,
35 | remove_columns=["text"],
36 | desc="tokenizing the splits",
37 | num_proc=num_proc,
38 | )
39 |
40 | # concatenate all the ids in each dataset into one large file we can use for training
41 | for split, dset in tokenized.items():
42 | arr_len = np.sum(dset["len"])
43 | filename = os.path.join(RPJ_DATA_PATH, f"{split}.bin")
44 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
45 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,))
46 | total_batches = min(1024, len(dset))
47 |
48 | idx = 0
49 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"):
50 | # Batch together samples for faster write
51 | batch = dset.shard(
52 | num_shards=total_batches, index=batch_idx, contiguous=True
53 | ).with_format("numpy")
54 | arr_batch = np.concatenate(batch["ids"])
55 | # Write into mmap
56 | arr[idx : idx + len(arr_batch)] = arr_batch
57 | idx += len(arr_batch)
58 | arr.flush()
59 |
60 | train_data = np.memmap(
61 | os.path.join(RPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r"
62 | )
63 | val_data = np.memmap(
64 | os.path.join(RPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r"
65 | )
66 |
67 | return {"train": train_data, "val": val_data}
68 |
69 |
70 | def get_redpajamav2_data(datasets_dir, num_proc=40):
71 | """https://openwebtext2.readthedocs.io/en/latest/"""
72 | RPJ_V2_DATA_PATH = os.path.join(datasets_dir, "redpajamaV2sample/")
73 | if not os.path.exists(os.path.join(RPJ_V2_DATA_PATH, "train.bin")):
74 | os.makedirs(RPJ_V2_DATA_PATH, exist_ok=True)
75 | dataset = load_dataset("togethercomputer/RedPajama-Data-V2", name="sample")
76 |
77 | split_dataset = dataset["train"].train_test_split(
78 | test_size=0.0005, seed=2357, shuffle=True
79 | )
80 | split_dataset["val"] = split_dataset.pop("test")
81 |
82 | def process(example):
83 | ids = tknzr.encode_ordinary(
84 | example["raw_content"]
85 | ) # encode_ordinary ignores any special tokens
86 | ids.append(
87 | tknzr.eot_token
88 | ) # add the end of text token, e.g. 50256 for gpt2 bpe
89 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
90 | out = {"ids": ids, "len": len(ids)}
91 | return out
92 |
93 | # tokenize the dataset
94 | tokenized = split_dataset.map(
95 | process,
96 | remove_columns=["raw_content"],
97 | desc="tokenizing the splits",
98 | num_proc=num_proc,
99 | )
100 |
101 | # concatenate all the ids in each dataset into one large file we can use for training
102 | for split, dset in tokenized.items():
103 | arr_len = np.sum(dset["len"])
104 | filename = os.path.join(RPJ_V2_DATA_PATH, f"{split}.bin")
105 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
106 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,))
107 | total_batches = min(1024, len(dset))
108 |
109 | idx = 0
110 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"):
111 | # Batch together samples for faster write
112 | batch = dset.shard(
113 | num_shards=total_batches, index=batch_idx, contiguous=True
114 | ).with_format("numpy")
115 | arr_batch = np.concatenate(batch["ids"])
116 | # Write into mmap
117 | arr[idx : idx + len(arr_batch)] = arr_batch
118 | idx += len(arr_batch)
119 | arr.flush()
120 |
121 | return {
122 | "train": os.path.join(RPJ_V2_DATA_PATH, "train.bin"),
123 | "val": os.path.join(RPJ_V2_DATA_PATH, "val.bin"),
124 | }
125 |
--------------------------------------------------------------------------------
/src/data/shakespeare.py:
--------------------------------------------------------------------------------
1 | import os
2 | from string import ascii_letters, digits, punctuation
3 |
4 | import numpy as np
5 | import requests
6 |
7 |
8 | _char_decode = dict(
9 | enumerate(sorted(set(ascii_letters + digits + punctuation + " \n")))
10 | )
11 | _char_encode = {char: i for i, char in _char_decode.items()}
12 |
13 |
14 | def char_tknzr(txt: str):
15 | return [_char_encode[char] for char in txt if char in _char_encode]
16 |
17 |
18 | def get_shakespeare_data(datasets_dir):
19 | """Inspired from https://github.com/karpathy/nanoGPT/"""
20 | DATA_PATH = os.path.join(datasets_dir, "shakespeare")
21 | raw_path = os.path.join(DATA_PATH, "raw.txt")
22 | train_path = os.path.join(DATA_PATH, f"train.npy")
23 | test_path = os.path.join(DATA_PATH, f"test.npy")
24 |
25 | # if path is not even there, download all data
26 | if not os.path.exists(DATA_PATH):
27 | print("Downloading raw Shakespeare texts")
28 | url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
29 | os.makedirs(DATA_PATH, exist_ok=True)
30 | text = requests.get(url, timeout=60).text
31 | with open(raw_path, "w+", encoding="utf8") as f:
32 | f.write(text)
33 |
34 | # attempt to find cached version for current tokenizer
35 | if not os.path.exists(train_path) or not os.path.exists(test_path):
36 | print("Tokenizing Shakespeare texts")
37 | # load text
38 | with open(raw_path, encoding="utf8") as f:
39 | text = "".join(f.readlines())
40 | i = int(0.8 * len(text))
41 | # encode text
42 | x = np.array(char_tknzr(text[:i]), dtype=np.uint16)
43 | x_test = np.array(char_tknzr(text[i:]), dtype=np.uint16)
44 | # map memory
45 | mem = np.memmap(train_path, dtype=np.uint16, mode="w+", shape=x.shape)
46 | mem[:] = x
47 | mem = np.memmap(test_path, dtype=np.uint16, mode="w+", shape=x_test.shape)
48 | mem[:] = x_test
49 |
50 | return {
51 | "train": train_path,
52 | "val": test_path,
53 | }
54 |
--------------------------------------------------------------------------------
/src/data/slimpajama.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import numpy as np
3 | import tiktoken
4 | from datasets import load_dataset
5 | import os
6 |
7 |
8 | tknzr = tiktoken.get_encoding("gpt2")
9 |
10 |
11 | def get_slimpajama_data(datasets_dir, num_proc=40):
12 | SPJ_DATA_PATH = os.path.join(datasets_dir, "slimpajama6B/")
13 | if not os.path.exists(os.path.join(SPJ_DATA_PATH, "train.bin")):
14 | os.makedirs(SPJ_DATA_PATH, exist_ok=True)
15 | dataset = load_dataset("DKYoon/SlimPajama-6B")
16 |
17 | split_dataset = dataset["train"].train_test_split(
18 | test_size=0.0005, seed=2357, shuffle=True
19 | )
20 | split_dataset["val"] = split_dataset.pop("test")
21 |
22 | def process(example):
23 | ids = tknzr.encode_ordinary(
24 | example["text"]
25 | ) # encode_ordinary ignores any special tokens
26 | ids.append(
27 | tknzr.eot_token
28 | ) # add the end of text token, e.g. 50256 for gpt2 bpe
29 | out = {"ids": ids, "len": len(ids)}
30 | return out
31 |
32 | # tokenize the dataset
33 | tokenized = split_dataset.map(
34 | process,
35 | remove_columns=["text"],
36 | desc="tokenizing the splits",
37 | num_proc=num_proc,
38 | )
39 |
40 | # concatenate all the ids in each dataset into one large file we can use for training
41 | for split, dset in tokenized.items():
42 | arr_len = np.sum(dset["len"])
43 | filename = os.path.join(SPJ_DATA_PATH, f"{split}.bin")
44 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
45 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,))
46 | total_batches = min(1024, len(dset))
47 |
48 | idx = 0
49 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"):
50 | # Batch together samples for faster write
51 | batch = dset.shard(
52 | num_shards=total_batches, index=batch_idx, contiguous=True
53 | ).with_format("numpy")
54 | arr_batch = np.concatenate(batch["ids"])
55 | # Write into mmap
56 | arr[idx : idx + len(arr_batch)] = arr_batch
57 | idx += len(arr_batch)
58 | arr.flush()
59 |
60 | return {
61 | "train": os.path.join(SPJ_DATA_PATH, "train.bin"),
62 | "val": os.path.join(SPJ_DATA_PATH, "val.bin"),
63 | }
64 |
65 |
66 | def get_slimpajama_chunk1(datasets_dir, num_proc=40):
67 | SPJ_DATA_PATH = os.path.join(datasets_dir, "slimpajama6B/")
68 | SPJ_CHUNK_1_DATA_PATH = os.path.join(SPJ_DATA_PATH, "chunk1")
69 | if not os.path.exists(os.path.join(SPJ_CHUNK_1_DATA_PATH, "train.bin")):
70 | os.makedirs(SPJ_DATA_PATH, exist_ok=True)
71 | dataset = load_dataset("cerebras/SlimPajama-627B", split="train/chunk1")
72 |
73 | split_dataset = dataset["train"].train_test_split(
74 | test_size=0.0005, seed=2357, shuffle=True
75 | )
76 | split_dataset["val"] = split_dataset.pop("test")
77 |
78 | def process(example):
79 | ids = tknzr.encode_ordinary(
80 | example["text"]
81 | ) # encode_ordinary ignores any special tokens
82 | ids.append(
83 | tknzr.eot_token
84 | ) # add the end of text token, e.g. 50256 for gpt2 bpe
85 | out = {"ids": ids, "len": len(ids)}
86 | return out
87 |
88 | # tokenize the dataset
89 | tokenized = split_dataset.map(
90 | process,
91 | remove_columns=["text"],
92 | desc="tokenizing the splits",
93 | num_proc=num_proc,
94 | )
95 |
96 | # concatenate all the ids in each dataset into one large file we can use for training
97 | for split, dset in tokenized.items():
98 | arr_len = np.sum(dset["len"])
99 | filename = os.path.join(SPJ_DATA_PATH, f"{split}.bin")
100 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
101 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,))
102 | total_batches = min(1024, len(dset))
103 |
104 | idx = 0
105 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"):
106 | # Batch together samples for faster write
107 | batch = dset.shard(
108 | num_shards=total_batches, index=batch_idx, contiguous=True
109 | ).with_format("numpy")
110 | arr_batch = np.concatenate(batch["ids"])
111 | # Write into mmap
112 | arr[idx : idx + len(arr_batch)] = arr_batch
113 | idx += len(arr_batch)
114 | arr.flush()
115 |
116 | return {
117 | "train": os.path.join(SPJ_DATA_PATH, "train.bin"),
118 | "val": os.path.join(SPJ_DATA_PATH, "val.bin"),
119 | }
120 |
--------------------------------------------------------------------------------
/src/data/utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import numpy as np
3 | from typing import Dict
4 | import torch
5 | import torch.distributed as dist
6 |
7 | from .shakespeare import get_shakespeare_data
8 | from .wikitext import get_wikitext_data
9 | from .arxiv import get_arxiv_2000, get_arxiv_full
10 | from .openwebtext2 import get_openwebtext2_data
11 | from .redpajama import get_redpajama_data, get_redpajamav2_data
12 | from .slimpajama import get_slimpajama_data
13 |
14 |
15 | def get_dataset(args) -> Dict[str, np.ndarray]:
16 | """Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is
17 | contained in its own python file. The expected format at the moment is a dictionary of np.memmap
18 | containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data.
19 | """
20 | if args.dataset == "wikitext":
21 | return get_wikitext_data(args.datasets_dir)
22 | if args.dataset == "shakespeare-char":
23 | return get_shakespeare_data(args.datasets_dir)
24 | if args.dataset == "arxiv2000":
25 | return get_arxiv_2000(args.datasets_dir)
26 | if args.dataset == "arxiv":
27 | return get_arxiv_full(args.datasets_dir)
28 | if args.dataset == "arxiv+wiki":
29 | arxiv_data = get_arxiv_full(args.datasets_dir)
30 | wiki_data = get_wikitext_data(args.datasets_dir)
31 | train_data = np.concatenate((arxiv_data["train"], wiki_data["train"]))
32 | val_data = np.concatenate((arxiv_data["val"], wiki_data["val"]))
33 | return {"train": train_data, "val": val_data}
34 | if args.dataset == "openwebtext2":
35 | return get_openwebtext2_data(args.datasets_dir)
36 | if args.dataset == "redpajama":
37 | return get_redpajama_data(args.datasets_dir)
38 | if args.dataset == "redpajamav2":
39 | return get_redpajamav2_data(args.datasets_dir)
40 | if args.dataset == "slimpajama":
41 | return get_slimpajama_data(args.datasets_dir)
42 | else:
43 | raise NotImplementedError(f"Unknow dataset key '{args.dataset}'")
44 |
45 |
46 | class DataReader:
47 | def __init__(
48 | self,
49 | data_src,
50 | batch_size,
51 | sequence_length,
52 | seed=1337,
53 | with_replacement=False,
54 | auto_shard=True,
55 | keep_in_ram=False,
56 | ):
57 | if isinstance(data_src, (str, Path)):
58 | self.data_path = Path(data_src)
59 | self.keep_in_ram = keep_in_ram
60 | if keep_in_ram:
61 | self.data = np.array(
62 | np.memmap(self.data_path, dtype=np.uint16, mode="r")
63 | )
64 | else:
65 | self.data = None
66 | elif isinstance(data_src, (np.ndarray, np.memmap)):
67 | self.data_path = None
68 | self.data = data_src
69 | self.keep_in_ram = True
70 |
71 | self.batch_size = batch_size
72 | self.sequence_length = sequence_length
73 | self.seed = seed
74 | self.with_replacement = with_replacement
75 |
76 | self.num_tokens = len(self._get_data())
77 |
78 | if auto_shard and dist.is_initialized():
79 | self.world_size = dist.get_world_size()
80 | self.rank = dist.get_rank()
81 | print(
82 | f"Distributed DataReader Initialized for Worker {self.rank}/{self.world_size}"
83 | )
84 | else:
85 | self.world_size = 1
86 | self.rank = 0
87 |
88 | # Sampling without replacement
89 | self.last_epoch = None
90 | self.order = None
91 | self.epoch_offset = None
92 | self.step = 0
93 | self.num_batches_of_seqlen = 0
94 | if not with_replacement:
95 | self._shuffle_epoch(0)
96 |
97 | def __len__(self):
98 | # Length in valid start indices for a sequence
99 | # Extra -1 to have a valid next token for the final token of the last idx
100 | return self.num_tokens - self.sequence_length - 1
101 |
102 | def _get_data(self):
103 | if self.data is not None:
104 | return self.data
105 | else:
106 | # Construct the memmap each time to avoid a memory leak per NanoGPT
107 | # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
108 | return np.memmap(self.data_path, dtype=np.uint16, mode="r")
109 |
110 | def __getitem__(self, idx):
111 | # Return the underlying datapoint, no random sampling, no worker sharding
112 | assert 0 <= idx < len(self)
113 | data = self._get_data()
114 | x = torch.from_numpy(data[idx : idx + self.sequence_length].astype(np.int64))
115 | y = torch.from_numpy(
116 | data[idx + 1 : idx + self.sequence_length + 1].astype(torch.int64)
117 | )
118 | return x, y
119 |
120 | def set_step(self, step):
121 | self.step = step
122 |
123 | def sample_batch(self):
124 | data = self._get_data()
125 |
126 | if self.with_replacement:
127 | idxs = self._sample_with_replacement(self.step)
128 | else:
129 | idxs = self._sample_without_replacement(self.step)
130 | self.step += 1
131 |
132 | xy = np.stack([data[i : i + self.sequence_length + 1] for i in idxs]).astype(
133 | np.int64
134 | )
135 | x = torch.from_numpy(xy[:, :-1]).contiguous()
136 | y = torch.from_numpy(xy[:, 1:]).contiguous()
137 | return x, y
138 |
139 | def _sample_with_replacement(self, idx):
140 | # Return an array of token indices of length self.batch_size
141 | # Sampled with replacement, can get repeats at any time
142 | seed = self.seed + idx * self.world_size + self.rank
143 | rng = np.random.default_rng(seed)
144 | return rng.integers(len(self), self.batch_size)
145 |
146 | def _shuffle_epoch(self, epoch):
147 | seed = self.seed + epoch
148 | rng = np.random.default_rng(seed)
149 | # Drop one sequence to allow different offsets per epoch:
150 | self.order = rng.permutation((len(self)) // self.sequence_length - 1)
151 | # Shift all sequences in this epoch by this amount:
152 | self.epoch_offset = rng.integers(self.sequence_length)
153 | self.last_epoch = epoch
154 | self.num_batches_of_seqlen = (
155 | len(self.order) // self.batch_size
156 | ) # Drops remainder batch
157 |
158 | def _sample_without_replacement(self, step):
159 | # Return an array of token indices of length self.batch_size
160 | # Sampled without replacement, cycle all sequences before potential repeats
161 | # Sequences are randomly offset in every epoch as well
162 | batch_idx = self.world_size * step + self.rank
163 | epoch_length = self.num_batches_of_seqlen
164 |
165 | epoch = batch_idx // epoch_length
166 | if epoch != self.last_epoch:
167 | self._shuffle_epoch(epoch)
168 | epoch_idx = batch_idx % epoch_length
169 |
170 | start = epoch_idx * self.batch_size
171 | end = start + self.batch_size
172 | return self.order[start:end] * self.sequence_length + self.epoch_offset
173 |
174 | def num_batches(self):
175 | if self.with_replacement:
176 | return self.num_tokens // self.batch_size
177 | return self.num_batches_of_seqlen
178 |
--------------------------------------------------------------------------------
/src/data/wikitext.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zipfile
3 | import urllib
4 | import numpy as np
5 | import tiktoken
6 |
7 |
8 | def get_wikitext_data(datasets_base_dir):
9 | """Inspired from https://github.com/tysam-code/hlb-gpt"""
10 | WIKITEXT_DATA_PATH = os.path.join(datasets_base_dir, "wikitext/")
11 | if not os.path.exists(WIKITEXT_DATA_PATH):
12 | os.makedirs(WIKITEXT_DATA_PATH, exist_ok=True)
13 | print("downloading data and tokenizing (1-2 min)")
14 | raw_data_source = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip"
15 | urllib.request.urlretrieve(
16 | raw_data_source, os.path.join(WIKITEXT_DATA_PATH, "data.zip")
17 | )
18 |
19 | with zipfile.ZipFile(
20 | os.path.join(WIKITEXT_DATA_PATH, "data.zip"), "r"
21 | ) as zip_ref:
22 | zip_ref.extractall(WIKITEXT_DATA_PATH)
23 |
24 | with open(
25 | os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.train.raw"), "r"
26 | ) as data_file:
27 | raw_train_data = data_file.read()
28 |
29 | with open(
30 | os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.valid.raw"), "r"
31 | ) as data_file:
32 | raw_eval_data = data_file.read()
33 |
34 | tokenizer = tiktoken.get_encoding("gpt2")
35 | raw_tokenized_train = tokenizer.encode_ordinary(raw_train_data)
36 | raw_tokenized_eval = tokenizer.encode_ordinary(raw_eval_data)
37 |
38 | train_tokenized = np.array(raw_tokenized_train, dtype=np.uint16)
39 | eval_tokenized = np.array(raw_tokenized_eval, dtype=np.uint16)
40 |
41 | train_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, "train.bin"))
42 | eval_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, "val.bin"))
43 | print("completed the tokenization process!")
44 |
45 | return {
46 | "train": os.path.join(WIKITEXT_DATA_PATH, "train.bin"),
47 | "val": os.path.join(WIKITEXT_DATA_PATH, "val.bin"),
48 | }
49 |
--------------------------------------------------------------------------------
/src/distributed/__init__.py:
--------------------------------------------------------------------------------
1 | from . import ddp
2 | from . import single
3 |
4 | BACKEND_TYPE_TO_MODULE_MAP = {
5 | "nccl": ddp.DataParallelDistributedBackend,
6 | None: single.SinlgeNodeBackend,
7 | }
8 |
9 |
10 | def make_backend_from_args(args):
11 | return BACKEND_TYPE_TO_MODULE_MAP[args.distributed_backend](args)
12 |
13 |
14 | def registered_backends():
15 | return BACKEND_TYPE_TO_MODULE_MAP.keys()
16 |
--------------------------------------------------------------------------------
/src/distributed/backend.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 |
4 | class DistributedBackend(object):
5 |
6 | def __init__(self, args):
7 | pass
8 |
9 | def transform_model(self, model):
10 | raise NotImplementedError
11 |
12 | def get_context_for_microstep_forward(
13 | self, model, microstep_idx, gradient_accumulation_steps
14 | ):
15 | raise NotImplementedError
16 |
17 | def is_master_process(self) -> bool:
18 | raise NotImplementedError
19 |
20 | def get_adjusted_args_for_process(self, args):
21 | raise NotImplementedError
22 |
23 | def get_raw_model(self, model):
24 | raise NotImplementedError
25 |
26 | def translate_model_parameter_name_for_node(self, parameter_name) -> List[str]:
27 | raise NotImplementedError
28 |
29 | def get_world_size(self):
30 | raise NotImplementedError
31 |
32 | def finalize(self):
33 | pass
34 |
--------------------------------------------------------------------------------
/src/distributed/ddp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from contextlib import contextmanager
4 |
5 | from torch.nn.parallel import DistributedDataParallel as DDP
6 | from torch.distributed import init_process_group, destroy_process_group, get_world_size
7 |
8 | from .backend import DistributedBackend
9 |
10 |
11 | class DataParallelDistributedBackend(DistributedBackend):
12 |
13 | def __init__(self, args):
14 | self.rank = int(os.environ.get("RANK", -1))
15 | assert self.rank != -1, "DDP backend can not be used without rank"
16 | assert "cuda" in args.device, "DDP backend can not be used on non-CUDA devices"
17 | init_process_group(backend=args.distributed_backend)
18 | self.local_rank = int(os.environ["LOCAL_RANK"])
19 |
20 | def get_adjusted_args_for_process(self, args):
21 | effective_batch_size = args.batch_size * args.acc_steps
22 | world_size = self.get_world_size()
23 | if effective_batch_size % world_size != 0:
24 | raise ValueError(
25 | f"Effective batch size "
26 | "{effective_batch_size} is not divisible "
27 | "by the world size {world_size}."
28 | )
29 | acc_steps_div = math.gcd(args.acc_steps, world_size)
30 | args.acc_steps = args.acc_steps // acc_steps_div
31 | args.batch_size = args.batch_size // (world_size // acc_steps_div)
32 | args.device = f"cuda:{self.local_rank}"
33 | args.seed = args.seed + self.local_rank
34 | args.data_seed = args.data_seed
35 | return args
36 |
37 | def transform_model(self, model):
38 | return DDP(model, device_ids=[self.local_rank])
39 |
40 | @contextmanager
41 | def get_context_for_microstep_forward(
42 | self, model, microstep_idx, gradient_accumulation_steps
43 | ):
44 | model.require_backward_grad_sync = (
45 | microstep_idx == gradient_accumulation_steps - 1
46 | )
47 | yield
48 |
49 | def is_master_process(self) -> bool:
50 | return self.rank == 0
51 |
52 | def get_raw_model(self, model):
53 | return model.module
54 |
55 | def translate_model_parameter_name_for_node(self, parameter_name):
56 | return [f"module.{parameter_name}"]
57 |
58 | def get_world_size(self):
59 | return get_world_size()
60 |
61 | def finalize(self):
62 | destroy_process_group()
63 |
--------------------------------------------------------------------------------
/src/distributed/single.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 |
3 | from .backend import DistributedBackend
4 |
5 |
6 | class SinlgeNodeBackend(DistributedBackend):
7 |
8 | def __init__(self, args):
9 | super().__init__(args)
10 | self.rank = 0
11 |
12 | def transform_model(self, model):
13 | return model
14 |
15 | def get_context_for_microstep_forward(self, *args, **kwargs):
16 | return nullcontext()
17 |
18 | def get_adjusted_args_for_process(self, args):
19 | return args
20 |
21 | def is_master_process(self) -> bool:
22 | return True
23 |
24 | def get_raw_model(self, model):
25 | return model
26 |
27 | def get_world_size(self):
28 | return 1
29 |
30 | def translate_model_parameter_name_for_node(self, parameter_name):
31 | return [parameter_name]
32 |
--------------------------------------------------------------------------------
/src/logger/logger.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import copy
3 | from functools import wraps
4 | from pathlib import Path
5 | import pickle
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import wandb
10 |
11 |
12 | class DynamicsLogger:
13 | def __init__(self, model, optimizer, cfg, output_folder, wandb=False):
14 | self.model = model
15 | self.optimizer = optimizer
16 | self.iteration = 0
17 | self.output_folder = output_folder
18 | self.wandb = wandb
19 |
20 | self.interval = None # Set in step
21 |
22 | if isinstance(cfg["interval"], int):
23 | self.interval = cfg["interval"]
24 | elif isinstance(cfg["interval"], (tuple, list)):
25 | # Tuples of iter, interval with iter0 < iter1 < ...
26 | # [(iter0, interval0), (iter1, interval1), (iter2, interval2)]
27 | # At iterX change interval to intervalX
28 | self.interval = cfg["interval"][0][1]
29 |
30 | # TODO: Add default cfg here once tested
31 | self.cfg = copy.deepcopy(cfg)
32 | if self.cfg["disk_stats"] == "all":
33 | self.cfg["disk_stats"] = self.cfg["stats"]
34 | if self.cfg["wandb_stats"] == "all":
35 | self.cfg["wandb_stats"] = self.cfg["stats"]
36 |
37 | self.stats = defaultdict(lambda: defaultdict(list))
38 | self.reducers = defaultdict(lambda: "rms")
39 | self.reducers.update(
40 | {
41 | # Make sure any signed stats are not reduced with RMS!
42 | # 'layer_cos_gradient_angle': 'mean',
43 | # 'neuron/cos_gradient_angle': 'mean',
44 | # 'vector/values': 'mean',
45 | # 'vector/update_values': 'mean',
46 | }
47 | )
48 |
49 | # Wrap the step method of the optimizer
50 | self.optimizer.original_step = self.optimizer.step
51 | self.optimizer.step = self_preserving_overwrite(self.step).__get__(
52 | self.optimizer
53 | )
54 |
55 | self.wandb_setup_complete = False
56 |
57 | def step(self, *args, **kwargs):
58 | # Dictionaries keyed by parameter name
59 | # NOTE: Some may be direct references to model / optimizer (do not change in-place)
60 | pre_params = dict()
61 | pre_grads = dict()
62 | pre_states = dict()
63 | post_params = dict()
64 | post_states = dict()
65 |
66 | if isinstance(self.cfg["interval"], int):
67 | self.interval = self.cfg["interval"]
68 | elif isinstance(self.cfg["interval"], (tuple, list)):
69 | # Tuples of iter, interval with iter0 < iter1 < ...
70 | # [(iter0, interval0), (iter1, interval1), (iter2, interval2)]
71 | # At iterX change interval to intervalX
72 | idx = 0
73 | while (
74 | idx < len(self.cfg["interval"])
75 | and self.cfg["interval"][idx][0] <= self.iteration
76 | ):
77 | idx += 1
78 | self.interval = self.cfg["interval"][idx - 1][1]
79 |
80 | if "eps" in self.optimizer.defaults:
81 | eps = self.optimizer.defaults["eps"]
82 | else:
83 | eps = 1e-8
84 |
85 | if self.iteration % self.interval == 0:
86 | for name, param in self.model.named_parameters():
87 | pre_params[name] = param.clone().detach()
88 | if param.grad is not None:
89 | pre_grads[name] = param.grad
90 | else:
91 | pre_grads[name] = None
92 |
93 | pre_states[name] = copy.deepcopy(self.optimizer.state[param])
94 |
95 | self.optimizer.original_step(*args, **kwargs) # Assuming no change to grads
96 |
97 | for name, param in self.model.named_parameters():
98 | post_params[name] = param.detach()
99 | post_states[name] = self.optimizer.state[param]
100 |
101 | self.log_statistics(
102 | pre_params, post_params, pre_grads, pre_states, post_states, eps
103 | )
104 | else:
105 | # Normal optimizer step, no logging
106 | self.optimizer.original_step(*args, **kwargs)
107 |
108 | self.iteration += 1
109 |
110 | @torch.no_grad()
111 | def log_statistics(
112 | self, pre_params, post_params, pre_grads, pre_states, post_states, eps
113 | ):
114 | requested_stats = set(self.cfg["stats"])
115 |
116 | if {"layer_norm", "neuron_norm"} & requested_stats:
117 | for name, param in pre_params.items():
118 | if param.dim() < 2:
119 | # Only higher dimensional weights (linear, conv etc)
120 | continue
121 |
122 | # Compute neuron norms (assume shape K x C x ...)
123 | neuron_norm = torch.linalg.vector_norm(param.flatten(1), dim=1)
124 |
125 | if "layer_norm" in requested_stats:
126 | # This makes more sense with layernorm
127 | # for BN rms_neuron_norm is what we predict (closely related)
128 | layer_norm = torch.linalg.vector_norm(neuron_norm)
129 | self.stats["layer_norm"][name].append(layer_norm)
130 | if "neuron_norm" in requested_stats:
131 | self.stats["neuron_norm"][name].append(neuron_norm)
132 |
133 | if {"layer_grad_norm", "neuron_grad_norm"} & requested_stats:
134 | for name, grad in pre_grads.items():
135 | if grad.dim() < 2:
136 | # Only higher dimensional weights (linear, conv etc)
137 | continue
138 |
139 | # Compute neuron norms (assume shape K x C x ...)
140 | neuron_grad_norm = torch.linalg.vector_norm(grad.flatten(1), dim=1)
141 |
142 | if "layer_grad_norm" in requested_stats:
143 | grad_norm = torch.linalg.vector_norm(neuron_grad_norm)
144 | self.stats["layer_grad_norm"][name].append(grad_norm)
145 | if "neuron_grad_norm" in requested_stats:
146 | self.stats["neuron_grad_norm"][name].append(neuron_grad_norm)
147 |
148 | if {"layer_update_norm", "neuron_update_norm"} & requested_stats:
149 | for name, pre_param in pre_params.items():
150 | if pre_param.dim() < 2:
151 | # Only higher dimensional weights (linear, conv etc)
152 | continue
153 |
154 | post_param = post_params[name]
155 | diff = post_param - pre_param
156 |
157 | # Compute neuron norms (assume shape K x C x ...)
158 | neuron_update_norm = torch.linalg.vector_norm(diff.flatten(1), dim=1)
159 |
160 | if "layer_update_norm" in requested_stats:
161 | layer_norm = torch.linalg.vector_norm(neuron_update_norm)
162 | self.stats["layer_update_norm"][name].append(layer_norm)
163 | if "neuron_update_norm" in requested_stats:
164 | self.stats["neuron_update_norm"][name].append(neuron_update_norm)
165 |
166 | if {"layer_relative_update", "neuron_relative_update"} & requested_stats:
167 | for name, pre_param in pre_params.items():
168 | if pre_param.dim() < 2:
169 | # Only higher dimensional weights (linear, conv etc)
170 | continue
171 |
172 | post_param = post_params[name]
173 | diff = post_param - pre_param
174 |
175 | if "layer_relative_update" in requested_stats:
176 | layer_diff_norm = torch.linalg.vector_norm(diff)
177 | layer_norm = torch.linalg.vector_norm(pre_param)
178 | layer_relative_update = (layer_diff_norm + eps) / (layer_norm + eps)
179 | self.stats["layer_relative_update"][name].append(
180 | layer_relative_update
181 | )
182 |
183 | if "neuron_relative_update" in requested_stats:
184 | neuron_diff_norm = torch.linalg.vector_norm(diff.flatten(1), dim=1)
185 | neuron_norm = torch.linalg.vector_norm(pre_param.flatten(1), dim=1)
186 | neuron_relative_update = (neuron_diff_norm + eps) / (
187 | neuron_norm + eps
188 | )
189 | self.stats["neuron_relative_update"][name].append(
190 | neuron_relative_update
191 | )
192 |
193 | if {"layer_angular_update", "neuron_angular_update"} & requested_stats:
194 | for name, pre_param in pre_params.items():
195 | if pre_param.dim() < 2:
196 | # Only higher dimensional weights (linear, conv etc)
197 | continue
198 |
199 | post_param = post_params[name]
200 | pre_param = (
201 | pre_param.double()
202 | ) # There is a lot of noise for small angles
203 |
204 | if "layer_angular_update" in requested_stats:
205 | cos = F.cosine_similarity(
206 | pre_param.flatten(), post_param.flatten(), dim=0
207 | )
208 | angles = torch.acos(torch.clamp(cos, min=-1, max=1))
209 | self.stats["layer_angular_update"][name].append(angles)
210 |
211 | if "neuron_angular_update" in requested_stats:
212 | cos = F.cosine_similarity(
213 | pre_param.flatten(1), post_param.flatten(1), dim=1
214 | )
215 | angles = torch.acos(torch.clamp(cos, min=-1, max=1))
216 | self.stats["neuron_angular_update"][name].append(angles)
217 |
218 | if {"layer_grad_alignment", "neuron_grad_alignment"} & requested_stats:
219 | for name, pre_param in pre_params.items():
220 | if pre_param.dim() < 2:
221 | # Only higher dimensional weights (linear, conv etc)
222 | continue
223 |
224 | pre_grad = pre_grads[name]
225 | mean_layer_alignment = self.layer_cosine_sim(pre_grad, pre_param)
226 | self.stats["layer_grad_alignment"][name].append(mean_layer_alignment)
227 |
228 | mean_neuron_alignment = self.neuron_cosine_sim(pre_grad, pre_param)
229 | self.stats["neuron_grad_alignment"][name].append(mean_neuron_alignment)
230 |
231 | if {
232 | "layer_grad_velocity_alignment",
233 | "neuron_grad_velocity_alignment",
234 | } & requested_stats:
235 | for name, pre_param in pre_params.items():
236 | if pre_param.dim() < 2:
237 | # Only higher dimensional weights (linear, conv etc)
238 | continue
239 |
240 | # Adam and similar only
241 | if "exp_avg" not in pre_states[name]:
242 | self.stats["layer_grad_velocity_alignment"][name].append(
243 | torch.tensor(0).to(pre_param.device)
244 | )
245 | self.stats["neuron_grad_velocity_alignment"][name].append(
246 | torch.zeros(pre_param.shape[0]).to(pre_param.device)
247 | )
248 | continue
249 |
250 | pre_grad = pre_grads[name]
251 | pre_state = pre_states[name]
252 | pre_m = pre_state["exp_avg"]
253 |
254 | mean_layer_alignment = self.layer_cosine_sim(pre_grad, pre_m)
255 | self.stats["layer_grad_velocity_alignment"][name].append(
256 | mean_layer_alignment
257 | )
258 |
259 | mean_neuron_alignment = self.neuron_cosine_sim(pre_grad, pre_m)
260 | self.stats["neuron_grad_velocity_alignment"][name].append(
261 | mean_neuron_alignment
262 | )
263 |
264 | # TODO: Log averages for scalar vectors
265 | if "scalar_rms" in requested_stats:
266 | for name, pre_param in pre_params.items():
267 | if pre_param.dim() >= 2:
268 | # Only scalars and scalar vectors
269 | continue
270 |
271 | self.stats["scalar_rms"][name].append((param**2).mean().sqrt())
272 |
273 | if "scalar_update_rms" in requested_stats:
274 | for name, pre_param in pre_params.items():
275 | if pre_param.dim() >= 2:
276 | # Only scalars and scalar vectors
277 | continue
278 |
279 | post_param = post_params[name]
280 | diff = post_param - pre_param
281 |
282 | self.stats["scalar_update_rms"][name].append((diff**2).mean().sqrt())
283 |
284 | if "scalar_grad_rms" in requested_stats:
285 | for name, grad in pre_grads.items():
286 | if grad.dim() >= 2:
287 | # Only scalars and scalar vectors
288 | continue
289 |
290 | self.stats["scalar_grad_rms"][name].append((grad**2).mean().sqrt())
291 |
292 | # Could add similar per elements histograms for the following:
293 | # scalar_value
294 | # scalar_update_value
295 | # scalar_grad_value
296 |
297 | # More obscure metrics below, only used in select experiments
298 | if {
299 | "layer_mean_second_grad_moment",
300 | "neuron_mean_second_grad_moment",
301 | } & requested_stats:
302 | for name, pre_param in pre_params.items():
303 | if pre_param.dim() < 2:
304 | # Only higher dimensional weights (linear, conv etc)
305 | continue
306 |
307 | post_state = post_states[name]
308 | post_v = post_state["exp_avg_sq"]
309 |
310 | if "layer_mean_second_grad_moment" in requested_stats:
311 | mean_v = torch.mean(post_v)
312 | self.stats["layer_mean_second_grad_moment"][name].append(mean_v)
313 |
314 | if "neuron_mean_second_grad_moment" in requested_stats:
315 | mean_v = torch.mean(post_v.flatten(1), dim=1)
316 | self.stats["neuron_mean_second_grad_moment"][name].append(mean_v)
317 |
318 | if {
319 | "layer_second_grad_moment_std_mean_ratio",
320 | "neuron_second_grad_moment_std_mean_ratio",
321 | } & requested_stats:
322 | for name, pre_param in pre_params.items():
323 | if pre_param.dim() < 2:
324 | # Only higher dimensional weights (linear, conv etc)
325 | continue
326 |
327 | post_state = post_states[name]
328 | post_v = post_state["exp_avg_sq"]
329 |
330 | v_neuron_mean = post_v.flatten(1).mean(dim=1)
331 | v_neuron_std = post_v.flatten(1).std(dim=1)
332 | neuron_std_mean_ratio = torch.div(v_neuron_std, v_neuron_mean)
333 | self.stats["neuron_second_grad_moment_std_mean_ratio"][name].append(
334 | neuron_std_mean_ratio
335 | )
336 |
337 | v_layer_mean = post_v.flatten(0).mean(dim=0)
338 | v_layer_std = post_v.flatten(0).std(dim=0)
339 | layer_std_mean_ratio = torch.div(v_layer_std, v_layer_mean)
340 | self.stats["layer_second_grad_moment_std_mean_ratio"][name].append(
341 | layer_std_mean_ratio
342 | )
343 |
344 | if {"layer_scaled_grad_norm", "neuron_scaled_grad_norm"} & requested_stats:
345 | for name, pre_param in pre_params.items():
346 | if pre_param.dim() < 2:
347 | # Only higher dimensional weights (linear, conv etc)
348 | continue
349 |
350 | pre_grad = pre_grads[name]
351 | post_state = post_states[name]
352 | post_v = post_state["exp_avg_sq"]
353 |
354 | scaled_grad = torch.div(pre_grad, (post_v.sqrt() + 1e-8))
355 |
356 | layer_scaled_grad_norm = self.layer_norm(scaled_grad)
357 | self.stats["layer_scaled_grad_norm"][name].append(
358 | layer_scaled_grad_norm
359 | )
360 |
361 | neuron_scaled_grad_norm = self.neuron_norm(scaled_grad)
362 | self.stats["neuron_scaled_grad_norm"][name].append(
363 | neuron_scaled_grad_norm
364 | )
365 |
366 | if {
367 | "layer_scaled_grad_wd_projection",
368 | "neuron_scaled_grad_wd_projection",
369 | } & requested_stats:
370 | for name, pre_param in pre_params.items():
371 | if pre_param.dim() < 2:
372 | # Only higher dimensional weights (linear, conv etc)
373 | continue
374 |
375 | pre_grad = pre_grads[name]
376 | post_state = post_states[name]
377 | post_v = post_state["exp_avg_sq"]
378 |
379 | scaled_grad = torch.div(pre_grad, (post_v.sqrt() + 1e-8))
380 | layer_scaled_grad_wd_projection = self.layer_gradient_wd_project(
381 | scaled_grad, pre_param
382 | )
383 | self.stats["layer_scaled_grad_wd_projection"][name].append(
384 | layer_scaled_grad_wd_projection
385 | )
386 |
387 | neuron_scaled_grad_wd_projection = self.neuron_gradient_wd_project(
388 | scaled_grad, pre_param
389 | )
390 | self.stats["neuron_scaled_grad_wd_projection"][name].append(
391 | neuron_scaled_grad_wd_projection
392 | )
393 |
394 | T_disk = self.cfg["disk_save_interval"] or 0
395 | T_wandb = self.cfg["wandb_interval"] or 0
396 |
397 | # Maybe log to disk
398 | if T_disk and (self.iteration + self.interval) % (T_disk * self.interval) == 0:
399 | self.log_to_disk()
400 |
401 | # Maybe log to wandb
402 | if (
403 | self.wandb
404 | and T_wandb
405 | and (self.iteration + self.interval) % (T_wandb * self.interval) == 0
406 | ):
407 | self.log_to_wandb()
408 |
409 | def layer_gradient_wd_project(self, g_t, w_t):
410 | norm = self.layer_norm(w_t)
411 | dot_prod = torch.sum(w_t.flatten() * g_t.flatten(), dim=0)
412 | projection = torch.div(dot_prod, norm * norm)
413 | return projection
414 |
415 | def neuron_gradient_wd_project(self, g_t, w_t):
416 | norm = self.neuron_norm(w_t)
417 | dot_prod = torch.sum(w_t.flatten(1) * g_t.flatten(1), dim=1)
418 | projection = torch.div(dot_prod, norm * norm)
419 | return projection
420 |
421 | def layer_cosine_sim(self, v1, v2):
422 | return F.cosine_similarity(v1.flatten(), v2.flatten(), dim=0)
423 |
424 | def neuron_cosine_sim(self, v1, v2):
425 | return F.cosine_similarity(v1.flatten(1), v2.flatten(1), dim=1)
426 |
427 | def layer_norm(self, v1):
428 | return torch.linalg.vector_norm(v1.flatten(), dim=0)
429 |
430 | def neuron_norm(self, v1):
431 | return torch.linalg.vector_norm(v1.flatten(1), dim=1)
432 |
433 | def log_to_disk(self, free_buffers=True):
434 | out_dict = dict()
435 | T_disk = self.cfg["disk_save_interval"]
436 | for stat_name in self.cfg["disk_stats"]:
437 | out_dict[stat_name] = dict()
438 | reducer = self.reducers[stat_name]
439 |
440 | for param_name, values in self.stats[stat_name].items():
441 | values = torch.stack(values[-T_disk:])
442 | if self.cfg["disk_max_channels"] > 0 and values.dim() > 1:
443 | values = values[:, : self.cfg["disk_max_channels"]]
444 | if self.cfg["disk_downsample"] > 1:
445 | assert T_disk % self.cfg["disk_downsample"] == 0
446 | values = values.reshape(
447 | (
448 | T_disk // self.cfg["disk_downsample"],
449 | self.cfg["disk_downsample"],
450 | -1,
451 | )
452 | )
453 | if reducer == "mean":
454 | values = values.mean(dim=1)
455 | elif reducer == "rms":
456 | values = (values**2).mean(dim=1).sqrt()
457 | elif reducer == "first":
458 | values = values[:, 0]
459 | else:
460 | raise ValueError(f"Unknown {reducer=}")
461 |
462 | values = values.cpu()
463 | out_dict[stat_name][param_name] = values
464 |
465 | out_path = Path(self.output_folder) / "dynamics.pkl"
466 | with open(out_path, "ab") as fp:
467 | # Multiple dumps in a single file
468 | # https://stackoverflow.com/a/12762056
469 | pickle.dump(out_dict, fp)
470 |
471 | if free_buffers:
472 | self.free_buffers("disk")
473 |
474 | def log_to_wandb(self, free_buffers=True):
475 | # Assume stats are logged as a list of tensors for each stat
476 | # Reducer can be individual samples (i.e. the first) or mean
477 |
478 | out_dict = dict()
479 | T_wandb = self.cfg["wandb_interval"]
480 | for stat_name in self.cfg["wandb_stats"]:
481 | out_dict[stat_name] = dict()
482 | reducer = self.reducers[stat_name]
483 |
484 | for param_name, values in self.stats[stat_name].items():
485 | values = torch.stack(values[-T_wandb:])
486 |
487 | if reducer == "mean":
488 | values = values.mean(dim=0)
489 | elif reducer == "global_mean":
490 | values = values.mean(dim=0).mean()
491 | elif reducer == "rms":
492 | values = (values**2).mean(dim=0).sqrt()
493 | elif reducer == "global_rms":
494 | values = (values**2).mean(dim=0).sqrt().mean()
495 | elif reducer == "first":
496 | values = values[0]
497 | else:
498 | raise ValueError(f"Unknown {reducer=}")
499 |
500 | values = values.cpu().numpy()
501 |
502 | if values.size > 1:
503 | values = wandb.Histogram(values)
504 |
505 | out_dict[f"{stat_name}/{param_name}"] = values
506 |
507 | if not self.wandb_setup_complete:
508 | # For whatever reason using globs at init doesn't work
509 | wandb.define_metric("iter")
510 | for stat in out_dict:
511 | wandb.define_metric(stat, step_metric="iter")
512 | self.wandb_setup_complete = True
513 |
514 | out_dict["iter"] = self.iteration - (T_wandb - 1) * self.interval
515 | wandb.log(
516 | data=out_dict,
517 | # step=self.iteration-(T_wandb-1)*self.interval
518 | )
519 |
520 | if free_buffers:
521 | self.free_buffers("wandb")
522 |
523 | def free_buffers(self, set_name="all"):
524 | # Delete old stat values that are no longer needed i.e. those that
525 | # have been logged by both wandb and to disk where appropriate
526 |
527 | if set_name == "all":
528 | self.stats.clear()
529 | return
530 | if set_name == "disk":
531 | main = "disk_stats"
532 | other = "wandb_stats"
533 | elif set_name == "wandb":
534 | main = "wandb_stats"
535 | other = "disk_stats"
536 | else:
537 | raise ValueError(f"Unknown {set_name=}")
538 |
539 | private_stats = set(self.cfg[main]) - set(self.cfg[other])
540 | for stat in private_stats:
541 | del self.stats[stat]
542 |
543 | T_disk = self.cfg["disk_save_interval"] or 0
544 | T_wandb = self.cfg["wandb_interval"] or 0
545 | buffer_size = max(T_disk, T_wandb)
546 | shared_stats = set(self.cfg[main]) & set(self.cfg[other])
547 | for stat_name in shared_stats:
548 | for param_name in self.stats[stat_name]:
549 | new_buffer = self.stats[stat_name][param_name][-buffer_size:]
550 | self.stats[stat_name][param_name] = new_buffer
551 |
552 | @staticmethod
553 | def load_stats(path):
554 | path = Path(path)
555 |
556 | log_fragments = []
557 | with open(path, "rb") as f:
558 | while True:
559 | try:
560 | log_fragments.append(pickle.load(f))
561 | except EOFError:
562 | break
563 |
564 | out_dict = dict()
565 | for stat_name in log_fragments[0]:
566 | stat_dict = {}
567 | for param_name in log_fragments[0][stat_name]:
568 | chunks = []
569 | for log_fragment in log_fragments:
570 | chunks.append(log_fragment[stat_name][param_name])
571 | stat_dict[param_name] = torch.concatenate(chunks)
572 | out_dict[stat_name] = stat_dict
573 | return out_dict
574 |
575 |
576 | def move_to_cpu(data, clone=False):
577 | def recurse(data):
578 | if isinstance(data, dict):
579 | return {k: recurse(v) for k, v in data.items()}
580 | if isinstance(data, list):
581 | return [recurse(v) for v in data]
582 | if isinstance(data, tuple):
583 | return tuple(recurse(v) for v in data)
584 |
585 | if isinstance(data, torch.Tensor):
586 | data = data.detach()
587 | if clone:
588 | data = data.clone()
589 | return data.to(device="cpu")
590 | else:
591 | # Others int, float, str, None etc
592 | if clone:
593 | return copy.deepcopy(data) # Copy just in case
594 | else:
595 | return data
596 |
597 | return recurse(data)
598 |
599 |
600 | # Bind this to the original object to preserve the self property
601 | # E.g. obj.method = self_preserving_overwrite(some_func).__get__(obj)
602 | def self_preserving_overwrite(method):
603 | @wraps(method)
604 | def _impl(inner_self, *args, **kwargs):
605 | return method(*args, **kwargs)
606 |
607 | return _impl
608 |
--------------------------------------------------------------------------------
/src/logger/rotational_logger.yaml:
--------------------------------------------------------------------------------
1 | interval: 5
2 | stats:
3 | [
4 | "layer_norm",
5 | "neuron_norm",
6 | "layer_grad_norm",
7 | "neuron_grad_norm",
8 | "layer_update_norm",
9 | "neuron_update_norm",
10 | "layer_angular_update",
11 | "neuron_angular_update",
12 | "layer_relative_update",
13 | "neuron_relative_update",
14 | "scalar_rms",
15 | "scalar_update_rms",
16 | "scalar_grad_rms",
17 | ]
18 |
19 | disk_save_interval: null # Save f, null to disable
20 | disk_stats: all
21 | disk_max_channels: 16 # Max channels per layer saved
22 | disk_downsample: 10 #
23 |
24 | wandb_interval: 5
25 | wandb_stats: all
26 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from pathlib import Path
4 | import random
5 | import os
6 | import schedulefree
7 |
8 | import numpy as np
9 | import torch
10 | import wandb
11 |
12 | import config
13 | from data.utils import DataReader, get_dataset
14 | import distributed
15 | from models.utils import get_model
16 | from optim.base import train
17 | from optim.utils import cos_inf_schedule, wsd_schedule
18 |
19 |
20 | def main(args):
21 | distributed_backend = distributed.make_backend_from_args(args)
22 | args = distributed_backend.get_adjusted_args_for_process(args)
23 | args.world_size = distributed_backend.get_world_size()
24 |
25 | if args.full_eval_at is None:
26 | args.full_eval_at = []
27 |
28 | # NOTE args.seed is offset per worker in get_adjusted_args_for_process
29 | torch.backends.cuda.matmul.allow_tf32 = True
30 | torch.backends.cudnn.allow_tf32 = True
31 | torch.manual_seed(args.seed)
32 | random.seed(args.seed)
33 | np.random.seed(args.seed)
34 | if "cuda" in args.device:
35 | torch.cuda.set_device(torch.device(args.device))
36 | # torch.use_deterministic_algorithms(True) # CUBLAS_WORKSPACE_CONFIG=:4096:8
37 |
38 | exp_name = get_exp_name(args, distributed_backend)
39 | exp_dir = Path(args.results_base_folder) / exp_name
40 | if distributed_backend.is_master_process() and args.wandb:
41 | wandb.init(
42 | project=args.wandb_project,
43 | name=exp_name,
44 | config=vars(args),
45 | )
46 | wandb.define_metric("iter")
47 | wandb.define_metric("train/*", step_metric="iter")
48 | wandb.define_metric("val/*", step_metric="iter")
49 | wandb.define_metric("lr", step_metric="iter")
50 |
51 | print(f"Starting Experiment: {exp_name}")
52 | print(f"Experiment Directory: {exp_dir}")
53 | print(f"Config:\n{vars(args)}\n")
54 |
55 | print(f"Loading dataset: '{args.dataset}'")
56 | datareaders = get_data_readers(args)
57 |
58 | model = get_model(args).to(args.device)
59 | # TODO: take care of initializing the model if args.use_pretrained != 'none'
60 | print(f"\nModel:\n{model}")
61 |
62 | model = distributed_backend.transform_model(model)
63 | group_specs = distributed_backend.get_raw_model(model).get_parameter_group_specs()
64 | param_name_mapping = {p_name: p for p_name, p in model.named_parameters()}
65 | optimized_params_cnt = 0
66 | for g in group_specs:
67 | params = []
68 | for p_name in g["params"]:
69 | translated_p_names = (
70 | distributed_backend.translate_model_parameter_name_for_node(p_name)
71 | )
72 | params += [param_name_mapping[p_name] for p_name in translated_p_names]
73 | g["params"] = params
74 | optimized_params_cnt += sum([p.numel() for p in g["params"]])
75 | params_cnt = distributed_backend.get_raw_model(model).get_num_params()
76 | print("number of parameters: %.2fM" % (params_cnt / 1e6,))
77 | print("number of optimized parameters: %.2fM" % (optimized_params_cnt / 1e6,))
78 | if args.wandb and distributed_backend.is_master_process():
79 | wandb.log(
80 | {"parameters": params_cnt, "optimized_parameters": optimized_params_cnt}
81 | )
82 |
83 | if args.opt == "adamw":
84 | opt = torch.optim.AdamW(
85 | group_specs,
86 | lr=args.lr,
87 | betas=(args.beta1, args.beta2),
88 | weight_decay=args.weight_decay,
89 | )
90 | elif args.opt == "SFAdamW":
91 | opt = schedulefree.AdamWScheduleFree(
92 | group_specs,
93 | lr=args.lr,
94 | betas=(args.beta1, args.beta2),
95 | weight_decay=args.weight_decay,
96 | warmup_steps=args.warmup_steps,
97 | )
98 |
99 | else:
100 | opt = torch.optim.SGD(
101 | group_specs, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay
102 | )
103 | print(f"\nOptimizer:\n{opt}")
104 |
105 | if args.scheduler != "none":
106 | assert args.warmup_steps < args.iterations, "Warmup steps must be < iterations."
107 | if args.scheduler in ["cos", "linear"]:
108 | # initial lr is args.lr / div_factor
109 | # final lr is initial_lr/final_div_factor = args.lr / div_factor / final_div_factor
110 | scheduler = torch.optim.lr_scheduler.OneCycleLR(
111 | optimizer=opt,
112 | max_lr=[group.get("lr", args.lr) for group in group_specs],
113 | total_steps=args.iterations,
114 | pct_start=args.warmup_steps / args.iterations,
115 | anneal_strategy=args.scheduler,
116 | cycle_momentum=False,
117 | div_factor=1e2,
118 | final_div_factor=0.1,
119 | )
120 | elif args.scheduler == "cos_inf":
121 | lambda_schedule = cos_inf_schedule(
122 | n_iterations=args.iterations,
123 | n_warmup=args.warmup_steps,
124 | n_inf=args.cos_inf_steps,
125 | div_factor=1e2,
126 | final_div_factor=0.1,
127 | )
128 | scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule)
129 | elif args.scheduler == "wsd":
130 | lambda_schedule = wsd_schedule(
131 | n_iterations=args.iterations,
132 | n_warmup=args.warmup_steps,
133 | fract_decay=args.wsd_fract_decay,
134 | init_div_factor=1e2,
135 | final_lr_factor=args.wsd_final_lr_scale, # should be 0 here
136 | decay_type=args.decay_type,
137 | )
138 | scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule)
139 | else:
140 | raise NotImplementedError(f"Unknown scheduler type: {args.scheduler}.")
141 | else:
142 | scheduler = None
143 |
144 | if (exp_dir / "ckpts" / "latest" / "main.pt").exists():
145 | if not args.auto_resume:
146 | raise ValueError(
147 | f"The experiment dir {exp_dir} already exists. "
148 | + "To resume training, set auto_resume=True. "
149 | + "Otherwise, specify a different experiment name. "
150 | )
151 | else:
152 | # Auto resume overwrites resume_from
153 | args.resume_from = str(exp_dir / "ckpts" / "latest")
154 |
155 | elif distributed_backend.is_master_process():
156 | exp_dir.mkdir(parents=True, exist_ok=True)
157 |
158 | stats = train(
159 | model=model,
160 | opt=opt,
161 | datareaders=datareaders,
162 | scheduler=scheduler,
163 | exp_dir=exp_dir,
164 | distributed_backend=distributed_backend,
165 | cfg=args,
166 | )
167 |
168 | stats["args"] = vars(args)
169 | if distributed_backend.is_master_process():
170 | with open(exp_dir / "summary.json", "w") as fs:
171 | json.dump(stats, fs)
172 | distributed_backend.finalize()
173 |
174 |
175 | def get_args():
176 | parser = argparse.ArgumentParser(allow_abbrev=False)
177 | parser.add_argument(
178 | "--config_format", default="base", choices=config.registered_formats()
179 | )
180 |
181 | args, rem_args = parser.parse_known_args()
182 |
183 | return config.parse_args_with_format(
184 | format=args.config_format, base_parser=parser, args=rem_args, namespace=args
185 | )
186 |
187 |
188 | def get_exp_name(args, distributed_backend):
189 | """Returns the name of the experiment, used for saving models and wandb."""
190 | if args.experiment_name is not None:
191 | return args.experiment_name
192 |
193 | rank = distributed_backend.rank
194 |
195 | exp_name = (
196 | f"{args.dataset}_{args.model}_nlayers{args.n_layer}"
197 | f"_nhead{args.n_head}_lr{args.lr}"
198 | f"_sched_{args.scheduler}_warmup{args.warmup_steps}"
199 | f"_decay_{args.decay_type}_{args.wsd_fract_decay}"
200 | f"_iter{args.iterations}"
201 | f"_bs{args.batch_size}x{args.acc_steps}_ws{args.world_size}"
202 | )
203 | # for mup
204 | if args.model == "mup_noam":
205 | exp_name = (
206 | f"{args.dataset}_{args.model}"
207 | f"_opt{args.opt}"
208 | f"_nlayers{args.n_layer}"
209 | # f"_nhead{args.n_head}"
210 | f"_lr{args.lr}"
211 | f"_sched_{args.scheduler}"
212 | f"_decay_{args.decay_type}"
213 | # f"_warmup{args.warmup_steps}"
214 | f"_iter{args.iterations}"
215 | f"_init{args.init_std}_sce{args.scale_emb}"
216 | f"_scd{args.scale_depth}"
217 | # f"_bs{args.batch_size}x{args.acc_steps}_ws{args.world_size}"
218 | )
219 | if args.wandb_run_prefix != "none":
220 | exp_name = args.wandb_run_prefix + "_" + exp_name
221 | exp_name += f"_seed{args.seed - rank}"
222 | exp_name += f"_data_seed{args.data_seed}"
223 |
224 | if args.weight_average:
225 | exp_name += f"_WA"
226 | if args.opt == "SFAdamW":
227 | exp_name += f"_beta1_{args.beta1}"
228 | exp_name += f"_beta2_{args.beta2}"
229 | return exp_name
230 |
231 |
232 | def get_data_readers(args, verbose=True):
233 | data_srcs = get_dataset(args)
234 | train_reader = DataReader(
235 | data_src=data_srcs["train"],
236 | batch_size=args.batch_size,
237 | sequence_length=args.sequence_length,
238 | seed=args.data_seed,
239 | with_replacement=False,
240 | auto_shard=True,
241 | keep_in_ram=args.data_in_ram,
242 | )
243 | val_reader = DataReader(
244 | data_src=data_srcs["val"],
245 | batch_size=args.batch_size,
246 | sequence_length=args.sequence_length,
247 | seed=args.data_seed,
248 | with_replacement=False,
249 | auto_shard=False, # NOTE Identical Per Rank
250 | keep_in_ram=args.data_in_ram,
251 | )
252 |
253 | if verbose:
254 | print(f"Num training tokens: {train_reader.num_tokens}")
255 | print(f"Num validation tokens: {val_reader.num_tokens}")
256 |
257 | return {
258 | "train": train_reader,
259 | "val": val_reader,
260 | }
261 |
262 |
263 | if __name__ == "__main__":
264 | args = get_args()
265 | main(args)
266 |
--------------------------------------------------------------------------------
/src/models/base.py:
--------------------------------------------------------------------------------
1 | """
2 | Full definition of a GPT Language Model, all of it in this single file.
3 | References:
4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5 | https://github.com/openai/gpt-2/blob/master/src/model.py
6 | 2) huggingface/transformers PyTorch implementation:
7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8 | """
9 |
10 | import math
11 |
12 | import tiktoken
13 | import torch
14 | import torch.nn as nn
15 | from torch.nn import functional as F
16 |
17 |
18 | class LayerNorm(nn.Module):
19 | """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
20 |
21 | def __init__(self, ndim, bias):
22 | super().__init__()
23 | self.weight = nn.Parameter(torch.ones(ndim))
24 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
25 |
26 | def forward(self, input):
27 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
28 |
29 |
30 | class CausalSelfAttention(nn.Module):
31 | def __init__(self, config):
32 | super().__init__()
33 | assert config.n_embd % config.n_head == 0
34 | # key, query, value projections for all heads, but in a batch
35 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
36 | # output projection
37 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
38 | # regularization
39 | self.attn_dropout = nn.Dropout(config.dropout)
40 | self.resid_dropout = nn.Dropout(config.dropout)
41 | self.n_head = config.n_head
42 | self.n_embd = config.n_embd
43 | self.dropout = config.dropout
44 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
45 | self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
46 | if not self.flash:
47 | print(
48 | "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0"
49 | )
50 | # causal mask to ensure that attention is only applied to the left in the input sequence
51 | self.register_buffer(
52 | "bias",
53 | torch.tril(
54 | torch.ones(config.sequence_length, config.sequence_length)
55 | ).view(1, 1, config.sequence_length, config.sequence_length),
56 | )
57 |
58 | def forward(self, x):
59 | # batch size, sequence length, embedding dimensionality (n_embd)
60 | (
61 | B,
62 | T,
63 | C,
64 | ) = x.size()
65 |
66 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
67 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
68 | # (B, T, nh, hs)
69 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
70 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
71 |
72 | # (B, nh, T, hs)
73 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
74 |
75 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
76 | if self.flash:
77 | # efficient attention using Flash Attention CUDA kernels
78 | y = torch.nn.functional.scaled_dot_product_attention(
79 | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
80 | )
81 | else:
82 | # manual implementation of attention
83 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
84 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
85 | att = F.softmax(att, dim=-1)
86 | att = self.attn_dropout(att)
87 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
88 | y = (
89 | y.transpose(1, 2).contiguous().view(B, T, C)
90 | ) # re-assemble all head outputs side by side
91 |
92 | # output projection
93 | y = self.resid_dropout(self.c_proj(y))
94 | return y
95 |
96 |
97 | class MLP(nn.Module):
98 | def __init__(self, config, exp_factor=1.0):
99 | super().__init__()
100 | self.dim_exp_factor = exp_factor * 4
101 |
102 | self.c_fc = nn.Linear(
103 | config.n_embd, int(self.dim_exp_factor * config.n_embd), bias=config.bias
104 | )
105 | self.c_proj = nn.Linear(
106 | int(self.dim_exp_factor * config.n_embd), config.n_embd, bias=config.bias
107 | )
108 | self.dropout = nn.Dropout(config.dropout)
109 | self.activation = nn.GELU()
110 |
111 | def forward(self, x):
112 | x = self.c_fc(x)
113 | x = self.activation(x)
114 | x = self.c_proj(x)
115 | x = self.dropout(x)
116 | return x, {}
117 |
118 |
119 | class Block(nn.Module):
120 | def __init__(self, config):
121 | super().__init__()
122 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
123 | self.attn = CausalSelfAttention(config)
124 | self.parallel = config.parallel_block
125 | if not self.parallel:
126 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
127 | self.mlp = MLP(config)
128 |
129 | def forward(self, x, *args, **kwargs):
130 | if self.parallel:
131 | # from GPT-J 6B https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L299
132 | x_ln = self.ln_1(x, *args, **kwargs)
133 | x_attn = self.attn(x_ln)
134 | x_ffn = self.mlp(x_ln)
135 | x = x + x_attn + x_ffn
136 | else:
137 | x = x + self.attn(self.ln_1(x, *args, **kwargs))
138 | x_ = self.mlp(self.ln_2(x, *args, **kwargs))
139 | x = x + x_
140 | return x
141 |
142 |
143 | class GPTBase(nn.Module):
144 | def __init__(self, config):
145 | super().__init__()
146 | assert config.vocab_size is not None
147 | assert config.sequence_length is not None
148 | self.config = config
149 | self.tokenizer = tiktoken.get_encoding("gpt2")
150 |
151 | self.transformer = nn.ModuleDict(
152 | dict(
153 | wte=nn.Embedding(config.vocab_size, config.n_embd),
154 | wpe=nn.Embedding(config.sequence_length, config.n_embd),
155 | drop=nn.Dropout(config.dropout),
156 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
157 | ln_f=LayerNorm(config.n_embd, bias=config.bias),
158 | )
159 | )
160 |
161 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
162 | # with weight tying when using torch.compile() some warnings get generated:
163 | # "UserWarning: functional_call was passed multiple values for tied weights.
164 | # This behavior is deprecated and will be an error in future versions"
165 | # not 100% sure what this is, so far seems to be harmless. TODO investigate
166 | self.transformer.wte.weight = (
167 | self.lm_head.weight
168 | ) # https://paperswithcode.com/method/weight-tying
169 |
170 | # init all weights
171 | self.apply(self._init_weights)
172 | # apply special scaled init to the residual projections, per GPT-2 paper
173 | for pn, p in self.named_parameters():
174 | if pn.endswith("c_proj.weight"):
175 | torch.nn.init.normal_(
176 | p,
177 | mean=0.0,
178 | std=self.config.init_std / math.sqrt(2 * config.n_layer),
179 | )
180 |
181 | def get_num_params(self, non_embedding=True):
182 | """
183 | Return the number of parameters in the model.
184 | For non-embedding count (default), the position embeddings get subtracted.
185 | The token embeddings would too, except due to the parameter sharing these
186 | params are actually used as weights in the final layer, so we include them.
187 | """
188 | n_params = sum(p.numel() for p in self.parameters())
189 | if non_embedding:
190 | n_params -= self.transformer.wpe.weight.numel()
191 | return n_params
192 |
193 | def _init_weights(self, module):
194 | if isinstance(module, nn.Linear):
195 | torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
196 | if module.bias is not None:
197 | torch.nn.init.zeros_(module.bias)
198 | elif isinstance(module, nn.Embedding):
199 | torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
200 |
201 | def forward(self, idx, targets=None, get_logits=False):
202 | device = idx.device
203 | b, t = idx.size()
204 | assert (
205 | t <= self.config.sequence_length
206 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
207 | # shape (1, t)
208 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
209 |
210 | # forward the GPT model itself
211 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
212 | pos_emb = self.transformer.wpe(
213 | pos
214 | ) # position embeddings of shape (1, t, n_embd)
215 | x = self.transformer.drop(tok_emb + pos_emb)
216 |
217 | # router logits is a list for each layer's routing, each of shape (b * seq_len, n_experts)
218 | router_logits = []
219 | # experts is a list for each layer's selected experts, shape (b * seq_len, topk)
220 | experts = []
221 |
222 | # forward pass through all the transformer blocks
223 | for block in self.transformer.h:
224 | x, logits_and_experts = block(x)
225 | x = self.transformer.ln_f(x)
226 |
227 | if targets is not None:
228 | # if we are given some desired targets also calculate the loss
229 | logits = self.lm_head(x)
230 | loss = F.cross_entropy(
231 | logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
232 | )
233 |
234 | else:
235 | # inference-time mini-optimization: only forward the lm_head on the very last position
236 | logits = self.lm_head(
237 | x[:, [-1], :]
238 | ) # note: using list [-1] to preserve the time dim
239 | loss = None
240 | logits = logits if get_logits else None
241 | return {
242 | "logits": logits,
243 | "loss": loss,
244 | }
245 |
246 | def crop_sequence_length(self, sequence_length):
247 | # model surgery to decrease the block size if necessary
248 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
249 | # but want to use a smaller block size for some smaller, simpler model
250 | assert sequence_length <= self.config.sequence_length
251 | self.config.sequence_length = sequence_length
252 | self.transformer.wpe.weight = nn.Parameter(
253 | self.transformer.wpe.weight[:sequence_length]
254 | )
255 | for block in self.transformer.h:
256 | block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length]
257 |
258 | def from_pretrained(
259 | self,
260 | model_path,
261 | ):
262 | paths = model_path.split(",")
263 | if len(paths) == 1:
264 | # TODO: with distributed?
265 | loaded_state = torch.load(
266 | str(model_path + "/ckpt.pt"),
267 | map_location=torch.device(self.config.device),
268 | )
269 | state_to_load = loaded_state["model"]
270 |
271 | # load the sparse model
272 | state_to_load = {
273 | ".".join(k.split(".")[1:]): v # drop _orig_mod from keys
274 | for k, v in state_to_load.items()
275 | }
276 |
277 | def get_parameter_group_specs(self):
278 | """
279 | This long function is unfortunately doing something very simple and is being very defensive:
280 | We are separating out all parameters of the model into two buckets: those that will experience
281 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
282 | We are then returning the PyTorch optimizer object.
283 | """
284 |
285 | # separate out all parameters to those that will and won't experience regularizing weight decay
286 | decay = set()
287 | no_decay = set()
288 | whitelist_weight_modules = (torch.nn.Linear,)
289 | # need to do import here to avoid circular import (since llama imports from base here)
290 | from .utils import BLACKLIST_WEIGHT_MODULES
291 |
292 | for mn, m in self.named_modules():
293 | for pn, p in m.named_parameters():
294 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
295 | # random note: because named_modules and named_parameters are recursive
296 | # we will see the same tensors p many many times. but doing it this way
297 | # allows us to know which parent module any tensor p belongs to...
298 | if pn.endswith("bias"):
299 | # all biases will not be decayed
300 | no_decay.add(fpn)
301 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
302 | # weights of whitelist modules will be weight decayed
303 | decay.add(fpn)
304 | elif pn.endswith("weight") and isinstance(m, BLACKLIST_WEIGHT_MODULES):
305 | # weights of blacklist modules will NOT be weight decayed
306 | no_decay.add(fpn)
307 |
308 | # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
309 | # will appear in the no_decay and decay sets respectively after the above.
310 | # In addition, because named_parameters() doesn't return duplicates, it
311 | # will only return the first occurence, key'd by 'transformer.wte.weight', below.
312 | # so let's manually remove 'lm_head.weight' from decay set. This will include
313 | # this tensor into optimization via transformer.wte.weight only, and not decayed.
314 | decay.remove("lm_head.weight")
315 |
316 | # validate that we considered every parameter
317 | param_dict = {pn: p for pn, p in self.named_parameters()}
318 | inter_params = decay & no_decay
319 | union_params = decay | no_decay
320 | assert (
321 | len(inter_params) == 0
322 | ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
323 | assert (
324 | len(param_dict.keys() - union_params) == 0
325 | ), "parameters %s were not separated into either decay/no_decay set!" % (
326 | str(param_dict.keys() - union_params),
327 | )
328 |
329 | # create the pytorch optimizer object
330 | return [
331 | {"params": sorted(list(decay))},
332 | {"params": sorted(list(no_decay)), "weight_decay": 0.0},
333 | ]
334 |
335 | @torch.no_grad()
336 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
337 | """
338 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
339 | the sequence max_new_tokens times, feeding the predictions back into the model each time.
340 | Most likely you'll want to make sure to be in model.eval() mode of operation for this.
341 | """
342 | for _ in range(max_new_tokens):
343 | # if the sequence context is growing too long we must crop it at sequence_length
344 | idx_cond = (
345 | idx
346 | if idx.size(1) <= self.config.sequence_length
347 | else idx[:, -self.config.sequence_length :]
348 | )
349 | # forward the model to get the logits for the index in the sequence
350 | logits = self(idx_cond, get_logits=True)["logits"]
351 | # pluck the logits at the final step and scale by desired temperature
352 | logits = logits[:, -1, :] / temperature
353 | # optionally crop the logits to only the top k options
354 | if top_k is not None:
355 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
356 | logits[logits < v[:, [-1]]] = -float("Inf")
357 | # apply softmax to convert logits to (normalized) probabilities
358 | probs = F.softmax(logits, dim=-1)
359 | # sample from the distribution
360 | idx_next = torch.multinomial(probs, num_samples=1)
361 | # append sampled index to the running sequence and continue
362 | idx = torch.cat((idx, idx_next), dim=1)
363 |
364 | return idx
365 |
366 | @torch.no_grad()
367 | def generate_from_string(self, in_str, max_new_tokens, temperature=1.0, top_k=None):
368 | idx = (
369 | torch.tensor(
370 | self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})
371 | )
372 | .view(1, -1)
373 | .to(self.lm_head.weight.device)
374 | )
375 | out_idx = (
376 | self.generate(idx, max_new_tokens, temperature, top_k)
377 | .view(-1)
378 | .to("cpu")
379 | .numpy()
380 | )
381 | return self.tokenizer.decode(out_idx)
382 |
--------------------------------------------------------------------------------
/src/models/llama.py:
--------------------------------------------------------------------------------
1 | """
2 | Llama style Language Model that is
3 | compilable (avoids torch complex)
4 | """
5 |
6 | import math
7 |
8 | import tiktoken
9 | import torch
10 | import torch.nn as nn
11 | from torch.nn import functional as F
12 | from models.base import CausalSelfAttention, GPTBase
13 |
14 |
15 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
16 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
17 | t = torch.arange(end, device=freqs.device) # type: ignore
18 | freqs = torch.outer(t, freqs).float() # type: ignore
19 | cos_freqs = torch.cos(freqs)
20 | sin_freqs = torch.sin(freqs)
21 | # Stack the cos and sin parts in the last dimension to simulate complex numbers
22 | return torch.stack((cos_freqs, sin_freqs), dim=-1)
23 |
24 |
25 | def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
26 | """
27 | freqs_cis: complex - (seq_len, head_dim / 2)
28 | x: complex - (bsz, seq_len, head_dim / 2)
29 | """
30 | ndim = x.ndim
31 | assert 1 < ndim
32 | assert freqs_cis.shape[:-1] == (x.shape[1], x.shape[-2])
33 | # New shape for broadcasting
34 | shape = [
35 | 1 if i != 1 and i != ndim - 2 else d for i, d in enumerate(x.shape[:-1])
36 | ] + [2]
37 | return freqs_cis.view(*shape)
38 |
39 |
40 | def apply_rotary_emb(q, k, freqs_cis):
41 | # q, k: (B, T, nh, hs)
42 | # freq_cis: (T, hs)
43 | # return: (B, T, nh, hs), (B, T, nh, hs)
44 | q = q.float().reshape(*q.shape[:-1], -1, 2)
45 | k = k.float().reshape(*k.shape[:-1], -1, 2)
46 |
47 | freqs_cis = _reshape_for_broadcast(freqs_cis, q)
48 |
49 | # Perform manual "complex" multiplication
50 | q_cos = q[..., 0] * freqs_cis[..., 0] - q[..., 1] * freqs_cis[..., 1]
51 | q_sin = q[..., 0] * freqs_cis[..., 1] + q[..., 1] * freqs_cis[..., 0]
52 | k_cos = k[..., 0] * freqs_cis[..., 0] - k[..., 1] * freqs_cis[..., 1]
53 | k_sin = k[..., 0] * freqs_cis[..., 1] + k[..., 1] * freqs_cis[..., 0]
54 |
55 | # Combine the results back into the interleaved format expected by q and k
56 | q_out = torch.stack((q_cos, q_sin), dim=-1).reshape(q.shape).flatten(3)
57 | k_out = torch.stack((k_cos, k_sin), dim=-1).reshape(k.shape).flatten(3)
58 |
59 | return q_out, k_out
60 |
61 |
62 | class RMSNorm(nn.Module):
63 | def __init__(self, dim: int, eps: float = 1e-6):
64 | super().__init__()
65 | self.eps = eps
66 | self.weight = nn.Parameter(torch.ones(dim))
67 |
68 | def _norm(self, x):
69 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
70 |
71 | def forward(self, x):
72 | output = self._norm(x.float()).type_as(x)
73 | return output * self.weight
74 |
75 |
76 | class LlamaMLP(nn.Module):
77 | def __init__(self, config):
78 | super().__init__()
79 |
80 | hidden_dim = config.n_embd * 4
81 | hidden_dim = int(2 * hidden_dim / 3)
82 | hidden_dim = config.multiple_of * (
83 | (hidden_dim + config.multiple_of - 1) // config.multiple_of
84 | )
85 |
86 | self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False)
87 | self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=False)
88 | self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
89 |
90 | def forward(self, x):
91 | return self.c_proj(nn.functional.silu(self.w1(x)) * self.w2(x))
92 |
93 |
94 | class LlamaAttention(CausalSelfAttention):
95 |
96 | def forward(self, x, freqs_cis):
97 | # batch size, sequence length, embedding dimensionality (n_embd)
98 | (
99 | B,
100 | T,
101 | C,
102 | ) = x.size()
103 |
104 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
105 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
106 | # (B, T, nh, hs)
107 | k = k.view(B, T, self.n_head, C // self.n_head)
108 | q = q.view(B, T, self.n_head, C // self.n_head)
109 | q, k = apply_rotary_emb(q, k, freqs_cis)
110 | # (B, nh, T, hs)
111 | q, k = q.transpose(1, 2), k.transpose(1, 2)
112 |
113 | # (B, nh, T, hs)
114 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
115 |
116 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
117 | if self.flash:
118 | # efficient attention using Flash Attention CUDA kernels
119 | y = torch.nn.functional.scaled_dot_product_attention(
120 | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
121 | )
122 | else:
123 | # manual implementation of attention
124 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
125 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
126 | att = F.softmax(att, dim=-1)
127 | att = self.attn_dropout(att)
128 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
129 | y = (
130 | y.transpose(1, 2).contiguous().view(B, T, C)
131 | ) # re-assemble all head outputs side by side
132 |
133 | # output projection
134 | y = self.resid_dropout(self.c_proj(y))
135 | return y
136 |
137 |
138 | class LlamaBlock(nn.Module):
139 | def __init__(self, config):
140 | super().__init__()
141 | self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps)
142 | self.attn = LlamaAttention(config)
143 | self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps)
144 | self.mlp = LlamaMLP(config)
145 |
146 | def forward(self, x, freqs_cis):
147 | x = x + self.attn(self.ln_1(x), freqs_cis)
148 | x_ = self.mlp(self.ln_2(x))
149 | x = x + x_
150 | return x
151 |
152 |
153 | class Llama(GPTBase):
154 | def __init__(self, config):
155 | super().__init__(config)
156 | assert config.vocab_size is not None
157 | assert config.sequence_length is not None
158 | self.config = config
159 | self.tokenizer = tiktoken.get_encoding("gpt2")
160 |
161 | # create the token and position embeddings
162 | self.head_dim = config.n_embd // config.n_head
163 | self.freqs_cis = precompute_freqs_cis(self.head_dim, config.sequence_length)
164 |
165 | self.transformer = nn.ModuleDict(
166 | dict(
167 | wte=nn.Embedding(config.vocab_size, config.n_embd),
168 | drop=nn.Dropout(config.dropout),
169 | h=nn.ModuleList([LlamaBlock(config) for _ in range(config.n_layer)]),
170 | ln_f=RMSNorm(config.n_embd, eps=config.rmsnorm_eps),
171 | )
172 | )
173 |
174 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
175 | # with weight tying when using torch.compile() some warnings get generated:
176 | # "UserWarning: functional_call was passed multiple values for tied weights.
177 | # This behavior is deprecated and will be an error in future versions"
178 | # not 100% sure what this is, so far seems to be harmless. TODO investigate
179 | self.transformer.wte.weight = (
180 | self.lm_head.weight
181 | ) # https://paperswithcode.com/method/weight-tying
182 |
183 | # init all weights
184 | self.apply(self._init_weights)
185 | # apply special scaled init to the residual projections, per GPT-2 paper
186 | for pn, p in self.named_parameters():
187 | if pn.endswith("c_proj.weight"):
188 | torch.nn.init.normal_(
189 | p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
190 | )
191 |
192 | def get_num_params(self, non_embedding=True):
193 | """
194 | Return the number of parameters in the model.
195 | For non-embedding count (default)
196 | The token embeddings would too, except due to the parameter sharing these
197 | params are actually used as weights in the final layer, so we include them.
198 | """
199 | n_params = sum(p.numel() for p in self.parameters())
200 | return n_params
201 |
202 | def forward(self, idx, targets=None, get_logits=False):
203 | device = idx.device
204 | b, t = idx.size()
205 | assert (
206 | t <= self.config.sequence_length
207 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
208 | # shape (1, t)
209 | pos = torch.arange(0, t, dtype=torch.long, device=device)
210 |
211 | # forward the GPT model itself
212 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
213 |
214 | x = self.transformer.drop(tok_emb)
215 | freqs_cis = self.freqs_cis.to(x.device)[pos]
216 |
217 | for block in self.transformer.h:
218 | x = block(x, freqs_cis=freqs_cis)
219 | x = self.transformer.ln_f(x)
220 |
221 | if targets is not None:
222 | # if we are given some desired targets also calculate the loss
223 | logits = self.lm_head(x)
224 | loss = F.cross_entropy(
225 | logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
226 | )
227 | else:
228 | # inference-time mini-optimization: only forward the lm_head on the very last position
229 | logits = self.lm_head(
230 | x[:, [-1], :]
231 | ) # note: using list [-1] to preserve the time dim
232 | loss = None
233 |
234 | logits = logits if get_logits else None
235 |
236 | return {
237 | "logits": logits,
238 | "loss": loss,
239 | }
240 |
--------------------------------------------------------------------------------
/src/models/utils.py:
--------------------------------------------------------------------------------
1 | from .llama import Llama, RMSNorm
2 | from .base import GPTBase, LayerNorm
3 | import torch
4 |
5 | BLACKLIST_WEIGHT_MODULES = (
6 | torch.nn.LayerNorm,
7 | LayerNorm,
8 | RMSNorm,
9 | torch.nn.Embedding,
10 | )
11 |
12 |
13 | def get_model(args):
14 | """Return the right model"""
15 | if args.model == "base":
16 | model = GPTBase(args)
17 | if args.use_pretrained != "none":
18 | model.from_pretrained(args.use_pretrained)
19 | return model
20 | elif args.model == "llama":
21 | model = Llama(args)
22 | if args.use_pretrained != "none":
23 | raise NotImplementedError(
24 | f"Loading of pretrained models not yet implemented for model '{args.model}'."
25 | )
26 | return model
27 | else:
28 | raise KeyError(f"Unknown model '{args.model}'.")
29 |
--------------------------------------------------------------------------------
/src/optim/base.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 | import copy
3 | from pathlib import Path
4 | import time
5 | import yaml
6 |
7 | import torch
8 | import wandb
9 |
10 | from logger.logger import DynamicsLogger
11 | from optim.weight_averaging import (
12 | WeightAverager,
13 | eval_ema,
14 | eval_wa,
15 | ExponentialWeightAverager,
16 | )
17 | from .utils import (
18 | eval,
19 | get_batch,
20 | load_checkpoint,
21 | load_worker_state,
22 | save_checkpoint,
23 | save_worker_state,
24 | )
25 |
26 |
27 | def train(
28 | model,
29 | opt,
30 | datareaders,
31 | scheduler,
32 | exp_dir,
33 | distributed_backend,
34 | cfg,
35 | ):
36 | not_compiled_model = model
37 | if cfg.compile:
38 | print(f"Compiling model ...")
39 | model = torch.compile(model)
40 |
41 | if "cuda" in cfg.device:
42 | type_ctx = torch.amp.autocast(
43 | device_type="cuda",
44 | dtype={
45 | "float32": torch.float32,
46 | "float16": torch.float16,
47 | "bfloat16": torch.bfloat16,
48 | }[cfg.dtype],
49 | )
50 | else:
51 | type_ctx = nullcontext()
52 |
53 | if cfg.resume_from:
54 | # This is a full resume including the model weights, optimizer, state
55 | # dataloader state, random seed, etc. Not indended for fine tuning or
56 | # other scenarios where some of these should change.
57 | print(f"\nResuming Training From {cfg.resume_from}")
58 | ckpt_dir = Path(cfg.resume_from)
59 | curr_iter = load_checkpoint(
60 | model,
61 | opt,
62 | scheduler,
63 | ckpt_dir / "main.pt",
64 | cfg.device,
65 | )
66 | load_worker_state(ckpt_dir)
67 | else:
68 | curr_iter = 0
69 |
70 | if cfg.weight_average:
71 | # This does generally not support resuming training, but will work if
72 | # cfg.wa_interval perfectly divides the iteration number of the chkpt.
73 | # Otherwise, the first avg will not be correctly computed, with a bias
74 | # towards the first sample and missing values for earlier iterations.
75 | weight_averager = WeightAverager(
76 | not_compiled_model,
77 | horizon=cfg.wa_horizon,
78 | interval=cfg.wa_interval,
79 | save_dir=None if cfg.wa_use_temp_dir else exp_dir / "avgs",
80 | dtype={
81 | "float32": torch.float32,
82 | "float64": torch.float64,
83 | }[cfg.wa_dtype],
84 | count=curr_iter,
85 | )
86 |
87 | if cfg.exponential_moving_average:
88 | ema = ExponentialWeightAverager(
89 | not_compiled_model,
90 | interval=cfg.ema_interval,
91 | decay=cfg.ema_decay,
92 | warmup=cfg.warmup_steps if cfg.ema_after_warmup else 0,
93 | dtype={
94 | "float32": torch.float32,
95 | "float64": torch.float64,
96 | }[cfg.wa_dtype],
97 | )
98 |
99 | if distributed_backend.is_master_process() and cfg.log_dynamics:
100 | with open(cfg.dynamics_logger_cfg, "r") as f:
101 | dlcfg = yaml.safe_load(f)
102 |
103 | # Hooks into optimizer
104 | dlogger = DynamicsLogger(
105 | model, opt, dlcfg, cfg.results_base_folder, wandb=cfg.wandb
106 | )
107 | dlogger.iteration = curr_iter
108 |
109 | substep = curr_iter * cfg.acc_steps
110 | train_reader, val_reader = datareaders["train"], datareaders["val"]
111 | train_reader.set_step(substep)
112 | stats = {"train_loss": [], "val_loss": [], "val_pp": [], "val_acc": []}
113 | model.train()
114 |
115 | while curr_iter <= cfg.iterations:
116 | # Save permanent checkpoint
117 | if cfg.permanent_ckpt_interval > 0:
118 | if curr_iter % cfg.permanent_ckpt_interval == 0:
119 | ckpt_dir = exp_dir / "ckpts" / str(curr_iter)
120 | if distributed_backend.is_master_process():
121 | save_checkpoint(model, opt, scheduler, curr_iter, ckpt_dir)
122 | save_worker_state(ckpt_dir)
123 |
124 | # Save temporary checkpoint for resuming training
125 | if cfg.latest_ckpt_interval > 0:
126 | if curr_iter % cfg.latest_ckpt_interval == 0 or curr_iter == cfg.iterations:
127 | ckpt_dir = exp_dir / "ckpts" / "latest"
128 | if distributed_backend.is_master_process():
129 | save_checkpoint(model, opt, scheduler, curr_iter, ckpt_dir)
130 | save_worker_state(ckpt_dir)
131 |
132 | ws = distributed_backend.get_world_size()
133 | tokens = ws * substep * cfg.sequence_length * cfg.batch_size
134 | epoch = tokens / train_reader.num_tokens
135 | if (
136 | curr_iter % cfg.eval_interval == 0
137 | or curr_iter == cfg.iterations
138 | or (curr_iter in cfg.full_eval_at)
139 | ):
140 | eval_and_log(
141 | curr_iter,
142 | epoch,
143 | model,
144 | val_reader,
145 | type_ctx,
146 | distributed_backend,
147 | cfg,
148 | opt,
149 | full_eval=(curr_iter in cfg.full_eval_at),
150 | )
151 |
152 | if curr_iter > cfg.wa_interval and cfg.weight_average:
153 | eval_wa(
154 | curr_iter,
155 | not_compiled_model,
156 | weight_averager,
157 | val_reader,
158 | type_ctx,
159 | distributed_backend,
160 | cfg,
161 | full_eval=(curr_iter in cfg.full_eval_at),
162 | )
163 | if cfg.exponential_moving_average:
164 | eval_ema(
165 | curr_iter,
166 | not_compiled_model,
167 | ema,
168 | val_reader,
169 | type_ctx,
170 | distributed_backend,
171 | cfg,
172 | full_eval=(curr_iter in cfg.full_eval_at),
173 | )
174 |
175 | if curr_iter == cfg.iterations:
176 | # Save checkpoints and evaluate at final iteration, but no need to train further
177 | break
178 |
179 | # Train model
180 | t_start = time.perf_counter_ns()
181 | for microstep_idx in range(cfg.acc_steps): # gradient accumulation
182 | x, y = get_batch(train_reader, device=cfg.device)
183 | with type_ctx:
184 | with distributed_backend.get_context_for_microstep_forward(
185 | model=model,
186 | microstep_idx=microstep_idx,
187 | gradient_accumulation_steps=cfg.acc_steps,
188 | ):
189 | outputs = model(x, targets=y)
190 |
191 | loss = outputs["loss"] / cfg.acc_steps
192 | loss.backward()
193 | substep += 1
194 |
195 | if cfg.grad_clip != 0.0:
196 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
197 | if cfg.opt == "SFAdamW":
198 | opt.train()
199 | opt.step()
200 | scheduler.step()
201 | opt.zero_grad(set_to_none=True)
202 | if cfg.weight_average:
203 | weight_averager.step(not_compiled_model, distributed_backend.is_master_process())
204 | if cfg.exponential_moving_average:
205 | ema.step(not_compiled_model, distributed_backend.is_master_process())
206 | dt = (time.perf_counter_ns() - t_start) / 1e9
207 |
208 | curr_iter += 1
209 |
210 | if (
211 | cfg.log_interval
212 | and curr_iter % cfg.log_interval == 0
213 | and distributed_backend.is_master_process() # Only log on master rank
214 | ):
215 | train_loss = loss.detach().cpu().item() * cfg.acc_steps
216 |
217 | current_lrs = [param_group["lr"] for param_group in opt.param_groups]
218 |
219 | print(
220 | f"Train: Iter={curr_iter} ({epoch:0.3f} epochs) "
221 | f"train_loss={train_loss:.3f} iter_dt={dt:.2e}s "
222 | f"lr={current_lrs[0]:.2e}"
223 | )
224 |
225 | if cfg.wandb:
226 | wandb.log(
227 | {
228 | "iter": curr_iter,
229 | "train/loss": train_loss,
230 | "train/perplexity": 2.71828**train_loss,
231 | "lr": current_lrs[0],
232 | "iter_dt": dt,
233 | }
234 | )
235 |
236 | return stats
237 |
238 |
239 | def eval_and_log(
240 | curr_iter,
241 | epoch,
242 | model,
243 | val_reader,
244 | type_ctx,
245 | distributed_backend,
246 | cfg,
247 | opt,
248 | full_eval=False,
249 | ):
250 | if not distributed_backend.is_master_process():
251 | # Only evaluate and log on master rank
252 | return
253 |
254 | model.eval()
255 | if cfg.opt == "SFAdamW":
256 | opt.eval()
257 |
258 | if curr_iter == cfg.iterations or full_eval:
259 | max_num_batches = val_reader.num_batches()
260 | else:
261 | max_num_batches = cfg.eval_batches
262 |
263 | # to make sure we start from the beginning of the validation set,
264 | # i.e. repeat the same batches
265 | val_reader.set_step(0)
266 | val_acc, val_loss, val_perplexity = eval(
267 | model,
268 | val_reader,
269 | cfg.device,
270 | max_num_batches=max_num_batches,
271 | ctx=type_ctx,
272 | cfg=cfg,
273 | )
274 |
275 | print(
276 | f">Eval: Iter={curr_iter} ({epoch:0.3f} epochs) "
277 | f"val_loss={val_loss:.3f} "
278 | f"val_pp={val_perplexity:.3f} "
279 | f"val_acc={val_acc:3f}"
280 | )
281 |
282 | if cfg.wandb:
283 | if curr_iter == cfg.iterations or full_eval:
284 | logs = {
285 | "iter": curr_iter,
286 | "final-val/loss": val_loss,
287 | "final-val/perplexity": val_perplexity,
288 | "final-val/acc": val_acc,
289 | }
290 | else:
291 | logs = {
292 | "iter": curr_iter,
293 | "val/loss": val_loss,
294 | "val/perplexity": val_perplexity,
295 | "val/acc": val_acc,
296 | }
297 |
298 | wandb.log(logs)
299 | if cfg.eval_seq_prefix != "none" and (
300 | curr_iter % (cfg.eval_interval * 5) == 0 or curr_iter == cfg.iterations
301 | ):
302 | text_table = wandb.Table(columns=["itr", "val-pp", "text"])
303 |
304 | out_str = distributed_backend.get_raw_model(model).generate_from_string(
305 | cfg.eval_seq_prefix,
306 | max_new_tokens=40,
307 | temperature=0.9,
308 | top_k=None,
309 | )
310 | text_table.add_data(curr_iter, val_perplexity, out_str)
311 | # why a copy? see github.com/wandb/wandb/issues/2981
312 | wandb.log({f"generated-text-{wandb.run.name}": copy.copy(text_table)})
313 | model.train()
314 |
--------------------------------------------------------------------------------
/src/optim/utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import random
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | from contextlib import nullcontext
7 | import torch.distributed as dist
8 | import math
9 | import wandb
10 |
11 |
12 | def get_batch(datareader, device="cpu"):
13 | x, y = datareader.sample_batch()
14 | if "cuda" in torch.device(device).type:
15 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
16 | x = x.pin_memory().to(device, non_blocking=True)
17 | y = y.pin_memory().to(device, non_blocking=True)
18 | else:
19 | x = x.to(device)
20 | y = y.to(device)
21 | return x, y
22 |
23 |
24 | def cos_inf_schedule(n_iterations, n_warmup, div_factor, final_div_factor, n_inf):
25 | """Cosine annealing with warmup and _constant_ final_lr after cycle ended.
26 | Args:
27 | n_iterations: total number of iterations
28 | n_warmup: number of warmup iterations
29 | div_factor: initial division factor for warmup
30 | final_div_factor: final division factor for final lr
31 | n_inf: number of iterations for the final lr (constant lr after cycle ended)
32 | Returns:
33 | schedule: a function that takes the current iteration and
34 | returns the multiplicative factor for the learning rate
35 | """
36 | max_lr = 1.0
37 | base_lr = max_lr / div_factor
38 | final_lr = base_lr / final_div_factor
39 |
40 | n_anneal_steps = n_iterations - n_inf
41 |
42 | def schedule(step):
43 | if step < n_warmup:
44 | return (step / n_warmup) + (1 - step / n_warmup) / div_factor
45 | elif step < n_anneal_steps:
46 | t = (step - n_warmup) / (n_anneal_steps - n_warmup)
47 | lr = final_lr + 0.5 * (max_lr - final_lr) * (1 + np.cos(np.pi * t))
48 | return lr
49 | else:
50 | return final_lr
51 |
52 | return schedule
53 |
54 |
55 | def wsd_schedule(
56 | n_iterations,
57 | final_lr_factor=0.0,
58 | n_warmup=1000,
59 | init_div_factor=100,
60 | fract_decay=0.1,
61 | decay_type="linear",
62 | ):
63 | """Warmup, hold, and decay schedule.
64 | Args:
65 | n_iterations: total number of iterations
66 | final_lr_factor: factor by which to reduce max_lr at the end
67 | warmup_fract: fraction of iterations used for warmup
68 | init_div_factor: initial division factor for warmup
69 | fract_decay: fraction of iterations used for decay
70 | Returns:
71 | schedule: a function that takes the current iteration and
72 | returns the multiplicative factor for the learning rate
73 | """
74 | n_anneal_steps = int(fract_decay * n_iterations)
75 | n_hold = n_iterations - n_anneal_steps
76 |
77 | def schedule(step):
78 | if step < n_warmup:
79 | return (step / n_warmup) + (1 - step / n_warmup) / init_div_factor
80 | elif step < n_hold:
81 | return 1.0
82 | elif step < n_iterations:
83 | if decay_type == "linear":
84 | return final_lr_factor + (1 - final_lr_factor) * (
85 | 1 - (step - n_hold) / n_anneal_steps
86 | )
87 | elif decay_type == "exp":
88 | return final_lr_factor ** ((step - n_hold) / n_anneal_steps)
89 | elif decay_type == "cosine":
90 | return (
91 | final_lr_factor
92 | + (1 - final_lr_factor)
93 | * (1 + math.cos(math.pi * (step - n_hold) / n_anneal_steps))
94 | * 0.5
95 | )
96 | elif decay_type == "miror_cosine":
97 | cosine_value = (
98 | final_lr_factor
99 | + (1 - final_lr_factor)
100 | * (1 + math.cos(math.pi * (step - n_hold) / n_anneal_steps))
101 | * 0.5
102 | )
103 | linear_value = final_lr_factor + (1 - final_lr_factor) * (
104 | 1 - (step - n_hold) / n_anneal_steps
105 | )
106 | return linear_value * 2 - cosine_value
107 | elif decay_type == "square":
108 | return final_lr_factor + (1 - final_lr_factor) * (
109 | 1 - ((step - n_hold) / n_anneal_steps) ** 2
110 | )
111 |
112 | elif decay_type == "sqrt":
113 | return final_lr_factor + (1 - final_lr_factor) * (
114 | 1 - math.sqrt((step - n_hold) / n_anneal_steps)
115 | )
116 |
117 | else:
118 | raise ValueError(
119 | f"decay type {decay_type} is not in ['cosine','miror_cosine','linear','exp']"
120 | )
121 |
122 | else:
123 | return final_lr_factor
124 |
125 | return schedule
126 |
127 |
128 | @torch.no_grad()
129 | def eval(
130 | model,
131 | reader,
132 | device="cpu",
133 | max_num_batches=24,
134 | ctx=nullcontext(),
135 | cfg=None,
136 | ):
137 | assert model.training == False
138 |
139 | loss_list_val, acc_list = [], []
140 |
141 | for idx in range(max_num_batches):
142 | x, y = get_batch(reader, device=device)
143 | with ctx:
144 | outputs = model(x, targets=y, get_logits=True)
145 | val_loss = outputs["loss"]
146 |
147 | loss_list_val.append(val_loss)
148 | acc_list.append((outputs["logits"].argmax(-1) == y).float().mean())
149 |
150 | val_acc = torch.stack(acc_list).mean().item()
151 | val_loss = torch.stack(loss_list_val).mean().item()
152 | val_perplexity = 2.71828**val_loss
153 |
154 | return val_acc, val_loss, val_perplexity
155 |
156 |
157 | @torch.no_grad()
158 | def eval_sweep_dropk(
159 | model,
160 | data_tensor,
161 | sequence_length,
162 | batch_size,
163 | n_heads,
164 | device="cpu",
165 | max_num_batches=24,
166 | ctx=nullcontext(),
167 | ):
168 | assert model.training == False
169 |
170 | x_axis, y_axis_pp, y_axis_acc, y_axis_loss = (
171 | torch.linspace(0.0, 0.95, 15),
172 | [],
173 | [],
174 | [],
175 | )
176 | loss_list_val, acc_list = [], []
177 |
178 | for frac in x_axis:
179 | drop_k = int(sequence_length * frac * n_heads)
180 | for _ in range(max_num_batches):
181 | x, y = get_batch(data_tensor, sequence_length, batch_size, device=device)
182 | with ctx:
183 | outputs = model(
184 | x, targets=y, alpha_th=None, drop_k=drop_k, get_logits=True
185 | )
186 | loss_list_val.append(outputs["ce_loss"])
187 | acc_list.append((outputs["logits"].argmax(-1) == y).float().mean())
188 |
189 | y_axis_acc.append(torch.stack(acc_list).mean().item())
190 | y_axis_loss.append(np.mean(loss_list_val))
191 | y_axis_pp.append(2.71828 ** y_axis_loss[-1])
192 |
193 | return x_axis, y_axis_acc, y_axis_pp, y_axis_loss
194 |
195 |
196 | @torch.no_grad()
197 | def eval_sweep_alphath(
198 | model,
199 | data_tensor,
200 | sequence_length,
201 | batch_size,
202 | device="cpu",
203 | max_num_batches=24,
204 | ctx=nullcontext(),
205 | ):
206 | assert model.training == False
207 |
208 | alpha_ths, y_axis_pp, y_axis_acc, y_axis_loss = (
209 | [0, 1e-4, 1e-3, 1e-2, 1e-1, 2e-1, 3e-1, 4e-1, 5e-1],
210 | [],
211 | [],
212 | [],
213 | )
214 | loss_list_val, acc_list, x_axis = [], [], []
215 |
216 | for alpha_th in alpha_ths:
217 | frac_heads_pruned_list = []
218 | for _ in range(max_num_batches):
219 | x, y = get_batch(data_tensor, sequence_length, batch_size, device=device)
220 | with ctx:
221 | outputs = model(
222 | x, targets=y, alpha_th=alpha_th, drop_k=None, get_logits=True
223 | )
224 | nph, nh = (
225 | outputs["num_head_pruned_per_layer"],
226 | outputs["num_heads_per_layer"],
227 | )
228 | frac_heads_pruned = np.sum(nph) / np.sum(
229 | nh
230 | ) # fractions of heads removed given alpha_th
231 | frac_heads_pruned_list.append(frac_heads_pruned)
232 | loss_list_val.append(outputs["ce_loss"])
233 | acc_list.append((outputs["logits"].argmax(-1) == y).float().mean())
234 |
235 | x_axis.append(np.mean(frac_heads_pruned_list))
236 | y_axis_acc.append(torch.stack(acc_list).mean().item())
237 | y_axis_loss.append(np.mean(loss_list_val))
238 | y_axis_pp.append(2.71828 ** y_axis_loss[-1])
239 |
240 | return x_axis, y_axis_acc, y_axis_pp, y_axis_loss
241 |
242 |
243 | def save_checkpoint(model, opt, scheduler, itr, ckpt_dir: Path):
244 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
245 | model = model.module
246 |
247 | checkpoint = {
248 | "model": model.state_dict(),
249 | "optimizer": opt.state_dict(),
250 | "scheduler": scheduler.state_dict(),
251 | "itr": itr,
252 | }
253 | ckpt_dir.mkdir(exist_ok=True, parents=True)
254 | torch.save(checkpoint, ckpt_dir / "main.pt")
255 |
256 |
257 | def load_checkpoint(model, opt, scheduler, ckpt_path, device):
258 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
259 | model = model.module
260 |
261 | ckpt = torch.load(ckpt_path, map_location=device)
262 | model.load_state_dict(ckpt["model"])
263 | opt.load_state_dict(ckpt["optimizer"])
264 | scheduler.load_state_dict(ckpt["scheduler"])
265 | itr = ckpt["itr"]
266 | return itr
267 |
268 |
269 | def save_worker_state(ckpt_dir: Path):
270 | # Dataloader, rng states
271 | worker_state = {
272 | "rng_torch_cpu": torch.random.get_rng_state(),
273 | "rng_torch_gpu": torch.cuda.get_rng_state(),
274 | "rng_np": np.random.get_state(),
275 | "rng_python": random.getstate(),
276 | }
277 | rank = 0 if not dist.is_initialized() else dist.get_rank()
278 | ckpt_dir.mkdir(exist_ok=True, parents=True)
279 | torch.save(worker_state, ckpt_dir / f"worker_{rank}.pt")
280 |
281 |
282 | def load_worker_state(ckpt_dir: Path):
283 | rank = 0 if not dist.is_initialized() else dist.get_rank()
284 | worker_state = torch.load(ckpt_dir / f"worker_{rank}.pt")
285 | torch.random.set_rng_state(worker_state["rng_torch_cpu"])
286 | torch.cuda.set_rng_state(worker_state["rng_torch_gpu"])
287 | np.random.set_state(worker_state["rng_np"])
288 | random.setstate(worker_state["rng_python"])
289 |
--------------------------------------------------------------------------------
/src/optim/weight_averaging.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from pathlib import Path
3 | import tempfile
4 |
5 | import torch
6 | import wandb
7 |
8 | from .utils import eval
9 |
10 |
11 | class WeightAverager:
12 | def __init__(
13 | self,
14 | model,
15 | horizon=100,
16 | interval=1,
17 | save_dir=None,
18 | device=None,
19 | dtype=torch.float32,
20 | count=0,
21 | ):
22 | super().__init__()
23 | self.device = device # Where to keep avg model
24 | self.dtype = dtype # Precision for accumulation (>= float32)
25 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
26 | model = model.module
27 | self.module = deepcopy(model).to(dtype=self.dtype, device=device)
28 |
29 | assert horizon % interval == 0, "Interval should divide period"
30 | self.interval = interval
31 | self.horizon = horizon
32 | self.period = horizon // interval
33 | if save_dir is None:
34 | # Keep in tempdir
35 | self._tempdir = tempfile.TemporaryDirectory()
36 | self.save_dir = Path(self._tempdir.name)
37 | else:
38 | self.save_dir = Path(save_dir)
39 | self.save_dir.mkdir(parents=True, exist_ok=True)
40 | self.count = count
41 | # check if there are any checkpoints saved in the directory and set
42 | # num_saved to number of checkpoints with name <= count
43 | self.num_saved = len(
44 | [f for f in self.save_dir.iterdir() if f.is_file() and int(f.stem) <= count]
45 | )
46 |
47 | @torch.no_grad()
48 | def step(self, model, is_master_rank=True):
49 | # Update module with current state
50 | if self.count % self.interval == 0:
51 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
52 | model = model.module
53 | for key, avg in self.module.state_dict().items():
54 | curr = model.state_dict()[key].to(device=self.device, dtype=avg.dtype)
55 | rate = 1 / ((self.count % self.horizon) // self.interval + 1)
56 | avg.copy_(torch.lerp(avg, curr, rate))
57 |
58 | self.count += 1
59 |
60 | if self.count % self.horizon == 0 and is_master_rank:
61 | torch.save(
62 | self.module.to().state_dict(),
63 | self.save_dir / f"{self.count}.pt",
64 | )
65 | self.num_saved += 1
66 |
67 | def get_latest_like(self, model):
68 | # Return model for latest completed period
69 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
70 | model = model.module
71 | new_model = deepcopy(model)
72 |
73 | # Assumes that we saved at a specific iteration, will fail otherwise
74 | count = self.count - self.count % self.horizon
75 | latest_path = self.save_dir / f"{count}.pt"
76 | map_and_load_state_dict(new_model, torch.load(latest_path))
77 |
78 | return new_model
79 |
80 | def sweep_horizon_like(self, model, max_num=None):
81 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
82 | model = model.module
83 | new_model = deepcopy(model)
84 | avg_state = deepcopy(self.module.state_dict())
85 | if max_num is None:
86 | max_num = self.num_saved
87 | # Assumes all points exist
88 | for n in range(min(self.num_saved, max_num)):
89 | # Load state from the corresponding checkpoint
90 | count = self.count - self.count % self.horizon - n * self.horizon
91 | state = torch.load(self.save_dir / f"{count}.pt")
92 |
93 | # Update average state
94 | for key, avg in avg_state.items():
95 | new = state[key].to(dtype=avg.dtype, device=avg.device)
96 | rate = 1 / (n + 1)
97 | avg.copy_(torch.lerp(avg, new, rate))
98 |
99 | # Set new_model state and yield it
100 | map_and_load_state_dict(new_model, avg_state)
101 | yield ((n + 1) * self.horizon, new_model)
102 |
103 |
104 | def map_and_load_state_dict(model, state_dict):
105 | for key, m_val in model.state_dict().items():
106 | for alias in (f'_orig_mod.{key}', f'_orig_mod.module.{key}'): # handle compiled / nested model
107 | if key not in state_dict and alias in state_dict:
108 | key = alias
109 | break
110 | s_val = state_dict[key]
111 | m_val.copy_(s_val.to(device=m_val.device, dtype=m_val.dtype))
112 |
113 |
114 | def eval_wa(
115 | curr_iter,
116 | model,
117 | weight_averager,
118 | val_reader,
119 | type_ctx,
120 | distributed_backend,
121 | cfg,
122 | full_eval=False,
123 | ):
124 | if not distributed_backend.is_master_process():
125 | # Only evaluate and log on master rank
126 | return
127 |
128 | if weight_averager.num_saved == 0:
129 | return
130 | if not cfg.wa_sweep_horizon:
131 | val_reader.set_step(0)
132 | val_acc, val_loss, val_perplexity = eval(
133 | weight_averager.get_latest_like(model).eval(),
134 | val_reader,
135 | cfg.device,
136 | max_num_batches=(
137 | val_reader.num_batches()
138 | if curr_iter == cfg.iterations or full_eval
139 | else cfg.eval_batches
140 | ),
141 | ctx=type_ctx,
142 | cfg=cfg,
143 | )
144 |
145 | if cfg.wandb:
146 | if curr_iter == cfg.iterations or full_eval:
147 | logs = {
148 | "iter": curr_iter,
149 | "final-val/loss_wa": val_loss,
150 | "final-val/perplexity_wa": val_perplexity,
151 | "final-val/acc_wa": val_acc,
152 | }
153 | else:
154 | logs = {
155 | "iter": curr_iter,
156 | "val/loss_wa": val_loss,
157 | "val/perplexity_wa": val_perplexity,
158 | "val/acc_wa": val_acc,
159 | }
160 | wandb.log(logs)
161 | print(
162 | f">WA Eval: Iter={curr_iter} "
163 | f"val_loss={val_loss:.3f} "
164 | f"val_pp={val_perplexity:.3f} "
165 | f"val_acc={val_acc:3f}"
166 | )
167 | else:
168 | losses = []
169 | for horizon, avg_model in weight_averager.sweep_horizon_like(
170 | model, cfg.max_num_wa_sweeps
171 | ):
172 | avg_model.eval()
173 | val_reader.set_step(0)
174 | _, val_loss, _ = eval(
175 | avg_model,
176 | val_reader,
177 | cfg.device,
178 | max_num_batches=(
179 | val_reader.num_batches()
180 | if curr_iter == cfg.iterations or full_eval
181 | else cfg.eval_batches
182 | ),
183 | ctx=type_ctx,
184 | cfg=cfg,
185 | )
186 |
187 | losses.append((val_loss, horizon))
188 | if len(losses) == 0: # in case of none saved yet
189 | return
190 | best_loss, best_horizon = sorted(losses)[0]
191 |
192 | print(f"WA Eval: {[(h, f'{l:0.3e}') for (l,h) in losses]}")
193 |
194 | if cfg.wandb:
195 | if curr_iter == cfg.iterations or full_eval:
196 | logs = {
197 | "iter": curr_iter,
198 | "final-val/loss_wa": losses[0][0],
199 | "final-val/perplexity_wa": 2.71828 ** losses[0][0],
200 | "final-val/best_loss_wa": best_loss,
201 | "final-val/best_perplexity_wa": 2.71828**best_loss,
202 | }
203 | else:
204 | logs = {
205 | "iter": curr_iter,
206 | "val/loss_wa": losses[0][0],
207 | "val/perplexity_wa": 2.71828 ** losses[0][0],
208 | "val/best_loss_wa": best_loss,
209 | "val/best_perplexity_wa": 2.71828**best_loss,
210 | "wa_best_horizon": best_horizon,
211 | }
212 | wandb.log(logs)
213 |
214 |
215 | class ExponentialWeightAverager:
216 | def __init__(
217 | self,
218 | model,
219 | interval=1,
220 | decay=0.95,
221 | device=None,
222 | warmup=0,
223 | dtype=torch.float32,
224 | ):
225 | super().__init__()
226 | self.device = device # Where to keep avg model
227 | self.dtype = dtype # Precision for accumulation (>= float32)
228 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
229 | model = model.module
230 | self.module = deepcopy(model).to(dtype=self.dtype, device=device)
231 |
232 | self.interval = interval
233 | self.decay = decay
234 | self.num_saved = 0
235 | self.warmup = warmup
236 | self.count = 0
237 |
238 | @torch.no_grad()
239 | def step(self, model, is_master_rank=True):
240 | # Update module with current state
241 |
242 | if self.count < self.warmup:
243 | self.count += 1
244 | return
245 |
246 | if self.count == self.warmup:
247 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
248 | model = model.module
249 | for key, avg in self.module.state_dict().items():
250 | curr = model.state_dict()[key].to(device=self.device, dtype=avg.dtype)
251 | avg.copy_(curr)
252 |
253 | elif self.count % self.interval == 0:
254 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
255 | model = model.module
256 | for key, avg in self.module.state_dict().items():
257 | curr = model.state_dict()[key].to(device=self.device, dtype=avg.dtype)
258 | avg.copy_(torch.lerp(avg, curr, 1 - self.decay))
259 | self.num_saved += 1
260 |
261 | self.count += 1
262 |
263 | # if self.count % self.horizon == 0 and is_master_rank:
264 | # torch.save(
265 | # self.module.to(dtype=torch.bfloat16).state_dict(),
266 | # self.save_dir / f"{self.count}.pt",
267 | # )
268 | # self.num_saved += 1
269 |
270 | def get_latest_like(self, model):
271 | # Return model for latest completed period
272 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
273 | model = model.module
274 | new_model = deepcopy(model)
275 |
276 | map_and_load_state_dict(
277 | new_model, self.module.to(dtype=torch.bfloat16).state_dict()
278 | )
279 |
280 | return new_model
281 |
282 |
283 | def eval_ema(
284 | curr_iter,
285 | model,
286 | ema,
287 | val_reader,
288 | type_ctx,
289 | distributed_backend,
290 | cfg,
291 | full_eval=False,
292 | ):
293 | if not distributed_backend.is_master_process():
294 | # Only evaluate and log on master rank
295 | return
296 |
297 | val_reader.set_step(0)
298 | val_acc, val_loss, val_perplexity = eval(
299 | ema.get_latest_like(model).eval(),
300 | val_reader,
301 | cfg.device,
302 | max_num_batches=(
303 | val_reader.num_batches()
304 | if curr_iter == cfg.iterations or full_eval
305 | else cfg.eval_batches
306 | ),
307 | ctx=type_ctx,
308 | cfg=cfg,
309 | )
310 |
311 | if cfg.wandb:
312 | if curr_iter == cfg.iterations or full_eval:
313 | logs = {
314 | "iter": curr_iter,
315 | "final-val/loss_ema": val_loss,
316 | "final-val/perplexity_ema": val_perplexity,
317 | "final-val/acc_ema": val_acc,
318 | }
319 | else:
320 | logs = {
321 | "iter": curr_iter,
322 | "val/loss_ema": val_loss,
323 | "val/perplexity_ema": val_perplexity,
324 | "val/acc_ema": val_acc,
325 | }
326 | wandb.log(logs)
327 | print(
328 | f">EMA Eval: Iter={curr_iter} "
329 | f"val_loss={val_loss:.3f} "
330 | f"val_pp={val_perplexity:.3f} "
331 | f"val_acc={val_acc:3f}"
332 | )
333 |
--------------------------------------------------------------------------------