├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── data ├── cached_fineweb100B.py ├── cached_fineweb10B.py ├── fineweb.py └── requirements.txt ├── img ├── algo_optimizer.png ├── dofa.jpg ├── fig_optimizer.png ├── fig_tuned_nanogpt.png ├── nanogpt_speedrun51.png ├── nanogpt_speedrun52.png ├── nanogpt_speedrun53.png └── nanogpt_speedrun54.png ├── records ├── 060624_AdamW │ ├── README.md │ └── f66d43d7-e449-4029-8adf-e8537bab49ea.log ├── 100924_SOAP │ ├── 5bdc3988-496c-4232-b4ef-53764cb81c92.txt │ ├── README.md │ └── train_gpt2.py ├── 101024_Muon │ ├── eb5659d0-fb6a-49e5-a311-f1f89412f726.txt │ └── train_gpt2.py ├── 101324_llmc │ ├── README.md │ └── main.log ├── 101424_ModernArch │ ├── dabaaddd-237c-4ec9-939d-6608a9ed5e27.txt │ └── train_gpt2.py ├── 101724_DistributedMuon │ └── 22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt ├── 101824_PyTorch25 │ └── d4bfb25f-688d-4da5-8743-33926fad4842.txt ├── 102024_ScaleUp1B │ ├── 87bd51fd-6203-4c88-b3aa-8a849a6a83ca.txt │ ├── ad8d7ae5-7b2d-4ee9-bc52-f912e9174d7a.txt │ └── c0078066-c8c9-49c8-868a-ff4d4f32e615.txt ├── 102924_Optimizers │ ├── 8bfe4e35-c3fc-4b70-a984-3be937b71ff3.txt │ ├── 8d6193f4-27fc-4e68-899f-af70019a4d54.txt │ ├── 95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt │ ├── README.md │ ├── e21a2838-a0f2-46f2-a247-db0021165682.txt │ ├── nanogpt_speedrun81w.png │ └── nanogpt_speedrun82w.png ├── 110324_UntieEmbed │ ├── README.md │ └── d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt ├── 110424_50Bruns │ ├── 3d715d41-453a-40d6-9506-421ba69766b2.txt │ ├── 4fbe61ec-f79a-4c19-836d-46d599deecce.txt │ ├── 530f3ee1-8862-4d21-be2b-da10eb05e6a9.txt │ ├── 69c33fc9-eabb-4a38-aa08-6922914eb405.txt │ └── README.md ├── 110624_ShortcutsTweaks │ ├── 042f9e87-07e6-4504-bb04-4ec59a380211.txt │ ├── 05b29e54-0be0-4a0f-a1e2-7d5317daedd3.txt │ ├── 10119f53-7001-4248-bfd9-33d32427a912.txt │ ├── 43f60c4f-0448-4de7-83d9-643ca26f61e7.txt │ ├── 4a71cc92-0f43-4058-a033-23e85c1e98f1.txt │ ├── README.md │ ├── d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt │ ├── dd7304a6-cc43-4d5e-adb8-c070111464a1.txt │ ├── nanogpt_speedrun110.png │ └── nanogpt_speedrun111.png ├── 110824_CastBf16 │ └── a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt ├── 110924_Replicateleloykun │ ├── 1621af10-aa0c-42af-bf54-8a773c63a2af.txt │ └── README.md ├── 111024_ScaleShortcuts │ ├── 3e55eb2e-6261-466a-b1e9-2b31f56fb16a.txt │ ├── 4897c987-9d09-435c-a23f-20585912936a.txt │ ├── 70a0ada6-8dee-4fef-8980-135379479c21.txt │ ├── a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt │ └── d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt ├── 111024_UNetDoubleLr │ ├── README.md │ └── c87bb826-797b-4f37-98c7-d3a5dad2de74.txt └── 111424_QuantizedFP4 │ ├── 433c1732-0c3d-4099-a4a8-ec31eae49b16.txt │ ├── 70a0ada6-8dee-4fef-8980-135379479c21.txt │ ├── 932bbe0e-41c3-4a5b-94bd-4ea3350909bd.txt │ └── a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt ├── requirements.txt ├── run.sh ├── run_rwkv6.sh ├── run_rwkv7.sh ├── rwkv_cuda ├── wkv6_cuda.cu ├── wkv6_op.cpp ├── wkv7g_op.cpp └── wkv7g_v1.cu ├── rwkv_cuda_wind ├── backstepping_f32.cpp ├── backstepping_f32_1.cu ├── backstepping_f32_2.cu ├── tile.cuh ├── wind_rwkv7.cpp └── wind_rwkv7.cu ├── rwkv_records ├── RWKV-6-2024-10-20-07-02-38.txt ├── RWKV-7-2024-10-21-14-51-20.txt ├── RWKV-7-fast-2024-10-29-07-48-34.txt └── RWKV-7-fast-2024-11-09-19-49-33.txt ├── train_gpt2.py ├── train_rwkv6.py └── train_rwkv7.py /.gitignore: -------------------------------------------------------------------------------- 1 | fineweb10B/ 2 | pylog124M/ 3 | __pycache__/ 4 | logs/ 5 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.6.2-cudnn-devel-ubuntu24.04 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | ENV PYTHON_VERSION=3.12.7 5 | ENV PATH=/usr/local/bin:$PATH 6 | 7 | RUN apt update && apt install -y --no-install-recommends build-essential libssl-dev zlib1g-dev \ 8 | libbz2-dev libreadline-dev libsqlite3-dev curl git libncursesw5-dev xz-utils tk-dev libxml2-dev \ 9 | libxmlsec1-dev libffi-dev liblzma-dev \ 10 | && apt clean && rm -rf /var/lib/apt/lists/* 11 | 12 | RUN curl -O https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz && \ 13 | tar -xzf Python-${PYTHON_VERSION}.tgz && \ 14 | cd Python-${PYTHON_VERSION} && \ 15 | ./configure --enable-optimizations && \ 16 | make -j$(nproc) && \ 17 | make altinstall && \ 18 | cd .. && \ 19 | rm -rf Python-${PYTHON_VERSION} Python-${PYTHON_VERSION}.tgz 20 | 21 | RUN ln -s /usr/local/bin/python3.12 /usr/local/bin/python && \ 22 | ln -s /usr/local/bin/pip3.12 /usr/local/bin/pip 23 | 24 | COPY requirements.txt /modded-nanogpt/requirements.txt 25 | WORKDIR /modded-nanogpt 26 | 27 | RUN python -m pip install --upgrade pip && \ 28 | pip install -r requirements.txt 29 | 30 | CMD ["bash"] 31 | ENTRYPOINT [] 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Keller Jordan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Modded-NanoGPT-RWKV 2 | 3 | RWKV Discord: https://discord.gg/bDSBUMeFpc 4 | 5 | RWKV Twitter: https://twitter.com/BlinkDL_AI 6 | 7 | ## RWKV-6 and RWKV-7 8 | 9 | ### Latest run: 3200 steps to reach 3.27xx loss 10 | 11 | This is using latest (current) train_rwkv7.py 12 | ``` 13 | ./run_rwkv7.sh --adam_lr 0.0026 --muon_lr 0.02 --ln_lr 0.0090 --headsz 64 --bsz 512 --device_bsz 32 --fast_cuda 14 | ``` 15 | 16 | ### Old run: 5100 steps to reach 3.27xx loss 17 | 18 | This is using old train_rwkv7.py 19 | 20 | Please read https://x.com/BlinkDL_AI/status/1848343821467390156 first. 21 | 22 | Modded-GPT 123.6M headsize 128 => val_loss 3.27xx 23 | 24 | RWKV-7 123.7M headsize 64 => val_loss 3.2715 (increase headsize to reach 3.26xx) 25 | 26 | RWKV-6 123.7M headsize 64 => val_loss 3.2914 27 | 28 | RWKV-6 123.7M headsize 192 => val_loss 3.28xx 29 | 30 | Check https://github.com/BlinkDL/modded-nanogpt-rwkv/tree/master/rwkv_records for training log. 31 | 32 | Try 0.0020/0.0022/0.0024 for adam_lr. Try 1.5/2/2.5 for emb_scale. Reduce device_bsz if OOM (will gradient accumulate). 33 | ``` 34 | Note: Currently inefficient implementation. Please help if you are a Pytorch / CUDA / triton master :) 35 | 36 | ./run_rwkv7.sh --adam_lr 0.0022 --emb_scale 2 --muon_lr 0.00036 --headsz 64 --bsz 512 --device_bsz 64 --fast_cuda (much faster cuda) 37 | 38 | ./run_rwkv7.sh --adam_lr 0.0022 --emb_scale 2 --muon_lr 0.00036 --headsz 64 --bsz 512 --device_bsz 32 (reference, takes more VRAM, have to reduce device_bsz) 39 | 40 | ./run_rwkv7.sh --adam_lr 0.0022 --emb_scale 2 --muon_lr 0.00036 --headsz 64 --bsz 512 --device_bsz 64 --wind_cuda (even faster cuda, likely worse loss) 41 | 42 | ./run_rwkv6.sh --adam_lr 0.0020 --emb_scale 1.5 --muon_lr 0.00036 --headsz 64 --bsz 512 --device_bsz 64 43 | ``` 44 | 45 | ## Original Readme 46 | 47 | This is a modified variant of the [PyTorch GPT-2 trainer](https://github.com/karpathy/llm.c/blob/7b929300217ff1a974b63791a228928b39b26409/train_gpt2.py) from 48 | Andrej Karpathy's [llm.c](https://github.com/karpathy/llm.c) repo, which attains the same final validation loss in: 49 | * 1.7B tokens instead of 10B 50 | * 7.8 minutes on 8xH100 instead of 45 51 | 52 | It uses the following techniques: 53 | * Modernized architecture: Rotary embeddings, QK-Norm, and ReLU^2. 54 | * New optimizer: Muon - Momentum Orthogonalized by Newton-schulz. 55 | * Untied head from embedding. 56 | * Projection and classification layers initialized to zero (muP-like). 57 | * Architectural shortcuts: value residual and embedding shortcut (partially following https://arxiv.org/abs/2410.17897). 58 | * Momentum warmup. 59 | * Tanh soft logit capping (following Gemma 2). 60 | 61 | --- 62 | 63 | ## Running the training 64 | 65 | To execute the training, run the following three commands. 66 | They should all complete within <20min on an 8xH100 with decent internet connection. 67 | ```bash 68 | pip install -r requirements.txt 69 | python data/cached_fineweb10B.py 18 # downloads only the first 1.8B training tokens to save time 70 | ./run.sh 71 | ``` 72 | 73 | The result will be a transformer with 124M active parameters trained for 3242 steps on 1.7B tokens of Fineweb [1], achieving ~3.278 validation loss. 74 | For comparison, the default llm.c PyTorch trainer yields [>3.28 validation loss after training for 19560 steps on 10B tokens](https://github.com/karpathy/llm.c/discussions/481#:~:text=By%20the%20end%20of%20the%20optimization%20we%27ll%20get%20to%20about%203.29). 75 | 76 | ## Running it on fewer GPUs or with less memory 77 | 78 | * To run on fewer GPUs, just modify `run.sh` to have a different `--nproc_per_node`. 79 | * If you're running out of memory, then go into `train_gpt2.py` and scale down the `device_batch_size` to either 16 or 32. 80 | 81 | Both of these changes will have no effect on the training - you should get the exact same loss curve as the most recent record, because the training code 82 | will automatically adjust the gradient accumulation in order to have the same total batch size. 83 | 84 | ## Running with Docker 85 | 86 | For cases where CUDA or NCCL versions aren't compatible with your current system setup, Docker can be a helpful alternative. 87 | This approach standardizes versions for CUDA, NCCL, CUDNN, and Python, reducing dependency issues and simplifying setup. 88 | Note: an NVIDIA driver must already be installed on the system (useful if only the NVIDIA driver and Docker are available). 89 | 90 | ```bash 91 | sudo docker build -t modded-nanogpt . 92 | sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt python data/cached_fineweb10B.py 18 93 | sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt sh run.sh 94 | ``` 95 | --- 96 | 97 | ## World record history 98 | 99 | The following is the progression of world records for the task of *training a model with 124M active parameters to 3.28 validation loss on FineWeb in the minimal amount of time on an 8xH100 machine.* 100 | 101 | 1. [45 minutes: llm.c baseline](https://github.com/karpathy/llm.c/discussions/481) (05/28/24) [[training log](./records/101324_llmc/main.log)] (note: the 90 minute time is on 8xA100; it's 45 minutes on 8xH100. This run is essentially a hardware-optimized GPT-2 (small) replication using better training data.) 102 | 2. [31.4 minutes: Architectural modernizations and learning rate tuning](https://x.com/kellerjordan0/status/1798863559243513937) (06/06/24) [[training log](./records/060624_AdamW/f66d43d7-e449-4029-8adf-e8537bab49ea.log)] 103 | 3. [24.9 minutes: Introduced the Muon optimizer](https://x.com/kellerjordan0/status/1842300916864844014) (10/04/24) 104 | 4. [22.3 minutes: Muon improvements](https://x.com/kellerjordan0/status/1844820919061287009) (10/11/24) [[reproducible log](./records/101024_Muon/eb5659d0-fb6a-49e5-a311-f1f89412f726.txt)] 105 | 5. [15.2 minutes: Pad embeddings & architectural modernizations](https://x.com/kellerjordan0/status/1845865698532450646) (10/14/24) [[reproducible log](./records/101424_ModernArch/dabaaddd-237c-4ec9-939d-6608a9ed5e27.txt)] 106 | 6. [13.1 minutes: Distributed the overhead of Muon](https://x.com/kellerjordan0/status/1847291684016783746) (10/18/24) [[reproducible log](./records/101724_DistributedMuon/22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt)] 107 | 7. [12.0 minutes: Upgraded PyTorch from 2.4.1 to 2.5.0](https://x.com/kellerjordan0/status/1847358578686152764) (10/18/24) [[reproducible log](./records/101824_PyTorch25/d4bfb25f-688d-4da5-8743-33926fad4842.txt)] 108 | 8. [10.8 minutes: Untied embed and lm_head](https://x.com/kellerjordan0/status/1853188916704387239) (11/03/24) [[reproducible log](./records/110324_UntieEmbed/d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt)] 109 | 9. [8.2 minutes: Shortcuts & tweaks](https://x.com/kellerjordan0/status/1854296101303800108) (11/06/24) [[reproducible log](./records/110624_ShortcutsTweaks/dd7304a6-cc43-4d5e-adb8-c070111464a1.txt)] 110 | 11. [7.8 minutes: Bfloat16 activations](https://x.com/kellerjordan0/status/1855267054774865980) (11/08/24) [[reproducible log](./records/110824_CastBf16/a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt)] 111 | 12. [7.23 minutes: U-net & 2x lr](https://x.com/kellerjordan0/status/1856053121103093922) (11/10/24) [[reproducible log](./records/111024_UNetDoubleLr/c87bb826-797b-4f37-98c7-d3a5dad2de74.txt)] 112 | 113 | Please see the X threads for the contributors to each record. 114 | 115 | The `train_gpt2.py` in this repo is the 11/08/24 record. To run the latest 11/10/24 record, use the code in its reproducible log. 116 | 117 | 121 | 124 | 125 | 131 | 132 | ### Notable attempts 133 | 134 | 1. [An 11/07/24 attempt, which I attempted to cerify on 11/09/24](./records/110924_Replicateleloykun) 135 | 136 | ### Notable forks 137 | 138 | * [https://github.com/BlinkDL/modded-nanogpt-rwkv](https://github.com/BlinkDL/modded-nanogpt-rwkv) 139 | * [https://github.com/nikhilvyas/modded-nanogpt-SOAP](https://github.com/nikhilvyas/modded-nanogpt-SOAP) 140 | 141 | ### Speedrun rules 142 | 143 | 1. Must not modify the train or validation data pipelines (except to change batch size if you want). 144 | 2. Must use ≤ 124M active parameters per token. 145 | 3. Must attain ≤ 3.28 val loss. A tasteful number would be 3.278 so that [this doesn't happen](./records/110924_Replicateleloykun/1621af10-aa0c-42af-bf54-8a773c63a2af.txt#L3780). 146 | 147 | Other than that, go crazy! Anything is fair game 148 | 149 | --- 150 | 151 | ### Q: What is the point of NanoGPT speedrunning? 152 | 153 | A: The officially stated goal of NanoGPT speedrunning is as follows: `gotta go fast`. But for something a little more verbose involving an argument for good benchmarking, here's some kind of manifesto, adorned with a blessing from the master. [https://x.com/karpathy/status/1846790537262571739](https://x.com/karpathy/status/1846790537262571739) 154 | 155 | ### Q: What makes "NanoGPT speedrunning" not just another idiosyncratic benchmark? 156 | 157 | A: Because it is a *competitive* benchmark. In particular, if you attain a new speed record (using whatever method you want), there is an open invitation for you 158 | to post that record (on arXiv or X) and thereby vacuum up all the clout for yourself. I will even help you do it by reposting you as much as I can. 159 | 160 | 167 | 168 | ["Artificial intelligence advances by inventing games and gloating to goad others to play" - Professor Ben Recht](https://www.argmin.net/p/too-much-information) 169 | 170 | ### Q: NanoGPT speedrunning is cool and all, but meh it probably won't scale and is just overfitting to val loss 171 | 172 | A: This is hard to refute, since "at scale" is an infinite category (what if the methods stop working only for >100T models?), making it impossible to fully prove. 173 | Also, I would agree that some of the methods used in the speedrun are unlikely to scale. 174 | But if the reader cares about 1.5B models, they might be convinced by this result: 175 | 176 | *Straightforwardly scaling up the speedrun (10/18/24 version) to 1.5B parameters yields a model with GPT-2 (1.5B)-level HellaSwag performance 2.5x more cheaply than [@karpathy's baseline](https://github.com/karpathy/llm.c/discussions/677) ($233 instead of $576):* 177 | 178 | ![](img/nanogpt_speedrun51.png) 179 | [[reproducible log](https://github.com/KellerJordan/modded-nanogpt/blob/master/records/102024_ScaleUp1B/ad8d7ae5-7b2d-4ee9-bc52-f912e9174d7a.txt)] 180 | ![](img/nanogpt_speedrun52.png) 181 | 182 | --- 183 | 184 | ## [Muon optimizer](https://github.com/KellerJordan/Muon) 185 | 186 | Muon is defined as follows: 187 | 188 | ![](img/algo_optimizer.png) 189 | 190 | Where NewtonSchulz5 is the following Newton-Schulz iteration [2, 3], which approximately replaces `G` with `U @ V.T` where `U, S, V = G.svd()`. 191 | ```python 192 | @torch.compile 193 | def zeroth_power_via_newtonschulz5(G, steps=5, eps=1e-7): 194 | assert len(G.shape) == 2 195 | a, b, c = (3.4445, -4.7750, 2.0315) 196 | X = G.bfloat16() / (G.norm() + eps) 197 | if G.size(0) > G.size(1): 198 | X = X.T 199 | for _ in range(steps): 200 | A = X @ X.T 201 | B = b * A + c * A @ A 202 | X = a * X + B @ X 203 | if G.size(0) > G.size(1): 204 | X = X.T 205 | return X.to(G.dtype) 206 | ``` 207 | 208 | For this training scenario, Muon has the following favorable properties: 209 | * Lower memory usage than Adam 210 | * ~1.5x better sample-efficiency 211 | * <2% wallclock overhead 212 | 213 | 214 | ### Provenance 215 | 216 | Many of the choices made to generate this optimizer were obtained experimentally by our pursuit of [CIFAR-10 speedrunning](https://github.com/KellerJordan/cifar10-airbench). 217 | In particular, we experimentally obtained the following practices: 218 | * Using Nesterov momentum inside the update, with orthogonalization applied after momentum. 219 | * Using a specifically quintic Newton-Schulz iteration as the method of orthogonalization. 220 | * Using non-convergent coefficients for the quintic polynomial in order to maximize slope at zero, and thereby minimize the number of necessary Newton-Schulz iterations. 221 | It turns out that the variance doesn't actually matter that much, so we end up with a quintic that (rapidly) converges to the range 0.68, 1.13 upon repeated application, rather than to 1. 222 | * Running the Newton-Schulz iteration in bfloat16 (whereas Shampoo implementations often depend on inverse-pth-roots run in fp32 or fp64). 223 | 224 | Our use of a Newton-Schulz iteration for orthogonalization traces to [Bernstein & Newhouse (2024)](https://arxiv.org/abs/2409.20325), 225 | who suggested it as a way to compute Shampoo [5, 6] preconditioners, and theoretically explored Shampoo without preconditioner accumulation. 226 | In particular, Jeremy Bernstein @jxbz sent us the draft, which caused us to experiment with various Newton-Schulz iterations as the 227 | orthogonalization method for this optimizer. 228 | If we had used SVD instead of a Newton-Schulz iteration, this optimizer would have been too slow to be useful. 229 | Bernstein & Newhouse also pointed out that Shampoo without preconditioner accumulation is equivalent to steepest descent in the spectral norm, 230 | and therefore Shampoo can be thought of as a way to smooth out spectral steepest descent. 231 | The proposed optimizer can be thought of as a second way of smoothing spectral steepest descent, with a different set of memory and runtime tradeoffs 232 | compared to Shampoo. 233 | 234 | --- 235 | 236 | ## Startup script 237 | 238 | Here's a good startup script for a fresh 8xH100 instance. 239 | 240 | ``` 241 | sudo apt-get update 242 | sudo apt-get install vim tmux python3-pip python-is-python3 -y 243 | git clone https://github.com/KellerJordan/modded-nanogpt.git 244 | cd modded-nanogpt 245 | tmux 246 | 247 | pip install numpy==1.23.5 huggingface-hub tqdm 248 | pip install --upgrade torch & 249 | python data/cached_fineweb10B.py 18 250 | ``` 251 | 252 | --- 253 | 254 | ## References 255 | 256 | 1. [Penedo, Guilherme, et al. "The fineweb datasets: Decanting the web for the finest text data at scale." arXiv preprint arXiv:2406.17557 (2024).](https://arxiv.org/abs/2406.17557) 257 | 2. Nicholas J. Higham. Functions of Matrices. Society for Industrial and Applied Mathematics, 2008. Equation 5.22. 258 | 3. Günther Schulz. Iterative Berechnung der reziproken Matrix. Z. Angew. Math. Mech., 13:57–59, 1933. 259 | 4. [Jeremy Bernstein and Laker Newhouse. "Old Optimizer, New Norm: An Anthology." arxiv preprint arXiv:2409.20325 (2024).](https://arxiv.org/abs/2409.20325) 260 | 5. [Vineet Gupta, Tomer Koren, and Yoram Singer. "Shampoo: Preconditioned stochastic tensor optimization." International Conference on Machine Learning. PMLR, 2018.](https://arxiv.org/abs/1802.09568) 261 | 6. [Anil, Rohan, et al. "Scalable second order optimization for deep learning." arXiv preprint arXiv:2002.09018 (2020).](https://arxiv.org/abs/2002.09018) 262 | 7. [Hägele, Alexander, et al. "Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations." arXiv preprint arXiv:2405.18392 (2024).](https://arxiv.org/abs/2405.18392) 263 | 264 | [![video](https://img.youtube.com/vi/dv13gl0a-FA/0.jpg)](https://www.youtube.com/watch?v=dv13gl0a-FA) 265 | 266 | itsover_wereback 267 | 268 | -------------------------------------------------------------------------------- /data/cached_fineweb100B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from huggingface_hub import hf_hub_download 4 | # Download the GPT-2 tokens of Fineweb100B from huggingface. This 5 | # saves about an hour of startup time compared to regenerating them. 6 | def get(fname): 7 | local_dir = os.path.join(os.path.dirname(__file__), 'fineweb100B') 8 | if not os.path.exists(os.path.join(local_dir, fname)): 9 | hf_hub_download(repo_id="kjj0/fineweb100B-gpt2", filename=fname, 10 | repo_type="dataset", local_dir=local_dir) 11 | get("fineweb_val_%06d.bin" % 0) 12 | num_chunks = 1030 # full fineweb100B. Each chunk is 100M tokens 13 | if len(sys.argv) >= 2: # we can pass an argument to download less 14 | num_chunks = int(sys.argv[1]) 15 | for i in range(1, num_chunks+1): 16 | get("fineweb_train_%06d.bin" % i) 17 | -------------------------------------------------------------------------------- /data/cached_fineweb10B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from huggingface_hub import hf_hub_download 4 | # Download the GPT-2 tokens of Fineweb10B from huggingface. This 5 | # saves about an hour of startup time compared to regenerating them. 6 | def get(fname): 7 | local_dir = os.path.join(os.path.dirname(__file__), 'fineweb10B') 8 | if not os.path.exists(os.path.join(local_dir, fname)): 9 | hf_hub_download(repo_id="kjj0/fineweb10B-gpt2", filename=fname, 10 | repo_type="dataset", local_dir=local_dir) 11 | get("fineweb_val_%06d.bin" % 0) 12 | num_chunks = 103 # full fineweb10B. Each chunk is 100M tokens 13 | if len(sys.argv) >= 2: # we can pass an argument to download less 14 | num_chunks = int(sys.argv[1]) 15 | for i in range(1, num_chunks+1): 16 | get("fineweb_train_%06d.bin" % i) 17 | -------------------------------------------------------------------------------- /data/fineweb.py: -------------------------------------------------------------------------------- 1 | """ 2 | FineWeb dataset (for srs pretraining) 3 | https://huggingface.co/datasets/HuggingFaceFW/fineweb 4 | 5 | example doc to highlight the structure of the dataset: 6 | { 7 | "text": "Posted by mattsmith on 20th April 2012\nStraight from...", 8 | "id": "", 9 | "dump": "CC-MAIN-2013-20", 10 | "url": "http://nleastchatter.com/philliesphandom/tag/freddy-galvis/", 11 | "date": "2013-05-18T07:24:47Z", 12 | "file_path": "s3://commoncrawl/long.../path.../file.gz", 13 | "language": "en", 14 | "language_score": 0.9185474514961243, 15 | "token_count": 594 16 | } 17 | """ 18 | import os 19 | import argparse 20 | import multiprocessing as mp 21 | import numpy as np 22 | import tiktoken 23 | # from huggingface_hub import snapshot_download 24 | from datasets import load_dataset 25 | from tqdm import tqdm 26 | import argparse 27 | import numpy as np 28 | def write_datafile(filename, toks): 29 | """ 30 | Saves token data as a .bin file, for reading in C. 31 | - First comes a header with 256 int32s 32 | - The tokens follow, each as a uint16 33 | """ 34 | assert len(toks) < 2**31, "token count too large" # ~2.1B tokens 35 | # construct the header 36 | header = np.zeros(256, dtype=np.int32) 37 | header[0] = 20240520 # magic 38 | header[1] = 1 # version 39 | header[2] = len(toks) # number of tokens after the 256*4 bytes of header (each 2 bytes as uint16) 40 | # construct the tokens numpy array, if not already 41 | if not isinstance(toks, np.ndarray) or not toks.dtype == np.uint16: 42 | # validate that no token exceeds a uint16 43 | maxtok = 2**16 44 | assert all(0 <= t < maxtok for t in toks), "token dictionary too large for uint16" 45 | toks_np = np.array(toks, dtype=np.uint16) 46 | else: 47 | toks_np = toks 48 | # write to file 49 | print(f"writing {len(toks):,} tokens to {filename}") 50 | with open(filename, "wb") as f: 51 | f.write(header.tobytes()) 52 | f.write(toks_np.tobytes()) 53 | # ------------------------------------------ 54 | 55 | parser = argparse.ArgumentParser(description="FineWeb dataset preprocessing") 56 | parser.add_argument("-v", "--version", type=str, default="10B", help="Which version of fineweb to use 10B|100B") 57 | parser.add_argument("-s", "--shard_size", type=int, default=10**8, help="Size of each shard in tokens") 58 | args = parser.parse_args() 59 | 60 | # FineWeb has a few possible subsamples available 61 | assert args.version in ["10B", "100B"], "version must be one of 10B, 100B" 62 | if args.version == "10B": 63 | local_dir = "fineweb10B" 64 | remote_name = "sample-10BT" 65 | elif args.version == "100B": 66 | local_dir = "fineweb100B" 67 | remote_name = "sample-100BT" 68 | 69 | # create the cache the local directory if it doesn't exist yet 70 | DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir) 71 | os.makedirs(DATA_CACHE_DIR, exist_ok=True) 72 | 73 | # download the dataset 74 | fw = load_dataset("HuggingFaceFW/fineweb", name=remote_name, split="train") 75 | 76 | # init the tokenizer 77 | enc = tiktoken.get_encoding("gpt2") 78 | eot = enc._special_tokens['<|endoftext|>'] # end of text token 79 | def tokenize(doc): 80 | # tokenizes a single document and returns a numpy array of uint16 tokens 81 | tokens = [eot] # the special <|endoftext|> token delimits all documents 82 | tokens.extend(enc.encode_ordinary(doc["text"])) 83 | tokens_np = np.array(tokens) 84 | assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16" 85 | tokens_np_uint16 = tokens_np.astype(np.uint16) 86 | return tokens_np_uint16 87 | 88 | # tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder) 89 | nprocs = max(1, os.cpu_count() - 2) # don't hog the entire system 90 | with mp.Pool(nprocs) as pool: 91 | shard_index = 0 92 | # preallocate buffer to hold current shard 93 | all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16) 94 | token_count = 0 95 | progress_bar = None 96 | for tokens in pool.imap(tokenize, fw, chunksize=16): 97 | 98 | # is there enough space in the current shard for the new tokens? 99 | if token_count + len(tokens) < args.shard_size: 100 | # simply append tokens to current shard 101 | all_tokens_np[token_count:token_count+len(tokens)] = tokens 102 | token_count += len(tokens) 103 | # update progress bar 104 | if progress_bar is None: 105 | progress_bar = tqdm(total=args.shard_size, unit="tokens", desc=f"Shard {shard_index}") 106 | progress_bar.update(len(tokens)) 107 | else: 108 | # write the current shard and start a new one 109 | split = "val" if shard_index == 0 else "train" 110 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin") 111 | # split the document into whatever fits in this shard; the remainder goes to next one 112 | remainder = args.shard_size - token_count 113 | progress_bar.update(remainder) 114 | all_tokens_np[token_count:token_count+remainder] = tokens[:remainder] 115 | write_datafile(filename, all_tokens_np) 116 | shard_index += 1 117 | progress_bar = None 118 | # populate the next shard with the leftovers of the current doc 119 | all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:] 120 | token_count = len(tokens)-remainder 121 | 122 | # write any remaining tokens as the last shard 123 | if token_count != 0: 124 | split = "val" if shard_index == 0 else "train" 125 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin") 126 | write_datafile(filename, all_tokens_np[:token_count]) 127 | -------------------------------------------------------------------------------- /data/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | tiktoken 3 | -------------------------------------------------------------------------------- /img/algo_optimizer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/algo_optimizer.png -------------------------------------------------------------------------------- /img/dofa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/dofa.jpg -------------------------------------------------------------------------------- /img/fig_optimizer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/fig_optimizer.png -------------------------------------------------------------------------------- /img/fig_tuned_nanogpt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/fig_tuned_nanogpt.png -------------------------------------------------------------------------------- /img/nanogpt_speedrun51.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/nanogpt_speedrun51.png -------------------------------------------------------------------------------- /img/nanogpt_speedrun52.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/nanogpt_speedrun52.png -------------------------------------------------------------------------------- /img/nanogpt_speedrun53.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/nanogpt_speedrun53.png -------------------------------------------------------------------------------- /img/nanogpt_speedrun54.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/nanogpt_speedrun54.png -------------------------------------------------------------------------------- /records/060624_AdamW/README.md: -------------------------------------------------------------------------------- 1 | This is the log for my baseline AdamW training to which I compared the new Muon and SOAP optimizers. 2 | 3 | just the log, which is in the old llm.c format ("tel" lines are val loss) 4 | 5 | this was batch size 2^19, so ~5B tokens 6 | 7 | was learning rate 0.0018, warmup=250, warmdown=2000, betas=(0.9, 0.95) IIRC 8 | 9 | -------------------------------------------------------------------------------- /records/100924_SOAP/README.md: -------------------------------------------------------------------------------- 1 | # SOAP record October 9 2024 2 | 3 | * New sample efficiency record: <3.28 validation loss in 3.15B tokens 4 | * Uses SOAP optimizer ([Vyas et al. 2024](https://arxiv.org/abs/2409.11321)) 5 | * 363ms/step - not a new wallclock record (SOAP is in active development to reduce the wallclock overhead for distributed training, so this may change) 6 | * Set by Nikhil Vyas @vyasnikhil96. Hyperparameters also tuned slightly by me 7 | * [https://x.com/vyasnikhil96/status/1842656792217858063](https://x.com/vyasnikhil96/status/1842656792217858063) 8 | * [https://github.com/nikhilvyas/modded-nanogpt-SOAP/tree/master](https://github.com/nikhilvyas/modded-nanogpt-SOAP/tree/master) 9 | 10 | -------------------------------------------------------------------------------- /records/101024_Muon/train_gpt2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | with open(sys.argv[0]) as f: 4 | code = f.read() # read the code of this file ASAP, for logging 5 | import uuid 6 | import glob 7 | import time 8 | from dataclasses import dataclass 9 | 10 | import numpy as np 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | import torch.distributed as dist 15 | import torch._inductor.config as config 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | 18 | # ----------------------------------------------------------------------------- 19 | # Muon optimizer 20 | 21 | def zeropower_via_svd(G, steps=None): 22 | U, S, V = G.svd() 23 | return U @ V.T 24 | 25 | @torch.compile 26 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): 27 | """ 28 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 29 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 30 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 31 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 32 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 33 | where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model 34 | performance at all relative to UV^T, where USV^T = G is the SVD. 35 | """ 36 | assert len(G.shape) == 2 37 | a, b, c = (3.4445, -4.7750, 2.0315) 38 | X = G.bfloat16() / (G.norm() + eps) # ensure top singular value <= 1 39 | if G.size(0) > G.size(1): 40 | X = X.T 41 | for _ in range(steps): 42 | A = X @ X.T 43 | B = A @ X 44 | X = a * X + b * B + c * A @ B 45 | if G.size(0) > G.size(1): 46 | X = X.T 47 | return X.to(G.dtype) 48 | 49 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5) 50 | 51 | class Muon(torch.optim.Optimizer): 52 | """ 53 | Muon: MomentUm Orthogonalized by Newton-schulz 54 | 55 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 56 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 57 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 58 | the advantage that it can be stably run in bfloat16 on the GPU. 59 | 60 | Some warnings: 61 | - This optimizer assumes that all parameters passed in are 2D. 62 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D 63 | parameters; those should all be optimized by a standard method (e.g., AdamW). 64 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. 65 | - We believe it is unlikely to work well for training with small batch size. 66 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this. 67 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). 68 | 69 | Arguments: 70 | lr: The learning rate used by the internal SGD. 71 | momentum: The momentum used by the internal SGD. 72 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 73 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') 74 | backend_steps: The number of iteration steps to use in the backend, if it is iterative. 75 | """ 76 | def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5): 77 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps) 78 | super().__init__(params, defaults) 79 | 80 | def step(self): 81 | for group in self.param_groups: 82 | lr = group['lr'] 83 | momentum = group['momentum'] 84 | zeropower_backend = zeropower_backends[group['backend']] 85 | for p in group['params']: 86 | g = p.grad 87 | if g is None: 88 | continue 89 | state = self.state[p] 90 | if 'momentum_buffer' not in state: 91 | state['momentum_buffer'] = torch.zeros_like(g) 92 | buf = state['momentum_buffer'] 93 | buf.mul_(momentum).add_(g) 94 | if group['nesterov']: 95 | g = g.add(buf, alpha=momentum) 96 | if g.size(0) == 3 * g.size(1): # split grouped QKV parameters 97 | g = torch.cat([zeropower_backend(g1, steps=group['backend_steps']) for g1 in g.split(g.size(1))]) 98 | scale = g.size(1)**0.5 99 | else: 100 | g = zeropower_backend(g, steps=group['backend_steps']) 101 | scale = max(g.size(0), g.size(1))**0.5 # scale to have update.square().mean() == 1 102 | p.data.add_(g, alpha=-lr * scale) 103 | 104 | # ----------------------------------------------------------------------------- 105 | # PyTorch nn.Module definitions for the GPT-2 model 106 | 107 | class Rotary(torch.nn.Module): 108 | 109 | def __init__(self, dim, base=10000): 110 | super().__init__() 111 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 112 | self.register_buffer("inv_freq", inv_freq) 113 | self.seq_len_cached = None 114 | self.cos_cached = None 115 | self.sin_cached = None 116 | 117 | def forward(self, x): 118 | seq_len = x.shape[1] 119 | if seq_len != self.seq_len_cached: 120 | self.seq_len_cached = seq_len 121 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) 122 | freqs = torch.outer(t, self.inv_freq).to(x.device) 123 | self.cos_cached = freqs.cos() 124 | self.sin_cached = freqs.sin() 125 | return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] 126 | 127 | def apply_rotary_emb(x, cos, sin): 128 | assert x.ndim == 4 # multihead attention 129 | d = x.shape[3]//2 130 | x1 = x[..., :d] 131 | x2 = x[..., d:] 132 | y1 = x1 * cos + x2 * sin 133 | y2 = x1 * (-sin) + x2 * cos 134 | return torch.cat([y1, y2], 3) 135 | 136 | def rmsnorm(x0, eps=1e-6): 137 | x = x0.float() 138 | x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 139 | return x.type_as(x0) 140 | 141 | class CausalSelfAttention(nn.Module): 142 | 143 | def __init__(self, config): 144 | super().__init__() 145 | self.n_head = config.n_head 146 | self.n_embd = config.n_embd 147 | self.head_dim = self.n_embd // self.n_head 148 | assert self.n_embd % self.n_head == 0 149 | # key, query, value projections for all heads, but in a batch 150 | self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False) 151 | # output projection 152 | self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) 153 | self.rotary = Rotary(self.head_dim) 154 | 155 | def forward(self, x): 156 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 157 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 158 | qkv = self.c_attn(x) 159 | q, k, v = qkv.split(self.n_embd, dim=2) 160 | k = k.view(B, T, self.n_head, self.head_dim) 161 | q = q.view(B, T, self.n_head, self.head_dim) 162 | v = v.view(B, T, self.n_head, self.head_dim) 163 | cos, sin = self.rotary(q) 164 | q = apply_rotary_emb(q, cos, sin) 165 | k = apply_rotary_emb(k, cos, sin) 166 | y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) 167 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 168 | # output projection 169 | y = self.c_proj(y) 170 | return y 171 | 172 | class MLP(nn.Module): 173 | 174 | def __init__(self, config): 175 | super().__init__() 176 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) 177 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) 178 | 179 | def forward(self, x): 180 | x = self.c_fc(x) 181 | x = F.gelu(x) 182 | x = self.c_proj(x) 183 | return x 184 | 185 | class Block(nn.Module): 186 | 187 | def __init__(self, config): 188 | super().__init__() 189 | self.attn = CausalSelfAttention(config) 190 | self.mlp = MLP(config) 191 | self.attn_scale = (1 / (2 * config.n_layer)**0.5) 192 | 193 | def forward(self, x): 194 | x = x + self.attn_scale * self.attn(rmsnorm(x)) 195 | x = x + self.mlp(rmsnorm(x)) 196 | return x 197 | 198 | # ----------------------------------------------------------------------------- 199 | # The main GPT-2 model 200 | 201 | @dataclass 202 | class GPTConfig: 203 | vocab_size : int = 50257 204 | n_layer : int = 12 205 | n_head : int = 12 206 | n_embd : int = 768 207 | 208 | class GPT(nn.Module): 209 | 210 | def __init__(self, config): 211 | super().__init__() 212 | self.config = config 213 | 214 | self.transformer = nn.ModuleDict(dict( 215 | wte = nn.Embedding(config.vocab_size, config.n_embd), 216 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 217 | )) 218 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 219 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 220 | 221 | def forward(self, idx, targets=None, return_logits=True): 222 | b, t = idx.size() 223 | pos = torch.arange(0, t, dtype=torch.long, device=idx.device) # shape (t) 224 | 225 | # forward the GPT model itself 226 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 227 | 228 | for block in self.transformer.h: 229 | x = block(x) 230 | x = rmsnorm(x) 231 | 232 | if targets is not None: 233 | # if we are given some desired targets also calculate the loss 234 | logits = self.lm_head(x) 235 | logits = logits.float() # use tf32/fp32 for logits 236 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 237 | else: 238 | # inference-time mini-optimization: only forward the lm_head on the very last position 239 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 240 | logits = logits.float() # use tf32/fp32 for logits 241 | loss = None 242 | 243 | # there are performance reasons why not returning logits is prudent, if not needed 244 | if not return_logits: 245 | logits = None 246 | 247 | return logits, loss 248 | 249 | # ----------------------------------------------------------------------------- 250 | # Our own simple Distributed Data Loader 251 | 252 | def _peek_data_shard(filename): 253 | # only reads the header, returns header data 254 | with open(filename, "rb") as f: 255 | # first read the header, which is 256 int32 integers (4 bytes each) 256 | header = np.frombuffer(f.read(256*4), dtype=np.int32) 257 | if header[0] != 20240520: 258 | print("ERROR: magic number mismatch in the data .bin file!") 259 | print("---> HINT: Are you passing in a correct file with --input_bin?") 260 | print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README") 261 | print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try") 262 | exit(1) 263 | assert header[1] == 1, "unsupported version" 264 | ntok = header[2] # number of tokens (claimed) 265 | return ntok # for now just return the number of tokens 266 | 267 | def _load_data_shard(filename): 268 | with open(filename, "rb") as f: 269 | # first read the header, which is 256 int32 integers (4 bytes each) 270 | header = np.frombuffer(f.read(256*4), dtype=np.int32) 271 | assert header[0] == 20240520, "magic number mismatch in the data .bin file" 272 | assert header[1] == 1, "unsupported version" 273 | ntok = header[2] # number of tokens (claimed) 274 | # the rest of it are tokens, stored as uint16 275 | tokens = np.frombuffer(f.read(), dtype=np.uint16) 276 | assert len(tokens) == ntok, "number of tokens read does not match header?" 277 | return tokens 278 | 279 | class DistributedDataLoader: 280 | def __init__(self, filename_pattern, B, T, process_rank, num_processes): 281 | self.process_rank = process_rank 282 | self.num_processes = num_processes 283 | self.B = B 284 | self.T = T 285 | 286 | # glob files that match the pattern 287 | self.files = sorted(glob.glob(filename_pattern)) 288 | assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}" 289 | 290 | # load and validate all data shards, count number of tokens in total 291 | ntok_total = 0 292 | for fname in self.files: 293 | shard_ntok = _peek_data_shard(fname) 294 | assert shard_ntok >= num_processes * B * T + 1 295 | ntok_total += int(shard_ntok) 296 | self.ntok_total = ntok_total 297 | 298 | # kick things off 299 | self.reset() 300 | 301 | def reset(self): 302 | self.current_shard = 0 303 | self.current_position = self.process_rank * self.B * self.T 304 | self.tokens = _load_data_shard(self.files[self.current_shard]) 305 | 306 | def advance(self): # advance to next data shard 307 | self.current_shard = (self.current_shard + 1) % len(self.files) 308 | self.current_position = self.process_rank * self.B * self.T 309 | self.tokens = _load_data_shard(self.files[self.current_shard]) 310 | 311 | def next_batch(self): 312 | B = self.B 313 | T = self.T 314 | buf = self.tokens[self.current_position : self.current_position+B*T+1] 315 | buf = torch.tensor(buf.astype(np.int32), dtype=torch.long) 316 | x = (buf[:-1]).view(B, T) # inputs 317 | y = (buf[1:]).view(B, T) # targets 318 | # advance current position and load next shard if necessary 319 | self.current_position += B * T * self.num_processes 320 | if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens): 321 | self.advance() 322 | return x.cuda(), y.cuda() 323 | 324 | # ----------------------------------------------------------------------------- 325 | # int main 326 | 327 | @dataclass 328 | class Hyperparameters: 329 | # data hyperparams 330 | input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on 331 | input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on 332 | # optimization hyperparams 333 | batch_size : int = 8*64 # batch size, in sequences, across all devices 334 | device_batch_size : int = 64 # batch size, in sequences, per device 335 | sequence_length : int = 1024 # sequence length, in tokens 336 | num_iterations : int = 6200 # number of iterations to run 337 | learning_rate : float = 0.0036 338 | warmup_iters : int = 0 339 | warmdown_iters : int = 1800 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule 340 | weight_decay : float = 0 341 | # evaluation and logging hyperparams 342 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end 343 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons 344 | save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end 345 | args = Hyperparameters() 346 | 347 | # set up DDP (distributed data parallel). torchrun sets this env variable 348 | assert torch.cuda.is_available() 349 | dist.init_process_group(backend='nccl') 350 | ddp_rank = int(os.environ['RANK']) 351 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 352 | ddp_world_size = int(os.environ['WORLD_SIZE']) 353 | device = f'cuda:{ddp_local_rank}' 354 | torch.cuda.set_device(device) 355 | print(f"using device: {device}") 356 | master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc. 357 | 358 | # convenience variables 359 | B, T = args.device_batch_size, args.sequence_length 360 | # calculate the number of steps to take in the val loop. 361 | assert args.val_tokens % (B * T * ddp_world_size) == 0 362 | val_steps = args.val_tokens // (B * T * ddp_world_size) 363 | # calculate the steps of gradient accumulation required to attain the desired global batch size. 364 | assert args.batch_size % (B * ddp_world_size) == 0 365 | train_accumulation_steps = args.batch_size // (B * ddp_world_size) 366 | 367 | # load tokens 368 | train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size) 369 | val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size) 370 | if master_process: 371 | print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files") 372 | print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files") 373 | x, y = train_loader.next_batch() 374 | 375 | # init the model from scratch 376 | num_vocab = 50257 377 | model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=12, n_embd=768)) 378 | model = model.cuda() 379 | if hasattr(config, "coordinate_descent_tuning"): 380 | config.coordinate_descent_tuning = True # suggested by @Chillee 381 | model = torch.compile(model) 382 | # here we wrap model into DDP container 383 | model = DDP(model, device_ids=[ddp_local_rank]) 384 | raw_model = model.module # always contains the "raw" unwrapped model 385 | ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16) 386 | 387 | # init the optimizer(s) 388 | optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95), 389 | weight_decay=args.weight_decay, fused=True) 390 | optimizer2 = Muon(raw_model.transformer.h.parameters(), lr=0.1*args.learning_rate, momentum=0.95) 391 | optimizers = [optimizer1, optimizer2] 392 | # learning rate decay scheduler (linear warmup and warmdown) 393 | def get_lr(it): 394 | assert it <= args.num_iterations 395 | # 1) linear warmup for warmup_iters steps 396 | if it < args.warmup_iters: 397 | return (it+1) / args.warmup_iters 398 | # 2) constant lr for a while 399 | elif it < args.num_iterations - args.warmdown_iters: 400 | return 1.0 401 | # 3) linear warmdown 402 | else: 403 | decay_ratio = (args.num_iterations - it) / args.warmdown_iters 404 | return decay_ratio 405 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] 406 | 407 | # begin logging 408 | if master_process: 409 | run_id = str(uuid.uuid4()) 410 | logdir = 'logs/%s/' % run_id 411 | os.makedirs(logdir, exist_ok=True) 412 | logfile = 'logs/%s.txt' % run_id 413 | # create the log file 414 | with open(logfile, "w") as f: 415 | # begin the log by printing this file (the Python code) 416 | f.write('='*100 + '\n') 417 | f.write(code) 418 | f.write('='*100 + '\n') 419 | # log information about the hardware/software environment this is running on 420 | # and print the full `nvidia-smi` to file 421 | f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n") 422 | import subprocess 423 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 424 | f.write(f'{result.stdout}\n') 425 | f.write('='*100 + '\n') 426 | 427 | training_time_ms = 0 428 | # start the clock 429 | torch.cuda.synchronize() 430 | t0 = time.time() 431 | # begin training 432 | train_loader.reset() 433 | for step in range(args.num_iterations + 1): 434 | last_step = (step == args.num_iterations) 435 | # This effectively ignores timing first 10 steps, which are slower for weird reasons. 436 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10 437 | # steps with dummy data first, and then re-initialize the model and reset the loader. 438 | if step == 10: 439 | training_time_ms = 0 440 | t0 = time.time() 441 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val 442 | 443 | # once in a while evaluate the validation dataset 444 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)): 445 | # stop the clock 446 | torch.cuda.synchronize() 447 | training_time_ms += 1000 * (time.time() - t0) 448 | # run validation batches 449 | model.eval() 450 | val_loader.reset() 451 | val_loss = 0.0 452 | for _ in range(val_steps): 453 | x_val, y_val = val_loader.next_batch() 454 | with torch.no_grad(): # of course, we'd like to use ctx here too, but that creates a torch.compile error for some reason 455 | _, loss = model(x_val, y_val, return_logits=False) 456 | val_loss += loss 457 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) 458 | val_loss /= val_steps 459 | # log val loss to console and to logfile 460 | if master_process: 461 | print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms') 462 | with open(logfile, "a") as f: 463 | f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n') 464 | # start the clock again 465 | torch.cuda.synchronize() 466 | t0 = time.time() 467 | 468 | if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)): 469 | # stop the clock 470 | torch.cuda.synchronize() 471 | training_time_ms += 1000 * (time.time() - t0) 472 | # save the state of the training process 473 | log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) 474 | torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step)) 475 | # start the clock again 476 | torch.cuda.synchronize() 477 | t0 = time.time() 478 | 479 | # bit confusing: we want to make sure to eval on 0th iteration 480 | # but also after the very last iteration. so we loop for step <= num_iterations 481 | # instead of just < num_iterations (one extra due to <=), only to do 482 | # the validation/sampling one last time, and then we break right here as we're done. 483 | if last_step: 484 | break 485 | 486 | # --------------- TRAINING SECTION BEGIN ----------------- 487 | model.train() 488 | for i in range(1, train_accumulation_steps+1): 489 | # forward pass 490 | with ctx: 491 | _, loss = model(x, y, return_logits=False) 492 | train_loss = loss.detach() 493 | # advance the dataset for the next batch 494 | x, y = train_loader.next_batch() 495 | # backward pass 496 | if i < train_accumulation_steps: 497 | with model.no_sync(): # there's no need to sync gradients every accumulation step 498 | loss.backward() 499 | else: 500 | loss.backward() # just sync on the last step 501 | for p in model.parameters(): 502 | p.grad /= train_accumulation_steps 503 | # step the optimizers and schedulers 504 | for opt, sched in zip(optimizers, schedulers): 505 | opt.step() 506 | sched.step() 507 | # null the gradients 508 | model.zero_grad(set_to_none=True) 509 | # --------------- TRAINING SECTION END ------------------- 510 | # everything that follows now is just diagnostics, prints, logging, etc. 511 | 512 | #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower 513 | if master_process: 514 | approx_time = training_time_ms + 1000 * (time.time() - t0) 515 | print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms") 516 | with open(logfile, "a") as f: 517 | f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n") 518 | 519 | if master_process: 520 | print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") 521 | 522 | # ------------------------------------------------------------------------- 523 | # clean up nice 524 | dist.destroy_process_group() 525 | -------------------------------------------------------------------------------- /records/101324_llmc/README.md: -------------------------------------------------------------------------------- 1 | This is a log produced by running the current version of Andrej Karpathy's [llm.c](https://github.com/karpathy/llm.c), as of October 13th 2024. 2 | 3 | It was run on a node with 8x H100 HBM3 according to the instructions [here](https://github.com/karpathy/llm.c/discussions/481). 4 | The mean per-step time was 140ms. The total number of training tokens is 10.26B. The final validation loss was **3.2722**. 5 | 6 | This is (significantly) better than the quoted result of **3.29** val loss in 7 | [Andrej Karpathy's May 28th GPT-2 replication discussion](https://github.com/karpathy/llm.c/discussions/481#:~:text=By%20the%20end%20of%20the%20optimization%20we%27ll%20get%20to%20about%203.29). 8 | So it appears that there have been some improvements to the training algorithm used by llm.c since then. 9 | 10 | Note that the set of examples which llm.c uses for validation appears to be the same as what we do in this repo, i.e., the first `10 * 2**20` tokens of the val set. 11 | 12 | -------------------------------------------------------------------------------- /records/101424_ModernArch/train_gpt2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | with open(sys.argv[0]) as f: 4 | code = f.read() # read the code of this file ASAP, for logging 5 | import uuid 6 | import glob 7 | import time 8 | from dataclasses import dataclass 9 | 10 | import numpy as np 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | import torch.distributed as dist 15 | import torch._inductor.config as config 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | 18 | # ----------------------------------------------------------------------------- 19 | # Muon optimizer 20 | 21 | def zeropower_via_svd(G, steps=None): 22 | U, S, V = G.svd() 23 | return U @ V.T 24 | 25 | @torch.compile 26 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): 27 | """ 28 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 29 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 30 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 31 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 32 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 33 | where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model 34 | performance at all relative to UV^T, where USV^T = G is the SVD. 35 | """ 36 | assert len(G.shape) == 2 37 | a, b, c = (3.4445, -4.7750, 2.0315) 38 | X = G.bfloat16() 39 | X /= (X.norm() + eps) # ensure top singular value <= 1 40 | if G.size(0) > G.size(1): 41 | X = X.T 42 | for _ in range(steps): 43 | A = X @ X.T 44 | B = A @ X 45 | X = a * X + b * B + c * A @ B 46 | if G.size(0) > G.size(1): 47 | X = X.T 48 | return X 49 | 50 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5) 51 | 52 | class Muon(torch.optim.Optimizer): 53 | """ 54 | Muon - MomentUm Orthogonalized by Newton-schulz 55 | 56 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 57 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 58 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 59 | the advantage that it can be stably run in bfloat16 on the GPU. 60 | 61 | Some warnings: 62 | - This optimizer assumes that all parameters passed in are 2D. 63 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D 64 | parameters; those should all be optimized by a standard method (e.g., AdamW). 65 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. 66 | - We believe it is unlikely to work well for training with small batch size. 67 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this. 68 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). 69 | 70 | Arguments: 71 | lr: The learning rate used by the internal SGD. 72 | momentum: The momentum used by the internal SGD. 73 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 74 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') 75 | backend_steps: The number of iteration steps to use in the backend, if it is iterative. 76 | """ 77 | def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5): 78 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps) 79 | super().__init__(params, defaults) 80 | 81 | def step(self): 82 | for group in self.param_groups: 83 | lr = group['lr'] 84 | momentum = group['momentum'] 85 | zeropower_backend = zeropower_backends[group['backend']] 86 | for p in group['params']: 87 | g = p.grad 88 | if g is None: 89 | continue 90 | state = self.state[p] 91 | if 'momentum_buffer' not in state: 92 | state['momentum_buffer'] = torch.zeros_like(g) 93 | buf = state['momentum_buffer'] 94 | buf.mul_(momentum).add_(g) 95 | if group['nesterov']: 96 | g = g.add(buf, alpha=momentum) 97 | if g.size(0) == 3 * g.size(1): # split grouped QKV parameters 98 | g = torch.cat([zeropower_backend(g1, steps=group['backend_steps']) for g1 in g.split(g.size(1))]) 99 | scale = g.size(1)**0.5 100 | else: 101 | g = zeropower_backend(g, steps=group['backend_steps']) 102 | scale = max(g.size(0), g.size(1))**0.5 # scale to have update.square().mean() == 1 103 | p.data.add_(g, alpha=-lr * scale) 104 | 105 | # ----------------------------------------------------------------------------- 106 | # PyTorch nn.Module definitions for the GPT-2 model 107 | 108 | class Rotary(torch.nn.Module): 109 | 110 | def __init__(self, dim, base=10000): 111 | super().__init__() 112 | self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 113 | self.seq_len_cached = None 114 | self.cos_cached = None 115 | self.sin_cached = None 116 | 117 | def forward(self, x): 118 | seq_len = x.shape[1] 119 | if seq_len != self.seq_len_cached: 120 | self.seq_len_cached = seq_len 121 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) 122 | freqs = torch.outer(t, self.inv_freq).to(x.device) 123 | self.cos_cached = freqs.cos().bfloat16() 124 | self.sin_cached = freqs.sin().bfloat16() 125 | return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] 126 | 127 | def apply_rotary_emb(x, cos, sin): 128 | assert x.ndim == 4 # multihead attention 129 | d = x.shape[3]//2 130 | x1 = x[..., :d] 131 | x2 = x[..., d:] 132 | y1 = x1 * cos + x2 * sin 133 | y2 = x1 * (-sin) + x2 * cos 134 | return torch.cat([y1, y2], 3).type_as(x) 135 | 136 | class CausalSelfAttention(nn.Module): 137 | 138 | def __init__(self, config): 139 | super().__init__() 140 | self.n_head = config.n_head 141 | self.n_embd = config.n_embd 142 | self.head_dim = self.n_embd // self.n_head 143 | assert self.n_embd % self.n_head == 0 144 | self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False) 145 | self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False) 146 | self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False) 147 | # output projection 148 | self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) 149 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 150 | self.rotary = Rotary(self.head_dim) 151 | 152 | def forward(self, x): 153 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 154 | q = self.c_q(x).view(B, T, self.n_head, self.head_dim) 155 | k = self.c_k(x).view(B, T, self.n_head, self.head_dim) 156 | v = self.c_v(x).view(B, T, self.n_head, self.head_dim) 157 | cos, sin = self.rotary(q) 158 | q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) 159 | q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) # QK norm suggested by @Grad62304977 160 | y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) 161 | y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side 162 | y = self.c_proj(y) 163 | return y 164 | 165 | class MLP(nn.Module): 166 | 167 | def __init__(self, config): 168 | super().__init__() 169 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) 170 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) 171 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 172 | 173 | def forward(self, x): 174 | x = self.c_fc(x) 175 | x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 176 | x = self.c_proj(x) 177 | return x 178 | 179 | class Block(nn.Module): 180 | 181 | def __init__(self, config): 182 | super().__init__() 183 | self.attn = CausalSelfAttention(config) 184 | self.mlp = MLP(config) 185 | 186 | def forward(self, x): 187 | x = x + self.attn(F.rms_norm(x, (x.size(-1),))) 188 | x = x + self.mlp(F.rms_norm(x, (x.size(-1),))) 189 | return x 190 | 191 | # ----------------------------------------------------------------------------- 192 | # The main GPT-2 model 193 | 194 | @dataclass 195 | class GPTConfig: 196 | vocab_size : int = 50304 197 | n_layer : int = 12 198 | n_head : int = 6 # head dim 128 suggested by @Grad62304977 199 | n_embd : int = 768 200 | 201 | class GPT(nn.Module): 202 | 203 | def __init__(self, config): 204 | super().__init__() 205 | self.config = config 206 | 207 | self.transformer = nn.ModuleDict(dict( 208 | wte = nn.Embedding(config.vocab_size, config.n_embd), 209 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 210 | )) 211 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 212 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 213 | 214 | def forward(self, idx, targets=None, return_logits=True): 215 | 216 | # forward the GPT model itself 217 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 218 | for block in self.transformer.h: 219 | x = block(x) 220 | x = F.rms_norm(x, (x.size(-1),)) 221 | 222 | if targets is not None: 223 | # if we are given some desired targets also calculate the loss 224 | logits = self.lm_head(x) 225 | logits = logits.float() # use tf32/fp32 for logits 226 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 227 | else: 228 | # inference-time mini-optimization: only forward the lm_head on the very last position 229 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 230 | logits = logits.float() # use tf32/fp32 for logits 231 | loss = None 232 | 233 | # there are performance reasons why not returning logits is prudent, if not needed 234 | if not return_logits: 235 | logits = None 236 | 237 | return logits, loss 238 | 239 | # ----------------------------------------------------------------------------- 240 | # Our own simple Distributed Data Loader 241 | 242 | def _peek_data_shard(filename): 243 | # only reads the header, returns header data 244 | with open(filename, "rb") as f: 245 | # first read the header, which is 256 int32 integers (4 bytes each) 246 | header = np.frombuffer(f.read(256*4), dtype=np.int32) 247 | if header[0] != 20240520: 248 | print("ERROR: magic number mismatch in the data .bin file!") 249 | print("---> HINT: Are you passing in a correct file with --input_bin?") 250 | print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README") 251 | print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try") 252 | exit(1) 253 | assert header[1] == 1, "unsupported version" 254 | ntok = header[2] # number of tokens (claimed) 255 | return ntok # for now just return the number of tokens 256 | 257 | def _load_data_shard(filename): 258 | with open(filename, "rb") as f: 259 | # first read the header, which is 256 int32 integers (4 bytes each) 260 | header = np.frombuffer(f.read(256*4), dtype=np.int32) 261 | assert header[0] == 20240520, "magic number mismatch in the data .bin file" 262 | assert header[1] == 1, "unsupported version" 263 | ntok = header[2] # number of tokens (claimed) 264 | # the rest of it are tokens, stored as uint16 265 | tokens = np.frombuffer(f.read(), dtype=np.uint16) 266 | assert len(tokens) == ntok, "number of tokens read does not match header?" 267 | return tokens 268 | 269 | class DistributedDataLoader: 270 | def __init__(self, filename_pattern, B, T, process_rank, num_processes): 271 | self.process_rank = process_rank 272 | self.num_processes = num_processes 273 | self.B = B 274 | self.T = T 275 | 276 | # glob files that match the pattern 277 | self.files = sorted(glob.glob(filename_pattern)) 278 | assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}" 279 | 280 | # load and validate all data shards, count number of tokens in total 281 | ntok_total = 0 282 | for fname in self.files: 283 | shard_ntok = _peek_data_shard(fname) 284 | assert shard_ntok >= num_processes * B * T + 1 285 | ntok_total += int(shard_ntok) 286 | self.ntok_total = ntok_total 287 | 288 | # kick things off 289 | self.reset() 290 | 291 | def reset(self): 292 | self.current_shard = 0 293 | self.current_position = self.process_rank * self.B * self.T 294 | self.tokens = _load_data_shard(self.files[self.current_shard]) 295 | 296 | def advance(self): # advance to next data shard 297 | self.current_shard = (self.current_shard + 1) % len(self.files) 298 | self.current_position = self.process_rank * self.B * self.T 299 | self.tokens = _load_data_shard(self.files[self.current_shard]) 300 | 301 | def next_batch(self): 302 | B = self.B 303 | T = self.T 304 | buf = self.tokens[self.current_position : self.current_position+B*T+1] 305 | buf = torch.tensor(buf.astype(np.int32), dtype=torch.long) 306 | x = (buf[:-1]).view(B, T) # inputs 307 | y = (buf[1:]).view(B, T) # targets 308 | # advance current position and load next shard if necessary 309 | self.current_position += B * T * self.num_processes 310 | if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens): 311 | self.advance() 312 | return x.cuda(), y.cuda() 313 | 314 | # ----------------------------------------------------------------------------- 315 | # int main 316 | 317 | @dataclass 318 | class Hyperparameters: 319 | # data hyperparams 320 | input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on 321 | input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on 322 | # optimization hyperparams 323 | batch_size : int = 8*64 # batch size, in sequences, across all devices 324 | device_batch_size : int = 64 # batch size, in sequences, per device 325 | sequence_length : int = 1024 # sequence length, in tokens 326 | num_iterations : int = 5100 # number of iterations to run 327 | learning_rate : float = 0.0036 328 | warmup_iters : int = 0 329 | warmdown_iters : int = 1450 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule 330 | weight_decay : float = 0 331 | # evaluation and logging hyperparams 332 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end 333 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons 334 | save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end 335 | args = Hyperparameters() 336 | 337 | # set up DDP (distributed data parallel). torchrun sets this env variable 338 | assert torch.cuda.is_available() 339 | dist.init_process_group(backend='nccl') 340 | ddp_rank = int(os.environ['RANK']) 341 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 342 | ddp_world_size = int(os.environ['WORLD_SIZE']) 343 | device = f'cuda:{ddp_local_rank}' 344 | torch.cuda.set_device(device) 345 | print(f"using device: {device}") 346 | master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc. 347 | 348 | # convenience variables 349 | B, T = args.device_batch_size, args.sequence_length 350 | # calculate the number of steps to take in the val loop. 351 | assert args.val_tokens % (B * T * ddp_world_size) == 0 352 | val_steps = args.val_tokens // (B * T * ddp_world_size) 353 | # calculate the steps of gradient accumulation required to attain the desired global batch size. 354 | assert args.batch_size % (B * ddp_world_size) == 0 355 | train_accumulation_steps = args.batch_size // (B * ddp_world_size) 356 | 357 | # load tokens 358 | train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size) 359 | val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size) 360 | if master_process: 361 | print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files") 362 | print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files") 363 | x, y = train_loader.next_batch() 364 | 365 | # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977. 366 | # this originates from Karpathy's experiments. 367 | num_vocab = 50304 368 | model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768)) 369 | model = model.cuda() 370 | if hasattr(config, "coordinate_descent_tuning"): 371 | config.coordinate_descent_tuning = True # suggested by @Chillee 372 | model = torch.compile(model) 373 | # here we wrap model into DDP container 374 | model = DDP(model, device_ids=[ddp_local_rank]) 375 | raw_model = model.module # always contains the "raw" unwrapped model 376 | ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16) 377 | 378 | # init the optimizer(s) 379 | optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95), 380 | weight_decay=args.weight_decay, fused=True) 381 | optimizer2 = Muon(raw_model.transformer.h.parameters(), lr=0.1*args.learning_rate, momentum=0.95) 382 | optimizers = [optimizer1, optimizer2] 383 | # learning rate decay scheduler (linear warmup and warmdown) 384 | def get_lr(it): 385 | assert it <= args.num_iterations 386 | # 1) linear warmup for warmup_iters steps 387 | if it < args.warmup_iters: 388 | return (it+1) / args.warmup_iters 389 | # 2) constant lr for a while 390 | elif it < args.num_iterations - args.warmdown_iters: 391 | return 1.0 392 | # 3) linear warmdown 393 | else: 394 | decay_ratio = (args.num_iterations - it) / args.warmdown_iters 395 | return decay_ratio 396 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] 397 | 398 | # begin logging 399 | if master_process: 400 | run_id = str(uuid.uuid4()) 401 | logdir = 'logs/%s/' % run_id 402 | os.makedirs(logdir, exist_ok=True) 403 | logfile = 'logs/%s.txt' % run_id 404 | # create the log file 405 | with open(logfile, "w") as f: 406 | # begin the log by printing this file (the Python code) 407 | f.write('='*100 + '\n') 408 | f.write(code) 409 | f.write('='*100 + '\n') 410 | # log information about the hardware/software environment this is running on 411 | # and print the full `nvidia-smi` to file 412 | f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n") 413 | import subprocess 414 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 415 | f.write(f'{result.stdout}\n') 416 | f.write('='*100 + '\n') 417 | 418 | training_time_ms = 0 419 | # start the clock 420 | torch.cuda.synchronize() 421 | t0 = time.time() 422 | # begin training 423 | train_loader.reset() 424 | for step in range(args.num_iterations + 1): 425 | last_step = (step == args.num_iterations) 426 | # This effectively ignores timing first 10 steps, which are slower for weird reasons. 427 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10 428 | # steps with dummy data first, and then re-initialize the model and reset the loader. 429 | if step == 10: 430 | training_time_ms = 0 431 | t0 = time.time() 432 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val 433 | 434 | # once in a while evaluate the validation dataset 435 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)): 436 | # stop the clock 437 | torch.cuda.synchronize() 438 | training_time_ms += 1000 * (time.time() - t0) 439 | # run validation batches 440 | model.eval() 441 | val_loader.reset() 442 | val_loss = 0.0 443 | for _ in range(val_steps): 444 | x_val, y_val = val_loader.next_batch() 445 | with ctx: # of course, we'd like to use no_grad() here too, but that creates a torch.compile error for some reason 446 | _, loss = model(x_val, y_val, return_logits=False) 447 | val_loss += loss.detach() 448 | del loss 449 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) 450 | val_loss /= val_steps 451 | # log val loss to console and to logfile 452 | if master_process: 453 | print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms') 454 | with open(logfile, "a") as f: 455 | f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n') 456 | # start the clock again 457 | torch.cuda.synchronize() 458 | t0 = time.time() 459 | 460 | if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)): 461 | # stop the clock 462 | torch.cuda.synchronize() 463 | training_time_ms += 1000 * (time.time() - t0) 464 | # save the state of the training process 465 | log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) 466 | torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step)) 467 | # start the clock again 468 | torch.cuda.synchronize() 469 | t0 = time.time() 470 | 471 | # bit confusing: we want to make sure to eval on 0th iteration 472 | # but also after the very last iteration. so we loop for step <= num_iterations 473 | # instead of just < num_iterations (one extra due to <=), only to do 474 | # the validation/sampling one last time, and then we break right here as we're done. 475 | if last_step: 476 | break 477 | 478 | # --------------- TRAINING SECTION BEGIN ----------------- 479 | model.train() 480 | for i in range(1, train_accumulation_steps+1): 481 | # forward pass 482 | with ctx: 483 | _, loss = model(x, y, return_logits=False) 484 | train_loss = loss.detach() 485 | # advance the dataset for the next batch 486 | x, y = train_loader.next_batch() 487 | # backward pass 488 | if i < train_accumulation_steps: 489 | with model.no_sync(): # there's no need to sync gradients every accumulation step 490 | loss.backward() 491 | else: 492 | loss.backward() # just sync on the last step 493 | for p in model.parameters(): 494 | p.grad /= train_accumulation_steps 495 | # step the optimizers and schedulers 496 | for opt, sched in zip(optimizers, schedulers): 497 | opt.step() 498 | sched.step() 499 | # null the gradients 500 | model.zero_grad(set_to_none=True) 501 | # --------------- TRAINING SECTION END ------------------- 502 | # everything that follows now is just diagnostics, prints, logging, etc. 503 | 504 | #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower 505 | if master_process: 506 | approx_time = training_time_ms + 1000 * (time.time() - t0) 507 | print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms") 508 | with open(logfile, "a") as f: 509 | f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n") 510 | 511 | if master_process: 512 | print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") 513 | 514 | # ------------------------------------------------------------------------- 515 | # clean up nice 516 | dist.destroy_process_group() 517 | -------------------------------------------------------------------------------- /records/102924_Optimizers/README.md: -------------------------------------------------------------------------------- 1 | # Optimizer comparison for NanoGPT speedrunning 2 | 3 | This is a comparison between the four best optimizers I am aware of for NanoGPT speedrunning. They are compared using the 10/18/24 NanoGPT speedrunning record. 4 | 5 | Reproducible logs: 6 | * [Adam](95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt) 7 | * [DistributedShampoo](8bfe4e35-c3fc-4b70-a984-3be937b71ff3) 8 | * [SOAP](e21a2838-a0f2-46f2-a247-db0021165682.txt) 9 | * [Muon](8d6193f4-27fc-4e68-899f-af70019a4d54.txt) 10 | 11 | Results: 12 | ![1](nanogpt_speedrun81w.png) 13 | ![2](nanogpt_speedrun82w.png) 14 | 15 | ### General notes for all optimizers 16 | 17 | All optimizers are run using zero weight decay (which is found to be empirically optimal). 18 | 19 | And they are all run with a warmup-stable-decay / trapezoidal schedule, which also seems to be optimal. That's what causes the kink in the loss curve ~75% of the way to the end. 20 | 21 | In addition, in all cases, we optimize the shared embedding/head layer just using Adam (which is also found to be empirically optimal). 22 | Note that in the following code snippets, `raw_model.transformer.h.parameters()` gives all parameters besides those two. 23 | 24 | In each case, the hyperparameters are the best ones I could find in around 20 attempts. 25 | 26 | ## [Adam](95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt) 27 | The optimizer here is equivalent to: 28 | ``` 29 | torch.optim.Adam(raw_model.transformer.h.parameters(), lr=0.0018, betas=(0.9, 0.95)) 30 | ``` 31 | 32 | 33 | ## [DistributedShampoo](8bfe4e35-c3fc-4b70-a984-3be937b71ff3.txt) 34 | Run as follows: 35 | ``` 36 | DistributedShampoo( 37 | raw_model.transformer.h.parameters(), 38 | lr=0.0018, 39 | betas=(0.95, 0.95), 40 | epsilon=1e-12, 41 | weight_decay=0, 42 | max_preconditioner_dim=8192, 43 | precondition_frequency=10, 44 | use_decoupled_weight_decay=True, 45 | grafting_config=AdamGraftingConfig( 46 | beta2=0.95, 47 | epsilon=1e-8, 48 | ), 49 | distributed_config=DDPShampooConfig( 50 | communication_dtype=CommunicationDType.FP32, 51 | num_trainers_per_group=8, 52 | communicate_params=False, 53 | ), 54 | ) 55 | ``` 56 | 57 | This is using the official `DistributedShampoo` implementation from [here](https://github.com/facebookresearch/optimizers/tree/ad2809a291c01859f68fcabbcb49a2aa75fd7827/distributed_shampoo). 58 | 59 | Things that turned out to be important: 60 | * Don't use epsilon above 1e-8; this loses performance. Epsilon 1e-12 performs as well as 1e-15 61 | * Betas=(0.95, 0.95) seemed optimal, which turns out to be the same thing that SOAP uses 62 | * Higher preconditioner update frequency is better but slower 63 | 64 | I'm open to hyperparameter suggestions; the experiment takes ~20-30 minutes to run on a fresh 8xH100 instance, so it's not hard for me to run more attempts. 65 | 66 | 67 | ## [SOAP](e21a2838-a0f2-46f2-a247-db0021165682.txt) 68 | ``` 69 | SOAP(model.transformer.h.parameters(), lr=0.0018, betas=(.95, .95), precondition_frequency=10) 70 | ``` 71 | 72 | This is using the official SOAP implementation [here](https://github.com/nikhilvyas/SOAP/blob/bbce86e890d3b697380f4376acb600c2d6c3d203/soap.py). 73 | 74 | Based on conversations with the authors, it is likely that a future SOAP implementation will significantly reduce the wallclock overhead. 75 | 76 | 77 | ## [Muon](8d6193f4-27fc-4e68-899f-af70019a4d54.txt) 78 | ``` 79 | Muon(raw_model.transformer.h.parameters(), lr=0.02, momentum=0.95) 80 | ``` 81 | 82 | 83 | ## Openness 84 | 85 | These training logs are reproducible (just cut out the part besides the code, and run it using the `run.sh` in the top-level folder). They take 12-25 minutes to run. 86 | 87 | I tried to do a good job sweeping the hyperparameters for each optimizer, but I can easily have missed something, or just not have performed enough runs. 88 | 89 | Therefore, I am interested in any better hyperparameter settings which other researchers can find, for any of the optimizers. 90 | If you post or send me your own reproducible log with one of these optimizers, I will be very happy to boost it in any way I can. 91 | 92 | ## Appendix: Negative results 93 | 94 | I believe it was Shazeer who said something like "negative results in machine learning are not worth much, because your inability to make something work doesn't prove that it can't work" 95 | 96 | Given that disclaimer, here are some optimizers that I tried to make work, but was unable to get a significant boost over Adam with: 97 | * Sophia 98 | * Lion 99 | * AdamWScheduleFree 100 | * AdEmaMix (actually this was slightly better than Adam, just not enough to get near competing with the three Shampoo-like optimizers) 101 | 102 | Of course, this is just for NanoGPT speedrunning (short train duration); it's quite possible they work better at longer training duration or for larger models. 103 | 104 | -------------------------------------------------------------------------------- /records/102924_Optimizers/nanogpt_speedrun81w.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/records/102924_Optimizers/nanogpt_speedrun81w.png -------------------------------------------------------------------------------- /records/102924_Optimizers/nanogpt_speedrun82w.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/records/102924_Optimizers/nanogpt_speedrun82w.png -------------------------------------------------------------------------------- /records/110324_UntieEmbed/README.md: -------------------------------------------------------------------------------- 1 | # New record 11/03/24 2 | 3 | 4 | New NanoGPT training speed record: 3.28 FineWeb val loss in 10.8 minutes on 8xH100 5 | 6 | Previous record: 12.0 minutes 7 | Changelog: 8 | - untied embed and head weights 9 | - added RMSNorm after embed 10 | - init head to zero 11 | 12 | Driven by @Grad62304977 13 | 14 | --- 15 | 16 | Technically, this is somewhat of an "any%" record, since untying the embedding and lm_head adds 39M parameters. 17 | 18 | However, it doesn't change the number of active parameters or the inference throughput. Future records will stay constrained to 124M active parameters. 19 | 20 | --- 21 | 22 | Like the last architectural change, this record was driven by @Grad62304977. I just finetuned some things and did bookkeeping. 23 | 24 | --- 25 | 26 | Shoutout to @cloneofsimo whose scaling guide already suggests initializing the head to zero. This works quite well and is a significant fraction of the record. 27 | 28 | -------------------------------------------------------------------------------- /records/110424_50Bruns/README.md: -------------------------------------------------------------------------------- 1 | # 50B-token runs 2 | 3 | This folder contains four runs generated by extending the 11/03/24 speedrun record to 50B FineWeb tokens. 4 | The goal is to test how the speedrun generalizes to long durations, and especially how well Muon does. 5 | 6 | We compare two things: 7 | 1. We compare Muon to Adam as the optimizer for the transformer body. (The head and embedding are always optimized by Adam.) 8 | 2. We compare training on 5 epochs of 10B tokens to training on 50B tokens. (Surprisingly this does about the same) 9 | 10 | The four resulting runs are as follows: 11 | 12 | * [Muon 50B tokens](./530f3ee1-8862-4d21-be2b-da10eb05e6a9.txt) (HellaSwag=35.82) 13 | * [Adam 50B tokens](./69c33fc9-eabb-4a38-aa08-6922914eb405.txt) (HellaSwag=34.26) 14 | * [Muon 5x10B tokens](./4fbe61ec-f79a-4c19-836d-46d599deecce.txt) (HellaSwag=36.17) 15 | * [Adam 5x10B tokens](./3d715d41-453a-40d6-9506-421ba69766b2.txt) (HellaSwag=34.05) 16 | 17 | To get a sense of what a good HellaSwag score would be for this scale of model, here are some baselines: 18 | * Karpathy's baseline llm.c training (trained for 10B FineWeb tokens): 29.9 19 | * OpenAI GPT-2 (124M): 29.4 20 | * OpenAI GPT-3 (124M) (trained for 300B WebText tokens): 33.7 21 | * Huggingface SmolLM2-135M (trained for 2T FineWeb/DCLM/etc tokens): 42.1 22 | 23 | Note: I'm a little concerned that the learning rate schedule (WSD) and weight decay (zero), which are tuned for the speedrun duration, 24 | might become undertuned/suboptimal for trainings of this duration. 25 | It does look like the gap between Muon/Adam is too large to be closed by something like this, and the HellaSwag scores look quite reasonable, but you never know. 26 | 27 | -------------------------------------------------------------------------------- /records/110624_ShortcutsTweaks/README.md: -------------------------------------------------------------------------------- 1 | # New record 11/06/24 2 | 3 | 8.2 minutes on 8xH100 (previous record: 10.8 minutes) 4 | 5 | ![](nanogpt_speedrun110.png) 6 | ![](nanogpt_speedrun111.png) 7 | 8 | * [Old record 11/03/24](d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt) 9 | * [+shorten duration](4a71cc92-0f43-4058-a033-23e85c1e98f1.txt) 10 | * [+value residual](042f9e87-07e6-4504-bb04-4ec59a380211.txt) by @Grad62304977 following [1] 11 | * [+learnable lambda](43f60c4f-0448-4de7-83d9-643ca26f61e7.txt) @Grad62304977's innovation on top of [1] 12 | * [+embed shortcut](05b29e54-0be0-4a0f-a1e2-7d5317daedd3.txt) 13 | * [+momentum warmup](10119f53-7001-4248-bfd9-33d32427a912.txt) 14 | * [+tanh logit capping](dd7304a6-cc43-4d5e-adb8-c070111464a1.txt) by @Grad62304977 following [2] 15 | 16 | ## Code snippets 17 | 18 | ### Value residual 19 | 20 | In the attention layer: 21 | ``` 22 | def forward(self, x, v1=None): 23 | ... 24 | v = self.c_v(x).view(B, T, self.n_head, self.head_dim) 25 | if v1 is None: 26 | v1 = v 27 | v = 0.5 * v + 0.5 * v1.view_as(v) 28 | ``` 29 | Where the first block receives v1=None, and subsequent blocks receive v1 as the value produced by the first block. 30 | 31 | ### Learnable lambda 32 | 33 | In the attention block: 34 | ``` 35 | def __init__(self, config): 36 | ... 37 | self.lamb = nn.Parameter(torch.tensor(0.5)) 38 | 39 | def forward(self, x, v1=None): 40 | ... 41 | v = (1 - self.lamb) * v + self.lamb * v1.view_as(v) 42 | ``` 43 | That is, we just replace the fixed 0.5 constant used in standard value residual [1] with a learnable scalar (optimized by Adam(lr=0.02)). 44 | 45 | ### Embed shortcut 46 | 47 | Replaces the standard transformer block with this: 48 | 49 | ``` 50 | class Block(nn.Module): 51 | 52 | def __init__(self, config): 53 | super().__init__() 54 | self.attn = CausalSelfAttention(config) 55 | self.mlp = MLP(config) 56 | self.lambdas = nn.Parameter(torch.tensor([1., 0.])) 57 | 58 | def forward(self, x, x0): 59 | x = self.lambdas[0] * x + self.lambdas[1] * x0 60 | x = x + self.attn(F.rms_norm(x, (x.size(-1),)), v1) 61 | x = x + self.mlp(F.rms_norm(x, (x.size(-1),))) 62 | return x 63 | ``` 64 | 65 | where the two scalars are optimized using Adam(lr=0.02), and `x0` is fed in from the initial embedding via: 66 | ``` 67 | ... 68 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 69 | x = F.rms_norm(x, (x.size(-1),)) 70 | x0 = x 71 | for block in self.transformer.h: 72 | x = block(x, x0) 73 | ... 74 | ``` 75 | 76 | ### Momentum warmup 77 | 78 | Just adds the following two lines. 79 | ``` 80 | frac = min(step/500, 1) 81 | optimizer3.param_groups[0]['momentum'] = (1 - frac) * 0.85 + frac * 0.95 82 | ``` 83 | where `optimizer3` is the Muon for the body of the transformer. 84 | 85 | ### Tanh soft capping 86 | 87 | Just adds the following line. 88 | 89 | ``` 90 | logits = 30 * torch.tanh(logits / 30) 91 | ``` 92 | 93 | 94 | ## References 95 | 96 | 1. [Zhou, Zhanchao, et al. "Value Residual Learning For Alleviating Attention Concentration In Transformers." arXiv preprint arXiv:2410.17897 (2024).](https://arxiv.org/abs/2410.17897) 97 | 2. [Team, Gemma, et al. "Gemma 2: Improving open language models at a practical size." arXiv preprint arXiv:2408.00118 (2024).](https://arxiv.org/abs/2408.00118) 98 | 99 | -------------------------------------------------------------------------------- /records/110624_ShortcutsTweaks/nanogpt_speedrun110.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/records/110624_ShortcutsTweaks/nanogpt_speedrun110.png -------------------------------------------------------------------------------- /records/110624_ShortcutsTweaks/nanogpt_speedrun111.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/records/110624_ShortcutsTweaks/nanogpt_speedrun111.png -------------------------------------------------------------------------------- /records/110924_Replicateleloykun/README.md: -------------------------------------------------------------------------------- 1 | This is a replication attempt for the record attempt described [here](https://x.com/leloykun/status/1854557419768254915) by @leloykun. 2 | 3 | The original record could not be directly accepted because it showed a slower wallclock time than the previous record - 4 | however, this was plausibly due to hardware differences, as the competitor's hardware was slightly slower. 5 | 6 | Therefore, to certify this attempt as the new record, here I replicated it on my own hardware. 7 | This did successfully reduce the wallclock time compared to the 11/07/24 record by ~11 seconds, however it also 8 | resulted in an invalid val loss of 3.2824, above the threshold of 3.28. 9 | 10 | The [original record attempt's reproducible log](https://github.com/leloykun/modded-nanogpt/blob/224f10d190677d9dc3c9c45da280078196a6fe40/records/110724_EmbeddingBetasCooldown/6c9d875b-ad91-46c9-9ede-2c7f998b9b16.txt) attained a val loss of 3.2798, just barely below the 3.28 threshold. So this difference is plausibly due to random inter-run variance. 11 | 12 | This indicates that the true average val loss of the run may be worse than 3.28, meaning I am **unable to certify it as the new record.** 13 | 14 | Ideally, all records should attain a low enough val loss such that >95% of runs attain below 3.28. Good evidence for this would be a single run 15 | attaining <= 3.278. Previous records have adhered to this rule, but admittedly it's hard to define precisely and is therefore mostly a matter of taste. 16 | 17 | -------------------------------------------------------------------------------- /records/111024_UNetDoubleLr/README.md: -------------------------------------------------------------------------------- 1 | This is a record by Brendan Hogan Rappazzo [@brendanh0gan](https://x.com/brendanh0gan). 2 | 3 | New record: 7.23 minutes 4 | 5 | Previous record: 7.8 minutes 6 | 7 | Changelog: 8 | - Added U-net-like skip connections into the transformer 9 | - Doubled the learning rate 10 | 11 | --- 12 | 13 | This record was first posted [here](https://x.com/brendanh0gan/status/1855273758681866352), & then a few iterations were required to benchmark it on 8x SXM H100s. 14 | Brendan's fork of modded-nanogpt is [here](https://github.com/brendanhogan/modded-nanogpt/tree/master). The code for the record can also be extracted from the reproducible log in this folder. 15 | 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | torch 4 | huggingface-hub 5 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 train_gpt2.py 2 | -------------------------------------------------------------------------------- /run_rwkv6.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 train_rwkv6.py "$@" 2 | -------------------------------------------------------------------------------- /run_rwkv7.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=8 train_rwkv7.py "$@" 2 | -------------------------------------------------------------------------------- /rwkv_cuda/wkv6_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, 9 | F *__restrict__ const _y) 10 | { 11 | const int b = blockIdx.x / H; 12 | const int h = blockIdx.x % H; 13 | const int i = threadIdx.x; 14 | _u += h*_N_; 15 | 16 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 17 | float state[_N_] = {0}; 18 | 19 | __syncthreads(); 20 | u[i] = float(_u[i]); 21 | __syncthreads(); 22 | 23 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 24 | { 25 | __syncthreads(); 26 | w[i] = __expf(-__expf(float(_w[t]))); 27 | r[i] = float(_r[t]); 28 | k[i] = float(_k[t]); 29 | __syncthreads(); 30 | 31 | const float v = float(_v[t]); 32 | float y = 0; 33 | 34 | #pragma unroll 35 | for (int j = 0; j < _N_; j+=4) 36 | { 37 | const float4& r_ = (float4&)(r[j]); 38 | const float4& k_ = (float4&)(k[j]); 39 | const float4& w_ = (float4&)(w[j]); 40 | const float4& u_ = (float4&)(u[j]); 41 | float4& s = (float4&)(state[j]); 42 | float4 x; 43 | 44 | x.x = k_.x * v; 45 | x.y = k_.y * v; 46 | x.z = k_.z * v; 47 | x.w = k_.w * v; 48 | 49 | y += r_.x * (u_.x * x.x + s.x); 50 | y += r_.y * (u_.y * x.y + s.y); 51 | y += r_.z * (u_.z * x.z + s.z); 52 | y += r_.w * (u_.w * x.w + s.w); 53 | 54 | s.x = s.x * w_.x + x.x; 55 | s.y = s.y * w_.y + x.y; 56 | s.z = s.z * w_.z + x.z; 57 | s.w = s.w * w_.w + x.w; 58 | } 59 | _y[t] = F(y); 60 | } 61 | } 62 | 63 | template 64 | __global__ void kernel_backward_101(const int B, const int T, const int C, const int H, 65 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, 66 | F *__restrict__ const _gr, F *__restrict__ const _gu) 67 | { 68 | const int b = blockIdx.x / H; 69 | const int h = blockIdx.x % H; 70 | const int i = threadIdx.x; 71 | 72 | __shared__ float v[_N_], gy[_N_]; 73 | 74 | const float u = float(_u[h*_N_ + i]); 75 | 76 | float state[_N_] = {0}; 77 | 78 | const int t_0 = b*T*C + h*_N_ + i; 79 | const int t_T = t_0 + T*C; 80 | 81 | float gu = 0; 82 | for (int t = t_0; t < t_T; t += C) 83 | { 84 | __syncthreads(); 85 | v[i] = float(_v[t]); 86 | gy[i] = float(_gy[t]); 87 | __syncthreads(); 88 | 89 | const float k = float(_k[t]); 90 | const float w = __expf(-__expf(float(_w[t]))); 91 | float gr = 0, gu_ = 0; 92 | 93 | #pragma unroll 94 | for (int j = 0; j < _N_; j++) 95 | { 96 | float& s = state[j]; 97 | float x = k * v[j]; 98 | 99 | gr += (u * x + s) * gy[j]; 100 | gu_ += x * gy[j]; 101 | s = s * w + x; 102 | } 103 | _gr[t] = F(gr); 104 | gu += float(_r[t]) * gu_; 105 | } 106 | _gu[b*C + h*_N_ + i] = F(gu); 107 | } 108 | 109 | template 110 | __global__ void kernel_backward_102(const int B, const int T, const int C, const int H, 111 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, 112 | F *__restrict__ const _gk) 113 | { 114 | const int b = blockIdx.x / H; 115 | const int h = blockIdx.x % H; 116 | const int i = threadIdx.x; 117 | 118 | __shared__ float v[_N_], gy[_N_]; 119 | 120 | const float u = float(_u[h*_N_ + i]); 121 | 122 | float scccc[_N_] = {0}; 123 | 124 | const int t_0 = b*T*C + h*_N_ + i; 125 | const int t_T_1 = t_0 + (T-1)*C; 126 | 127 | for (int t = t_T_1; t >= t_0; t -= C) 128 | { 129 | __syncthreads(); 130 | v[i] = float(_v[t]); 131 | gy[i] = float(_gy[t]); 132 | __syncthreads(); 133 | 134 | const float rr = float(_r[t]); 135 | const float w = __expf(-__expf(float(_w[t]))); 136 | float gk = 0; 137 | 138 | #pragma unroll 139 | for (int j = 0; j < _N_; j++) 140 | { 141 | float& s = scccc[j]; 142 | float x = rr * gy[j]; 143 | 144 | gk += (u * x + s) * v[j]; 145 | s = x + s * w; 146 | } 147 | _gk[t] = F(gk); 148 | } 149 | } 150 | 151 | template 152 | __global__ void kernel_backward_103(const int B, const int T, const int C, const int H, 153 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, 154 | F *__restrict__ const _gv) 155 | { 156 | const int b = blockIdx.x / H; 157 | const int h = blockIdx.x % H; 158 | const int i = threadIdx.x; 159 | _u += h*_N_; 160 | 161 | __shared__ float u_[_N_], r[_N_], k[_N_], w_[_N_]; 162 | __syncthreads(); 163 | u_[i] = float(_u[i]); 164 | __syncthreads(); 165 | 166 | float sdddd[_N_] = {0}; 167 | 168 | const int t_0 = b*T*C + h*_N_ + i; 169 | const int t_T_1 = t_0 + (T-1)*C; 170 | 171 | for (int t = t_T_1; t >= t_0; t -= C) 172 | { 173 | __syncthreads(); 174 | r[i] = float(_r[t]); 175 | k[i] = float(_k[t]); 176 | w_[i] = __expf(-__expf(float(_w[t]))); 177 | __syncthreads(); 178 | 179 | const float gyy = float(_gy[t]); 180 | float gv = 0; 181 | 182 | #pragma unroll 183 | for (int j = 0; j < _N_; j++) 184 | { 185 | float& s = sdddd[j]; 186 | float x = gyy * r[j]; 187 | 188 | gv += (u_[j] * x + s) * k[j]; 189 | s = x + s * w_[j]; 190 | } 191 | _gv[t] = F(gv); 192 | } 193 | } 194 | 195 | template 196 | __global__ void kernel_backward_201(const int B, const int T, const int C, const int H, 197 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, 198 | F *__restrict__ const _gw) 199 | { 200 | const int b = blockIdx.x / H; 201 | const int h = blockIdx.x % H; 202 | const int i = threadIdx.x; 203 | 204 | __shared__ float v[_N_], gy[_N_]; 205 | float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0}; 206 | 207 | const int t_0 = b*T*C + h*_N_ + i; 208 | const int t_1 = t_0 + C; 209 | const int t_2 = t_0 + 2*C; 210 | const int t_T_1 = t_0 + (T-1)*C; 211 | 212 | for (int t = t_T_1; t > t_1; t -= C) 213 | { 214 | __syncthreads(); 215 | gy[i] = float(_gy[t]); 216 | v[i] = float(_v[t-2*C]); 217 | __syncthreads(); 218 | 219 | const float r = float(_r[t]); 220 | const float w = __expf(-__expf(float(_w[t-C]))); 221 | float sum = 0.0f; 222 | 223 | #pragma unroll 224 | for (int j = 0; j < _N_; j++) 225 | { 226 | float& s = saaaa[j]; 227 | float x = r * gy[j]; 228 | s = (s + x) * w; 229 | sum += s * v[j]; 230 | } 231 | sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]); 232 | } 233 | 234 | float sss = sbbbb[0]; 235 | _gw[t_0] = 0; 236 | _gw[t_1] = F(sss * -__expf(float(_w[t_1]))); 237 | 238 | for (int t = t_2; t < t_T_1; t += C) 239 | { 240 | __syncthreads(); 241 | gy[i] = float(_gy[t]); 242 | v[i] = float(_v[t-2*C]); 243 | __syncthreads(); 244 | 245 | const float w = __expf(-__expf(float(_w[t-C]))); 246 | const float k = float(_k[t-2*C]); 247 | float sum = 0.0f; 248 | 249 | #pragma unroll 250 | for (int j = 0; j < _N_; j++) 251 | { 252 | float& s = scccc[j]; 253 | float x = k * v[j]; 254 | s = (s + x) * w; 255 | sum += s * gy[j]; 256 | } 257 | sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t])); 258 | _gw[t] = F(sss * -__expf(float(_w[t]))); 259 | } 260 | _gw[t_T_1] = 0; 261 | } 262 | 263 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *y) 264 | { 265 | assert(H*_N_ == C); 266 | assert(_N_%4 == 0); 267 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 268 | } 269 | 270 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 271 | { 272 | assert(H*_N_ == C); 273 | assert(_N_%4 == 0); 274 | kernel_backward_101<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gu); 275 | kernel_backward_102<<>>(B, T, C, H, r, k, v, w, u, gy, gk); 276 | kernel_backward_103<<>>(B, T, C, H, r, k, v, w, u, gy, gv); 277 | kernel_backward_201<<>>(B, T, C, H, r, k, v, w, u, gy, gw); 278 | } 279 | -------------------------------------------------------------------------------- /rwkv_cuda/wkv6_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); 13 | } 14 | 15 | TORCH_LIBRARY(wkv6, m) { 16 | m.def("forward(int B, int T, int C, int H, Tensor r, Tensor k, Tensor v, Tensor w, Tensor u, Tensor(a!) y) -> ()"); 17 | m.def("backward(int B, int T, int C, int H, Tensor r, Tensor k, Tensor v, Tensor w, Tensor u, Tensor gy, Tensor(a!) gr, Tensor(b!) gk, Tensor(c!) gv, Tensor(d!) gw, Tensor(e!) gu) -> ()"); 18 | } 19 | 20 | TORCH_LIBRARY_IMPL(wkv6, CUDA, m) { 21 | m.impl("forward", &forward); 22 | m.impl("backward", &backward); 23 | } 24 | -------------------------------------------------------------------------------- /rwkv_cuda/wkv7g_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | 4 | typedef at::BFloat16 bf16; 5 | 6 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y, float *saa, float* sss); 7 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, float *saa, float* sss, float* zzz, bf16 *gy, bf16 *gr, bf16 *gw, bf16 *gk, bf16 *gv, bf16 *ga, bf16 *gb); 8 | 9 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y, torch::Tensor &saa, torch::Tensor &sss) { 10 | cuda_forward(B, T, C, H, r.data_ptr(), w.data_ptr(), k.data_ptr(), v.data_ptr(), a.data_ptr(), b.data_ptr(), y.data_ptr(), saa.data_ptr(), sss.data_ptr()); 11 | } 12 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &saa, torch::Tensor &sss, torch::Tensor &zzz, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gw, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &ga, torch::Tensor &gb) { 13 | cuda_backward(B, T, C, H, r.data_ptr(), w.data_ptr(), k.data_ptr(), v.data_ptr(), a.data_ptr(), b.data_ptr(), saa.data_ptr(), sss.data_ptr(), zzz.data_ptr(), gy.data_ptr(), gr.data_ptr(), gw.data_ptr(), gk.data_ptr(), gv.data_ptr(), ga.data_ptr(), gb.data_ptr()); 14 | } 15 | 16 | TORCH_LIBRARY(wkv7g, m) { 17 | m.def("forward(int B, int T, int C, int H, Tensor r, Tensor w, Tensor k, Tensor v, Tensor a, Tensor b, Tensor(a!) y, Tensor(b!) saa, Tensor(c!) sss) -> ()"); 18 | m.def("backward(int B, int T, int C, int H, Tensor r, Tensor w, Tensor k, Tensor v, Tensor a, Tensor b, Tensor saa, Tensor sss, Tensor(a!) zzz, Tensor gy, Tensor(b!) gr, Tensor(c!) gw, Tensor(d!) gk, Tensor(e!) gv, Tensor(f!) ga, Tensor(g!) gb) -> ()"); 19 | } 20 | 21 | TORCH_LIBRARY_IMPL(wkv7g, CUDA, m) { 22 | m.impl("forward", &forward); 23 | m.impl("backward", &backward); 24 | } 25 | -------------------------------------------------------------------------------- /rwkv_cuda/wkv7g_v1.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | 5 | typedef at::BFloat16 bf16; 6 | 7 | template 8 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 9 | const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b, 10 | F *__restrict__ const _y, float *__restrict__ const _saa, float *__restrict__ const _sss) 11 | { 12 | const int e = blockIdx.x / H; 13 | const int h = blockIdx.x % H; 14 | const int i = threadIdx.x; 15 | 16 | float state[_N_] = {0}; 17 | __shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_]; 18 | 19 | float v[_T_]; 20 | for (int _t = 0; _t < T; _t++) 21 | { 22 | const int t = e*T*C + h*_N_ + i + _t * C; 23 | v[_t] = float(_v[t]); 24 | } 25 | 26 | for (int _t = 0; _t < T; _t++) 27 | { 28 | const int t = e*T*C + h*_N_ + i + _t * C; 29 | __syncthreads(); 30 | r[i] = float(_r[t]); 31 | w[i] = __expf(-__expf(float(_w[t]))); 32 | k[i] = float(_k[t]); 33 | a[i] = float(_a[t]); 34 | b[i] = float(_b[t]); 35 | __syncthreads(); 36 | 37 | float sa = 0; 38 | #pragma unroll 39 | for (int j = 0; j < _N_; j++) 40 | { 41 | sa += a[j] * state[j]; 42 | } 43 | _saa[t] = float(sa); 44 | 45 | float vv = v[_t]; 46 | float y = 0; 47 | #pragma unroll 48 | for (int j = 0; j < _N_; j++) 49 | { 50 | float& s = state[j]; 51 | s = s * w[j] + sa * b[j] + k[j] * vv; 52 | y += s * r[j]; 53 | } 54 | _y[t] = F(y); 55 | 56 | if ((_t+1) % _CHUNK_LEN_ == 0) 57 | { 58 | const int a = _t / _CHUNK_LEN_; 59 | const int c = _T_ / _CHUNK_LEN_; 60 | const int p = e*C*_N_*c + h*_N_*_N_*c + a*_N_*_N_ + i; 61 | #pragma unroll 62 | for (int j = 0; j < _N_; j++) 63 | { 64 | _sss[p + j*_N_] = float(state[j]); 65 | } 66 | } 67 | } 68 | } 69 | 70 | template 71 | __global__ void kernel_backward_zzz(const int B, const int T, const int C, const int H, 72 | const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _a, const F *__restrict__ const _b, const F *__restrict__ const _gy, 73 | float *__restrict__ const _zzz) 74 | { 75 | const int e = blockIdx.x / H; 76 | const int h = blockIdx.x % H; 77 | const int i = threadIdx.x; 78 | 79 | __shared__ float r[_N_], w[_N_], a[_N_], b[_N_]; 80 | 81 | const int T_1 = e*T*C + (T-1)*C + h*_N_; 82 | float z[_N_]; 83 | const float gy = _gy[T_1 + i]; 84 | __syncthreads(); 85 | r[i] = float(_r[T_1+i]); 86 | __syncthreads(); 87 | #pragma unroll 88 | for (int j = 0; j < _N_; j++) 89 | { 90 | z[j] = gy * r[j]; 91 | } 92 | 93 | for (int _t = T-2; _t > _CHUNK_LEN_-1; _t--) 94 | { 95 | const int t = e*T*C + h*_N_ + _t * C + i; 96 | const float gy = _gy[t]; 97 | __syncthreads(); 98 | r[i] = float(_r[t]); 99 | w[i] = __expf(-__expf(float(_w[t+C]))); 100 | a[i] = float(_a[t+C]); 101 | b[i] = float(_b[t+C]); 102 | __syncthreads(); 103 | 104 | float zz = 0; 105 | #pragma unroll 106 | for (int j = 0; j < _N_; j++) 107 | { 108 | zz += b[j] * z[j]; 109 | } 110 | #pragma unroll 111 | for (int j = 0; j < _N_; j++) 112 | { 113 | z[j] = z[j] * w[j] + gy * r[j] + a[j] * zz; 114 | // printf("t %d i %d j %d z %f\n", _t, i, j, z[j]); 115 | } 116 | if (_t % _CHUNK_LEN_ == 0) 117 | { 118 | const int a = _t / _CHUNK_LEN_ - 1; 119 | const int c = _T_ / _CHUNK_LEN_ - 1; 120 | const int p = e*C*_N_*c + h*_N_*_N_*c + a*_N_*_N_ + i; 121 | #pragma unroll 122 | for (int j = 0; j < _N_; j++) 123 | { 124 | _zzz[p + j*_N_] = float(z[j]); 125 | } 126 | } 127 | } 128 | } 129 | 130 | template 131 | __global__ void kernel_backward_rwkv(const int B, const int T, const int C, const int H, 132 | const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b, const float *__restrict__ const _saa, const float *__restrict__ const _sss, const float *__restrict__ const _zzz, 133 | const F *__restrict__ const _gy, F *__restrict__ const _gr, F *__restrict__ const _gw, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _ga, F *__restrict__ const _gb) 134 | { 135 | const int e = blockIdx.x / H; 136 | const int h = blockIdx.x % H; 137 | const int chunk = threadIdx.x; 138 | const int n_chunk = _T_ / _CHUNK_LEN_; 139 | 140 | float zzz[_N_*_N_] = {0}, sss[_N_*_N_] = {999}, saa[_N_] = {999}; 141 | float r[_N_] = {999}, w[_N_] = {0}, w1[_N_] = {999}, winv[_N_] = {999}, ww[_N_] = {999}; 142 | float k[_N_] = {0}, v[_N_] = {999}, a[_N_] = {0}, a1[_N_] = {999}, b[_N_] = {0}, b1[_N_] = {999}, gy[_N_] = {999}; 143 | 144 | if (chunk != n_chunk - 1) 145 | { 146 | const int p = e*T*C + (chunk+1)*_CHUNK_LEN_*C + h*_N_; 147 | for (int i = 0; i < _N_; i++) 148 | { 149 | k[i] = float(_k[p+i]); 150 | a[i] = float(_a[p+i]); 151 | b[i] = float(_b[p+i]); 152 | w[i] = __expf(-__expf(float(_w[p+i]))); 153 | const int p = e*C*_N_*(n_chunk-1) + h*_N_*_N_*(n_chunk-1) + chunk*_N_*_N_ + i*_N_; 154 | #pragma unroll 155 | for (int j = 0; j < _N_; j++) 156 | { 157 | zzz[i*_N_+j] = float(_zzz[p+j]); 158 | } 159 | } 160 | } 161 | for (int i = 0; i < _N_; i++) 162 | { 163 | const int p = e*C*_N_*n_chunk + h*_N_*_N_*n_chunk + chunk*_N_*_N_ + i*_N_; 164 | #pragma unroll 165 | for (int j = 0; j < _N_; j++) 166 | { 167 | sss[i*_N_+j] = float(_sss[p+j]); 168 | } 169 | } 170 | 171 | for (int _t = _CHUNK_LEN_-1; _t > -1; _t--) 172 | { 173 | const int t = chunk * _CHUNK_LEN_ + _t; 174 | const int b_t_h = e*T*C + t*C + h*_N_; 175 | #pragma unroll 176 | for (int n = 0; n < _N_; n++) 177 | { 178 | w1[n] = w[n]; 179 | a1[n] = a[n]; 180 | b1[n] = b[n]; 181 | 182 | r[n] = float(_r[b_t_h+n]); 183 | k[n] = float(_k[b_t_h+n]); 184 | v[n] = float(_v[b_t_h+n]); 185 | a[n] = float(_a[b_t_h+n]); 186 | b[n] = float(_b[b_t_h+n]); 187 | gy[n] = float(_gy[b_t_h+n]); 188 | saa[n] = float(_saa[b_t_h+n]); 189 | 190 | ww[n] = -__expf(float(_w[b_t_h+n])); 191 | w[n] = __expf(ww[n]); 192 | ww[n] = ww[n] * w[n]; 193 | winv[n] = 1.0f / w[n]; 194 | } 195 | 196 | for (int j = 0; j < _N_; j++) 197 | { 198 | float zz = 0; 199 | #pragma unroll 200 | for (int i = 0; i < _N_; i++) 201 | { 202 | zz += b1[i] * zzz[i*_N_+j]; 203 | } 204 | const float gyj = gy[j]; 205 | #pragma unroll 206 | for (int i = 0; i < _N_; i++) 207 | { 208 | zzz[i*_N_+j] = zzz[i*_N_+j] * w1[i] + gyj * r[i] + a1[i] * zz; 209 | // printf("t %d i %d j %d z %f\n",t,i,j,zzz[i*_N_+j]); 210 | // printf("t %d i %d j %d s %f\n",t,i,j,sss[i*_N_+j]); 211 | } 212 | } 213 | 214 | for (int i = 0; i < _N_; i++) 215 | { 216 | float gr = 0; 217 | #pragma unroll 218 | for (int j = 0; j < _N_; j++) 219 | { 220 | gr += gy[j] * sss[i*_N_+j]; 221 | } 222 | _gr[b_t_h+i] = F(gr); 223 | } 224 | 225 | for (int i = 0; i < _N_; i++) 226 | { 227 | const float ki = k[i]; 228 | const float bi = b[i]; 229 | const float wi = winv[i]; 230 | #pragma unroll 231 | for (int j = 0; j < _N_; j++) 232 | { 233 | sss[i*_N_+j] = (sss[i*_N_+j] - ki * v[j] - bi * saa[j]) * wi; 234 | } 235 | } 236 | 237 | float gv[_N_] = {0}; float as[_N_] = {0}; float bz[_N_] = {0}; 238 | for (int i = 0; i < _N_; i++) 239 | { 240 | const float ki = k[i]; 241 | const float ai = a[i]; 242 | const float bi = b[i]; 243 | float gw = 0; 244 | float gk = 0; 245 | #pragma unroll 246 | for (int j = 0; j < _N_; j++) 247 | { 248 | const float sij = sss[i*_N_+j]; 249 | const float zij = zzz[i*_N_+j]; 250 | gv[j] += ki * zij; 251 | as[j] += ai * sij; 252 | bz[j] += bi * zij; 253 | gw += sij * zij; 254 | gk += v[j] * zij; 255 | } 256 | _gw[b_t_h+i] = F(gw * ww[i]); 257 | _gk[b_t_h+i] = F(gk); 258 | } 259 | for (int i = 0; i < _N_; i++) 260 | { 261 | _gv[b_t_h+i] = F(gv[i]); 262 | float ga = 0; 263 | float gb = 0; 264 | #pragma unroll 265 | for (int j = 0; j < _N_; j++) 266 | { 267 | ga += bz[j] * sss[i*_N_+j]; 268 | gb += as[j] * zzz[i*_N_+j]; 269 | } 270 | _ga[b_t_h+i] = F(ga); 271 | _gb[b_t_h+i] = F(gb); 272 | } 273 | } 274 | } 275 | 276 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y, float *saa, float* sss) 277 | { 278 | assert(H*_N_ == C); 279 | kernel_forward<<>>(B, T, C, H, r, w, k, v, a, b, y, saa, sss); 280 | } 281 | 282 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, float *saa, float* sss, float* zzz, bf16 *gy, bf16 *gr, bf16 *gw, bf16 *gk, bf16 *gv, bf16 *ga, bf16 *gb) 283 | { 284 | assert(H*_N_ == C); 285 | assert(T%_CHUNK_LEN_ == 0); 286 | 287 | kernel_backward_zzz<<>>(B, T, C, H, r, w, k, a, b, gy, zzz); 288 | kernel_backward_rwkv<<>>(B, T, C, H, r, w, k, v, a, b, saa, sss, zzz, gy, gr, gw, gk, gv, ga, gb); 289 | } 290 | -------------------------------------------------------------------------------- /rwkv_cuda_wind/backstepping_f32.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | using bf = __nv_bfloat16; 4 | 5 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa); 6 | 7 | void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) { 8 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; 9 | cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr()); 10 | } 11 | 12 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da); 13 | 14 | void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy, 15 | torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) { 16 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; 17 | cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)dy.data_ptr(), 18 | (float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr()); 19 | } 20 | 21 | TORCH_LIBRARY(wind_backstepping, m) { 22 | m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa) -> ()"); 23 | m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) dz, Tensor(f!) da) -> ()"); 24 | } 25 | 26 | TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) { 27 | m.impl("forward", &forward); 28 | m.impl("backward", &backward); 29 | } 30 | -------------------------------------------------------------------------------- /rwkv_cuda_wind/backstepping_f32_1.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | using bf = __nv_bfloat16; 5 | __device__ inline float to_float(const bf & u) { return __bfloat162float(u); } 6 | __device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); } 7 | 8 | typedef bf * __restrict__ F_; 9 | 10 | __global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, bf* y_, float* s_, float* sa_) { 11 | constexpr int C = _C_; 12 | int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x; 13 | 14 | float state[C] = {0}; 15 | __shared__ float q[C], k[C], w[C], a[C], b[C]; 16 | 17 | for (int t = 0; t < T; t++) { 18 | int ind = bb*T*H*C + t*H*C + hh * C + i; 19 | __syncthreads(); 20 | q[i] = to_float(q_[ind]); 21 | w[i] = __expf(-__expf(to_float(w_[ind]))); 22 | k[i] = to_float(k_[ind]); 23 | a[i] = to_float(a_[ind]); 24 | b[i] = to_float(b_[ind]); 25 | __syncthreads(); 26 | 27 | float sa = 0; 28 | #pragma unroll 29 | for (int j = 0; j < C; j++) { 30 | sa += a[j] * state[j]; 31 | } 32 | sa_[ind] = sa; 33 | 34 | float v = to_float(v_[ind]); 35 | float y = 0; 36 | #pragma unroll 37 | for (int j = 0; j < C; j++) { 38 | float& s = state[j]; 39 | s = s * w[j] + sa * b[j] + k[j] * v; 40 | y += s * q[j]; 41 | } 42 | y_[ind] = to_bf(y); 43 | 44 | if ((t+1)%_CHUNK_LEN_ == 0) { 45 | int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i; 46 | #pragma unroll 47 | for (int j = 0; j < C; j++) { 48 | s_[base + j*C] = state[j]; 49 | } 50 | } 51 | } 52 | } 53 | 54 | __global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) { 55 | constexpr int C = _C_; 56 | int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x; 57 | 58 | float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0}; 59 | __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C]; 60 | float qi, wi, ki, ai, bi, dyi; 61 | 62 | for (int t = T-1; t >= 0; t--) { 63 | int ind = bb*T*H*C + t*H*C + hh * C + i; 64 | __syncthreads(); 65 | q[i] = qi = to_float(q_[ind]); 66 | float wi_fac = -__expf(to_float(w_[ind])); 67 | w[i] = wi = __expf(wi_fac); 68 | k[i] = ki = to_float(k_[ind]); 69 | a[i] = ai = to_float(a_[ind]); 70 | b[i] = bi = to_float(b_[ind]); 71 | v[i] = to_float(v_[ind]); 72 | dy[i] = dyi = to_float(dy_[ind]); 73 | sa[i] = sa_[ind]; 74 | __syncthreads(); 75 | 76 | if ((t+1)%_CHUNK_LEN_ == 0) { 77 | int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C; 78 | #pragma unroll 79 | for (int j = 0; j < C; j++) { 80 | stateT[j] = s_[base + j]; 81 | } 82 | } 83 | 84 | float dq = 0; 85 | #pragma unroll 86 | for (int j = 0; j < C; j++) { 87 | dq += stateT[j]*dy[j]; 88 | } 89 | dq_[ind] = to_bf(dq); 90 | 91 | float iwi = 1.0f/wi; 92 | #pragma unroll 93 | for (int j = 0; j < C; j++) { 94 | stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi; 95 | dstate[j] += dyi * q[j]; 96 | dstateT[j] += qi * dy[j]; 97 | } 98 | 99 | float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0; 100 | #pragma unroll 101 | for (int j = 0; j < C; j++) { 102 | dw += dstateT[j]*stateT[j]; 103 | dk += dstateT[j]*v[j]; 104 | dv += dstate[j]*k[j]; 105 | dSb += dstate[j]*b[j]; 106 | db += dstateT[j]*sa[j]; 107 | } 108 | dw_[ind] = to_bf(dw * wi * wi_fac); 109 | dk_[ind] = to_bf(dk); 110 | dv_[ind] = to_bf(dv); 111 | db_[ind] = to_bf(db); 112 | 113 | __syncthreads(); 114 | dSb_shared[i] = dSb; 115 | __syncthreads(); 116 | 117 | float da = 0; 118 | #pragma unroll 119 | for (int j = 0; j < C; j++) { 120 | da += stateT[j]*dSb_shared[j]; 121 | } 122 | da_[ind] = to_bf(da); 123 | 124 | #pragma unroll 125 | for (int j = 0; j < C; j++) { 126 | dstate[j] = dstate[j]*w[j] + dSb * a[j]; 127 | dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j]; 128 | } 129 | } 130 | } 131 | 132 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa) { 133 | forward_kernel<<>>(T,H,w,q,k,v,z,a,y,s,sa); 134 | } 135 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da) { 136 | assert(T%_CHUNK_LEN_ == 0); 137 | backward_kernel<<>>(T,H,w,q,k,v,z,a,dy,s,sa,dw,dq,dk,dv,dz,da); 138 | } 139 | -------------------------------------------------------------------------------- /rwkv_cuda_wind/backstepping_f32_2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | using bf = __nv_bfloat16; 5 | __device__ inline float to_float(const bf & u) { return __bfloat162float(u); } 6 | __device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); } 7 | 8 | typedef bf * __restrict__ F_; 9 | 10 | __global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, bf* y_, float* s_, float* sa_) { 11 | constexpr int C = _C_; 12 | int bind = blockIdx.y, hind = blockIdx.x, i = threadIdx.x; 13 | 14 | float state[C] = {0}; 15 | __shared__ float q[C], k[C], w[C], a[C], b[C]; 16 | 17 | for (int t = 0; t < T; t++) { 18 | int ind = bind*T*H*C + t*H*C + hind * C + i; 19 | __syncthreads(); 20 | q[i] = to_float(q_[ind]); 21 | w[i] = __expf(-__expf(to_float(w_[ind]))); 22 | k[i] = to_float(k_[ind]); 23 | a[i] = to_float(a_[ind]); 24 | b[i] = to_float(b_[ind]); 25 | __syncthreads(); 26 | 27 | float sa = 0; 28 | #pragma unroll 29 | for (int j = 0; j < C; j++) { 30 | sa += a[j] * state[j]; 31 | } 32 | sa_[ind] = sa; 33 | 34 | float v = to_float(v_[ind]); 35 | float y = 0; 36 | #pragma unroll 37 | for (int j = 0; j < C; j++) { 38 | float& s = state[j]; 39 | s = s * w[j] + sa * b[j] + k[j] * v; 40 | y += s * q[j]; 41 | } 42 | y_[ind] = to_bf(y); 43 | 44 | if ((t+1)%_CHUNK_LEN_ == 0) { 45 | int base = (bind*H+hind)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i; 46 | #pragma unroll 47 | for (int j = 0; j < C; j++) { 48 | s_[base + j*C] = state[j]; 49 | } 50 | } 51 | } 52 | } 53 | 54 | __global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) { 55 | constexpr int C = _C_; 56 | int bind = blockIdx.y, hind = blockIdx.x, i = threadIdx.x; 57 | 58 | float stateT[C] = {0}; 59 | __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C]; 60 | 61 | extern __shared__ char smem_[]; 62 | float*dstate = (float*)smem_; //[C*(C+1)]; 63 | 64 | for (int j = 0; j < C; j++) { 65 | dstate[i*(C+1)+j] = 0; 66 | } 67 | 68 | for (int t = T-1; t >= 0; t--) { 69 | int ind = bind*T*H*C + t*H*C + hind * C + i; 70 | float bi, ki, dyi, wi; 71 | __syncthreads(); 72 | q[i] = to_float(q_[ind]); 73 | float wi_fac = -__expf(to_float(w_[ind])); 74 | w[i] = wi = __expf(wi_fac); 75 | k[i] = ki = to_float(k_[ind]); 76 | a[i] = to_float(a_[ind]); 77 | b[i] = bi = to_float(b_[ind]); 78 | v[i] = to_float(v_[ind]); 79 | dy[i] = dyi = to_float(dy_[ind]); 80 | sa[i] = sa_[ind]; 81 | __syncthreads(); 82 | 83 | if ((t+1)%_CHUNK_LEN_ == 0) { 84 | int base = (bind*H+hind)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C; 85 | #pragma unroll 86 | for (int j = 0; j < C; j++) { 87 | stateT[j] = s_[base + j]; 88 | } 89 | } 90 | 91 | float dq = 0; 92 | #pragma unroll 93 | for (int j = 0; j < C; j++) { 94 | dq += stateT[j]*dy[j]; 95 | } 96 | dq_[ind] = to_bf(dq); 97 | 98 | float iwi = 1.0f/wi; 99 | for (int j = 0; j < C; j++) { 100 | stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi; 101 | dstate[i*(C+1)+j] += dyi * q[j]; 102 | } 103 | 104 | float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0; 105 | #pragma unroll 106 | for (int j = 0; j < C; j++) { 107 | float ds = dstate[j*(C+1)+i]; 108 | dw += ds*stateT[j]; 109 | dk += ds*v[j]; 110 | db += ds*sa[j]; 111 | } 112 | #pragma unroll 113 | for (int j = 0; j < C; j++) { 114 | float ds = dstate[i*(C+1)+j]; 115 | dv += ds*k[j]; 116 | dSb += ds*b[j]; 117 | } 118 | dw_[ind] = to_bf(dw * wi * wi_fac); 119 | dk_[ind] = to_bf(dk); 120 | dv_[ind] = to_bf(dv); 121 | db_[ind] = to_bf(db); 122 | 123 | __syncthreads(); 124 | dSb_shared[i] = dSb; 125 | __syncthreads(); 126 | 127 | float da = 0; 128 | #pragma unroll 129 | for (int j = 0; j < C; j++) { 130 | da += stateT[j]*dSb_shared[j]; 131 | } 132 | da_[ind] = to_bf(da); 133 | 134 | for (int j = 0; j < C; j++) { 135 | dstate[i*(C+1)+j] = dstate[i*(C+1)+j]*w[j] + dSb * a[j]; 136 | } 137 | } 138 | } 139 | 140 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa) { 141 | forward_kernel<<>>(T,H,w,q,k,v,z,a,y,s,sa); 142 | } 143 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da) { 144 | assert(T%_CHUNK_LEN_ == 0); 145 | int shared_mem = _C_*(_C_+1)*4; 146 | assert(!cudaFuncSetAttribute(backward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem)); 147 | backward_kernel<<>>(T,H,w,q,k,v,z,a,dy,s,sa,dw,dq,dk,dv,dz,da); 148 | } 149 | -------------------------------------------------------------------------------- /rwkv_cuda_wind/tile.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | //TODO: static? inline? __align__(16)? 5 | 6 | using bf = __nv_bfloat16; 7 | using bf2 = __nv_bfloat162; 8 | using uint = unsigned int; 9 | __device__ inline float to_float(const bf & u) { return __bfloat162float(u); } 10 | __device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); } 11 | __device__ inline float2 to_float2(const bf2 & u) { return __bfloat1622float2(u); } 12 | __device__ inline float2 to_float2(const float2 & u) { return u; } 13 | __device__ inline bf2 to_bf2(const float2 & u) { return __float22bfloat162_rn(u); } 14 | __device__ inline uint& as_uint(const bf2&x) { return *((uint*)(&x)); } 15 | __device__ inline uint __smem(const void*x) { return __cvta_generic_to_shared(x); } 16 | 17 | __device__ void __commit_group() { asm volatile("cp.async.commit_group;\n" ::); } 18 | __device__ void __wait_group() { asm volatile("cp.async.wait_all;\n" ::); } 19 | template __device__ void __wait_groups() { asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); } 20 | 21 | __device__ void __copy_wait() { __commit_group(); __wait_group(); } 22 | 23 | __device__ void operator*=(float2&a, const float2&b) { a.x *= b.x; a.y *= b.y; } 24 | __device__ void operator+=(float2&a, const float2&b) { a.x += b.x; a.y += b.y; } 25 | __device__ float2 operator+(const float2&a, const float2&b) { return {a.x+b.x,a.y+b.y}; } 26 | __device__ float2 operator*(const float2&a, const float2&b) { return {a.x*b.x,a.y*b.y}; } 27 | 28 | struct STile; 29 | struct RTile; 30 | struct FTile; 31 | 32 | struct GTile { 33 | bf*ga; 34 | int stride; 35 | __device__ GTile(bf*ga_, int stride_) : ga(ga_), stride(stride_) {} 36 | __device__ GTile& operator=(const RTile&); 37 | }; 38 | struct GFTile { 39 | float*ga; 40 | int stride; 41 | __device__ GFTile(float*ga_, int stride_) : ga(ga_), stride(stride_) {} 42 | __device__ GFTile& operator=(const FTile&); 43 | }; 44 | struct STileT { STile*st; }; 45 | 46 | struct __align__(16) STile { 47 | bf data[16*16]; 48 | __device__ STile() {} 49 | __device__ STile(const RTile&o) { *this=o; } 50 | __device__ STile& operator=(const GTile&); 51 | __device__ STile& operator=(const RTile&); 52 | __device__ STileT t() { return STileT{this}; } 53 | }; 54 | struct Product { const RTile*a, *b; }; 55 | struct ProductPlus { const RTile*a, *b; const FTile* c; }; 56 | struct RTile { 57 | bf2 data[4]; 58 | __device__ RTile() {} 59 | __device__ void zero_() { data[0] = data[1] = data[2] = data[3] = to_bf2({0.f,0.f}); } 60 | __device__ RTile(const STile&o) { *this=o; } 61 | __device__ RTile(const STileT&o) { *this=o; } 62 | __device__ RTile(const FTile&o) { *this=o; } 63 | __device__ RTile& operator=(const STile&); 64 | __device__ RTile& operator=(const STileT&); 65 | __device__ RTile& operator=(const FTile&fa); 66 | __device__ RTile& operator=(const GTile&); 67 | }; 68 | struct FTile { 69 | union { 70 | float2 data[4]; 71 | float fdata[8]; 72 | }; 73 | __device__ void zero_() { data[0] = data[1] = data[2] = data[3] = {0.f,0.f}; } 74 | __device__ FTile() {} 75 | __device__ FTile(const FTile&o) { for (int i = 0; i < 4; i++) data[i] = o.data[i]; } 76 | __device__ FTile(const RTile&r) { *this=r; } 77 | __device__ FTile(const Product&p) { *this=p; } 78 | __device__ FTile(const ProductPlus&p) { *this=p; } 79 | __device__ FTile& operator=(const Product&); 80 | __device__ FTile& operator=(const RTile&); 81 | __device__ FTile& operator=(const ProductPlus&); 82 | __device__ FTile& operator+=(const Product&); 83 | __device__ FTile& operator+=(const FTile&o) { for (int i = 0; i < 4; i++) data[i] += o.data[i]; return *this; } 84 | }; 85 | 86 | __device__ void print(STile t) { 87 | if (threadIdx.x == 0) { 88 | for (int i = 0; i < 16; i++) { 89 | for (int j = 0; j < 16; j++) { 90 | printf("%f ", to_float(t.data[i*16+j])); 91 | } 92 | printf("\n"); 93 | } 94 | printf("\n"); 95 | } 96 | } 97 | 98 | template 99 | __device__ void print(T t, int warpi = 0) { 100 | int tid = threadIdx.x - warpi*32; 101 | for (int i = 0; i < 16; i++) { 102 | for (int j = 0; j < 16; j += 2) { 103 | if (tid == i%8*4+j%8/2) { 104 | float2 xy = to_float2(t.data[i/8+j/8*2]); 105 | printf("%f %f ", xy.x, xy.y); 106 | //printf("T%d:{a%d,a%d} ", threadIdx.x, (i/8+j/8*2)*2, (i/8+j/8*2)*2+1); 107 | } 108 | __syncthreads(); 109 | } 110 | if (tid == 0) printf("\n"); 111 | __syncthreads(); 112 | } 113 | if (tid == 0) printf("\n"); 114 | __syncthreads(); 115 | } 116 | 117 | template 118 | __device__ void print8(T mat) { 119 | for (int i = 0; i < 8; i++) { 120 | for (int j = 0; j < 8; j += 2) { 121 | if (threadIdx.x == i%8*4+j%8/2) { 122 | float2 xy = to_float2(mat); 123 | printf("%f %f ", xy.x, xy.y); 124 | } 125 | __syncthreads(); 126 | } 127 | if (threadIdx.x == 0) printf("\n"); 128 | __syncthreads(); 129 | } 130 | if (threadIdx.x == 0) printf("\n"); 131 | __syncthreads(); 132 | } 133 | 134 | 135 | 136 | __device__ void load(STile&sa, bf*ga, int stride) { 137 | int i = threadIdx.x%32/2, j = threadIdx.x%2; 138 | asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" :: "r"(__smem(&sa.data[i*16+j*8])), "l"(ga+stride*i+j*8), "n"(16)); 139 | } 140 | 141 | __device__ void load(RTile&ra, const STile&sa) { 142 | int i = threadIdx.x%8, j = threadIdx.x%32/16, k = threadIdx.x/8%2; 143 | asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" 144 | : "=r"(as_uint(ra.data[0])), "=r"(as_uint(ra.data[1])), "=r"(as_uint(ra.data[2])), "=r"(as_uint(ra.data[3])) 145 | : "r"(__smem(&sa.data[i*16+j*8+k*8*16]))); 146 | } 147 | __device__ void loadT(RTile&ra, const STile&sa) { 148 | int i = threadIdx.x%8, j = threadIdx.x%32/16, k = threadIdx.x/8%2; 149 | asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" 150 | : "=r"(as_uint(ra.data[0])), "=r"(as_uint(ra.data[1])), "=r"(as_uint(ra.data[2])), "=r"(as_uint(ra.data[3])) 151 | : "r"(__smem(&sa.data[i*16+j*8*16+k*8]))); 152 | } 153 | 154 | __device__ static inline void __m16n8k16(float2&d0, float2&d1, const bf2 &a0, const bf2 &a1, const bf2 &a2, const bf2 &a3, const bf2 &b0, const bf2 &b1, const float2 &c0, const float2 &c1) { 155 | asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" 156 | : "=f"(d0.x), "=f"(d0.y), "=f"(d1.x), "=f"(d1.y) 157 | : "r"(as_uint(a0)), "r"(as_uint(a1)), "r"(as_uint(a2)), "r"(as_uint(a3)), 158 | "r"(as_uint(b0)), "r"(as_uint(b1)), 159 | "f"(c0.x), "f"(c0.y), "f"(c1.x), "f"(c1.y)); 160 | } 161 | __device__ void mma(FTile&rd, const RTile&ra, const RTile&rb, const FTile&rc) { // d = a*b^T + c 162 | __m16n8k16(rd.data[0],rd.data[1], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[0],rb.data[2], rc.data[0],rc.data[1]); 163 | __m16n8k16(rd.data[2],rd.data[3], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[1],rb.data[3], rc.data[2],rc.data[3]); 164 | } 165 | __device__ static inline void __m16n8k16(float2&d0, float2&d1, const bf2 &a0, const bf2 &a1, const bf2 &a2, const bf2 &a3, const bf2 &b0, const bf2 &b1) { 166 | asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" 167 | : "+f"(d0.x), "+f"(d0.y), "+f"(d1.x), "+f"(d1.y) 168 | : "r"(as_uint(a0)), "r"(as_uint(a1)), "r"(as_uint(a2)), "r"(as_uint(a3)), 169 | "r"(as_uint(b0)), "r"(as_uint(b1)), 170 | "f"(d0.x), "f"(d0.y), "f"(d1.x), "f"(d1.y)); 171 | } 172 | __device__ void mma(FTile&rd, const RTile&ra, const RTile&rb) { // d += a*b^T 173 | __m16n8k16(rd.data[0],rd.data[1], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[0],rb.data[2]); 174 | __m16n8k16(rd.data[2],rd.data[3], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[1],rb.data[3]); 175 | } 176 | __device__ void mm(FTile&rd, const RTile&ra, const RTile&rb) { // d = a*b^T 177 | __m16n8k16(rd.data[0],rd.data[1], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[0],rb.data[2], {0.f,0.f}, {0.f,0.f}); 178 | __m16n8k16(rd.data[2],rd.data[3], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[1],rb.data[3], {0.f,0.f}, {0.f,0.f}); 179 | } 180 | 181 | __device__ void store(const FTile&ra, float*ga, int stride) { 182 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2; 183 | *((float2*)&ga[ i *stride+j ]) = ra.data[0]; 184 | *((float2*)&ga[(i+8)*stride+j ]) = ra.data[1]; 185 | *((float2*)&ga[ i *stride+j+8]) = ra.data[2]; 186 | *((float2*)&ga[(i+8)*stride+j+8]) = ra.data[3]; 187 | } 188 | 189 | __device__ void store(const RTile&ra, bf*ga, int stride) { 190 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2; 191 | *((bf2*)&ga[ i *stride+j ]) = ra.data[0]; 192 | *((bf2*)&ga[(i+8)*stride+j ]) = ra.data[1]; 193 | *((bf2*)&ga[ i *stride+j+8]) = ra.data[2]; 194 | *((bf2*)&ga[(i+8)*stride+j+8]) = ra.data[3]; 195 | } 196 | __device__ void load(RTile&ra, bf*ga, int stride) { 197 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2; 198 | ra.data[0] = *((bf2*)&ga[ i *stride+j ]); 199 | ra.data[1] = *((bf2*)&ga[(i+8)*stride+j ]); 200 | ra.data[2] = *((bf2*)&ga[ i *stride+j+8]); 201 | ra.data[3] = *((bf2*)&ga[(i+8)*stride+j+8]); 202 | } 203 | __device__ void store(const RTile&ra, STile&sa) { //TODO: reduce bank conflicts? 204 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2; 205 | *((bf2*)&sa.data[ i *16+j ]) = ra.data[0]; 206 | *((bf2*)&sa.data[(i+8)*16+j ]) = ra.data[1]; 207 | *((bf2*)&sa.data[ i *16+j+8]) = ra.data[2]; 208 | *((bf2*)&sa.data[(i+8)*16+j+8]) = ra.data[3]; 209 | } 210 | 211 | __device__ void convert(RTile&ra, const FTile&fa) { 212 | ra.data[0] = to_bf2(fa.data[0]); 213 | ra.data[1] = to_bf2(fa.data[1]); 214 | ra.data[2] = to_bf2(fa.data[2]); 215 | ra.data[3] = to_bf2(fa.data[3]); 216 | } 217 | __device__ void convert(FTile&fa, const RTile&ra) { 218 | fa.data[0] = to_float2(ra.data[0]); 219 | fa.data[1] = to_float2(ra.data[1]); 220 | fa.data[2] = to_float2(ra.data[2]); 221 | fa.data[3] = to_float2(ra.data[3]); 222 | } 223 | 224 | __device__ STile& STile::operator=(const GTile& ga) { load(*this, ga.ga, ga.stride); return *this; } 225 | __device__ RTile& RTile::operator=(const GTile& ga) { load(*this, ga.ga, ga.stride); return *this; } 226 | __device__ RTile& RTile::operator=(const STile& sa) { load(*this, sa); return *this; } 227 | __device__ STile& STile::operator=(const RTile& ra) { store(ra, *this); return *this; } 228 | __device__ RTile& RTile::operator=(const STileT& sa) { loadT(*this, *sa.st); return *this; } 229 | __device__ Product operator%(const RTile&ra, const RTile&rb) { return Product{&ra,&rb}; } 230 | __device__ ProductPlus operator+(const Product&prod, const FTile&rc) { return ProductPlus{prod.a,prod.b,&rc}; } 231 | __device__ FTile& FTile::operator=(const Product& prod) { mm(*this, *prod.a, *prod.b); return *this; } 232 | __device__ FTile& FTile::operator=(const ProductPlus& prod) { mma(*this, *prod.a, *prod.b, *prod.c); return *this; } 233 | __device__ FTile& FTile::operator+=(const Product& prod) { mma(*this, *prod.a, *prod.b); return *this; } 234 | __device__ RTile& RTile::operator=(const FTile&fa) { convert(*this,fa); return *this; } 235 | __device__ FTile& FTile::operator=(const RTile&ra) { convert(*this,ra); return *this; } 236 | __device__ GTile& GTile::operator=(const RTile&ra) { store(ra, this->ga, this->stride); return *this; } 237 | __device__ GFTile& GFTile::operator=(const FTile&fa) { store(fa, this->ga, this->stride); return *this; } 238 | 239 | // Is this kind of cumsum better than multiplying with a triangular matrix of ones? 240 | template 241 | __device__ FTile cumsumv(FTile&w) { 242 | int tid = threadIdx.x%32, t = tid/4; 243 | 244 | FTile ret; 245 | if (inclusive) for (int i = 0; i < 4; i++) ret.data[i] = w.data[i]; 246 | else for (int i = 0; i < 4; i++) ret.data[i] = float2{0.f,0.f}; 247 | 248 | for (int b = 0; b < 3; b++) { 249 | for (int i = 0; i < 8; i++) { 250 | float other_w = __shfl_xor_sync(0xffffffff, w.fdata[i], 4<>b)%2 == !rev) ret.fdata[i] += other_w; 252 | w.fdata[i] += other_w; 253 | } 254 | } 255 | for (int i : {0,1,4,5}) { 256 | float &w0 = w.fdata[i^(2*!rev)], &w1 = w.fdata[i^(2*rev)]; 257 | ret.fdata[i^(2*!rev)] += w1; 258 | w0 += w1; 259 | w1 = w0; 260 | } 261 | return ret; 262 | } 263 | 264 | template 265 | __device__ FTile cumprodv(FTile&w) { 266 | int tid = threadIdx.x%32, t = tid/4; 267 | 268 | FTile ret; 269 | if (inclusive) for (int i = 0; i < 4; i++) ret.data[i] = w.data[i]; 270 | else for (int i = 0; i < 4; i++) ret.data[i] = float2{1.f,1.f}; 271 | 272 | for (int b = 0; b < 3; b++) { 273 | for (int i = 0; i < 8; i++) { 274 | float other_w = __shfl_xor_sync(0xffffffff, w.fdata[i], 4<>b)%2 == !rev) ret.fdata[i] *= other_w; 276 | w.fdata[i] *= other_w; 277 | } 278 | } 279 | for (int i : {0,1,4,5}) { 280 | float &w0 = w.fdata[i^(2*!rev)], &w1 = w.fdata[i^(2*rev)]; 281 | ret.fdata[i^(2*!rev)] *= w1; 282 | w0 *= w1; 283 | w1 = w0; 284 | } 285 | return ret; 286 | } 287 | 288 | __device__ FTile operator*(const FTile&a, const FTile&b) { 289 | FTile ret; 290 | for (int i = 0; i < 8; i++) ret.fdata[i] = a.fdata[i]*b.fdata[i]; 291 | return ret; 292 | } 293 | 294 | template // Lower triangular 295 | __device__ FTile sum_warp(float2*share, const FTile&f) { 296 | int tid = threadIdx.x%32, warpi = threadIdx.x/32; 297 | FTile sum; 298 | sum.zero_(); 299 | for (int i : {0,1,2,3}) { 300 | if (i == 2 && triangular) continue; 301 | for (int j = 0; j < WARPS; j++) { 302 | if (warpi == j) share[tid] = f.data[i]; 303 | __syncthreads(); 304 | sum.data[i].x += share[tid].x; 305 | sum.data[i].y += share[tid].y; 306 | __syncthreads(); 307 | } 308 | } 309 | return sum; 310 | } 311 | 312 | __device__ RTile from_warp(const RTile&ra, int src, float4*share) { 313 | int tid = threadIdx.x%32, warpi = threadIdx.x/32; 314 | RTile ret; 315 | if (warpi == src) share[tid] = *((float4*)ra.data); 316 | __syncthreads(); 317 | *((float4*)ret.data) = share[tid]; 318 | __syncthreads(); 319 | return ret; 320 | } 321 | 322 | // inv(I-f) where f is strictly lower triangular 323 | __device__ FTile tri_minv(const FTile&f, float*share) { 324 | int i0 = threadIdx.x%32/4, j0 = threadIdx.x%4*2; 325 | float inv[16] = {}; 326 | for (int k = 0; k < 8; k++) { 327 | int i = i0+k/2%2*8, j = j0+k%2+k/4*8; 328 | share[i*16+j] = f.fdata[k]; 329 | } 330 | int tid = threadIdx.x%32; 331 | inv[tid%16] = 1; 332 | for (int i = 1; i < 16; i++) { 333 | for (int j = 0; j < i; j++) { 334 | float fac = share[i*16+j]; 335 | inv[i] += fac*inv[j]; 336 | } 337 | } 338 | for (int i = 0; i < 16; i++) 339 | share[tid*16+i] = inv[i]; 340 | FTile ret; 341 | for (int k = 0; k < 8; k++) { 342 | int i = i0+k/2%2*8, j = j0+k%2+k/4*8; 343 | ret.fdata[k] = share[j*16+i]; 344 | } 345 | return ret; 346 | } 347 | 348 | template 349 | __device__ FTile tril(const FTile&f) { 350 | int i0 = threadIdx.x%32/4, j0 = threadIdx.x%4*2; 351 | FTile ret; 352 | for (int k = 0; k < 8; k++) { 353 | int i = i0+k/2%2*8, j = j0+k%2+k/4*8; 354 | if (strict) ret.fdata[k] = (i>j ? f.fdata[k] : 0.f); 355 | else ret.fdata[k] = (i>=j ? f.fdata[k] : 0.f); 356 | } 357 | return ret; 358 | } 359 | 360 | template 361 | __device__ void apply_(FTile&tile, F f) { 362 | for (int i = 0; i < 8; i++) tile.fdata[i] = f(tile.fdata[i]); 363 | } 364 | 365 | __device__ bf2 transpose(bf2 a) { 366 | bf2 ret; 367 | asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(as_uint(ret)) : "r"(as_uint(a))); 368 | return ret; 369 | } 370 | 371 | __device__ RTile transpose(const RTile&ra) { 372 | RTile rb; 373 | rb.data[0] = transpose(ra.data[0]); 374 | rb.data[1] = transpose(ra.data[2]); 375 | rb.data[2] = transpose(ra.data[1]); 376 | rb.data[3] = transpose(ra.data[3]); 377 | return rb; 378 | } 379 | 380 | template 381 | __device__ FTile slow_dw(const RTile&A, const RTile&q, const RTile&k, STile*share) { 382 | share[0] = A; 383 | share[1] = q; 384 | share[2] = k; 385 | __syncthreads(); 386 | if (threadIdx.x%32 == 0) { 387 | for (int k = 0; k < 16; k++) { 388 | for (int j = 0; j < 16; j++) { 389 | float sum = 0; 390 | for (int l = 0; l < k; l++) { 391 | for (int r = k+strict; r < 16; r++) { 392 | sum += to_float(share[0].data[r*16+l]) * to_float(share[1].data[r*16+j]) * to_float(share[2].data[l*16+j]); 393 | } 394 | } 395 | share[3].data[k*16+j] = to_bf(sum); 396 | } 397 | } 398 | } 399 | __syncthreads(); 400 | RTile ret = (RTile)share[3]; 401 | __syncthreads(); 402 | return ret; 403 | } 404 | 405 | 406 | __device__ static inline void __m16n8k8(float2&d0, float2&d1, const bf2 &a0, const bf2 &a1, const bf2 &b0) { 407 | asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};" 408 | : "=f"(d0.x), "=f"(d0.y), "=f"(d1.x), "=f"(d1.y) : "r"(as_uint(a0)), "r"(as_uint(a1)), "r"(as_uint(b0)), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); 409 | } 410 | 411 | template 412 | __device__ RTile fast_dw(const RTile&A, const RTile&q, const RTile&k) { 413 | float2 qkA8[4]; 414 | RTile kt = transpose(k), qt = transpose(q); 415 | __m16n8k8(qkA8[0],qkA8[1], qt.data[2], qt.data[3], transpose(A.data[1])); 416 | __m16n8k8(qkA8[2],qkA8[3], kt.data[0], kt.data[1], A.data[1]); 417 | for (int x : {0,1}) { 418 | qkA8[x] *= to_float2(kt.data[x]); 419 | qkA8[2+x] *= to_float2(qt.data[2+x]); 420 | } 421 | 422 | int tid = threadIdx.x%32, j = threadIdx.x%4; 423 | // Non-inclusive cumsum 424 | for (int i = 0; i < 4; i++) { 425 | float sum = qkA8[i].x+qkA8[i].y; 426 | float psum = __shfl_xor_sync(0xffffffff, sum, 1); 427 | float ppsum = __shfl_xor_sync(0xffffffff, sum+psum, 2); 428 | if (i < 2) { 429 | psum = ppsum*(j>=2)+psum*(j%2); 430 | qkA8[i].y = psum + qkA8[i].x; 431 | qkA8[i].x = psum; 432 | } else { 433 | psum = ppsum*(j<2)+psum*(j%2==0); 434 | qkA8[i].x = psum + qkA8[i].y; 435 | qkA8[i].y = psum; 436 | } 437 | } 438 | 439 | float2 qkA4[4]; 440 | { 441 | RTile k_q; 442 | for (int i = 0; i < 8; i++) ((bf*)k_q.data)[i] = (j<2?((bf*)kt.data)[i]:((bf*)qt.data)[i]); 443 | float lower_left = (tid >= 16 && j < 2); 444 | bf2 A0 = to_bf2(to_float2(A.data[0])*float2{lower_left,lower_left}); 445 | bf2 A3 = to_bf2(to_float2(A.data[3])*float2{lower_left,lower_left}); 446 | __m16n8k8(qkA4[0],qkA4[1], k_q.data[0], k_q.data[1], A0 + transpose(A0)); 447 | __m16n8k8(qkA4[2],qkA4[3], k_q.data[2], k_q.data[3], A3 + transpose(A3)); 448 | for (int i = 0; i < 4; i++) 449 | qkA4[i] *= to_float2(k_q.data[i]); 450 | } 451 | 452 | // Non-inclusive cumsum 453 | for (int i = 0; i < 4; i++) { 454 | float sum = qkA4[i].x+qkA4[i].y; 455 | float psum = __shfl_xor_sync(0xffffffff, sum, 1); 456 | psum *= (j%2 == j<2); 457 | qkA4[i] = {psum + qkA4[i].y*(j>=2), psum + qkA4[i].x*(j<2)}; 458 | } 459 | 460 | FTile ret; 461 | ret.data[0] = qkA8[0]+qkA4[0]; 462 | ret.data[1] = qkA8[1]+qkA4[1]; 463 | ret.data[2] = qkA8[2]+qkA4[2]; 464 | ret.data[3] = qkA8[3]+qkA4[3]; 465 | 466 | for (int ci : {0,1}) { 467 | for (int ti : {0,1}) { 468 | int Ai = ti*3, di = ti*2+ci; 469 | unsigned mask = 0xffff<<(j>=2)*16; 470 | bf A8x = __shfl_sync(mask, A.data[Ai].x, 8+(j>=2)*18); 471 | bf A12x = __shfl_sync(mask, A.data[Ai].x, 12+(j>=2)*18); 472 | bf A12y = __shfl_sync(mask, A.data[Ai].y, 12+(j>=2)*18); 473 | bf2 nq = __shfl_xor_sync(0xffffffff, qt.data[di], 1); 474 | bf2 pk = __shfl_xor_sync(0xffffffff, kt.data[di], 1); 475 | 476 | bool even = (j%2==0); 477 | float ax = to_float(even?A8x:A12x), ay = to_float(even?A12x:A12y), c = to_float(even?kt.data[di].x:qt.data[di].y); 478 | float2 b = to_float2(j%2?pk:nq); 479 | float d = (ax*b.x+ay*b.y)*c; 480 | ret.data[di].y += even*d; 481 | ret.data[di].x +=!even*d; 482 | } 483 | } 484 | 485 | if (!strict) { 486 | // Do we really need tril<1>()? 487 | ret += (kt % tril<1>(A)) * qt; 488 | } 489 | return transpose(ret); 490 | } 491 | 492 | __device__ void debug_set(RTile&ra, int i, int j, float v) { 493 | if (threadIdx.x%32 == i%8*4+j%8/2) ((bf*)ra.data)[i/8*2+j/8*4+j%2] = to_bf(v); 494 | } 495 | -------------------------------------------------------------------------------- /rwkv_cuda_wind/wind_rwkv7.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | using bf = __nv_bfloat16; 4 | 5 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*s0, bf*y, bf*s, bf*sT); 6 | 7 | void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &s0, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sT) { 8 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; 9 | cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)s0.data_ptr(), (bf*)y.data_ptr(), (bf*)s.data_ptr(), (bf*)sT.data_ptr()); 10 | } 11 | 12 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, bf*s, bf*dsT, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da, bf*ds0); 13 | 14 | void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy, 15 | torch::Tensor &s, torch::Tensor &dsT, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da, torch::Tensor &ds0) { 16 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2]; 17 | cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)dy.data_ptr(), 18 | (bf*)s.data_ptr(), (bf*)dsT.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr(), (bf*)ds0.data_ptr()); 19 | } 20 | 21 | /*TORCH_LIBRARY(wind, m) { 22 | m.def("forward", forward); 23 | m.def("backward", backward); 24 | }*/ 25 | 26 | TORCH_LIBRARY(wind, m) { 27 | m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor s0, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sT) -> ()"); 28 | m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor dsT, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) dz, Tensor(f!) da, Tensor(g!) ds0) -> ()"); 29 | } 30 | 31 | TORCH_LIBRARY_IMPL(wind, CUDA, m) { 32 | m.impl("forward", &forward); 33 | m.impl("backward", &backward); 34 | } 35 | -------------------------------------------------------------------------------- /rwkv_cuda_wind/wind_rwkv7.cu: -------------------------------------------------------------------------------- 1 | #include "tile.cuh" 2 | #include 3 | typedef bf * __restrict__ F_; 4 | 5 | constexpr int WARPS = _C_/16; 6 | constexpr int fw_stages = 1, bw_stages = 1; 7 | 8 | __global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ s0_, bf* y_, bf* s_, bf* sT_) { 9 | constexpr int C = _C_, K = 16; 10 | int bi = blockIdx.y, hi = blockIdx.x; 11 | extern __shared__ char smem_[]; 12 | char*smem = smem_; 13 | 14 | STile *sw_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 15 | STile *sq_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 16 | STile *sk_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 17 | STile *sv_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 18 | STile *sa_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 19 | STile *sb_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS; 20 | char*share = (char*)smem; 21 | 22 | int stride = H*C; 23 | int warpi = threadIdx.x/32; 24 | 25 | auto push = [&](int t) { 26 | int off = bi*T*H*C + t*K*H*C + hi*C + warpi*16; 27 | int si = t%fw_stages; 28 | sw_[si*WARPS+warpi] = GTile(w_+off, stride); 29 | sq_[si*WARPS+warpi] = GTile(q_+off, stride); 30 | sk_[si*WARPS+warpi] = GTile(k_+off, stride); 31 | sv_[si*WARPS+warpi] = GTile(v_+off, stride); 32 | sa_[si*WARPS+warpi] = GTile(a_+off, stride); 33 | sb_[si*WARPS+warpi] = GTile(b_+off, stride); 34 | }; 35 | for (int t = 0; t < fw_stages-1 && t < T/K; t++) push(t), __commit_group(); 36 | 37 | FTile state[WARPS]; 38 | for (int i = 0; i < WARPS; i++) { 39 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16; 40 | RTile tmp; 41 | tmp = GTile(s0_+off, C); 42 | state[i] = tmp; 43 | } 44 | 45 | for (int t = 0; t < T/K; t++) { 46 | __syncthreads(); 47 | if (t+fw_stages-1 < T/K) 48 | push(t+fw_stages-1); 49 | __commit_group(); 50 | __wait_groups(); 51 | __syncthreads(); 52 | int si = t%fw_stages; 53 | STile &sw = sw_[si*WARPS+warpi], &sq = sq_[si*WARPS+warpi], &sk = sk_[si*WARPS+warpi], &sv = sv_[si*WARPS+warpi], &sa = sa_[si*WARPS+warpi], &sb = sb_[si*WARPS+warpi]; 54 | 55 | FTile w = (RTile)sw; 56 | apply_(w, [](float x) { return __expf(-__expf(x)); }); 57 | FTile fw = w; 58 | FTile non_incl_pref = cumprodv<0,0>(fw); 59 | FTile incl_pref = non_incl_pref * w; 60 | FTile inv_incl_pref = incl_pref; 61 | apply_(inv_incl_pref, [](float x) { return 1.f/x; }); 62 | 63 | RTile wq = (RTile)sq * incl_pref, kwi = (RTile)sk * inv_incl_pref; 64 | RTile wa = (RTile)sa * non_incl_pref, bwi = (RTile)sb * inv_incl_pref; 65 | FTile ab = sum_warp<1,WARPS>((float2*)share, tril<1>(wa % bwi)); 66 | RTile ak = sum_warp<1,WARPS>((float2*)share, tril<1>(wa % kwi)); 67 | 68 | RTile ab_inv; 69 | __syncthreads(); 70 | if (threadIdx.x < 32) ab_inv = tri_minv(ab, (float*)share); 71 | __syncthreads(); 72 | ab_inv = from_warp(ab_inv, 0, (float4*)share); 73 | 74 | RTile vt = sv.t(); 75 | FTile ab_ut = vt % ak; 76 | for (int i = 0; i < WARPS; i++) 77 | ab_ut += state[i] % from_warp(wa, i, (float4*)share); 78 | RTile ut = FTile(ab_ut % ab_inv); 79 | 80 | FTile y = sum_warp<1,WARPS>((float2*)share, tril<0>(wq % kwi)) % vt; 81 | y += sum_warp<1,WARPS>((float2*)share, tril<0>(wq % bwi)) % ut; 82 | for (int i = 0; i < WARPS; i++) 83 | y += from_warp(wq, i, (float4*)share) % state[i]; 84 | 85 | int off = bi*T*H*C + t*K*H*C + hi*C + warpi*16; 86 | GTile(y_+off, stride) = RTile(y); 87 | 88 | RTile kwt = transpose(kwi*fw), bwt = transpose(bwi*fw); 89 | for (int i = 0; i < WARPS; i++) { 90 | int off = bi*H*(T/K)*C*C + hi*(T/K)*C*C + t*C*C + warpi*16*C + i*16; 91 | GTile(s_+off, C) = (RTile)state[i]; 92 | 93 | FTile fstate = state[i] * from_warp(fw, i, (float4*)share); 94 | fstate += vt % from_warp(kwt, i, (float4*)share); 95 | fstate += ut % from_warp(bwt, i, (float4*)share); 96 | state[i] = fstate; 97 | } 98 | } 99 | for (int i = 0; i < WARPS; i++) { 100 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16; 101 | GTile(sT_+off, C) = state[i]; 102 | } 103 | } 104 | 105 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*s0, bf*y, bf*s, bf*sT) { 106 | assert(T%16 == 0); 107 | constexpr int tmp_size1 = sizeof(float4)*32, tmp_size2 = sizeof(float)*16*16*2; 108 | constexpr int threads = 32*WARPS, shared_mem = sizeof(STile)*fw_stages*WARPS*6 + (tmp_size1 > tmp_size2 ? tmp_size1 : tmp_size2); 109 | static int reported = 0; 110 | if (!reported++) { 111 | #if defined VERBOSE 112 | printf("forward_kernel() uses %d bytes of (dynamic) shared memory\n", shared_mem); 113 | #endif 114 | cudaFuncAttributes attr; 115 | cudaFuncGetAttributes(&attr, forward_kernel); 116 | int cur_mem = attr.maxDynamicSharedSizeBytes; 117 | if (shared_mem > cur_mem) { 118 | #if defined VERBOSE 119 | printf("Increasing forward_kernel's MaxDynamicSharedMemorySize from %d to %d\n", cur_mem, shared_mem); 120 | #endif 121 | assert(!cudaFuncSetAttribute(forward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem)); 122 | } 123 | } 124 | forward_kernel<<>>(T,H,w,q,k,v,z,a,s0,y,s,sT); 125 | } 126 | 127 | 128 | __global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, F_ s_, F_ dsT_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_, bf* ds0_) { 129 | constexpr int C = _C_, K = 16; 130 | int bi = blockIdx.y, hi = blockIdx.x; 131 | extern __shared__ char smem_[]; 132 | char*smem = smem_; 133 | 134 | STile *sw_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 135 | STile *sq_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 136 | STile *sk_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 137 | STile *sv_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 138 | STile *sa_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 139 | STile *sb_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 140 | STile *sdy_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS; 141 | STile *state_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS*WARPS; 142 | char*share = (char*)smem; 143 | 144 | int stride = H*C; 145 | int warpi = threadIdx.x/32; 146 | 147 | auto push = [&](int t) { 148 | int off = bi*T*H*C + t*K*H*C + hi*C + warpi*16; 149 | int si = t%fw_stages; 150 | sw_[si*WARPS+warpi] = GTile(w_+off, stride); 151 | sq_[si*WARPS+warpi] = GTile(q_+off, stride); 152 | sk_[si*WARPS+warpi] = GTile(k_+off, stride); 153 | sv_[si*WARPS+warpi] = GTile(v_+off, stride); 154 | sa_[si*WARPS+warpi] = GTile(a_+off, stride); 155 | sb_[si*WARPS+warpi] = GTile(b_+off, stride); 156 | sdy_[si*WARPS+warpi] = GTile(dy_+off, stride); 157 | for (int i = 0; i < WARPS; i++) { 158 | int off2 = bi*H*(T/K)*C*C + hi*(T/K)*C*C + t*C*C + warpi*16*C + i*16; 159 | state_[si*WARPS*WARPS+warpi*WARPS+i] = GTile(s_+off2, C); 160 | } 161 | }; 162 | 163 | FTile dstate[WARPS]; 164 | for (int i = 0; i < WARPS; i++) { 165 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16; 166 | RTile tmp; 167 | tmp = GTile(dsT_+off, C); 168 | dstate[i] = tmp; 169 | __commit_group(); 170 | } 171 | 172 | for (int t = 0; t < bw_stages-1 && t < T/K; t++) push(T/K-1-t), __commit_group(); 173 | 174 | for (int t = T/K-1; t >= 0; t--) { 175 | __syncthreads(); 176 | if (t-bw_stages+1 >= 0) 177 | push(t-bw_stages+1); 178 | __commit_group(); 179 | __wait_groups(); 180 | __syncthreads(); 181 | int si = t%bw_stages; 182 | STile &sw = sw_[si*WARPS+warpi], &sq = sq_[si*WARPS+warpi], &sk = sk_[si*WARPS+warpi], &sv = sv_[si*WARPS+warpi], &sa = sa_[si*WARPS+warpi], &sb = sb_[si*WARPS+warpi], &sdy = sdy_[si*WARPS+warpi]; 183 | STile*state = state_+si*WARPS*WARPS; 184 | 185 | FTile w = (RTile)sw; 186 | apply_(w, [](float x) { return __expf(-__expf(x)); }); 187 | FTile fw = w; 188 | FTile non_incl_pref = cumprodv<0,0>(fw); 189 | FTile incl_pref = non_incl_pref * w; 190 | FTile inv_incl_pref = incl_pref; 191 | apply_(inv_incl_pref, [](float x) { return 1.f/x; }); 192 | 193 | RTile wq = (RTile)sq * incl_pref, kwi = (RTile)sk * inv_incl_pref; 194 | RTile wa = (RTile)sa * non_incl_pref, bwi = (RTile)sb * inv_incl_pref; 195 | FTile ab = sum_warp<1,WARPS>((float2*)share, tril<1>(wa % bwi)); 196 | RTile ak = sum_warp<1,WARPS>((float2*)share, tril<1>(wa % kwi)); 197 | 198 | RTile ab_inv; 199 | __syncthreads(); 200 | if (threadIdx.x < 32) ab_inv = tri_minv(ab, (float*)share); 201 | __syncthreads(); 202 | ab_inv = from_warp(ab_inv, 0, (float4*)share); 203 | 204 | RTile vt = sv.t(); 205 | FTile ab_ut = vt % ak; 206 | for (int i = 0; i < WARPS; i++) 207 | ab_ut += state[warpi*WARPS+i] % from_warp(wa, i, (float4*)share); 208 | RTile ut = FTile(ab_ut % ab_inv); 209 | 210 | RTile qb = sum_warp<1,WARPS>((float2*)share, tril<0>(wq % bwi)); 211 | RTile qk = sum_warp<1,WARPS>((float2*)share, tril<0>(wq % kwi)); 212 | 213 | RTile dyt = sdy.t(); 214 | FTile dut = FTile(dyt % transpose(qb)); 215 | FTile dv = transpose(qk) % dyt; 216 | for (int i = 0; i < WARPS; i++) { 217 | RTile dstatei = dstate[i]; 218 | dut += dstatei % from_warp(bwi*fw, i, (float4*)share); 219 | dv += from_warp(kwi*fw, i, (float4*)share) % dstatei; 220 | } 221 | RTile dab_ut = FTile(dut % transpose(ab_inv)); 222 | dv += transpose(ak) % dab_ut; 223 | 224 | int off = bi*T*H*C + t*K*H*C + hi*C + warpi*16; 225 | GTile(dv_+off, stride) = RTile(dv); 226 | 227 | FTile dab = sum_warp<1,WARPS>((float2*)share, tril<1>(transpose(dab_ut) % transpose(ut))); 228 | FTile dak = sum_warp<1,WARPS>((float2*)share, tril<1>(transpose(dab_ut) % transpose(vt))); 229 | FTile dab_u_state0; 230 | dab_u_state0.zero_(); 231 | for (int i = 0; i < WARPS; i++) 232 | dab_u_state0 += from_warp(transpose(dab_ut), i, (float4*)share) % state[i*WARPS+warpi].t(); 233 | 234 | FTile da = dab_u_state0; 235 | da += dab % transpose(bwi); 236 | da += dak % transpose(kwi); 237 | da = non_incl_pref * da; 238 | GTile(da_+off, stride) = RTile(da); 239 | 240 | FTile dqb = sum_warp<1,WARPS>((float2*)share, tril<0>(transpose(dyt) % transpose(ut))); 241 | FTile dqk = sum_warp<1,WARPS>((float2*)share, tril<0>(transpose(dyt) % transpose(vt))); 242 | FTile dy_state0; 243 | dy_state0.zero_(); 244 | for (int i = 0; i < WARPS; i++) 245 | dy_state0 += from_warp(transpose(dyt), i, (float4*)share) % state[i*WARPS+warpi].t(); 246 | 247 | FTile dq = dy_state0; 248 | dq += dqb % transpose(bwi); 249 | dq += dqk % transpose(kwi); 250 | dq = incl_pref * dq; 251 | GTile(dq_+off, stride) = RTile(dq); 252 | 253 | RTile wqt = transpose(wq), wat = transpose(wa); 254 | 255 | FTile u_dstate, v_dstate, dw; 256 | u_dstate.zero_(); 257 | v_dstate.zero_(); 258 | dw.zero_(); 259 | RTile ones; 260 | for (int i = 0; i < 4; i++) ones.data[i] = to_bf2({1.f,1.f}); 261 | for (int i = 0; i < WARPS; i++) { 262 | int tid = threadIdx.x%32; 263 | if (warpi == i) { 264 | for (int j = 0; j < WARPS; j++) { 265 | RTile ra = dstate[j]; 266 | ((float4*)share)[j*32+tid] = *((float4*)ra.data); 267 | } 268 | } 269 | RTile dstatei;// = dstate[i*WARPS+warpi]; 270 | __syncthreads(); 271 | *((float4*)dstatei.data) = ((float4*)share)[warpi*32+tid]; 272 | __syncthreads(); 273 | RTile dstatei_t = transpose(dstatei); 274 | v_dstate += from_warp(transpose(vt), i, (float4*)share) % dstatei_t; 275 | u_dstate += from_warp(transpose(ut), i, (float4*)share) % dstatei_t; 276 | dw += ones % transpose((RTile)state[i*WARPS+warpi]*dstatei); 277 | } 278 | 279 | FTile db = fw * u_dstate; 280 | db += transpose(dab) % wat; 281 | db += transpose(dqb) % wqt; 282 | db = inv_incl_pref * db; 283 | GTile(db_+off, stride) = RTile(db); 284 | 285 | FTile dk = fw * v_dstate; 286 | dk += transpose(dak) % wat; 287 | dk += transpose(dqk) % wqt; 288 | dk = inv_incl_pref * dk; 289 | GTile(dk_+off, stride) = RTile(dk); 290 | 291 | dw = fw * dw; 292 | dw += fast_dw<1>(dab,wa,bwi); 293 | dw += fast_dw<1>(dak,wa,kwi); 294 | dw += fast_dw<0>(dqb,wq,bwi); 295 | dw += fast_dw<0>(dqk,wq,kwi); 296 | FTile tmp; 297 | dw += cumsumv<0,0>(tmp = v_dstate*(fw*kwi)); 298 | dw += cumsumv<0,0>(tmp = u_dstate*(fw*bwi)); 299 | dw += cumsumv<0,1>(tmp = dab_u_state0*wa); 300 | dw += cumsumv<1,1>(tmp = dy_state0*wq); 301 | 302 | FTile dw_fac = (RTile)sw; 303 | apply_(dw_fac, [](float x) { return -__expf(x); }); 304 | dw = dw * dw_fac; 305 | GTile(dw_+off, stride) = RTile(dw); 306 | 307 | __syncthreads(); 308 | for (int i = 0; i < WARPS; i++) { 309 | FTile ndstate = dstate[i] * from_warp(fw, i, (float4*)share); 310 | ndstate += dyt % from_warp(wqt, i, (float4*)share); 311 | ndstate += dab_ut % from_warp(wat, i, (float4*)share); 312 | dstate[i] = ndstate; 313 | } 314 | __syncthreads(); 315 | } 316 | for (int i = 0; i < WARPS; i++) { 317 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16; 318 | GTile(ds0_+off, C) = dstate[i]; 319 | } 320 | } 321 | 322 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, bf*s, bf*dsT, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da, bf*ds0) { 323 | assert(T%16 == 0); 324 | constexpr int tmp_size1 = sizeof(float4)*32*WARPS, tmp_size2 = sizeof(float)*16*16*2; 325 | constexpr int threads = 32*WARPS, shared_mem = sizeof(STile)*WARPS*bw_stages*(7+WARPS) + (tmp_size1 > tmp_size2 ? tmp_size1 : tmp_size2); 326 | static int reported = 0; 327 | if (!reported++) { 328 | #if defined VERBOSE 329 | printf("backward_kernel() uses %d bytes of (dynamic) shared memory\n", shared_mem); 330 | #endif 331 | cudaFuncAttributes attr; 332 | cudaFuncGetAttributes(&attr, backward_kernel); 333 | int cur_mem = attr.maxDynamicSharedSizeBytes; 334 | if (shared_mem > cur_mem) { 335 | #if defined VERBOSE 336 | printf("Increasing backward_kernel's MaxDynamicSharedMemorySize from %d to %d\n", cur_mem, shared_mem); 337 | #endif 338 | assert(!cudaFuncSetAttribute(backward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem)); 339 | } 340 | } 341 | backward_kernel<<>>(T,H,w,q,k,v,z,a,dy,s,dsT,dw,dq,dk,dv,dz,da,ds0); 342 | } 343 | -------------------------------------------------------------------------------- /train_gpt2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | with open(sys.argv[0]) as f: 4 | code = f.read() # read the code of this file ASAP, for logging 5 | import uuid 6 | import glob 7 | import time 8 | from dataclasses import dataclass 9 | 10 | import numpy as np 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | import torch.distributed as dist 15 | import torch._inductor.config as config 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | 18 | # ----------------------------------------------------------------------------- 19 | # Muon optimizer 20 | 21 | def zeropower_via_svd(G, steps=None): 22 | U, S, V = G.svd() 23 | return U @ V.T 24 | 25 | @torch.compile 26 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): 27 | r""" 28 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 29 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 30 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 31 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 32 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 33 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 34 | performance at all relative to UV^T, where USV^T = G is the SVD. 35 | """ 36 | assert len(G.shape) == 2 37 | a, b, c = (3.4445, -4.7750, 2.0315) 38 | X = G.bfloat16() 39 | X /= (X.norm() + eps) # ensure top singular value <= 1 40 | if G.size(0) > G.size(1): 41 | X = X.T 42 | for _ in range(steps): 43 | A = X @ X.T 44 | B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng 45 | X = a * X + B @ X 46 | if G.size(0) > G.size(1): 47 | X = X.T 48 | return X 49 | 50 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5) 51 | 52 | class Muon(torch.optim.Optimizer): 53 | """ 54 | Muon - MomentUm Orthogonalized by Newton-schulz 55 | 56 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 57 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 58 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 59 | the advantage that it can be stably run in bfloat16 on the GPU. 60 | 61 | Some warnings: 62 | - This optimizer assumes that all parameters passed in are 2D. 63 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D 64 | parameters; those should all be optimized by a standard method (e.g., AdamW). 65 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. 66 | - We believe it is unlikely to work well for training with small batch size. 67 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this. 68 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). 69 | 70 | Arguments: 71 | lr: The learning rate used by the internal SGD. 72 | momentum: The momentum used by the internal SGD. 73 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 74 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') 75 | backend_steps: The number of iteration steps to use in the backend, if it is iterative. 76 | """ 77 | def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, 78 | backend='newtonschulz5', backend_steps=5): 79 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps) 80 | super().__init__(params, defaults) 81 | 82 | def step(self): 83 | 84 | for group in self.param_groups: 85 | 86 | lr = group['lr'] 87 | momentum = group['momentum'] 88 | zeropower_backend = zeropower_backends[group['backend']] 89 | 90 | # generate weight updates in distributed fashion 91 | total_params = sum(p.numel() for p in group['params']) 92 | updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16) 93 | curr_idx = 0 94 | for i, p in enumerate(group['params']): 95 | # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs 96 | if i % int(os.environ['WORLD_SIZE']) == int(os.environ['RANK']): 97 | g = p.grad 98 | assert g is not None 99 | state = self.state[p] 100 | if 'momentum_buffer' not in state: 101 | state['momentum_buffer'] = torch.zeros_like(g) 102 | buf = state['momentum_buffer'] 103 | buf.mul_(momentum).add_(g) 104 | if group['nesterov']: 105 | g = g.add(buf, alpha=momentum) 106 | g = zeropower_backend(g, steps=group['backend_steps']) 107 | g *= max(1, g.size(0)/g.size(1))**0.5 108 | updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten() 109 | curr_idx += p.numel() 110 | 111 | # sync updates across devices. we are not memory-constrained so can do this simple deserialization 112 | dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) 113 | 114 | # deserialize and apply updates 115 | curr_idx = 0 116 | for p in group['params']: 117 | g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data) 118 | p.data.add_(g, alpha=-lr) 119 | curr_idx += p.numel() 120 | 121 | # ----------------------------------------------------------------------------- 122 | # PyTorch nn.Module definitions for the GPT-2 model 123 | 124 | class Rotary(torch.nn.Module): 125 | 126 | def __init__(self, dim, base=10000): 127 | super().__init__() 128 | self.dim = dim 129 | self.base = base 130 | self.inv_freq = None 131 | self.seq_len_cached = None 132 | self.cos_cached = None 133 | self.sin_cached = None 134 | 135 | def forward(self, x): 136 | seq_len = x.shape[1] 137 | if seq_len != self.seq_len_cached: 138 | self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim)) 139 | self.seq_len_cached = seq_len 140 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) 141 | freqs = torch.outer(t, self.inv_freq) 142 | self.cos_cached = freqs.cos().bfloat16() 143 | self.sin_cached = freqs.sin().bfloat16() 144 | return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] 145 | 146 | def apply_rotary_emb(x, cos, sin): 147 | assert x.ndim == 4 # multihead attention 148 | d = x.shape[3]//2 149 | x1 = x[..., :d] 150 | x2 = x[..., d:] 151 | y1 = x1 * cos + x2 * sin 152 | y2 = x1 * (-sin) + x2 * cos 153 | return torch.cat([y1, y2], 3).type_as(x) 154 | 155 | class CastedLinear(nn.Linear): 156 | def forward(self, x): 157 | return F.linear(x, self.weight.to(x.dtype)) 158 | 159 | class CausalSelfAttention(nn.Module): 160 | 161 | def __init__(self, config): 162 | super().__init__() 163 | self.n_head = config.n_head 164 | self.n_embd = config.n_embd 165 | self.head_dim = self.n_embd // self.n_head 166 | assert self.n_embd % self.n_head == 0 167 | self.c_q = CastedLinear(self.n_embd, self.n_embd, bias=False) 168 | self.c_k = CastedLinear(self.n_embd, self.n_embd, bias=False) 169 | self.c_v = CastedLinear(self.n_embd, self.n_embd, bias=False) 170 | # output projection 171 | self.c_proj = CastedLinear(self.n_embd, self.n_embd, bias=False) 172 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 173 | self.rotary = Rotary(self.head_dim) 174 | self.lamb = nn.Parameter(torch.tensor(0.5)) # @Grad62304977 175 | 176 | def forward(self, x, v1=None): 177 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 178 | q = self.c_q(x).view(B, T, self.n_head, self.head_dim) 179 | k = self.c_k(x).view(B, T, self.n_head, self.head_dim) 180 | v = self.c_v(x).view(B, T, self.n_head, self.head_dim) 181 | if v1 is None: 182 | v1 = v # This happens if we are in the first block. v needs to be accessed by subsequent blocks 183 | v = (1 - self.lamb) * v + self.lamb * v1.view_as(v) # @Grad62304977 184 | cos, sin = self.rotary(q) 185 | q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) # QK norm suggested by @Grad62304977 186 | q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) 187 | y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) 188 | y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side 189 | y = self.c_proj(y) 190 | return y, v1 191 | 192 | class MLP(nn.Module): 193 | 194 | def __init__(self, config): 195 | super().__init__() 196 | self.c_fc = CastedLinear(config.n_embd, 4 * config.n_embd, bias=False) 197 | self.c_proj = CastedLinear(4 * config.n_embd, config.n_embd, bias=False) 198 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 199 | 200 | def forward(self, x): 201 | x = self.c_fc(x) 202 | x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 203 | x = self.c_proj(x) 204 | return x 205 | 206 | class Block(nn.Module): 207 | 208 | def __init__(self, config): 209 | super().__init__() 210 | self.attn = CausalSelfAttention(config) 211 | self.mlp = MLP(config) 212 | self.lambdas = nn.Parameter(torch.tensor([1., 0.])) 213 | 214 | def forward(self, x, v1, x0): 215 | x = self.lambdas[0] * x + self.lambdas[1] * x0 216 | x1, v1 = self.attn(F.rms_norm(x, (x.size(-1),)), v1) 217 | x = x + x1 218 | x = x + self.mlp(F.rms_norm(x, (x.size(-1),))) 219 | return x, v1 220 | 221 | # ----------------------------------------------------------------------------- 222 | # The main GPT-2 model 223 | 224 | @dataclass 225 | class GPTConfig: 226 | vocab_size : int = 50304 227 | n_layer : int = 12 228 | n_head : int = 6 # head dim 128 suggested by @Grad62304977 229 | n_embd : int = 768 230 | 231 | class GPT(nn.Module): 232 | 233 | def __init__(self, config): 234 | super().__init__() 235 | self.config = config 236 | 237 | self.transformer = nn.ModuleDict(dict( 238 | wte = nn.Embedding(config.vocab_size, config.n_embd), 239 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 240 | )) 241 | self.lm_head = CastedLinear(config.n_embd, config.vocab_size, bias=False) 242 | self.lm_head.weight.data.zero_() # @Grad62304977 243 | 244 | def forward(self, idx, target): 245 | 246 | # forward the GPT model itself 247 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 248 | x = F.rms_norm(x, (x.size(-1),)) # @Grad62304977 249 | x0 = x 250 | v1 = None 251 | for block in self.transformer.h: 252 | x, v1 = block(x, v1, x0) 253 | x = F.rms_norm(x, (x.size(-1),)) 254 | 255 | logits = self.lm_head(x) 256 | logits = 30 * torch.tanh(logits / 30) # @Grad62304977 257 | logits = logits.float() 258 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1)) 259 | return loss.float() 260 | 261 | # ----------------------------------------------------------------------------- 262 | # Our own simple Distributed Data Loader 263 | 264 | def _peek_data_shard(filename): 265 | # only reads the header, returns header data 266 | with open(filename, "rb") as f: 267 | # first read the header, which is 256 int32 integers (4 bytes each) 268 | header = np.frombuffer(f.read(256*4), dtype=np.int32) 269 | if header[0] != 20240520: 270 | print("ERROR: magic number mismatch in the data .bin file!") 271 | print("---> HINT: Are you passing in a correct file with --input_bin?") 272 | print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README") 273 | print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try") 274 | exit(1) 275 | assert header[1] == 1, "unsupported version" 276 | ntok = header[2] # number of tokens (claimed) 277 | return ntok # for now just return the number of tokens 278 | 279 | def _load_data_shard(filename): 280 | with open(filename, "rb") as f: 281 | # first read the header, which is 256 int32 integers (4 bytes each) 282 | header = np.frombuffer(f.read(256*4), dtype=np.int32) 283 | assert header[0] == 20240520, "magic number mismatch in the data .bin file" 284 | assert header[1] == 1, "unsupported version" 285 | ntok = header[2] # number of tokens (claimed) 286 | # the rest of it are tokens, stored as uint16 287 | tokens = np.frombuffer(f.read(), dtype=np.uint16) 288 | assert len(tokens) == ntok, "number of tokens read does not match header?" 289 | return tokens 290 | 291 | class DistributedDataLoader: 292 | def __init__(self, filename_pattern, B, T, process_rank, num_processes): 293 | self.process_rank = process_rank 294 | self.num_processes = num_processes 295 | self.B = B 296 | self.T = T 297 | 298 | # glob files that match the pattern 299 | self.files = sorted(glob.glob(filename_pattern)) 300 | assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}" 301 | 302 | # load and validate all data shards, count number of tokens in total 303 | ntok_total = 0 304 | for fname in self.files: 305 | shard_ntok = _peek_data_shard(fname) 306 | assert shard_ntok >= num_processes * B * T + 1 307 | ntok_total += int(shard_ntok) 308 | self.ntok_total = ntok_total 309 | 310 | # kick things off 311 | self.reset() 312 | 313 | def reset(self): 314 | self.current_shard = 0 315 | self.current_position = self.process_rank * self.B * self.T 316 | self.tokens = _load_data_shard(self.files[self.current_shard]) 317 | 318 | def advance(self): # advance to next data shard 319 | self.current_shard = (self.current_shard + 1) % len(self.files) 320 | self.current_position = self.process_rank * self.B * self.T 321 | self.tokens = _load_data_shard(self.files[self.current_shard]) 322 | 323 | def next_batch(self): 324 | B = self.B 325 | T = self.T 326 | buf = self.tokens[self.current_position : self.current_position+B*T+1] 327 | buf = torch.tensor(buf.astype(np.int32), dtype=torch.long) 328 | x = (buf[:-1]).view(B, T) # inputs 329 | y = (buf[1:]).view(B, T) # targets 330 | # advance current position and load next shard if necessary 331 | self.current_position += B * T * self.num_processes 332 | if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens): 333 | self.advance() 334 | return x.cuda(), y.cuda() 335 | 336 | # ----------------------------------------------------------------------------- 337 | # int main 338 | 339 | @dataclass 340 | class Hyperparameters: 341 | # data hyperparams 342 | input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on 343 | input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on 344 | # optimization hyperparams 345 | batch_size : int = 8*64 # batch size, in sequences, across all devices 346 | device_batch_size : int = 64 # batch size, in sequences, per device 347 | sequence_length : int = 1024 # sequence length, in tokens 348 | num_iterations : int = 3242 # number of iterations to run 349 | warmup_iters : int = 0 350 | warmdown_iters : int = 926 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule 351 | weight_decay : float = 0 352 | # evaluation and logging hyperparams 353 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end 354 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons 355 | save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end 356 | args = Hyperparameters() 357 | 358 | # set up DDP (distributed data parallel). torchrun sets this env variable 359 | assert torch.cuda.is_available() 360 | dist.init_process_group(backend='nccl') 361 | ddp_rank = int(os.environ['RANK']) 362 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 363 | ddp_world_size = int(os.environ['WORLD_SIZE']) 364 | device = f'cuda:{ddp_local_rank}' 365 | torch.cuda.set_device(device) 366 | print(f"using device: {device}") 367 | master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc. 368 | 369 | # begin logging 370 | logfile = None 371 | if master_process: 372 | run_id = str(uuid.uuid4()) 373 | logdir = 'logs/%s/' % run_id 374 | os.makedirs(logdir, exist_ok=True) 375 | logfile = 'logs/%s.txt' % run_id 376 | # create the log file 377 | with open(logfile, "w") as f: 378 | # begin the log by printing this file (the Python code) 379 | f.write('='*100 + '\n') 380 | f.write(code) 381 | f.write('='*100 + '\n') 382 | def print0(s, logonly=False): 383 | if master_process: 384 | with open(logfile, "a") as f: 385 | if not logonly: 386 | print(s) 387 | f.write(s+'\n') 388 | # log information about the hardware/software environment this is running on 389 | # and print the full `nvidia-smi` to file 390 | print0(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:") 391 | import subprocess 392 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 393 | print0(f'{result.stdout}', logonly=True) 394 | print0('='*100, logonly=True) 395 | 396 | # convenience variables 397 | B, T = args.device_batch_size, args.sequence_length 398 | # calculate the number of steps to take in the val loop. 399 | assert args.val_tokens % (B * T * ddp_world_size) == 0 400 | val_steps = args.val_tokens // (B * T * ddp_world_size) 401 | # calculate the steps of gradient accumulation required to attain the desired global batch size. 402 | assert args.batch_size % (B * ddp_world_size) == 0 403 | train_accumulation_steps = args.batch_size // (B * ddp_world_size) 404 | 405 | # load tokens 406 | train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size) 407 | val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size) 408 | print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files") 409 | print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files") 410 | print0('='*100, logonly=True) 411 | x, y = train_loader.next_batch() 412 | 413 | # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977. 414 | # this originates from Karpathy's experiments. 415 | num_vocab = 50304 416 | model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768)) 417 | model = model.cuda().bfloat16() 418 | for m in model.modules(): 419 | if isinstance(m, CastedLinear): 420 | m.float() 421 | if hasattr(config, "coordinate_descent_tuning"): 422 | config.coordinate_descent_tuning = True # suggested by @Chillee 423 | model = torch.compile(model) 424 | # here we wrap model into DDP container 425 | model = DDP(model, device_ids=[ddp_local_rank]) 426 | raw_model = model.module # always contains the "raw" unwrapped model 427 | 428 | # CUDNN attention is ~4ms faster than Flash, but doesn't get selected by default in PyTorch 2.5.1 429 | from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp 430 | enable_cudnn_sdp(True) 431 | enable_flash_sdp(False) 432 | enable_mem_efficient_sdp(False) 433 | enable_math_sdp(False) 434 | 435 | # init the optimizer(s) 436 | optimizer1 = torch.optim.Adam([raw_model.transformer.wte.weight], lr=0.3, betas=(0.9, 0.95), fused=True) 437 | optimizer2 = torch.optim.Adam([raw_model.lm_head.weight], lr=0.002, betas=(0.9, 0.95), fused=True) 438 | params = list(raw_model.transformer.h.parameters()) 439 | matrix_params = [p for p in params if p.ndim == 2] 440 | scalar_params = [p for p in params if p.ndim < 2] 441 | optimizer3 = Muon(matrix_params, lr=0.02, momentum=0.95) 442 | optimizer4 = torch.optim.Adam(scalar_params, lr=0.02, betas=(0.9, 0.95), fused=True) # note that this learning rate is neither sensitive nor tuned 443 | optimizers = [optimizer1, optimizer2, optimizer3, optimizer4] 444 | # learning rate decay scheduler (linear warmup and warmdown) 445 | def get_lr(it): 446 | assert it <= args.num_iterations 447 | # 1) linear warmup for warmup_iters steps 448 | if it < args.warmup_iters: 449 | return (it+1) / args.warmup_iters 450 | # 2) constant lr for a while 451 | elif it < args.num_iterations - args.warmdown_iters: 452 | return 1.0 453 | # 3) linear warmdown 454 | else: 455 | decay_ratio = (args.num_iterations - it) / args.warmdown_iters 456 | return decay_ratio 457 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] 458 | 459 | # Start training loop 460 | training_time_ms = 0 461 | # start the clock 462 | torch.cuda.synchronize() 463 | t0 = time.time() 464 | # begin training 465 | train_loader.reset() 466 | for step in range(args.num_iterations + 1): 467 | last_step = (step == args.num_iterations) 468 | # This effectively ignores timing first 10 steps, which are slower for weird reasons. 469 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10 470 | # steps with dummy data first, and then re-initialize the model and reset the loader. 471 | if step == 10: 472 | training_time_ms = 0 473 | t0 = time.time() 474 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val 475 | 476 | # once in a while evaluate the validation dataset 477 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)): 478 | # stop the clock 479 | torch.cuda.synchronize() 480 | training_time_ms += 1000 * (time.time() - t0) 481 | # run validation batches 482 | model.eval() 483 | val_loader.reset() 484 | val_loss = 0.0 485 | for _ in range(val_steps): 486 | with torch.no_grad(): 487 | x_val, y_val = val_loader.next_batch() 488 | val_loss += model(x_val, y_val) 489 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) 490 | val_loss /= val_steps 491 | # log val loss to console and to logfile 492 | print0(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms') 493 | # start the clock again 494 | torch.cuda.synchronize() 495 | t0 = time.time() 496 | 497 | if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)): 498 | # stop the clock 499 | torch.cuda.synchronize() 500 | training_time_ms += 1000 * (time.time() - t0) 501 | # save the state of the training process 502 | log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) 503 | torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step)) 504 | # start the clock again 505 | torch.cuda.synchronize() 506 | t0 = time.time() 507 | 508 | # bit confusing: we want to make sure to eval on 0th iteration 509 | # but also after the very last iteration. so we loop for step <= num_iterations 510 | # instead of just < num_iterations (one extra due to <=), only to do 511 | # the validation/sampling one last time, and then we break right here as we're done. 512 | if last_step: 513 | break 514 | 515 | # --------------- TRAINING SECTION BEGIN ----------------- 516 | model.train() 517 | for i in range(1, train_accumulation_steps+1): 518 | # forward pass 519 | loss = model(x, y) 520 | train_loss = loss.detach() 521 | # advance the dataset for the next batch 522 | x, y = train_loader.next_batch() 523 | # backward pass 524 | if i < train_accumulation_steps: 525 | with model.no_sync(): # there's no need to sync gradients every accumulation step 526 | loss.backward() 527 | else: 528 | loss.backward() # just sync on the last step 529 | for p in model.parameters(): 530 | p.grad /= train_accumulation_steps 531 | # momentum warmup for Muon 532 | frac = min(step/500, 1) 533 | optimizer3.param_groups[0]['momentum'] = (1 - frac) * 0.85 + frac * 0.95 534 | # step the optimizers and schedulers 535 | for opt, sched in zip(optimizers, schedulers): 536 | opt.step() 537 | sched.step() 538 | # null the gradients 539 | model.zero_grad(set_to_none=True) 540 | # --------------- TRAINING SECTION END ------------------- 541 | # everything that follows now is just diagnostics, prints, logging, etc. 542 | 543 | #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower 544 | approx_time = training_time_ms + 1000 * (time.time() - t0) 545 | print0(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms") 546 | 547 | if master_process: 548 | print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") 549 | 550 | # ------------------------------------------------------------------------- 551 | # clean up nice 552 | dist.destroy_process_group() 553 | --------------------------------------------------------------------------------