├── LICENSE ├── README.md ├── logs ├── A100_run.txt ├── GPT4-tok-run.txt ├── Muon_run.txt ├── PSGD_run.txt ├── lr_test_runs.txt ├── shrunk_run.txt └── tweaks_run_nosave.txt └── src ├── blank ├── data ├── data.py └── finewebedu10b │ └── download.py ├── model.py ├── plot.py ├── sample.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 VatsaDev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NanoPoor 2 | NanoGPT-speedrunning for the poor T4 enjoyers 3 | 4 | [Colab Notebook](https://colab.research.google.com/drive/1x87U-mCZCt7Kwc5-HGPOR1NVCOYAN1dr?usp=sharing) 5 | 6 | Inspired by [Modded NanoGPT](https://github.com/KellerJordan/modded-nanogpt) and my goat [Jonas Geiping (Cramming)](https://arxiv.org/pdf/2212.14034), I trained a custom GPT I've been working on over at [Dagonet](https://github.com/BambooML/Dagonet), got to the 3.28 val loss on a single T4. 7 | 8 | **Important! Note/Future-bugifx** 9 | 10 | As [@main_horse](https://x.com/main_horse/status/1907238044434104633) pointed out, I wrote a method that had the DSMoE class send the current tok to all experts, then apply router weights, so it removed the hard selection of the router, and made it more of a soft weighing instead, the hard routing is loss ~0.1 lower, or about 10 steps faster, but wallclock time per step is 2x longer and init was 8x longer, working on GEMMs 11 | 12 | **caveats:** 13 | - Less than the 120M from main speedrun for stability 14 | - was just a 1B subset of finewebedu10b, not filtered or anything I just processed that much at this time, will probably fix this later 15 | 16 | ## Runs 17 | 18 | | Ranking | Time - date | Data | Person | Description | log | 19 | | -------- | ----------- | ---- | ------ | ----------- | --- | 20 | | 1 | 7.09m - 4/5/25 | ~3.27M tok (1024 * 8 * 4 * 100) | Vatsa | now GPT-2 tokenizer shrunk vocab_size, and also shrunk head_lm and n_experts for stability, less params, now at ~73m | [log](https://github.com/VatsaDev/NanoPoor/blob/main/logs/shrunk_run.txt) | 21 | | 2 | 11.69m - 4/4/25 | ~3.93M tok (1024 * 8 * 4 * 120) | Vatsa | lr tuning (5e-4) | [log](https://github.com/VatsaDev/NanoPoor/blob/main/logs/lr_test_runs.txt) | 22 | | 3 | 14.86m - 4/2/25 | ~5.21M tok (1024 * 8 * 4 * 160) | Vatsa | 3x lr, removed ckpt saves every step, less printing | [log](https://github.com/VatsaDev/NanoPoor/blob/main/logs/tweaks_run_nosave.txt) | 23 | | 4 | 15.04m - 4/1/25 | ~3.89M tok (1024 * 5 * 4 * 190) | Vatsa | Used Muon instead | [log](https://github.com/VatsaDev/NanoPoor/blob/main/logs/Muon_run.txt) | 24 | | 5 | 37.17m - 4/1/25 | ~6.14M tok (1024 * 5 * 4 * 300) | Vatsa | Added PSGD | [log](https://github.com/VatsaDev/NanoPoor/blob/main/logs/GPT4-tok-run.txt) | 25 | | 6 | 70.61m - 3/31/25 | ~14M tok (1024 * 6 * 4 * 570) | Vatsa | First Run, has DS-MoE, MLA+NSA hybrid, Rope, etc | [log](https://github.com/VatsaDev/NanoPoor/blob/main/logs/PSGD_run.txt) | 26 | 27 | ## Unofficial Runs 28 | 29 | | Ranking | Time - date | Data | Person | Description | log | 30 | | -------- | ----------- | ---- | ------ | ----------- | --- | 31 | | 1st | 7.63m - 4/1/25 | ~6.96M tok (1024 * 10 * 4 * 170) | Vatsa | Used an A100 with (15.04m - 4/1/25) run to see how I look on a real GPU | [log](https://github.com/VatsaDev/NanoPoor/blob/main/logs/Muon_run.txt) | 32 | -------------------------------------------------------------------------------- /logs/A100_run.txt: -------------------------------------------------------------------------------- 1 | num Muon parameters: 57,824,256 2 | num AdamW parameters: 61,822,481 3 | Using Muon (optimizer 0) and AdamW (optimizer 1) 4 | compiling the model... 5 | compiled 6 | 119.646737 M trainable parameters 7 | 8 | Starting training loop from iteration 0... 9 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 10 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 11 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 12 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 13 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 14 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 15 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 16 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 17 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 18 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 19 | /usr/lib/python3/dist-packages/torch/optim/lr_scheduler.py:243: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose. 20 | warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) 21 | step: 10, train loss: 7.9090, val loss: 7.9055, elapsed time: 3.53 min, dt: 211932.50 ms 22 | Saving checkpoint to checkpoints/Sp8LcM_check_10.pt 23 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 24 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 25 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 26 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 27 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 28 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 29 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 30 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 31 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 32 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 33 | step: 20, train loss: 5.3240, val loss: 5.2993, elapsed time: 4.49 min, dt: 57423.77 ms 34 | Saving checkpoint to checkpoints/Sp8LcM_check_20.pt 35 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 36 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 37 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 38 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 39 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 40 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 41 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 42 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 43 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 44 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 45 | step: 30, train loss: 4.1583, val loss: 4.1320, elapsed time: 4.70 min, dt: 12739.16 ms 46 | Saving checkpoint to checkpoints/Sp8LcM_check_30.pt 47 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 48 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 49 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 50 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 51 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 52 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 53 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 54 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 55 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 56 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 57 | step: 40, train loss: 3.9705, val loss: 3.9540, elapsed time: 4.91 min, dt: 12491.79 ms 58 | Saving checkpoint to checkpoints/Sp8LcM_check_40.pt 59 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 60 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 61 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 62 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 63 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 64 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 65 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 66 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 67 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 68 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 69 | step: 50, train loss: 3.9581, val loss: 3.9803, elapsed time: 5.12 min, dt: 12494.17 ms 70 | Saving checkpoint to checkpoints/Sp8LcM_check_50.pt 71 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 72 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 73 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 74 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 75 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 76 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 77 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 78 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 79 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 80 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 81 | step: 60, train loss: 3.9341, val loss: 3.9294, elapsed time: 5.33 min, dt: 12726.67 ms 82 | Saving checkpoint to checkpoints/Sp8LcM_check_60.pt 83 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 84 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 85 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 86 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 87 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 88 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 89 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 90 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 91 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 92 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 93 | step: 70, train loss: 3.8891, val loss: 3.9142, elapsed time: 5.54 min, dt: 12492.11 ms 94 | Saving checkpoint to checkpoints/Sp8LcM_check_70.pt 95 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 96 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 97 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 98 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 99 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 100 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 101 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 102 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 103 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 104 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 105 | step: 80, train loss: 3.8947, val loss: 3.8826, elapsed time: 5.75 min, dt: 12482.31 ms 106 | Saving checkpoint to checkpoints/Sp8LcM_check_80.pt 107 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 108 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 109 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 110 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 111 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 112 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 113 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 114 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 115 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 116 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 117 | step: 90, train loss: 3.8449, val loss: 3.8586, elapsed time: 5.96 min, dt: 12718.64 ms 118 | Saving checkpoint to checkpoints/Sp8LcM_check_90.pt 119 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 120 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 121 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 122 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 123 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 124 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 125 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 126 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 127 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 128 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 129 | step: 100, train loss: 3.8152, val loss: 3.7845, elapsed time: 6.17 min, dt: 12722.94 ms 130 | Saving checkpoint to checkpoints/Sp8LcM_check_100.pt 131 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 132 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 133 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 134 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 135 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 136 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 137 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 138 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 139 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 140 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 141 | step: 110, train loss: 3.7092, val loss: 3.7006, elapsed time: 6.38 min, dt: 12736.79 ms 142 | Saving checkpoint to checkpoints/Sp8LcM_check_110.pt 143 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 144 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 145 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 146 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 147 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 148 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 149 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 150 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 151 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 152 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 153 | step: 120, train loss: 3.6268, val loss: 3.6139, elapsed time: 6.59 min, dt: 12491.84 ms 154 | Saving checkpoint to checkpoints/Sp8LcM_check_120.pt 155 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 156 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 157 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 158 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 159 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 160 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 161 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 162 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 163 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 164 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 165 | step: 130, train loss: 3.5350, val loss: 3.5343, elapsed time: 6.80 min, dt: 12474.93 ms 166 | Saving checkpoint to checkpoints/Sp8LcM_check_130.pt 167 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 168 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 169 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 170 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 171 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 172 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 173 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 174 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 175 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 176 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 177 | step: 140, train loss: 3.4531, val loss: 3.4620, elapsed time: 7.01 min, dt: 12716.14 ms 178 | Saving checkpoint to checkpoints/Sp8LcM_check_140.pt 179 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 180 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 181 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 182 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 183 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 184 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 185 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 186 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 187 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 188 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 189 | step: 150, train loss: 3.3810, val loss: 3.4169, elapsed time: 7.22 min, dt: 12479.96 ms 190 | Saving checkpoint to checkpoints/Sp8LcM_check_150.pt 191 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 192 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 193 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 194 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 195 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 196 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 197 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 198 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 199 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 200 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 201 | step: 160, train loss: 3.2809, val loss: 3.3219, elapsed time: 7.43 min, dt: 12465.71 ms 202 | Saving checkpoint to checkpoints/Sp8LcM_check_160.pt 203 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 204 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 205 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 206 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 207 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 208 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 209 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 210 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 211 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 212 | Inside loop, before scaler.step(Muon): dist.is_initialized() = True 213 | step: 170, train loss: 3.1845, val loss: 3.1830, elapsed time: 7.63 min, dt: 12464.43 ms 214 | -------------------------------------------------------------------------------- /logs/GPT4-tok-run.txt: -------------------------------------------------------------------------------- 1 | # First ever run 2 | 3 | # tokens used is ctx_len*batch_size*grad_accum*steps = 1024*6*4*570 = 14008320, 14M tokens 4 | 5 | found vocab_size = 100277 (inside data/tokenized_data/meta.pkl) 6 | num decayed parameter tensors: 278, with 119,443,712 parameters 7 | num non-decayed parameter tensors: 206, with 203,025 parameters 8 | using fused AdamW: True 9 | starting run 8mqSZf from scratch 10 | compiling the model... 11 | compiled 12 | 119.646797 M params 13 | W0331 21:26:50.368000 4777 torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode 14 | step: 0, train loss: 11.6878, val loss: 11.6905, elapsed time: 4.2598 min, mfu: 0.00109919, total_flops: 1.8261e+13 15 | /usr/local/lib/python3.11/dist-packages/torch/optim/lr_scheduler.py:243: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose. 16 | warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) 17 | step: 10, train loss: 7.5543, val loss: 7.5441, elapsed time: 5.9373 min, mfu: 0.00279129, total_flops: 6.4634e+13 18 | step: 20, train loss: 5.1971, val loss: 5.1872, elapsed time: 6.9738 min, mfu: 0.00451737, total_flops: 1.2286e+14 19 | step: 30, train loss: 4.6162, val loss: 4.5774, elapsed time: 7.9321 min, mfu: 0.00488616, total_flops: 1.5115e+14 20 | step: 40, train loss: 4.1497, val loss: 4.1250, elapsed time: 8.8896 min, mfu: 0.00489007, total_flops: 1.6954e+14 21 | step: 50, train loss: 3.9786, val loss: 4.0142, elapsed time: 9.8746 min, mfu: 0.00475375, total_flops: 1.8307e+14 22 | step: 60, train loss: 3.9823, val loss: 3.9745, elapsed time: 11.7877 min, mfu: 0.00244749, total_flops: 1.1252e+14 23 | step: 70, train loss: 3.9559, val loss: 3.9975, elapsed time: 12.8118 min, mfu: 0.00457232, total_flops: 2.2846e+14 24 | step: 80, train loss: 3.9556, val loss: 3.9723, elapsed time: 14.1266 min, mfu: 0.00356108, total_flops: 1.9619e+14 25 | step: 90, train loss: 3.9474, val loss: 3.9719, elapsed time: 15.1842 min, mfu: 0.00442757, total_flops: 2.6219e+14 26 | step: 100, train loss: 3.9621, val loss: 3.9485, elapsed time: 16.3599 min, mfu: 0.00398247, total_flops: 2.5410e+14 27 | step: 110, train loss: 3.9657, val loss: 3.9476, elapsed time: 17.5582 min, mfu: 0.00390742, total_flops: 2.6757e+14 28 | step: 120, train loss: 3.9218, val loss: 3.9505, elapsed time: 18.6338 min, mfu: 0.00435334, total_flops: 3.1637e+14 29 | step: 130, train loss: 3.9347, val loss: 3.9418, elapsed time: 19.7117 min, mfu: 0.00434395, total_flops: 3.3394e+14 30 | step: 140, train loss: 3.9684, val loss: 3.9282, elapsed time: 20.7854 min, mfu: 0.00436084, total_flops: 3.5350e+14 31 | step: 150, train loss: 3.9149, val loss: 3.9402, elapsed time: 21.9225 min, mfu: 0.00411777, total_flops: 3.5206e+14 32 | step: 160, train loss: 3.9220, val loss: 3.9144, elapsed time: 23.6963 min, mfu: 0.00263982, total_flops: 2.4396e+14 33 | step: 170, train loss: 3.8918, val loss: 3.8444, elapsed time: 24.8278 min, mfu: 0.00413804, total_flops: 4.0068e+14 34 | step: 180, train loss: 3.8807, val loss: 3.8477, elapsed time: 25.7900 min, mfu: 0.00486624, total_flops: 4.8945e+14 35 | step: 190, train loss: 3.8007, val loss: 3.8110, elapsed time: 27.0022 min, mfu: 0.00386283, total_flops: 4.0679e+14 36 | step: 200, train loss: 3.7953, val loss: 3.7890, elapsed time: 28.2383 min, mfu: 0.00378779, total_flops: 4.1715e+14 37 | /content/Dagonet/src/plot.py:31: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). Consider using `matplotlib.pyplot.close()`. 38 | plt.figure(figsize=(8, 4), dpi=100) 39 | step: 210, train loss: 3.7528, val loss: 3.7398, elapsed time: 29.2736 min, mfu: 0.00452280, total_flops: 5.1636e+14 40 | step: 220, train loss: 3.7527, val loss: 3.7445, elapsed time: 30.3741 min, mfu: 0.00425478, total_flops: 5.0402e+14 41 | step: 230, train loss: 3.6900, val loss: 3.7211, elapsed time: 32.3297 min, mfu: 0.00239432, total_flops: 3.0189e+14 42 | step: 240, train loss: 3.6982, val loss: 3.6794, elapsed time: 33.4205 min, mfu: 0.00429260, total_flops: 5.5950e+14 43 | step: 250, train loss: 3.6778, val loss: 3.6606, elapsed time: 34.4805 min, mfu: 0.00441723, total_flops: 5.9400e+14 44 | step: 260, train loss: 3.6398, val loss: 3.6378, elapsed time: 35.4419 min, mfu: 0.00487057, total_flops: 6.7323e+14 45 | step: 270, train loss: 3.6196, val loss: 3.6187, elapsed time: 36.6087 min, mfu: 0.00401265, total_flops: 5.7290e+14 46 | step: 280, train loss: 3.5822, val loss: 3.6158, elapsed time: 37.7452 min, mfu: 0.00412023, total_flops: 6.0652e+14 47 | step: 290, train loss: 3.5738, val loss: 3.5825, elapsed time: 38.9017 min, mfu: 0.00404859, total_flops: 6.1424e+14 48 | step: 300, train loss: 3.5570, val loss: 3.5395, elapsed time: 40.0013 min, mfu: 0.00425832, total_flops: 6.6432e+14 49 | step: 310, train loss: 3.5456, val loss: 3.5815, elapsed time: 41.3260 min, mfu: 0.00353467, total_flops: 5.6969e+14 50 | step: 320, train loss: 3.5627, val loss: 3.5953, elapsed time: 42.3248 min, mfu: 0.00468803, total_flops: 7.7384e+14 51 | step: 330, train loss: 3.5262, val loss: 3.5836, elapsed time: 43.4853 min, mfu: 0.00403454, total_flops: 6.8423e+14 52 | step: 340, train loss: 3.5075, val loss: 3.5125, elapsed time: 44.7525 min, mfu: 0.00369516, total_flops: 6.4493e+14 53 | step: 350, train loss: 3.5361, val loss: 3.5139, elapsed time: 45.8161 min, mfu: 0.00440234, total_flops: 7.8662e+14 54 | step: 360, train loss: 3.4930, val loss: 3.5195, elapsed time: 46.8953 min, mfu: 0.00433842, total_flops: 7.9346e+14 55 | step: 370, train loss: 3.4761, val loss: 3.4916, elapsed time: 47.9731 min, mfu: 0.00434452, total_flops: 8.1284e+14 56 | step: 380, train loss: 3.4725, val loss: 3.4822, elapsed time: 49.1806 min, mfu: 0.00387787, total_flops: 7.4379e+14 57 | step: 390, train loss: 3.5043, val loss: 3.4283, elapsed time: 50.2219 min, mfu: 0.00449652, total_flops: 8.8071e+14 58 | step: 400, train loss: 3.4199, val loss: 3.4215, elapsed time: 51.3271 min, mfu: 0.00423647, total_flops: 8.4804e+14 59 | step: 410, train loss: 3.4194, val loss: 3.4291, elapsed time: 52.4343 min, mfu: 0.00422906, total_flops: 8.6482e+14 60 | step: 420, train loss: 3.4330, val loss: 3.3939, elapsed time: 53.5508 min, mfu: 0.00419388, total_flops: 8.7588e+14 61 | step: 430, train loss: 3.3952, val loss: 3.4142, elapsed time: 54.7765 min, mfu: 0.00382009, total_flops: 8.1608e+14 62 | step: 440, train loss: 3.3877, val loss: 3.4125, elapsed time: 55.7788 min, mfu: 0.00467162, total_flops: 1.0163e+15 63 | step: 450, train loss: 3.3836, val loss: 3.3682, elapsed time: 56.8280 min, mfu: 0.00446271, total_flops: 9.8907e+14 64 | step: 460, train loss: 3.3667, val loss: 3.4045, elapsed time: 58.0875 min, mfu: 0.00371753, total_flops: 8.4217e+14 65 | step: 470, train loss: 3.3659, val loss: 3.4039, elapsed time: 59.4221 min, mfu: 0.00350840, total_flops: 8.1306e+14 66 | step: 480, train loss: 3.3528, val loss: 3.3764, elapsed time: 60.7892 min, mfu: 0.00342520, total_flops: 8.1204e+14 67 | step: 490, train loss: 3.3788, val loss: 3.3866, elapsed time: 61.9427 min, mfu: 0.00405923, total_flops: 9.8061e+14 68 | step: 500, train loss: 3.3362, val loss: 3.3539, elapsed time: 62.9484 min, mfu: 0.00465548, total_flops: 1.1429e+15 69 | step: 510, train loss: 3.3687, val loss: 3.3481, elapsed time: 64.0334 min, mfu: 0.00431545, total_flops: 1.0777e+15 70 | step: 520, train loss: 3.3415, val loss: 3.3037, elapsed time: 65.2529 min, mfu: 0.00383981, total_flops: 9.7718e+14 71 | step: 530, train loss: 3.3251, val loss: 3.3111, elapsed time: 66.3568 min, mfu: 0.00424145, total_flops: 1.0977e+15 72 | step: 540, train loss: 3.3138, val loss: 3.3068, elapsed time: 67.4962 min, mfu: 0.00410959, total_flops: 1.0818e+15 73 | step: 550, train loss: 3.2846, val loss: 3.3099, elapsed time: 68.4892 min, mfu: 0.00471533, total_flops: 1.2595e+15 74 | step: 560, train loss: 3.2775, val loss: 3.2866, elapsed time: 69.6459 min, mfu: 0.00404811, total_flops: 1.0995e+15 75 | step: 570, train loss: 3.2876, val loss: 3.2631, elapsed time: 70.6161 min, mfu: 0.00482616, total_flops: 1.3291e+15 76 | -------------------------------------------------------------------------------- /logs/Muon_run.txt: -------------------------------------------------------------------------------- 1 | step: 10, train loss: 7.9172, val loss: 7.9133, elapsed time: 2.03 min, dt: 121513.46 ms 2 | step: 20, train loss: 5.3151, val loss: 5.3263, elapsed time: 3.64 min, dt: 97023.25 ms 3 | Saving checkpoint to checkpoints/MXLrqn_check_20.pt 4 | step: 30, train loss: 4.1351, val loss: 4.1755, elapsed time: 4.28 min, dt: 38154.20 ms 5 | Saving checkpoint to checkpoints/MXLrqn_check_30.pt 6 | step: 40, train loss: 3.9834, val loss: 3.9873, elapsed time: 4.97 min, dt: 41521.85 ms 7 | Saving checkpoint to checkpoints/MXLrqn_check_40.pt 8 | step: 50, train loss: 3.9849, val loss: 3.9488, elapsed time: 5.59 min, dt: 36892.43 ms 9 | Saving checkpoint to checkpoints/MXLrqn_check_50.pt 10 | step: 60, train loss: 3.9570, val loss: 3.9586, elapsed time: 6.37 min, dt: 46864.00 ms 11 | Saving checkpoint to checkpoints/MXLrqn_check_60.pt 12 | step: 70, train loss: 3.9274, val loss: 3.9428, elapsed time: 7.02 min, dt: 39311.94 ms 13 | Saving checkpoint to checkpoints/MXLrqn_check_70.pt 14 | step: 80, train loss: 3.8958, val loss: 3.8950, elapsed time: 7.66 min, dt: 38221.44 ms 15 | Saving checkpoint to checkpoints/MXLrqn_check_80.pt 16 | step: 90, train loss: 3.8799, val loss: 3.8430, elapsed time: 8.30 min, dt: 38579.50 ms 17 | Saving checkpoint to checkpoints/MXLrqn_check_90.pt 18 | step: 100, train loss: 3.8561, val loss: 3.8345, elapsed time: 8.92 min, dt: 36825.70 ms 19 | Saving checkpoint to checkpoints/MXLrqn_check_100.pt 20 | step: 110, train loss: 3.7931, val loss: 3.7668, elapsed time: 9.56 min, dt: 38775.39 ms 21 | Saving checkpoint to checkpoints/MXLrqn_check_110.pt 22 | step: 120, train loss: 3.6871, val loss: 3.7047, elapsed time: 10.20 min, dt: 38315.45 ms 23 | Saving checkpoint to checkpoints/MXLrqn_check_120.pt 24 | step: 130, train loss: 3.6668, val loss: 3.6315, elapsed time: 10.83 min, dt: 37585.17 ms 25 | Saving checkpoint to checkpoints/MXLrqn_check_130.pt 26 | step: 140, train loss: 3.5746, val loss: 3.5911, elapsed time: 11.48 min, dt: 38963.46 ms 27 | Saving checkpoint to checkpoints/MXLrqn_check_140.pt 28 | step: 150, train loss: 3.5368, val loss: 3.5283, elapsed time: 12.22 min, dt: 44538.58 ms 29 | Saving checkpoint to checkpoints/MXLrqn_check_150.pt 30 | step: 160, train loss: 3.4759, val loss: 3.4888, elapsed time: 13.06 min, dt: 50308.02 ms 31 | Saving checkpoint to checkpoints/MXLrqn_check_160.pt 32 | step: 170, train loss: 3.4451, val loss: 3.4001, elapsed time: 13.69 min, dt: 38077.79 ms 33 | Saving checkpoint to checkpoints/MXLrqn_check_170.pt 34 | step: 180, train loss: 3.3543, val loss: 3.3593, elapsed time: 14.40 min, dt: 42558.34 ms 35 | Saving checkpoint to checkpoints/MXLrqn_check_180.pt 36 | step: 190, train loss: 3.2829, val loss: 3.2258, elapsed time: 15.04 min, dt: 38442.20 ms 37 | -------------------------------------------------------------------------------- /logs/PSGD_run.txt: -------------------------------------------------------------------------------- 1 | step: 0, train loss: 11.4827, val loss: 11.4840, elapsed time: 2.2009 min, mfu: 0.00177285, total_flops: 1.5218e+13 2 | step: 10, train loss: 10.8437, val loss: 10.8442, elapsed time: 3.7332 min, mfu: 0.00254662, total_flops: 3.7077e+13 3 | step: 20, train loss: 8.2496, val loss: 8.2467, elapsed time: 4.9877 min, mfu: 0.00311037, total_flops: 6.0502e+13 4 | step: 30, train loss: 6.1695, val loss: 6.1480, elapsed time: 6.2757 min, mfu: 0.00302938, total_flops: 7.4145e+13 5 | step: 40, train loss: 5.2552, val loss: 5.2632, elapsed time: 7.3975 min, mfu: 0.00347816, total_flops: 1.0035e+14 6 | step: 50, train loss: 4.7883, val loss: 4.7777, elapsed time: 8.5980 min, mfu: 0.00325029, total_flops: 1.0899e+14 7 | step: 60, train loss: 4.3622, val loss: 4.3403, elapsed time: 9.7631 min, mfu: 0.00334917, total_flops: 1.2752e+14 8 | step: 70, train loss: 4.2487, val loss: 4.2070, elapsed time: 11.0581 min, mfu: 0.00301291, total_flops: 1.2994e+14 9 | step: 80, train loss: 4.1395, val loss: 4.1081, elapsed time: 13.0404 min, mfu: 0.00196848, total_flops: 1.0011e+14 10 | step: 90, train loss: 4.1091, val loss: 4.0845, elapsed time: 14.5393 min, mfu: 0.00260310, total_flops: 1.4760e+14 11 | step: 100, train loss: 4.0750, val loss: 4.0714, elapsed time: 15.9790 min, mfu: 0.00271028, total_flops: 1.6890e+14 12 | step: 120, train loss: 4.0499, val loss: 3.9653, elapsed time: 18.1044 min, mfu: 0.00183583, total_flops: 1.2962e+14 13 | step: 140, train loss: 4.0030, val loss: 3.9880, elapsed time: 20.8908 min, mfu: 0.00140035, total_flops: 1.1409e+14 14 | step: 160, train loss: 3.9687, val loss: 3.9773, elapsed time: 23.2816 min, mfu: 0.00163213, total_flops: 1.4819e+14 15 | step: 180, train loss: 3.9467, val loss: 3.9294, elapsed time: 25.7613 min, mfu: 0.00157354, total_flops: 1.5809e+14 16 | step: 200, train loss: 3.8427, val loss: 3.8493, elapsed time: 28.0373 min, mfu: 0.00171441, total_flops: 1.8746e+14 17 | step: 220, train loss: 3.7118, val loss: 3.6894, elapsed time: 29.9996 min, mfu: 0.00198839, total_flops: 2.3264e+14 18 | step: 240, train loss: 3.6133, val loss: 3.5821, elapsed time: 31.7676 min, mfu: 0.00220696, total_flops: 2.7343e+14 19 | step: 260, train loss: 3.4895, val loss: 3.4729, elapsed time: 33.5910 min, mfu: 0.00213999, total_flops: 2.8035e+14 20 | step: 280, train loss: 3.3502, val loss: 3.3854, elapsed time: 35.4013 min, mfu: 0.00215542, total_flops: 2.9759e+14 21 | step: 300, train loss: 2.9473, val loss: 2.9681, elapsed time: 37.1768 min, mfu: 0.00219760, total_flops: 3.1863e+14 22 | -------------------------------------------------------------------------------- /logs/lr_test_runs.txt: -------------------------------------------------------------------------------- 1 | 1e-4 2 | 3 | step: 10, train loss: 8.7964, val loss: 8.7941, elapsed time: 4.89 min, dt: 293679.71 ms 4 | step: 20, train loss: 7.7850, val loss: 7.7917, elapsed time: 6.56 min, dt: 100037.10 ms 5 | step: 30, train loss: 6.8389, val loss: 6.8585, elapsed time: 7.31 min, dt: 44992.72 ms 6 | step: 40, train loss: 6.0018, val loss: 6.0188, elapsed time: 8.05 min, dt: 44339.52 ms 7 | step: 50, train loss: 5.2365, val loss: 5.2694, elapsed time: 8.79 min, dt: 44438.28 ms 8 | step: 60, train loss: 4.6692, val loss: 4.6830, elapsed time: 9.53 min, dt: 44325.31 ms 9 | step: 70, train loss: 4.2921, val loss: 4.2862, elapsed time: 10.27 min, dt: 44460.13 ms 10 | step: 80, train loss: 4.0717, val loss: 4.0941, elapsed time: 11.01 min, dt: 44308.59 ms 11 | step: 90, train loss: 3.9965, val loss: 3.9957, elapsed time: 11.75 min, dt: 44342.15 ms 12 | step: 100, train loss: 3.9821, val loss: 3.9249, elapsed time: 12.49 min, dt: 44411.13 ms 13 | step: 110, train loss: 3.8991, val loss: 3.9044, elapsed time: 13.23 min, dt: 44269.83 ms 14 | step: 120, train loss: 3.8378, val loss: 3.8553, elapsed time: 13.96 min, dt: 44213.40 ms 15 | step: 130, train loss: 3.7199, val loss: 3.7348, elapsed time: 14.70 min, dt: 44302.53 ms 16 | step: 140, train loss: 3.5814, val loss: 3.5846, elapsed time: 15.44 min, dt: 44384.17 ms 17 | step: 150, train loss: 3.4695, val loss: 3.4810, elapsed time: 16.18 min, dt: 44302.74 ms 18 | step: 160, train loss: 3.3894, val loss: 3.3638, elapsed time: 16.92 min, dt: 44373.32 ms 19 | step: 170, train loss: 3.2664, val loss: 3.2955, elapsed time: 17.66 min, dt: 44479.63 ms 20 | step: 180, train loss: 3.1568, val loss: 3.1678, elapsed time: 18.40 min, dt: 44281.46 ms 21 | Saving checkpoint to checkpoints/Cl90Sp_check_180.pt 22 | 23 | 1e-3 24 | 25 | step: 10, train loss: 8.0774, val loss: 8.0777, elapsed time: 2.25 min, dt: 135175.40 ms 26 | step: 20, train loss: 5.4560, val loss: 5.4668, elapsed time: 3.33 min, dt: 64902.41 ms 27 | step: 30, train loss: 4.2038, val loss: 4.1714, elapsed time: 4.09 min, dt: 45321.44 ms 28 | step: 40, train loss: 3.9768, val loss: 3.9988, elapsed time: 4.83 min, dt: 44692.93 ms 29 | step: 50, train loss: 3.9563, val loss: 3.9668, elapsed time: 5.59 min, dt: 45009.04 ms 30 | step: 60, train loss: 3.9251, val loss: 3.9293, elapsed time: 6.33 min, dt: 44735.48 ms 31 | step: 70, train loss: 3.9318, val loss: 3.9181, elapsed time: 7.08 min, dt: 44715.18 ms 32 | step: 80, train loss: 3.8700, val loss: 3.9168, elapsed time: 7.82 min, dt: 44758.88 ms 33 | step: 90, train loss: 3.8777, val loss: 3.8872, elapsed time: 8.57 min, dt: 44604.11 ms 34 | step: 100, train loss: 3.8154, val loss: 3.8027, elapsed time: 9.31 min, dt: 44587.97 ms 35 | step: 110, train loss: 3.7157, val loss: 3.7052, elapsed time: 10.05 min, dt: 44641.93 ms 36 | step: 120, train loss: 3.5947, val loss: 3.5681, elapsed time: 10.80 min, dt: 44666.71 ms 37 | step: 130, train loss: 3.5073, val loss: 3.4977, elapsed time: 11.54 min, dt: 44679.17 ms 38 | step: 140, train loss: 3.3685, val loss: 3.3885, elapsed time: 12.29 min, dt: 44640.66 ms 39 | step: 150, train loss: 3.3154, val loss: 3.3111, elapsed time: 13.03 min, dt: 44698.11 ms 40 | step: 160, train loss: 3.2338, val loss: 3.2122, elapsed time: 13.77 min, dt: 44632.22 ms 41 | 42 | 3e-3 43 | 44 | step: 10, train loss: 6.0486, val loss: 6.0423, elapsed time: 2.26 min, dt: 135523.44 ms 45 | step: 20, train loss: 4.0450, val loss: 4.0298, elapsed time: 3.33 min, dt: 64531.28 ms 46 | step: 30, train loss: 3.9973, val loss: 4.0013, elapsed time: 4.09 min, dt: 45291.70 ms 47 | step: 40, train loss: 3.9750, val loss: 4.0184, elapsed time: 4.83 min, dt: 44723.28 ms 48 | step: 50, train loss: 3.9773, val loss: 3.9618, elapsed time: 5.58 min, dt: 44856.33 ms 49 | step: 60, train loss: 3.9682, val loss: 3.9468, elapsed time: 6.33 min, dt: 44627.84 ms 50 | step: 70, train loss: 3.9276, val loss: 3.9363, elapsed time: 7.07 min, dt: 44636.57 ms 51 | step: 80, train loss: 3.8621, val loss: 3.8797, elapsed time: 7.81 min, dt: 44497.26 ms 52 | step: 90, train loss: 3.7987, val loss: 3.7652, elapsed time: 8.55 min, dt: 44549.54 ms 53 | step: 100, train loss: 3.7104, val loss: 3.7015, elapsed time: 9.30 min, dt: 44666.38 ms 54 | step: 110, train loss: 3.5863, val loss: 3.6089, elapsed time: 10.04 min, dt: 44658.51 ms 55 | step: 120, train loss: 3.5288, val loss: 3.5259, elapsed time: 10.79 min, dt: 44808.46 ms 56 | step: 130, train loss: 3.4518, val loss: 3.4448, elapsed time: 11.53 min, dt: 44683.01 ms 57 | step: 140, train loss: 3.3864, val loss: 3.3816, elapsed time: 12.28 min, dt: 44669.32 ms 58 | step: 150, train loss: 3.2992, val loss: 3.3092, elapsed time: 13.02 min, dt: 44564.80 ms 59 | step: 160, train loss: 3.3091, val loss: 3.2976, elapsed time: 13.77 min, dt: 44639.72 ms 60 | step: 170, train loss: 3.2627, val loss: 3.2717, elapsed time: 14.51 min, dt: 44566.19 ms 61 | 62 | 8e-4 63 | 64 | step: 10, train loss: 8.1012, val loss: 8.1027, elapsed time: 2.24 min, dt: 134186.37 ms 65 | step: 20, train loss: 5.8543, val loss: 5.8421, elapsed time: 3.32 min, dt: 64865.44 ms 66 | step: 30, train loss: 4.4041, val loss: 4.4239, elapsed time: 4.07 min, dt: 44886.29 ms 67 | step: 40, train loss: 4.0104, val loss: 4.0330, elapsed time: 4.80 min, dt: 44320.51 ms 68 | step: 50, train loss: 3.9523, val loss: 3.9493, elapsed time: 5.55 min, dt: 44517.93 ms 69 | step: 60, train loss: 3.9473, val loss: 3.9172, elapsed time: 6.29 min, dt: 44325.69 ms 70 | step: 70, train loss: 3.9095, val loss: 3.9101, elapsed time: 7.03 min, dt: 44424.55 ms 71 | step: 80, train loss: 3.8870, val loss: 3.8780, elapsed time: 7.76 min, dt: 44290.22 ms 72 | step: 90, train loss: 3.8378, val loss: 3.8602, elapsed time: 8.50 min, dt: 44312.84 ms 73 | step: 100, train loss: 3.7882, val loss: 3.7610, elapsed time: 9.24 min, dt: 44404.35 ms 74 | step: 110, train loss: 3.6771, val loss: 3.6783, elapsed time: 9.98 min, dt: 44357.48 ms 75 | step: 120, train loss: 3.4918, val loss: 3.5198, elapsed time: 10.72 min, dt: 44345.01 ms 76 | step: 130, train loss: 3.3143, val loss: 3.3387, elapsed time: 11.46 min, dt: 44376.19 ms 77 | step: 140, train loss: 3.1061, val loss: 3.1340, elapsed time: 12.20 min, dt: 44407.92 ms 78 | 79 | 5e-4 80 | 81 | step: 10, train loss: 7.4084, val loss: 7.4085, elapsed time: 2.26 min, dt: 135703.52 ms 82 | step: 20, train loss: 4.4590, val loss: 4.4569, elapsed time: 3.97 min, dt: 102532.80 ms 83 | step: 30, train loss: 3.9711, val loss: 4.0074, elapsed time: 4.76 min, dt: 47304.89 ms 84 | step: 40, train loss: 3.9685, val loss: 3.9865, elapsed time: 5.53 min, dt: 46010.53 ms 85 | step: 50, train loss: 3.9490, val loss: 3.9342, elapsed time: 6.30 min, dt: 46481.06 ms 86 | step: 60, train loss: 3.9105, val loss: 3.9100, elapsed time: 7.07 min, dt: 46149.87 ms 87 | step: 70, train loss: 3.8982, val loss: 3.8994, elapsed time: 7.84 min, dt: 46214.96 ms 88 | step: 80, train loss: 3.8798, val loss: 3.8729, elapsed time: 8.61 min, dt: 46169.85 ms 89 | step: 90, train loss: 3.8716, val loss: 3.8647, elapsed time: 9.38 min, dt: 46242.58 ms 90 | step: 100, train loss: 3.8138, val loss: 3.7876, elapsed time: 10.15 min, dt: 46211.50 ms 91 | step: 110, train loss: 3.5337, val loss: 3.5820, elapsed time: 10.92 min, dt: 46306.79 ms 92 | step: 120, train loss: 3.2169, val loss: 3.2139, elapsed time: 11.69 min, dt: 46230.55 ms 93 | 94 | 3e-4 95 | 96 | step: 10, train loss: 8.6369, val loss: 8.6391, elapsed time: 2.23 min, dt: 133903.98 ms 97 | step: 20, train loss: 7.6096, val loss: 7.5991, elapsed time: 3.30 min, dt: 63902.89 ms 98 | step: 30, train loss: 6.6714, val loss: 6.6685, elapsed time: 4.04 min, dt: 44736.74 ms 99 | step: 40, train loss: 5.8167, val loss: 5.8201, elapsed time: 4.78 min, dt: 44215.30 ms 100 | step: 50, train loss: 5.0828, val loss: 5.0998, elapsed time: 5.52 min, dt: 44353.50 ms 101 | step: 60, train loss: 4.5565, val loss: 4.5469, elapsed time: 6.26 min, dt: 44238.05 ms 102 | step: 70, train loss: 4.2280, val loss: 4.2522, elapsed time: 7.00 min, dt: 44363.59 ms 103 | step: 80, train loss: 4.0634, val loss: 4.0536, elapsed time: 7.73 min, dt: 44229.74 ms 104 | step: 90, train loss: 3.9594, val loss: 3.9790, elapsed time: 8.47 min, dt: 44369.64 ms 105 | step: 100, train loss: 3.9177, val loss: 3.9497, elapsed time: 9.21 min, dt: 44358.37 ms 106 | step: 110, train loss: 3.8734, val loss: 3.8821, elapsed time: 9.95 min, dt: 44293.98 ms 107 | step: 120, train loss: 3.8265, val loss: 3.8212, elapsed time: 10.69 min, dt: 44366.00 ms 108 | step: 130, train loss: 3.7684, val loss: 3.7515, elapsed time: 11.43 min, dt: 44306.83 ms 109 | step: 140, train loss: 3.6896, val loss: 3.6767, elapsed time: 12.17 min, dt: 44368.17 ms 110 | step: 150, train loss: 3.6194, val loss: 3.5902, elapsed time: 12.91 min, dt: 44348.29 ms 111 | step: 160, train loss: 3.5321, val loss: 3.5389, elapsed time: 13.65 min, dt: 44384.85 ms 112 | step: 170, train loss: 3.4791, val loss: 3.4776, elapsed time: 14.38 min, dt: 44337.62 ms 113 | step: 180, train loss: 3.4409, val loss: 3.4231, elapsed time: 15.12 min, dt: 44342.49 ms 114 | step: 190, train loss: 3.3183, val loss: 3.3625, elapsed time: 15.86 min, dt: 44326.46 ms 115 | step: 200, train loss: 3.3048, val loss: 3.3348, elapsed time: 16.60 min, dt: 44310.72 ms 116 | step: 210, train loss: 3.2634, val loss: 3.2462, elapsed time: 17.34 min, dt: 44252.46 ms 117 | -------------------------------------------------------------------------------- /logs/shrunk_run.txt: -------------------------------------------------------------------------------- 1 | --- Attempting Distributed Initialization (TCP Method, Device: cuda) --- 2 | Successfully called init_process_group (TCP) with backend: nccl. 3 | Is distributed initialized after setup? True 4 | --- Finished Distributed Initialization Attempt --- 5 | found vocab_size = 50257 (inside data/tokenized_data/meta.pkl) 6 | Starting run EZwiE9 from scratch 7 | num Muon parameters: 24,720,384 8 | num AdamW parameters: 48,819,712 9 | Using Muon (optimizer 0) and AdamW (optimizer 1) 10 | compiling the model... 11 | compiled 12 | 73.540096 M trainable parameters 13 | 14 | Starting training loop from iteration 0... 15 | /usr/local/lib/python3.11/dist-packages/torch/optim/lr_scheduler.py:243: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose. 16 | warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) 17 | step: 10, train loss: 6.4833, val loss: 6.4835, elapsed time: 1.94 min, dt: 116364.03 ms 18 | step: 20, train loss: 3.9841, val loss: 3.9786, elapsed time: 2.80 min, dt: 51917.61 ms 19 | step: 30, train loss: 3.8680, val loss: 3.8540, elapsed time: 3.35 min, dt: 32916.64 ms 20 | step: 40, train loss: 3.8437, val loss: 3.8203, elapsed time: 3.89 min, dt: 32303.15 ms 21 | step: 50, train loss: 3.8178, val loss: 3.7975, elapsed time: 4.42 min, dt: 31744.37 ms 22 | step: 60, train loss: 3.7735, val loss: 3.7632, elapsed time: 4.96 min, dt: 32284.19 ms 23 | step: 70, train loss: 3.7154, val loss: 3.7056, elapsed time: 5.49 min, dt: 31933.22 ms 24 | step: 80, train loss: 3.6110, val loss: 3.6092, elapsed time: 6.02 min, dt: 32022.54 ms 25 | step: 90, train loss: 3.4527, val loss: 3.4277, elapsed time: 6.56 min, dt: 31916.10 ms 26 | step: 100, train loss: 3.2231, val loss: 3.2284, elapsed time: 7.09 min, dt: 31845.57 ms 27 | Saving checkpoint to checkpoints/EZwiE9_check_100.pt 28 | -------------------------------------------------------------------------------- /logs/tweaks_run_nosave.txt: -------------------------------------------------------------------------------- 1 | --- Attempting Distributed Initialization (TCP Method, Device: cuda) --- 2 | Successfully called init_process_group (TCP) with backend: nccl. 3 | Is distributed initialized after setup? True 4 | --- Finished Distributed Initialization Attempt --- 5 | found vocab_size = 100277 (inside data/tokenized_data/meta.pkl) 6 | Starting run 2yapG8 from scratch 7 | num Muon parameters: 57,824,256 8 | num AdamW parameters: 61,624,832 9 | Using Muon (optimizer 0) and AdamW (optimizer 1) 10 | compiling the model... 11 | compiled 12 | 119.449088 M trainable parameters 13 | 14 | Starting training loop from iteration 0... 15 | /usr/local/lib/python3.11/dist-packages/torch/optim/lr_scheduler.py:243: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose. 16 | warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) 17 | step: 10, train loss: 5.8910, val loss: 5.9025, elapsed time: 2.30 min, dt: 138136.45 ms 18 | step: 20, train loss: 4.0102, val loss: 4.0368, elapsed time: 3.45 min, dt: 69152.66 ms 19 | step: 30, train loss: 4.0095, val loss: 4.0048, elapsed time: 4.28 min, dt: 49494.04 ms 20 | step: 40, train loss: 3.9923, val loss: 3.9634, elapsed time: 5.10 min, dt: 49015.26 ms 21 | step: 50, train loss: 3.9593, val loss: 3.9881, elapsed time: 5.91 min, dt: 49056.54 ms 22 | step: 60, train loss: 3.9213, val loss: 3.9232, elapsed time: 6.73 min, dt: 48822.89 ms 23 | step: 70, train loss: 3.9156, val loss: 3.9319, elapsed time: 7.54 min, dt: 48766.85 ms 24 | step: 80, train loss: 3.8683, val loss: 3.8783, elapsed time: 8.35 min, dt: 48658.99 ms 25 | step: 90, train loss: 3.7697, val loss: 3.7703, elapsed time: 9.16 min, dt: 48777.39 ms 26 | step: 100, train loss: 3.6790, val loss: 3.7020, elapsed time: 9.98 min, dt: 48859.98 ms 27 | step: 110, train loss: 3.6369, val loss: 3.5895, elapsed time: 10.79 min, dt: 48859.54 ms 28 | step: 120, train loss: 3.5343, val loss: 3.5527, elapsed time: 11.60 min, dt: 48677.10 ms 29 | step: 130, train loss: 3.4485, val loss: 3.4703, elapsed time: 12.42 min, dt: 48793.44 ms 30 | step: 140, train loss: 3.3778, val loss: 3.3954, elapsed time: 13.23 min, dt: 48763.12 ms 31 | step: 150, train loss: 3.3642, val loss: 3.3163, elapsed time: 14.04 min, dt: 48821.04 ms 32 | step: 160, train loss: 3.2074, val loss: 3.1957, elapsed time: 14.86 min, dt: 48685.86 ms 33 | -------------------------------------------------------------------------------- /src/blank: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/data/data.py: -------------------------------------------------------------------------------- 1 | # parallel_tokenize.py (Updated with tokenizer padding) 2 | 3 | import os 4 | import pickle 5 | import numpy as np 6 | import glob 7 | from tqdm import tqdm 8 | import sys 9 | import multiprocessing as mp 10 | from functools import partial 11 | import random 12 | import tiktoken 13 | 14 | # Config 15 | shard_size = 10_000_000 16 | data_dir = "data" 17 | dataset_folder = "finewebedu10b/fineweb_chunks" 18 | output_dir = "tokenized_data" 19 | tiktoken_model_name = "cl100k_base" 20 | max_workers = max(1, mp.cpu_count() // 2) 21 | power_of_base_padding = 64 # Pad vocab to nearest power of this value 22 | 23 | try: 24 | enc = tiktoken.get_encoding(tiktoken_model_name) 25 | print(f"Loaded tiktoken encoder '{tiktoken_model_name}' with {enc.n_vocab:,} tokens") 26 | except Exception as e: 27 | print(f"Error initializing tiktoken encoder '{tiktoken_model_name}': {e}") 28 | sys.exit(1) 29 | 30 | def n64(n): 31 | return (64*(n//64))+64 32 | 33 | def write_shard(filename, tokens): 34 | if isinstance(tokens, np.ndarray): 35 | tokens = tokens.tolist() 36 | 37 | header = np.array([20240520, 1, len(tokens)], dtype=np.uint32) # Magic, Version, Length 38 | 39 | dtype = np.uint32 40 | 41 | try: 42 | with open(filename, 'wb') as f: 43 | f.write(header.tobytes()) 44 | f.write(np.array(tokens, dtype=dtype).tobytes()) 45 | except IOError as e: 46 | print(f"Error writing shard {filename}: {e}", file=sys.stderr) 47 | except Exception as e: 48 | print(f"Unexpected error writing shard {filename}: {e}", file=sys.stderr) 49 | 50 | worker_enc = None 51 | 52 | def init_worker(model_name): 53 | global worker_enc 54 | try: 55 | worker_enc = tiktoken.get_encoding(model_name) 56 | except Exception as e: 57 | print(f"Error initializing tiktoken in worker {os.getpid()}: {e}", file=sys.stderr) 58 | raise e 59 | 60 | def process_chunk(file_path): 61 | """Tokenizes a single file using the worker's tiktoken encoder. 62 | Always returns tuple (error_message_or_None, list_of_tokens).""" 63 | global worker_enc 64 | if worker_enc is None: 65 | # This should not happen if init_worker is called correctly, if you see this, you messed up 66 | return (f"Error: Tiktoken encoder not initialized in worker for file {file_path}", []) 67 | 68 | try: 69 | with open(file_path, 'r', encoding='utf-8') as f: 70 | text = f.read() 71 | tokens = worker_enc.encode_ordinary(text) 72 | return (None, tokens) 73 | except FileNotFoundError: 74 | return (f"Error: File not found {file_path}", []) 75 | except UnicodeDecodeError as e: 76 | return (f"Error decoding file {file_path}: {e}", []) 77 | except Exception as e: 78 | return (f"Error processing {file_path}: {e}", []) 79 | 80 | def parallel_tokenize(files, workers, model_name): 81 | """mp pool""" 82 | results = [] 83 | errors = [] 84 | 85 | chunksize = max(1, len(files) // (workers * 4) if workers > 0 else len(files)) 86 | 87 | try: 88 | with mp.Pool(workers, initializer=init_worker, initargs=(model_name,)) as pool: 89 | with tqdm(total=len(files), desc="Tokenizing Chunks") as pbar: 90 | for error_msg, tokens in pool.imap(process_chunk, files, chunksize=chunksize): 91 | if error_msg: 92 | errors.append(error_msg) 93 | elif tokens: # Check if tokens list is not empty 94 | results.extend(tokens) 95 | pbar.update(1) 96 | 97 | except Exception as e: 98 | print(f"\nCritical error during parallel processing: {e}", file=sys.stderr) # kinda basic handling 99 | return [], errors 100 | 101 | if errors: 102 | print(f"\nEncountered {len(errors)} errors during tokenization (showing first 10):") 103 | for i, err in enumerate(errors[:10]): 104 | print(f" {i+1}. {err}") 105 | if len(errors) > 10: 106 | print(f" ... and {len(errors) - 10} more errors.") 107 | 108 | return results 109 | 110 | def create_shards(tokens, output_dir, split_name): 111 | os.makedirs(output_dir, exist_ok=True) 112 | num_shards = (len(tokens) + shard_size - 1) // shard_size 113 | print(f"Writing {len(tokens):,} tokens into {num_shards} shards for '{split_name}' split...") 114 | 115 | shard_tokens = [] 116 | 117 | token_count = 0 118 | shard_index = 0 119 | 120 | for token in tqdm(tokens, desc=f"Writing {split_name} shards", total=len(tokens)): 121 | shard_tokens.append(token) 122 | token_count += 1 123 | if len(shard_tokens) >= shard_size: 124 | shard_filename = os.path.join(output_dir, f"{split_name}_{shard_index:06d}.bin") 125 | write_shard(shard_filename, shard_tokens) 126 | shard_index += 1 127 | shard_tokens = [] 128 | 129 | if shard_tokens: 130 | shard_filename = os.path.join(output_dir, f"{split_name}_{shard_index:06d}.bin") 131 | write_shard(shard_filename, shard_tokens) 132 | shard_index += 1 # match num_shards calculated earlier 133 | 134 | if shard_index != num_shards: 135 | print(f"Warning: Expected {num_shards} shards, but wrote {shard_index} for {split_name}", file=sys.stderr) 136 | 137 | return shard_index 138 | 139 | if __name__ == '__main__': 140 | 141 | input_path = os.path.join(data_dir, dataset_folder) 142 | print(f"Searching for .txt files in: {input_path}") 143 | files = sorted(glob.glob(os.path.join(input_path, "*.txt")))[:100] 144 | 145 | if not files: 146 | print(f"Error: No .txt files found in {input_path}", file=sys.stderr) 147 | sys.exit(1) 148 | 149 | print(f"Found {len(files)} .txt files.") 150 | 151 | # Shuffle and split files 152 | random.seed(42) 153 | random.shuffle(files) 154 | split_idx = int(0.9 * len(files)) 155 | train_files, val_files = files[:split_idx], files[split_idx:] 156 | print(f"Splitting into {len(train_files)} train files and {len(val_files)} validation files.") 157 | 158 | print(f"\nStarting tokenization for TRAINING split ({len(train_files)} files)...") 159 | train_tokens = parallel_tokenize(train_files, max_workers, tiktoken_model_name) 160 | if not train_tokens: 161 | print("No training tokens were generated. Check errors.", file=sys.stderr) 162 | else: 163 | print(f"Successfully tokenized training files. Total tokens: {len(train_tokens):,}") 164 | train_shards = create_shards(train_tokens, output_dir, "train") 165 | print(f"Finished writing {train_shards} training shards.") 166 | del train_tokens 167 | 168 | print(f"\nStarting tokenization for VALIDATION split ({len(val_files)} files)...") 169 | val_tokens = parallel_tokenize(val_files, max_workers, tiktoken_model_name) 170 | if not val_tokens: 171 | print("No validation tokens were generated. Check errors.", file=sys.stderr) 172 | else: 173 | print(f"Successfully tokenized validation files. Total tokens: {len(val_tokens):,}") 174 | val_shards = create_shards(val_tokens, output_dir, "val") 175 | print(f"Finished writing {val_shards} validation shards.") 176 | del val_tokens 177 | 178 | meta_path = os.path.join(output_dir, 'meta.pkl') 179 | print(f"\nSaving metadata to {meta_path}...") 180 | 181 | try: 182 | enc_meta = tiktoken.get_encoding(tiktoken_model_name) 183 | padded_vocab_size = n64(enc_meta.n_vocab) 184 | 185 | metadata = { 186 | 'vocab_size': padded_vocab_size, 187 | 'block_size': 1024, 188 | 'tokenizer': tiktoken_model_name, 189 | 'num_train_shards': train_shards if 'train_shards' in locals() else 0, 190 | 'num_val_shards': val_shards if 'val_shards' in locals() else 0, 191 | } 192 | 193 | with open(meta_path, 'wb') as f: 194 | pickle.dump(metadata, f) 195 | 196 | print("Metadata saved successfully:") 197 | print(metadata) 198 | print(f"Padded vocab_size: {padded_vocab_size}, Original vocab_size: {original_vocab_size}") # Print both for clarity 199 | 200 | except Exception as e: 201 | print(f"Error saving metadata: {e}", file=sys.stderr) 202 | 203 | print(f"\nTokenization complete. Output shards in: {output_dir}") 204 | print(f"Total shards created: {metadata.get('num_train_shards', 0)} (train) + {metadata.get('num_val_shards', 0)} (val)") 205 | print(f"Remember to update 'vocab_size' in your `config.py` to: {metadata.get('vocab_size')}") 206 | -------------------------------------------------------------------------------- /src/data/finewebedu10b/download.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import os 3 | import time 4 | from tqdm import tqdm 5 | 6 | # Configuration 7 | dataset_name = "HuggingFaceFW/fineweb-edu" 8 | dataset_subset = "sample-10BT" 9 | output_folder = "./fineweb_chunks" # Folder to save chunk files 10 | chunk_size = 10000 # Number of examples per chunk file 11 | 12 | # Create output folder if it doesn't exist 13 | os.makedirs(output_folder, exist_ok=True) 14 | 15 | # Load the dataset (sample-10BT subset) 16 | dataset = load_dataset(dataset_name, dataset_subset, split="train", streaming=True) 17 | 18 | file_counter = 0 19 | example_counter = 0 20 | current_file = None 21 | 22 | start_time = time.time() # Start time for overall progress 23 | 24 | total_examples = 10_000_000 # Approximate total examples in sample-10BT (adjust if needed for other subsets) 25 | 26 | print(f"Downloading and chunking dataset '{dataset_name}/{dataset_subset}'...") 27 | 28 | # Wrap dataset iteration with tqdm for progress bar 29 | with tqdm(total=total_examples, desc="Chunking Dataset") as progress_bar: 30 | for example in dataset: 31 | text_content = example["text"] 32 | 33 | if example_counter % chunk_size == 0: 34 | # Close the previous file if it's open 35 | if current_file: 36 | current_file.close() 37 | 38 | # Create a new chunk file 39 | chunk_filename = os.path.join(output_folder, f"chunk_{file_counter:05d}.txt") # 5 digits for chunk number 40 | current_file = open(chunk_filename, "w", encoding="utf-8") 41 | progress_bar.write(f"Creating chunk file: {chunk_filename}") # Use progress_bar.write to avoid messing up progress bar 42 | file_counter += 1 43 | 44 | if text_content: # Check if text content is not None or empty 45 | current_file.write(text_content) 46 | current_file.write("\n\n<|file_separator|>\n\n") # Add a separator between examples 47 | 48 | example_counter += 1 49 | progress_bar.update(1) # Increment progress bar by 1 example 50 | 51 | # Close the last file 52 | if current_file: 53 | current_file.close() 54 | 55 | end_time = time.time() # End time 56 | elapsed_time = end_time - start_time # Calculate elapsed time 57 | 58 | print(f"Dataset chunking complete.") 59 | print(f"Saved chunk files to '{output_folder}'.") 60 | print(f"Total examples processed: {example_counter}") 61 | print(f"Total chunk files created: {file_counter}") 62 | print(f"Elapsed time: {elapsed_time:.2f} seconds") # Print elapsed time 63 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | # Add old pre-deepseek meta MTP and scale with that 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | import math 7 | import inspect 8 | from muon import Muon 9 | 10 | config = { 11 | "n_embd": 256, 12 | "n_head": 16, 13 | "n_layer": 4, 14 | "n_experts": 32, 15 | "dropout": 0.2, 16 | "vocab_size": 65, 17 | "ctx_len": 2048, 18 | "init_moe_scaling": 1.25, 19 | "type": ['mlp', 'moe', 'mlp', 'moe'], 20 | "device": 'cuda' if torch.cuda.is_available() else 'cpu' 21 | } 22 | 23 | # RoPE 24 | 25 | class RoPE(nn.Module): 26 | def __init__(self, d, base=100_000_000_000, device=config['device']): 27 | super().__init__() 28 | 29 | self.base = base 30 | self.d = d 31 | self.device = device 32 | self.cos_cached = None 33 | self.sin_cached = None 34 | 35 | def _build_cache(self, x): 36 | if self.cos_cached is not None: 37 | return 38 | 39 | head_dim = x.shape[-1] 40 | 41 | theta = 1 / (self.base ** (torch.arange(0, head_dim, 2, device=self.device).float() / self.d)) 42 | seq_idx = torch.arange(x.shape[0], device=self.device).float() 43 | idx_theta = torch.einsum('n,d->nd', seq_idx, theta) 44 | 45 | cos_cache = torch.cos(idx_theta) 46 | sin_cache = torch.sin(idx_theta) 47 | 48 | self.cos_cached = torch.cat([cos_cache, cos_cache], dim=-1).unsqueeze(0).unsqueeze(0) 49 | self.sin_cached = torch.cat([sin_cache, sin_cache], dim=-1).unsqueeze(0).unsqueeze(0) 50 | 51 | def _neg_half(self, x): 52 | head_dim = x.shape[-1] 53 | d_2 = head_dim // 2 54 | return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) 55 | 56 | def forward(self, x): 57 | if self.cos_cached is None or self.cos_cached.shape[2] != x.shape[1]: 58 | self._build_cache(x) 59 | 60 | x_rope = x.clone() # VERY IMPORTANT: Create a copy! 61 | neg_half_x = self._neg_half(x_rope) 62 | x_out = (x_rope * self.cos_cached[:, :, :x.shape[1], :]) + (neg_half_x * self.sin_cached[:, :, :x.shape[1], :]) 63 | return x_out 64 | 65 | def precompute_freqs_cis(dim, end, device, theta=10000.0): 66 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim)) 67 | t = torch.arange(end, device=device) 68 | freqs = torch.outer(t, freqs) 69 | return torch.cos(freqs), torch.sin(freqs) 70 | 71 | def apply_rope(x: torch.Tensor, y: torch.Tensor, freqs_cis) -> tuple[torch.Tensor,torch.Tensor]: 72 | cos_freqs, sin_freqs = freqs_cis 73 | seq_len = x.shape[-2] 74 | 75 | cos_seq = cos_freqs[:seq_len] 76 | sin_seq = sin_freqs[:seq_len] 77 | cos_seq = cos_seq.unsqueeze(0).unsqueeze(0) 78 | sin_seq = sin_seq.unsqueeze(0).unsqueeze(0) 79 | x_real, x_imag = x.chunk(2, dim=-1) 80 | y_real, y_imag = y.chunk(2, dim=-1) 81 | x_rotated_real = x_real * cos_seq - x_imag * sin_seq 82 | x_rotated_imag = x_real * sin_seq + x_imag * cos_seq 83 | y_rotated_real = y_real * cos_seq - y_imag * sin_seq 84 | y_rotated_imag = y_real * sin_seq + y_imag * cos_seq 85 | x_rotated = torch.cat([x_rotated_real, x_rotated_imag], dim=-1) 86 | y_rotated = torch.cat([y_rotated_real, y_rotated_imag], dim=-1) 87 | return x_rotated.type_as(x), y_rotated.type_as(y) 88 | 89 | # MLA-NSA hybrid, not hardware optimized, just uses NSA sparsity for better training rn 90 | 91 | class Attn(nn.Module): 92 | """ 93 | Native Sparse Attention with Multi-headed Latent Attention integration. 94 | Combines MLA's compression techniques with NSA's natural sparsity, also better loss 95 | """ 96 | def __init__(self): 97 | super().__init__() 98 | self.device = config['device'] 99 | self.n_embd = config['n_embd'] 100 | self.n_head = config['n_head'] 101 | self.dropout = config['dropout'] 102 | self.ctx_len = config['ctx_len'] 103 | self.rms_norm_eps = config.get('rms_norm_eps', 1e-6) 104 | 105 | # Original MLA parameters 106 | self.v_head_dim = 32 107 | self.kv_lora_rank = 32 108 | self.q_lora_rank = 3 * self.kv_lora_rank 109 | self.rope_head_dim = 64 110 | self.nope_head_dim = 32 111 | self.value_dim = self.n_head * self.v_head_dim 112 | self.nope_dim = self.n_head * self.nope_head_dim 113 | self.rope_dim = self.n_head * self.rope_head_dim 114 | 115 | # NSA-specific parameters 116 | self.block_size = config.get('block_size', 16) # Size of token blocks for compression 117 | self.num_blocks = self.ctx_len // self.block_size 118 | self.window_size = config.get('window_size', 128) # Sliding window size 119 | self.num_tokens_to_keep = config.get('num_tokens_to_keep', self.ctx_len // 4) # Number of fine-grained tokens to keep 120 | 121 | # === Branch 1: Coarse-grained compression branch (adapted from MLA) === 122 | self.compress_q_linear = nn.Linear(self.n_embd, self.q_lora_rank, bias=False) 123 | self.q_norm = nn.RMSNorm(self.q_lora_rank, eps=self.rms_norm_eps) 124 | self.decompress_q_nope = nn.Linear(self.q_lora_rank, self.nope_dim, bias=False) 125 | self.decompress_q_rope = nn.Linear(self.q_lora_rank, self.rope_dim, bias=False) 126 | 127 | self.compress_kv_linear = nn.Linear(self.n_embd, self.kv_lora_rank, bias=False) 128 | self.kv_norm = nn.RMSNorm(self.kv_lora_rank, eps=self.rms_norm_eps) 129 | self.decompress_k_nope = nn.Linear(self.kv_lora_rank, self.nope_dim, bias=False) 130 | self.decompress_v_linear = nn.Linear(self.kv_lora_rank, self.value_dim, bias=False) 131 | self.k_rope_linear = nn.Linear(self.n_embd, self.rope_head_dim, bias=False) 132 | 133 | # === Branch 2: Token Selection Branch (NSA) === 134 | # Components for importance-based token selection 135 | self.importance_scorer = nn.Linear(self.n_embd, 1,bias=False) 136 | # Independent KV for selected tokens 137 | self.selection_k = nn.Linear(self.n_embd, self.n_head * (self.rope_head_dim + self.nope_head_dim), bias=False) 138 | self.selection_v = nn.Linear(self.n_embd, self.value_dim, bias=False) 139 | 140 | # === Branch 3: Sliding Window Branch (NSA) === 141 | # Independent KV for sliding window 142 | self.window_k = nn.Linear(self.n_embd, self.n_head * (self.rope_head_dim + self.nope_head_dim), bias=False) 143 | self.window_v = nn.Linear(self.n_embd, self.value_dim, bias=False) 144 | 145 | # Token Compression Mechanism (NSA) 146 | self.block_compressor = nn.Sequential( 147 | nn.Linear(self.block_size * self.n_embd, 4 * self.n_embd,bias=False), 148 | nn.GELU(), 149 | nn.Linear(4 * self.n_embd, self.n_embd,bias=False) 150 | ) 151 | 152 | # Intra-block position encoding 153 | self.intra_block_pos_encoding = nn.Parameter( 154 | torch.randn(1, self.block_size, self.n_embd) 155 | ) 156 | 157 | # Gated Multi-Branch Integration (NSA) 158 | self.branch_gate = nn.Linear(self.n_embd, 3,bias=False) # 3 gates for 3 branches 159 | 160 | # Output projection 161 | self.proj = nn.Linear(self.value_dim, self.n_embd, bias=False) 162 | self.res_dropout = nn.Dropout(p=self.dropout) 163 | 164 | # Caching for inference 165 | self.k_cache = None 166 | self.v_cache = None 167 | self.cache_filled = 0 168 | 169 | # RoPE 170 | self.rope = RoPE(self.rope_head_dim, device=self.device) 171 | self.freqs_cis = precompute_freqs_cis(self.rope_head_dim, self.ctx_len, self.device) 172 | 173 | def _compress_tokens(self, x): 174 | """Token compression mechanism from NSA""" 175 | B, T, C = x.size() 176 | 177 | # Ensure T is divisible by block_size for simplicity 178 | padded_len = ((T + self.block_size - 1) // self.block_size) * self.block_size 179 | if padded_len > T: 180 | padding = torch.zeros(B, padded_len - T, C, device=x.device, dtype=x.dtype) 181 | x_padded = torch.cat([x, padding], dim=1) 182 | else: 183 | x_padded = x 184 | 185 | # Add intra-block position encoding 186 | blocks = x_padded.view(B, -1, self.block_size, C) 187 | pos_encoded_blocks = blocks + self.intra_block_pos_encoding 188 | 189 | # Reshape for compression 190 | blocks_flat = pos_encoded_blocks.view(B, -1, self.block_size * C) 191 | 192 | # Apply block compression 193 | compressed_blocks = self.block_compressor(blocks_flat) 194 | 195 | return compressed_blocks 196 | 197 | def _select_important_tokens(self, x, importance_scores): 198 | """Select the most important tokens based on scores""" 199 | B, T, _ = x.size() 200 | 201 | # Get indices of top-k tokens by importance 202 | _, indices = torch.topk(importance_scores.squeeze(-1), 203 | min(self.num_tokens_to_keep, T), 204 | dim=1) 205 | 206 | # Sort indices to maintain sequence order (continuity-aware) 207 | indices, _ = torch.sort(indices, dim=1) 208 | 209 | # Gather selected tokens 210 | batch_indices = torch.arange(B, device=x.device).unsqueeze(1).expand(-1, indices.size(1)) 211 | selected_tokens = x[batch_indices, indices] 212 | 213 | return selected_tokens, indices 214 | 215 | def _get_sliding_window_tokens(self, x, current_pos=None): 216 | """Extract tokens within the sliding window""" 217 | if self.training or current_pos is None: 218 | # During training, we can use the whole sequence with windowed attention 219 | return x 220 | else: 221 | # During inference, get a window centered around the current position 222 | B, T, _ = x.size() 223 | window_start = max(0, current_pos - self.window_size // 2) 224 | window_end = min(T, window_start + self.window_size) 225 | return x[:, window_start:window_end] 226 | 227 | def forward(self, x): 228 | B, T, C = x.size() 229 | 230 | # === Prepare queries using MLA's approach === 231 | compressed_q = self.compress_q_linear(x) 232 | norm_q = self.q_norm(compressed_q) 233 | query_nope = self.decompress_q_nope(norm_q) 234 | query_rope = self.decompress_q_rope(norm_q) 235 | 236 | # Reshape and transpose queries 237 | query_nope = query_nope.view(B, T, self.n_head, self.nope_head_dim).transpose(1, 2) 238 | query_rope = query_rope.view(B, T, self.n_head, self.rope_head_dim).transpose(1, 2) 239 | 240 | # Apply RoPE to query 241 | q_rope, _ = apply_rope(query_rope, query_rope, self.freqs_cis) # Corrected 242 | 243 | # Recombine query parts 244 | q_recombined = torch.empty((B, self.n_head, T, self.rope_head_dim + self.nope_head_dim), 245 | device=x.device, dtype=x.dtype) 246 | q_recombined[:, :, :, :self.nope_head_dim] = query_nope 247 | q_recombined[:, :, :, self.nope_head_dim:] = q_rope 248 | 249 | # Compute branch gates for dynamic weighting 250 | branch_weights = F.softmax(self.branch_gate(x).mean(dim=1), dim=-1) # [B, 3] 251 | 252 | # === Branch 1: Coarse-grained compression branch (from MLA) === 253 | compressed_kv = self.compress_kv_linear(x) 254 | norm_kv = self.kv_norm(compressed_kv) 255 | key_nope_1 = self.decompress_k_nope(norm_kv) 256 | value_1 = self.decompress_v_linear(norm_kv) 257 | key_rope_1 = self.k_rope_linear(x) 258 | 259 | # Reshape keys and values 260 | key_nope_1 = key_nope_1.view(B, T, self.n_head, self.nope_head_dim).transpose(1, 2) 261 | key_rope_1 = key_rope_1.view(B, T, 1, self.rope_head_dim).transpose(1, 2) 262 | value_1 = value_1.view(B, T, self.n_head, self.v_head_dim).transpose(1, 2) 263 | 264 | # Apply RoPE to keys 265 | key_rope_1 = key_rope_1 / self.n_head # Scale like in original code 266 | _, k_rope_1 = apply_rope(key_rope_1, key_rope_1, self.freqs_cis) # Corrected 267 | 268 | # Recombine key parts for branch 1 269 | k_recombined_1 = torch.empty((B, self.n_head, T, self.rope_head_dim + self.nope_head_dim), 270 | device=x.device, dtype=x.dtype) 271 | k_recombined_1[:, :, :, :self.nope_head_dim] = key_nope_1 272 | k_recombined_1[:, :, :, self.nope_head_dim:] = k_rope_1 273 | 274 | # === Branch 2: Token Selection Branch (NSA) === 275 | # Compute importance scores 276 | importance_scores = self.importance_scorer(x) 277 | 278 | # Select important tokens 279 | selected_tokens, selected_indices = self._select_important_tokens(x, importance_scores) 280 | 281 | # Get KV for selected tokens 282 | B, S, _ = selected_tokens.size() # S is the number of selected tokens 283 | k_selected = self.selection_k(selected_tokens) 284 | v_selected = self.selection_v(selected_tokens) 285 | 286 | # Reshape 287 | k_selected = k_selected.view(B, S, self.n_head, self.rope_head_dim + self.nope_head_dim).transpose(1, 2) 288 | v_selected = v_selected.view(B, S, self.n_head, self.v_head_dim).transpose(1, 2) 289 | 290 | # Apply RoPE (only to the RoPE portion) 291 | k_selected_rope = k_selected[:, :, :, self.nope_head_dim:] 292 | k_selected_nope = k_selected[:, :, :, :self.nope_head_dim] 293 | # Corrected: pass k_selected_rope for both x and y 294 | _, k_selected_rope = apply_rope(k_selected_rope, k_selected_rope, self.freqs_cis) 295 | 296 | 297 | # Recombine 298 | k_selected[:, :, :, self.nope_head_dim:] = k_selected_rope 299 | k_selected[:, :, :, :self.nope_head_dim] = k_selected_nope # make sure we add the nope back! 300 | 301 | # === Branch 3: Sliding Window Branch (NSA) === 302 | window_tokens = self._get_sliding_window_tokens(x) 303 | B, W, _ = window_tokens.size() # W is window size 304 | 305 | k_window = self.window_k(window_tokens) 306 | v_window = self.window_v(window_tokens) 307 | 308 | # Reshape 309 | k_window = k_window.view(B, W, self.n_head, self.rope_head_dim + self.nope_head_dim).transpose(1, 2) 310 | v_window = v_window.view(B, W, self.n_head, self.v_head_dim).transpose(1, 2) 311 | 312 | # Apply RoPE (only to the RoPE portion) 313 | k_window_rope = k_window[:, :, :, self.nope_head_dim:] 314 | k_window_nope = k_window[:, :, :, :self.nope_head_dim] 315 | # Corrected: pass k_window_rope for both x and y 316 | _, k_window_rope = apply_rope(k_window_rope, k_window_rope, self.freqs_cis) 317 | 318 | 319 | # Recombine 320 | k_window[:, :, :, self.nope_head_dim:] = k_window_rope 321 | k_window[:, :, :, :self.nope_head_dim] = k_window_nope 322 | 323 | # === Compute attention for each branch and blend results === 324 | if self.training: 325 | self.cache_filled = 0 326 | 327 | # Branch 1: Original MLA attention with full sequence 328 | output_1 = F.scaled_dot_product_attention( 329 | q_recombined, k_recombined_1, value_1, 330 | is_causal=True, dropout_p=self.dropout 331 | ) 332 | 333 | # Branch 2: Attention with selected tokens 334 | # For selected tokens, we need to compute attention differently 335 | # as they might not be in sequence order 336 | output_2 = F.scaled_dot_product_attention( 337 | q_recombined, k_selected, v_selected, 338 | is_causal=False, dropout_p=self.dropout # Non-causal for selected tokens 339 | ) 340 | 341 | # Branch 3: Sliding window attention 342 | output_3 = F.scaled_dot_product_attention( 343 | q_recombined, k_window, v_window, 344 | is_causal=True, dropout_p=self.dropout 345 | ) 346 | 347 | # Blend outputs using branch weights 348 | blended_output = ( 349 | output_1 * branch_weights[:, 0].view(B, 1, 1, 1) + 350 | output_2 * branch_weights[:, 1].view(B, 1, 1, 1) + 351 | output_3 * branch_weights[:, 2].view(B, 1, 1, 1) 352 | ) 353 | 354 | else: 355 | # Inference mode with KV caching 356 | if self.k_cache is None or self.v_cache is None or self.k_cache.size(0) != B: 357 | self.k_cache = torch.zeros( 358 | B, self.n_head, self.ctx_len, self.rope_head_dim + self.nope_head_dim, 359 | device=self.device, dtype=x.dtype 360 | ) 361 | self.v_cache = torch.zeros( 362 | B, self.n_head, self.ctx_len, self.v_head_dim, 363 | device=self.device, dtype=x.dtype 364 | ) 365 | self.cache_filled = 0 366 | 367 | # Update cache with new tokens 368 | new_cache_filled = min(self.cache_filled + T, self.ctx_len) 369 | 370 | # Branch 1: Update cache 371 | k_to_cache = k_recombined_1[:, :, :new_cache_filled - self.cache_filled] 372 | v_to_cache = value_1[:, :, :new_cache_filled - self.cache_filled] 373 | 374 | self.k_cache[:, :, self.cache_filled:new_cache_filled] = k_to_cache 375 | self.v_cache[:, :, self.cache_filled:new_cache_filled] = v_to_cache 376 | self.cache_filled = new_cache_filled 377 | 378 | # Get cached KVs 379 | k1 = self.k_cache[:, :, :self.cache_filled] 380 | v1 = self.v_cache[:, :, :self.cache_filled] 381 | 382 | # Branch 1: Attention with cached KVs 383 | output_1 = F.scaled_dot_product_attention( 384 | q_recombined, k1, v1, is_causal=True, dropout_p=0 385 | ) 386 | 387 | # Branch 2: Attention with selected tokens (from current sequence) 388 | output_2 = F.scaled_dot_product_attention( 389 | q_recombined, k_selected, v_selected, is_causal=False, dropout_p=0 390 | ) 391 | 392 | # Branch 3: Sliding window attention 393 | current_pos = self.cache_filled - 1 # Current position for window centering 394 | output_3 = F.scaled_dot_product_attention( 395 | q_recombined, k_window, v_window, is_causal=True, dropout_p=0 396 | ) 397 | 398 | # Blend outputs using branch weights 399 | blended_output = ( 400 | output_1 * branch_weights[:, 0].view(B, 1, 1, 1) + 401 | output_2 * branch_weights[:, 1].view(B, 1, 1, 1) + 402 | output_3 * branch_weights[:, 2].view(B, 1, 1, 1) 403 | ) 404 | 405 | # Final processing 406 | output = blended_output.transpose(1, 2).contiguous().view(B, T, self.value_dim) 407 | output = self.proj(output) 408 | output = self.res_dropout(output) 409 | 410 | return output 411 | 412 | # Reg MLP 413 | 414 | class MLP(nn.Module): 415 | def __init__(self): 416 | super().__init__() 417 | n_embd = config['n_embd'] 418 | self.c_fc = nn.Linear(n_embd, 4 * n_embd,bias=False) 419 | self.c_proj = nn.Linear(4 * n_embd, n_embd,bias=False) 420 | self.dropout = nn.Dropout(config['dropout']) 421 | 422 | def forward(self, x): 423 | x = self.c_fc(x) 424 | x = F.relu(x).square() # relu sq, not gelu 425 | x = self.c_proj(x) 426 | x = self.dropout(x) 427 | return x 428 | 429 | # DS-MoE Layer 430 | 431 | class UnitCenteredNoise(nn.Module): 432 | def __init__(self, scaling=0.02): 433 | super(UnitCenteredNoise, self).__init__() 434 | self.scaling = scaling 435 | self.base = 1 - (scaling * 0.5) 436 | 437 | def forward(self, x): 438 | if self.training: 439 | noise = torch.rand(x.size(), device=x.device, dtype=x.dtype) 440 | noise_centered = (noise * self.scaling) + self.base 441 | return x * noise_centered 442 | else: 443 | return x 444 | 445 | class DSMoE(nn.Module): 446 | 447 | def __init__(self, index, num_exp=4): 448 | super().__init__() 449 | self.hidden_dim = config['n_embd'] * 2 # was 4, had to shrink by 1/2 450 | self.num_experts = config["n_experts"] 451 | self.num_exp = num_exp 452 | self.moe_scaling = config["init_moe_scaling"] 453 | self.experts = nn.ModuleList([MLP() for _ in range(self.num_experts)]) 454 | self.gate = nn.Sequential( 455 | nn.Linear(config['n_embd'], self.num_experts - 1,bias=False), # exclude shared expert 456 | UnitCenteredNoise(scaling=0.02), 457 | nn.Softmax(dim=-1) 458 | ) 459 | # Initialize expert bias (excluding the shared expert) 460 | self.expert_bias = nn.Parameter(torch.zeros(self.num_experts - 1), requires_grad=False) 461 | 462 | 463 | def forward(self, x): 464 | b, t, c = x.shape 465 | x_flat = x.reshape(b * t, c) 466 | 467 | gate_val_continuous = self.gate(x_flat) 468 | 469 | # Apply expert bias *before* topk 470 | biased_gate_vals = gate_val_continuous + self.expert_bias 471 | 472 | # get top-(num_exp-1) experts excluding the first one 473 | gate_vals, gate_val_indices = torch.topk(biased_gate_vals, self.num_exp - 1, dim=-1) 474 | gate_vals = gate_vals / gate_vals.sum(dim=-1, keepdim=True) # normalize 475 | 476 | # prepend the shared expert (index 0) - Corrected handling 477 | shared_expert_weight = torch.ones_like(gate_vals[:, :1]) / self.num_exp 478 | gate_vals = torch.cat([shared_expert_weight, gate_vals * (self.num_exp - 1) / self.num_exp], dim=-1) 479 | gate_val_indices = torch.cat([torch.zeros_like(gate_val_indices[:, :1]), gate_val_indices + 1], dim=-1) 480 | 481 | # process all experts once (fully static) 482 | expert_outputs = torch.stack([expert(x_flat) for expert in self.experts], dim=0) # [num_experts, b*t, c] 483 | 484 | # create routing weights matrix (one-hot * gate values) 485 | router_weights = torch.zeros(x_flat.size(0), self.num_experts, device=x.device) 486 | for i in range(self.num_exp): 487 | idx = gate_val_indices[:, i:i+1] # [b*t, 1] 488 | val = gate_vals[:, i:i+1] # [b*t, 1] 489 | router_weights.scatter_add_(1, idx, val) 490 | 491 | # apply routing weights to expert outputs 492 | weighted_outputs = expert_outputs * router_weights.transpose(0, 1).unsqueeze(-1) # [num_experts, b*t, c] 493 | output = weighted_outputs.sum(dim=0) # [b*t, c] 494 | 495 | # Return both the output and the router_weights 496 | return output.reshape(b, t, c), router_weights 497 | 498 | class Block(nn.Module): 499 | def __init__(self, index): 500 | super().__init__() 501 | n_embd = config['n_embd'] 502 | self.attn = Attn() 503 | self.ffn_type = config['type'][index] 504 | 505 | if self.ffn_type == "mlp": 506 | self.ffn = MLP() 507 | elif self.ffn_type == "moe": 508 | self.ffn = DSMoE(index) 509 | else: 510 | raise ValueError(f"Invalid layer type: {self.ffn_type}") 511 | 512 | self.rm1 = nn.RMSNorm(n_embd) 513 | self.rm2 = nn.RMSNorm(n_embd) 514 | 515 | def forward(self, x): 516 | 517 | x = x + self.attn(self.rm1(x)) 518 | 519 | if self.ffn_type == "moe": 520 | x_ffn, router_weights = self.ffn(self.rm2(x)) 521 | return x + x_ffn, router_weights 522 | 523 | else: 524 | x_ffn = self.ffn(self.rm2(x)) 525 | return x + x_ffn, None # no MoE, no route weights 526 | 527 | class Transformer(nn.Module): 528 | def __init__(self): 529 | super().__init__() 530 | self.config = config 531 | self.token_embedding_table = nn.Embedding(config['vocab_size'], config['n_embd']) 532 | self.position_embedding_table = nn.Embedding(config['ctx_len'], config['n_embd']) 533 | self.blocks = nn.Sequential(*[Block(i) for i in range(config['n_layer'])]) 534 | self.rm_f = nn.RMSNorm(config['n_embd']) 535 | self.lm_head = nn.Linear(config['n_embd'], config['vocab_size'],bias=False) 536 | self.token_embedding_table.weight = self.lm_head.weight 537 | self.apply(self._init_weights) 538 | self.total_params = sum(p.numel() for p in self.parameters()) 539 | 540 | def _init_weights(self, module): 541 | if isinstance(module, nn.Linear): 542 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 543 | if module.bias is not None: 544 | torch.nn.init.zeros_(module.bias) 545 | elif isinstance(module, nn.Embedding): 546 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 547 | 548 | def forward(self, idx, targets=None): 549 | B, T = idx.shape 550 | tok_emb = self.token_embedding_table(idx).clone() 551 | pos_emb = self.position_embedding_table(torch.arange(T, device=self.config['device'])) 552 | x = tok_emb + pos_emb 553 | 554 | all_router_weights = [] # Collect router_weights across MoEs 555 | 556 | for block in self.blocks: 557 | x, router_weights = block(x) # Get router_weights from Block 558 | if router_weights is not None: 559 | all_router_weights.append(router_weights) 560 | 561 | x = self.rm_f(x) 562 | logits = self.lm_head(x) 563 | 564 | if targets is None: 565 | loss = None 566 | else: 567 | B, T, C = logits.shape 568 | logits = logits.view(B * T, C) 569 | targets = targets.view(B * T) 570 | loss = F.cross_entropy(logits, targets) 571 | 572 | return logits, loss, all_router_weights 573 | 574 | @torch.no_grad() 575 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, tiktoken_vocab_size=None): 576 | """ 577 | Generates sequences of tokens autoregressively. 578 | 579 | Args: 580 | idx (torch.LongTensor): Input sequence indices (shape: B, T). 581 | max_new_tokens (int): Maximum number of new tokens to generate. 582 | temperature (float): Sampling temperature. Lower values make the distribution 583 | sharper (less random), higher values make it flatter (more random). 584 | Must be positive. Defaults to 1.0. 585 | top_k (int, optional): If set, only the top_k most probable tokens are considered 586 | for sampling at each step. Set to None or 0 to disable. 587 | Defaults to None. 588 | tiktoken_vocab_size (int, optional): The vocabulary size of the tokenizer. 589 | If provided and smaller than the model's internal 590 | vocab_size (config['vocab_size']), tokens with 591 | indices >= tiktoken_vocab_size will be masked out 592 | during sampling to prevent generating padding tokens. 593 | Defaults to None. 594 | 595 | Returns: 596 | Tuple[torch.LongTensor, float]: 597 | - idx: The generated sequence including the initial prompt (shape: B, T + max_new_tokens). 598 | - total_kv_cache_size_gb: Estimated size of the KV cache in GB after generation. 599 | """ 600 | # Ensure temperature is positive 601 | if temperature <= 0: 602 | # Using temperature=0 often implies greedy sampling (always pick the max logit). 603 | # You could implement that explicitly or just use a very small positive value. 604 | # For simplicity here, we'll just use a very small value to avoid division by zero 605 | # and maintain the sampling structure. Or raise an error. 606 | # raise ValueError("Temperature must be positive.") 607 | print("Warning: Temperature <= 0. Using a very small value (1e-6) instead.") 608 | temperature = 1e-6 609 | 610 | # Determine if vocabulary masking is needed 611 | model_vocab_size = config['vocab_size'] 612 | use_vocab_mask = False 613 | effective_vocab_size = model_vocab_size 614 | if tiktoken_vocab_size is not None: 615 | if tiktoken_vocab_size < model_vocab_size: 616 | print(f"generate(): Masking logits for indices >= {tiktoken_vocab_size} (model vocab size: {model_vocab_size})") 617 | use_vocab_mask = True 618 | effective_vocab_size = tiktoken_vocab_size # For top_k adjustment if needed 619 | elif tiktoken_vocab_size > model_vocab_size: 620 | print(f"generate(): Warning - tiktoken_vocab_size ({tiktoken_vocab_size}) > model_vocab_size ({model_vocab_size}). Masking ineffective.") 621 | # else: sizes match, no masking needed 622 | 623 | 624 | for _ in range(max_new_tokens): 625 | # Crop the context if it exceeds the maximum length 626 | # Use max() to handle initial prompts shorter than ctx_len 627 | start_pos = max(0, idx.size(1) - config['ctx_len']) 628 | idx_cond = idx[:, start_pos:] # shape (B, min(T, ctx_len)) 629 | 630 | # Forward pass to get logits for the next token 631 | # Assuming your model's forward returns (logits, loss, optional_other_data) 632 | # Adjust this based on your actual forward method's return signature 633 | logits, _, _ = self(idx_cond) # We only need logits here 634 | 635 | # Get the logits for the very last token position 636 | logits = logits[:, -1, :] # shape (B, model_vocab_size) 637 | 638 | # Apply temperature scaling 639 | logits = logits / temperature 640 | 641 | # --- Apply Vocabulary Masking (before top-k and softmax) --- 642 | if use_vocab_mask: 643 | logits[:, tiktoken_vocab_size:] = -float('Inf') 644 | # ----------------------------------------------------------- 645 | 646 | # --- Apply Top-k Filtering (before softmax) --- 647 | if top_k is not None and top_k > 0: 648 | # Determine the actual k to use (cannot exceed the number of available logits) 649 | # After masking, the effective number might be smaller, but topk handles -inf correctly. 650 | k = min(top_k, logits.size(-1)) # Use model_vocab_size as the upper bound 651 | 652 | # Get the top k values and indices for each batch element 653 | # We only need the values to find the threshold 654 | top_k_values, _ = torch.topk(logits, k=k, dim=-1) # shape (B, k) 655 | 656 | # Find the value of the k-th largest logit (the minimum value in the top-k set) 657 | kth_logit_value = top_k_values[:, [-1]] # shape (B, 1) 658 | 659 | # Create a mask for logits less than the k-th largest logit 660 | # Set logits below the threshold to negative infinity 661 | logits[logits < kth_logit_value] = -float('Inf') 662 | # ------------------------------------------------- 663 | 664 | # Convert logits to probabilities using softmax 665 | probs = F.softmax(logits, dim=-1) # shape (B, model_vocab_size) 666 | 667 | # Sample the next token index from the probability distribution 668 | idx_next = torch.multinomial(probs, num_samples=1) # shape (B, 1) 669 | 670 | # Append the newly sampled token index to the sequence 671 | idx = torch.cat((idx, idx_next), dim=1) # shape (B, T+1) 672 | 673 | # --- Calculate KV Cache Size (after generation loop) --- 674 | total_size_gb = 0 675 | # Ensure self.blocks exists and contains your transformer blocks 676 | if hasattr(self, 'blocks') and self.blocks is not None: 677 | for block in self.blocks: 678 | # Check if attention layer and its caches exist 679 | if hasattr(block, 'attn') and hasattr(block.attn, 'k_cache') and block.attn.k_cache is not None: 680 | # k_cache size 681 | size_bytes = block.attn.k_cache.numel() * block.attn.k_cache.element_size() 682 | total_size_gb += size_bytes / (1024**3) 683 | if hasattr(block, 'attn') and hasattr(block.attn, 'v_cache') and block.attn.v_cache is not None: 684 | # v_cache size 685 | size_bytes = block.attn.v_cache.numel() * block.attn.v_cache.element_size() 686 | total_size_gb += size_bytes / (1024**3) 687 | else: 688 | print("Warning: Cannot calculate KV cache size. `self.blocks` not found or is None.") 689 | 690 | return idx, total_size_gb 691 | 692 | def configure_optimizers(self, weight_decay, learning_rate, device): 693 | """ 694 | Configures optimizers to use Muon for >=2D parameters WITHIN `self.blocks` 695 | (excluding those known not to receive gradients or with requires_grad=False) 696 | and AdamW for all other parameters. 697 | """ 698 | muon_params = [] 699 | adamw_params = [] 700 | 701 | #print("--- Refining Parameter Assignment (configure_optimizers) ---") 702 | 703 | # List patterns within 'blocks' known not to receive gradients or that shouldn't be optimized by Muon 704 | # Note: '.weight'/'.bias' suffixes are often needed for precise matching. 705 | muon_exclude_patterns = [ 706 | 'attn.intra_block_pos_encoding', # Unused or detached 707 | 'attn.importance_scorer.weight', # Used with non-differentiable topk 708 | 'attn.importance_scorer.bias', # Used with non-differentiable topk 709 | 'attn.block_compressor', # Unused or detached 710 | # 'ffn.expert_bias', # This is already handled by the requires_grad check below 711 | ] 712 | 713 | for name, param in self.named_parameters(): 714 | # 1. Only consider parameters that require gradients 715 | if not param.requires_grad: 716 | #print(f"Skipping (requires_grad=False): {name}") 717 | continue # Skip parameters like expert_bias 718 | 719 | is_excluded = False 720 | # 2. Check if the parameter name contains any of the explicit exclusion patterns 721 | for pattern in muon_exclude_patterns: 722 | if pattern in name: 723 | is_excluded = True 724 | #print(f"Excluding from Muon (known non-grad pattern): {name}") 725 | break # Stop checking patterns once excluded 726 | 727 | #print(f"Processing: {name}, Dim: {param.ndim}, Requires Grad: {param.requires_grad}, Excluded: {is_excluded}") 728 | 729 | # 3. Assign to Muon if: in blocks, >= 2D, AND not explicitly excluded 730 | if 'blocks' in name and param.ndim >= 2 and not is_excluded: 731 | #print(f" -> Assigning to Muon: {name}") 732 | muon_params.append(param) 733 | else: 734 | # Assign to AdamW if: not in blocks, or < 2D, or explicitly excluded 735 | #print(f" -> Assigning to AdamW: {name}") 736 | adamw_params.append(param) 737 | 738 | 739 | #print("--- Final Parameter Group Counts ---") 740 | num_muon_params = sum(p.numel() for p in muon_params) 741 | num_adamw_params = sum(p.numel() for p in adamw_params) 742 | print(f"num Muon parameters: {num_muon_params:,}") 743 | print(f"num AdamW parameters: {num_adamw_params:,}") 744 | 745 | # Defensive check: Ensure Muon doesn't get an empty list 746 | if not muon_params: 747 | print("\n\n*** WARNING: Muon parameter list is EMPTY after filtering! ***") 748 | print("This might be due to incorrect exclusion patterns or model structure.") 749 | print("Proceeding with only the AdamW optimizer.") 750 | # Return only AdamW optimizer in a list for consistent return type 751 | optimizers = [ 752 | torch.optim.AdamW(adamw_params, lr=learning_rate, betas=(0.90, 0.95), weight_decay=weight_decay) 753 | ] 754 | else: 755 | optimizers = [ 756 | Muon(muon_params, lr=0.02, momentum=0.95, rank=0, world_size=1), 757 | torch.optim.AdamW(adamw_params, lr=learning_rate, betas=(0.90, 0.95), weight_decay=weight_decay) 758 | ] 759 | 760 | return optimizers 761 | 762 | def update_expert_biases(self, all_router_weights, update_rate): 763 | 764 | with torch.no_grad(): 765 | # Iterate through the blocks and find MoE layers 766 | 767 | j = 0 768 | 769 | for block in self.blocks: 770 | if isinstance(block.ffn, DSMoE): 771 | 772 | router_weights = all_router_weights[j] 773 | j += 1 774 | 775 | c_i = router_weights[:, 1:].sum(dim=0) # Exclude shared expert, calculate expert load 776 | total_routed_tokens = c_i.sum() 777 | c_i_bar = total_routed_tokens / (block.ffn.num_experts - 1) # avg load 778 | e_i = c_i - c_i_bar # Load violation error 779 | 780 | block.ffn.expert_bias.add_(update_rate * torch.sign(e_i)) # update step 781 | 782 | def estimate_mfu(self, params, fwdbwd_per_iter, dt): 783 | N = params 784 | L, H, Q, T = config['n_layer'], config['n_head'], config['n_embd']//config['n_head'], config['ctx_len'] 785 | flops_per_token = 6*N + 12*L*H*Q*T # fix recalc for MoE 786 | flops_per_fwdbwd = flops_per_token * T 787 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 788 | flops_achieved = flops_per_iter * (1.0/dt) # per second 789 | flops_promised = 65e12 # 65 tflops for a t4 790 | mfu = flops_achieved / flops_promised 791 | return mfu 792 | -------------------------------------------------------------------------------- /src/plot.py: -------------------------------------------------------------------------------- 1 | # wandb-like plots 2 | 3 | import numpy as np 4 | 5 | import matplotlib 6 | from matplotlib import pyplot as plt 7 | 8 | # Wandb at home styling 9 | 10 | plt.style.use('default') 11 | matplotlib.rcParams['font.family'] = 'sans-serif' 12 | matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Helvetica', 'DejaVu Sans', 'Liberation Sans', 'Bitstream Vera Sans', 'sans-serif'] 13 | matplotlib.rcParams['axes.spines.top'] = False 14 | matplotlib.rcParams['axes.spines.right'] = False 15 | matplotlib.rcParams['axes.facecolor'] = '#f0f0f0' 16 | matplotlib.rcParams['figure.facecolor'] = '#f0f0f0' 17 | matplotlib.rcParams['grid.alpha'] = 0.4 18 | matplotlib.rcParams['axes.titlesize'] = 12 19 | matplotlib.rcParams['axes.labelsize'] = 12 20 | matplotlib.rcParams['xtick.labelsize'] = 10 21 | matplotlib.rcParams['ytick.labelsize'] = 10 22 | matplotlib.rcParams['legend.fontsize'] = 10 23 | matplotlib.rcParams['axes.titlecolor'] = 'grey' 24 | matplotlib.rcParams['axes.labelcolor'] = 'grey' 25 | matplotlib.rcParams['xtick.color'] = 'grey' 26 | matplotlib.rcParams['ytick.color'] = 'grey' 27 | matplotlib.rcParams['legend.labelcolor'] = 'grey' 28 | 29 | def plot_loss(train_hist, val_hist, i_eval, iter, run_name): 30 | 31 | plt.figure(figsize=(8, 4), dpi=100) 32 | ax = plt.gca() 33 | 34 | iter_train = range(len(train_hist)) 35 | iter_eval = range(0, len(val_hist) * i_eval, i_eval) 36 | 37 | iter_eval = [x for x in iter_eval if x <= iter] 38 | iter_eval = iter_eval[:len(val_hist)] # Ensures len 39 | 40 | l_train = plt.plot(iter_train, train_hist, label='Train', color='royalblue', linestyle='-', linewidth=2, marker='', alpha=0.7) 41 | val_line = plt.plot(iter_eval, val_hist, label='Val', color='palevioletred', linestyle='-', linewidth=2, marker='', alpha=0.7) 42 | 43 | plt.plot(iter_train[-1:], train_hist[-1:], marker='o', markersize=3, markerfacecolor='royalblue', markeredgecolor='none', linestyle='none') 44 | 45 | if val_hist: # Check if val_hist is non-empty 46 | plt.plot(iter_eval[-1:], val_hist[-1:], marker='o', markersize=3, markerfacecolor='palevioletred', markeredgecolor='none', linestyle='none') 47 | 48 | plt.xlabel("Steps", labelpad=8, color='grey') 49 | plt.ylabel("Loss", labelpad=8, color='grey') 50 | plt.title(f"Train/Val Loss", fontsize=12, color='grey') 51 | 52 | legend = plt.legend(frameon=False, loc='upper right') 53 | 54 | for line in legend.get_lines(): 55 | line.set_linewidth(2.5) 56 | line.set_solid_capstyle('round') 57 | 58 | ax.tick_params(axis='both', which='major', pad=8) 59 | 60 | ax.yaxis.set_ticks_position('left') 61 | ax.xaxis.set_ticks_position('bottom') 62 | 63 | y_ticks = [tick for tick in ax.get_yticks() if tick >= min(min(train_hist), min(val_hist)) and tick <= max(max(train_hist), max(val_hist))] 64 | 65 | ax.set_yticks(y_ticks[::2]) 66 | ax.set_yticks(np.arange(min(y_ticks), max(y_ticks) + 0.4, 0.4), minor=False) 67 | 68 | plt.grid(axis='y', color='grey', linestyle='-', linewidth=0.5) 69 | plt.savefig(f"plots/{run_name}_plot.png", bbox_inches='tight', dpi=300) 70 | plt.clf() 71 | 72 | -------------------------------------------------------------------------------- /src/sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import pickle 5 | import argparse 6 | from contextlib import nullcontext 7 | import tiktoken # Import tiktoken 8 | 9 | # Assuming 'model.py' defines the Transformer correctly and its generate method 10 | # handles temperature, top_k, and tiktoken_vocab_size as implemented previously. 11 | import model 12 | from model import Transformer 13 | 14 | parser = argparse.ArgumentParser(description="Generate text using a Transformer model with Tiktoken.") 15 | 16 | parser.add_argument('--ckpath', type=str, required=True, help='Path to the model checkpoint (.pt file)') 17 | parser.add_argument('--data_dir', type=str, required=True, help='Directory containing meta.pkl (must contain tiktoken info)') 18 | parser.add_argument('--prompt', type=str, default="Hello!", help='Starting prompt for generation') 19 | # Model architecture arguments (should match the loaded checkpoint) 20 | # IMPORTANT: These should match the *padded* configuration if padding was used during training 21 | parser.add_argument('--n_embd', type=int, default=768, help='Embedding dimension (match checkpoint)') 22 | parser.add_argument('--n_head', type=int, default=12, help='Number of attention heads (match checkpoint)') 23 | parser.add_argument('--n_layer', type=int, default=12, help='Number of layers (match checkpoint)') 24 | parser.add_argument('--n_experts', type=int, default=None, help='Number of experts per MoE layer (if used, match checkpoint)') 25 | parser.add_argument('--ctx_len', type=int, default=1024, help='Context length (match checkpoint training or set for generation)') 26 | # Generation arguments 27 | parser.add_argument('--max_tok', type=int, default=100, help='Maximum number of new tokens to generate') 28 | parser.add_argument('--temp', type=float, default=0.1, help='Sampling temperature (e.g., 0.8 for less random, 1.0 for standard, >1.0 for more random)') 29 | parser.add_argument('--top_k', type=int, default=5000, help='Top-k sampling threshold (e.g., 50). Use 0 or None to disable.') 30 | # Optional: Specify model types if needed by your model architecture 31 | parser.add_argument('--types', nargs='*', type=str, default=['mlp'], help='Types of layers used (e.g., mlp moe) - match checkpoint') 32 | # Optional: Add other config args if your model.py needs them (like init_moe_scaling) 33 | parser.add_argument('--init_moe_scaling', type=float, default=1.0, help='Initial MoE scaling factor (match checkpoint if applicable)') # Example 34 | 35 | args = parser.parse_args() 36 | 37 | # --- Configuration and Device Setup --- 38 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 39 | print(f"Using device: {device}") 40 | if device == 'cuda': 41 | # Basic check, specific GPU name might differ 42 | try: 43 | print(f"Using GPU: {torch.cuda.get_device_name(0)}") 44 | except Exception as e: 45 | print(f"Could not get GPU name: {e}") 46 | 47 | 48 | use_compile = hasattr(torch, 'compile') and device=='cuda' 49 | 50 | # --- Load Tokenizer Info from Meta --- 51 | meta_path = os.path.join(args.data_dir, 'meta.pkl') 52 | if not os.path.exists(meta_path): 53 | print(f"Error: meta.pkl not found at {meta_path}") 54 | exit(1) 55 | 56 | print(f"Loading metadata from {meta_path}...") 57 | enc = None # Initialize enc to None 58 | model_vocab_size = None # Vocab size for the model architecture 59 | tiktoken_vocab_size = None # Vocab size reported by tiktoken 60 | 61 | try: 62 | with open(meta_path, 'rb') as f: 63 | meta = pickle.load(f) 64 | if 'vocab_size' not in meta or 'tokenizer' not in meta: 65 | print("Error: meta.pkl must contain 'vocab_size' and 'tokenizer' (tiktoken model name)") 66 | exit(1) 67 | 68 | # This is the vocab size the model was *trained* with (potentially padded) 69 | model_vocab_size = meta['vocab_size'] 70 | tiktoken_model_name = meta['tokenizer'] 71 | print(f" Model vocab size (from meta.pkl): {model_vocab_size:,}") 72 | print(f" Tokenizer model name: {tiktoken_model_name}") 73 | 74 | except Exception as e: 75 | print(f"Error loading or parsing meta.pkl: {e}") 76 | exit(1) 77 | 78 | # --- Initialize Tiktoken Encoder --- 79 | try: 80 | enc = tiktoken.get_encoding(tiktoken_model_name) 81 | encode = lambda s: enc.encode(s, allowed_special='all') 82 | decode = lambda l: enc.decode(l) 83 | tiktoken_vocab_size = enc.n_vocab # Get the actual size from tiktoken 84 | print(f"Tiktoken encoder '{tiktoken_model_name}' loaded successfully.") 85 | print(f" Tiktoken vocabulary size: {tiktoken_vocab_size:,}") 86 | 87 | # Check for mismatch and warn if necessary 88 | if model_vocab_size != tiktoken_vocab_size: 89 | print(f"\n!!! Vocab Size Mismatch Detected !!!") 90 | print(f" Model expects vocab size: {model_vocab_size} (from meta.pkl)") 91 | print(f" Tiktoken provides vocab size: {tiktoken_vocab_size}") 92 | if model_vocab_size > tiktoken_vocab_size: 93 | padding = model_vocab_size - tiktoken_vocab_size 94 | print(f" This suggests the model was trained with a vocabulary padded by {padding} tokens.") 95 | print(f" Generation will proceed, ignoring the padded token indices (>= {tiktoken_vocab_size}).") 96 | else: 97 | print(f" Warning: Model vocab size is smaller than Tiktoken's. This is unusual.") 98 | print(f" Proceeding, but the model might not be able to generate all tokens Tiktoken knows.") 99 | else: 100 | print(" Model and Tiktoken vocabulary sizes match.") 101 | 102 | except Exception as e: 103 | print(f"Error initializing tiktoken encoder '{tiktoken_model_name}': {e}") 104 | exit(1) 105 | 106 | # --- Configure and Load Model --- 107 | print("\nConfiguring model...") 108 | # Use the vocabulary size from meta.pkl for model configuration 109 | model.config['device'] = device 110 | model.config['vocab_size'] = model_vocab_size # CRITICAL: Use the potentially padded size 111 | model.config['ctx_len'] = args.ctx_len 112 | model.config['n_embd'] = args.n_embd 113 | model.config['n_head'] = args.n_head 114 | model.config['n_layer'] = args.n_layer 115 | model.config['n_experts'] = args.n_experts 116 | model.config['type'] = args.types 117 | # Add any other relevant config args from command line or defaults 118 | model.config['init_moe_scaling'] = args.init_moe_scaling # Example 119 | # Make sure dropout is set for eval mode (might be handled in model.eval() but good to be explicit if needed) 120 | # model.config['dropout'] = 0.0 # Typically dropout is disabled during eval 121 | 122 | print("Model Configuration:") 123 | # Filter config to show relevant items, avoid printing huge tensors if any 124 | relevant_keys = ['device', 'vocab_size', 'ctx_len', 'n_embd', 'n_head', 'n_layer', 'n_experts', 'type', 'init_moe_scaling', 'dropout'] 125 | for key in relevant_keys: 126 | if key in model.config: 127 | print(f" {key}: {model.config[key]}") 128 | 129 | transformer_model = Transformer() 130 | 131 | print(f"\nLoading model checkpoint from {args.ckpath}...") 132 | try: 133 | # --- MODIFICATION: Use weights_only=False as discussed --- 134 | print("Warning: Loading checkpoint with weights_only=False. Ensure the checkpoint source is trusted.") 135 | checkpoint = torch.load(args.ckpath, map_location=device, weights_only=False) 136 | # --- END MODIFICATION --- 137 | 138 | # --- Add Check: Ensure checkpoint is a dictionary containing 'model' --- 139 | if not isinstance(checkpoint, dict): 140 | print(f"Error: Loaded checkpoint is not a dictionary (found type: {type(checkpoint)}). Cannot extract state_dict.") 141 | exit(1) 142 | 143 | if 'model' not in checkpoint: 144 | print("Error: Checkpoint dictionary does not contain the key 'model'.") 145 | print("Available keys:", checkpoint.keys()) 146 | exit(1) 147 | # --- End Check --- 148 | 149 | state_dict = checkpoint['model'] 150 | 151 | # Fix potential state_dict key mismatches (e.g., from DataParallel/DDP or torch.compile) 152 | unwanted_prefix = '_orig_mod.' 153 | for k,v in list(state_dict.items()): 154 | if k.startswith(unwanted_prefix): 155 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 156 | # Add other potential prefix fixes if needed (e.g., 'module.') 157 | unwanted_prefix_ddp = 'module.' 158 | for k,v in list(state_dict.items()): 159 | if k.startswith(unwanted_prefix_ddp): 160 | state_dict[k[len(unwanted_prefix_ddp):]] = state_dict.pop(k) 161 | 162 | 163 | # Check if output layer size in checkpoint matches expected model_vocab_size 164 | output_head_key = 'lm_head.weight' # Adjust if your output layer name is different 165 | if output_head_key in state_dict: 166 | ckpt_vocab_size = state_dict[output_head_key].shape[0] 167 | if ckpt_vocab_size != model_vocab_size: 168 | print("\n!!! CRITICAL WARNING: Checkpoint Vocab Size Mismatch !!!") 169 | print(f" Model configured with vocab_size = {model_vocab_size} (from meta.pkl)") 170 | print(f" Checkpoint's output layer ('{output_head_key}') has size {ckpt_vocab_size}") 171 | print(f" Loading WILL LIKELY FAIL or lead to unexpected behavior.") 172 | print(f" Ensure meta.pkl reflects the *exact* vocab_size used for training this checkpoint.") 173 | # Consider exiting: exit(1) 174 | else: 175 | print(f"\nWarning: Could not find output layer key '{output_head_key}' in checkpoint to verify vocab size.") 176 | 177 | 178 | # Load the state dict using strict=False first for better debugging 179 | missing_keys, unexpected_keys = transformer_model.load_state_dict(state_dict, strict=False) 180 | 181 | # --- FIX: Initialize filtered_missing before the conditional block --- 182 | filtered_missing = [] 183 | # --- END FIX --- 184 | 185 | if missing_keys: 186 | print("\nWarning: Missing keys when loading state_dict:") 187 | # Filter out expected buffer keys if desired 188 | expected_missing = {'attn.bias', 'attn.masked_bias'} # Add other known buffers if needed 189 | # Now this line just updates filtered_missing if there are missing keys 190 | filtered_missing = [k for k in missing_keys if not any(buf in k for buf in expected_missing)] 191 | if filtered_missing: 192 | for key in filtered_missing: print(f" {key}") 193 | else: 194 | print(" (Only expected buffer keys like attn.bias seem missing)") 195 | 196 | if unexpected_keys: 197 | print("\nWarning: Unexpected keys when loading state_dict:") 198 | for key in unexpected_keys: print(f" {key}") 199 | 200 | # This condition will now always work because filtered_missing is guaranteed to exist 201 | if not filtered_missing and not unexpected_keys: 202 | print("Model state_dict loaded successfully (ignoring expected buffers).") 203 | else: 204 | print("Model state_dict loaded with potential issues (see missing/unexpected keys above).") 205 | 206 | 207 | except FileNotFoundError: 208 | print(f"Error: Checkpoint file not found at {args.ckpath}") 209 | exit(1) 210 | except pickle.UnpicklingError as e: 211 | print(f"Error: Failed to unpickle the checkpoint file at {args.ckpath}. It might be corrupted or saved incorrectly.") 212 | print(f"Unpickling error details: {e}") 213 | if "argparse.Namespace" in str(e): 214 | print("Hint: This often happens when loading older checkpoints with weights_only=True. Try weights_only=False (as implemented here), ensuring you trust the source.") 215 | exit(1) 216 | except Exception as e: 217 | print(f"Error loading checkpoint: {e}") 218 | print(f"Error type: {type(e).__name__}") 219 | exit(1) 220 | 221 | transformer_model.eval() # Set model to evaluation mode (disables dropout, etc.) 222 | transformer_model.to(device) 223 | 224 | # --- Optional: Compile Model --- 225 | if use_compile: 226 | print("\nCompiling model (takes a minute)...") 227 | try: 228 | transformer_model = torch.compile(transformer_model) 229 | print("Model compiled successfully.") 230 | except Exception as e: 231 | print(f"Model compilation failed: {e}") 232 | use_compile = False # Fallback to non-compiled 233 | 234 | # --- Generation Setup --- 235 | # Encode the starting prompt 236 | start_ids = encode(args.prompt) 237 | x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] # Add batch dimension [1, T] 238 | 239 | print(f"\nStarting generation with prompt: \"{args.prompt}\" ({len(start_ids)} tokens)") 240 | print(f"Max new tokens: {args.max_tok}, Temperature: {args.temp}, Top-k: {args.top_k}") 241 | 242 | # Set up context manager for mixed precision (if applicable and desired) 243 | # BF16 is generally preferred on newer GPUs (Ampere+) if available 244 | pt_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 245 | ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=pt_dtype) 246 | print(f"Using autocast context with dtype: {pt_dtype if device=='cuda' else 'None (CPU)'}") 247 | # ctx = nullcontext() # Uncomment this line to disable autocast if it causes issues 248 | 249 | # --- Run Generation --- 250 | print("Generating...") 251 | with torch.no_grad(): # Crucial for inference efficiency 252 | with ctx: # Apply mixed precision context 253 | start_time = time.time() 254 | # Pass the actual tiktoken vocab size to the generate function 255 | # Ensure your generate function in model.py accepts these arguments 256 | y, kv_cache_size_gb = transformer_model.generate( 257 | idx=x, 258 | max_new_tokens=args.max_tok, 259 | temperature=args.temp, 260 | top_k=args.top_k, 261 | # Add the argument to handle vocabulary padding: 262 | tiktoken_vocab_size=tiktoken_vocab_size 263 | ) 264 | elapsed = time.time() - start_time 265 | 266 | # --- Decode and Print Output --- 267 | generated_tokens = y[0].tolist() # Get tokens from the first batch element 268 | generated_text = decode(generated_tokens) # Decode should work fine as padding tokens were avoided 269 | 270 | print("\n--- Generated Text ---") 271 | print(generated_text) 272 | print("----------------------") 273 | 274 | # Print generation statistics 275 | num_generated = len(generated_tokens) - len(start_ids) 276 | # Avoid division by zero if generation was instant or failed 277 | tokens_per_sec = num_generated / elapsed if elapsed > 0 else float('inf') 278 | print(f"\nGenerated {num_generated} tokens in {elapsed:.2f} seconds ({tokens_per_sec:.2f} tok/s)") 279 | if kv_cache_size_gb is not None: # If generate returns KV cache size info 280 | print(f"Estimated final KV Cache size: {kv_cache_size_gb:.4f} GB") 281 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # Training Loop 2 | 3 | import os 4 | import math 5 | import time 6 | import glob 7 | import torch 8 | import string 9 | import random 10 | import pickle 11 | import argparse 12 | #import heavyball # No longer used 13 | import numpy as np 14 | from contextlib import nullcontext 15 | 16 | import torch.amp as amp # For GradScaler 17 | import torch._dynamo 18 | import torch.distributed as dist # <-- Added for distributed init 19 | 20 | import model 21 | from model import Transformer 22 | from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR 23 | 24 | import plot 25 | from plot import plot_loss 26 | 27 | from muon import Muon 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | parser.add_argument('--batch_size', type=int, default=20) 32 | parser.add_argument('--ctx_len', type=int, default=1024) 33 | parser.add_argument('--eval_interval', type=int, default=20) 34 | parser.add_argument('--grad_accum', type=int, default=4) 35 | 36 | parser.add_argument('--lr', type=float, default=1e-3) 37 | parser.add_argument('--min_lr', type=float, default=1e-4) # <-- Changed type to float 38 | parser.add_argument('--dropout', type=float, default=0.02) 39 | 40 | parser.add_argument('--max_iters', type=int, default=200) 41 | parser.add_argument('--eval_iters', type=int, default=20) 42 | parser.add_argument('--warmup_iters', type=int, default=10) 43 | 44 | parser.add_argument('--resume', type=bool, default=False) 45 | parser.add_argument('--res_path', type=str, default="") 46 | 47 | parser.add_argument('--data_dir', type=str, default="shakespeare") 48 | 49 | parser.add_argument('--n_embd', type=int, default=16) 50 | parser.add_argument('--n_head', type=int, default=2) 51 | parser.add_argument('--n_layer', type=int, default=2) 52 | parser.add_argument('--n_experts', type=int, default=32) 53 | parser.add_argument('--use_expert_bias', type=bool, default=True) 54 | 55 | parser.add_argument('--types', nargs='*', type=str, default= ['mlp','moe','mlp','moe']) 56 | parser.add_argument('--device', type=str, default="cpu") 57 | 58 | args = parser.parse_args() 59 | 60 | # Update the config with parsed arguments 61 | model.config['ctx_len'] = args.ctx_len 62 | model.config['device'] = args.device 63 | model.config['n_embd'] = args.n_embd 64 | model.config['n_head'] = args.n_head 65 | model.config['n_layer'] = args.n_layer 66 | model.config['n_experts'] = args.n_experts 67 | model.config['type'] = args.types 68 | model.config['use_expert_bias'] = args.use_expert_bias 69 | model.config['dropout'] = args.dropout 70 | 71 | # hyperparams 72 | batch_size = args.batch_size 73 | block_size = args.ctx_len # ctx_len 74 | model.config['ctx_len'] = args.ctx_len 75 | eval_interval = args.eval_interval 76 | grad_accum_steps = args.grad_accum # Num microbatches 77 | 78 | lr = 3*args.lr # 3x lr, was in early ngpt spdrns 79 | min_lr = args.min_lr # Now float 80 | 81 | max_iters = args.max_iters 82 | eval_iters = args.eval_iters 83 | warmup_iters = args.warmup_iters 84 | 85 | beta1 = 0.9 # AdamW beta1 86 | beta2 = 0.95 # AdamW beta2 87 | weight_decay = 1e-1 88 | max_grad_norm = 1.0 # Grad clipping 89 | 90 | train_losses_history = [] 91 | val_losses_history = [] 92 | 93 | # continue or scratch 94 | resume = args.resume 95 | data_dir = args.data_dir 96 | resume_checkpoint = args.res_path 97 | 98 | device = args.device 99 | model.config['device'] = args.device 100 | model.config.update(vars(args)) 101 | 102 | # --- Distributed Initialization for Muon (TCP method) --- 103 | distributed_initialized = False 104 | print(f"\n--- Attempting Distributed Initialization (TCP Method, Device: {device}) ---") 105 | # Only attempt initialization if CUDA is used and Muon is present in the expected optimizers 106 | # (Muon might strictly require CUDA/distributed backend) 107 | muon_in_use = True # Assume Muon is intended unless configure_optimizers excludes it 108 | if 'cuda' in device: 109 | try: 110 | # Choose backend 111 | backend = 'gloo' 112 | if dist.is_nccl_available(): 113 | backend = 'nccl' 114 | else: 115 | print("WARNING: NCCL backend not available, using 'gloo'. Muon performance might be affected.") 116 | 117 | # Use TCP initialization (simpler for single node) 118 | init_url = f"tcp://localhost:12355" # Use a free port 119 | dist.init_process_group(backend=backend, init_method=init_url, world_size=1, rank=0) 120 | print(f"Successfully called init_process_group (TCP) with backend: {backend}.") 121 | distributed_initialized = True 122 | print(f"Is distributed initialized after setup? {dist.is_initialized()}") 123 | 124 | except Exception as e: 125 | print(f"ERROR: Failed to initialize process group (TCP): {e}") 126 | print("Muon optimizer might fail if distributed initialization is required.") 127 | # Consider exiting if Muon is critical and initialization fails: 128 | # exit(1) 129 | else: 130 | print("INFO: Not initializing distributed process group (device is not CUDA).") 131 | print("--- Finished Distributed Initialization Attempt ---") 132 | 133 | # Mixed precision, dtype, etc. 134 | ctx = nullcontext() if args.device == 'cpu' else torch.amp.autocast(device_type=args.device, dtype=torch.float16) # Keep float16 for GradScaler 135 | scaler = amp.GradScaler(enabled=("cuda" in device)) 136 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # Check for bfloat later if needed 137 | 138 | # torch compile stuff 139 | torch._dynamo.config.cache_size_limit = 64 140 | torch._dynamo.config.verbose = False # Set to True for more dynamo details 141 | # os.environ["TORCH_LOGS"] = "recompiles" 142 | # os.environ["TORCHDYNAMO_VERBOSE"] = "1" 143 | # os.environ["TORCH_LOGS"] = "+dynamo" 144 | 145 | # run name 146 | characters = string.ascii_letters + string.digits 147 | run_name = ''.join(random.choice(characters) for i in range(6)) 148 | 149 | # --- get data func --- 150 | def get_batch(split): 151 | split_filenames = glob.glob(os.path.join("data", f"{data_dir}", f"{split}_*.bin")) 152 | if not split_filenames: 153 | raise FileNotFoundError(f"No {split} shard files found in {data_dir}") 154 | 155 | shard_file = np.random.choice(split_filenames) 156 | try: 157 | # Assuming header is 256 * 4 bytes based on original code 158 | data = np.memmap(shard_file, dtype=np.uint16, mode='r', offset=1024) 159 | except Exception as e: 160 | print(f"Error memory-mapping file {shard_file}: {e}") 161 | # Handle error, maybe retry or skip shard 162 | return get_batch(split) # Simple retry 163 | 164 | num_tokens_in_shard = len(data) 165 | 166 | if num_tokens_in_shard <= block_size + 1: 167 | print(f"Warning: Shard {shard_file} too small ({num_tokens_in_shard} tokens), resampling...") 168 | return get_batch(split) 169 | 170 | ix = torch.randint(num_tokens_in_shard - block_size - 1, (batch_size,)) 171 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 172 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 173 | 174 | if device == 'cuda': 175 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 176 | else: 177 | x, y = x.to(device), y.to(device) 178 | return x, y 179 | 180 | 181 | # --- getting vocab size --- 182 | meta_path = f'data/{data_dir}/meta.pkl' 183 | meta_vocab_size = None 184 | if os.path.exists(meta_path): 185 | with open(meta_path, 'rb') as f: 186 | meta = pickle.load(f) 187 | meta_vocab_size = meta['vocab_size'] 188 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 189 | model.config["vocab_size"] = meta_vocab_size 190 | else: 191 | print(f"Warning: meta.pkl not found at {meta_path}. Using default vocab size: {model.config['vocab_size']}") 192 | 193 | # --- loss check --- 194 | @torch.no_grad() 195 | def estimate_loss(): 196 | out = {} 197 | model.eval() # Ensure model is in eval mode 198 | for split in ['train', 'val']: 199 | losses = torch.zeros(eval_iters) 200 | for k in range(eval_iters): 201 | X, Y = get_batch(split) 202 | # Use the appropriate context manager for evaluation 203 | with ctx: 204 | logits, loss, _ = model(X, Y) # Ignore router weights in eval 205 | losses[k] = loss.item() 206 | out[split] = losses.mean() 207 | model.train() # Set back to train mode 208 | return out 209 | 210 | # --- Model, Optimizer, Scheduler Init --- 211 | start_iter = 0 212 | scheduler = None # Initialize scheduler variable 213 | 214 | if resume: 215 | print(f"Resuming training from {resume_checkpoint}") 216 | checkpoint = torch.load(resume_checkpoint, map_location=device) 217 | # Load model config from checkpoint if available, otherwise use current args 218 | if 'config' in checkpoint: 219 | model.config.update(checkpoint['config']) 220 | print("Loaded model config from checkpoint.") 221 | else: 222 | print("Warning: No config found in checkpoint, using current script args.") 223 | 224 | model = Transformer() # Initialize model with potentially updated config 225 | state_dict = checkpoint['model'] 226 | # Handle potential issues with compiled models state_dict keys 227 | unwrapped_state_dict = {} 228 | for k, v in state_dict.items(): 229 | if k.startswith('_orig_mod.'): 230 | unwrapped_state_dict[k[len('_orig_mod.'):]] = v 231 | else: 232 | unwrapped_state_dict[k] = v 233 | model.load_state_dict(unwrapped_state_dict) 234 | m = model.to(device) 235 | print("Model loaded from checkpoint.") 236 | 237 | optimizers = model.configure_optimizers(weight_decay, lr, device) # Get the list 238 | adamw_optimizer = optimizers[-1] # Assume AdamW is the last optimizer 239 | 240 | # Load optimizer states 241 | if 'optimizer_states' in checkpoint and len(checkpoint['optimizer_states']) == len(optimizers): 242 | for i, opt_state in enumerate(checkpoint['optimizer_states']): 243 | try: 244 | optimizers[i].load_state_dict(opt_state) 245 | print(f"Loaded state for optimizer {i} ({type(optimizers[i]).__name__})") 246 | except Exception as e: 247 | print(f"Warning: Could not load state for optimizer {i}: {e}") 248 | else: 249 | print("Warning: Optimizer states not found or mismatch in checkpoint. Initializing optimizers from scratch.") 250 | 251 | # Create and load scheduler (assuming it's for AdamW) 252 | warmup_scheduler = LinearLR(adamw_optimizer, start_factor=1e-3, total_iters=warmup_iters) 253 | cosine_scheduler = CosineAnnealingLR(adamw_optimizer, T_max=max_iters - warmup_iters, eta_min=min_lr) 254 | scheduler = SequentialLR(adamw_optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_iters]) 255 | if 'scheduler' in checkpoint: 256 | try: 257 | scheduler.load_state_dict(checkpoint['scheduler']) 258 | print("Scheduler state loaded.") 259 | except Exception as e: 260 | print(f"Warning: Could not load scheduler state: {e}. Initializing scheduler from scratch.") 261 | else: 262 | print("Warning: Scheduler state not found in checkpoint. Initializing scheduler from scratch.") 263 | 264 | 265 | # Restore previous run state 266 | start_iter = checkpoint.get('iter', 0) + 1 # Start from next iteration 267 | run_name = checkpoint.get('run_name', run_name) # Keep original name if not in ckpt 268 | 269 | train_losses_history = checkpoint.get('train_losses_history', []) 270 | val_losses_history = checkpoint.get('val_losses_history', []) 271 | 272 | print(f"Resuming run {run_name} from iteration {start_iter}") 273 | 274 | else: 275 | print(f"Starting run {run_name} from scratch") 276 | model = Transformer() 277 | m = model.to(device) 278 | 279 | optimizers = model.configure_optimizers(weight_decay, lr, device) # Get the list of optimizers 280 | # Check if Muon was actually included (based on defensive check in configure_optimizers) 281 | if len(optimizers) == 1: 282 | print("WARNING: configure_optimizers returned only one optimizer. Assuming AdamW only.") 283 | adamw_optimizer = optimizers[0] 284 | muon_in_use = False 285 | elif len(optimizers) == 2: 286 | adamw_optimizer = optimizers[1] # Access AdamW optimizer (index 1) 287 | muon_in_use = True 288 | print("Using Muon (optimizer 0) and AdamW (optimizer 1)") 289 | else: 290 | print(f"ERROR: Unexpected number of optimizers ({len(optimizers)}) returned. Exiting.") 291 | exit(1) 292 | 293 | # Create scheduler (for AdamW) 294 | warmup_scheduler = LinearLR(adamw_optimizer, start_factor=1e-3, total_iters=warmup_iters) 295 | cosine_scheduler = CosineAnnealingLR(adamw_optimizer, T_max=max_iters - warmup_iters, eta_min=min_lr) 296 | scheduler = SequentialLR(adamw_optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_iters]) 297 | 298 | start_iter = 0 299 | 300 | # Compile the model AFTER potentially loading state dict and moving to device 301 | # And AFTER distributed init is done 302 | if "cuda" in device: 303 | print("compiling the model...") 304 | try: 305 | # Disable dynamic shapes if causing issues with compile/distributed 306 | model = torch.compile(model, fullgraph=False, dynamic=False) 307 | print("compiled") 308 | except Exception as e: 309 | print(f"Torch compile failed: {e}. Running without compile.") 310 | 311 | 312 | p = sum(p.numel() for p in m.parameters() if p.requires_grad) # Count trainable params 313 | print(f"{p/1e6:.6f} M trainable parameters") 314 | 315 | # Create directories if they don't exist 316 | if not os.path.exists("checkpoints"): 317 | os.makedirs("checkpoints") 318 | if not os.path.exists("plots"): 319 | os.makedirs("plots") 320 | 321 | # --- Training Loop --- 322 | print(f"\nStarting training loop from iteration {start_iter}...") 323 | time_s = time.time() 324 | prev_time = time_s 325 | 326 | try: # Wrap training loop in try...finally for cleanup 327 | for iter in range(start_iter, max_iters + 1): 328 | 329 | # Evaluate loss periodically 330 | if (iter % eval_interval == 0 or (iter < 100 and iter % 10 == 0) or iter == max_iters) and iter > 0: 331 | losses = estimate_loss() 332 | val_losses_history.append(losses['val']) 333 | 334 | time_n = time.time() 335 | elapsed = time_n - time_s 336 | dt = time_n - prev_time if iter > start_iter else time_n - time_s # Handle first dt calculation 337 | prev_time = time_n 338 | 339 | # MFU calculation might need adjustment for MoE/complex attn 340 | # mfu = model.estimate_mfu(p, batch_size * grad_accum_steps, dt) if hasattr(model, 'estimate_mfu') else 0.0 341 | # total_flops = 65e12 * elapsed * mfu # Placeholder 342 | 343 | print(f"step: {iter}, train loss: {losses['train']:.4f}, val loss: {losses['val']:.4f}, elapsed time: {elapsed/60:.2f} min, dt: {dt*1000:.2f} ms") # Removed MFU/flops for now 344 | 345 | # Save checkpoint 346 | # Ensure state dicts are handled correctly for compiled/uncompiled models 347 | model_state_to_save = model.state_dict() 348 | if hasattr(model, '_orig_mod'): 349 | model_state_to_save = model._orig_mod.state_dict() 350 | 351 | # Save optimizer states as a list 352 | optimizer_states_to_save = [opt.state_dict() for opt in optimizers] 353 | 354 | checkpoint = { 355 | 'model': model_state_to_save, 356 | 'optimizer_states': optimizer_states_to_save, # Save list of states 357 | 'scheduler': scheduler.state_dict() if scheduler else None, 358 | 'iter': iter, 359 | 'run_name': run_name, 360 | 'train_losses_history': train_losses_history, 361 | 'val_losses_history': val_losses_history, 362 | 'config': model.config, # Save model config 363 | 'args': args # Save script arguments 364 | } 365 | 366 | if losses['val'] < 3.28: 367 | 368 | checkpoint_path = f'checkpoints/{run_name}_check_{iter}.pt' 369 | print(f"Saving checkpoint to {checkpoint_path}") 370 | torch.save(checkpoint, checkpoint_path) 371 | 372 | # Plot loss 373 | #plot_loss(train_losses_history, val_losses_history, eval_interval, iter, run_name) 374 | 375 | # --- Training Step --- 376 | if iter == max_iters: # Don't do a training step after the last eval 377 | break 378 | 379 | loss_accum = 0.0 380 | all_router_weights_accum = [] 381 | 382 | # Gradient Accumulation Loop 383 | for micro_step in range(grad_accum_steps): 384 | xb, yb = get_batch('train') 385 | # Determine context based on GradScaler enabled status 386 | current_ctx = ctx if scaler.is_enabled() else nullcontext() 387 | 388 | with current_ctx: 389 | logits, loss, rw = model(xb, yb) 390 | loss = loss / grad_accum_steps # Scale loss for accumulation 391 | 392 | if rw: # Check if router weights were returned 393 | all_router_weights_accum.extend(rw) 394 | 395 | # Scaled backward pass 396 | scaler.scale(loss).backward() 397 | loss_accum += loss.item() * grad_accum_steps # Unscale loss item for logging 398 | 399 | train_losses_history.append(loss_accum / grad_accum_steps) # Log average loss for the step 400 | 401 | # Unscale gradients for ALL optimizers BEFORE clipping 402 | for opt in optimizers: 403 | scaler.unscale_(opt) 404 | 405 | # Clip gradients 406 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 407 | 408 | # Step EACH optimizer 409 | for opt in optimizers: 410 | # Inside loop check for distributed init status (for debugging) 411 | #if isinstance(opt, Muon): 412 | #print(f"Inside loop, before scaler.step(Muon): dist.is_initialized() = {dist.is_initialized()}") 413 | #if not dist.is_initialized(): 414 | #print("ERROR: Default process group IS NOT INITIALIZED right before Muon step!") 415 | scaler.step(opt) 416 | 417 | # Update GradScaler 418 | scaler.update() 419 | 420 | # Zero gradients for ALL optimizers 421 | for opt in optimizers: 422 | opt.zero_grad(set_to_none=True) 423 | 424 | # Step the scheduler (only affects AdamW) 425 | if scheduler: 426 | scheduler.step() 427 | 428 | # Update expert biases (if using DS-MoE and weights were collected) 429 | if all_router_weights_accum and hasattr(model, 'update_expert_biases'): 430 | model.update_expert_biases(all_router_weights_accum, 1e-3) 431 | 432 | 433 | finally: # Ensure cleanup happens even if errors occur 434 | # --- Distributed Cleanup --- 435 | if distributed_initialized: 436 | if dist.is_initialized(): # Check before destroying 437 | dist.destroy_process_group() 438 | print("Destroyed default process group.") 439 | else: 440 | print("INFO: Default process group was already destroyed or not initialized.") 441 | 442 | print('\nTraining finished or interrupted.') 443 | --------------------------------------------------------------------------------