├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── data
├── cached_fineweb100B.py
├── cached_fineweb10B.py
├── fineweb.py
└── requirements.txt
├── img
├── algo_optimizer.png
├── dofa.jpg
├── fig_optimizer.png
├── fig_tuned_nanogpt.png
├── nanogpt_speedrun51.png
├── nanogpt_speedrun52.png
├── nanogpt_speedrun53.png
└── nanogpt_speedrun54.png
├── records
├── 060624_AdamW
│ ├── README.md
│ └── f66d43d7-e449-4029-8adf-e8537bab49ea.log
├── 100924_SOAP
│ ├── 5bdc3988-496c-4232-b4ef-53764cb81c92.txt
│ ├── README.md
│ └── train_gpt2.py
├── 101024_Muon
│ ├── eb5659d0-fb6a-49e5-a311-f1f89412f726.txt
│ └── train_gpt2.py
├── 101324_llmc
│ ├── README.md
│ └── main.log
├── 101424_ModernArch
│ ├── dabaaddd-237c-4ec9-939d-6608a9ed5e27.txt
│ └── train_gpt2.py
├── 101724_DistributedMuon
│ └── 22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt
├── 101824_PyTorch25
│ └── d4bfb25f-688d-4da5-8743-33926fad4842.txt
├── 102024_ScaleUp1B
│ ├── 87bd51fd-6203-4c88-b3aa-8a849a6a83ca.txt
│ ├── ad8d7ae5-7b2d-4ee9-bc52-f912e9174d7a.txt
│ └── c0078066-c8c9-49c8-868a-ff4d4f32e615.txt
├── 102924_Optimizers
│ ├── 8bfe4e35-c3fc-4b70-a984-3be937b71ff3.txt
│ ├── 8d6193f4-27fc-4e68-899f-af70019a4d54.txt
│ ├── 95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt
│ ├── README.md
│ ├── e21a2838-a0f2-46f2-a247-db0021165682.txt
│ ├── nanogpt_speedrun81w.png
│ └── nanogpt_speedrun82w.png
├── 110324_UntieEmbed
│ ├── README.md
│ └── d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt
├── 110424_50Bruns
│ ├── 3d715d41-453a-40d6-9506-421ba69766b2.txt
│ ├── 4fbe61ec-f79a-4c19-836d-46d599deecce.txt
│ ├── 530f3ee1-8862-4d21-be2b-da10eb05e6a9.txt
│ ├── 69c33fc9-eabb-4a38-aa08-6922914eb405.txt
│ └── README.md
├── 110624_ShortcutsTweaks
│ ├── 042f9e87-07e6-4504-bb04-4ec59a380211.txt
│ ├── 05b29e54-0be0-4a0f-a1e2-7d5317daedd3.txt
│ ├── 10119f53-7001-4248-bfd9-33d32427a912.txt
│ ├── 43f60c4f-0448-4de7-83d9-643ca26f61e7.txt
│ ├── 4a71cc92-0f43-4058-a033-23e85c1e98f1.txt
│ ├── README.md
│ ├── d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt
│ ├── dd7304a6-cc43-4d5e-adb8-c070111464a1.txt
│ ├── nanogpt_speedrun110.png
│ └── nanogpt_speedrun111.png
├── 110824_CastBf16
│ └── a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt
├── 110924_Replicateleloykun
│ ├── 1621af10-aa0c-42af-bf54-8a773c63a2af.txt
│ └── README.md
├── 111024_ScaleShortcuts
│ ├── 3e55eb2e-6261-466a-b1e9-2b31f56fb16a.txt
│ ├── 4897c987-9d09-435c-a23f-20585912936a.txt
│ ├── 70a0ada6-8dee-4fef-8980-135379479c21.txt
│ ├── a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt
│ └── d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt
├── 111024_UNetDoubleLr
│ ├── README.md
│ └── c87bb826-797b-4f37-98c7-d3a5dad2de74.txt
└── 111424_QuantizedFP4
│ ├── 433c1732-0c3d-4099-a4a8-ec31eae49b16.txt
│ ├── 70a0ada6-8dee-4fef-8980-135379479c21.txt
│ ├── 932bbe0e-41c3-4a5b-94bd-4ea3350909bd.txt
│ └── a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt
├── requirements.txt
├── run.sh
├── run_rwkv6.sh
├── run_rwkv7.sh
├── rwkv_cuda
├── wkv6_cuda.cu
├── wkv6_op.cpp
├── wkv7g_op.cpp
└── wkv7g_v1.cu
├── rwkv_cuda_wind
├── backstepping_f32.cpp
├── backstepping_f32_1.cu
├── backstepping_f32_2.cu
├── tile.cuh
├── wind_rwkv7.cpp
└── wind_rwkv7.cu
├── rwkv_records
├── RWKV-6-2024-10-20-07-02-38.txt
├── RWKV-7-2024-10-21-14-51-20.txt
├── RWKV-7-fast-2024-10-29-07-48-34.txt
└── RWKV-7-fast-2024-11-09-19-49-33.txt
├── train_gpt2.py
├── train_rwkv6.py
└── train_rwkv7.py
/.gitignore:
--------------------------------------------------------------------------------
1 | fineweb10B/
2 | pylog124M/
3 | __pycache__/
4 | logs/
5 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:12.6.2-cudnn-devel-ubuntu24.04
2 |
3 | ENV DEBIAN_FRONTEND=noninteractive
4 | ENV PYTHON_VERSION=3.12.7
5 | ENV PATH=/usr/local/bin:$PATH
6 |
7 | RUN apt update && apt install -y --no-install-recommends build-essential libssl-dev zlib1g-dev \
8 | libbz2-dev libreadline-dev libsqlite3-dev curl git libncursesw5-dev xz-utils tk-dev libxml2-dev \
9 | libxmlsec1-dev libffi-dev liblzma-dev \
10 | && apt clean && rm -rf /var/lib/apt/lists/*
11 |
12 | RUN curl -O https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz && \
13 | tar -xzf Python-${PYTHON_VERSION}.tgz && \
14 | cd Python-${PYTHON_VERSION} && \
15 | ./configure --enable-optimizations && \
16 | make -j$(nproc) && \
17 | make altinstall && \
18 | cd .. && \
19 | rm -rf Python-${PYTHON_VERSION} Python-${PYTHON_VERSION}.tgz
20 |
21 | RUN ln -s /usr/local/bin/python3.12 /usr/local/bin/python && \
22 | ln -s /usr/local/bin/pip3.12 /usr/local/bin/pip
23 |
24 | COPY requirements.txt /modded-nanogpt/requirements.txt
25 | WORKDIR /modded-nanogpt
26 |
27 | RUN python -m pip install --upgrade pip && \
28 | pip install -r requirements.txt
29 |
30 | CMD ["bash"]
31 | ENTRYPOINT []
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Keller Jordan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Modded-NanoGPT-RWKV
2 |
3 | RWKV Discord: https://discord.gg/bDSBUMeFpc
4 |
5 | RWKV Twitter: https://twitter.com/BlinkDL_AI
6 |
7 | ## RWKV-6 and RWKV-7
8 |
9 | ### Latest run: 3200 steps to reach 3.27xx loss
10 |
11 | This is using latest (current) train_rwkv7.py
12 | ```
13 | ./run_rwkv7.sh --adam_lr 0.0026 --muon_lr 0.02 --ln_lr 0.0090 --headsz 64 --bsz 512 --device_bsz 32 --fast_cuda
14 | ```
15 |
16 | ### Old run: 5100 steps to reach 3.27xx loss
17 |
18 | This is using old train_rwkv7.py
19 |
20 | Please read https://x.com/BlinkDL_AI/status/1848343821467390156 first.
21 |
22 | Modded-GPT 123.6M headsize 128 => val_loss 3.27xx
23 |
24 | RWKV-7 123.7M headsize 64 => val_loss 3.2715 (increase headsize to reach 3.26xx)
25 |
26 | RWKV-6 123.7M headsize 64 => val_loss 3.2914
27 |
28 | RWKV-6 123.7M headsize 192 => val_loss 3.28xx
29 |
30 | Check https://github.com/BlinkDL/modded-nanogpt-rwkv/tree/master/rwkv_records for training log.
31 |
32 | Try 0.0020/0.0022/0.0024 for adam_lr. Try 1.5/2/2.5 for emb_scale. Reduce device_bsz if OOM (will gradient accumulate).
33 | ```
34 | Note: Currently inefficient implementation. Please help if you are a Pytorch / CUDA / triton master :)
35 |
36 | ./run_rwkv7.sh --adam_lr 0.0022 --emb_scale 2 --muon_lr 0.00036 --headsz 64 --bsz 512 --device_bsz 64 --fast_cuda (much faster cuda)
37 |
38 | ./run_rwkv7.sh --adam_lr 0.0022 --emb_scale 2 --muon_lr 0.00036 --headsz 64 --bsz 512 --device_bsz 32 (reference, takes more VRAM, have to reduce device_bsz)
39 |
40 | ./run_rwkv7.sh --adam_lr 0.0022 --emb_scale 2 --muon_lr 0.00036 --headsz 64 --bsz 512 --device_bsz 64 --wind_cuda (even faster cuda, likely worse loss)
41 |
42 | ./run_rwkv6.sh --adam_lr 0.0020 --emb_scale 1.5 --muon_lr 0.00036 --headsz 64 --bsz 512 --device_bsz 64
43 | ```
44 |
45 | ## Original Readme
46 |
47 | This is a modified variant of the [PyTorch GPT-2 trainer](https://github.com/karpathy/llm.c/blob/7b929300217ff1a974b63791a228928b39b26409/train_gpt2.py) from
48 | Andrej Karpathy's [llm.c](https://github.com/karpathy/llm.c) repo, which attains the same final validation loss in:
49 | * 1.7B tokens instead of 10B
50 | * 7.8 minutes on 8xH100 instead of 45
51 |
52 | It uses the following techniques:
53 | * Modernized architecture: Rotary embeddings, QK-Norm, and ReLU^2.
54 | * New optimizer: Muon - Momentum Orthogonalized by Newton-schulz.
55 | * Untied head from embedding.
56 | * Projection and classification layers initialized to zero (muP-like).
57 | * Architectural shortcuts: value residual and embedding shortcut (partially following https://arxiv.org/abs/2410.17897).
58 | * Momentum warmup.
59 | * Tanh soft logit capping (following Gemma 2).
60 |
61 | ---
62 |
63 | ## Running the training
64 |
65 | To execute the training, run the following three commands.
66 | They should all complete within <20min on an 8xH100 with decent internet connection.
67 | ```bash
68 | pip install -r requirements.txt
69 | python data/cached_fineweb10B.py 18 # downloads only the first 1.8B training tokens to save time
70 | ./run.sh
71 | ```
72 |
73 | The result will be a transformer with 124M active parameters trained for 3242 steps on 1.7B tokens of Fineweb [1], achieving ~3.278 validation loss.
74 | For comparison, the default llm.c PyTorch trainer yields [>3.28 validation loss after training for 19560 steps on 10B tokens](https://github.com/karpathy/llm.c/discussions/481#:~:text=By%20the%20end%20of%20the%20optimization%20we%27ll%20get%20to%20about%203.29).
75 |
76 | ## Running it on fewer GPUs or with less memory
77 |
78 | * To run on fewer GPUs, just modify `run.sh` to have a different `--nproc_per_node`.
79 | * If you're running out of memory, then go into `train_gpt2.py` and scale down the `device_batch_size` to either 16 or 32.
80 |
81 | Both of these changes will have no effect on the training - you should get the exact same loss curve as the most recent record, because the training code
82 | will automatically adjust the gradient accumulation in order to have the same total batch size.
83 |
84 | ## Running with Docker
85 |
86 | For cases where CUDA or NCCL versions aren't compatible with your current system setup, Docker can be a helpful alternative.
87 | This approach standardizes versions for CUDA, NCCL, CUDNN, and Python, reducing dependency issues and simplifying setup.
88 | Note: an NVIDIA driver must already be installed on the system (useful if only the NVIDIA driver and Docker are available).
89 |
90 | ```bash
91 | sudo docker build -t modded-nanogpt .
92 | sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt python data/cached_fineweb10B.py 18
93 | sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt sh run.sh
94 | ```
95 | ---
96 |
97 | ## World record history
98 |
99 | The following is the progression of world records for the task of *training a model with 124M active parameters to 3.28 validation loss on FineWeb in the minimal amount of time on an 8xH100 machine.*
100 |
101 | 1. [45 minutes: llm.c baseline](https://github.com/karpathy/llm.c/discussions/481) (05/28/24) [[training log](./records/101324_llmc/main.log)] (note: the 90 minute time is on 8xA100; it's 45 minutes on 8xH100. This run is essentially a hardware-optimized GPT-2 (small) replication using better training data.)
102 | 2. [31.4 minutes: Architectural modernizations and learning rate tuning](https://x.com/kellerjordan0/status/1798863559243513937) (06/06/24) [[training log](./records/060624_AdamW/f66d43d7-e449-4029-8adf-e8537bab49ea.log)]
103 | 3. [24.9 minutes: Introduced the Muon optimizer](https://x.com/kellerjordan0/status/1842300916864844014) (10/04/24)
104 | 4. [22.3 minutes: Muon improvements](https://x.com/kellerjordan0/status/1844820919061287009) (10/11/24) [[reproducible log](./records/101024_Muon/eb5659d0-fb6a-49e5-a311-f1f89412f726.txt)]
105 | 5. [15.2 minutes: Pad embeddings & architectural modernizations](https://x.com/kellerjordan0/status/1845865698532450646) (10/14/24) [[reproducible log](./records/101424_ModernArch/dabaaddd-237c-4ec9-939d-6608a9ed5e27.txt)]
106 | 6. [13.1 minutes: Distributed the overhead of Muon](https://x.com/kellerjordan0/status/1847291684016783746) (10/18/24) [[reproducible log](./records/101724_DistributedMuon/22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt)]
107 | 7. [12.0 minutes: Upgraded PyTorch from 2.4.1 to 2.5.0](https://x.com/kellerjordan0/status/1847358578686152764) (10/18/24) [[reproducible log](./records/101824_PyTorch25/d4bfb25f-688d-4da5-8743-33926fad4842.txt)]
108 | 8. [10.8 minutes: Untied embed and lm_head](https://x.com/kellerjordan0/status/1853188916704387239) (11/03/24) [[reproducible log](./records/110324_UntieEmbed/d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt)]
109 | 9. [8.2 minutes: Shortcuts & tweaks](https://x.com/kellerjordan0/status/1854296101303800108) (11/06/24) [[reproducible log](./records/110624_ShortcutsTweaks/dd7304a6-cc43-4d5e-adb8-c070111464a1.txt)]
110 | 11. [7.8 minutes: Bfloat16 activations](https://x.com/kellerjordan0/status/1855267054774865980) (11/08/24) [[reproducible log](./records/110824_CastBf16/a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt)]
111 | 12. [7.23 minutes: U-net & 2x lr](https://x.com/kellerjordan0/status/1856053121103093922) (11/10/24) [[reproducible log](./records/111024_UNetDoubleLr/c87bb826-797b-4f37-98c7-d3a5dad2de74.txt)]
112 |
113 | Please see the X threads for the contributors to each record.
114 |
115 | The `train_gpt2.py` in this repo is the 11/08/24 record. To run the latest 11/10/24 record, use the code in its reproducible log.
116 |
117 |
121 |
124 |
125 |
131 |
132 | ### Notable attempts
133 |
134 | 1. [An 11/07/24 attempt, which I attempted to cerify on 11/09/24](./records/110924_Replicateleloykun)
135 |
136 | ### Notable forks
137 |
138 | * [https://github.com/BlinkDL/modded-nanogpt-rwkv](https://github.com/BlinkDL/modded-nanogpt-rwkv)
139 | * [https://github.com/nikhilvyas/modded-nanogpt-SOAP](https://github.com/nikhilvyas/modded-nanogpt-SOAP)
140 |
141 | ### Speedrun rules
142 |
143 | 1. Must not modify the train or validation data pipelines (except to change batch size if you want).
144 | 2. Must use ≤ 124M active parameters per token.
145 | 3. Must attain ≤ 3.28 val loss. A tasteful number would be 3.278 so that [this doesn't happen](./records/110924_Replicateleloykun/1621af10-aa0c-42af-bf54-8a773c63a2af.txt#L3780).
146 |
147 | Other than that, go crazy! Anything is fair game
148 |
149 | ---
150 |
151 | ### Q: What is the point of NanoGPT speedrunning?
152 |
153 | A: The officially stated goal of NanoGPT speedrunning is as follows: `gotta go fast`. But for something a little more verbose involving an argument for good benchmarking, here's some kind of manifesto, adorned with a blessing from the master. [https://x.com/karpathy/status/1846790537262571739](https://x.com/karpathy/status/1846790537262571739)
154 |
155 | ### Q: What makes "NanoGPT speedrunning" not just another idiosyncratic benchmark?
156 |
157 | A: Because it is a *competitive* benchmark. In particular, if you attain a new speed record (using whatever method you want), there is an open invitation for you
158 | to post that record (on arXiv or X) and thereby vacuum up all the clout for yourself. I will even help you do it by reposting you as much as I can.
159 |
160 |
167 |
168 | ["Artificial intelligence advances by inventing games and gloating to goad others to play" - Professor Ben Recht](https://www.argmin.net/p/too-much-information)
169 |
170 | ### Q: NanoGPT speedrunning is cool and all, but meh it probably won't scale and is just overfitting to val loss
171 |
172 | A: This is hard to refute, since "at scale" is an infinite category (what if the methods stop working only for >100T models?), making it impossible to fully prove.
173 | Also, I would agree that some of the methods used in the speedrun are unlikely to scale.
174 | But if the reader cares about 1.5B models, they might be convinced by this result:
175 |
176 | *Straightforwardly scaling up the speedrun (10/18/24 version) to 1.5B parameters yields a model with GPT-2 (1.5B)-level HellaSwag performance 2.5x more cheaply than [@karpathy's baseline](https://github.com/karpathy/llm.c/discussions/677) ($233 instead of $576):*
177 |
178 | 
179 | [[reproducible log](https://github.com/KellerJordan/modded-nanogpt/blob/master/records/102024_ScaleUp1B/ad8d7ae5-7b2d-4ee9-bc52-f912e9174d7a.txt)]
180 | 
181 |
182 | ---
183 |
184 | ## [Muon optimizer](https://github.com/KellerJordan/Muon)
185 |
186 | Muon is defined as follows:
187 |
188 | 
189 |
190 | Where NewtonSchulz5 is the following Newton-Schulz iteration [2, 3], which approximately replaces `G` with `U @ V.T` where `U, S, V = G.svd()`.
191 | ```python
192 | @torch.compile
193 | def zeroth_power_via_newtonschulz5(G, steps=5, eps=1e-7):
194 | assert len(G.shape) == 2
195 | a, b, c = (3.4445, -4.7750, 2.0315)
196 | X = G.bfloat16() / (G.norm() + eps)
197 | if G.size(0) > G.size(1):
198 | X = X.T
199 | for _ in range(steps):
200 | A = X @ X.T
201 | B = b * A + c * A @ A
202 | X = a * X + B @ X
203 | if G.size(0) > G.size(1):
204 | X = X.T
205 | return X.to(G.dtype)
206 | ```
207 |
208 | For this training scenario, Muon has the following favorable properties:
209 | * Lower memory usage than Adam
210 | * ~1.5x better sample-efficiency
211 | * <2% wallclock overhead
212 |
213 |
214 | ### Provenance
215 |
216 | Many of the choices made to generate this optimizer were obtained experimentally by our pursuit of [CIFAR-10 speedrunning](https://github.com/KellerJordan/cifar10-airbench).
217 | In particular, we experimentally obtained the following practices:
218 | * Using Nesterov momentum inside the update, with orthogonalization applied after momentum.
219 | * Using a specifically quintic Newton-Schulz iteration as the method of orthogonalization.
220 | * Using non-convergent coefficients for the quintic polynomial in order to maximize slope at zero, and thereby minimize the number of necessary Newton-Schulz iterations.
221 | It turns out that the variance doesn't actually matter that much, so we end up with a quintic that (rapidly) converges to the range 0.68, 1.13 upon repeated application, rather than to 1.
222 | * Running the Newton-Schulz iteration in bfloat16 (whereas Shampoo implementations often depend on inverse-pth-roots run in fp32 or fp64).
223 |
224 | Our use of a Newton-Schulz iteration for orthogonalization traces to [Bernstein & Newhouse (2024)](https://arxiv.org/abs/2409.20325),
225 | who suggested it as a way to compute Shampoo [5, 6] preconditioners, and theoretically explored Shampoo without preconditioner accumulation.
226 | In particular, Jeremy Bernstein @jxbz sent us the draft, which caused us to experiment with various Newton-Schulz iterations as the
227 | orthogonalization method for this optimizer.
228 | If we had used SVD instead of a Newton-Schulz iteration, this optimizer would have been too slow to be useful.
229 | Bernstein & Newhouse also pointed out that Shampoo without preconditioner accumulation is equivalent to steepest descent in the spectral norm,
230 | and therefore Shampoo can be thought of as a way to smooth out spectral steepest descent.
231 | The proposed optimizer can be thought of as a second way of smoothing spectral steepest descent, with a different set of memory and runtime tradeoffs
232 | compared to Shampoo.
233 |
234 | ---
235 |
236 | ## Startup script
237 |
238 | Here's a good startup script for a fresh 8xH100 instance.
239 |
240 | ```
241 | sudo apt-get update
242 | sudo apt-get install vim tmux python3-pip python-is-python3 -y
243 | git clone https://github.com/KellerJordan/modded-nanogpt.git
244 | cd modded-nanogpt
245 | tmux
246 |
247 | pip install numpy==1.23.5 huggingface-hub tqdm
248 | pip install --upgrade torch &
249 | python data/cached_fineweb10B.py 18
250 | ```
251 |
252 | ---
253 |
254 | ## References
255 |
256 | 1. [Penedo, Guilherme, et al. "The fineweb datasets: Decanting the web for the finest text data at scale." arXiv preprint arXiv:2406.17557 (2024).](https://arxiv.org/abs/2406.17557)
257 | 2. Nicholas J. Higham. Functions of Matrices. Society for Industrial and Applied Mathematics, 2008. Equation 5.22.
258 | 3. Günther Schulz. Iterative Berechnung der reziproken Matrix. Z. Angew. Math. Mech., 13:57–59, 1933.
259 | 4. [Jeremy Bernstein and Laker Newhouse. "Old Optimizer, New Norm: An Anthology." arxiv preprint arXiv:2409.20325 (2024).](https://arxiv.org/abs/2409.20325)
260 | 5. [Vineet Gupta, Tomer Koren, and Yoram Singer. "Shampoo: Preconditioned stochastic tensor optimization." International Conference on Machine Learning. PMLR, 2018.](https://arxiv.org/abs/1802.09568)
261 | 6. [Anil, Rohan, et al. "Scalable second order optimization for deep learning." arXiv preprint arXiv:2002.09018 (2020).](https://arxiv.org/abs/2002.09018)
262 | 7. [Hägele, Alexander, et al. "Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations." arXiv preprint arXiv:2405.18392 (2024).](https://arxiv.org/abs/2405.18392)
263 |
264 | [](https://www.youtube.com/watch?v=dv13gl0a-FA)
265 |
266 |
267 |
268 |
--------------------------------------------------------------------------------
/data/cached_fineweb100B.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from huggingface_hub import hf_hub_download
4 | # Download the GPT-2 tokens of Fineweb100B from huggingface. This
5 | # saves about an hour of startup time compared to regenerating them.
6 | def get(fname):
7 | local_dir = os.path.join(os.path.dirname(__file__), 'fineweb100B')
8 | if not os.path.exists(os.path.join(local_dir, fname)):
9 | hf_hub_download(repo_id="kjj0/fineweb100B-gpt2", filename=fname,
10 | repo_type="dataset", local_dir=local_dir)
11 | get("fineweb_val_%06d.bin" % 0)
12 | num_chunks = 1030 # full fineweb100B. Each chunk is 100M tokens
13 | if len(sys.argv) >= 2: # we can pass an argument to download less
14 | num_chunks = int(sys.argv[1])
15 | for i in range(1, num_chunks+1):
16 | get("fineweb_train_%06d.bin" % i)
17 |
--------------------------------------------------------------------------------
/data/cached_fineweb10B.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from huggingface_hub import hf_hub_download
4 | # Download the GPT-2 tokens of Fineweb10B from huggingface. This
5 | # saves about an hour of startup time compared to regenerating them.
6 | def get(fname):
7 | local_dir = os.path.join(os.path.dirname(__file__), 'fineweb10B')
8 | if not os.path.exists(os.path.join(local_dir, fname)):
9 | hf_hub_download(repo_id="kjj0/fineweb10B-gpt2", filename=fname,
10 | repo_type="dataset", local_dir=local_dir)
11 | get("fineweb_val_%06d.bin" % 0)
12 | num_chunks = 103 # full fineweb10B. Each chunk is 100M tokens
13 | if len(sys.argv) >= 2: # we can pass an argument to download less
14 | num_chunks = int(sys.argv[1])
15 | for i in range(1, num_chunks+1):
16 | get("fineweb_train_%06d.bin" % i)
17 |
--------------------------------------------------------------------------------
/data/fineweb.py:
--------------------------------------------------------------------------------
1 | """
2 | FineWeb dataset (for srs pretraining)
3 | https://huggingface.co/datasets/HuggingFaceFW/fineweb
4 |
5 | example doc to highlight the structure of the dataset:
6 | {
7 | "text": "Posted by mattsmith on 20th April 2012\nStraight from...",
8 | "id": "",
9 | "dump": "CC-MAIN-2013-20",
10 | "url": "http://nleastchatter.com/philliesphandom/tag/freddy-galvis/",
11 | "date": "2013-05-18T07:24:47Z",
12 | "file_path": "s3://commoncrawl/long.../path.../file.gz",
13 | "language": "en",
14 | "language_score": 0.9185474514961243,
15 | "token_count": 594
16 | }
17 | """
18 | import os
19 | import argparse
20 | import multiprocessing as mp
21 | import numpy as np
22 | import tiktoken
23 | # from huggingface_hub import snapshot_download
24 | from datasets import load_dataset
25 | from tqdm import tqdm
26 | import argparse
27 | import numpy as np
28 | def write_datafile(filename, toks):
29 | """
30 | Saves token data as a .bin file, for reading in C.
31 | - First comes a header with 256 int32s
32 | - The tokens follow, each as a uint16
33 | """
34 | assert len(toks) < 2**31, "token count too large" # ~2.1B tokens
35 | # construct the header
36 | header = np.zeros(256, dtype=np.int32)
37 | header[0] = 20240520 # magic
38 | header[1] = 1 # version
39 | header[2] = len(toks) # number of tokens after the 256*4 bytes of header (each 2 bytes as uint16)
40 | # construct the tokens numpy array, if not already
41 | if not isinstance(toks, np.ndarray) or not toks.dtype == np.uint16:
42 | # validate that no token exceeds a uint16
43 | maxtok = 2**16
44 | assert all(0 <= t < maxtok for t in toks), "token dictionary too large for uint16"
45 | toks_np = np.array(toks, dtype=np.uint16)
46 | else:
47 | toks_np = toks
48 | # write to file
49 | print(f"writing {len(toks):,} tokens to {filename}")
50 | with open(filename, "wb") as f:
51 | f.write(header.tobytes())
52 | f.write(toks_np.tobytes())
53 | # ------------------------------------------
54 |
55 | parser = argparse.ArgumentParser(description="FineWeb dataset preprocessing")
56 | parser.add_argument("-v", "--version", type=str, default="10B", help="Which version of fineweb to use 10B|100B")
57 | parser.add_argument("-s", "--shard_size", type=int, default=10**8, help="Size of each shard in tokens")
58 | args = parser.parse_args()
59 |
60 | # FineWeb has a few possible subsamples available
61 | assert args.version in ["10B", "100B"], "version must be one of 10B, 100B"
62 | if args.version == "10B":
63 | local_dir = "fineweb10B"
64 | remote_name = "sample-10BT"
65 | elif args.version == "100B":
66 | local_dir = "fineweb100B"
67 | remote_name = "sample-100BT"
68 |
69 | # create the cache the local directory if it doesn't exist yet
70 | DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)
71 | os.makedirs(DATA_CACHE_DIR, exist_ok=True)
72 |
73 | # download the dataset
74 | fw = load_dataset("HuggingFaceFW/fineweb", name=remote_name, split="train")
75 |
76 | # init the tokenizer
77 | enc = tiktoken.get_encoding("gpt2")
78 | eot = enc._special_tokens['<|endoftext|>'] # end of text token
79 | def tokenize(doc):
80 | # tokenizes a single document and returns a numpy array of uint16 tokens
81 | tokens = [eot] # the special <|endoftext|> token delimits all documents
82 | tokens.extend(enc.encode_ordinary(doc["text"]))
83 | tokens_np = np.array(tokens)
84 | assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
85 | tokens_np_uint16 = tokens_np.astype(np.uint16)
86 | return tokens_np_uint16
87 |
88 | # tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder)
89 | nprocs = max(1, os.cpu_count() - 2) # don't hog the entire system
90 | with mp.Pool(nprocs) as pool:
91 | shard_index = 0
92 | # preallocate buffer to hold current shard
93 | all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16)
94 | token_count = 0
95 | progress_bar = None
96 | for tokens in pool.imap(tokenize, fw, chunksize=16):
97 |
98 | # is there enough space in the current shard for the new tokens?
99 | if token_count + len(tokens) < args.shard_size:
100 | # simply append tokens to current shard
101 | all_tokens_np[token_count:token_count+len(tokens)] = tokens
102 | token_count += len(tokens)
103 | # update progress bar
104 | if progress_bar is None:
105 | progress_bar = tqdm(total=args.shard_size, unit="tokens", desc=f"Shard {shard_index}")
106 | progress_bar.update(len(tokens))
107 | else:
108 | # write the current shard and start a new one
109 | split = "val" if shard_index == 0 else "train"
110 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin")
111 | # split the document into whatever fits in this shard; the remainder goes to next one
112 | remainder = args.shard_size - token_count
113 | progress_bar.update(remainder)
114 | all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
115 | write_datafile(filename, all_tokens_np)
116 | shard_index += 1
117 | progress_bar = None
118 | # populate the next shard with the leftovers of the current doc
119 | all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:]
120 | token_count = len(tokens)-remainder
121 |
122 | # write any remaining tokens as the last shard
123 | if token_count != 0:
124 | split = "val" if shard_index == 0 else "train"
125 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin")
126 | write_datafile(filename, all_tokens_np[:token_count])
127 |
--------------------------------------------------------------------------------
/data/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets
2 | tiktoken
3 |
--------------------------------------------------------------------------------
/img/algo_optimizer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/algo_optimizer.png
--------------------------------------------------------------------------------
/img/dofa.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/dofa.jpg
--------------------------------------------------------------------------------
/img/fig_optimizer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/fig_optimizer.png
--------------------------------------------------------------------------------
/img/fig_tuned_nanogpt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/fig_tuned_nanogpt.png
--------------------------------------------------------------------------------
/img/nanogpt_speedrun51.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/nanogpt_speedrun51.png
--------------------------------------------------------------------------------
/img/nanogpt_speedrun52.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/nanogpt_speedrun52.png
--------------------------------------------------------------------------------
/img/nanogpt_speedrun53.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/nanogpt_speedrun53.png
--------------------------------------------------------------------------------
/img/nanogpt_speedrun54.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/img/nanogpt_speedrun54.png
--------------------------------------------------------------------------------
/records/060624_AdamW/README.md:
--------------------------------------------------------------------------------
1 | This is the log for my baseline AdamW training to which I compared the new Muon and SOAP optimizers.
2 |
3 | just the log, which is in the old llm.c format ("tel" lines are val loss)
4 |
5 | this was batch size 2^19, so ~5B tokens
6 |
7 | was learning rate 0.0018, warmup=250, warmdown=2000, betas=(0.9, 0.95) IIRC
8 |
9 |
--------------------------------------------------------------------------------
/records/100924_SOAP/README.md:
--------------------------------------------------------------------------------
1 | # SOAP record October 9 2024
2 |
3 | * New sample efficiency record: <3.28 validation loss in 3.15B tokens
4 | * Uses SOAP optimizer ([Vyas et al. 2024](https://arxiv.org/abs/2409.11321))
5 | * 363ms/step - not a new wallclock record (SOAP is in active development to reduce the wallclock overhead for distributed training, so this may change)
6 | * Set by Nikhil Vyas @vyasnikhil96. Hyperparameters also tuned slightly by me
7 | * [https://x.com/vyasnikhil96/status/1842656792217858063](https://x.com/vyasnikhil96/status/1842656792217858063)
8 | * [https://github.com/nikhilvyas/modded-nanogpt-SOAP/tree/master](https://github.com/nikhilvyas/modded-nanogpt-SOAP/tree/master)
9 |
10 |
--------------------------------------------------------------------------------
/records/101024_Muon/train_gpt2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | with open(sys.argv[0]) as f:
4 | code = f.read() # read the code of this file ASAP, for logging
5 | import uuid
6 | import glob
7 | import time
8 | from dataclasses import dataclass
9 |
10 | import numpy as np
11 | import torch
12 | from torch import nn
13 | import torch.nn.functional as F
14 | import torch.distributed as dist
15 | import torch._inductor.config as config
16 | from torch.nn.parallel import DistributedDataParallel as DDP
17 |
18 | # -----------------------------------------------------------------------------
19 | # Muon optimizer
20 |
21 | def zeropower_via_svd(G, steps=None):
22 | U, S, V = G.svd()
23 | return U @ V.T
24 |
25 | @torch.compile
26 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
27 | """
28 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
29 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
30 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
31 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
32 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
33 | where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
34 | performance at all relative to UV^T, where USV^T = G is the SVD.
35 | """
36 | assert len(G.shape) == 2
37 | a, b, c = (3.4445, -4.7750, 2.0315)
38 | X = G.bfloat16() / (G.norm() + eps) # ensure top singular value <= 1
39 | if G.size(0) > G.size(1):
40 | X = X.T
41 | for _ in range(steps):
42 | A = X @ X.T
43 | B = A @ X
44 | X = a * X + b * B + c * A @ B
45 | if G.size(0) > G.size(1):
46 | X = X.T
47 | return X.to(G.dtype)
48 |
49 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)
50 |
51 | class Muon(torch.optim.Optimizer):
52 | """
53 | Muon: MomentUm Orthogonalized by Newton-schulz
54 |
55 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
56 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
57 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
58 | the advantage that it can be stably run in bfloat16 on the GPU.
59 |
60 | Some warnings:
61 | - This optimizer assumes that all parameters passed in are 2D.
62 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
63 | parameters; those should all be optimized by a standard method (e.g., AdamW).
64 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
65 | - We believe it is unlikely to work well for training with small batch size.
66 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
67 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
68 |
69 | Arguments:
70 | lr: The learning rate used by the internal SGD.
71 | momentum: The momentum used by the internal SGD.
72 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
73 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
74 | backend_steps: The number of iteration steps to use in the backend, if it is iterative.
75 | """
76 | def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5):
77 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
78 | super().__init__(params, defaults)
79 |
80 | def step(self):
81 | for group in self.param_groups:
82 | lr = group['lr']
83 | momentum = group['momentum']
84 | zeropower_backend = zeropower_backends[group['backend']]
85 | for p in group['params']:
86 | g = p.grad
87 | if g is None:
88 | continue
89 | state = self.state[p]
90 | if 'momentum_buffer' not in state:
91 | state['momentum_buffer'] = torch.zeros_like(g)
92 | buf = state['momentum_buffer']
93 | buf.mul_(momentum).add_(g)
94 | if group['nesterov']:
95 | g = g.add(buf, alpha=momentum)
96 | if g.size(0) == 3 * g.size(1): # split grouped QKV parameters
97 | g = torch.cat([zeropower_backend(g1, steps=group['backend_steps']) for g1 in g.split(g.size(1))])
98 | scale = g.size(1)**0.5
99 | else:
100 | g = zeropower_backend(g, steps=group['backend_steps'])
101 | scale = max(g.size(0), g.size(1))**0.5 # scale to have update.square().mean() == 1
102 | p.data.add_(g, alpha=-lr * scale)
103 |
104 | # -----------------------------------------------------------------------------
105 | # PyTorch nn.Module definitions for the GPT-2 model
106 |
107 | class Rotary(torch.nn.Module):
108 |
109 | def __init__(self, dim, base=10000):
110 | super().__init__()
111 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
112 | self.register_buffer("inv_freq", inv_freq)
113 | self.seq_len_cached = None
114 | self.cos_cached = None
115 | self.sin_cached = None
116 |
117 | def forward(self, x):
118 | seq_len = x.shape[1]
119 | if seq_len != self.seq_len_cached:
120 | self.seq_len_cached = seq_len
121 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
122 | freqs = torch.outer(t, self.inv_freq).to(x.device)
123 | self.cos_cached = freqs.cos()
124 | self.sin_cached = freqs.sin()
125 | return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
126 |
127 | def apply_rotary_emb(x, cos, sin):
128 | assert x.ndim == 4 # multihead attention
129 | d = x.shape[3]//2
130 | x1 = x[..., :d]
131 | x2 = x[..., d:]
132 | y1 = x1 * cos + x2 * sin
133 | y2 = x1 * (-sin) + x2 * cos
134 | return torch.cat([y1, y2], 3)
135 |
136 | def rmsnorm(x0, eps=1e-6):
137 | x = x0.float()
138 | x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
139 | return x.type_as(x0)
140 |
141 | class CausalSelfAttention(nn.Module):
142 |
143 | def __init__(self, config):
144 | super().__init__()
145 | self.n_head = config.n_head
146 | self.n_embd = config.n_embd
147 | self.head_dim = self.n_embd // self.n_head
148 | assert self.n_embd % self.n_head == 0
149 | # key, query, value projections for all heads, but in a batch
150 | self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
151 | # output projection
152 | self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
153 | self.rotary = Rotary(self.head_dim)
154 |
155 | def forward(self, x):
156 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
157 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
158 | qkv = self.c_attn(x)
159 | q, k, v = qkv.split(self.n_embd, dim=2)
160 | k = k.view(B, T, self.n_head, self.head_dim)
161 | q = q.view(B, T, self.n_head, self.head_dim)
162 | v = v.view(B, T, self.n_head, self.head_dim)
163 | cos, sin = self.rotary(q)
164 | q = apply_rotary_emb(q, cos, sin)
165 | k = apply_rotary_emb(k, cos, sin)
166 | y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
167 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
168 | # output projection
169 | y = self.c_proj(y)
170 | return y
171 |
172 | class MLP(nn.Module):
173 |
174 | def __init__(self, config):
175 | super().__init__()
176 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
177 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
178 |
179 | def forward(self, x):
180 | x = self.c_fc(x)
181 | x = F.gelu(x)
182 | x = self.c_proj(x)
183 | return x
184 |
185 | class Block(nn.Module):
186 |
187 | def __init__(self, config):
188 | super().__init__()
189 | self.attn = CausalSelfAttention(config)
190 | self.mlp = MLP(config)
191 | self.attn_scale = (1 / (2 * config.n_layer)**0.5)
192 |
193 | def forward(self, x):
194 | x = x + self.attn_scale * self.attn(rmsnorm(x))
195 | x = x + self.mlp(rmsnorm(x))
196 | return x
197 |
198 | # -----------------------------------------------------------------------------
199 | # The main GPT-2 model
200 |
201 | @dataclass
202 | class GPTConfig:
203 | vocab_size : int = 50257
204 | n_layer : int = 12
205 | n_head : int = 12
206 | n_embd : int = 768
207 |
208 | class GPT(nn.Module):
209 |
210 | def __init__(self, config):
211 | super().__init__()
212 | self.config = config
213 |
214 | self.transformer = nn.ModuleDict(dict(
215 | wte = nn.Embedding(config.vocab_size, config.n_embd),
216 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
217 | ))
218 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
219 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
220 |
221 | def forward(self, idx, targets=None, return_logits=True):
222 | b, t = idx.size()
223 | pos = torch.arange(0, t, dtype=torch.long, device=idx.device) # shape (t)
224 |
225 | # forward the GPT model itself
226 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
227 |
228 | for block in self.transformer.h:
229 | x = block(x)
230 | x = rmsnorm(x)
231 |
232 | if targets is not None:
233 | # if we are given some desired targets also calculate the loss
234 | logits = self.lm_head(x)
235 | logits = logits.float() # use tf32/fp32 for logits
236 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
237 | else:
238 | # inference-time mini-optimization: only forward the lm_head on the very last position
239 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
240 | logits = logits.float() # use tf32/fp32 for logits
241 | loss = None
242 |
243 | # there are performance reasons why not returning logits is prudent, if not needed
244 | if not return_logits:
245 | logits = None
246 |
247 | return logits, loss
248 |
249 | # -----------------------------------------------------------------------------
250 | # Our own simple Distributed Data Loader
251 |
252 | def _peek_data_shard(filename):
253 | # only reads the header, returns header data
254 | with open(filename, "rb") as f:
255 | # first read the header, which is 256 int32 integers (4 bytes each)
256 | header = np.frombuffer(f.read(256*4), dtype=np.int32)
257 | if header[0] != 20240520:
258 | print("ERROR: magic number mismatch in the data .bin file!")
259 | print("---> HINT: Are you passing in a correct file with --input_bin?")
260 | print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
261 | print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
262 | exit(1)
263 | assert header[1] == 1, "unsupported version"
264 | ntok = header[2] # number of tokens (claimed)
265 | return ntok # for now just return the number of tokens
266 |
267 | def _load_data_shard(filename):
268 | with open(filename, "rb") as f:
269 | # first read the header, which is 256 int32 integers (4 bytes each)
270 | header = np.frombuffer(f.read(256*4), dtype=np.int32)
271 | assert header[0] == 20240520, "magic number mismatch in the data .bin file"
272 | assert header[1] == 1, "unsupported version"
273 | ntok = header[2] # number of tokens (claimed)
274 | # the rest of it are tokens, stored as uint16
275 | tokens = np.frombuffer(f.read(), dtype=np.uint16)
276 | assert len(tokens) == ntok, "number of tokens read does not match header?"
277 | return tokens
278 |
279 | class DistributedDataLoader:
280 | def __init__(self, filename_pattern, B, T, process_rank, num_processes):
281 | self.process_rank = process_rank
282 | self.num_processes = num_processes
283 | self.B = B
284 | self.T = T
285 |
286 | # glob files that match the pattern
287 | self.files = sorted(glob.glob(filename_pattern))
288 | assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
289 |
290 | # load and validate all data shards, count number of tokens in total
291 | ntok_total = 0
292 | for fname in self.files:
293 | shard_ntok = _peek_data_shard(fname)
294 | assert shard_ntok >= num_processes * B * T + 1
295 | ntok_total += int(shard_ntok)
296 | self.ntok_total = ntok_total
297 |
298 | # kick things off
299 | self.reset()
300 |
301 | def reset(self):
302 | self.current_shard = 0
303 | self.current_position = self.process_rank * self.B * self.T
304 | self.tokens = _load_data_shard(self.files[self.current_shard])
305 |
306 | def advance(self): # advance to next data shard
307 | self.current_shard = (self.current_shard + 1) % len(self.files)
308 | self.current_position = self.process_rank * self.B * self.T
309 | self.tokens = _load_data_shard(self.files[self.current_shard])
310 |
311 | def next_batch(self):
312 | B = self.B
313 | T = self.T
314 | buf = self.tokens[self.current_position : self.current_position+B*T+1]
315 | buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
316 | x = (buf[:-1]).view(B, T) # inputs
317 | y = (buf[1:]).view(B, T) # targets
318 | # advance current position and load next shard if necessary
319 | self.current_position += B * T * self.num_processes
320 | if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
321 | self.advance()
322 | return x.cuda(), y.cuda()
323 |
324 | # -----------------------------------------------------------------------------
325 | # int main
326 |
327 | @dataclass
328 | class Hyperparameters:
329 | # data hyperparams
330 | input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
331 | input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
332 | # optimization hyperparams
333 | batch_size : int = 8*64 # batch size, in sequences, across all devices
334 | device_batch_size : int = 64 # batch size, in sequences, per device
335 | sequence_length : int = 1024 # sequence length, in tokens
336 | num_iterations : int = 6200 # number of iterations to run
337 | learning_rate : float = 0.0036
338 | warmup_iters : int = 0
339 | warmdown_iters : int = 1800 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
340 | weight_decay : float = 0
341 | # evaluation and logging hyperparams
342 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
343 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
344 | save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end
345 | args = Hyperparameters()
346 |
347 | # set up DDP (distributed data parallel). torchrun sets this env variable
348 | assert torch.cuda.is_available()
349 | dist.init_process_group(backend='nccl')
350 | ddp_rank = int(os.environ['RANK'])
351 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
352 | ddp_world_size = int(os.environ['WORLD_SIZE'])
353 | device = f'cuda:{ddp_local_rank}'
354 | torch.cuda.set_device(device)
355 | print(f"using device: {device}")
356 | master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
357 |
358 | # convenience variables
359 | B, T = args.device_batch_size, args.sequence_length
360 | # calculate the number of steps to take in the val loop.
361 | assert args.val_tokens % (B * T * ddp_world_size) == 0
362 | val_steps = args.val_tokens // (B * T * ddp_world_size)
363 | # calculate the steps of gradient accumulation required to attain the desired global batch size.
364 | assert args.batch_size % (B * ddp_world_size) == 0
365 | train_accumulation_steps = args.batch_size // (B * ddp_world_size)
366 |
367 | # load tokens
368 | train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
369 | val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
370 | if master_process:
371 | print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
372 | print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
373 | x, y = train_loader.next_batch()
374 |
375 | # init the model from scratch
376 | num_vocab = 50257
377 | model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=12, n_embd=768))
378 | model = model.cuda()
379 | if hasattr(config, "coordinate_descent_tuning"):
380 | config.coordinate_descent_tuning = True # suggested by @Chillee
381 | model = torch.compile(model)
382 | # here we wrap model into DDP container
383 | model = DDP(model, device_ids=[ddp_local_rank])
384 | raw_model = model.module # always contains the "raw" unwrapped model
385 | ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
386 |
387 | # init the optimizer(s)
388 | optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95),
389 | weight_decay=args.weight_decay, fused=True)
390 | optimizer2 = Muon(raw_model.transformer.h.parameters(), lr=0.1*args.learning_rate, momentum=0.95)
391 | optimizers = [optimizer1, optimizer2]
392 | # learning rate decay scheduler (linear warmup and warmdown)
393 | def get_lr(it):
394 | assert it <= args.num_iterations
395 | # 1) linear warmup for warmup_iters steps
396 | if it < args.warmup_iters:
397 | return (it+1) / args.warmup_iters
398 | # 2) constant lr for a while
399 | elif it < args.num_iterations - args.warmdown_iters:
400 | return 1.0
401 | # 3) linear warmdown
402 | else:
403 | decay_ratio = (args.num_iterations - it) / args.warmdown_iters
404 | return decay_ratio
405 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
406 |
407 | # begin logging
408 | if master_process:
409 | run_id = str(uuid.uuid4())
410 | logdir = 'logs/%s/' % run_id
411 | os.makedirs(logdir, exist_ok=True)
412 | logfile = 'logs/%s.txt' % run_id
413 | # create the log file
414 | with open(logfile, "w") as f:
415 | # begin the log by printing this file (the Python code)
416 | f.write('='*100 + '\n')
417 | f.write(code)
418 | f.write('='*100 + '\n')
419 | # log information about the hardware/software environment this is running on
420 | # and print the full `nvidia-smi` to file
421 | f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n")
422 | import subprocess
423 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
424 | f.write(f'{result.stdout}\n')
425 | f.write('='*100 + '\n')
426 |
427 | training_time_ms = 0
428 | # start the clock
429 | torch.cuda.synchronize()
430 | t0 = time.time()
431 | # begin training
432 | train_loader.reset()
433 | for step in range(args.num_iterations + 1):
434 | last_step = (step == args.num_iterations)
435 | # This effectively ignores timing first 10 steps, which are slower for weird reasons.
436 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
437 | # steps with dummy data first, and then re-initialize the model and reset the loader.
438 | if step == 10:
439 | training_time_ms = 0
440 | t0 = time.time()
441 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
442 |
443 | # once in a while evaluate the validation dataset
444 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
445 | # stop the clock
446 | torch.cuda.synchronize()
447 | training_time_ms += 1000 * (time.time() - t0)
448 | # run validation batches
449 | model.eval()
450 | val_loader.reset()
451 | val_loss = 0.0
452 | for _ in range(val_steps):
453 | x_val, y_val = val_loader.next_batch()
454 | with torch.no_grad(): # of course, we'd like to use ctx here too, but that creates a torch.compile error for some reason
455 | _, loss = model(x_val, y_val, return_logits=False)
456 | val_loss += loss
457 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
458 | val_loss /= val_steps
459 | # log val loss to console and to logfile
460 | if master_process:
461 | print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
462 | with open(logfile, "a") as f:
463 | f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n')
464 | # start the clock again
465 | torch.cuda.synchronize()
466 | t0 = time.time()
467 |
468 | if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
469 | # stop the clock
470 | torch.cuda.synchronize()
471 | training_time_ms += 1000 * (time.time() - t0)
472 | # save the state of the training process
473 | log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
474 | torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
475 | # start the clock again
476 | torch.cuda.synchronize()
477 | t0 = time.time()
478 |
479 | # bit confusing: we want to make sure to eval on 0th iteration
480 | # but also after the very last iteration. so we loop for step <= num_iterations
481 | # instead of just < num_iterations (one extra due to <=), only to do
482 | # the validation/sampling one last time, and then we break right here as we're done.
483 | if last_step:
484 | break
485 |
486 | # --------------- TRAINING SECTION BEGIN -----------------
487 | model.train()
488 | for i in range(1, train_accumulation_steps+1):
489 | # forward pass
490 | with ctx:
491 | _, loss = model(x, y, return_logits=False)
492 | train_loss = loss.detach()
493 | # advance the dataset for the next batch
494 | x, y = train_loader.next_batch()
495 | # backward pass
496 | if i < train_accumulation_steps:
497 | with model.no_sync(): # there's no need to sync gradients every accumulation step
498 | loss.backward()
499 | else:
500 | loss.backward() # just sync on the last step
501 | for p in model.parameters():
502 | p.grad /= train_accumulation_steps
503 | # step the optimizers and schedulers
504 | for opt, sched in zip(optimizers, schedulers):
505 | opt.step()
506 | sched.step()
507 | # null the gradients
508 | model.zero_grad(set_to_none=True)
509 | # --------------- TRAINING SECTION END -------------------
510 | # everything that follows now is just diagnostics, prints, logging, etc.
511 |
512 | #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
513 | if master_process:
514 | approx_time = training_time_ms + 1000 * (time.time() - t0)
515 | print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
516 | with open(logfile, "a") as f:
517 | f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n")
518 |
519 | if master_process:
520 | print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
521 |
522 | # -------------------------------------------------------------------------
523 | # clean up nice
524 | dist.destroy_process_group()
525 |
--------------------------------------------------------------------------------
/records/101324_llmc/README.md:
--------------------------------------------------------------------------------
1 | This is a log produced by running the current version of Andrej Karpathy's [llm.c](https://github.com/karpathy/llm.c), as of October 13th 2024.
2 |
3 | It was run on a node with 8x H100 HBM3 according to the instructions [here](https://github.com/karpathy/llm.c/discussions/481).
4 | The mean per-step time was 140ms. The total number of training tokens is 10.26B. The final validation loss was **3.2722**.
5 |
6 | This is (significantly) better than the quoted result of **3.29** val loss in
7 | [Andrej Karpathy's May 28th GPT-2 replication discussion](https://github.com/karpathy/llm.c/discussions/481#:~:text=By%20the%20end%20of%20the%20optimization%20we%27ll%20get%20to%20about%203.29).
8 | So it appears that there have been some improvements to the training algorithm used by llm.c since then.
9 |
10 | Note that the set of examples which llm.c uses for validation appears to be the same as what we do in this repo, i.e., the first `10 * 2**20` tokens of the val set.
11 |
12 |
--------------------------------------------------------------------------------
/records/101424_ModernArch/train_gpt2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | with open(sys.argv[0]) as f:
4 | code = f.read() # read the code of this file ASAP, for logging
5 | import uuid
6 | import glob
7 | import time
8 | from dataclasses import dataclass
9 |
10 | import numpy as np
11 | import torch
12 | from torch import nn
13 | import torch.nn.functional as F
14 | import torch.distributed as dist
15 | import torch._inductor.config as config
16 | from torch.nn.parallel import DistributedDataParallel as DDP
17 |
18 | # -----------------------------------------------------------------------------
19 | # Muon optimizer
20 |
21 | def zeropower_via_svd(G, steps=None):
22 | U, S, V = G.svd()
23 | return U @ V.T
24 |
25 | @torch.compile
26 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
27 | """
28 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
29 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
30 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
31 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
32 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
33 | where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
34 | performance at all relative to UV^T, where USV^T = G is the SVD.
35 | """
36 | assert len(G.shape) == 2
37 | a, b, c = (3.4445, -4.7750, 2.0315)
38 | X = G.bfloat16()
39 | X /= (X.norm() + eps) # ensure top singular value <= 1
40 | if G.size(0) > G.size(1):
41 | X = X.T
42 | for _ in range(steps):
43 | A = X @ X.T
44 | B = A @ X
45 | X = a * X + b * B + c * A @ B
46 | if G.size(0) > G.size(1):
47 | X = X.T
48 | return X
49 |
50 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)
51 |
52 | class Muon(torch.optim.Optimizer):
53 | """
54 | Muon - MomentUm Orthogonalized by Newton-schulz
55 |
56 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
57 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
58 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
59 | the advantage that it can be stably run in bfloat16 on the GPU.
60 |
61 | Some warnings:
62 | - This optimizer assumes that all parameters passed in are 2D.
63 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
64 | parameters; those should all be optimized by a standard method (e.g., AdamW).
65 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
66 | - We believe it is unlikely to work well for training with small batch size.
67 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
68 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
69 |
70 | Arguments:
71 | lr: The learning rate used by the internal SGD.
72 | momentum: The momentum used by the internal SGD.
73 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
74 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
75 | backend_steps: The number of iteration steps to use in the backend, if it is iterative.
76 | """
77 | def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5):
78 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
79 | super().__init__(params, defaults)
80 |
81 | def step(self):
82 | for group in self.param_groups:
83 | lr = group['lr']
84 | momentum = group['momentum']
85 | zeropower_backend = zeropower_backends[group['backend']]
86 | for p in group['params']:
87 | g = p.grad
88 | if g is None:
89 | continue
90 | state = self.state[p]
91 | if 'momentum_buffer' not in state:
92 | state['momentum_buffer'] = torch.zeros_like(g)
93 | buf = state['momentum_buffer']
94 | buf.mul_(momentum).add_(g)
95 | if group['nesterov']:
96 | g = g.add(buf, alpha=momentum)
97 | if g.size(0) == 3 * g.size(1): # split grouped QKV parameters
98 | g = torch.cat([zeropower_backend(g1, steps=group['backend_steps']) for g1 in g.split(g.size(1))])
99 | scale = g.size(1)**0.5
100 | else:
101 | g = zeropower_backend(g, steps=group['backend_steps'])
102 | scale = max(g.size(0), g.size(1))**0.5 # scale to have update.square().mean() == 1
103 | p.data.add_(g, alpha=-lr * scale)
104 |
105 | # -----------------------------------------------------------------------------
106 | # PyTorch nn.Module definitions for the GPT-2 model
107 |
108 | class Rotary(torch.nn.Module):
109 |
110 | def __init__(self, dim, base=10000):
111 | super().__init__()
112 | self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
113 | self.seq_len_cached = None
114 | self.cos_cached = None
115 | self.sin_cached = None
116 |
117 | def forward(self, x):
118 | seq_len = x.shape[1]
119 | if seq_len != self.seq_len_cached:
120 | self.seq_len_cached = seq_len
121 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
122 | freqs = torch.outer(t, self.inv_freq).to(x.device)
123 | self.cos_cached = freqs.cos().bfloat16()
124 | self.sin_cached = freqs.sin().bfloat16()
125 | return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
126 |
127 | def apply_rotary_emb(x, cos, sin):
128 | assert x.ndim == 4 # multihead attention
129 | d = x.shape[3]//2
130 | x1 = x[..., :d]
131 | x2 = x[..., d:]
132 | y1 = x1 * cos + x2 * sin
133 | y2 = x1 * (-sin) + x2 * cos
134 | return torch.cat([y1, y2], 3).type_as(x)
135 |
136 | class CausalSelfAttention(nn.Module):
137 |
138 | def __init__(self, config):
139 | super().__init__()
140 | self.n_head = config.n_head
141 | self.n_embd = config.n_embd
142 | self.head_dim = self.n_embd // self.n_head
143 | assert self.n_embd % self.n_head == 0
144 | self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False)
145 | self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False)
146 | self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False)
147 | # output projection
148 | self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
149 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
150 | self.rotary = Rotary(self.head_dim)
151 |
152 | def forward(self, x):
153 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
154 | q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
155 | k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
156 | v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
157 | cos, sin = self.rotary(q)
158 | q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
159 | q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) # QK norm suggested by @Grad62304977
160 | y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
161 | y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
162 | y = self.c_proj(y)
163 | return y
164 |
165 | class MLP(nn.Module):
166 |
167 | def __init__(self, config):
168 | super().__init__()
169 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
170 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
171 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
172 |
173 | def forward(self, x):
174 | x = self.c_fc(x)
175 | x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
176 | x = self.c_proj(x)
177 | return x
178 |
179 | class Block(nn.Module):
180 |
181 | def __init__(self, config):
182 | super().__init__()
183 | self.attn = CausalSelfAttention(config)
184 | self.mlp = MLP(config)
185 |
186 | def forward(self, x):
187 | x = x + self.attn(F.rms_norm(x, (x.size(-1),)))
188 | x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
189 | return x
190 |
191 | # -----------------------------------------------------------------------------
192 | # The main GPT-2 model
193 |
194 | @dataclass
195 | class GPTConfig:
196 | vocab_size : int = 50304
197 | n_layer : int = 12
198 | n_head : int = 6 # head dim 128 suggested by @Grad62304977
199 | n_embd : int = 768
200 |
201 | class GPT(nn.Module):
202 |
203 | def __init__(self, config):
204 | super().__init__()
205 | self.config = config
206 |
207 | self.transformer = nn.ModuleDict(dict(
208 | wte = nn.Embedding(config.vocab_size, config.n_embd),
209 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
210 | ))
211 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
212 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
213 |
214 | def forward(self, idx, targets=None, return_logits=True):
215 |
216 | # forward the GPT model itself
217 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
218 | for block in self.transformer.h:
219 | x = block(x)
220 | x = F.rms_norm(x, (x.size(-1),))
221 |
222 | if targets is not None:
223 | # if we are given some desired targets also calculate the loss
224 | logits = self.lm_head(x)
225 | logits = logits.float() # use tf32/fp32 for logits
226 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
227 | else:
228 | # inference-time mini-optimization: only forward the lm_head on the very last position
229 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
230 | logits = logits.float() # use tf32/fp32 for logits
231 | loss = None
232 |
233 | # there are performance reasons why not returning logits is prudent, if not needed
234 | if not return_logits:
235 | logits = None
236 |
237 | return logits, loss
238 |
239 | # -----------------------------------------------------------------------------
240 | # Our own simple Distributed Data Loader
241 |
242 | def _peek_data_shard(filename):
243 | # only reads the header, returns header data
244 | with open(filename, "rb") as f:
245 | # first read the header, which is 256 int32 integers (4 bytes each)
246 | header = np.frombuffer(f.read(256*4), dtype=np.int32)
247 | if header[0] != 20240520:
248 | print("ERROR: magic number mismatch in the data .bin file!")
249 | print("---> HINT: Are you passing in a correct file with --input_bin?")
250 | print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
251 | print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
252 | exit(1)
253 | assert header[1] == 1, "unsupported version"
254 | ntok = header[2] # number of tokens (claimed)
255 | return ntok # for now just return the number of tokens
256 |
257 | def _load_data_shard(filename):
258 | with open(filename, "rb") as f:
259 | # first read the header, which is 256 int32 integers (4 bytes each)
260 | header = np.frombuffer(f.read(256*4), dtype=np.int32)
261 | assert header[0] == 20240520, "magic number mismatch in the data .bin file"
262 | assert header[1] == 1, "unsupported version"
263 | ntok = header[2] # number of tokens (claimed)
264 | # the rest of it are tokens, stored as uint16
265 | tokens = np.frombuffer(f.read(), dtype=np.uint16)
266 | assert len(tokens) == ntok, "number of tokens read does not match header?"
267 | return tokens
268 |
269 | class DistributedDataLoader:
270 | def __init__(self, filename_pattern, B, T, process_rank, num_processes):
271 | self.process_rank = process_rank
272 | self.num_processes = num_processes
273 | self.B = B
274 | self.T = T
275 |
276 | # glob files that match the pattern
277 | self.files = sorted(glob.glob(filename_pattern))
278 | assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
279 |
280 | # load and validate all data shards, count number of tokens in total
281 | ntok_total = 0
282 | for fname in self.files:
283 | shard_ntok = _peek_data_shard(fname)
284 | assert shard_ntok >= num_processes * B * T + 1
285 | ntok_total += int(shard_ntok)
286 | self.ntok_total = ntok_total
287 |
288 | # kick things off
289 | self.reset()
290 |
291 | def reset(self):
292 | self.current_shard = 0
293 | self.current_position = self.process_rank * self.B * self.T
294 | self.tokens = _load_data_shard(self.files[self.current_shard])
295 |
296 | def advance(self): # advance to next data shard
297 | self.current_shard = (self.current_shard + 1) % len(self.files)
298 | self.current_position = self.process_rank * self.B * self.T
299 | self.tokens = _load_data_shard(self.files[self.current_shard])
300 |
301 | def next_batch(self):
302 | B = self.B
303 | T = self.T
304 | buf = self.tokens[self.current_position : self.current_position+B*T+1]
305 | buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
306 | x = (buf[:-1]).view(B, T) # inputs
307 | y = (buf[1:]).view(B, T) # targets
308 | # advance current position and load next shard if necessary
309 | self.current_position += B * T * self.num_processes
310 | if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
311 | self.advance()
312 | return x.cuda(), y.cuda()
313 |
314 | # -----------------------------------------------------------------------------
315 | # int main
316 |
317 | @dataclass
318 | class Hyperparameters:
319 | # data hyperparams
320 | input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
321 | input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
322 | # optimization hyperparams
323 | batch_size : int = 8*64 # batch size, in sequences, across all devices
324 | device_batch_size : int = 64 # batch size, in sequences, per device
325 | sequence_length : int = 1024 # sequence length, in tokens
326 | num_iterations : int = 5100 # number of iterations to run
327 | learning_rate : float = 0.0036
328 | warmup_iters : int = 0
329 | warmdown_iters : int = 1450 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
330 | weight_decay : float = 0
331 | # evaluation and logging hyperparams
332 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
333 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
334 | save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end
335 | args = Hyperparameters()
336 |
337 | # set up DDP (distributed data parallel). torchrun sets this env variable
338 | assert torch.cuda.is_available()
339 | dist.init_process_group(backend='nccl')
340 | ddp_rank = int(os.environ['RANK'])
341 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
342 | ddp_world_size = int(os.environ['WORLD_SIZE'])
343 | device = f'cuda:{ddp_local_rank}'
344 | torch.cuda.set_device(device)
345 | print(f"using device: {device}")
346 | master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
347 |
348 | # convenience variables
349 | B, T = args.device_batch_size, args.sequence_length
350 | # calculate the number of steps to take in the val loop.
351 | assert args.val_tokens % (B * T * ddp_world_size) == 0
352 | val_steps = args.val_tokens // (B * T * ddp_world_size)
353 | # calculate the steps of gradient accumulation required to attain the desired global batch size.
354 | assert args.batch_size % (B * ddp_world_size) == 0
355 | train_accumulation_steps = args.batch_size // (B * ddp_world_size)
356 |
357 | # load tokens
358 | train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
359 | val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
360 | if master_process:
361 | print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
362 | print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
363 | x, y = train_loader.next_batch()
364 |
365 | # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
366 | # this originates from Karpathy's experiments.
367 | num_vocab = 50304
368 | model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768))
369 | model = model.cuda()
370 | if hasattr(config, "coordinate_descent_tuning"):
371 | config.coordinate_descent_tuning = True # suggested by @Chillee
372 | model = torch.compile(model)
373 | # here we wrap model into DDP container
374 | model = DDP(model, device_ids=[ddp_local_rank])
375 | raw_model = model.module # always contains the "raw" unwrapped model
376 | ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
377 |
378 | # init the optimizer(s)
379 | optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95),
380 | weight_decay=args.weight_decay, fused=True)
381 | optimizer2 = Muon(raw_model.transformer.h.parameters(), lr=0.1*args.learning_rate, momentum=0.95)
382 | optimizers = [optimizer1, optimizer2]
383 | # learning rate decay scheduler (linear warmup and warmdown)
384 | def get_lr(it):
385 | assert it <= args.num_iterations
386 | # 1) linear warmup for warmup_iters steps
387 | if it < args.warmup_iters:
388 | return (it+1) / args.warmup_iters
389 | # 2) constant lr for a while
390 | elif it < args.num_iterations - args.warmdown_iters:
391 | return 1.0
392 | # 3) linear warmdown
393 | else:
394 | decay_ratio = (args.num_iterations - it) / args.warmdown_iters
395 | return decay_ratio
396 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
397 |
398 | # begin logging
399 | if master_process:
400 | run_id = str(uuid.uuid4())
401 | logdir = 'logs/%s/' % run_id
402 | os.makedirs(logdir, exist_ok=True)
403 | logfile = 'logs/%s.txt' % run_id
404 | # create the log file
405 | with open(logfile, "w") as f:
406 | # begin the log by printing this file (the Python code)
407 | f.write('='*100 + '\n')
408 | f.write(code)
409 | f.write('='*100 + '\n')
410 | # log information about the hardware/software environment this is running on
411 | # and print the full `nvidia-smi` to file
412 | f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n")
413 | import subprocess
414 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
415 | f.write(f'{result.stdout}\n')
416 | f.write('='*100 + '\n')
417 |
418 | training_time_ms = 0
419 | # start the clock
420 | torch.cuda.synchronize()
421 | t0 = time.time()
422 | # begin training
423 | train_loader.reset()
424 | for step in range(args.num_iterations + 1):
425 | last_step = (step == args.num_iterations)
426 | # This effectively ignores timing first 10 steps, which are slower for weird reasons.
427 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
428 | # steps with dummy data first, and then re-initialize the model and reset the loader.
429 | if step == 10:
430 | training_time_ms = 0
431 | t0 = time.time()
432 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
433 |
434 | # once in a while evaluate the validation dataset
435 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
436 | # stop the clock
437 | torch.cuda.synchronize()
438 | training_time_ms += 1000 * (time.time() - t0)
439 | # run validation batches
440 | model.eval()
441 | val_loader.reset()
442 | val_loss = 0.0
443 | for _ in range(val_steps):
444 | x_val, y_val = val_loader.next_batch()
445 | with ctx: # of course, we'd like to use no_grad() here too, but that creates a torch.compile error for some reason
446 | _, loss = model(x_val, y_val, return_logits=False)
447 | val_loss += loss.detach()
448 | del loss
449 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
450 | val_loss /= val_steps
451 | # log val loss to console and to logfile
452 | if master_process:
453 | print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
454 | with open(logfile, "a") as f:
455 | f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n')
456 | # start the clock again
457 | torch.cuda.synchronize()
458 | t0 = time.time()
459 |
460 | if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
461 | # stop the clock
462 | torch.cuda.synchronize()
463 | training_time_ms += 1000 * (time.time() - t0)
464 | # save the state of the training process
465 | log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
466 | torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
467 | # start the clock again
468 | torch.cuda.synchronize()
469 | t0 = time.time()
470 |
471 | # bit confusing: we want to make sure to eval on 0th iteration
472 | # but also after the very last iteration. so we loop for step <= num_iterations
473 | # instead of just < num_iterations (one extra due to <=), only to do
474 | # the validation/sampling one last time, and then we break right here as we're done.
475 | if last_step:
476 | break
477 |
478 | # --------------- TRAINING SECTION BEGIN -----------------
479 | model.train()
480 | for i in range(1, train_accumulation_steps+1):
481 | # forward pass
482 | with ctx:
483 | _, loss = model(x, y, return_logits=False)
484 | train_loss = loss.detach()
485 | # advance the dataset for the next batch
486 | x, y = train_loader.next_batch()
487 | # backward pass
488 | if i < train_accumulation_steps:
489 | with model.no_sync(): # there's no need to sync gradients every accumulation step
490 | loss.backward()
491 | else:
492 | loss.backward() # just sync on the last step
493 | for p in model.parameters():
494 | p.grad /= train_accumulation_steps
495 | # step the optimizers and schedulers
496 | for opt, sched in zip(optimizers, schedulers):
497 | opt.step()
498 | sched.step()
499 | # null the gradients
500 | model.zero_grad(set_to_none=True)
501 | # --------------- TRAINING SECTION END -------------------
502 | # everything that follows now is just diagnostics, prints, logging, etc.
503 |
504 | #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
505 | if master_process:
506 | approx_time = training_time_ms + 1000 * (time.time() - t0)
507 | print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
508 | with open(logfile, "a") as f:
509 | f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n")
510 |
511 | if master_process:
512 | print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
513 |
514 | # -------------------------------------------------------------------------
515 | # clean up nice
516 | dist.destroy_process_group()
517 |
--------------------------------------------------------------------------------
/records/102924_Optimizers/README.md:
--------------------------------------------------------------------------------
1 | # Optimizer comparison for NanoGPT speedrunning
2 |
3 | This is a comparison between the four best optimizers I am aware of for NanoGPT speedrunning. They are compared using the 10/18/24 NanoGPT speedrunning record.
4 |
5 | Reproducible logs:
6 | * [Adam](95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt)
7 | * [DistributedShampoo](8bfe4e35-c3fc-4b70-a984-3be937b71ff3)
8 | * [SOAP](e21a2838-a0f2-46f2-a247-db0021165682.txt)
9 | * [Muon](8d6193f4-27fc-4e68-899f-af70019a4d54.txt)
10 |
11 | Results:
12 | 
13 | 
14 |
15 | ### General notes for all optimizers
16 |
17 | All optimizers are run using zero weight decay (which is found to be empirically optimal).
18 |
19 | And they are all run with a warmup-stable-decay / trapezoidal schedule, which also seems to be optimal. That's what causes the kink in the loss curve ~75% of the way to the end.
20 |
21 | In addition, in all cases, we optimize the shared embedding/head layer just using Adam (which is also found to be empirically optimal).
22 | Note that in the following code snippets, `raw_model.transformer.h.parameters()` gives all parameters besides those two.
23 |
24 | In each case, the hyperparameters are the best ones I could find in around 20 attempts.
25 |
26 | ## [Adam](95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt)
27 | The optimizer here is equivalent to:
28 | ```
29 | torch.optim.Adam(raw_model.transformer.h.parameters(), lr=0.0018, betas=(0.9, 0.95))
30 | ```
31 |
32 |
33 | ## [DistributedShampoo](8bfe4e35-c3fc-4b70-a984-3be937b71ff3.txt)
34 | Run as follows:
35 | ```
36 | DistributedShampoo(
37 | raw_model.transformer.h.parameters(),
38 | lr=0.0018,
39 | betas=(0.95, 0.95),
40 | epsilon=1e-12,
41 | weight_decay=0,
42 | max_preconditioner_dim=8192,
43 | precondition_frequency=10,
44 | use_decoupled_weight_decay=True,
45 | grafting_config=AdamGraftingConfig(
46 | beta2=0.95,
47 | epsilon=1e-8,
48 | ),
49 | distributed_config=DDPShampooConfig(
50 | communication_dtype=CommunicationDType.FP32,
51 | num_trainers_per_group=8,
52 | communicate_params=False,
53 | ),
54 | )
55 | ```
56 |
57 | This is using the official `DistributedShampoo` implementation from [here](https://github.com/facebookresearch/optimizers/tree/ad2809a291c01859f68fcabbcb49a2aa75fd7827/distributed_shampoo).
58 |
59 | Things that turned out to be important:
60 | * Don't use epsilon above 1e-8; this loses performance. Epsilon 1e-12 performs as well as 1e-15
61 | * Betas=(0.95, 0.95) seemed optimal, which turns out to be the same thing that SOAP uses
62 | * Higher preconditioner update frequency is better but slower
63 |
64 | I'm open to hyperparameter suggestions; the experiment takes ~20-30 minutes to run on a fresh 8xH100 instance, so it's not hard for me to run more attempts.
65 |
66 |
67 | ## [SOAP](e21a2838-a0f2-46f2-a247-db0021165682.txt)
68 | ```
69 | SOAP(model.transformer.h.parameters(), lr=0.0018, betas=(.95, .95), precondition_frequency=10)
70 | ```
71 |
72 | This is using the official SOAP implementation [here](https://github.com/nikhilvyas/SOAP/blob/bbce86e890d3b697380f4376acb600c2d6c3d203/soap.py).
73 |
74 | Based on conversations with the authors, it is likely that a future SOAP implementation will significantly reduce the wallclock overhead.
75 |
76 |
77 | ## [Muon](8d6193f4-27fc-4e68-899f-af70019a4d54.txt)
78 | ```
79 | Muon(raw_model.transformer.h.parameters(), lr=0.02, momentum=0.95)
80 | ```
81 |
82 |
83 | ## Openness
84 |
85 | These training logs are reproducible (just cut out the part besides the code, and run it using the `run.sh` in the top-level folder). They take 12-25 minutes to run.
86 |
87 | I tried to do a good job sweeping the hyperparameters for each optimizer, but I can easily have missed something, or just not have performed enough runs.
88 |
89 | Therefore, I am interested in any better hyperparameter settings which other researchers can find, for any of the optimizers.
90 | If you post or send me your own reproducible log with one of these optimizers, I will be very happy to boost it in any way I can.
91 |
92 | ## Appendix: Negative results
93 |
94 | I believe it was Shazeer who said something like "negative results in machine learning are not worth much, because your inability to make something work doesn't prove that it can't work"
95 |
96 | Given that disclaimer, here are some optimizers that I tried to make work, but was unable to get a significant boost over Adam with:
97 | * Sophia
98 | * Lion
99 | * AdamWScheduleFree
100 | * AdEmaMix (actually this was slightly better than Adam, just not enough to get near competing with the three Shampoo-like optimizers)
101 |
102 | Of course, this is just for NanoGPT speedrunning (short train duration); it's quite possible they work better at longer training duration or for larger models.
103 |
104 |
--------------------------------------------------------------------------------
/records/102924_Optimizers/nanogpt_speedrun81w.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/records/102924_Optimizers/nanogpt_speedrun81w.png
--------------------------------------------------------------------------------
/records/102924_Optimizers/nanogpt_speedrun82w.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/records/102924_Optimizers/nanogpt_speedrun82w.png
--------------------------------------------------------------------------------
/records/110324_UntieEmbed/README.md:
--------------------------------------------------------------------------------
1 | # New record 11/03/24
2 |
3 |
4 | New NanoGPT training speed record: 3.28 FineWeb val loss in 10.8 minutes on 8xH100
5 |
6 | Previous record: 12.0 minutes
7 | Changelog:
8 | - untied embed and head weights
9 | - added RMSNorm after embed
10 | - init head to zero
11 |
12 | Driven by @Grad62304977
13 |
14 | ---
15 |
16 | Technically, this is somewhat of an "any%" record, since untying the embedding and lm_head adds 39M parameters.
17 |
18 | However, it doesn't change the number of active parameters or the inference throughput. Future records will stay constrained to 124M active parameters.
19 |
20 | ---
21 |
22 | Like the last architectural change, this record was driven by @Grad62304977. I just finetuned some things and did bookkeeping.
23 |
24 | ---
25 |
26 | Shoutout to @cloneofsimo whose scaling guide already suggests initializing the head to zero. This works quite well and is a significant fraction of the record.
27 |
28 |
--------------------------------------------------------------------------------
/records/110424_50Bruns/README.md:
--------------------------------------------------------------------------------
1 | # 50B-token runs
2 |
3 | This folder contains four runs generated by extending the 11/03/24 speedrun record to 50B FineWeb tokens.
4 | The goal is to test how the speedrun generalizes to long durations, and especially how well Muon does.
5 |
6 | We compare two things:
7 | 1. We compare Muon to Adam as the optimizer for the transformer body. (The head and embedding are always optimized by Adam.)
8 | 2. We compare training on 5 epochs of 10B tokens to training on 50B tokens. (Surprisingly this does about the same)
9 |
10 | The four resulting runs are as follows:
11 |
12 | * [Muon 50B tokens](./530f3ee1-8862-4d21-be2b-da10eb05e6a9.txt) (HellaSwag=35.82)
13 | * [Adam 50B tokens](./69c33fc9-eabb-4a38-aa08-6922914eb405.txt) (HellaSwag=34.26)
14 | * [Muon 5x10B tokens](./4fbe61ec-f79a-4c19-836d-46d599deecce.txt) (HellaSwag=36.17)
15 | * [Adam 5x10B tokens](./3d715d41-453a-40d6-9506-421ba69766b2.txt) (HellaSwag=34.05)
16 |
17 | To get a sense of what a good HellaSwag score would be for this scale of model, here are some baselines:
18 | * Karpathy's baseline llm.c training (trained for 10B FineWeb tokens): 29.9
19 | * OpenAI GPT-2 (124M): 29.4
20 | * OpenAI GPT-3 (124M) (trained for 300B WebText tokens): 33.7
21 | * Huggingface SmolLM2-135M (trained for 2T FineWeb/DCLM/etc tokens): 42.1
22 |
23 | Note: I'm a little concerned that the learning rate schedule (WSD) and weight decay (zero), which are tuned for the speedrun duration,
24 | might become undertuned/suboptimal for trainings of this duration.
25 | It does look like the gap between Muon/Adam is too large to be closed by something like this, and the HellaSwag scores look quite reasonable, but you never know.
26 |
27 |
--------------------------------------------------------------------------------
/records/110624_ShortcutsTweaks/README.md:
--------------------------------------------------------------------------------
1 | # New record 11/06/24
2 |
3 | 8.2 minutes on 8xH100 (previous record: 10.8 minutes)
4 |
5 | 
6 | 
7 |
8 | * [Old record 11/03/24](d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt)
9 | * [+shorten duration](4a71cc92-0f43-4058-a033-23e85c1e98f1.txt)
10 | * [+value residual](042f9e87-07e6-4504-bb04-4ec59a380211.txt) by @Grad62304977 following [1]
11 | * [+learnable lambda](43f60c4f-0448-4de7-83d9-643ca26f61e7.txt) @Grad62304977's innovation on top of [1]
12 | * [+embed shortcut](05b29e54-0be0-4a0f-a1e2-7d5317daedd3.txt)
13 | * [+momentum warmup](10119f53-7001-4248-bfd9-33d32427a912.txt)
14 | * [+tanh logit capping](dd7304a6-cc43-4d5e-adb8-c070111464a1.txt) by @Grad62304977 following [2]
15 |
16 | ## Code snippets
17 |
18 | ### Value residual
19 |
20 | In the attention layer:
21 | ```
22 | def forward(self, x, v1=None):
23 | ...
24 | v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
25 | if v1 is None:
26 | v1 = v
27 | v = 0.5 * v + 0.5 * v1.view_as(v)
28 | ```
29 | Where the first block receives v1=None, and subsequent blocks receive v1 as the value produced by the first block.
30 |
31 | ### Learnable lambda
32 |
33 | In the attention block:
34 | ```
35 | def __init__(self, config):
36 | ...
37 | self.lamb = nn.Parameter(torch.tensor(0.5))
38 |
39 | def forward(self, x, v1=None):
40 | ...
41 | v = (1 - self.lamb) * v + self.lamb * v1.view_as(v)
42 | ```
43 | That is, we just replace the fixed 0.5 constant used in standard value residual [1] with a learnable scalar (optimized by Adam(lr=0.02)).
44 |
45 | ### Embed shortcut
46 |
47 | Replaces the standard transformer block with this:
48 |
49 | ```
50 | class Block(nn.Module):
51 |
52 | def __init__(self, config):
53 | super().__init__()
54 | self.attn = CausalSelfAttention(config)
55 | self.mlp = MLP(config)
56 | self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
57 |
58 | def forward(self, x, x0):
59 | x = self.lambdas[0] * x + self.lambdas[1] * x0
60 | x = x + self.attn(F.rms_norm(x, (x.size(-1),)), v1)
61 | x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
62 | return x
63 | ```
64 |
65 | where the two scalars are optimized using Adam(lr=0.02), and `x0` is fed in from the initial embedding via:
66 | ```
67 | ...
68 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
69 | x = F.rms_norm(x, (x.size(-1),))
70 | x0 = x
71 | for block in self.transformer.h:
72 | x = block(x, x0)
73 | ...
74 | ```
75 |
76 | ### Momentum warmup
77 |
78 | Just adds the following two lines.
79 | ```
80 | frac = min(step/500, 1)
81 | optimizer3.param_groups[0]['momentum'] = (1 - frac) * 0.85 + frac * 0.95
82 | ```
83 | where `optimizer3` is the Muon for the body of the transformer.
84 |
85 | ### Tanh soft capping
86 |
87 | Just adds the following line.
88 |
89 | ```
90 | logits = 30 * torch.tanh(logits / 30)
91 | ```
92 |
93 |
94 | ## References
95 |
96 | 1. [Zhou, Zhanchao, et al. "Value Residual Learning For Alleviating Attention Concentration In Transformers." arXiv preprint arXiv:2410.17897 (2024).](https://arxiv.org/abs/2410.17897)
97 | 2. [Team, Gemma, et al. "Gemma 2: Improving open language models at a practical size." arXiv preprint arXiv:2408.00118 (2024).](https://arxiv.org/abs/2408.00118)
98 |
99 |
--------------------------------------------------------------------------------
/records/110624_ShortcutsTweaks/nanogpt_speedrun110.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/records/110624_ShortcutsTweaks/nanogpt_speedrun110.png
--------------------------------------------------------------------------------
/records/110624_ShortcutsTweaks/nanogpt_speedrun111.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlinkDL/modded-nanogpt-rwkv/49f7076ab95a660e3585be94c382e4b2ada920af/records/110624_ShortcutsTweaks/nanogpt_speedrun111.png
--------------------------------------------------------------------------------
/records/110924_Replicateleloykun/README.md:
--------------------------------------------------------------------------------
1 | This is a replication attempt for the record attempt described [here](https://x.com/leloykun/status/1854557419768254915) by @leloykun.
2 |
3 | The original record could not be directly accepted because it showed a slower wallclock time than the previous record -
4 | however, this was plausibly due to hardware differences, as the competitor's hardware was slightly slower.
5 |
6 | Therefore, to certify this attempt as the new record, here I replicated it on my own hardware.
7 | This did successfully reduce the wallclock time compared to the 11/07/24 record by ~11 seconds, however it also
8 | resulted in an invalid val loss of 3.2824, above the threshold of 3.28.
9 |
10 | The [original record attempt's reproducible log](https://github.com/leloykun/modded-nanogpt/blob/224f10d190677d9dc3c9c45da280078196a6fe40/records/110724_EmbeddingBetasCooldown/6c9d875b-ad91-46c9-9ede-2c7f998b9b16.txt) attained a val loss of 3.2798, just barely below the 3.28 threshold. So this difference is plausibly due to random inter-run variance.
11 |
12 | This indicates that the true average val loss of the run may be worse than 3.28, meaning I am **unable to certify it as the new record.**
13 |
14 | Ideally, all records should attain a low enough val loss such that >95% of runs attain below 3.28. Good evidence for this would be a single run
15 | attaining <= 3.278. Previous records have adhered to this rule, but admittedly it's hard to define precisely and is therefore mostly a matter of taste.
16 |
17 |
--------------------------------------------------------------------------------
/records/111024_UNetDoubleLr/README.md:
--------------------------------------------------------------------------------
1 | This is a record by Brendan Hogan Rappazzo [@brendanh0gan](https://x.com/brendanh0gan).
2 |
3 | New record: 7.23 minutes
4 |
5 | Previous record: 7.8 minutes
6 |
7 | Changelog:
8 | - Added U-net-like skip connections into the transformer
9 | - Doubled the learning rate
10 |
11 | ---
12 |
13 | This record was first posted [here](https://x.com/brendanh0gan/status/1855273758681866352), & then a few iterations were required to benchmark it on 8x SXM H100s.
14 | Brendan's fork of modded-nanogpt is [here](https://github.com/brendanhogan/modded-nanogpt/tree/master). The code for the record can also be extracted from the reproducible log in this folder.
15 |
16 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | tqdm
3 | torch
4 | huggingface-hub
5 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 train_gpt2.py
2 |
--------------------------------------------------------------------------------
/run_rwkv6.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 train_rwkv6.py "$@"
2 |
--------------------------------------------------------------------------------
/run_rwkv7.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=8 train_rwkv7.py "$@"
2 |
--------------------------------------------------------------------------------
/rwkv_cuda/wkv6_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "ATen/ATen.h"
4 | typedef at::BFloat16 bf16;
5 |
6 | template
7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H,
8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u,
9 | F *__restrict__ const _y)
10 | {
11 | const int b = blockIdx.x / H;
12 | const int h = blockIdx.x % H;
13 | const int i = threadIdx.x;
14 | _u += h*_N_;
15 |
16 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
17 | float state[_N_] = {0};
18 |
19 | __syncthreads();
20 | u[i] = float(_u[i]);
21 | __syncthreads();
22 |
23 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
24 | {
25 | __syncthreads();
26 | w[i] = __expf(-__expf(float(_w[t])));
27 | r[i] = float(_r[t]);
28 | k[i] = float(_k[t]);
29 | __syncthreads();
30 |
31 | const float v = float(_v[t]);
32 | float y = 0;
33 |
34 | #pragma unroll
35 | for (int j = 0; j < _N_; j+=4)
36 | {
37 | const float4& r_ = (float4&)(r[j]);
38 | const float4& k_ = (float4&)(k[j]);
39 | const float4& w_ = (float4&)(w[j]);
40 | const float4& u_ = (float4&)(u[j]);
41 | float4& s = (float4&)(state[j]);
42 | float4 x;
43 |
44 | x.x = k_.x * v;
45 | x.y = k_.y * v;
46 | x.z = k_.z * v;
47 | x.w = k_.w * v;
48 |
49 | y += r_.x * (u_.x * x.x + s.x);
50 | y += r_.y * (u_.y * x.y + s.y);
51 | y += r_.z * (u_.z * x.z + s.z);
52 | y += r_.w * (u_.w * x.w + s.w);
53 |
54 | s.x = s.x * w_.x + x.x;
55 | s.y = s.y * w_.y + x.y;
56 | s.z = s.z * w_.z + x.z;
57 | s.w = s.w * w_.w + x.w;
58 | }
59 | _y[t] = F(y);
60 | }
61 | }
62 |
63 | template
64 | __global__ void kernel_backward_101(const int B, const int T, const int C, const int H,
65 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
66 | F *__restrict__ const _gr, F *__restrict__ const _gu)
67 | {
68 | const int b = blockIdx.x / H;
69 | const int h = blockIdx.x % H;
70 | const int i = threadIdx.x;
71 |
72 | __shared__ float v[_N_], gy[_N_];
73 |
74 | const float u = float(_u[h*_N_ + i]);
75 |
76 | float state[_N_] = {0};
77 |
78 | const int t_0 = b*T*C + h*_N_ + i;
79 | const int t_T = t_0 + T*C;
80 |
81 | float gu = 0;
82 | for (int t = t_0; t < t_T; t += C)
83 | {
84 | __syncthreads();
85 | v[i] = float(_v[t]);
86 | gy[i] = float(_gy[t]);
87 | __syncthreads();
88 |
89 | const float k = float(_k[t]);
90 | const float w = __expf(-__expf(float(_w[t])));
91 | float gr = 0, gu_ = 0;
92 |
93 | #pragma unroll
94 | for (int j = 0; j < _N_; j++)
95 | {
96 | float& s = state[j];
97 | float x = k * v[j];
98 |
99 | gr += (u * x + s) * gy[j];
100 | gu_ += x * gy[j];
101 | s = s * w + x;
102 | }
103 | _gr[t] = F(gr);
104 | gu += float(_r[t]) * gu_;
105 | }
106 | _gu[b*C + h*_N_ + i] = F(gu);
107 | }
108 |
109 | template
110 | __global__ void kernel_backward_102(const int B, const int T, const int C, const int H,
111 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
112 | F *__restrict__ const _gk)
113 | {
114 | const int b = blockIdx.x / H;
115 | const int h = blockIdx.x % H;
116 | const int i = threadIdx.x;
117 |
118 | __shared__ float v[_N_], gy[_N_];
119 |
120 | const float u = float(_u[h*_N_ + i]);
121 |
122 | float scccc[_N_] = {0};
123 |
124 | const int t_0 = b*T*C + h*_N_ + i;
125 | const int t_T_1 = t_0 + (T-1)*C;
126 |
127 | for (int t = t_T_1; t >= t_0; t -= C)
128 | {
129 | __syncthreads();
130 | v[i] = float(_v[t]);
131 | gy[i] = float(_gy[t]);
132 | __syncthreads();
133 |
134 | const float rr = float(_r[t]);
135 | const float w = __expf(-__expf(float(_w[t])));
136 | float gk = 0;
137 |
138 | #pragma unroll
139 | for (int j = 0; j < _N_; j++)
140 | {
141 | float& s = scccc[j];
142 | float x = rr * gy[j];
143 |
144 | gk += (u * x + s) * v[j];
145 | s = x + s * w;
146 | }
147 | _gk[t] = F(gk);
148 | }
149 | }
150 |
151 | template
152 | __global__ void kernel_backward_103(const int B, const int T, const int C, const int H,
153 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
154 | F *__restrict__ const _gv)
155 | {
156 | const int b = blockIdx.x / H;
157 | const int h = blockIdx.x % H;
158 | const int i = threadIdx.x;
159 | _u += h*_N_;
160 |
161 | __shared__ float u_[_N_], r[_N_], k[_N_], w_[_N_];
162 | __syncthreads();
163 | u_[i] = float(_u[i]);
164 | __syncthreads();
165 |
166 | float sdddd[_N_] = {0};
167 |
168 | const int t_0 = b*T*C + h*_N_ + i;
169 | const int t_T_1 = t_0 + (T-1)*C;
170 |
171 | for (int t = t_T_1; t >= t_0; t -= C)
172 | {
173 | __syncthreads();
174 | r[i] = float(_r[t]);
175 | k[i] = float(_k[t]);
176 | w_[i] = __expf(-__expf(float(_w[t])));
177 | __syncthreads();
178 |
179 | const float gyy = float(_gy[t]);
180 | float gv = 0;
181 |
182 | #pragma unroll
183 | for (int j = 0; j < _N_; j++)
184 | {
185 | float& s = sdddd[j];
186 | float x = gyy * r[j];
187 |
188 | gv += (u_[j] * x + s) * k[j];
189 | s = x + s * w_[j];
190 | }
191 | _gv[t] = F(gv);
192 | }
193 | }
194 |
195 | template
196 | __global__ void kernel_backward_201(const int B, const int T, const int C, const int H,
197 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
198 | F *__restrict__ const _gw)
199 | {
200 | const int b = blockIdx.x / H;
201 | const int h = blockIdx.x % H;
202 | const int i = threadIdx.x;
203 |
204 | __shared__ float v[_N_], gy[_N_];
205 | float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0};
206 |
207 | const int t_0 = b*T*C + h*_N_ + i;
208 | const int t_1 = t_0 + C;
209 | const int t_2 = t_0 + 2*C;
210 | const int t_T_1 = t_0 + (T-1)*C;
211 |
212 | for (int t = t_T_1; t > t_1; t -= C)
213 | {
214 | __syncthreads();
215 | gy[i] = float(_gy[t]);
216 | v[i] = float(_v[t-2*C]);
217 | __syncthreads();
218 |
219 | const float r = float(_r[t]);
220 | const float w = __expf(-__expf(float(_w[t-C])));
221 | float sum = 0.0f;
222 |
223 | #pragma unroll
224 | for (int j = 0; j < _N_; j++)
225 | {
226 | float& s = saaaa[j];
227 | float x = r * gy[j];
228 | s = (s + x) * w;
229 | sum += s * v[j];
230 | }
231 | sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]);
232 | }
233 |
234 | float sss = sbbbb[0];
235 | _gw[t_0] = 0;
236 | _gw[t_1] = F(sss * -__expf(float(_w[t_1])));
237 |
238 | for (int t = t_2; t < t_T_1; t += C)
239 | {
240 | __syncthreads();
241 | gy[i] = float(_gy[t]);
242 | v[i] = float(_v[t-2*C]);
243 | __syncthreads();
244 |
245 | const float w = __expf(-__expf(float(_w[t-C])));
246 | const float k = float(_k[t-2*C]);
247 | float sum = 0.0f;
248 |
249 | #pragma unroll
250 | for (int j = 0; j < _N_; j++)
251 | {
252 | float& s = scccc[j];
253 | float x = k * v[j];
254 | s = (s + x) * w;
255 | sum += s * gy[j];
256 | }
257 | sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t]));
258 | _gw[t] = F(sss * -__expf(float(_w[t])));
259 | }
260 | _gw[t_T_1] = 0;
261 | }
262 |
263 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *y)
264 | {
265 | assert(H*_N_ == C);
266 | assert(_N_%4 == 0);
267 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y);
268 | }
269 |
270 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
271 | {
272 | assert(H*_N_ == C);
273 | assert(_N_%4 == 0);
274 | kernel_backward_101<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gu);
275 | kernel_backward_102<<>>(B, T, C, H, r, k, v, w, u, gy, gk);
276 | kernel_backward_103<<>>(B, T, C, H, r, k, v, w, u, gy, gv);
277 | kernel_backward_201<<>>(B, T, C, H, r, k, v, w, u, gy, gw);
278 | }
279 |
--------------------------------------------------------------------------------
/rwkv_cuda/wkv6_op.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "ATen/ATen.h"
3 | typedef at::BFloat16 bf16;
4 |
5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *y);
6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
7 |
8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr());
10 | }
11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr());
13 | }
14 |
15 | TORCH_LIBRARY(wkv6, m) {
16 | m.def("forward(int B, int T, int C, int H, Tensor r, Tensor k, Tensor v, Tensor w, Tensor u, Tensor(a!) y) -> ()");
17 | m.def("backward(int B, int T, int C, int H, Tensor r, Tensor k, Tensor v, Tensor w, Tensor u, Tensor gy, Tensor(a!) gr, Tensor(b!) gk, Tensor(c!) gv, Tensor(d!) gw, Tensor(e!) gu) -> ()");
18 | }
19 |
20 | TORCH_LIBRARY_IMPL(wkv6, CUDA, m) {
21 | m.impl("forward", &forward);
22 | m.impl("backward", &backward);
23 | }
24 |
--------------------------------------------------------------------------------
/rwkv_cuda/wkv7g_op.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "ATen/ATen.h"
3 |
4 | typedef at::BFloat16 bf16;
5 |
6 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y, float *saa, float* sss);
7 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, float *saa, float* sss, float* zzz, bf16 *gy, bf16 *gr, bf16 *gw, bf16 *gk, bf16 *gv, bf16 *ga, bf16 *gb);
8 |
9 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y, torch::Tensor &saa, torch::Tensor &sss) {
10 | cuda_forward(B, T, C, H, r.data_ptr(), w.data_ptr(), k.data_ptr(), v.data_ptr(), a.data_ptr(), b.data_ptr(), y.data_ptr(), saa.data_ptr(), sss.data_ptr());
11 | }
12 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &saa, torch::Tensor &sss, torch::Tensor &zzz, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gw, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &ga, torch::Tensor &gb) {
13 | cuda_backward(B, T, C, H, r.data_ptr(), w.data_ptr(), k.data_ptr(), v.data_ptr(), a.data_ptr(), b.data_ptr(), saa.data_ptr(), sss.data_ptr(), zzz.data_ptr(), gy.data_ptr(), gr.data_ptr(), gw.data_ptr(), gk.data_ptr(), gv.data_ptr(), ga.data_ptr(), gb.data_ptr());
14 | }
15 |
16 | TORCH_LIBRARY(wkv7g, m) {
17 | m.def("forward(int B, int T, int C, int H, Tensor r, Tensor w, Tensor k, Tensor v, Tensor a, Tensor b, Tensor(a!) y, Tensor(b!) saa, Tensor(c!) sss) -> ()");
18 | m.def("backward(int B, int T, int C, int H, Tensor r, Tensor w, Tensor k, Tensor v, Tensor a, Tensor b, Tensor saa, Tensor sss, Tensor(a!) zzz, Tensor gy, Tensor(b!) gr, Tensor(c!) gw, Tensor(d!) gk, Tensor(e!) gv, Tensor(f!) ga, Tensor(g!) gb) -> ()");
19 | }
20 |
21 | TORCH_LIBRARY_IMPL(wkv7g, CUDA, m) {
22 | m.impl("forward", &forward);
23 | m.impl("backward", &backward);
24 | }
25 |
--------------------------------------------------------------------------------
/rwkv_cuda/wkv7g_v1.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "ATen/ATen.h"
4 |
5 | typedef at::BFloat16 bf16;
6 |
7 | template
8 | __global__ void kernel_forward(const int B, const int T, const int C, const int H,
9 | const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
10 | F *__restrict__ const _y, float *__restrict__ const _saa, float *__restrict__ const _sss)
11 | {
12 | const int e = blockIdx.x / H;
13 | const int h = blockIdx.x % H;
14 | const int i = threadIdx.x;
15 |
16 | float state[_N_] = {0};
17 | __shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];
18 |
19 | float v[_T_];
20 | for (int _t = 0; _t < T; _t++)
21 | {
22 | const int t = e*T*C + h*_N_ + i + _t * C;
23 | v[_t] = float(_v[t]);
24 | }
25 |
26 | for (int _t = 0; _t < T; _t++)
27 | {
28 | const int t = e*T*C + h*_N_ + i + _t * C;
29 | __syncthreads();
30 | r[i] = float(_r[t]);
31 | w[i] = __expf(-__expf(float(_w[t])));
32 | k[i] = float(_k[t]);
33 | a[i] = float(_a[t]);
34 | b[i] = float(_b[t]);
35 | __syncthreads();
36 |
37 | float sa = 0;
38 | #pragma unroll
39 | for (int j = 0; j < _N_; j++)
40 | {
41 | sa += a[j] * state[j];
42 | }
43 | _saa[t] = float(sa);
44 |
45 | float vv = v[_t];
46 | float y = 0;
47 | #pragma unroll
48 | for (int j = 0; j < _N_; j++)
49 | {
50 | float& s = state[j];
51 | s = s * w[j] + sa * b[j] + k[j] * vv;
52 | y += s * r[j];
53 | }
54 | _y[t] = F(y);
55 |
56 | if ((_t+1) % _CHUNK_LEN_ == 0)
57 | {
58 | const int a = _t / _CHUNK_LEN_;
59 | const int c = _T_ / _CHUNK_LEN_;
60 | const int p = e*C*_N_*c + h*_N_*_N_*c + a*_N_*_N_ + i;
61 | #pragma unroll
62 | for (int j = 0; j < _N_; j++)
63 | {
64 | _sss[p + j*_N_] = float(state[j]);
65 | }
66 | }
67 | }
68 | }
69 |
70 | template
71 | __global__ void kernel_backward_zzz(const int B, const int T, const int C, const int H,
72 | const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _a, const F *__restrict__ const _b, const F *__restrict__ const _gy,
73 | float *__restrict__ const _zzz)
74 | {
75 | const int e = blockIdx.x / H;
76 | const int h = blockIdx.x % H;
77 | const int i = threadIdx.x;
78 |
79 | __shared__ float r[_N_], w[_N_], a[_N_], b[_N_];
80 |
81 | const int T_1 = e*T*C + (T-1)*C + h*_N_;
82 | float z[_N_];
83 | const float gy = _gy[T_1 + i];
84 | __syncthreads();
85 | r[i] = float(_r[T_1+i]);
86 | __syncthreads();
87 | #pragma unroll
88 | for (int j = 0; j < _N_; j++)
89 | {
90 | z[j] = gy * r[j];
91 | }
92 |
93 | for (int _t = T-2; _t > _CHUNK_LEN_-1; _t--)
94 | {
95 | const int t = e*T*C + h*_N_ + _t * C + i;
96 | const float gy = _gy[t];
97 | __syncthreads();
98 | r[i] = float(_r[t]);
99 | w[i] = __expf(-__expf(float(_w[t+C])));
100 | a[i] = float(_a[t+C]);
101 | b[i] = float(_b[t+C]);
102 | __syncthreads();
103 |
104 | float zz = 0;
105 | #pragma unroll
106 | for (int j = 0; j < _N_; j++)
107 | {
108 | zz += b[j] * z[j];
109 | }
110 | #pragma unroll
111 | for (int j = 0; j < _N_; j++)
112 | {
113 | z[j] = z[j] * w[j] + gy * r[j] + a[j] * zz;
114 | // printf("t %d i %d j %d z %f\n", _t, i, j, z[j]);
115 | }
116 | if (_t % _CHUNK_LEN_ == 0)
117 | {
118 | const int a = _t / _CHUNK_LEN_ - 1;
119 | const int c = _T_ / _CHUNK_LEN_ - 1;
120 | const int p = e*C*_N_*c + h*_N_*_N_*c + a*_N_*_N_ + i;
121 | #pragma unroll
122 | for (int j = 0; j < _N_; j++)
123 | {
124 | _zzz[p + j*_N_] = float(z[j]);
125 | }
126 | }
127 | }
128 | }
129 |
130 | template
131 | __global__ void kernel_backward_rwkv(const int B, const int T, const int C, const int H,
132 | const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b, const float *__restrict__ const _saa, const float *__restrict__ const _sss, const float *__restrict__ const _zzz,
133 | const F *__restrict__ const _gy, F *__restrict__ const _gr, F *__restrict__ const _gw, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _ga, F *__restrict__ const _gb)
134 | {
135 | const int e = blockIdx.x / H;
136 | const int h = blockIdx.x % H;
137 | const int chunk = threadIdx.x;
138 | const int n_chunk = _T_ / _CHUNK_LEN_;
139 |
140 | float zzz[_N_*_N_] = {0}, sss[_N_*_N_] = {999}, saa[_N_] = {999};
141 | float r[_N_] = {999}, w[_N_] = {0}, w1[_N_] = {999}, winv[_N_] = {999}, ww[_N_] = {999};
142 | float k[_N_] = {0}, v[_N_] = {999}, a[_N_] = {0}, a1[_N_] = {999}, b[_N_] = {0}, b1[_N_] = {999}, gy[_N_] = {999};
143 |
144 | if (chunk != n_chunk - 1)
145 | {
146 | const int p = e*T*C + (chunk+1)*_CHUNK_LEN_*C + h*_N_;
147 | for (int i = 0; i < _N_; i++)
148 | {
149 | k[i] = float(_k[p+i]);
150 | a[i] = float(_a[p+i]);
151 | b[i] = float(_b[p+i]);
152 | w[i] = __expf(-__expf(float(_w[p+i])));
153 | const int p = e*C*_N_*(n_chunk-1) + h*_N_*_N_*(n_chunk-1) + chunk*_N_*_N_ + i*_N_;
154 | #pragma unroll
155 | for (int j = 0; j < _N_; j++)
156 | {
157 | zzz[i*_N_+j] = float(_zzz[p+j]);
158 | }
159 | }
160 | }
161 | for (int i = 0; i < _N_; i++)
162 | {
163 | const int p = e*C*_N_*n_chunk + h*_N_*_N_*n_chunk + chunk*_N_*_N_ + i*_N_;
164 | #pragma unroll
165 | for (int j = 0; j < _N_; j++)
166 | {
167 | sss[i*_N_+j] = float(_sss[p+j]);
168 | }
169 | }
170 |
171 | for (int _t = _CHUNK_LEN_-1; _t > -1; _t--)
172 | {
173 | const int t = chunk * _CHUNK_LEN_ + _t;
174 | const int b_t_h = e*T*C + t*C + h*_N_;
175 | #pragma unroll
176 | for (int n = 0; n < _N_; n++)
177 | {
178 | w1[n] = w[n];
179 | a1[n] = a[n];
180 | b1[n] = b[n];
181 |
182 | r[n] = float(_r[b_t_h+n]);
183 | k[n] = float(_k[b_t_h+n]);
184 | v[n] = float(_v[b_t_h+n]);
185 | a[n] = float(_a[b_t_h+n]);
186 | b[n] = float(_b[b_t_h+n]);
187 | gy[n] = float(_gy[b_t_h+n]);
188 | saa[n] = float(_saa[b_t_h+n]);
189 |
190 | ww[n] = -__expf(float(_w[b_t_h+n]));
191 | w[n] = __expf(ww[n]);
192 | ww[n] = ww[n] * w[n];
193 | winv[n] = 1.0f / w[n];
194 | }
195 |
196 | for (int j = 0; j < _N_; j++)
197 | {
198 | float zz = 0;
199 | #pragma unroll
200 | for (int i = 0; i < _N_; i++)
201 | {
202 | zz += b1[i] * zzz[i*_N_+j];
203 | }
204 | const float gyj = gy[j];
205 | #pragma unroll
206 | for (int i = 0; i < _N_; i++)
207 | {
208 | zzz[i*_N_+j] = zzz[i*_N_+j] * w1[i] + gyj * r[i] + a1[i] * zz;
209 | // printf("t %d i %d j %d z %f\n",t,i,j,zzz[i*_N_+j]);
210 | // printf("t %d i %d j %d s %f\n",t,i,j,sss[i*_N_+j]);
211 | }
212 | }
213 |
214 | for (int i = 0; i < _N_; i++)
215 | {
216 | float gr = 0;
217 | #pragma unroll
218 | for (int j = 0; j < _N_; j++)
219 | {
220 | gr += gy[j] * sss[i*_N_+j];
221 | }
222 | _gr[b_t_h+i] = F(gr);
223 | }
224 |
225 | for (int i = 0; i < _N_; i++)
226 | {
227 | const float ki = k[i];
228 | const float bi = b[i];
229 | const float wi = winv[i];
230 | #pragma unroll
231 | for (int j = 0; j < _N_; j++)
232 | {
233 | sss[i*_N_+j] = (sss[i*_N_+j] - ki * v[j] - bi * saa[j]) * wi;
234 | }
235 | }
236 |
237 | float gv[_N_] = {0}; float as[_N_] = {0}; float bz[_N_] = {0};
238 | for (int i = 0; i < _N_; i++)
239 | {
240 | const float ki = k[i];
241 | const float ai = a[i];
242 | const float bi = b[i];
243 | float gw = 0;
244 | float gk = 0;
245 | #pragma unroll
246 | for (int j = 0; j < _N_; j++)
247 | {
248 | const float sij = sss[i*_N_+j];
249 | const float zij = zzz[i*_N_+j];
250 | gv[j] += ki * zij;
251 | as[j] += ai * sij;
252 | bz[j] += bi * zij;
253 | gw += sij * zij;
254 | gk += v[j] * zij;
255 | }
256 | _gw[b_t_h+i] = F(gw * ww[i]);
257 | _gk[b_t_h+i] = F(gk);
258 | }
259 | for (int i = 0; i < _N_; i++)
260 | {
261 | _gv[b_t_h+i] = F(gv[i]);
262 | float ga = 0;
263 | float gb = 0;
264 | #pragma unroll
265 | for (int j = 0; j < _N_; j++)
266 | {
267 | ga += bz[j] * sss[i*_N_+j];
268 | gb += as[j] * zzz[i*_N_+j];
269 | }
270 | _ga[b_t_h+i] = F(ga);
271 | _gb[b_t_h+i] = F(gb);
272 | }
273 | }
274 | }
275 |
276 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y, float *saa, float* sss)
277 | {
278 | assert(H*_N_ == C);
279 | kernel_forward<<>>(B, T, C, H, r, w, k, v, a, b, y, saa, sss);
280 | }
281 |
282 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, float *saa, float* sss, float* zzz, bf16 *gy, bf16 *gr, bf16 *gw, bf16 *gk, bf16 *gv, bf16 *ga, bf16 *gb)
283 | {
284 | assert(H*_N_ == C);
285 | assert(T%_CHUNK_LEN_ == 0);
286 |
287 | kernel_backward_zzz<<>>(B, T, C, H, r, w, k, a, b, gy, zzz);
288 | kernel_backward_rwkv<<>>(B, T, C, H, r, w, k, v, a, b, saa, sss, zzz, gy, gr, gw, gk, gv, ga, gb);
289 | }
290 |
--------------------------------------------------------------------------------
/rwkv_cuda_wind/backstepping_f32.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | using bf = __nv_bfloat16;
4 |
5 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa);
6 |
7 | void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) {
8 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
9 | cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr());
10 | }
11 |
12 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da);
13 |
14 | void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy,
15 | torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) {
16 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
17 | cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)dy.data_ptr(),
18 | (float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr());
19 | }
20 |
21 | TORCH_LIBRARY(wind_backstepping, m) {
22 | m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa) -> ()");
23 | m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) dz, Tensor(f!) da) -> ()");
24 | }
25 |
26 | TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) {
27 | m.impl("forward", &forward);
28 | m.impl("backward", &backward);
29 | }
30 |
--------------------------------------------------------------------------------
/rwkv_cuda_wind/backstepping_f32_1.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | using bf = __nv_bfloat16;
5 | __device__ inline float to_float(const bf & u) { return __bfloat162float(u); }
6 | __device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); }
7 |
8 | typedef bf * __restrict__ F_;
9 |
10 | __global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, bf* y_, float* s_, float* sa_) {
11 | constexpr int C = _C_;
12 | int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
13 |
14 | float state[C] = {0};
15 | __shared__ float q[C], k[C], w[C], a[C], b[C];
16 |
17 | for (int t = 0; t < T; t++) {
18 | int ind = bb*T*H*C + t*H*C + hh * C + i;
19 | __syncthreads();
20 | q[i] = to_float(q_[ind]);
21 | w[i] = __expf(-__expf(to_float(w_[ind])));
22 | k[i] = to_float(k_[ind]);
23 | a[i] = to_float(a_[ind]);
24 | b[i] = to_float(b_[ind]);
25 | __syncthreads();
26 |
27 | float sa = 0;
28 | #pragma unroll
29 | for (int j = 0; j < C; j++) {
30 | sa += a[j] * state[j];
31 | }
32 | sa_[ind] = sa;
33 |
34 | float v = to_float(v_[ind]);
35 | float y = 0;
36 | #pragma unroll
37 | for (int j = 0; j < C; j++) {
38 | float& s = state[j];
39 | s = s * w[j] + sa * b[j] + k[j] * v;
40 | y += s * q[j];
41 | }
42 | y_[ind] = to_bf(y);
43 |
44 | if ((t+1)%_CHUNK_LEN_ == 0) {
45 | int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i;
46 | #pragma unroll
47 | for (int j = 0; j < C; j++) {
48 | s_[base + j*C] = state[j];
49 | }
50 | }
51 | }
52 | }
53 |
54 | __global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) {
55 | constexpr int C = _C_;
56 | int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
57 |
58 | float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
59 | __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
60 | float qi, wi, ki, ai, bi, dyi;
61 |
62 | for (int t = T-1; t >= 0; t--) {
63 | int ind = bb*T*H*C + t*H*C + hh * C + i;
64 | __syncthreads();
65 | q[i] = qi = to_float(q_[ind]);
66 | float wi_fac = -__expf(to_float(w_[ind]));
67 | w[i] = wi = __expf(wi_fac);
68 | k[i] = ki = to_float(k_[ind]);
69 | a[i] = ai = to_float(a_[ind]);
70 | b[i] = bi = to_float(b_[ind]);
71 | v[i] = to_float(v_[ind]);
72 | dy[i] = dyi = to_float(dy_[ind]);
73 | sa[i] = sa_[ind];
74 | __syncthreads();
75 |
76 | if ((t+1)%_CHUNK_LEN_ == 0) {
77 | int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C;
78 | #pragma unroll
79 | for (int j = 0; j < C; j++) {
80 | stateT[j] = s_[base + j];
81 | }
82 | }
83 |
84 | float dq = 0;
85 | #pragma unroll
86 | for (int j = 0; j < C; j++) {
87 | dq += stateT[j]*dy[j];
88 | }
89 | dq_[ind] = to_bf(dq);
90 |
91 | float iwi = 1.0f/wi;
92 | #pragma unroll
93 | for (int j = 0; j < C; j++) {
94 | stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi;
95 | dstate[j] += dyi * q[j];
96 | dstateT[j] += qi * dy[j];
97 | }
98 |
99 | float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0;
100 | #pragma unroll
101 | for (int j = 0; j < C; j++) {
102 | dw += dstateT[j]*stateT[j];
103 | dk += dstateT[j]*v[j];
104 | dv += dstate[j]*k[j];
105 | dSb += dstate[j]*b[j];
106 | db += dstateT[j]*sa[j];
107 | }
108 | dw_[ind] = to_bf(dw * wi * wi_fac);
109 | dk_[ind] = to_bf(dk);
110 | dv_[ind] = to_bf(dv);
111 | db_[ind] = to_bf(db);
112 |
113 | __syncthreads();
114 | dSb_shared[i] = dSb;
115 | __syncthreads();
116 |
117 | float da = 0;
118 | #pragma unroll
119 | for (int j = 0; j < C; j++) {
120 | da += stateT[j]*dSb_shared[j];
121 | }
122 | da_[ind] = to_bf(da);
123 |
124 | #pragma unroll
125 | for (int j = 0; j < C; j++) {
126 | dstate[j] = dstate[j]*w[j] + dSb * a[j];
127 | dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j];
128 | }
129 | }
130 | }
131 |
132 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa) {
133 | forward_kernel<<>>(T,H,w,q,k,v,z,a,y,s,sa);
134 | }
135 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da) {
136 | assert(T%_CHUNK_LEN_ == 0);
137 | backward_kernel<<>>(T,H,w,q,k,v,z,a,dy,s,sa,dw,dq,dk,dv,dz,da);
138 | }
139 |
--------------------------------------------------------------------------------
/rwkv_cuda_wind/backstepping_f32_2.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | using bf = __nv_bfloat16;
5 | __device__ inline float to_float(const bf & u) { return __bfloat162float(u); }
6 | __device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); }
7 |
8 | typedef bf * __restrict__ F_;
9 |
10 | __global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, bf* y_, float* s_, float* sa_) {
11 | constexpr int C = _C_;
12 | int bind = blockIdx.y, hind = blockIdx.x, i = threadIdx.x;
13 |
14 | float state[C] = {0};
15 | __shared__ float q[C], k[C], w[C], a[C], b[C];
16 |
17 | for (int t = 0; t < T; t++) {
18 | int ind = bind*T*H*C + t*H*C + hind * C + i;
19 | __syncthreads();
20 | q[i] = to_float(q_[ind]);
21 | w[i] = __expf(-__expf(to_float(w_[ind])));
22 | k[i] = to_float(k_[ind]);
23 | a[i] = to_float(a_[ind]);
24 | b[i] = to_float(b_[ind]);
25 | __syncthreads();
26 |
27 | float sa = 0;
28 | #pragma unroll
29 | for (int j = 0; j < C; j++) {
30 | sa += a[j] * state[j];
31 | }
32 | sa_[ind] = sa;
33 |
34 | float v = to_float(v_[ind]);
35 | float y = 0;
36 | #pragma unroll
37 | for (int j = 0; j < C; j++) {
38 | float& s = state[j];
39 | s = s * w[j] + sa * b[j] + k[j] * v;
40 | y += s * q[j];
41 | }
42 | y_[ind] = to_bf(y);
43 |
44 | if ((t+1)%_CHUNK_LEN_ == 0) {
45 | int base = (bind*H+hind)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i;
46 | #pragma unroll
47 | for (int j = 0; j < C; j++) {
48 | s_[base + j*C] = state[j];
49 | }
50 | }
51 | }
52 | }
53 |
54 | __global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) {
55 | constexpr int C = _C_;
56 | int bind = blockIdx.y, hind = blockIdx.x, i = threadIdx.x;
57 |
58 | float stateT[C] = {0};
59 | __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
60 |
61 | extern __shared__ char smem_[];
62 | float*dstate = (float*)smem_; //[C*(C+1)];
63 |
64 | for (int j = 0; j < C; j++) {
65 | dstate[i*(C+1)+j] = 0;
66 | }
67 |
68 | for (int t = T-1; t >= 0; t--) {
69 | int ind = bind*T*H*C + t*H*C + hind * C + i;
70 | float bi, ki, dyi, wi;
71 | __syncthreads();
72 | q[i] = to_float(q_[ind]);
73 | float wi_fac = -__expf(to_float(w_[ind]));
74 | w[i] = wi = __expf(wi_fac);
75 | k[i] = ki = to_float(k_[ind]);
76 | a[i] = to_float(a_[ind]);
77 | b[i] = bi = to_float(b_[ind]);
78 | v[i] = to_float(v_[ind]);
79 | dy[i] = dyi = to_float(dy_[ind]);
80 | sa[i] = sa_[ind];
81 | __syncthreads();
82 |
83 | if ((t+1)%_CHUNK_LEN_ == 0) {
84 | int base = (bind*H+hind)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C;
85 | #pragma unroll
86 | for (int j = 0; j < C; j++) {
87 | stateT[j] = s_[base + j];
88 | }
89 | }
90 |
91 | float dq = 0;
92 | #pragma unroll
93 | for (int j = 0; j < C; j++) {
94 | dq += stateT[j]*dy[j];
95 | }
96 | dq_[ind] = to_bf(dq);
97 |
98 | float iwi = 1.0f/wi;
99 | for (int j = 0; j < C; j++) {
100 | stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi;
101 | dstate[i*(C+1)+j] += dyi * q[j];
102 | }
103 |
104 | float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0;
105 | #pragma unroll
106 | for (int j = 0; j < C; j++) {
107 | float ds = dstate[j*(C+1)+i];
108 | dw += ds*stateT[j];
109 | dk += ds*v[j];
110 | db += ds*sa[j];
111 | }
112 | #pragma unroll
113 | for (int j = 0; j < C; j++) {
114 | float ds = dstate[i*(C+1)+j];
115 | dv += ds*k[j];
116 | dSb += ds*b[j];
117 | }
118 | dw_[ind] = to_bf(dw * wi * wi_fac);
119 | dk_[ind] = to_bf(dk);
120 | dv_[ind] = to_bf(dv);
121 | db_[ind] = to_bf(db);
122 |
123 | __syncthreads();
124 | dSb_shared[i] = dSb;
125 | __syncthreads();
126 |
127 | float da = 0;
128 | #pragma unroll
129 | for (int j = 0; j < C; j++) {
130 | da += stateT[j]*dSb_shared[j];
131 | }
132 | da_[ind] = to_bf(da);
133 |
134 | for (int j = 0; j < C; j++) {
135 | dstate[i*(C+1)+j] = dstate[i*(C+1)+j]*w[j] + dSb * a[j];
136 | }
137 | }
138 | }
139 |
140 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa) {
141 | forward_kernel<<>>(T,H,w,q,k,v,z,a,y,s,sa);
142 | }
143 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da) {
144 | assert(T%_CHUNK_LEN_ == 0);
145 | int shared_mem = _C_*(_C_+1)*4;
146 | assert(!cudaFuncSetAttribute(backward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem));
147 | backward_kernel<<>>(T,H,w,q,k,v,z,a,dy,s,sa,dw,dq,dk,dv,dz,da);
148 | }
149 |
--------------------------------------------------------------------------------
/rwkv_cuda_wind/tile.cuh:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | //TODO: static? inline? __align__(16)?
5 |
6 | using bf = __nv_bfloat16;
7 | using bf2 = __nv_bfloat162;
8 | using uint = unsigned int;
9 | __device__ inline float to_float(const bf & u) { return __bfloat162float(u); }
10 | __device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); }
11 | __device__ inline float2 to_float2(const bf2 & u) { return __bfloat1622float2(u); }
12 | __device__ inline float2 to_float2(const float2 & u) { return u; }
13 | __device__ inline bf2 to_bf2(const float2 & u) { return __float22bfloat162_rn(u); }
14 | __device__ inline uint& as_uint(const bf2&x) { return *((uint*)(&x)); }
15 | __device__ inline uint __smem(const void*x) { return __cvta_generic_to_shared(x); }
16 |
17 | __device__ void __commit_group() { asm volatile("cp.async.commit_group;\n" ::); }
18 | __device__ void __wait_group() { asm volatile("cp.async.wait_all;\n" ::); }
19 | template __device__ void __wait_groups() { asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); }
20 |
21 | __device__ void __copy_wait() { __commit_group(); __wait_group(); }
22 |
23 | __device__ void operator*=(float2&a, const float2&b) { a.x *= b.x; a.y *= b.y; }
24 | __device__ void operator+=(float2&a, const float2&b) { a.x += b.x; a.y += b.y; }
25 | __device__ float2 operator+(const float2&a, const float2&b) { return {a.x+b.x,a.y+b.y}; }
26 | __device__ float2 operator*(const float2&a, const float2&b) { return {a.x*b.x,a.y*b.y}; }
27 |
28 | struct STile;
29 | struct RTile;
30 | struct FTile;
31 |
32 | struct GTile {
33 | bf*ga;
34 | int stride;
35 | __device__ GTile(bf*ga_, int stride_) : ga(ga_), stride(stride_) {}
36 | __device__ GTile& operator=(const RTile&);
37 | };
38 | struct GFTile {
39 | float*ga;
40 | int stride;
41 | __device__ GFTile(float*ga_, int stride_) : ga(ga_), stride(stride_) {}
42 | __device__ GFTile& operator=(const FTile&);
43 | };
44 | struct STileT { STile*st; };
45 |
46 | struct __align__(16) STile {
47 | bf data[16*16];
48 | __device__ STile() {}
49 | __device__ STile(const RTile&o) { *this=o; }
50 | __device__ STile& operator=(const GTile&);
51 | __device__ STile& operator=(const RTile&);
52 | __device__ STileT t() { return STileT{this}; }
53 | };
54 | struct Product { const RTile*a, *b; };
55 | struct ProductPlus { const RTile*a, *b; const FTile* c; };
56 | struct RTile {
57 | bf2 data[4];
58 | __device__ RTile() {}
59 | __device__ void zero_() { data[0] = data[1] = data[2] = data[3] = to_bf2({0.f,0.f}); }
60 | __device__ RTile(const STile&o) { *this=o; }
61 | __device__ RTile(const STileT&o) { *this=o; }
62 | __device__ RTile(const FTile&o) { *this=o; }
63 | __device__ RTile& operator=(const STile&);
64 | __device__ RTile& operator=(const STileT&);
65 | __device__ RTile& operator=(const FTile&fa);
66 | __device__ RTile& operator=(const GTile&);
67 | };
68 | struct FTile {
69 | union {
70 | float2 data[4];
71 | float fdata[8];
72 | };
73 | __device__ void zero_() { data[0] = data[1] = data[2] = data[3] = {0.f,0.f}; }
74 | __device__ FTile() {}
75 | __device__ FTile(const FTile&o) { for (int i = 0; i < 4; i++) data[i] = o.data[i]; }
76 | __device__ FTile(const RTile&r) { *this=r; }
77 | __device__ FTile(const Product&p) { *this=p; }
78 | __device__ FTile(const ProductPlus&p) { *this=p; }
79 | __device__ FTile& operator=(const Product&);
80 | __device__ FTile& operator=(const RTile&);
81 | __device__ FTile& operator=(const ProductPlus&);
82 | __device__ FTile& operator+=(const Product&);
83 | __device__ FTile& operator+=(const FTile&o) { for (int i = 0; i < 4; i++) data[i] += o.data[i]; return *this; }
84 | };
85 |
86 | __device__ void print(STile t) {
87 | if (threadIdx.x == 0) {
88 | for (int i = 0; i < 16; i++) {
89 | for (int j = 0; j < 16; j++) {
90 | printf("%f ", to_float(t.data[i*16+j]));
91 | }
92 | printf("\n");
93 | }
94 | printf("\n");
95 | }
96 | }
97 |
98 | template
99 | __device__ void print(T t, int warpi = 0) {
100 | int tid = threadIdx.x - warpi*32;
101 | for (int i = 0; i < 16; i++) {
102 | for (int j = 0; j < 16; j += 2) {
103 | if (tid == i%8*4+j%8/2) {
104 | float2 xy = to_float2(t.data[i/8+j/8*2]);
105 | printf("%f %f ", xy.x, xy.y);
106 | //printf("T%d:{a%d,a%d} ", threadIdx.x, (i/8+j/8*2)*2, (i/8+j/8*2)*2+1);
107 | }
108 | __syncthreads();
109 | }
110 | if (tid == 0) printf("\n");
111 | __syncthreads();
112 | }
113 | if (tid == 0) printf("\n");
114 | __syncthreads();
115 | }
116 |
117 | template
118 | __device__ void print8(T mat) {
119 | for (int i = 0; i < 8; i++) {
120 | for (int j = 0; j < 8; j += 2) {
121 | if (threadIdx.x == i%8*4+j%8/2) {
122 | float2 xy = to_float2(mat);
123 | printf("%f %f ", xy.x, xy.y);
124 | }
125 | __syncthreads();
126 | }
127 | if (threadIdx.x == 0) printf("\n");
128 | __syncthreads();
129 | }
130 | if (threadIdx.x == 0) printf("\n");
131 | __syncthreads();
132 | }
133 |
134 |
135 |
136 | __device__ void load(STile&sa, bf*ga, int stride) {
137 | int i = threadIdx.x%32/2, j = threadIdx.x%2;
138 | asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" :: "r"(__smem(&sa.data[i*16+j*8])), "l"(ga+stride*i+j*8), "n"(16));
139 | }
140 |
141 | __device__ void load(RTile&ra, const STile&sa) {
142 | int i = threadIdx.x%8, j = threadIdx.x%32/16, k = threadIdx.x/8%2;
143 | asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
144 | : "=r"(as_uint(ra.data[0])), "=r"(as_uint(ra.data[1])), "=r"(as_uint(ra.data[2])), "=r"(as_uint(ra.data[3]))
145 | : "r"(__smem(&sa.data[i*16+j*8+k*8*16])));
146 | }
147 | __device__ void loadT(RTile&ra, const STile&sa) {
148 | int i = threadIdx.x%8, j = threadIdx.x%32/16, k = threadIdx.x/8%2;
149 | asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
150 | : "=r"(as_uint(ra.data[0])), "=r"(as_uint(ra.data[1])), "=r"(as_uint(ra.data[2])), "=r"(as_uint(ra.data[3]))
151 | : "r"(__smem(&sa.data[i*16+j*8*16+k*8])));
152 | }
153 |
154 | __device__ static inline void __m16n8k16(float2&d0, float2&d1, const bf2 &a0, const bf2 &a1, const bf2 &a2, const bf2 &a3, const bf2 &b0, const bf2 &b1, const float2 &c0, const float2 &c1) {
155 | asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
156 | : "=f"(d0.x), "=f"(d0.y), "=f"(d1.x), "=f"(d1.y)
157 | : "r"(as_uint(a0)), "r"(as_uint(a1)), "r"(as_uint(a2)), "r"(as_uint(a3)),
158 | "r"(as_uint(b0)), "r"(as_uint(b1)),
159 | "f"(c0.x), "f"(c0.y), "f"(c1.x), "f"(c1.y));
160 | }
161 | __device__ void mma(FTile&rd, const RTile&ra, const RTile&rb, const FTile&rc) { // d = a*b^T + c
162 | __m16n8k16(rd.data[0],rd.data[1], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[0],rb.data[2], rc.data[0],rc.data[1]);
163 | __m16n8k16(rd.data[2],rd.data[3], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[1],rb.data[3], rc.data[2],rc.data[3]);
164 | }
165 | __device__ static inline void __m16n8k16(float2&d0, float2&d1, const bf2 &a0, const bf2 &a1, const bf2 &a2, const bf2 &a3, const bf2 &b0, const bf2 &b1) {
166 | asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
167 | : "+f"(d0.x), "+f"(d0.y), "+f"(d1.x), "+f"(d1.y)
168 | : "r"(as_uint(a0)), "r"(as_uint(a1)), "r"(as_uint(a2)), "r"(as_uint(a3)),
169 | "r"(as_uint(b0)), "r"(as_uint(b1)),
170 | "f"(d0.x), "f"(d0.y), "f"(d1.x), "f"(d1.y));
171 | }
172 | __device__ void mma(FTile&rd, const RTile&ra, const RTile&rb) { // d += a*b^T
173 | __m16n8k16(rd.data[0],rd.data[1], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[0],rb.data[2]);
174 | __m16n8k16(rd.data[2],rd.data[3], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[1],rb.data[3]);
175 | }
176 | __device__ void mm(FTile&rd, const RTile&ra, const RTile&rb) { // d = a*b^T
177 | __m16n8k16(rd.data[0],rd.data[1], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[0],rb.data[2], {0.f,0.f}, {0.f,0.f});
178 | __m16n8k16(rd.data[2],rd.data[3], ra.data[0],ra.data[1],ra.data[2],ra.data[3], rb.data[1],rb.data[3], {0.f,0.f}, {0.f,0.f});
179 | }
180 |
181 | __device__ void store(const FTile&ra, float*ga, int stride) {
182 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2;
183 | *((float2*)&ga[ i *stride+j ]) = ra.data[0];
184 | *((float2*)&ga[(i+8)*stride+j ]) = ra.data[1];
185 | *((float2*)&ga[ i *stride+j+8]) = ra.data[2];
186 | *((float2*)&ga[(i+8)*stride+j+8]) = ra.data[3];
187 | }
188 |
189 | __device__ void store(const RTile&ra, bf*ga, int stride) {
190 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2;
191 | *((bf2*)&ga[ i *stride+j ]) = ra.data[0];
192 | *((bf2*)&ga[(i+8)*stride+j ]) = ra.data[1];
193 | *((bf2*)&ga[ i *stride+j+8]) = ra.data[2];
194 | *((bf2*)&ga[(i+8)*stride+j+8]) = ra.data[3];
195 | }
196 | __device__ void load(RTile&ra, bf*ga, int stride) {
197 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2;
198 | ra.data[0] = *((bf2*)&ga[ i *stride+j ]);
199 | ra.data[1] = *((bf2*)&ga[(i+8)*stride+j ]);
200 | ra.data[2] = *((bf2*)&ga[ i *stride+j+8]);
201 | ra.data[3] = *((bf2*)&ga[(i+8)*stride+j+8]);
202 | }
203 | __device__ void store(const RTile&ra, STile&sa) { //TODO: reduce bank conflicts?
204 | int i = threadIdx.x%32/4, j = threadIdx.x%4*2;
205 | *((bf2*)&sa.data[ i *16+j ]) = ra.data[0];
206 | *((bf2*)&sa.data[(i+8)*16+j ]) = ra.data[1];
207 | *((bf2*)&sa.data[ i *16+j+8]) = ra.data[2];
208 | *((bf2*)&sa.data[(i+8)*16+j+8]) = ra.data[3];
209 | }
210 |
211 | __device__ void convert(RTile&ra, const FTile&fa) {
212 | ra.data[0] = to_bf2(fa.data[0]);
213 | ra.data[1] = to_bf2(fa.data[1]);
214 | ra.data[2] = to_bf2(fa.data[2]);
215 | ra.data[3] = to_bf2(fa.data[3]);
216 | }
217 | __device__ void convert(FTile&fa, const RTile&ra) {
218 | fa.data[0] = to_float2(ra.data[0]);
219 | fa.data[1] = to_float2(ra.data[1]);
220 | fa.data[2] = to_float2(ra.data[2]);
221 | fa.data[3] = to_float2(ra.data[3]);
222 | }
223 |
224 | __device__ STile& STile::operator=(const GTile& ga) { load(*this, ga.ga, ga.stride); return *this; }
225 | __device__ RTile& RTile::operator=(const GTile& ga) { load(*this, ga.ga, ga.stride); return *this; }
226 | __device__ RTile& RTile::operator=(const STile& sa) { load(*this, sa); return *this; }
227 | __device__ STile& STile::operator=(const RTile& ra) { store(ra, *this); return *this; }
228 | __device__ RTile& RTile::operator=(const STileT& sa) { loadT(*this, *sa.st); return *this; }
229 | __device__ Product operator%(const RTile&ra, const RTile&rb) { return Product{&ra,&rb}; }
230 | __device__ ProductPlus operator+(const Product&prod, const FTile&rc) { return ProductPlus{prod.a,prod.b,&rc}; }
231 | __device__ FTile& FTile::operator=(const Product& prod) { mm(*this, *prod.a, *prod.b); return *this; }
232 | __device__ FTile& FTile::operator=(const ProductPlus& prod) { mma(*this, *prod.a, *prod.b, *prod.c); return *this; }
233 | __device__ FTile& FTile::operator+=(const Product& prod) { mma(*this, *prod.a, *prod.b); return *this; }
234 | __device__ RTile& RTile::operator=(const FTile&fa) { convert(*this,fa); return *this; }
235 | __device__ FTile& FTile::operator=(const RTile&ra) { convert(*this,ra); return *this; }
236 | __device__ GTile& GTile::operator=(const RTile&ra) { store(ra, this->ga, this->stride); return *this; }
237 | __device__ GFTile& GFTile::operator=(const FTile&fa) { store(fa, this->ga, this->stride); return *this; }
238 |
239 | // Is this kind of cumsum better than multiplying with a triangular matrix of ones?
240 | template
241 | __device__ FTile cumsumv(FTile&w) {
242 | int tid = threadIdx.x%32, t = tid/4;
243 |
244 | FTile ret;
245 | if (inclusive) for (int i = 0; i < 4; i++) ret.data[i] = w.data[i];
246 | else for (int i = 0; i < 4; i++) ret.data[i] = float2{0.f,0.f};
247 |
248 | for (int b = 0; b < 3; b++) {
249 | for (int i = 0; i < 8; i++) {
250 | float other_w = __shfl_xor_sync(0xffffffff, w.fdata[i], 4<>b)%2 == !rev) ret.fdata[i] += other_w;
252 | w.fdata[i] += other_w;
253 | }
254 | }
255 | for (int i : {0,1,4,5}) {
256 | float &w0 = w.fdata[i^(2*!rev)], &w1 = w.fdata[i^(2*rev)];
257 | ret.fdata[i^(2*!rev)] += w1;
258 | w0 += w1;
259 | w1 = w0;
260 | }
261 | return ret;
262 | }
263 |
264 | template
265 | __device__ FTile cumprodv(FTile&w) {
266 | int tid = threadIdx.x%32, t = tid/4;
267 |
268 | FTile ret;
269 | if (inclusive) for (int i = 0; i < 4; i++) ret.data[i] = w.data[i];
270 | else for (int i = 0; i < 4; i++) ret.data[i] = float2{1.f,1.f};
271 |
272 | for (int b = 0; b < 3; b++) {
273 | for (int i = 0; i < 8; i++) {
274 | float other_w = __shfl_xor_sync(0xffffffff, w.fdata[i], 4<>b)%2 == !rev) ret.fdata[i] *= other_w;
276 | w.fdata[i] *= other_w;
277 | }
278 | }
279 | for (int i : {0,1,4,5}) {
280 | float &w0 = w.fdata[i^(2*!rev)], &w1 = w.fdata[i^(2*rev)];
281 | ret.fdata[i^(2*!rev)] *= w1;
282 | w0 *= w1;
283 | w1 = w0;
284 | }
285 | return ret;
286 | }
287 |
288 | __device__ FTile operator*(const FTile&a, const FTile&b) {
289 | FTile ret;
290 | for (int i = 0; i < 8; i++) ret.fdata[i] = a.fdata[i]*b.fdata[i];
291 | return ret;
292 | }
293 |
294 | template // Lower triangular
295 | __device__ FTile sum_warp(float2*share, const FTile&f) {
296 | int tid = threadIdx.x%32, warpi = threadIdx.x/32;
297 | FTile sum;
298 | sum.zero_();
299 | for (int i : {0,1,2,3}) {
300 | if (i == 2 && triangular) continue;
301 | for (int j = 0; j < WARPS; j++) {
302 | if (warpi == j) share[tid] = f.data[i];
303 | __syncthreads();
304 | sum.data[i].x += share[tid].x;
305 | sum.data[i].y += share[tid].y;
306 | __syncthreads();
307 | }
308 | }
309 | return sum;
310 | }
311 |
312 | __device__ RTile from_warp(const RTile&ra, int src, float4*share) {
313 | int tid = threadIdx.x%32, warpi = threadIdx.x/32;
314 | RTile ret;
315 | if (warpi == src) share[tid] = *((float4*)ra.data);
316 | __syncthreads();
317 | *((float4*)ret.data) = share[tid];
318 | __syncthreads();
319 | return ret;
320 | }
321 |
322 | // inv(I-f) where f is strictly lower triangular
323 | __device__ FTile tri_minv(const FTile&f, float*share) {
324 | int i0 = threadIdx.x%32/4, j0 = threadIdx.x%4*2;
325 | float inv[16] = {};
326 | for (int k = 0; k < 8; k++) {
327 | int i = i0+k/2%2*8, j = j0+k%2+k/4*8;
328 | share[i*16+j] = f.fdata[k];
329 | }
330 | int tid = threadIdx.x%32;
331 | inv[tid%16] = 1;
332 | for (int i = 1; i < 16; i++) {
333 | for (int j = 0; j < i; j++) {
334 | float fac = share[i*16+j];
335 | inv[i] += fac*inv[j];
336 | }
337 | }
338 | for (int i = 0; i < 16; i++)
339 | share[tid*16+i] = inv[i];
340 | FTile ret;
341 | for (int k = 0; k < 8; k++) {
342 | int i = i0+k/2%2*8, j = j0+k%2+k/4*8;
343 | ret.fdata[k] = share[j*16+i];
344 | }
345 | return ret;
346 | }
347 |
348 | template
349 | __device__ FTile tril(const FTile&f) {
350 | int i0 = threadIdx.x%32/4, j0 = threadIdx.x%4*2;
351 | FTile ret;
352 | for (int k = 0; k < 8; k++) {
353 | int i = i0+k/2%2*8, j = j0+k%2+k/4*8;
354 | if (strict) ret.fdata[k] = (i>j ? f.fdata[k] : 0.f);
355 | else ret.fdata[k] = (i>=j ? f.fdata[k] : 0.f);
356 | }
357 | return ret;
358 | }
359 |
360 | template
361 | __device__ void apply_(FTile&tile, F f) {
362 | for (int i = 0; i < 8; i++) tile.fdata[i] = f(tile.fdata[i]);
363 | }
364 |
365 | __device__ bf2 transpose(bf2 a) {
366 | bf2 ret;
367 | asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(as_uint(ret)) : "r"(as_uint(a)));
368 | return ret;
369 | }
370 |
371 | __device__ RTile transpose(const RTile&ra) {
372 | RTile rb;
373 | rb.data[0] = transpose(ra.data[0]);
374 | rb.data[1] = transpose(ra.data[2]);
375 | rb.data[2] = transpose(ra.data[1]);
376 | rb.data[3] = transpose(ra.data[3]);
377 | return rb;
378 | }
379 |
380 | template
381 | __device__ FTile slow_dw(const RTile&A, const RTile&q, const RTile&k, STile*share) {
382 | share[0] = A;
383 | share[1] = q;
384 | share[2] = k;
385 | __syncthreads();
386 | if (threadIdx.x%32 == 0) {
387 | for (int k = 0; k < 16; k++) {
388 | for (int j = 0; j < 16; j++) {
389 | float sum = 0;
390 | for (int l = 0; l < k; l++) {
391 | for (int r = k+strict; r < 16; r++) {
392 | sum += to_float(share[0].data[r*16+l]) * to_float(share[1].data[r*16+j]) * to_float(share[2].data[l*16+j]);
393 | }
394 | }
395 | share[3].data[k*16+j] = to_bf(sum);
396 | }
397 | }
398 | }
399 | __syncthreads();
400 | RTile ret = (RTile)share[3];
401 | __syncthreads();
402 | return ret;
403 | }
404 |
405 |
406 | __device__ static inline void __m16n8k8(float2&d0, float2&d1, const bf2 &a0, const bf2 &a1, const bf2 &b0) {
407 | asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};"
408 | : "=f"(d0.x), "=f"(d0.y), "=f"(d1.x), "=f"(d1.y) : "r"(as_uint(a0)), "r"(as_uint(a1)), "r"(as_uint(b0)), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f));
409 | }
410 |
411 | template
412 | __device__ RTile fast_dw(const RTile&A, const RTile&q, const RTile&k) {
413 | float2 qkA8[4];
414 | RTile kt = transpose(k), qt = transpose(q);
415 | __m16n8k8(qkA8[0],qkA8[1], qt.data[2], qt.data[3], transpose(A.data[1]));
416 | __m16n8k8(qkA8[2],qkA8[3], kt.data[0], kt.data[1], A.data[1]);
417 | for (int x : {0,1}) {
418 | qkA8[x] *= to_float2(kt.data[x]);
419 | qkA8[2+x] *= to_float2(qt.data[2+x]);
420 | }
421 |
422 | int tid = threadIdx.x%32, j = threadIdx.x%4;
423 | // Non-inclusive cumsum
424 | for (int i = 0; i < 4; i++) {
425 | float sum = qkA8[i].x+qkA8[i].y;
426 | float psum = __shfl_xor_sync(0xffffffff, sum, 1);
427 | float ppsum = __shfl_xor_sync(0xffffffff, sum+psum, 2);
428 | if (i < 2) {
429 | psum = ppsum*(j>=2)+psum*(j%2);
430 | qkA8[i].y = psum + qkA8[i].x;
431 | qkA8[i].x = psum;
432 | } else {
433 | psum = ppsum*(j<2)+psum*(j%2==0);
434 | qkA8[i].x = psum + qkA8[i].y;
435 | qkA8[i].y = psum;
436 | }
437 | }
438 |
439 | float2 qkA4[4];
440 | {
441 | RTile k_q;
442 | for (int i = 0; i < 8; i++) ((bf*)k_q.data)[i] = (j<2?((bf*)kt.data)[i]:((bf*)qt.data)[i]);
443 | float lower_left = (tid >= 16 && j < 2);
444 | bf2 A0 = to_bf2(to_float2(A.data[0])*float2{lower_left,lower_left});
445 | bf2 A3 = to_bf2(to_float2(A.data[3])*float2{lower_left,lower_left});
446 | __m16n8k8(qkA4[0],qkA4[1], k_q.data[0], k_q.data[1], A0 + transpose(A0));
447 | __m16n8k8(qkA4[2],qkA4[3], k_q.data[2], k_q.data[3], A3 + transpose(A3));
448 | for (int i = 0; i < 4; i++)
449 | qkA4[i] *= to_float2(k_q.data[i]);
450 | }
451 |
452 | // Non-inclusive cumsum
453 | for (int i = 0; i < 4; i++) {
454 | float sum = qkA4[i].x+qkA4[i].y;
455 | float psum = __shfl_xor_sync(0xffffffff, sum, 1);
456 | psum *= (j%2 == j<2);
457 | qkA4[i] = {psum + qkA4[i].y*(j>=2), psum + qkA4[i].x*(j<2)};
458 | }
459 |
460 | FTile ret;
461 | ret.data[0] = qkA8[0]+qkA4[0];
462 | ret.data[1] = qkA8[1]+qkA4[1];
463 | ret.data[2] = qkA8[2]+qkA4[2];
464 | ret.data[3] = qkA8[3]+qkA4[3];
465 |
466 | for (int ci : {0,1}) {
467 | for (int ti : {0,1}) {
468 | int Ai = ti*3, di = ti*2+ci;
469 | unsigned mask = 0xffff<<(j>=2)*16;
470 | bf A8x = __shfl_sync(mask, A.data[Ai].x, 8+(j>=2)*18);
471 | bf A12x = __shfl_sync(mask, A.data[Ai].x, 12+(j>=2)*18);
472 | bf A12y = __shfl_sync(mask, A.data[Ai].y, 12+(j>=2)*18);
473 | bf2 nq = __shfl_xor_sync(0xffffffff, qt.data[di], 1);
474 | bf2 pk = __shfl_xor_sync(0xffffffff, kt.data[di], 1);
475 |
476 | bool even = (j%2==0);
477 | float ax = to_float(even?A8x:A12x), ay = to_float(even?A12x:A12y), c = to_float(even?kt.data[di].x:qt.data[di].y);
478 | float2 b = to_float2(j%2?pk:nq);
479 | float d = (ax*b.x+ay*b.y)*c;
480 | ret.data[di].y += even*d;
481 | ret.data[di].x +=!even*d;
482 | }
483 | }
484 |
485 | if (!strict) {
486 | // Do we really need tril<1>()?
487 | ret += (kt % tril<1>(A)) * qt;
488 | }
489 | return transpose(ret);
490 | }
491 |
492 | __device__ void debug_set(RTile&ra, int i, int j, float v) {
493 | if (threadIdx.x%32 == i%8*4+j%8/2) ((bf*)ra.data)[i/8*2+j/8*4+j%2] = to_bf(v);
494 | }
495 |
--------------------------------------------------------------------------------
/rwkv_cuda_wind/wind_rwkv7.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | using bf = __nv_bfloat16;
4 |
5 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*s0, bf*y, bf*s, bf*sT);
6 |
7 | void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &s0, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sT) {
8 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
9 | cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)s0.data_ptr(), (bf*)y.data_ptr(), (bf*)s.data_ptr(), (bf*)sT.data_ptr());
10 | }
11 |
12 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, bf*s, bf*dsT, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da, bf*ds0);
13 |
14 | void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy,
15 | torch::Tensor &s, torch::Tensor &dsT, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da, torch::Tensor &ds0) {
16 | int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
17 | cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)dy.data_ptr(),
18 | (bf*)s.data_ptr(), (bf*)dsT.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr(), (bf*)ds0.data_ptr());
19 | }
20 |
21 | /*TORCH_LIBRARY(wind, m) {
22 | m.def("forward", forward);
23 | m.def("backward", backward);
24 | }*/
25 |
26 | TORCH_LIBRARY(wind, m) {
27 | m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor s0, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sT) -> ()");
28 | m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor dsT, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) dz, Tensor(f!) da, Tensor(g!) ds0) -> ()");
29 | }
30 |
31 | TORCH_LIBRARY_IMPL(wind, CUDA, m) {
32 | m.impl("forward", &forward);
33 | m.impl("backward", &backward);
34 | }
35 |
--------------------------------------------------------------------------------
/rwkv_cuda_wind/wind_rwkv7.cu:
--------------------------------------------------------------------------------
1 | #include "tile.cuh"
2 | #include
3 | typedef bf * __restrict__ F_;
4 |
5 | constexpr int WARPS = _C_/16;
6 | constexpr int fw_stages = 1, bw_stages = 1;
7 |
8 | __global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ s0_, bf* y_, bf* s_, bf* sT_) {
9 | constexpr int C = _C_, K = 16;
10 | int bi = blockIdx.y, hi = blockIdx.x;
11 | extern __shared__ char smem_[];
12 | char*smem = smem_;
13 |
14 | STile *sw_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS;
15 | STile *sq_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS;
16 | STile *sk_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS;
17 | STile *sv_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS;
18 | STile *sa_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS;
19 | STile *sb_ = (STile*)smem; smem += sizeof(STile)*fw_stages*WARPS;
20 | char*share = (char*)smem;
21 |
22 | int stride = H*C;
23 | int warpi = threadIdx.x/32;
24 |
25 | auto push = [&](int t) {
26 | int off = bi*T*H*C + t*K*H*C + hi*C + warpi*16;
27 | int si = t%fw_stages;
28 | sw_[si*WARPS+warpi] = GTile(w_+off, stride);
29 | sq_[si*WARPS+warpi] = GTile(q_+off, stride);
30 | sk_[si*WARPS+warpi] = GTile(k_+off, stride);
31 | sv_[si*WARPS+warpi] = GTile(v_+off, stride);
32 | sa_[si*WARPS+warpi] = GTile(a_+off, stride);
33 | sb_[si*WARPS+warpi] = GTile(b_+off, stride);
34 | };
35 | for (int t = 0; t < fw_stages-1 && t < T/K; t++) push(t), __commit_group();
36 |
37 | FTile state[WARPS];
38 | for (int i = 0; i < WARPS; i++) {
39 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16;
40 | RTile tmp;
41 | tmp = GTile(s0_+off, C);
42 | state[i] = tmp;
43 | }
44 |
45 | for (int t = 0; t < T/K; t++) {
46 | __syncthreads();
47 | if (t+fw_stages-1 < T/K)
48 | push(t+fw_stages-1);
49 | __commit_group();
50 | __wait_groups();
51 | __syncthreads();
52 | int si = t%fw_stages;
53 | STile &sw = sw_[si*WARPS+warpi], &sq = sq_[si*WARPS+warpi], &sk = sk_[si*WARPS+warpi], &sv = sv_[si*WARPS+warpi], &sa = sa_[si*WARPS+warpi], &sb = sb_[si*WARPS+warpi];
54 |
55 | FTile w = (RTile)sw;
56 | apply_(w, [](float x) { return __expf(-__expf(x)); });
57 | FTile fw = w;
58 | FTile non_incl_pref = cumprodv<0,0>(fw);
59 | FTile incl_pref = non_incl_pref * w;
60 | FTile inv_incl_pref = incl_pref;
61 | apply_(inv_incl_pref, [](float x) { return 1.f/x; });
62 |
63 | RTile wq = (RTile)sq * incl_pref, kwi = (RTile)sk * inv_incl_pref;
64 | RTile wa = (RTile)sa * non_incl_pref, bwi = (RTile)sb * inv_incl_pref;
65 | FTile ab = sum_warp<1,WARPS>((float2*)share, tril<1>(wa % bwi));
66 | RTile ak = sum_warp<1,WARPS>((float2*)share, tril<1>(wa % kwi));
67 |
68 | RTile ab_inv;
69 | __syncthreads();
70 | if (threadIdx.x < 32) ab_inv = tri_minv(ab, (float*)share);
71 | __syncthreads();
72 | ab_inv = from_warp(ab_inv, 0, (float4*)share);
73 |
74 | RTile vt = sv.t();
75 | FTile ab_ut = vt % ak;
76 | for (int i = 0; i < WARPS; i++)
77 | ab_ut += state[i] % from_warp(wa, i, (float4*)share);
78 | RTile ut = FTile(ab_ut % ab_inv);
79 |
80 | FTile y = sum_warp<1,WARPS>((float2*)share, tril<0>(wq % kwi)) % vt;
81 | y += sum_warp<1,WARPS>((float2*)share, tril<0>(wq % bwi)) % ut;
82 | for (int i = 0; i < WARPS; i++)
83 | y += from_warp(wq, i, (float4*)share) % state[i];
84 |
85 | int off = bi*T*H*C + t*K*H*C + hi*C + warpi*16;
86 | GTile(y_+off, stride) = RTile(y);
87 |
88 | RTile kwt = transpose(kwi*fw), bwt = transpose(bwi*fw);
89 | for (int i = 0; i < WARPS; i++) {
90 | int off = bi*H*(T/K)*C*C + hi*(T/K)*C*C + t*C*C + warpi*16*C + i*16;
91 | GTile(s_+off, C) = (RTile)state[i];
92 |
93 | FTile fstate = state[i] * from_warp(fw, i, (float4*)share);
94 | fstate += vt % from_warp(kwt, i, (float4*)share);
95 | fstate += ut % from_warp(bwt, i, (float4*)share);
96 | state[i] = fstate;
97 | }
98 | }
99 | for (int i = 0; i < WARPS; i++) {
100 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16;
101 | GTile(sT_+off, C) = state[i];
102 | }
103 | }
104 |
105 | void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*s0, bf*y, bf*s, bf*sT) {
106 | assert(T%16 == 0);
107 | constexpr int tmp_size1 = sizeof(float4)*32, tmp_size2 = sizeof(float)*16*16*2;
108 | constexpr int threads = 32*WARPS, shared_mem = sizeof(STile)*fw_stages*WARPS*6 + (tmp_size1 > tmp_size2 ? tmp_size1 : tmp_size2);
109 | static int reported = 0;
110 | if (!reported++) {
111 | #if defined VERBOSE
112 | printf("forward_kernel() uses %d bytes of (dynamic) shared memory\n", shared_mem);
113 | #endif
114 | cudaFuncAttributes attr;
115 | cudaFuncGetAttributes(&attr, forward_kernel);
116 | int cur_mem = attr.maxDynamicSharedSizeBytes;
117 | if (shared_mem > cur_mem) {
118 | #if defined VERBOSE
119 | printf("Increasing forward_kernel's MaxDynamicSharedMemorySize from %d to %d\n", cur_mem, shared_mem);
120 | #endif
121 | assert(!cudaFuncSetAttribute(forward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem));
122 | }
123 | }
124 | forward_kernel<<>>(T,H,w,q,k,v,z,a,s0,y,s,sT);
125 | }
126 |
127 |
128 | __global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, F_ s_, F_ dsT_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_, bf* ds0_) {
129 | constexpr int C = _C_, K = 16;
130 | int bi = blockIdx.y, hi = blockIdx.x;
131 | extern __shared__ char smem_[];
132 | char*smem = smem_;
133 |
134 | STile *sw_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS;
135 | STile *sq_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS;
136 | STile *sk_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS;
137 | STile *sv_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS;
138 | STile *sa_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS;
139 | STile *sb_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS;
140 | STile *sdy_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS;
141 | STile *state_ = (STile*)smem; smem += sizeof(STile)*bw_stages*WARPS*WARPS;
142 | char*share = (char*)smem;
143 |
144 | int stride = H*C;
145 | int warpi = threadIdx.x/32;
146 |
147 | auto push = [&](int t) {
148 | int off = bi*T*H*C + t*K*H*C + hi*C + warpi*16;
149 | int si = t%fw_stages;
150 | sw_[si*WARPS+warpi] = GTile(w_+off, stride);
151 | sq_[si*WARPS+warpi] = GTile(q_+off, stride);
152 | sk_[si*WARPS+warpi] = GTile(k_+off, stride);
153 | sv_[si*WARPS+warpi] = GTile(v_+off, stride);
154 | sa_[si*WARPS+warpi] = GTile(a_+off, stride);
155 | sb_[si*WARPS+warpi] = GTile(b_+off, stride);
156 | sdy_[si*WARPS+warpi] = GTile(dy_+off, stride);
157 | for (int i = 0; i < WARPS; i++) {
158 | int off2 = bi*H*(T/K)*C*C + hi*(T/K)*C*C + t*C*C + warpi*16*C + i*16;
159 | state_[si*WARPS*WARPS+warpi*WARPS+i] = GTile(s_+off2, C);
160 | }
161 | };
162 |
163 | FTile dstate[WARPS];
164 | for (int i = 0; i < WARPS; i++) {
165 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16;
166 | RTile tmp;
167 | tmp = GTile(dsT_+off, C);
168 | dstate[i] = tmp;
169 | __commit_group();
170 | }
171 |
172 | for (int t = 0; t < bw_stages-1 && t < T/K; t++) push(T/K-1-t), __commit_group();
173 |
174 | for (int t = T/K-1; t >= 0; t--) {
175 | __syncthreads();
176 | if (t-bw_stages+1 >= 0)
177 | push(t-bw_stages+1);
178 | __commit_group();
179 | __wait_groups();
180 | __syncthreads();
181 | int si = t%bw_stages;
182 | STile &sw = sw_[si*WARPS+warpi], &sq = sq_[si*WARPS+warpi], &sk = sk_[si*WARPS+warpi], &sv = sv_[si*WARPS+warpi], &sa = sa_[si*WARPS+warpi], &sb = sb_[si*WARPS+warpi], &sdy = sdy_[si*WARPS+warpi];
183 | STile*state = state_+si*WARPS*WARPS;
184 |
185 | FTile w = (RTile)sw;
186 | apply_(w, [](float x) { return __expf(-__expf(x)); });
187 | FTile fw = w;
188 | FTile non_incl_pref = cumprodv<0,0>(fw);
189 | FTile incl_pref = non_incl_pref * w;
190 | FTile inv_incl_pref = incl_pref;
191 | apply_(inv_incl_pref, [](float x) { return 1.f/x; });
192 |
193 | RTile wq = (RTile)sq * incl_pref, kwi = (RTile)sk * inv_incl_pref;
194 | RTile wa = (RTile)sa * non_incl_pref, bwi = (RTile)sb * inv_incl_pref;
195 | FTile ab = sum_warp<1,WARPS>((float2*)share, tril<1>(wa % bwi));
196 | RTile ak = sum_warp<1,WARPS>((float2*)share, tril<1>(wa % kwi));
197 |
198 | RTile ab_inv;
199 | __syncthreads();
200 | if (threadIdx.x < 32) ab_inv = tri_minv(ab, (float*)share);
201 | __syncthreads();
202 | ab_inv = from_warp(ab_inv, 0, (float4*)share);
203 |
204 | RTile vt = sv.t();
205 | FTile ab_ut = vt % ak;
206 | for (int i = 0; i < WARPS; i++)
207 | ab_ut += state[warpi*WARPS+i] % from_warp(wa, i, (float4*)share);
208 | RTile ut = FTile(ab_ut % ab_inv);
209 |
210 | RTile qb = sum_warp<1,WARPS>((float2*)share, tril<0>(wq % bwi));
211 | RTile qk = sum_warp<1,WARPS>((float2*)share, tril<0>(wq % kwi));
212 |
213 | RTile dyt = sdy.t();
214 | FTile dut = FTile(dyt % transpose(qb));
215 | FTile dv = transpose(qk) % dyt;
216 | for (int i = 0; i < WARPS; i++) {
217 | RTile dstatei = dstate[i];
218 | dut += dstatei % from_warp(bwi*fw, i, (float4*)share);
219 | dv += from_warp(kwi*fw, i, (float4*)share) % dstatei;
220 | }
221 | RTile dab_ut = FTile(dut % transpose(ab_inv));
222 | dv += transpose(ak) % dab_ut;
223 |
224 | int off = bi*T*H*C + t*K*H*C + hi*C + warpi*16;
225 | GTile(dv_+off, stride) = RTile(dv);
226 |
227 | FTile dab = sum_warp<1,WARPS>((float2*)share, tril<1>(transpose(dab_ut) % transpose(ut)));
228 | FTile dak = sum_warp<1,WARPS>((float2*)share, tril<1>(transpose(dab_ut) % transpose(vt)));
229 | FTile dab_u_state0;
230 | dab_u_state0.zero_();
231 | for (int i = 0; i < WARPS; i++)
232 | dab_u_state0 += from_warp(transpose(dab_ut), i, (float4*)share) % state[i*WARPS+warpi].t();
233 |
234 | FTile da = dab_u_state0;
235 | da += dab % transpose(bwi);
236 | da += dak % transpose(kwi);
237 | da = non_incl_pref * da;
238 | GTile(da_+off, stride) = RTile(da);
239 |
240 | FTile dqb = sum_warp<1,WARPS>((float2*)share, tril<0>(transpose(dyt) % transpose(ut)));
241 | FTile dqk = sum_warp<1,WARPS>((float2*)share, tril<0>(transpose(dyt) % transpose(vt)));
242 | FTile dy_state0;
243 | dy_state0.zero_();
244 | for (int i = 0; i < WARPS; i++)
245 | dy_state0 += from_warp(transpose(dyt), i, (float4*)share) % state[i*WARPS+warpi].t();
246 |
247 | FTile dq = dy_state0;
248 | dq += dqb % transpose(bwi);
249 | dq += dqk % transpose(kwi);
250 | dq = incl_pref * dq;
251 | GTile(dq_+off, stride) = RTile(dq);
252 |
253 | RTile wqt = transpose(wq), wat = transpose(wa);
254 |
255 | FTile u_dstate, v_dstate, dw;
256 | u_dstate.zero_();
257 | v_dstate.zero_();
258 | dw.zero_();
259 | RTile ones;
260 | for (int i = 0; i < 4; i++) ones.data[i] = to_bf2({1.f,1.f});
261 | for (int i = 0; i < WARPS; i++) {
262 | int tid = threadIdx.x%32;
263 | if (warpi == i) {
264 | for (int j = 0; j < WARPS; j++) {
265 | RTile ra = dstate[j];
266 | ((float4*)share)[j*32+tid] = *((float4*)ra.data);
267 | }
268 | }
269 | RTile dstatei;// = dstate[i*WARPS+warpi];
270 | __syncthreads();
271 | *((float4*)dstatei.data) = ((float4*)share)[warpi*32+tid];
272 | __syncthreads();
273 | RTile dstatei_t = transpose(dstatei);
274 | v_dstate += from_warp(transpose(vt), i, (float4*)share) % dstatei_t;
275 | u_dstate += from_warp(transpose(ut), i, (float4*)share) % dstatei_t;
276 | dw += ones % transpose((RTile)state[i*WARPS+warpi]*dstatei);
277 | }
278 |
279 | FTile db = fw * u_dstate;
280 | db += transpose(dab) % wat;
281 | db += transpose(dqb) % wqt;
282 | db = inv_incl_pref * db;
283 | GTile(db_+off, stride) = RTile(db);
284 |
285 | FTile dk = fw * v_dstate;
286 | dk += transpose(dak) % wat;
287 | dk += transpose(dqk) % wqt;
288 | dk = inv_incl_pref * dk;
289 | GTile(dk_+off, stride) = RTile(dk);
290 |
291 | dw = fw * dw;
292 | dw += fast_dw<1>(dab,wa,bwi);
293 | dw += fast_dw<1>(dak,wa,kwi);
294 | dw += fast_dw<0>(dqb,wq,bwi);
295 | dw += fast_dw<0>(dqk,wq,kwi);
296 | FTile tmp;
297 | dw += cumsumv<0,0>(tmp = v_dstate*(fw*kwi));
298 | dw += cumsumv<0,0>(tmp = u_dstate*(fw*bwi));
299 | dw += cumsumv<0,1>(tmp = dab_u_state0*wa);
300 | dw += cumsumv<1,1>(tmp = dy_state0*wq);
301 |
302 | FTile dw_fac = (RTile)sw;
303 | apply_(dw_fac, [](float x) { return -__expf(x); });
304 | dw = dw * dw_fac;
305 | GTile(dw_+off, stride) = RTile(dw);
306 |
307 | __syncthreads();
308 | for (int i = 0; i < WARPS; i++) {
309 | FTile ndstate = dstate[i] * from_warp(fw, i, (float4*)share);
310 | ndstate += dyt % from_warp(wqt, i, (float4*)share);
311 | ndstate += dab_ut % from_warp(wat, i, (float4*)share);
312 | dstate[i] = ndstate;
313 | }
314 | __syncthreads();
315 | }
316 | for (int i = 0; i < WARPS; i++) {
317 | int off = bi*H*C*C + hi*C*C + warpi*16*C + i*16;
318 | GTile(ds0_+off, C) = dstate[i];
319 | }
320 | }
321 |
322 | void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, bf*s, bf*dsT, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da, bf*ds0) {
323 | assert(T%16 == 0);
324 | constexpr int tmp_size1 = sizeof(float4)*32*WARPS, tmp_size2 = sizeof(float)*16*16*2;
325 | constexpr int threads = 32*WARPS, shared_mem = sizeof(STile)*WARPS*bw_stages*(7+WARPS) + (tmp_size1 > tmp_size2 ? tmp_size1 : tmp_size2);
326 | static int reported = 0;
327 | if (!reported++) {
328 | #if defined VERBOSE
329 | printf("backward_kernel() uses %d bytes of (dynamic) shared memory\n", shared_mem);
330 | #endif
331 | cudaFuncAttributes attr;
332 | cudaFuncGetAttributes(&attr, backward_kernel);
333 | int cur_mem = attr.maxDynamicSharedSizeBytes;
334 | if (shared_mem > cur_mem) {
335 | #if defined VERBOSE
336 | printf("Increasing backward_kernel's MaxDynamicSharedMemorySize from %d to %d\n", cur_mem, shared_mem);
337 | #endif
338 | assert(!cudaFuncSetAttribute(backward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem));
339 | }
340 | }
341 | backward_kernel<<>>(T,H,w,q,k,v,z,a,dy,s,dsT,dw,dq,dk,dv,dz,da,ds0);
342 | }
343 |
--------------------------------------------------------------------------------
/train_gpt2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | with open(sys.argv[0]) as f:
4 | code = f.read() # read the code of this file ASAP, for logging
5 | import uuid
6 | import glob
7 | import time
8 | from dataclasses import dataclass
9 |
10 | import numpy as np
11 | import torch
12 | from torch import nn
13 | import torch.nn.functional as F
14 | import torch.distributed as dist
15 | import torch._inductor.config as config
16 | from torch.nn.parallel import DistributedDataParallel as DDP
17 |
18 | # -----------------------------------------------------------------------------
19 | # Muon optimizer
20 |
21 | def zeropower_via_svd(G, steps=None):
22 | U, S, V = G.svd()
23 | return U @ V.T
24 |
25 | @torch.compile
26 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
27 | r"""
28 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
29 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
30 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
31 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
32 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
33 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
34 | performance at all relative to UV^T, where USV^T = G is the SVD.
35 | """
36 | assert len(G.shape) == 2
37 | a, b, c = (3.4445, -4.7750, 2.0315)
38 | X = G.bfloat16()
39 | X /= (X.norm() + eps) # ensure top singular value <= 1
40 | if G.size(0) > G.size(1):
41 | X = X.T
42 | for _ in range(steps):
43 | A = X @ X.T
44 | B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
45 | X = a * X + B @ X
46 | if G.size(0) > G.size(1):
47 | X = X.T
48 | return X
49 |
50 | zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)
51 |
52 | class Muon(torch.optim.Optimizer):
53 | """
54 | Muon - MomentUm Orthogonalized by Newton-schulz
55 |
56 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
57 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
58 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
59 | the advantage that it can be stably run in bfloat16 on the GPU.
60 |
61 | Some warnings:
62 | - This optimizer assumes that all parameters passed in are 2D.
63 | - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
64 | parameters; those should all be optimized by a standard method (e.g., AdamW).
65 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
66 | - We believe it is unlikely to work well for training with small batch size.
67 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
68 | - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
69 |
70 | Arguments:
71 | lr: The learning rate used by the internal SGD.
72 | momentum: The momentum used by the internal SGD.
73 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
74 | backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
75 | backend_steps: The number of iteration steps to use in the backend, if it is iterative.
76 | """
77 | def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
78 | backend='newtonschulz5', backend_steps=5):
79 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
80 | super().__init__(params, defaults)
81 |
82 | def step(self):
83 |
84 | for group in self.param_groups:
85 |
86 | lr = group['lr']
87 | momentum = group['momentum']
88 | zeropower_backend = zeropower_backends[group['backend']]
89 |
90 | # generate weight updates in distributed fashion
91 | total_params = sum(p.numel() for p in group['params'])
92 | updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16)
93 | curr_idx = 0
94 | for i, p in enumerate(group['params']):
95 | # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs
96 | if i % int(os.environ['WORLD_SIZE']) == int(os.environ['RANK']):
97 | g = p.grad
98 | assert g is not None
99 | state = self.state[p]
100 | if 'momentum_buffer' not in state:
101 | state['momentum_buffer'] = torch.zeros_like(g)
102 | buf = state['momentum_buffer']
103 | buf.mul_(momentum).add_(g)
104 | if group['nesterov']:
105 | g = g.add(buf, alpha=momentum)
106 | g = zeropower_backend(g, steps=group['backend_steps'])
107 | g *= max(1, g.size(0)/g.size(1))**0.5
108 | updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()
109 | curr_idx += p.numel()
110 |
111 | # sync updates across devices. we are not memory-constrained so can do this simple deserialization
112 | dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM)
113 |
114 | # deserialize and apply updates
115 | curr_idx = 0
116 | for p in group['params']:
117 | g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data)
118 | p.data.add_(g, alpha=-lr)
119 | curr_idx += p.numel()
120 |
121 | # -----------------------------------------------------------------------------
122 | # PyTorch nn.Module definitions for the GPT-2 model
123 |
124 | class Rotary(torch.nn.Module):
125 |
126 | def __init__(self, dim, base=10000):
127 | super().__init__()
128 | self.dim = dim
129 | self.base = base
130 | self.inv_freq = None
131 | self.seq_len_cached = None
132 | self.cos_cached = None
133 | self.sin_cached = None
134 |
135 | def forward(self, x):
136 | seq_len = x.shape[1]
137 | if seq_len != self.seq_len_cached:
138 | self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim))
139 | self.seq_len_cached = seq_len
140 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
141 | freqs = torch.outer(t, self.inv_freq)
142 | self.cos_cached = freqs.cos().bfloat16()
143 | self.sin_cached = freqs.sin().bfloat16()
144 | return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
145 |
146 | def apply_rotary_emb(x, cos, sin):
147 | assert x.ndim == 4 # multihead attention
148 | d = x.shape[3]//2
149 | x1 = x[..., :d]
150 | x2 = x[..., d:]
151 | y1 = x1 * cos + x2 * sin
152 | y2 = x1 * (-sin) + x2 * cos
153 | return torch.cat([y1, y2], 3).type_as(x)
154 |
155 | class CastedLinear(nn.Linear):
156 | def forward(self, x):
157 | return F.linear(x, self.weight.to(x.dtype))
158 |
159 | class CausalSelfAttention(nn.Module):
160 |
161 | def __init__(self, config):
162 | super().__init__()
163 | self.n_head = config.n_head
164 | self.n_embd = config.n_embd
165 | self.head_dim = self.n_embd // self.n_head
166 | assert self.n_embd % self.n_head == 0
167 | self.c_q = CastedLinear(self.n_embd, self.n_embd, bias=False)
168 | self.c_k = CastedLinear(self.n_embd, self.n_embd, bias=False)
169 | self.c_v = CastedLinear(self.n_embd, self.n_embd, bias=False)
170 | # output projection
171 | self.c_proj = CastedLinear(self.n_embd, self.n_embd, bias=False)
172 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
173 | self.rotary = Rotary(self.head_dim)
174 | self.lamb = nn.Parameter(torch.tensor(0.5)) # @Grad62304977
175 |
176 | def forward(self, x, v1=None):
177 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
178 | q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
179 | k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
180 | v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
181 | if v1 is None:
182 | v1 = v # This happens if we are in the first block. v needs to be accessed by subsequent blocks
183 | v = (1 - self.lamb) * v + self.lamb * v1.view_as(v) # @Grad62304977
184 | cos, sin = self.rotary(q)
185 | q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) # QK norm suggested by @Grad62304977
186 | q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
187 | y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
188 | y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
189 | y = self.c_proj(y)
190 | return y, v1
191 |
192 | class MLP(nn.Module):
193 |
194 | def __init__(self, config):
195 | super().__init__()
196 | self.c_fc = CastedLinear(config.n_embd, 4 * config.n_embd, bias=False)
197 | self.c_proj = CastedLinear(4 * config.n_embd, config.n_embd, bias=False)
198 | self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
199 |
200 | def forward(self, x):
201 | x = self.c_fc(x)
202 | x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
203 | x = self.c_proj(x)
204 | return x
205 |
206 | class Block(nn.Module):
207 |
208 | def __init__(self, config):
209 | super().__init__()
210 | self.attn = CausalSelfAttention(config)
211 | self.mlp = MLP(config)
212 | self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
213 |
214 | def forward(self, x, v1, x0):
215 | x = self.lambdas[0] * x + self.lambdas[1] * x0
216 | x1, v1 = self.attn(F.rms_norm(x, (x.size(-1),)), v1)
217 | x = x + x1
218 | x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
219 | return x, v1
220 |
221 | # -----------------------------------------------------------------------------
222 | # The main GPT-2 model
223 |
224 | @dataclass
225 | class GPTConfig:
226 | vocab_size : int = 50304
227 | n_layer : int = 12
228 | n_head : int = 6 # head dim 128 suggested by @Grad62304977
229 | n_embd : int = 768
230 |
231 | class GPT(nn.Module):
232 |
233 | def __init__(self, config):
234 | super().__init__()
235 | self.config = config
236 |
237 | self.transformer = nn.ModuleDict(dict(
238 | wte = nn.Embedding(config.vocab_size, config.n_embd),
239 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
240 | ))
241 | self.lm_head = CastedLinear(config.n_embd, config.vocab_size, bias=False)
242 | self.lm_head.weight.data.zero_() # @Grad62304977
243 |
244 | def forward(self, idx, target):
245 |
246 | # forward the GPT model itself
247 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
248 | x = F.rms_norm(x, (x.size(-1),)) # @Grad62304977
249 | x0 = x
250 | v1 = None
251 | for block in self.transformer.h:
252 | x, v1 = block(x, v1, x0)
253 | x = F.rms_norm(x, (x.size(-1),))
254 |
255 | logits = self.lm_head(x)
256 | logits = 30 * torch.tanh(logits / 30) # @Grad62304977
257 | logits = logits.float()
258 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1))
259 | return loss.float()
260 |
261 | # -----------------------------------------------------------------------------
262 | # Our own simple Distributed Data Loader
263 |
264 | def _peek_data_shard(filename):
265 | # only reads the header, returns header data
266 | with open(filename, "rb") as f:
267 | # first read the header, which is 256 int32 integers (4 bytes each)
268 | header = np.frombuffer(f.read(256*4), dtype=np.int32)
269 | if header[0] != 20240520:
270 | print("ERROR: magic number mismatch in the data .bin file!")
271 | print("---> HINT: Are you passing in a correct file with --input_bin?")
272 | print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
273 | print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
274 | exit(1)
275 | assert header[1] == 1, "unsupported version"
276 | ntok = header[2] # number of tokens (claimed)
277 | return ntok # for now just return the number of tokens
278 |
279 | def _load_data_shard(filename):
280 | with open(filename, "rb") as f:
281 | # first read the header, which is 256 int32 integers (4 bytes each)
282 | header = np.frombuffer(f.read(256*4), dtype=np.int32)
283 | assert header[0] == 20240520, "magic number mismatch in the data .bin file"
284 | assert header[1] == 1, "unsupported version"
285 | ntok = header[2] # number of tokens (claimed)
286 | # the rest of it are tokens, stored as uint16
287 | tokens = np.frombuffer(f.read(), dtype=np.uint16)
288 | assert len(tokens) == ntok, "number of tokens read does not match header?"
289 | return tokens
290 |
291 | class DistributedDataLoader:
292 | def __init__(self, filename_pattern, B, T, process_rank, num_processes):
293 | self.process_rank = process_rank
294 | self.num_processes = num_processes
295 | self.B = B
296 | self.T = T
297 |
298 | # glob files that match the pattern
299 | self.files = sorted(glob.glob(filename_pattern))
300 | assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
301 |
302 | # load and validate all data shards, count number of tokens in total
303 | ntok_total = 0
304 | for fname in self.files:
305 | shard_ntok = _peek_data_shard(fname)
306 | assert shard_ntok >= num_processes * B * T + 1
307 | ntok_total += int(shard_ntok)
308 | self.ntok_total = ntok_total
309 |
310 | # kick things off
311 | self.reset()
312 |
313 | def reset(self):
314 | self.current_shard = 0
315 | self.current_position = self.process_rank * self.B * self.T
316 | self.tokens = _load_data_shard(self.files[self.current_shard])
317 |
318 | def advance(self): # advance to next data shard
319 | self.current_shard = (self.current_shard + 1) % len(self.files)
320 | self.current_position = self.process_rank * self.B * self.T
321 | self.tokens = _load_data_shard(self.files[self.current_shard])
322 |
323 | def next_batch(self):
324 | B = self.B
325 | T = self.T
326 | buf = self.tokens[self.current_position : self.current_position+B*T+1]
327 | buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
328 | x = (buf[:-1]).view(B, T) # inputs
329 | y = (buf[1:]).view(B, T) # targets
330 | # advance current position and load next shard if necessary
331 | self.current_position += B * T * self.num_processes
332 | if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
333 | self.advance()
334 | return x.cuda(), y.cuda()
335 |
336 | # -----------------------------------------------------------------------------
337 | # int main
338 |
339 | @dataclass
340 | class Hyperparameters:
341 | # data hyperparams
342 | input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
343 | input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
344 | # optimization hyperparams
345 | batch_size : int = 8*64 # batch size, in sequences, across all devices
346 | device_batch_size : int = 64 # batch size, in sequences, per device
347 | sequence_length : int = 1024 # sequence length, in tokens
348 | num_iterations : int = 3242 # number of iterations to run
349 | warmup_iters : int = 0
350 | warmdown_iters : int = 926 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
351 | weight_decay : float = 0
352 | # evaluation and logging hyperparams
353 | val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
354 | val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
355 | save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end
356 | args = Hyperparameters()
357 |
358 | # set up DDP (distributed data parallel). torchrun sets this env variable
359 | assert torch.cuda.is_available()
360 | dist.init_process_group(backend='nccl')
361 | ddp_rank = int(os.environ['RANK'])
362 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
363 | ddp_world_size = int(os.environ['WORLD_SIZE'])
364 | device = f'cuda:{ddp_local_rank}'
365 | torch.cuda.set_device(device)
366 | print(f"using device: {device}")
367 | master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
368 |
369 | # begin logging
370 | logfile = None
371 | if master_process:
372 | run_id = str(uuid.uuid4())
373 | logdir = 'logs/%s/' % run_id
374 | os.makedirs(logdir, exist_ok=True)
375 | logfile = 'logs/%s.txt' % run_id
376 | # create the log file
377 | with open(logfile, "w") as f:
378 | # begin the log by printing this file (the Python code)
379 | f.write('='*100 + '\n')
380 | f.write(code)
381 | f.write('='*100 + '\n')
382 | def print0(s, logonly=False):
383 | if master_process:
384 | with open(logfile, "a") as f:
385 | if not logonly:
386 | print(s)
387 | f.write(s+'\n')
388 | # log information about the hardware/software environment this is running on
389 | # and print the full `nvidia-smi` to file
390 | print0(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:")
391 | import subprocess
392 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
393 | print0(f'{result.stdout}', logonly=True)
394 | print0('='*100, logonly=True)
395 |
396 | # convenience variables
397 | B, T = args.device_batch_size, args.sequence_length
398 | # calculate the number of steps to take in the val loop.
399 | assert args.val_tokens % (B * T * ddp_world_size) == 0
400 | val_steps = args.val_tokens // (B * T * ddp_world_size)
401 | # calculate the steps of gradient accumulation required to attain the desired global batch size.
402 | assert args.batch_size % (B * ddp_world_size) == 0
403 | train_accumulation_steps = args.batch_size // (B * ddp_world_size)
404 |
405 | # load tokens
406 | train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
407 | val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
408 | print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
409 | print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
410 | print0('='*100, logonly=True)
411 | x, y = train_loader.next_batch()
412 |
413 | # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
414 | # this originates from Karpathy's experiments.
415 | num_vocab = 50304
416 | model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768))
417 | model = model.cuda().bfloat16()
418 | for m in model.modules():
419 | if isinstance(m, CastedLinear):
420 | m.float()
421 | if hasattr(config, "coordinate_descent_tuning"):
422 | config.coordinate_descent_tuning = True # suggested by @Chillee
423 | model = torch.compile(model)
424 | # here we wrap model into DDP container
425 | model = DDP(model, device_ids=[ddp_local_rank])
426 | raw_model = model.module # always contains the "raw" unwrapped model
427 |
428 | # CUDNN attention is ~4ms faster than Flash, but doesn't get selected by default in PyTorch 2.5.1
429 | from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp
430 | enable_cudnn_sdp(True)
431 | enable_flash_sdp(False)
432 | enable_mem_efficient_sdp(False)
433 | enable_math_sdp(False)
434 |
435 | # init the optimizer(s)
436 | optimizer1 = torch.optim.Adam([raw_model.transformer.wte.weight], lr=0.3, betas=(0.9, 0.95), fused=True)
437 | optimizer2 = torch.optim.Adam([raw_model.lm_head.weight], lr=0.002, betas=(0.9, 0.95), fused=True)
438 | params = list(raw_model.transformer.h.parameters())
439 | matrix_params = [p for p in params if p.ndim == 2]
440 | scalar_params = [p for p in params if p.ndim < 2]
441 | optimizer3 = Muon(matrix_params, lr=0.02, momentum=0.95)
442 | optimizer4 = torch.optim.Adam(scalar_params, lr=0.02, betas=(0.9, 0.95), fused=True) # note that this learning rate is neither sensitive nor tuned
443 | optimizers = [optimizer1, optimizer2, optimizer3, optimizer4]
444 | # learning rate decay scheduler (linear warmup and warmdown)
445 | def get_lr(it):
446 | assert it <= args.num_iterations
447 | # 1) linear warmup for warmup_iters steps
448 | if it < args.warmup_iters:
449 | return (it+1) / args.warmup_iters
450 | # 2) constant lr for a while
451 | elif it < args.num_iterations - args.warmdown_iters:
452 | return 1.0
453 | # 3) linear warmdown
454 | else:
455 | decay_ratio = (args.num_iterations - it) / args.warmdown_iters
456 | return decay_ratio
457 | schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
458 |
459 | # Start training loop
460 | training_time_ms = 0
461 | # start the clock
462 | torch.cuda.synchronize()
463 | t0 = time.time()
464 | # begin training
465 | train_loader.reset()
466 | for step in range(args.num_iterations + 1):
467 | last_step = (step == args.num_iterations)
468 | # This effectively ignores timing first 10 steps, which are slower for weird reasons.
469 | # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
470 | # steps with dummy data first, and then re-initialize the model and reset the loader.
471 | if step == 10:
472 | training_time_ms = 0
473 | t0 = time.time()
474 | timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
475 |
476 | # once in a while evaluate the validation dataset
477 | if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
478 | # stop the clock
479 | torch.cuda.synchronize()
480 | training_time_ms += 1000 * (time.time() - t0)
481 | # run validation batches
482 | model.eval()
483 | val_loader.reset()
484 | val_loss = 0.0
485 | for _ in range(val_steps):
486 | with torch.no_grad():
487 | x_val, y_val = val_loader.next_batch()
488 | val_loss += model(x_val, y_val)
489 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
490 | val_loss /= val_steps
491 | # log val loss to console and to logfile
492 | print0(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
493 | # start the clock again
494 | torch.cuda.synchronize()
495 | t0 = time.time()
496 |
497 | if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
498 | # stop the clock
499 | torch.cuda.synchronize()
500 | training_time_ms += 1000 * (time.time() - t0)
501 | # save the state of the training process
502 | log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
503 | torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
504 | # start the clock again
505 | torch.cuda.synchronize()
506 | t0 = time.time()
507 |
508 | # bit confusing: we want to make sure to eval on 0th iteration
509 | # but also after the very last iteration. so we loop for step <= num_iterations
510 | # instead of just < num_iterations (one extra due to <=), only to do
511 | # the validation/sampling one last time, and then we break right here as we're done.
512 | if last_step:
513 | break
514 |
515 | # --------------- TRAINING SECTION BEGIN -----------------
516 | model.train()
517 | for i in range(1, train_accumulation_steps+1):
518 | # forward pass
519 | loss = model(x, y)
520 | train_loss = loss.detach()
521 | # advance the dataset for the next batch
522 | x, y = train_loader.next_batch()
523 | # backward pass
524 | if i < train_accumulation_steps:
525 | with model.no_sync(): # there's no need to sync gradients every accumulation step
526 | loss.backward()
527 | else:
528 | loss.backward() # just sync on the last step
529 | for p in model.parameters():
530 | p.grad /= train_accumulation_steps
531 | # momentum warmup for Muon
532 | frac = min(step/500, 1)
533 | optimizer3.param_groups[0]['momentum'] = (1 - frac) * 0.85 + frac * 0.95
534 | # step the optimizers and schedulers
535 | for opt, sched in zip(optimizers, schedulers):
536 | opt.step()
537 | sched.step()
538 | # null the gradients
539 | model.zero_grad(set_to_none=True)
540 | # --------------- TRAINING SECTION END -------------------
541 | # everything that follows now is just diagnostics, prints, logging, etc.
542 |
543 | #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
544 | approx_time = training_time_ms + 1000 * (time.time() - t0)
545 | print0(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
546 |
547 | if master_process:
548 | print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
549 |
550 | # -------------------------------------------------------------------------
551 | # clean up nice
552 | dist.destroy_process_group()
553 |
--------------------------------------------------------------------------------