├── README.md
├── annotated_examples
├── classics
│ ├── cross_entropy.py
│ ├── geglu.py
│ ├── layernorm.py
│ ├── matmul.py
│ └── swiglu.py
└── gemlite
│ ├── gemm.py
│ └── gemm_splitK.py
├── assets
└── logo.png
└── backprop_math
├── cross_entropy.md
├── geglu.md
├── layernorm.md
└── swiglu.md
/README.md:
--------------------------------------------------------------------------------
1 | # Triton Academy
2 |
3 |
4 |

5 |
6 |
7 |
8 | Triton Academy is an initiative to educate developers on writing efficient GPU kernels using Triton. It's a work in progress and it's goal is to provide:
9 |
10 | 🎓 Tutorials – Learn Triton step by step
11 | 🔍 Famous Triton Kernels detailed explanations
12 | 📐 Mathematic formulas and proofs behind backpropogation in kernels
13 | 📖 Documentation & Debugging – Official docs and best practices
14 | 🔬 Benchmarks – Performance comparisons with CUDA
15 |
16 | ## What is Triton?
17 |
18 | Triton is an open-source programming language and compiler designed specifically for GPU programming. It aims to simplify the development of efficient GPU kernels by providing a higher-level abstraction than CUDA or other low-level GPU programming models.
19 |
20 | Triton enables developers to write high-performance GPU code with Python-like syntax while automatically handling many low-level optimizations that would otherwise require significant expertise in GPU architecture. It was developed by OpenAI and is now widely used in machine learning and scientific computing applications.
21 |
22 | ## Why Triton instead of CUDA?
23 |
24 | While CUDA offers fine-grained control over GPU programming, it comes with several challenges:
25 |
26 | 1. **Steep Learning Curve**: CUDA requires understanding complex GPU architecture details
27 | 2. **Verbose Code**: Simple operations often require extensive boilerplate code
28 | 3. **Manual Optimization**: Developers must manually handle memory coalescing, tiling, and other optimizations
29 | 4. **Limited Portability**: CUDA is specific to NVIDIA GPUs
30 |
31 | Triton addresses these issues by:
32 |
33 | 1. **Higher Abstraction**: Provides intuitive programming constructs
34 | 2. **Automatic Optimization**: Handles many low-level optimizations automatically
35 | 3. **Python Integration**: Seamlessly integrates with the Python ecosystem
36 | 4. **Performance**: Achieves performance comparable to hand-optimized CUDA in many cases
37 | 5. **Readability**: More maintainable code that clearly expresses intent
38 |
39 | ## Quick Example: Vector Addition in Triton
40 |
41 | Let’s expand on your example and provide comprehensive resources on Triton, details on Triton Academy, and how you can contribute!
42 |
43 | ```python
44 | import torch
45 | import triton
46 | import triton.language as tl
47 |
48 | @triton.jit
49 | def vector_add_kernel(X, Y, Z, N: tl.constexpr):
50 | pid = tl.program_id(axis=0) # Get the program ID
51 | block_size = 128 # Number of elements per block
52 | offsets = pid * block_size + tl.arange(0, block_size) # Compute memory offsets
53 | mask = offsets < N # Ensure we don’t go out of bounds
54 | x = tl.load(X + offsets, mask=mask) # Load X
55 | y = tl.load(Y + offsets, mask=mask) # Load Y
56 | tl.store(Z + offsets, x + y, mask=mask) # Store the result
57 |
58 | # Example usage
59 | N = 1024
60 | X = torch.randn(N, device="cuda")
61 | Y = torch.randn(N, device="cuda")
62 | Z = torch.empty(N, device="cuda")
63 |
64 | grid = (N // 128,)
65 | vector_add_kernel[grid](X, Y, Z, N=N)
66 | print(Z) # Output: X + Y
67 | ```
68 |
69 | Explanation:
70 | • tl.program_id(axis=0) → Gets the program index
71 | • tl.arange(0, block_size) → Generates thread-local indices
72 | • tl.load and tl.store → Handle memory operations efficiently
73 |
74 | ## Resources
75 | - [CUDA Programming Course – High-Performance Computing with GPUs](https://www.youtube.com/watch?v=86FAWCzIe_4&t=30156s&pp=ygUNdHJpdG9uIGNvdXJzZQ%3D%3D)
76 | - [Triton Tutorials](https://triton-lang.org/main/getting-started/tutorials/index.html)
77 | - [Triton Puzzles](https://github.com/srush/Triton-Puzzles)
78 | -
79 |
--------------------------------------------------------------------------------
/annotated_examples/classics/cross_entropy.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2 | # Modifications Copyright 2025 Mekkcyber.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import triton
17 | import triton.language as tl
18 | import torch
19 | from transformers.models.llama.modeling_llama import logger
20 |
21 | from triton.language.extra import libdevice
22 | triton_tanh = libdevice.tanh
23 | triton_cast = tl.cast
24 | MAX_FUSED_SIZE : int = 65536
25 | next_power_of_2 = triton.next_power_of_2
26 |
27 | def calculate_settings(n : int) -> (int, int,):
28 | BLOCK_SIZE : int = next_power_of_2(n)
29 | if BLOCK_SIZE > MAX_FUSED_SIZE:
30 | raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
31 | f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
32 | num_warps : int = 4
33 | if BLOCK_SIZE >= 32768: num_warps = 32
34 | elif BLOCK_SIZE >= 8192: num_warps = 16
35 | elif BLOCK_SIZE >= 2048: num_warps = 8
36 | return BLOCK_SIZE, num_warps
37 |
38 | @triton.jit
39 | def _cross_entropy_forward(
40 | logits_ptr , # Pointer to logits tensor [batch*seq_len, vocab_size]
41 | logits_row_stride , # Stride for accessing rows in logits
42 | loss_ptr , # Pointer to output loss values
43 | logsumexp_ptr , # Pointer to store logsumexp values (needed for backward)
44 | labels_ptr , # Pointer to label indices
45 | VOCAB_SIZE , # Size of vocabulary
46 | BLOCK_SIZE : tl.constexpr, # Block size for parallel processing
47 | DO_SOFTCAPPING , # Flag for logit softcapping (e.g., for Gemma 2)
48 | SOFTCAP , # Softcapping parameter value
49 | DO_LOGIT_SCALING , # Flag for logit scaling (e.g., for Cohere models)
50 | LOGIT_SCALE , # Scaling factor for logits
51 | ):
52 | """
53 | Computes cross-entropy loss in a numerically stable way.
54 |
55 | Cross Entropy Loss Formula:
56 | CE = -∑(y_i * log(p_i)) where p_i = softmax(x_i) = exp(x_i) / ∑exp(x_j)
57 |
58 | For one-hot labels (our case), this simplifies to:
59 | CE = -log(p_correct) = -(logit_correct - logsumexp)
60 | = logsumexp - logit_correct
61 |
62 | Numerical Stability:
63 | We use the LogSumExp trick for numerical stability:
64 | logsumexp(x) = max(x) + log(∑exp(x - max(x)))
65 |
66 | This prevents overflow by ensuring the largest exponentiated term is exp(0.0) = 1.0.
67 |
68 | Special handling:
69 | - If label == -100: loss = 0 (ignore token, e.g., padding)
70 | - Otherwise: loss = logsumexp - logit_correct
71 | """
72 | # Get current row index from the program ID, every block thread will handle a different row
73 | row_idx = tl.program_id(0)
74 |
75 | # Offset pointers to the current row
76 | # we cast to tl.int64 to avoid overflow because vocab sizes are large
77 | logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
78 | # each row corresponds to a different token in the sequence, hence a loss, logsumexp and label
79 | loss_ptr += row_idx
80 | logsumexp_ptr += row_idx
81 | labels_ptr += row_idx
82 |
83 | # Create offsets for accessing columns in parallel
84 | col_offsets = tl.arange(0, BLOCK_SIZE)
85 | # Create mask for valid vocabulary indices
86 | mask = col_offsets < VOCAB_SIZE
87 |
88 | # Load the label index for this row
89 | label_idx = tl.load(labels_ptr).to(tl.int32)
90 | # Load logits for this row, masking invalid indices
91 | # we mask invalid indices to -infinity to ensure they don't contribute to the sum (exp(-infinity) = 0)
92 | logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
93 |
94 | # Apply logit scaling if enabled: x → t*x (t = LOGIT_SCALE)
95 | # This scales the logits before softmax, affecting the "temperature" of the distribution
96 | # Higher values (t > 1) make the distribution more uniform/smoother
97 | # Lower values (0 < t < 1) make the distribution more peaked/confident
98 | # Logit scaling was introduced in models like Cohere Command and Claude to control
99 | # the model's confidence in its predictions. It helps prevent overconfidence and
100 | # can improve model calibration, especially in out-of-distribution scenarios.
101 | # Unlike temperature sampling at inference time, this scaling is applied during training.
102 | if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
103 |
104 | # Apply logit softcapping if enabled: x → t*tanh(x/t) (t = SOFTCAP)
105 | # This bounds logits to [-t, t] range, preventing extreme values
106 | # Softcapping was introduced in models like Gemma 2 to improve training stability
107 | # by preventing logits from growing too large, which can cause:
108 | # 1. Numerical instability in softmax computation
109 | # 2. Overconfident predictions leading to poor generalization
110 | # 3. Gradient explosion during backpropagation
111 | # Unlike simple clipping, tanh-based softcapping maintains differentiability
112 | # and allows gradients to flow even for extreme values, just at a reduced magnitude.
113 | if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
114 |
115 | # Compute logsumexp in a numerically stable way
116 | # First find the maximum logit value
117 | c = tl.max(logits, 0)
118 | # Then compute logsumexp = max + log(sum(exp(logits - max)))
119 | logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
120 |
121 | # Compute loss only if label is valid (not -100)
122 | if label_idx != -100:
123 | # Load the logit for the correct class
124 | x = tl.load(logits_ptr + label_idx).to(tl.float32)
125 |
126 | # Apply the same transformations to the target logit
127 | if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
128 | if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
129 |
130 | # Compute cross entropy: logsumexp - correct_logit
131 | # This is equivalent to -log(softmax(correct_logit))
132 | loss = logsumexp - x
133 | else:
134 | # For padding tokens (label_idx == -100), set loss to 0
135 | loss = 0.0
136 |
137 | # Store results for this row
138 | tl.store(logsumexp_ptr, logsumexp) # Save logsumexp for backward pass
139 | tl.store(loss_ptr, loss) # Save the computed loss
140 |
141 | @triton.jit
142 | def _cross_entropy_backward(
143 | logits_ptr , # Pointer to input logits
144 | logits_row_stride , # Stride between rows in logits tensor
145 | dloss_ptr , # Pointer to gradient of loss w.r.t output
146 | dloss_row_stride , # Stride between rows in dloss tensor
147 | logsumexp_ptr , # Pointer to precomputed logsumexp values
148 | labels_ptr , # Pointer to target labels
149 | VOCAB_SIZE , # Size of vocabulary (number of classes)
150 | BLOCK_SIZE : tl.constexpr, # Size of processing block
151 | DO_SOFTCAPPING , # Whether to apply softcapping
152 | SOFTCAP , # Softcapping parameter value
153 | DO_LOGIT_SCALING , # Whether to apply logit scaling
154 | LOGIT_SCALE , # Logit scaling parameter value
155 | ):
156 | """
157 | Backward pass for cross entropy loss.
158 |
159 | Cross Entropy Loss: CE(x, class) = -log(softmax(x)[class])
160 | = -log(exp(x_class) / sum(exp(x_i)))
161 | = -x_class + log(sum(exp(x_i)))
162 |
163 | For the backward pass, we need to compute gradients w.r.t. each logit.
164 |
165 | Let L = CE(x, class) and z = log(sum(exp(x_i))) (logsumexp)
166 |
167 | For the correct class (i = class):
168 | dL/dx_class = d/dx_class(-x_class + z) = -1 + exp(x_class - z) = -1 + softmax(x_class) (check backprop_math/cross_entropy.md)
169 |
170 | For other classes (i ≠ class):
171 | dL/dx_i = d/dx_i(-x_class + z) = d/dx_i(z) = exp(x_i - z) = softmax(x_i) (check backprop_math/cross_entropy.md)
172 |
173 | When logit transformations are applied, we use the chain rule to compute gradients.
174 | """
175 | # Get current row and block indices
176 | row_idx = tl.program_id(0)
177 |
178 | # Calculate pointers for current row
179 | logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
180 | dloss_ptr += row_idx * dloss_row_stride
181 |
182 | # Calculate column offsets for current block
183 | col_offsets = tl.arange(0, BLOCK_SIZE)
184 | # Create mask for valid vocabulary indices
185 | mask = col_offsets < VOCAB_SIZE
186 |
187 | # Load the target label for current row
188 | label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
189 |
190 | # Load gradient of loss w.r.t output
191 | # For padding tokens (label_idx == -100), set gradient to 0
192 | if label_idx != -100:
193 | dloss = tl.load(dloss_ptr)
194 | else:
195 | dloss = 0.0
196 |
197 | # Load logits for current row
198 | x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
199 |
200 | # Apply logit scaling if enabled
201 | # If x is scaled as x' = s*x in forward, then dx'/dx = s
202 | if DO_LOGIT_SCALING:
203 | x = x * LOGIT_SCALE
204 |
205 | # Store original values before softcapping for gradient calculation
206 | # For softcapping: x' = t*tanh(x/t), we need to track intermediate values they will be used in the backward pass chain rule
207 | tanh_term = x
208 | if DO_SOFTCAPPING:
209 | # Apply softcapping: x' = t*tanh(x/t)
210 | tanh_term = triton_tanh(x / SOFTCAP) # Store tanh(x/t) for gradient calculation
211 | x = SOFTCAP * tanh_term # This is the softcapped value
212 |
213 | logsumexp = tl.load(logsumexp_ptr + row_idx)
214 |
215 | # Compute softmax: exp(x - logsumexp) = softmax(x) for the whole row
216 | # This gives us part of the gradient formula
217 | y = tl.exp(x - logsumexp)
218 |
219 | # Adjust gradient for the target class
220 | # For i = target: gradient = softmax(x_i) - 1
221 | # For i ≠ target: gradient = softmax(x_i)
222 | y = tl.where(
223 | col_offsets == label_idx,
224 | y - 1.0, # For target class: exp(x - logsumexp) - 1
225 | y, # For other classes: exp(x - logsumexp)
226 | )
227 |
228 | # Apply chain rule for logit scaling
229 | # If x' = s*x, then dL/dx = dL/dx' * dx'/dx = dL/dx' * s
230 | if DO_LOGIT_SCALING:
231 | y = y * LOGIT_SCALE
232 |
233 | # Apply chain rule for softcapping
234 | # For x' = t*tanh(x/t), dx'/dx = 1 - tanh²(x/t)
235 | # This is the derivative of tanh: d/dx[tanh(x)] = 1 - tanh²(x)
236 | if DO_SOFTCAPPING:
237 | y = y * (1.0 - tanh_term*tanh_term) # tanh_term = tanh(x/t)
238 |
239 | # Store final gradients
240 | # For padding tokens (label_idx == -100), gradient is 0
241 | tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
242 |
243 | class Fast_CrossEntropyLoss(torch.autograd.Function):
244 | @staticmethod
245 | def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0):
246 | n_rows : int
247 | vocab_size : int
248 | n_rows, vocab_size = logits.shape
249 |
250 | losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
251 |
252 | DO_SOFTCAPPING : bool = bool(logit_softcapping != 0)
253 | DO_LOGIT_SCALING : bool = bool(logit_scaling != 0)
254 |
255 | BLOCK_SIZE : int
256 | num_warps : int
257 | # For small vocabs <= 65336 like Llama, Mistral
258 | BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
259 | logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
260 |
261 | _cross_entropy_forward[(n_rows,)](
262 | logits, logits.stride(0),
263 | losses,
264 | logsumexp,
265 | labels,
266 | VOCAB_SIZE = vocab_size,
267 | BLOCK_SIZE = BLOCK_SIZE,
268 | DO_SOFTCAPPING = DO_SOFTCAPPING,
269 | SOFTCAP = logit_softcapping,
270 | DO_LOGIT_SCALING = DO_LOGIT_SCALING,
271 | LOGIT_SCALE = logit_scaling,
272 | num_warps = num_warps,
273 | )
274 |
275 | ctx.save_for_backward(logits, logsumexp, labels)
276 | ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
277 | ctx.logit_softcapping = logit_softcapping
278 | ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING
279 | ctx.logit_scaling = logit_scaling
280 | return losses
281 | pass
282 |
283 |
284 | @staticmethod
285 | def backward(ctx, dlosses):
286 | logits, logsumexp, labels = ctx.saved_tensors
287 | n_rows : int
288 | vocab_size : int
289 | n_rows, vocab_size = logits.shape
290 |
291 | BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
292 |
293 | _cross_entropy_backward[(n_rows,)](
294 | logits, logits.stride(0),
295 | dlosses, dlosses.stride(0),
296 | logsumexp,
297 | labels,
298 | VOCAB_SIZE = vocab_size,
299 | BLOCK_SIZE = BLOCK_SIZE,
300 | DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
301 | SOFTCAP = ctx.logit_softcapping,
302 | DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
303 | LOGIT_SCALE = ctx.logit_scaling,
304 | num_warps = num_warps,
305 | )
306 | return logits, None, None, None
307 |
308 |
309 | def fast_cross_entropy_loss(logits, labels, logit_softcapping=0, logit_scaling=0, n_items=None):
310 | """
311 | Arguments:
312 | logits: (batch, seq_len, vocab_size)
313 | labels: (batch, seq_len,)
314 | Returns:
315 | losses: float
316 | """
317 | batch, seq_len, d = logits.shape
318 | assert(labels.shape == (batch, seq_len))
319 |
320 | loss = Fast_CrossEntropyLoss.apply(
321 | logits.view(batch*seq_len, d),
322 | labels.view(-1),
323 | logit_softcapping,
324 | logit_scaling,
325 | )
326 | if n_items is None:
327 | n_items = torch.count_nonzero(labels != -100)
328 | return loss.sum() / n_items
329 |
330 |
331 | def reference_cross_entropy_loss(logits, labels, logit_softcapping=0, logit_scaling=0):
332 | """Reference implementation using PyTorch's native functions"""
333 | if logit_scaling != 0:
334 | logits = logits * logit_scaling
335 |
336 | if logit_softcapping != 0:
337 | logits = logit_softcapping * torch.tanh(logits / logit_softcapping)
338 |
339 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
340 |
341 | # Get the log probability for the correct labels
342 | label_mask = labels != -100
343 | labels_masked = labels.clone()
344 | labels_masked[~label_mask] = 0
345 |
346 | # Gather the log probabilities for the correct labels
347 | label_log_probs = log_probs.gather(dim=-1, index=labels_masked.unsqueeze(-1)).squeeze(-1)
348 |
349 | # Apply the mask to ignore padding tokens
350 | label_log_probs = label_log_probs * label_mask
351 |
352 | loss = -label_log_probs.sum() / label_mask.sum()
353 | return loss
354 |
355 |
356 | def test_cross_entropy():
357 | """Test the forward and backward pass of the custom cross entropy implementation"""
358 | print("Testing Fast Cross Entropy implementation...")
359 |
360 | # Test configurations
361 | test_configs = [
362 | {"name": "Standard", "softcap": 0, "scaling": 0},
363 | {"name": "With Softcapping", "softcap": 10.0, "scaling": 0},
364 | {"name": "With Scaling", "softcap": 0, "scaling": 2.0},
365 | {"name": "With Both", "softcap": 10.0, "scaling": 2.0}
366 | ]
367 |
368 | for config in test_configs:
369 | print(f"\nTesting {config['name']} configuration...")
370 |
371 | # Create test inputs
372 | batch_size, seq_len, vocab_size = 2, 10, 32000
373 | logits = torch.randn(batch_size, seq_len, vocab_size, device='cuda', requires_grad=True)
374 | # Create labels with some -100 values to test padding
375 | labels = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda')
376 | labels[0, 0] = -100 # Add some padding tokens
377 |
378 | # Clone inputs for reference implementation
379 | logits_ref = logits.clone().detach().requires_grad_(True)
380 |
381 | # Forward pass
382 | our_loss = fast_cross_entropy_loss(
383 | logits, labels,
384 | logit_softcapping=config['softcap'],
385 | logit_scaling=config['scaling']
386 | )
387 |
388 | # Reference implementation
389 | ref_loss = reference_cross_entropy_loss(
390 | logits_ref, labels,
391 | logit_softcapping=config['softcap'],
392 | logit_scaling=config['scaling']
393 | )
394 |
395 | # Compare forward results
396 | forward_diff = torch.abs(our_loss - ref_loss).item()
397 | print(f"Forward pass difference: {forward_diff:.6f}")
398 | assert forward_diff < 1e-4, f"Forward pass failed for {config['name']} configuration!"
399 |
400 | # Backward pass
401 | our_loss.backward()
402 | ref_loss.backward()
403 | # Compare gradients
404 |
405 | grad_diff = torch.max(torch.abs(logits.grad - logits_ref.grad)).item()
406 | print(f"Max gradient difference: {grad_diff:.6f}")
407 | assert grad_diff < 1e-4, f"Backward pass failed for {config['name']} configuration!"
408 |
409 | # Reset gradients for next test
410 | logits.grad.zero_()
411 | logits_ref.grad.zero_()
412 |
413 | print("\nAll tests passed successfully!")
414 | return True
415 |
416 | if __name__ == "__main__":
417 | test_cross_entropy()
418 |
--------------------------------------------------------------------------------
/annotated_examples/classics/geglu.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2 | # Modifications Copyright 2025 Mekkcyber.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import triton
17 | import triton.language as tl
18 | import torch
19 |
20 | from triton.language.extra import libdevice
21 | triton_tanh = libdevice.tanh
22 | triton_cast = tl.cast
23 | MAX_FUSED_SIZE : int = 65536
24 | next_power_of_2 = triton.next_power_of_2
25 |
26 | def calculate_settings(n : int) -> (int, int,):
27 | BLOCK_SIZE : int = next_power_of_2(n)
28 | if BLOCK_SIZE > MAX_FUSED_SIZE:
29 | raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
30 | f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
31 | num_warps : int = 4
32 | if BLOCK_SIZE >= 32768: num_warps = 32
33 | elif BLOCK_SIZE >= 8192: num_warps = 16
34 | elif BLOCK_SIZE >= 2048: num_warps = 8
35 | return BLOCK_SIZE, num_warps
36 |
37 | @triton.jit
38 | def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
39 | """
40 | Forward pass for the GeGLU (Gated Gaussian Error Linear Unit) activation function.
41 |
42 | This kernel computes:
43 | 1. The GELU activation: f = 0.5 * e * (1 + erf(e/sqrt(2)))
44 | 2. The GeGLU output: h = f * g
45 |
46 | Parameters:
47 | - e: gate values (first half of the projection)
48 | - g: up values (second half of the projection)
49 | - h: output tensor to store the result
50 | - n_elements: total number of elements in the tensors
51 | - BLOCK_SIZE: size of each CUDA block for parallelization
52 | """
53 | # Get the current block index in the grid
54 | block_idx = tl.program_id(0)
55 |
56 | # Calculate memory offsets for this block
57 | # Each block thread processes BLOCK_SIZE elements from the input
58 | offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
59 |
60 | # Create a mask to handle the case where n_elements is not divisible by BLOCK_SIZE
61 | mask = offsets < n_elements
62 |
63 | # Load input values from global memory
64 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
65 | g_row = tl.load(g + offsets, mask = mask, other = 0)
66 |
67 | # Compute GELU activation using the exact formula:
68 | # f(x) = 0.5 * x * (1 + erf(x/sqrt(2)))
69 | # where erf is the error function
70 | # rsqrt(2.0) computes 1/sqrt(2)
71 | f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
72 |
73 | # Convert the result back to the same dtype as g_row
74 | f_row = f_row.to(g_row.dtype)
75 |
76 | # Compute the final GeGLU output by multiplying the GELU activation with g element-wise
77 | h_row = f_row * g_row
78 |
79 | # Store the result back to global memory
80 | tl.store(h + offsets, h_row, mask = mask)
81 |
82 |
83 | def geglu_exact_forward_kernel(gate, up):
84 | batch, seq_len, hd = gate.shape
85 | n_elements = gate.numel()
86 | out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
87 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
88 | _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
89 | return out
90 |
91 |
92 | @triton.jit
93 | def _exact_backward_kernel(dY, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
94 | """
95 | Backward pass for the GeGLU (Gated Gaussian Error Linear Unit) activation function.
96 |
97 | In the forward pass:
98 | - f = 1/2 * e * (1 + erf(1/sqrt(2) * e)) # The GELU function
99 | - h = f * g # The GeGLU output
100 |
101 | Where:
102 | - e: gate values (first half of the projection)
103 | - g: up values (second half of the projection)
104 | - h: output of GeGLU
105 |
106 | In the backward pass, we need to compute:
107 | - de: gradient with respect to e (gate values)
108 | - dg: gradient with respect to g (up values)
109 |
110 | For de, we need the derivative of f with respect to e:
111 | df/de = 1/2 * (1 + erf(1/sqrt(2) * x)) + 1/sqrt(2*pi) * x * exp(-1/2 * x^2) (see backprop_math/geglu.md)
112 |
113 | Parameters:
114 | - dY: gradient flowing from the next layer (dL/dh)
115 | - e: gate values from forward pass
116 | - g: up values from forward pass
117 | - n_elements: total number of elements in the tensors
118 | - BLOCK_SIZE: size of each CUDA block for parallelization
119 | """
120 | block_idx = tl.program_id(0)
121 | offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
122 | mask = offsets < n_elements
123 |
124 | # Load the gradients and values
125 | dY_row = tl.load(dY + offsets, mask = mask, other = 0)
126 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
127 | g_row = tl.load(g + offsets, mask = mask, other = 0)
128 |
129 | # Compute the partial GELU activation: 1/2 * (1 + erf(1/sqrt(2) * e))
130 | # This is reused in both the forward computation and the derivative
131 | f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
132 |
133 | # Complete the GELU computation: f = f_partial_row * e_row
134 | f_row = f_partial_row * e_row
135 | f_row = f_row.to(dY_row.dtype)
136 |
137 | # Compute gradient for g: dg = dY * f
138 | # By chain rule: dL/dg = dL/dh * dh/dg = dY * f (as specified above h is the output of the GeGLU)
139 | dg_row = dY_row * f_row
140 |
141 | # Compute gradient for e using the derivative of GELU
142 | # df/de = f_partial_row + (1/sqrt(2*pi)) * e * exp(-e²/2)
143 | t = 0.3989422804014327 # 1/sqrt(2*pi)
144 | df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)
145 |
146 | # Apply chain rule: dL/de = dL/dh * dh/de = dY * (g * df/de)
147 | de_row = g_row.to(tl.float32) * df_de
148 | de_row = de_row.to(dY_row.dtype) * dY_row
149 |
150 | # Store the computed gradients back to memory
151 | tl.store(e + offsets, de_row, mask = mask)
152 | tl.store(g + offsets, dg_row, mask = mask)
153 |
154 |
155 | def geglu_exact_backward_kernel(DW, e, g):
156 | n_elements = e.numel()
157 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
158 | _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
159 | return e, g
160 |
161 |
162 | @triton.jit
163 | def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
164 | """
165 | Computes the forward pass of the approximate GELU activation function for GeGLU.
166 |
167 | GeGLU (Gated GELU Linear Unit) combines a gating mechanism with GELU activation:
168 | - Input is split into two parts: gate (e) and up (g)
169 | - GELU is applied to the gate values
170 | - The result is multiplied by the up values
171 |
172 | This kernel implements the approximate version of GELU which is faster but slightly less accurate.
173 |
174 | Formula:
175 | f(e) = 0.5 * e * (1 + tanh(sqrt(2/π) * e * (1 + 0.044715 * e²)))
176 | h = f(e) * g
177 |
178 | Where:
179 | - e: gate values
180 | - g: up values
181 | - h: output
182 | - f(e): approximate GELU activation
183 | """
184 | # Get the current block index and compute offsets for each thread
185 | block_idx = tl.program_id(0)
186 | offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
187 | mask = offsets < n_elements
188 |
189 | # Constant for the GELU approximation: sqrt(2/π)
190 | s = 0.7978845608028654 # Precomputed value of sqrt(2/π)
191 |
192 | # Load gate and up values from memory
193 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
194 | g_row = tl.load(g + offsets, mask = mask, other = 0)
195 |
196 | # Compute the approximate GELU activation:
197 | # f(e) = 0.5 * e * (1 + tanh(sqrt(2/π) * e * (1 + 0.044715 * e²)))
198 | #
199 | # This is a faster approximation of the exact GELU:
200 | # f(e) = 0.5 * e * (1 + erf(e/sqrt(2)))
201 | #
202 | # The approximation uses tanh instead of erf and adds a cubic term
203 | # to better match the shape of the exact GELU function
204 | inner_term = s * e_row * (1.0 + 0.044715 * e_row * e_row)
205 | f_row = 0.5 * e_row * (triton_tanh(inner_term) + 1.0)
206 |
207 | # Convert back to the original data type
208 | f_row = f_row.to(g_row.dtype)
209 |
210 | # Compute the final output: h = f(e) * g
211 | h_row = f_row * g_row
212 |
213 | # Store the result back to memory
214 | tl.store(h + offsets, h_row, mask = mask)
215 |
216 |
217 | def geglu_approx_forward_kernel(gate, up):
218 | batch, seq_len, hd = gate.shape
219 | n_elements = gate.numel()
220 | out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
221 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
222 | _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
223 | return out
224 |
225 | @triton.jit
226 | def _approx_backward_kernel(dY, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
227 | """
228 | Backward pass for the approximate GELU activation function.
229 |
230 | Forward pass:
231 | f(e) = 0.5 * e * (1 + tanh(sqrt(2/π) * e * (1 + 0.044715 * e²)))
232 | h = f(e) * g
233 |
234 | Where:
235 | - e: gate values
236 | - g: up values
237 | - dY: gradient from upstream layers
238 | - h: output
239 |
240 | Backward pass derivatives:
241 | 1. df/de = 0.5 * (1 + tanh(inner))(1 + e * (2 - (1 + tanh(inner))) * d(inner)/de)
242 | where inner = sqrt(2/π) * e * (1 + 0.044715 * e²)
243 | 2. d(inner)/de = sqrt(2/π) * (1 + 0.044715 * e² * 3) = (a + 3b * e²) where a = sqrt(2/π) and b = 0.044715*sqrt(2/π)
244 | 3. de = dY * g * df/de
245 | 4. dg = dY * f(e)
246 |
247 | """
248 | # Get block index and compute offsets for parallel processing
249 | block_idx = tl.program_id(0)
250 | offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
251 | mask = offsets < n_elements
252 |
253 | # Load input values from memory
254 | dY_row = tl.load(dY + offsets, mask=mask, other=0)
255 | e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
256 | g_row = tl.load(g + offsets, mask=mask, other=0)
257 |
258 | # Constants for the GELU approximation
259 | a = 0.7978845608028654 # Precomputed value of sqrt(2/π)
260 | b = 0.044715 * a
261 |
262 | # Full inner term: sqrt(2/π) * e * (1 + 0.044715 * e²)
263 | inner = e_row * (a + b * e_row * e_row)
264 |
265 | # Compute tanh of the inner term
266 | tanh_inner = triton_tanh(inner)
267 |
268 | # Compute (1 + tanh(inner_term))
269 | v = 1.0 + tanh_inner
270 |
271 | # compute f(e) = 0.5 * e * (1 + tanh(inner_term)) based on the forward pass formula in the backprop_math/geglu.md
272 | f_row = 0.5 * e_row * v
273 |
274 | # compute df/de based on the fomula in the backprop_math/geglu.md
275 | df_de = 0.5 * v * (1.0 + e_row * (2.0 - v) * (a + 3.0 * b * e_row * e_row))
276 |
277 | # Compute gradients for backpropagation:
278 | # dg = dY * f(e)
279 | dg_row = dY_row * f_row
280 |
281 | # Compute gradients for backpropagation:
282 | # de = dY * g * df/de
283 | de_row = g_row * df_de
284 | de_row = de_row.to(dY_row.dtype) * dY_row
285 |
286 | # Store results and gradients back to memory
287 | tl.store(e + offsets, de_row, mask=mask) # Store gradient for gate values
288 | tl.store(g + offsets, dg_row, mask=mask) # Store gradient for up values
289 |
290 |
291 | def geglu_approx_backward_kernel(dY, e, g):
292 | n_elements = e.numel()
293 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
294 | _approx_backward_kernel[grid](dY, e, g, n_elements, BLOCK_SIZE = 1024,)
295 | return e, g
296 |
297 |
298 | def test_geglu_correctness(use_approx=False):
299 | """
300 | Test the correctness of the GEGLU implementation by comparing with a reference implementation.
301 | Tests both forward and backward passes for GEGLU (Gated GELU).
302 |
303 | Args:
304 | use_approx (bool): If True, use the approximate GEGLU implementation.
305 | If False, use the exact GEGLU implementation.
306 | """
307 | import torch
308 | import torch.nn.functional as F
309 |
310 | # Define reference implementations for GEGLU (Gated GELU)
311 | def geglu_reference_forward(x):
312 | """Reference implementation of GEGLU forward pass"""
313 | x_chunks = torch.chunk(x, 2, dim=-1)
314 | gate, value = x_chunks[0], x_chunks[1]
315 | return value * F.gelu(gate)
316 |
317 | # Select the appropriate kernels based on the use_approx flag
318 | forward_kernel = _approx_forward_kernel if use_approx else _exact_forward_kernel
319 | backward_kernel = _approx_backward_kernel if use_approx else _exact_backward_kernel
320 |
321 | implementation_type = "approximate" if use_approx else "exact"
322 | print(f"Testing {implementation_type} GEGLU implementation...")
323 |
324 | def test_forward():
325 | """Test the forward pass of GEGLU"""
326 | print(f"Testing {implementation_type} GEGLU forward pass...")
327 |
328 | batch_size, seq_len, hidden_dim = 2, 10, 128
329 | x = torch.randn(batch_size, seq_len, hidden_dim * 2, device='cuda', requires_grad=True)
330 |
331 | ref_output = geglu_reference_forward(x)
332 |
333 | x_chunks = torch.chunk(x, 2, dim=-1)
334 | gate, value = x_chunks[0], x_chunks[1]
335 | gate_flat = gate.reshape(-1)
336 | value_flat = value.reshape(-1)
337 |
338 | output_flat = torch.empty_like(gate_flat)
339 |
340 | n_elements = gate_flat.numel()
341 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
342 | forward_kernel[grid](gate_flat, value_flat, output_flat, n_elements, BLOCK_SIZE=1024)
343 |
344 | our_output = output_flat.reshape(gate.shape)
345 |
346 | max_diff = torch.max(torch.abs(ref_output - our_output))
347 | print(f"Max difference in {implementation_type} GEGLU forward pass: {max_diff.item()}")
348 | assert max_diff < 1e-2 if use_approx else 1e-5, f"{implementation_type} GEGLU forward pass implementation is incorrect!"
349 | return True
350 |
351 | def test_backward():
352 | """Test the backward pass of GEGLU"""
353 | print(f"Testing {implementation_type} GEGLU backward pass...")
354 |
355 | batch_size, seq_len, hidden_dim = 2, 10, 128
356 | x = torch.randn(batch_size, seq_len, hidden_dim * 2, device='cuda', requires_grad=True)
357 |
358 | x_ref = x.clone().detach().requires_grad_(True)
359 | ref_output = geglu_reference_forward(x_ref)
360 |
361 | grad_output = torch.randn_like(ref_output)
362 |
363 | ref_output.backward(grad_output)
364 | ref_grad = x_ref.grad.clone()
365 |
366 | x_chunks = torch.chunk(x, 2, dim=-1)
367 | gate, value = x_chunks[0], x_chunks[1]
368 | gate_flat = gate.reshape(-1)
369 | value_flat = value.reshape(-1)
370 |
371 | output_flat = torch.empty_like(gate_flat)
372 | n_elements = gate_flat.numel()
373 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
374 | forward_kernel[grid](gate_flat, value_flat, output_flat, n_elements, BLOCK_SIZE=1024)
375 |
376 | grad_output_flat = grad_output.reshape(-1)
377 |
378 | dW = grad_output_flat.clone()
379 | e = gate_flat.clone()
380 | g = value_flat.clone()
381 |
382 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
383 | backward_kernel[grid](dW, e, g, n_elements, BLOCK_SIZE=1024)
384 |
385 | our_grad = torch.cat([e.reshape(gate.shape), g.reshape(value.shape)], dim=-1)
386 |
387 | max_diff = torch.max(torch.abs(ref_grad - our_grad))
388 | print(f"Max difference in {implementation_type} GEGLU backward pass: {max_diff.item()}")
389 | assert max_diff < 1e-2 if use_approx else 1e-5, f"{implementation_type} GEGLU backward pass implementation is incorrect!"
390 | return True
391 |
392 | forward_passed = test_forward()
393 | backward_passed = test_backward()
394 |
395 | if forward_passed and backward_passed:
396 | print(f"All tests passed! {implementation_type.capitalize()} GEGLU implementation is correct.")
397 | else:
398 | print(f"Tests failed! {implementation_type.capitalize()} GEGLU implementation needs fixing.")
399 |
400 | if __name__ == "__main__":
401 | # Test exact implementation
402 | test_geglu_correctness(use_approx=False)
403 |
404 | # Test approximate implementation
405 | test_geglu_correctness(use_approx=True)
406 |
--------------------------------------------------------------------------------
/annotated_examples/classics/layernorm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2 | # Modifications Copyright 2025 Mekkcyber.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License
15 |
16 | import triton
17 | import triton.language as tl
18 | import torch
19 |
20 | MAX_FUSED_SIZE : int = 65536
21 | next_power_of_2 = triton.next_power_of_2
22 |
23 | def calculate_settings(n : int) -> (int, int,):
24 | BLOCK_SIZE : int = next_power_of_2(n)
25 | if BLOCK_SIZE > MAX_FUSED_SIZE:
26 | raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
27 | f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
28 | num_warps : int = 4
29 | if BLOCK_SIZE >= 32768: num_warps = 32
30 | elif BLOCK_SIZE >= 8192: num_warps = 16
31 | elif BLOCK_SIZE >= 2048: num_warps = 8
32 | return BLOCK_SIZE, num_warps
33 |
34 | @triton.jit
35 | def layernorm_forward(
36 | Y, Y_row_stride, # Output tensor and its row stride
37 | X, X_row_stride, # Input tensor and its row stride
38 | weight, # Scale parameter for normalization
39 | bias, # Bias parameter for normalization
40 | inv_var, # Buffer to store inverse variance
41 | mean, # Buffer to store mean
42 | n_cols, eps, # Number of columns and epsilon for numerical stability
43 | BLOCK_SIZE : tl.constexpr # Compile-time constant for block size
44 | ):
45 | """
46 | This kernel implements the forward pass of Layer Normalization using Triton.
47 |
48 | Layer Normalization normalizes each input row independently using the formula:
49 | y = ((x - mean) / sqrt(variance + eps)) * weight + bias
50 |
51 | Example with a 3x5 input matrix X:
52 | X = [
53 | [1.0, 2.0, 3.0, 4.0, 5.0], # Row 0
54 | [6.0, 7.0, 8.0, 9.0, 10.0], # Row 1
55 | [11.0, 12.0, 13.0, 14.0, 15.0] # Row 2
56 | ]
57 | weight = [0.5, 0.5, 0.5, 0.5, 0.5]
58 | bias = [0.1, 0.1, 0.1, 0.1, 0.1]
59 |
60 | For row_idx = 1 (second CUDA thread block):
61 | """
62 |
63 | # Each CUDA thread block processes one row of the input
64 | row_idx = tl.program_id(0)
65 |
66 | # Create column indices [0, 1, 2, ..., BLOCK_SIZE-1]
67 | # BLOCK_SIZE is the nearest power of 2 greater than or equal to n_cols
68 | # These will be used to access elements within a row
69 | col_offsets = tl.arange(0, BLOCK_SIZE)
70 |
71 | # Create a mask to handle cases where n_cols < BLOCK_SIZE
72 | # For example, if n_cols=5 and BLOCK_SIZE=8, only the first 5 elements are valid
73 | mask = col_offsets < n_cols
74 |
75 | # In the case of Layer Normalization, the input tensor X and output tensor Y have the same shape.
76 | # This means we can use the same indexing pattern to access corresponding elements in both tensors.
77 | # We're using row_idx to determine which row we're processing, and then using the same
78 | # col_offsets within that row to access individual elements.
79 | #
80 | # The row_stride parameters (X_row_stride and Y_row_stride) tell us how far to jump in memory
81 | # to move from one row to the next. While these are often the same for X and Y, having separate
82 | # stride parameters allows for flexibility in memory layout.
83 |
84 | # In row-major order, elements in a row are stored contiguously in memory
85 | # For a matrix with n_cols columns, the row_stride equals n_cols
86 | # Example with our 3x5 matrix X stored in row-major order in memory:
87 | # [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]
88 | # |--------- Row0 ---------|---------- Row1 ---------|---------- Row2 -----------|
89 | # To move from the start of Row 0 to the start of Row 1, we add row_stride (5)
90 |
91 | # In the beginning X and Y point to the first element of their first row
92 | # if row_idx==1 :
93 | # - Y + row_idx * Y_row_stride = Y + 1 * 5 = Y + 5 points to the second row of Y
94 | # - X + row_idx * X_row_stride = X + 1 * 5 = X + 5 points to the second row of X
95 |
96 | Y += row_idx * Y_row_stride
97 | X += row_idx * X_row_stride
98 | # inv_var and mean are 1D tensors with n_rows elements
99 | # when row_idx==1, inv_var points to the second element in the inverse variance buffer
100 | # when row_idx==1, mean points to the second element in the mean buffer
101 | inv_var += row_idx
102 | mean += row_idx
103 |
104 | # Load the entire row from input tensor X
105 | # For row_idx=1: X_row = [6.0, 7.0, 8.0, 9.0, 10.0, 0, 0, 0]
106 | # The 'other=0' parameter sets values outside the mask to 0
107 | X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
108 |
109 | # Load weight parameters for this row
110 | # weight_row = [0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0] (if BLOCK_SIZE=8)
111 | weight_row = tl.load(weight + col_offsets, mask = mask, other = 0).to(tl.float32)
112 |
113 | # Load bias parameters for this row
114 | # bias_row = [0.1, 0.1, 0.1, 0.1, 0.1, 0, 0, 0] (if BLOCK_SIZE=8)
115 | bias_row = tl.load(bias + col_offsets, mask = mask, other = 0).to(tl.float32)
116 |
117 | # Calculate mean of the row
118 | # For row_idx=1: mean_X = (6.0 + 7.0 + 8.0 + 9.0 + 10.0) / 5 = 8.0
119 | mean_X = tl.sum(X_row, axis = 0) / n_cols
120 |
121 | # Subtract mean from each element in the row
122 | # For row_idx=1: XX = [-2.0, -1.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0]
123 | XX = tl.where(mask, X_row - mean_X, 0)
124 |
125 | # Calculate variance of the row
126 | # For row_idx=1: row_var = ((-2.0)² + (-1.0)² + 0.0² + 1.0² + 2.0²) / 5 = (4.0 + 1.0 + 0.0 + 1.0 + 4.0) / 5 = 2.0
127 | row_var = tl.sum(XX * XX, axis = 0) / n_cols
128 |
129 | # Calculate inverse square root of variance (for stability, add epsilon)
130 | # For row_idx=1: inv_var_val = 1/sqrt(2.0 + eps) ≈ 0.707
131 | inv_var_val = tl.math.rsqrt(row_var + eps)
132 |
133 | # Store the inverse variance and mean for later use in backward pass
134 | tl.store(inv_var, inv_var_val)
135 | tl.store(mean, mean_X)
136 |
137 | # Calculate normalized output with scaling and bias
138 | # For row_idx=1:
139 | # output = ([-2.0, -1.0, 0.0, 1.0, 2.0] * 0.707) * [0.5, 0.5, 0.5, 0.5, 0.5] + [0.1, 0.1, 0.1, 0.1, 0.1]
140 | # = [-0.607, -0.2535, 0.1, 0.4535, 0.807]
141 | output = (XX * inv_var_val) * weight_row + bias_row
142 |
143 | # Store the output row
144 | tl.store(Y + col_offsets, output, mask = mask)
145 |
146 |
147 | @triton.jit
148 | def layernorm_backward(
149 | dY, dY_row_stride, # Gradient from upstream and its row stride
150 | X, X_row_stride, # Input tensor and its row stride
151 | weight, # Scale parameter for normalization
152 | bias, # Bias parameter for normalization
153 | inv_var, # Stored inverse variance from forward pass
154 | mean, # Stored mean from forward pass
155 | n_cols, eps, # Number of columns and epsilon for numerical stability
156 | BLOCK_SIZE : tl.constexpr # Compile-time constant for block size
157 | ):
158 | """
159 | This kernel implements the backward pass of Layer Normalization using Triton.
160 |
161 | The backward pass computes the gradient with respect to the input (dX) given the
162 | gradient with respect to the output (dY).
163 |
164 | Example with a 3x5 input matrix X and corresponding gradient dY:
165 | X = [
166 | [1.0, 2.0, 3.0, 4.0, 5.0], # Row 0
167 | [6.0, 7.0, 8.0, 9.0, 10.0], # Row 1
168 | [11.0, 12.0, 13.0, 14.0, 15.0] # Row 2
169 | ]
170 | dY = [
171 | [0.1, 0.2, 0.3, 0.4, 0.5], # Row 0
172 | [0.6, 0.7, 0.8, 0.9, 1.0], # Row 1
173 | [1.1, 1.2, 1.3, 1.4, 1.5] # Row 2
174 | ]
175 | weight = [0.5, 0.5, 0.5, 0.5, 0.5]
176 | """
177 |
178 | # Each CUDA thread block processes one row of the input
179 | row_idx = tl.program_id(0)
180 |
181 | # Create column indices [0, 1, 2, ..., BLOCK_SIZE-1]
182 | col_offsets = tl.arange(0, BLOCK_SIZE)
183 |
184 | # Create a mask to handle cases where n_cols < BLOCK_SIZE
185 | mask = col_offsets < n_cols
186 |
187 | # Calculate pointers to the current row in each tensor
188 | # For row_idx=1, we're processing the second row of each tensor
189 | dY += row_idx * dY_row_stride
190 | X += row_idx * X_row_stride
191 | inv_var += row_idx
192 | mean += row_idx
193 |
194 | # Load the gradient from upstream (dY)
195 | # For row_idx=1: dY_row = [0.6, 0.7, 0.8, 0.9, 1.0, 0, 0, 0]
196 | dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
197 |
198 | # Load the input values
199 | # For row_idx=1: X_row = [6.0, 7.0, 8.0, 9.0, 10.0, 0, 0, 0]
200 | X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
201 |
202 | # Load weight parameters
203 | # weight_row = [0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0]
204 | weight_row = tl.load(weight + col_offsets, mask = mask, other = 0).to(tl.float32)
205 |
206 | # Load the stored inverse variance and mean from the forward pass
207 | # For row_idx=1: inv_var_val ≈ 0.707, mean_val = 8.0
208 | inv_var_val = tl.load(inv_var).to(tl.float32)
209 | mean_val = tl.load(mean).to(tl.float32)
210 |
211 | # Calculate the normalized input values (same as in forward pass)
212 | # For row_idx=1: normed = [(6.0-8.0)*0.707, (7.0-8.0)*0.707, (8.0-8.0)*0.707, (9.0-8.0)*0.707, (10.0-8.0)*0.707]
213 | # = [-1.414, -0.707, 0.0, 0.707, 1.414]
214 | normed = (X_row - mean_val) * inv_var_val
215 |
216 | # Scale the upstream gradient by the weight
217 | # For row_idx=1: dY_W = [0.6*0.5, 0.7*0.5, 0.8*0.5, 0.9*0.5, 1.0*0.5]
218 | # = [0.3, 0.35, 0.4, 0.45, 0.5]
219 | dY_W = dY_row * weight_row
220 |
221 | # Calculate the gradient with respect to the input
222 | # This follows the chain rule for backpropagation through layer normalization
223 | # The formula has three terms:
224 | # 1. dY_W: direct contribution from upstream gradient
225 | # 2. -tl.sum(dY_W, axis=0)/n_cols: contribution from the mean term
226 | # 3. -normed * tl.sum(dY_W * normed, axis=0)/n_cols: contribution from the variance term
227 |
228 | # In general, the result would be non-zero and would then be scaled by inv_var_val
229 | dX_row = dY_W - tl.sum(dY_W, axis = 0) / n_cols - normed * tl.sum(dY_W * normed, axis = 0) / n_cols
230 | dX_row = dX_row * inv_var_val
231 |
232 | # Store the gradient with respect to the input
233 | # Note: We're reusing the dY tensor to store the result (in-place operation)
234 | tl.store(dY + col_offsets, dX_row, mask = mask)
235 |
236 |
237 | class Fast_Layernorm(torch.autograd.Function):
238 | @staticmethod
239 | def forward(ctx, X, weight, bias, eps):
240 | shape = X.shape
241 | dim = shape[-1]
242 | X = X.view(-1, dim)
243 | n_rows, n_cols = X.shape
244 | BLOCK_SIZE, num_warps = calculate_settings(n_cols)
245 | device = X.device
246 | Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
247 | inv_var = torch.empty(n_rows, dtype = torch.float32, device = device)
248 | mean = torch.empty(n_rows, dtype = torch.float32, device = device)
249 |
250 | layernorm_forward[(n_rows,)](
251 | Y, Y.stride(0),
252 | X, X.stride(0),
253 | weight,
254 | bias,
255 | inv_var,
256 | mean,
257 | n_cols, eps,
258 | BLOCK_SIZE = BLOCK_SIZE,
259 | num_warps = num_warps,
260 | )
261 | ctx.eps = eps
262 | ctx.BLOCK_SIZE = BLOCK_SIZE
263 | ctx.num_warps = num_warps
264 | ctx.save_for_backward(X, weight, bias, inv_var, mean)
265 | return Y.view(*shape)
266 | pass
267 |
268 | @staticmethod
269 | def backward(ctx, dY):
270 | shape = dY.shape
271 | dim = shape[-1]
272 | dY = dY.view(-1, dim)
273 | X, weight, bias, inv_var, mean = ctx.saved_tensors
274 | n_rows, n_cols = dY.shape
275 |
276 | layernorm_backward[(n_rows,)](
277 | dY, dY.stride(0),
278 | X, X.stride(0),
279 | weight,
280 | bias,
281 | inv_var,
282 | mean,
283 | n_cols, ctx.eps,
284 | BLOCK_SIZE = ctx.BLOCK_SIZE,
285 | num_warps = ctx.num_warps,
286 | )
287 | dX = dY.view(*shape)
288 | return dX, None, None, None, None
289 |
290 | def fast_layernorm(layernorm, X):
291 | assert(layernorm.elementwise_affine is True)
292 | W = layernorm.weight
293 | bias = layernorm.bias
294 | eps = layernorm.variance_epsilon if \
295 | hasattr(layernorm, "variance_epsilon") \
296 | else layernorm.eps
297 | out = Fast_Layernorm.apply(X, W, bias, eps)
298 | return out
--------------------------------------------------------------------------------
/annotated_examples/classics/matmul.py:
--------------------------------------------------------------------------------
1 | import torch, math, random, copy
2 | from torch import Tensor
3 | import triton
4 | import triton.language as tl
5 | import pdb
6 |
7 | @triton.jit
8 | def matmul_kernel(
9 | # Pointers to matrices
10 | a_ptr, b_ptr, c_ptr,
11 | # Matrix dimensions
12 | M, N, K,
13 | # The stride variables represent how much to increase the ptr by when moving by 1
14 | # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
15 | # by to get the element one row down (A has M rows).
16 | stride_am, stride_ak, # Strides for matrix A (M, K)
17 | stride_bk, stride_bn, # Strides for matrix B (K, N)
18 | stride_cm, stride_cn, # Strides for matrix C (M, N)
19 | # Meta-parameters
20 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # Tile sizes
21 | GROUP_SIZE_M: tl.constexpr, # Number of M-dimension tiles per group (for L2 cache optimization)
22 | ACTIVATION: tl.constexpr # Optional activation function to apply
23 | ):
24 | """Kernel for computing the matmul C = A x B.
25 | A has shape (M, K), B has shape (K, N) and C has shape (M, N)
26 |
27 | Example:
28 | For M=16, N=8, K=16 with BLOCK_SIZE_M=2, BLOCK_SIZE_N=2, BLOCK_SIZE_K=2, GROUP_SIZE_M=3:
29 |
30 | Matrix A (16x16):
31 | [A00, A01, A02, ..., A0,15]
32 | [A10, A11, A12, ..., A1,15]
33 | [... ]
34 | [A15,0, A15,1, ..., A15,15]
35 |
36 | Matrix B (16x8):
37 | [B00, B01, B02, ..., B07]
38 | [B10, B11, B12, ..., B17]
39 | [... ]
40 | [B15,0, B15,1, ..., B15,7]
41 |
42 | Matrix C (16x8):
43 | [C00, C01, C02, ..., C07]
44 | [C10, C11, C12, ..., C17]
45 | [... ]
46 | [C15,0, C15,1, ..., C15,7]
47 |
48 | - We divide matrices into blocks of size 2x2
49 | - Matrix A (16x16) has 8x8 blocks
50 | - Matrix B (16x8) has 8x4 blocks
51 | - Matrix C (16x8) has 8x4 blocks
52 | - We'll have 32 thread blocks computing the 32 blocks of C
53 | """
54 | # -----------------------------------------------------------
55 | # STEP 1: Determine which block of the output matrix C this thread block will compute
56 | # -----------------------------------------------------------
57 | # Each thread block is assigned a unique program ID (pid)
58 | pid = tl.program_id(axis=0)
59 |
60 | # Calculate how many blocks we need in each dimension
61 | # For our example: num_pid_m = ceil(16/2) = 8, num_pid_n = ceil(8/2) = 4
62 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) # Number of blocks in M dimension
63 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # Number of blocks in N dimension
64 |
65 | # L2 cache optimization: Group blocks along M dimension to promote data reuse
66 | # For our example: num_pid_in_group = 3*4 = 12 (total blocks in a group)
67 | num_pid_in_group = GROUP_SIZE_M * num_pid_n
68 |
69 | # Determine which group this thread block belongs to
70 | # We have 3 groups, theoretically they all have 12 blocks
71 | # For pid<12: group_id = 0
72 | # For 12<=pid<24: group_id = 1
73 | # else: group_id = 2
74 | group_id = pid // num_pid_in_group
75 |
76 | # Find the first block index in M dimension for this group
77 | # For group_id=0: first_pid_m = 0
78 | # For group_id=1: first_pid_m = 3
79 | # For group_id=2: first_pid_m = 6
80 | first_pid_m = group_id * GROUP_SIZE_M
81 |
82 | # Calculate actual group size (might be smaller at boundaries)
83 | # For first_pid_m=0: group_size_m = min(8-0, 3) = 3
84 | # For first_pid_m=3: group_size_m = min(8-3, 3) = 3
85 | # For first_pid_m=6: group_size_m = min(8-6, 3) = 2 (last group is smaller)
86 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
87 |
88 | # Calculate the specific block indices this thread block will compute
89 | # For pid=0: pid_m = 0 + ((0 % 12) % 3) = 0, pid_n = (0 % 12) // 3 = 0
90 | # For pid=1: pid_m = 0 + ((1 % 12) % 3) = 1, pid_n = (1 % 12) // 3 = 0
91 | # For pid=2: pid_m = 0 + ((2 % 12) % 3) = 2, pid_n = (2 % 12) // 3 = 0
92 | # For pid=3: pid_m = 0 + ((3 % 12) % 3) = 0, pid_n = (3 % 12) // 3 = 1
93 | # For pid=4: pid_m = 0 + ((4 % 12) % 3) = 1, pid_n = (4 % 12) // 3 = 1
94 | # For pid=5: pid_m = 0 + ((5 % 12) % 3) = 2, pid_n = (5 % 12) // 3 = 1
95 | # For pid=6: pid_m = 0 + ((6 % 12) % 3) = 0, pid_n = (6 % 12) // 3 = 2
96 | # For pid=7: pid_m = 0 + ((7 % 12) % 3) = 1, pid_n = (7 % 12) // 3 = 2
97 | # For pid=8: pid_m = 0 + ((8 % 12) % 3) = 2, pid_n = (8 % 12) // 3 = 2
98 | # For pid=9: pid_m = 0 + ((9 % 12) % 3) = 0, pid_n = (9 % 12) // 3 = 3
99 | # For pid=10: pid_m = 0 + ((10 % 12) % 3) = 1, pid_n = (10 % 12) // 3 = 3
100 | # For pid=11: pid_m = 0 + ((11 % 12) % 3) = 2, pid_n = (11 % 12) // 3 = 3
101 | # For pid=12: pid_m = 3 + ((12 % 12) % 3) = 3, pid_n = (12 % 12) // 3 = 0
102 | # For pid=13: pid_m = 3 + ((13 % 12) % 3) = 4, pid_n = (13 % 12) // 3 = 0
103 | # For pid=14: pid_m = 3 + ((14 % 12) % 3) = 5, pid_n = (14 % 12) // 3 = 0
104 | # For pid=15: pid_m = 3 + ((15 % 12) % 3) = 3, pid_n = (15 % 12) // 3 = 1
105 | # For pid=16: pid_m = 3 + ((16 % 12) % 3) = 4, pid_n = (16 % 12) // 3 = 1
106 | # For pid=17: pid_m = 3 + ((17 % 12) % 3) = 5, pid_n = (17 % 12) // 3 = 1
107 | # For pid=18: pid_m = 3 + ((18 % 12) % 3) = 3, pid_n = (18 % 12) // 3 = 2
108 | # For pid=19: pid_m = 3 + ((19 % 12) % 3) = 4, pid_n = (19 % 12) // 3 = 2
109 | # For pid=20: pid_m = 3 + ((20 % 12) % 3) = 5, pid_n = (20 % 12) // 3 = 2
110 | # For pid=21: pid_m = 3 + ((21 % 12) % 3) = 3, pid_n = (21 % 12) // 3 = 3
111 | # For pid=22: pid_m = 3 + ((22 % 12) % 3) = 4, pid_n = (22 % 12) // 3 = 3
112 | # For pid=23: pid_m = 3 + ((23 % 12) % 3) = 5, pid_n = (23 % 12) // 3 = 3
113 | # For pid=24: pid_m = 6 + ((24 % 12) % 3) = 6, pid_n = (24 % 12) // 3 = 0
114 | # For pid=25: pid_m = 6 + ((25 % 12) % 3) = 7, pid_n = (25 % 12) // 3 = 0
115 | # For pid=26: pid_m = 6 + ((26 % 12) % 3) = 6, pid_n = (26 % 12) // 3 = 1
116 | # For pid=27: pid_m = 6 + ((27 % 12) % 3) = 7, pid_n = (27 % 12) // 3 = 1
117 | # For pid=28: pid_m = 6 + ((28 % 12) % 3) = 6, pid_n = (28 % 12) // 3 = 2
118 | # For pid=29: pid_m = 6 + ((29 % 12) % 3) = 7, pid_n = (29 % 12) // 3 = 2
119 | # For pid=30: pid_m = 6 + ((30 % 12) % 3) = 6, pid_n = (30 % 12) // 3 = 3
120 | # For pid=31: pid_m = 6 + ((31 % 12) % 3) = 7, pid_n = (31 % 12) // 3 = 3
121 | #
122 | # Matrix C (16x8) with blocks of size 2x2 will be computed by these pids:
123 | # [pid=0, pid=3, pid=6, pid=9 ]
124 | # [pid=1, pid=4, pid=7, pid=10]
125 | # [pid=2, pid=5, pid=8, pid=11]
126 | # [pid=12, pid=15, pid=18, pid=21]
127 | # [pid=13, pid=16, pid=19, pid=22]
128 | # [pid=14, pid=17, pid=20, pid=23]
129 | # [pid=24, pid=26, pid=28, pid=30]
130 | # [pid=25, pid=27, pid=29, pid=31]
131 | #
132 | # Swizzle pattern visualization:
133 | # Group 0: Group 1: Group 2:
134 | # +---+---+---+---+ +---+---+---+---+ +---+---+---+---+
135 | # | 0 | 3 | 6 | 9 | |12 |15 |18 |21 | |24 |26 |28 |30 |
136 | # +---+---+---+---+ +---+---+---+---+ +---+---+---+---+
137 | # | 1 | 4 | 7 |10 | |13 |16 |19 |22 | |25 |27 |29 |31 |
138 | # +---+---+---+---+ +---+---+---+---+ +---+---+---+---+
139 | # | 2 | 5 | 8 |11 | |14 |17 |20 |23 | +---+---+---+---+
140 | # +---+---+---+---+ +---+---+---+---+
141 | #
142 | # Notice how threads are assigned in column-major order within each group:
143 | # - Within each group, we process blocks in a column-first pattern (0,1,2 then 3,4,5 etc.)
144 | # - This creates spatial locality for memory accesses within each group
145 | # - Adjacent thread blocks process adjacent memory, improving cache efficiency
146 |
147 | pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
148 | pid_n = (pid % num_pid_in_group) // group_size_m
149 |
150 | # -----------------------------------------------------------
151 | # STEP 2: Create pointers to the blocks of input matrices A and B
152 | # -----------------------------------------------------------
153 | # Calculate offsets for the block of A and B this thread block will process
154 | # For pid=13 (computing block in row 4, column 0 of C):
155 | # offs_am = [8,9] (rows 8-9 of A)
156 | # offs_bn = [0,1] (columns 0-1 of B)
157 | # offs_k = [0,1] (columns 0-1 of A / rows 0-1 of B)
158 | offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
159 | offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
160 | offs_k = tl.arange(0, BLOCK_SIZE_K)
161 |
162 | # Create pointer blocks for A and B
163 | # Strides represent the memory distance between consecutive elements in a dimension:
164 | # - stride_am: bytes between consecutive rows in matrix A (used to move between rows of A)
165 | # - stride_ak: bytes between consecutive columns in matrix A (used to move between columns of A)
166 | # - stride_bk: bytes between consecutive rows in matrix B (used to move between rows of B)
167 | # - stride_bn: bytes between consecutive columns in matrix B (used to move between columns of B)
168 | #
169 | # For pid=13 (computing block in row 4, column 0 of C):
170 | # a_ptrs calculates pointers to A[8:10, 0:2] by:
171 | # - Starting at base pointer a_ptr
172 | # - Adding row offsets (offs_am[:, None] * stride_am) to move to rows 8-9
173 | # - Adding column offsets (offs_k[None, :] * stride_ak) to access columns 0-1
174 | # b_ptrs calculates pointers to B[0:2, 0:2] by:
175 | # - Starting at base pointer b_ptr
176 | # - Adding row offsets (offs_k[:, None] * stride_bk) to move to rows 0-1
177 | # - Adding column offsets (offs_bn[None, :] * stride_bn) to access columns 0-1
178 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
179 | b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
180 |
181 | # -----------------------------------------------------------
182 | # STEP 3: Compute the matrix multiplication C = A × B block by block
183 | # -----------------------------------------------------------
184 | # Initialize accumulator with zeros
185 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
186 |
187 | # For our example with K=16 and BLOCK_SIZE_K=2, we need 8 iterations
188 | for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
189 | # Load blocks of A and B
190 | # For pid=0, iteration 0: we load A[0:2, 0:2] and B[0:2, 0:2]
191 | # For pid=0, iteration 1: we load A[0:2, 2:4] and B[2:4, 0:2]
192 | # For pid=0, iteration 2: we load A[0:2, 4:6] and B[4:6, 0:2]
193 | # And so on...
194 |
195 | # Calculate how many elements remain in the K dimension for the current iteration
196 | k_remaining = K - k * BLOCK_SIZE_K
197 |
198 | # a_mask handles the columns of matrix A (K dimension)
199 | # For example, if K=10 and BLOCK_SIZE_K=4, in the last iteration (k=2):
200 | # - k_remaining = 10 - 2*4 = 2
201 | # - offs_k = [0,1,2,3]
202 | # - a_mask will be [[True,True,False,False]]
203 | # This ensures we only load the valid 2 remaining columns
204 | a_mask = (offs_k[None, :] < k_remaining)
205 |
206 | # b_mask handles the rows of matrix B (K dimension)
207 | # Using the same example, b_mask will be:
208 | # [[True], [True], [False], [False]]
209 | # This ensures we only load the valid 2 remaining rows
210 | b_mask = (offs_k[:, None] < k_remaining)
211 | a = tl.load(a_ptrs, mask=a_mask, other=0.0)
212 | b = tl.load(b_ptrs, mask=b_mask, other=0.0)
213 |
214 | # Compute matrix multiplication for this block and accumulate
215 | # For pid=0, iteration 0:
216 | # C[0,0] += A[0,0]*B[0,0] + A[0,1]*B[1,0]
217 | # C[0,1] += A[0,0]*B[0,1] + A[0,1]*B[1,1]
218 | # C[1,0] += A[1,0]*B[0,0] + A[1,1]*B[1,0]
219 | # C[1,1] += A[1,0]*B[0,1] + A[1,1]*B[1,1]
220 | accumulator = tl.dot(a, b, accumulator)
221 |
222 | # Move pointers to the next K block
223 | # For iteration 1: a_ptrs now points to A[0:2, 2:4]
224 | # For iteration 1: b_ptrs now points to B[2:4, 0:2]
225 | a_ptrs += BLOCK_SIZE_K * stride_ak
226 | b_ptrs += BLOCK_SIZE_K * stride_bk
227 |
228 | # -----------------------------------------------------------
229 | # STEP 4: Apply activation function (if specified) and prepare for output
230 | # -----------------------------------------------------------
231 | # Apply activation function if specified
232 | if ACTIVATION == "leaky_relu":
233 | accumulator = leaky_relu(accumulator)
234 |
235 | # Convert back to float16 for output
236 | c = accumulator.to(tl.float16)
237 |
238 | # -----------------------------------------------------------
239 | # STEP 5: Write the computed block back to matrix C
240 | # -----------------------------------------------------------
241 | # Calculate global indices for this block in matrix C
242 | # same as the offs_am and offs_bn
243 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
244 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
245 |
246 | # Create pointers to the output locations in C
247 | # same as the a_ptrs and b_ptrs
248 | c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
249 |
250 | # Create mask to handle boundary conditions
251 | c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
252 |
253 | # Store the computed values to matrix C
254 | # For pid=0: This writes to C[0:2, 0:2]
255 | tl.store(c_ptrs, c, mask=c_mask)
256 |
257 | @triton.jit
258 | def leaky_relu(x):
259 | return tl.where(x >= 0, x, 0.01 * x)
260 |
261 | def matmul(a, b, activation=None):
262 | M, K = a.shape
263 | K, N = b.shape
264 | # Initialize output tensor
265 | c = torch.empty((M, N), device=a.device, dtype=torch.float16)
266 | BLOCK_SIZE_M = 64
267 | BLOCK_SIZE_N = 64
268 | BLOCK_SIZE_K = 64
269 | GROUP_SIZE_M = 8
270 | # Calculate grid dimensions based on block sizes
271 | grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),)
272 |
273 | # Launch the kernel
274 | matmul_kernel[grid](
275 | a_ptr=a.data_ptr(),
276 | b_ptr=b.data_ptr(),
277 | c_ptr=c.data_ptr(),
278 | M=M, N=N, K=K,
279 | stride_am=a.stride(0),
280 | stride_ak=a.stride(1),
281 | stride_bk=b.stride(0),
282 | stride_bn=b.stride(1),
283 | stride_cm=c.stride(0),
284 | stride_cn=c.stride(1),
285 | ACTIVATION=activation,
286 | BLOCK_SIZE_M=BLOCK_SIZE_M,
287 | BLOCK_SIZE_N=BLOCK_SIZE_N,
288 | BLOCK_SIZE_K=BLOCK_SIZE_K,
289 | GROUP_SIZE_M=GROUP_SIZE_M,
290 | )
291 |
292 | return c
293 |
294 | # Test function for the matmul implementation
295 | def test_matmul():
296 | import torch
297 | for m, n, k in [(32, 32, 32), (256, 512, 128), (1024, 1024, 1024)]:
298 | a = torch.randn((m, k), device='cuda', dtype=torch.float16)
299 | b = torch.randn((k, n), device='cuda', dtype=torch.float16)
300 | c_ref = torch.matmul(a, b)
301 | c_triton = matmul(a, b)
302 | assert torch.allclose(c_ref, c_triton, rtol=1e-2, atol=1e-2), f"Failed for size {m}x{k}x{n}"
303 |
304 | print("Size tests passed!")
305 |
306 | if __name__ == "__main__":
307 | test_matmul()
308 |
--------------------------------------------------------------------------------
/annotated_examples/classics/swiglu.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2 | # Modifications Copyright 2025 Mekkcyber.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import triton
17 | import triton.language as tl
18 | import torch
19 |
20 | MAX_FUSED_SIZE : int = 65536
21 | next_power_of_2 = triton.next_power_of_2
22 |
23 | def calculate_settings(n : int) -> (int, int,):
24 | BLOCK_SIZE : int = next_power_of_2(n)
25 | if BLOCK_SIZE > MAX_FUSED_SIZE:
26 | raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
27 | f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
28 | num_warps : int = 4
29 | if BLOCK_SIZE >= 32768: num_warps = 32
30 | elif BLOCK_SIZE >= 8192: num_warps = 16
31 | elif BLOCK_SIZE >= 2048: num_warps = 8
32 | return BLOCK_SIZE, num_warps
33 |
34 | @triton.jit
35 | def _forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
36 | block_idx = tl.program_id(0)
37 | offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
38 | mask = offsets < n_elements
39 |
40 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
41 | g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
42 |
43 | # f = e * sigmoid(e)
44 | f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
45 | f_row = f_row.to(g_row.dtype)
46 | # h = f * g
47 | h_row = f_row * g_row
48 |
49 | # Store h
50 | tl.store(h + offsets, h_row, mask = mask)
51 |
52 |
53 | def swiglu_forward_kernel(e, g):
54 | batch, seq_len, hd = e.shape
55 | n_elements = e.numel()
56 | h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device)
57 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
58 | _forward_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
59 | return h
60 |
61 |
62 | @triton.jit
63 | def _backward_kernel(dY, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
64 | """
65 | Backward pass for SwiGLU activation function.
66 |
67 | Forward pass (for reference):
68 | f = e * sigmoid(e) # SiLU/Swish activation
69 | h = f * g # Gating mechanism
70 |
71 | Backward pass derivation:
72 | Given dL/dh (dY), we need to compute:
73 | - dL/de: Gradient with respect to first input
74 | - dL/dg: Gradient with respect to second input
75 |
76 | Using the chain rule:
77 | dL/dg = dL/dh * dh/dg = dY * f
78 | dL/de = dL/dh * dh/de = dY * dh/de = dY * g * df/de
79 | Where df/de = sigmoid(e) + e * sigmoid(e) * (1 - sigmoid(e))
80 | = sigmoid(e) * (1 + e * (1 - sigmoid(e))) (see backprop_math/swiglu.md)
81 | """
82 | # Get the block index and calculate offsets for parallel processing
83 | block_idx = tl.program_id(0)
84 | offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
85 | mask = offsets < n_elements
86 |
87 | # Load input tensors
88 | dY_row = tl.load(dY + offsets, mask = mask, other = 0)
89 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
90 | g_row = tl.load(g + offsets, mask = mask, other = 0)
91 |
92 | # Compute sigmoid(e) - needed for both forward and backward calculations
93 | se_row = tl.sigmoid(e_row) # sigmoid(e)
94 |
95 | # Compute f = e * sigmoid(e) (SiLU/Swish activation)
96 | f_row = se_row * e_row
97 | f_row = f_row.to(dY_row.dtype) # Convert back to original dtype
98 |
99 | # Compute dL/dg = dY * f
100 | dg_row = dY_row * f_row
101 |
102 | # Compute dL/de = dY * g * sigmoid(e) * (1 + e * (1 - sigmoid(e)))
103 | # This is the derivative of SwiGLU with respect to e
104 | de_row = dY_row.to(tl.float32) * g_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
105 | de_row = de_row.to(dY_row.dtype) # Convert back to original dtype
106 |
107 | # Store computed gradients back to the input buffers
108 | # Note: We're reusing the input buffers to store the gradients
109 | tl.store(e + offsets, de_row, mask = mask) # Store dL/de in e buffer
110 | tl.store(g + offsets, dg_row, mask = mask) # Store dL/dg in g buffer
111 |
112 |
113 | def swiglu_DWf_DW_dfg_kernel(dY, e, g):
114 | n_elements = e.numel()
115 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
116 | _backward_kernel[grid](dY, e, g, n_elements, BLOCK_SIZE = 1024,)
117 | return e, g
118 |
119 | def test_swiglu_correctness():
120 | """
121 | Test the correctness of the SwiGLU implementation by comparing with a reference implementation.
122 | Tests both forward and backward passes for SwiGLU (SiLU/Swish).
123 | """
124 | import torch
125 | import torch.nn.functional as F
126 |
127 | def swiglu_reference_forward(e, g):
128 | """Reference implementation of SwiGLU forward pass"""
129 | return g * (e * F.sigmoid(e))
130 |
131 | forward_kernel = _forward_kernel
132 | backward_kernel = _backward_kernel
133 |
134 | batch_size, seq_len, hidden_dim = 2, 10, 128
135 | e = torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.float32, device='cuda')
136 | g = torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.float32, device='cuda')
137 |
138 | h = torch.empty_like(e)
139 |
140 | n_elements = e.numel()
141 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
142 | forward_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024)
143 |
144 | our_output = h.clone()
145 |
146 | ref_output = swiglu_reference_forward(e, g)
147 |
148 | max_diff = torch.max(torch.abs(ref_output - our_output))
149 | print(f"Max difference in SwiGLU forward pass: {max_diff.item()}")
150 | assert max_diff < 1e-5, "SwiGLU forward pass implementation is incorrect!"
151 |
152 | # Test backward pass
153 | dY = torch.randn_like(h)
154 |
155 | # Compute reference gradients
156 | e.requires_grad_(True)
157 | g.requires_grad_(True)
158 | ref_output = swiglu_reference_forward(e, g)
159 | ref_output.backward(dY)
160 | ref_de = e.grad.clone()
161 | ref_dg = g.grad.clone()
162 |
163 | backward_kernel[grid](dY, e, g, n_elements, BLOCK_SIZE = 1024)
164 |
165 | max_diff_de = torch.max(torch.abs(ref_de - e))
166 | print(f"Max difference in SwiGLU backward pass (de): {max_diff_de.item()}")
167 | assert max_diff_de < 1e-5, "SwiGLU backward pass implementation for de is incorrect!"
168 |
169 | max_diff_dg = torch.max(torch.abs(ref_dg - g))
170 | print(f"Max difference in SwiGLU backward pass (dg): {max_diff_dg.item()}")
171 | assert max_diff_dg < 1e-5, "SwiGLU backward pass implementation for dg is incorrect!"
172 |
173 | print("All tests passed!")
174 |
175 |
176 | if __name__ == "__main__":
177 | test_swiglu_correctness()
178 |
--------------------------------------------------------------------------------
/annotated_examples/gemlite/gemm.py:
--------------------------------------------------------------------------------
1 | # Written by Dr. Hicham Badri @Mobius Labs GmbH - 2024
2 | # Modified by Mekkcyber - 2025
3 | #*******************************************************
4 | import torch, time
5 | import triton
6 | import triton.language as tl
7 |
8 | ########################################################################################################################################################################
9 |
10 | # Prerequisities for the gemm, for better understanding start from the gemm_kernel function, everything is explained there
11 | @triton.jit
12 | def dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample: tl.constexpr, W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr):
13 | """
14 | Dequantizes packed integer values into floating point values using various quantization schemes.
15 |
16 | Args:
17 | b: Packed quantized values (typically int32)
18 | scales: Scaling factors for dequantization (per group or channel)
19 | zeros: Zero points for asymmetric quantization (per group or channel)
20 | q_shift: Bit shift amount for unpacking elements from packed format
21 | meta_dtype: Target data type for metadata operations
22 | unpack_mask: Bit mask for extracting individual elements (e.g., 0xF for 4-bit)
23 | elements_per_sample: Number of quantized elements packed into each storage unit
24 | W_group_mode: Quantization scheme to use (1-4)
25 | zero_is_scalar: Whether zero point is shared across all elements
26 |
27 | Returns:
28 | Dequantized tensor in floating point format
29 | """
30 | # Step 1: Unpack the elements if they are packed (e.g., 8 4-bit values in one int32)
31 | if(elements_per_sample > 1):
32 | # Extract individual quantized values using bit shifting and masking
33 | # q_shift determines which element to extract based on position
34 | b = (b >> q_shift) & unpack_mask # int32 -> int32
35 |
36 | # Step 2: Apply the appropriate dequantization formula based on W_group_mode
37 |
38 | if(W_group_mode == 1): # Shift-only mode (zero-point subtraction)
39 | # Formula: dequantized = quantized - zero_point
40 | b = b.to(meta_dtype) - zeros
41 |
42 | if(W_group_mode == 2): # Scale-only mode (symmetric quantization)
43 | # Formula: dequantized = quantized * scale
44 | # Used when quantized values are centered around zero
45 | b = b.to(meta_dtype) * scales
46 |
47 | if(W_group_mode == 3): # Scale and shift mode (asymmetric quantization)
48 | # Formula: dequantized = (quantized - zero_point) * scale
49 | if(zero_is_scalar):
50 | # When zero_point is shared across all elements (memory optimization)
51 | b = (b - zeros).to(meta_dtype) * scales
52 | else:
53 | # When each group has its own zero_point
54 | b = (b.to(meta_dtype) - zeros) * scales
55 |
56 | if(W_group_mode == 4): # Fused multiply-add mode
57 | # Formula: dequantized = quantized * scale + zero
58 | # Uses fused multiply-add for better performance
59 | # Note: in this mode, 'zeros' is actually an additive term, not a zero point
60 | b = tl.fma(b.to(meta_dtype), scales, zeros)
61 |
62 | return b
63 |
64 | @triton.jit
65 | def swizzle_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
66 | grid_m = tl.cdiv(M, BLOCK_SIZE_M)
67 | grid_n = tl.cdiv(N, BLOCK_SIZE_N)
68 | width = GROUP_SIZE_M * grid_n
69 | group_id = pid // width
70 | group_size = tl.minimum(grid_m - group_id * GROUP_SIZE_M, GROUP_SIZE_M)
71 | pid_m = group_id * GROUP_SIZE_M + (pid % group_size)
72 | pid_n = (pid % width) // group_size
73 | return pid_m, pid_n
74 |
75 | @triton.jit
76 | def linear_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
77 | pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N)
78 | pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N)
79 | return pid_m, pid_n
80 | ########################################################################################################################################################################
81 |
82 | # START HERE for the gemm kernel
83 | @triton.jit
84 | def gemm_kernel(
85 | a_ptr, b_ptr, c_ptr,
86 | scales_ptr, zeros_ptr, scales_a_ptr,
87 | M, N, K,
88 | ######### Quant parms #########
89 | W_nbits: tl.constexpr,
90 | group_size: tl.constexpr,
91 | unpack_mask: tl.constexpr,
92 | elements_per_sample: tl.constexpr,
93 | ######### Strides #########
94 | stride_am, stride_ak,
95 | stride_bk, stride_bn,
96 | stride_cm, stride_cn,
97 | stride_meta_g, stride_meta_n,
98 | ######### Dtypes #########
99 | input_dtype: tl.constexpr,
100 | output_dtype: tl.constexpr,
101 | acc_dtype: tl.constexpr,
102 | meta_dtype: tl.constexpr,
103 | ######### Meta-data mode #########
104 | channel_scale_mode: tl.constexpr,
105 | W_group_mode: tl.constexpr,
106 | zero_is_scalar: tl.constexpr,
107 | ######### tuning params #########
108 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
109 | GROUP_SIZE_M: tl.constexpr,
110 | A_load_order: tl.constexpr, meta_evict_policy: tl.constexpr,
111 | data_contiguous: tl.constexpr,
112 | ):
113 | """
114 | Based on https://github.com/fpgaminer/GPTQ-triton
115 | GEMM for C = matmul(A, dequantize(B, scales, zeros))
116 | A is of shape (M, K): float16 or bfloat16
117 | B is of shape (K//elements_per_sample, N): int32 as a packed matrix
118 | C is of shape (M, N): float16 or bfloat16 depending on the input A
119 | scales and zeros is of shape (group_size, N): float16 or bfloat16
120 |
121 | BLOCK_SIZE_M >=16
122 | BLOCK_SIZE_K <= group_size
123 | """
124 |
125 | # This kernel implements a quantized matrix multiplication operation where:
126 | # - Matrix A is a full-precision matrix (float16/bfloat16)
127 | # - Matrix B is a quantized matrix (packed into int32)
128 | # - The result C = A @ dequantize(B)
129 |
130 | # Example:
131 | # If we have:
132 | # - A: 128x512 matrix (M=128, K=512) in float16
133 | # - B: quantized 128x256 matrix packed into int32
134 | # - W_nbits=4 (4-bit quantization)
135 | # - elements_per_sample=8 (8 elements packed into each int32)
136 | # B should be dequantized to a 512x256 matrix in float16
137 | # Then C will be a 128x256 matrix in float16
138 |
139 | # Get the program ID which identifies which tile this thread block processes
140 | pid = tl.program_id(axis=0)
141 |
142 | # for each pid, we need to find the corresponding block in the output matrix, we can do this in two ways:
143 | # 1. swizzle_tile: Creating groups horizontally and inside each group, pids are mapped vertically
144 | # Example of swizzle pattern with GROUP_SIZE_M=2:
145 | # pid layout in the output matrix (each number represents a block's pid):
146 | # 0 2 4 6
147 | # 1 3 5 7
148 | # 8 10 12 14
149 | # 9 11 13 15
150 | # This improves cache locality by keeping adjacent thread blocks working on adjacent rows, see classic/matmul.py for more details
151 | # 2. linear_tile: we simply set pid_m = pid // (tl.cdiv(N, BLOCK_SIZE_N)) and pid_n = pid % (tl.cdiv(N, BLOCK_SIZE_N))
152 | # Example of linear pattern with GROUP_SIZE_M=2:
153 | # pid layout in the output matrix (each number represents a block's pid):
154 | # 0 1 2 3
155 | # 4 5 6 7
156 | # 8 9 10 11
157 | # 12 13 14 15
158 | # This is simpler to compute but may result in poorer cache locality
159 | pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M)
160 |
161 | # Calculate how many blocks we need in the K dimension
162 | # For example, if K=512 and BLOCK_SIZE_K=64, we need 8 iterations
163 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
164 |
165 | # Calculate offsets for each thread within the block
166 | # If BLOCK_SIZE_M=16, BLOCK_SIZE_N=16, each thread block processes a 16x16 tile
167 | # offs_m will be [0,1,2,...,15] + (pid_m * 16)
168 | # offs_n will be [0,1,2,...,15] + (pid_n * 16)
169 | # offs_k will be [0,12,...,BLOCK_SIZE_K-1]
170 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
171 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
172 | offs_k = tl.arange(0, BLOCK_SIZE_K)
173 |
174 | # Optimize memory access patterns based on data layout
175 | if(data_contiguous):
176 | # multiple_of(tensor, values): Informs the compiler that the shape dimensions of the tensor are multiples of 'values'
177 | # - This allows the compiler to optimize memory access patterns and vectorize loads
178 | # - For example, if the shape of offs_am is a multiple of BLOCK_SIZE_M, the compiler can generate more efficient code
179 | #
180 | # max_contiguous(tensor, values): Informs the compiler that the first 'values' elements in the tensor are contiguous
181 | # - This helps the compiler generate more efficient memory access patterns
182 | # - For example, if the first BLOCK_SIZE_M elements in offs_am are contiguous (0,1,2,...),
183 | # the compiler can use coalesced memory accesses
184 | offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
185 |
186 | # Calculate pointers to input matrices
187 | # For matrix A: If stride_am=K and stride_ak=1, this accesses A in row-major order
188 | # Example: For A[2,3] with K=512, the offset would be 2*512 + 3 = 1027
189 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
190 | # Create a mask to handle boundary conditions (when M is not divisible by BLOCK_SIZE_M)
191 | a_mask = (offs_am[:, None] < M)
192 |
193 | # For matrix B: Calculate pointers based on the packed format
194 | # If B is packed with 8 elements per int32, we divide offs_k by 8 to get the actual memory location
195 | #
196 | # Example: With 8-bit elements packed into 32-bit integers (elements_per_sample = 4):
197 | # - For offs_k values [0,1,2,3], the division gives [0,0,0,0]
198 | # - For offs_k values [4,5,6,7], the division gives [1,1,1,1]
199 | #
200 | # This means that elements 0-3 are stored in the first 32-bit word, elements 4-7 in the second word, etc.
201 | # Later, we'll use q_shift to extract the correct bits from each packed word.
202 | b_ptrs = b_ptr + ((offs_k[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn)
203 |
204 | # Calculate bit shift for unpacking quantized values from packed integers
205 | #
206 | # Example 1: With 4-bit quantization (W_nbits=4) and 8 elements per int32 (elements_per_sample=8):
207 | # - For offs_k = [0,1,2,3,4,5,6,7]:
208 | # offs_k % elements_per_sample = [0,1,2,3,4,5,6,7]
209 | # q_shift = [0,4,8,12,16,20,24,28] bits
210 | #
211 | # Example 2: With 8-bit quantization (W_nbits=8) and 4 elements per int32 (elements_per_sample=4):
212 | # - For offs_k = [0,1,2,3,4,5,6,7]:
213 | # offs_k % elements_per_sample = [0,1,2,3,0,1,2,3]
214 | # q_shift = [0,8,16,24,0,8,16,24] bits
215 | #
216 | # The modulo operation (%) wraps around when we exceed elements_per_sample,
217 | # ensuring we extract the correct element position within each packed integer.
218 | q_shift = ((offs_k % elements_per_sample) * W_nbits).to(tl.int32)[:, None]
219 |
220 | # Calculate pointers to quantization metadata (scales and zeros)
221 | # These pointers point to the start of each column in the metadata matrices
222 | # For example, if we have a matrix with N columns, each column has its own
223 | # scale and zero point values for dequantization
224 | scales_ptrs = scales_ptr + offs_bn[None, :] * stride_meta_n
225 | zeros_ptrs = zeros_ptr + offs_bn[None, :] * stride_meta_n
226 |
227 | # Calculate stride multiplier for group quantization
228 | # If group_size=64 and BLOCK_SIZE_K=32, stride_mul=0.5
229 | # This means we need a new scale/zero for every 2 K-dimension blocks
230 | stride_mul = BLOCK_SIZE_K / group_size
231 |
232 | # If zero point is a scalar (same for all elements), load it once
233 | # eviction_policy='evict_last' tells the compiler how to manage the cache:
234 | # - 'evict_last': Keep the data in cache as long as possible (good for reused data)
235 | # - 'evict_first': Evict from cache quickly (good for data used only once)
236 | # This helps optimize memory access patterns and cache utilization
237 | if(zero_is_scalar):
238 | zero_scalar = tl.load(zeros_ptr, eviction_policy='evict_last')
239 |
240 | # Initialize accumulator for matrix multiplication
241 | # This will store the partial sums during the computation
242 | # For a 16x16 tile, this is a 16x16 matrix initialized to zeros
243 | acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
244 |
245 | # Main computation loop - iterate over blocks in the K dimension
246 | # For example, if K=512 and BLOCK_SIZE_K=64, we do 8 iterations
247 | for k in range(num_pid_k):
248 | # Load matrix A based on the specified loading order
249 | # Different loading orders can help with instruction scheduling and can lead to better performance
250 | if(A_load_order == 0): # Early load - load A before B
251 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last')
252 |
253 | # Load packed quantized values from matrix B
254 | # Load packed quantized weights with 'evict_first' policy since we'll immediately
255 | # dequantize these values and won't need the packed representation in cache
256 | # Each row of B is repeated elements_per_sample times, this way we can unpack it using the q_shift
257 | # If you don't get why you can look at how b_ptrs is computed
258 | b = tl.load(b_ptrs, eviction_policy='evict_first')
259 |
260 | if(A_load_order == 1): # Load A after loading B
261 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last')
262 |
263 | # Load quantization metadata (scales and zero points) based on the group mode
264 | # Different modes use different patterns of scales and zeros
265 | # W_group_mode controls how quantization parameters (scales and zeros) are applied:
266 |
267 | # Mode 0: No quantization - neither scales nor zeros are used
268 | # Mode 1: Zero-point only quantization - only zero points are used, no scales
269 | # Used for integer quantization where only zero-point shifting is needed
270 | # Mode 2: Scale-only quantization - only scales are used, no zero points
271 | # Used for symmetric quantization where values are centered around zero
272 | # Mode 3: Full quantization - both scales and zero points are used
273 | # Used for asymmetric quantization with arbitrary ranges
274 | # Mode 4: Asymmetric (Grouped - b*scales + zeros)
275 |
276 | if(W_group_mode > 0):
277 | # Calculate offset for grouped quantization
278 | # For every group_size weights, we have a single scale/zero point
279 | # stride_mul = BLOCK_SIZE_K / group_size controls how often we need new metadata
280 | #
281 | # Examples:
282 | # 1. If group_size=64 and BLOCK_SIZE_K=64, stride_mul=1
283 | # We need a new scale for each K block (k_m increases by 1 each iteration)
284 | # 2. If group_size=128 and BLOCK_SIZE_K=64, stride_mul=0.5
285 | # We need a new scale every 2 K blocks (k_m increases by 1 every 2 iterations)
286 | #
287 | # This mapping ensures we use the correct scale/zero for each weight group
288 | k_m = (k * stride_mul).to(tl.int32)
289 |
290 | # Load scales if needed (modes 2, 3, 4)
291 | # Example: For per-channel quantization, each output channel has its own scale
292 | if(W_group_mode >= 2): # [2, 3, 4]
293 | scales = tl.load(scales_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
294 | else:
295 | scales = None
296 |
297 | # Load zero points if needed (modes 1, 3, 4)
298 | # Example: For per-channel quantization, each output channel has its own zero point
299 | if(W_group_mode == 1 or W_group_mode >= 3): # [1, 3, 4]
300 | if(zero_is_scalar):
301 | # If zero_is_scalar=1, use the same zero point for all elements
302 | # This saves memory and bandwidth when all channels share the same zero point
303 | zeros = zero_scalar
304 | else:
305 | # Otherwise load per-group zero points from memory
306 | # stride_meta_g controls the spacing between groups in memory
307 | zeros = tl.load(zeros_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
308 | else:
309 | zeros = None
310 |
311 | if(A_load_order == 2): # Mid load - load A after loading metadata
312 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last')
313 |
314 | # Unpack and dequantize the values from matrix B
315 | # Dequantization formula depends on the W_group_mode:
316 | # - Mode 0: No dequantization just unpacking
317 | # - Mode 1: dequantized_value = quantized_value - zero_point
318 | # - Mode 2: dequantized_value = quantized_value * scale
319 | # - Mode 3: dequantized_value = (quantized_value - zero_point) * scale
320 | # - Mode 4: dequantized_value = b*scales + zeros
321 | b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar)
322 |
323 | if(A_load_order == 3): # Late load - load A after dequantization
324 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last')
325 |
326 | # Perform matrix multiplication for this block
327 | # For 16x16 tiles, this computes a 16x16 result and adds to accumulator
328 | # Example: If a is 16x64 and b is 64x16, this computes a 16x16 partial result
329 | acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype, input_precision="tf32")
330 |
331 | # Advance pointers for the next iteration
332 | # Move to the next block in the K dimension
333 | a_ptrs += BLOCK_SIZE_K * stride_ak
334 | b_ptrs += (BLOCK_SIZE_K // elements_per_sample) * stride_bk
335 |
336 | # Apply channel-wise scaling to the result if needed
337 | # This is used for various quantization schemes
338 | if(channel_scale_mode == 1): # Weight-only scaling
339 | # Load scales for each output channel
340 | # Example: If each output has a different scale factor
341 | scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy)
342 | # Apply scales to each column of the result
343 | acc = acc.to(meta_dtype) * scales_b[None, :]
344 |
345 | if(channel_scale_mode == 2): # Activation-only scaling
346 | # Load scaling factors for each input channel (row of the activation matrix)
347 | scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy)
348 |
349 | # Create a vector of ones for the output dimension
350 | # Since we're only scaling by activation (input channels), we use 1.0 for all output channels
351 | # This creates a vector of size BLOCK_SIZE_N filled with 1's of the metadata type
352 | scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=meta_dtype)
353 |
354 | # Apply the scaling factors to the accumulated result:
355 | # 1. scales_a[:, None]: Reshape scales_a from [BLOCK_SIZE_M] to [BLOCK_SIZE_M, 1] for broadcasting
356 | # 2. scales_b[None, :]: Reshape scales_b from [BLOCK_SIZE_N] to [1, BLOCK_SIZE_N] for broadcasting
357 | # 3. This creates a scaling matrix of shape [BLOCK_SIZE_M, BLOCK_SIZE_N] where each row is scaled by its corresponding scales_a value
358 | # 4. Multiply the accumulator by this scaling matrix element-wise
359 | acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :])
360 |
361 | if(channel_scale_mode == 3): # Both weight and activation scaling
362 | # Load scales for both input and output channels
363 | scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy)
364 | scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy)
365 | # Apply both scales to the result
366 | # Example: If row 2 has scale 0.5 and column 3 has scale 0.25,
367 | # element [2,3] is multiplied by 0.5*0.25=0.125
368 | acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :])
369 |
370 | # Convert the result to the output data type
371 | acc = acc.to(output_dtype)
372 |
373 | # Calculate pointers to the output matrix
374 | # Similar to input pointers, but for matrix C
375 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
376 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
377 | offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N)
378 | c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn)
379 |
380 | # Store the result to the output matrix
381 | # Use masks to handle boundary conditions
382 | tl.store(c_ptrs, acc, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))
383 |
384 |
385 | def test_kernel():
386 |
387 | # Set up test parameters
388 | M, N, K = 128, 256, 512
389 | W_nbits = 4
390 | group_size = 64
391 | elements_per_sample = 8 # For 4-bit quantization, 8 elements fit in one int32
392 |
393 | # Create input matrices
394 | a = torch.randn(M, K, dtype=torch.float16, device='cuda')
395 |
396 | # Create quantized weights (normally this would come from a quantization process)
397 | # For testing, we'll create random data
398 | b_unpacked = torch.randint(-8, 7, (K, N), dtype=torch.int8, device='cuda')
399 |
400 | # Pack the weights into int32
401 | b_packed = torch.zeros((K // elements_per_sample, N), dtype=torch.int32, device='cuda')
402 | for i in range(elements_per_sample):
403 | b_packed |= (b_unpacked[i::elements_per_sample, :].to(torch.int32) & ((1 << W_nbits) - 1)) << (i * W_nbits)
404 |
405 | # Create scales and zeros for dequantization
406 | scales = torch.ones((K // group_size, N), dtype=torch.float16, device='cuda')
407 | zeros = torch.zeros((K // group_size, N), dtype=torch.float16, device='cuda')
408 | scales_a = torch.ones(M, dtype=torch.float16, device='cuda')
409 |
410 | # Output matrix
411 | c_triton = torch.zeros((M, N), dtype=torch.float16, device='cuda')
412 |
413 | # Calculate strides
414 | stride_am, stride_ak = a.stride(0), a.stride(1)
415 | stride_bk, stride_bn = b_packed.stride(0), b_packed.stride(1)
416 | stride_cm, stride_cn = c_triton.stride(0), c_triton.stride(1)
417 | stride_meta_g, stride_meta_n = scales.stride(0), scales.stride(1)
418 |
419 | # Define grid
420 | grid = lambda META: (
421 | triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
422 | )
423 |
424 | # For verification, compute reference result using PyTorch
425 | # Dequantize the weights
426 | b_dequantized = torch.zeros((K, N), dtype=torch.float16, device='cuda')
427 | for g in range(K // group_size):
428 | start_idx = g * group_size
429 | end_idx = min((g + 1) * group_size, K)
430 | for i in range(start_idx, end_idx):
431 | element_idx = i % elements_per_sample
432 | packed_idx = i // elements_per_sample
433 | shift = element_idx * W_nbits
434 | mask = (1 << W_nbits) - 1
435 | b_dequantized[i] = ((b_packed[packed_idx] >> shift) & mask).to(torch.float16)
436 | # Apply scales
437 | b_dequantized[i] *= scales[g]
438 |
439 | # Compute reference result
440 | c_ref = torch.matmul(a, b_dequantized)
441 |
442 | # Check correctness
443 | max_diff = torch.max(torch.abs(c_ref - c_triton))
444 | print(f"Max difference between PyTorch and Triton: {max_diff.item()}")
445 |
446 | # Benchmark
447 | warmup = 25
448 | rep = 100
449 |
450 | torch.cuda.synchronize()
451 | start = time.time()
452 | for _ in range(warmup + rep):
453 | gemm_kernel[grid](
454 | a_ptr=a, b_ptr=b_packed, c_ptr=c_triton,
455 | scales_ptr=scales, zeros_ptr=zeros, scales_a_ptr=scales_a,
456 | M=M, N=N, K=K,
457 | W_nbits=W_nbits, group_size=group_size,
458 | unpack_mask=(1 << W_nbits) - 1, elements_per_sample=elements_per_sample,
459 | stride_am=stride_am, stride_ak=stride_ak,
460 | stride_bk=stride_bk, stride_bn=stride_bn,
461 | stride_cm=stride_cm, stride_cn=stride_cn,
462 | stride_meta_g=stride_meta_g, stride_meta_n=stride_meta_n,
463 | input_dtype=tl.float16, output_dtype=tl.float16, acc_dtype=tl.float32, meta_dtype=tl.float16,
464 | channel_scale_mode=1, W_group_mode=1, zero_is_scalar=0,
465 | BLOCK_SIZE_M=16, BLOCK_SIZE_N=16, BLOCK_SIZE_K=64, GROUP_SIZE_M=8,
466 | A_load_order=1, meta_evict_policy='evict_last', data_contiguous=1,
467 | )
468 | torch.cuda.synchronize()
469 | end = time.time()
470 |
471 | elapsed_time = (end - start) / rep
472 | print(f"Triton kernel time: {elapsed_time * 1000:.2f} ms")
473 |
474 |
475 | return c_triton, c_ref, max_diff.item()
476 |
477 | if __name__ == "__main__":
478 | test_kernel()
479 |
--------------------------------------------------------------------------------
/annotated_examples/gemlite/gemm_splitK.py:
--------------------------------------------------------------------------------
1 | # Written by Dr. Hicham Badri @Mobius Labs GmbH - 2024
2 | # Modified by MekkCyber - 2025
3 | #*******************************************************
4 | import triton
5 | import triton.language as tl
6 |
7 | @triton.jit
8 | def dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample: tl.constexpr, W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr):
9 | """
10 | Dequantizes packed integer values into floating point values using various quantization schemes.
11 |
12 | Args:
13 | b: Packed quantized values (typically int32)
14 | scales: Scaling factors for dequantization (per group or channel)
15 | zeros: Zero points for asymmetric quantization (per group or channel)
16 | q_shift: Bit shift amount for unpacking elements from packed format
17 | meta_dtype: Target data type for metadata operations
18 | unpack_mask: Bit mask for extracting individual elements (e.g., 0xF for 4-bit)
19 | elements_per_sample: Number of quantized elements packed into each storage unit
20 | W_group_mode: Quantization scheme to use (1-4)
21 | zero_is_scalar: Whether zero point is shared across all elements
22 |
23 | Returns:
24 | Dequantized tensor in floating point format
25 | """
26 | # Step 1: Unpack the elements if they are packed (e.g., 8 4-bit values in one int32)
27 | if(elements_per_sample > 1):
28 | # Extract individual quantized values using bit shifting and masking
29 | # q_shift determines which element to extract based on position
30 | b = (b >> q_shift) & unpack_mask # int32 -> int32
31 |
32 | # Step 2: Apply the appropriate dequantization formula based on W_group_mode
33 |
34 | if(W_group_mode == 1): # Shift-only mode (zero-point subtraction)
35 | # Formula: dequantized = quantized - zero_point
36 | b = b.to(meta_dtype) - zeros
37 |
38 | if(W_group_mode == 2): # Scale-only mode (symmetric quantization)
39 | # Formula: dequantized = quantized * scale
40 | # Used when quantized values are centered around zero
41 | b = b.to(meta_dtype) * scales
42 |
43 | if(W_group_mode == 3): # Scale and shift mode (asymmetric quantization)
44 | # Formula: dequantized = (quantized - zero_point) * scale
45 | if(zero_is_scalar):
46 | # When zero_point is shared across all elements (memory optimization)
47 | b = (b - zeros).to(meta_dtype) * scales
48 | else:
49 | # When each group has its own zero_point
50 | b = (b.to(meta_dtype) - zeros) * scales
51 |
52 | if(W_group_mode == 4): # Fused multiply-add mode
53 | # Formula: dequantized = quantized * scale + zero
54 | # Uses fused multiply-add for better performance
55 | # Note: in this mode, 'zeros' is actually an additive term, not a zero point
56 | b = tl.fma(b.to(meta_dtype), scales, zeros)
57 |
58 | return b
59 |
60 | @triton.jit
61 | def swizzle_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
62 | grid_m = tl.cdiv(M, BLOCK_SIZE_M)
63 | grid_n = tl.cdiv(N, BLOCK_SIZE_N)
64 | width = GROUP_SIZE_M * grid_n
65 | group_id = pid // width
66 | group_size = tl.minimum(grid_m - group_id * GROUP_SIZE_M, GROUP_SIZE_M)
67 | pid_m = group_id * GROUP_SIZE_M + (pid % group_size)
68 | pid_n = (pid % width) // group_size
69 | return pid_m, pid_n
70 |
71 | @triton.jit
72 | def linear_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
73 | pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N)
74 | pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N)
75 | return pid_m, pid_n
76 |
77 |
78 | @triton.jit
79 | def gemm_splitK_kernel(
80 | a_ptr, b_ptr, c_ptr,
81 | scales_ptr, zeros_ptr, scales_a_ptr,
82 | M, N, K,
83 | ######### Quant parms #########
84 | W_nbits: tl.constexpr,
85 | group_size: tl.constexpr,
86 | unpack_mask: tl.constexpr,
87 | elements_per_sample: tl.constexpr,
88 | ######### Strides #########
89 | stride_am, stride_ak,
90 | stride_bk, stride_bn,
91 | stride_cm, stride_cn,
92 | stride_meta_g, stride_meta_n,
93 | ######### Dtypes #########
94 | input_dtype: tl.constexpr,
95 | output_dtype: tl.constexpr,
96 | acc_dtype: tl.constexpr,
97 | meta_dtype: tl.constexpr,
98 | ######### Meta-data mode #########
99 | channel_scale_mode: tl.constexpr,
100 | W_group_mode: tl.constexpr,
101 | zero_is_scalar: tl.constexpr,
102 | ######### tuning params #########
103 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
104 | GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr,
105 | A_load_order: tl.constexpr, meta_evict_policy: tl.constexpr, atomic_mode: tl.constexpr,
106 | data_contiguous: tl.constexpr,
107 | ):
108 | """
109 | Quantized GEMM with split-K parallelism for C = matmul(A, dequantize(B, scales, zeros))
110 |
111 | A is of shape (M, K): float16 or bfloat16
112 | B is of shape (K//elements_per_sample, N): int32 as a packed matrix
113 | C is of shape (M, N): float16 or bfloat16 (same dtype as input A)
114 | scales is of shape (K//group_size, N) or (1, N): meta_dtype
115 | zeros is of shape (K//group_size, N) or (1, 1): meta_dtype
116 |
117 | Requirements:
118 | - BLOCK_SIZE_M must be >= 16
119 | - BLOCK_SIZE_K * SPLIT_K must be <= group_size
120 |
121 | Based on the split-K dequantization GEMM implementation from:
122 | https://github.com/foundation-model-stack/foundation-model-stack/blob/main/fms/models/llm/kernels/triton/splitk_dequant_gemm.py
123 | """
124 |
125 | # It's recommended to understand the standard GEMM implementation first in gemlite/gemm.py, as this split-K version
126 | # builds upon it, so we will not delve into the details of the standard GEMM implementation here.
127 | #
128 | # Unlike standard GEMM where each thread block computes its entire output tile,
129 | # in split-K GEMM we have a 2D grid of thread blocks where:
130 | # - pid (x-dim): Each thread block along x-axis is responsible for a unique output tile
131 | # - pid_k (y-dim): Multiple thread blocks along y-axis collaborate on the same output tile,
132 | # each computing a partial result and using atomic adds to safely accumulate into the final output
133 | pid = tl.program_id(axis=0) # Determines which output tile to compute
134 | pid_k = tl.program_id(axis=1) # Determines which K-slice to process for this output tile
135 |
136 | pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M)
137 |
138 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
139 |
140 |
141 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
142 | offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
143 | offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
144 |
145 | offs_am = offs_m
146 | offs_ak = tl.max_contiguous(tl.multiple_of(offs_k, BLOCK_SIZE_K), BLOCK_SIZE_K)
147 |
148 | if(data_contiguous):
149 | offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_SIZE_N), BLOCK_SIZE_N)
150 | offs_bk = offs_k
151 | else:
152 | offs_bn = offs_n
153 | offs_bk = tl.max_contiguous(tl.multiple_of(offs_k, BLOCK_SIZE_K), BLOCK_SIZE_K)
154 |
155 | b_ptrs = b_ptr + ((offs_bk[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn)
156 | q_shift = ((offs_bk % elements_per_sample) * W_nbits).to(tl.int32)[:, None]
157 |
158 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak)
159 | a_mask = offs_am[:, None] < M
160 |
161 | scales_ptrs = scales_ptr + offs_bn[None, :] * stride_meta_n
162 | zeros_ptrs = zeros_ptr + offs_bn[None, :] * stride_meta_n
163 |
164 | stride_mul: tl.constexpr = BLOCK_SIZE_K / group_size
165 |
166 | # BLOCK_SIZE_K_U: How much to advance pointers in matrix A (unpacked matrix)
167 | # We multiply by SPLIT_K since each thread block processes only K / SPLIT_K * BLOCK_SIZE_K elements
168 | # This represents the stride in the K dimension for matrix A
169 | BLOCK_SIZE_K_U: tl.constexpr = BLOCK_SIZE_K * SPLIT_K
170 |
171 | # BLOCK_SIZE_K_P: How much to advance pointers in matrix B (packed matrix)
172 | # Since B is packed with elements_per_sample values per int32
173 | # We divide BLOCK_SIZE_K by elements_per_sample to get number of int32s
174 | # Then multiply by SPLIT_K for the same reason as above
175 | # This represents the stride in the K dimension for packed matrix B
176 | BLOCK_SIZE_K_P: tl.constexpr = (BLOCK_SIZE_K // elements_per_sample) * SPLIT_K
177 |
178 | if(zero_is_scalar):
179 | zero_scalar = tl.load(zeros_ptr, eviction_policy='evict_last')
180 |
181 | acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
182 |
183 | for k in range(num_pid_k):
184 |
185 | if(A_load_order == 0):
186 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last')
187 |
188 | b = tl.load(b_ptrs, eviction_policy='evict_first')
189 |
190 | if(A_load_order == 1):
191 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last')
192 |
193 | # This code calculates which group we're currently processing for weight quantization
194 | if(W_group_mode > 0):
195 | # Important: We need BLOCK_SIZE_K to be smaller than group_size
196 | # This is because we only load one line from the scales here
197 | #
198 | # Example with proper sizing:
199 | # - BLOCK_SIZE_K = 8 (how many K elements we process per block)
200 | # - group_size = 32 (we quantize weights in groups of 32)
201 | # - SPLIT_K = 4 (we split K dimension across 4 blocks)
202 | # - stride_mul = BLOCK_SIZE_K/group_size = 8/32 = 0.25 (fraction of a group in one block)
203 | #
204 | # k: outer loop counter (0, 1, 2..., tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K))
205 | # pid_k: which split the thread block is responsible for (0 to SPLIT_K-1)
206 | #
207 | # For example, if k=2, SPLIT_K=4, pid_k=1:
208 | # k * SPLIT_K + pid_k = 2 * 4 + 1 = 9
209 | #
210 | # Multiply by stride_mul to convert from blocks to groups:
211 | # 9 * 0.25 = 2.25, which means we're processing part of the 2nd group
212 | k_m = ((k * SPLIT_K + pid_k) * stride_mul).to(tl.int32)
213 |
214 | if(W_group_mode >= 2): #[2, 3, 4]
215 | scales = tl.load(scales_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
216 | else:
217 | scales = None
218 |
219 | if(W_group_mode == 1 or W_group_mode >= 3): #[1, 3, 4]
220 | if(zero_is_scalar):
221 | zeros = zero_scalar
222 | else:
223 | zeros = tl.load(zeros_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
224 | else:
225 | zeros = None
226 |
227 | if(A_load_order == 2): #Mid load
228 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last')
229 |
230 | b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar)
231 |
232 | if(A_load_order == 3):
233 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last')
234 |
235 | acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype, input_precision="tf32")
236 |
237 | # Advance pointers for the next iteration of the k-loop
238 | a_ptrs += BLOCK_SIZE_K_U * stride_ak
239 | b_ptrs += BLOCK_SIZE_K_P * stride_bk
240 |
241 | if(channel_scale_mode == 1):
242 | scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy)
243 | acc = acc.to(meta_dtype) * scales_b[None, :]
244 |
245 | if(channel_scale_mode == 2):
246 | scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy)
247 | scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=meta_dtype)
248 | acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :])
249 |
250 | if(channel_scale_mode == 3):
251 | scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy)
252 | scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy)
253 | acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :])
254 |
255 | acc = acc.to(output_dtype)
256 |
257 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
258 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
259 | offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N)
260 | c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn)
261 |
262 | # We know that each thread block computes a partial result for the same output location (M,N coordinates)
263 | # When SPLIT_K > 1, multiple blocks will write to the same memory location
264 | # We use atomic_add to safely accumulate these partial results from different blocks
265 | # without race conditions, ensuring all contributions are correctly summed
266 | # The atomic operation guarantees that concurrent updates to the same memory
267 | # location happen in a coordinated way, preventing data corruption
268 | if(SPLIT_K > 1):
269 | tl.atomic_add(c_ptrs, acc, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N), sem=atomic_mode) #release / relaxed
270 | else:
271 | # When SPLIT_K = 1, each output location is computed by exactly one block
272 | # so we can use a simple store operation instead of an atomic add (this is the same as the standard GEMM)
273 | tl.store(c_ptrs, acc, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))
274 |
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MekkCyber/TritonAcademy/7a3bf78b72406edbeb9fb6c4be07d7e694061e3f/assets/logo.png
--------------------------------------------------------------------------------
/backprop_math/cross_entropy.md:
--------------------------------------------------------------------------------
1 | # Cross Entropy
2 |
3 | ## Definition of Cross Entropy Loss
4 | Cross Entropy Loss for a classification problem is defined as:
5 |
6 | $$CE(x, class) = -\log(\text{softmax}(x)[class])$$
7 |
8 | Where softmax is defined as:
9 |
10 | $$\text{softmax}(x)[i] = \frac{e^{x_i}}{\sum_j e^{x_j}}$$
11 |
12 | ## Expanded form of Cross Entropy Loss
13 |
14 | $$CE(x, class) = -\log\left(\frac{e^{x_{class}}}{\sum_i e^{x_i}}\right)$$
15 |
16 | $$CE(x, class) = -x_{class} + \log\left(\sum_i e^{x_i}\right)$$
17 |
18 | Let's denote $z = \log\left(\sum_i e^{x_i}\right)$, which is the LogSumExp function.
19 |
20 | ## Compute gradients with respect to each logit
21 |
22 | ### Case 1: For the correct class (i = class)
23 |
24 | $$\frac{\partial CE}{\partial x_{class}} = \frac{\partial}{\partial x_{class}}(-x_{class} + z)$$
25 |
26 | $$\frac{\partial CE}{\partial x_{class}} = -1 + \frac{\partial z}{\partial x_{class}}$$
27 |
28 | For the LogSumExp term:
29 |
30 | $$\frac{\partial z}{\partial x_{class}} = \frac{\partial}{\partial x_{class}}\log\left(\sum_i e^{x_i}\right)$$
31 |
32 | $$\frac{\partial z}{\partial x_{class}} = \frac{1}{\sum_i e^{x_i}} \cdot \frac{\partial}{\partial x_{class}}\left(\sum_i e^{x_i}\right)$$
33 |
34 | $$\frac{\partial z}{\partial x_{class}} = \frac{1}{\sum_i e^{x_i}} \cdot e^{x_{class}}$$
35 |
36 | $$\frac{\partial z}{\partial x_{class}} = \frac{e^{x_{class}}}{\sum_i e^{x_i}} = \text{softmax}(x)[class]$$
37 |
38 | Substituting back:
39 |
40 | $$\frac{\partial CE}{\partial x_{class}} = -1 + \text{softmax}(x)[class]$$
41 |
42 | ### Case 2: For other classes (i ≠ class)
43 |
44 | $$\frac{\partial CE}{\partial x_i} = \frac{\partial}{\partial x_i}(-x_{class} + z)$$
45 |
46 | $$\frac{\partial CE}{\partial x_i} = 0 + \frac{\partial z}{\partial x_i}$$
47 |
48 | For the LogSumExp term same as before:
49 |
50 | $$\frac{\partial z}{\partial x_i} = \frac{e^{x_i}}{\sum_j e^{x_j}} = \text{softmax}(x)[i]$$
51 |
52 | For the correct class (i = class):
53 |
54 | $$\frac{\partial CE}{\partial x_{class}} = -1 + \text{softmax}(x)[class]$$
55 |
56 | For other classes (i ≠ class):
57 |
58 | $$\frac{\partial CE}{\partial x_i} = \text{softmax}(x)[i]$$
59 |
60 | ## Generalize to a single formula
61 | We can combine both cases into one formula:
62 |
63 | $$\frac{\partial CE}{\partial x_i} = \text{softmax}(x)[i] - \mathbf{1}_{i=class}$$
64 |
65 | Where $\mathbf{1}_{i=class}$ is an indicator function that equals 1 when i is the correct class and 0 otherwise.
66 |
--------------------------------------------------------------------------------
/backprop_math/geglu.md:
--------------------------------------------------------------------------------
1 | # Derivative of GeLU
2 |
3 | ## Exact Derivative
4 |
5 | Starting with:
6 |
7 | $$f = \frac{1}{2} \cdot x \cdot \left(1 + \text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)\right)$$
8 |
9 | Using the product rule:
10 |
11 | $$\frac{d}{dx}[u(x) \cdot v(x)] = u'(x) \cdot v(x) + u(x) \cdot v'(x)$$
12 |
13 | Let:
14 |
15 | - $u(x) = \frac{1}{2} \cdot x$
16 | - $v(x) = 1 + \text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)$
17 |
18 | Step 1: Find $u'(x)$
19 |
20 | $$u'(x) = \frac{d}{dx}\left[\frac{1}{2} \cdot x\right] = \frac{1}{2}$$
21 |
22 | Step 2: Find $v'(x)$
23 | We need the chain rule here. The derivative of erf(x) is
24 |
25 | $$\frac{2}{\sqrt{\pi}} \cdot e^{-x^2}$$
26 |
27 | $$v'(x) = \frac{d}{dx}\left[1 + \text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)\right] = \frac{d}{dx}\left[\text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)\right]$$
28 |
29 | Using the chain rule with $g(x) = \frac{1}{\sqrt{2}} \cdot x$:
30 |
31 | $$v'(x) = \frac{2}{\sqrt{\pi}} \cdot e^{-\left(\frac{1}{\sqrt{2}} \cdot x\right)^2} \cdot \frac{d}{dx}\left[\frac{1}{\sqrt{2}} \cdot x\right]$$
32 |
33 | $$v'(x) = \frac{2}{\sqrt{\pi}} \cdot e^{-\frac{x^2}{2}} \cdot \frac{1}{\sqrt{2}}$$
34 |
35 | $$v'(x) = \frac{2}{\sqrt{\pi}} \cdot \frac{1}{\sqrt{2}} \cdot e^{-\frac{x^2}{2}}$$
36 |
37 | $$v'(x) = \frac{2}{\sqrt{2\pi}} \cdot e^{-\frac{x^2}{2}}$$
38 |
39 | Step 3: Apply the product rule
40 |
41 | $$\frac{df}{dx} = u'(x) \cdot v(x) + u(x) \cdot v'(x)$$
42 |
43 | $$\frac{df}{dx} = \frac{1}{2} \cdot \left(1 + \text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)\right) + \frac{1}{2} \cdot x \cdot \frac{2}{\sqrt{2\pi}} \cdot e^{-\frac{x^2}{2}}$$
44 |
45 | $$\frac{df}{dx} = \frac{1}{2} \cdot \left(1 + \text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)\right) + \frac{x}{\sqrt{2\pi}} \cdot e^{-\frac{x^2}{2}}$$
46 |
47 | This is our final result:
48 |
49 | $$\frac{df}{dx} = \frac{1}{2} \cdot \left(1 + \text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)\right) + \frac{x}{\sqrt{2\pi}} \cdot e^{-\frac{x^2}{2}}$$
50 |
51 |
52 | ## Approximate Derivative
53 |
54 | Starting with:
55 |
56 | $$f(x) = 0.5 \cdot x \cdot (1 + \tanh(\sqrt{\frac{2}{\pi}} \cdot x \cdot (1 + 0.044715 \cdot x^2)))$$
57 |
58 | For simplicity, let's denote:
59 |
60 | $$z(x) = \sqrt{\frac{2}{\pi}} \cdot x \cdot (1 + 0.044715 \cdot x^2) = x \cdot (a + b \cdot x^2)$$
61 |
62 | and
63 |
64 | $$v(x) = 1 + \tanh(z(x))$$
65 |
66 | Then:
67 |
68 | $$f(x) = 0.5 \cdot x \cdot (1 + \tanh(z(x))) = 0.5 \cdot x \cdot v(x)$$
69 |
70 | Using the product rule:
71 |
72 | $$\frac{d}{dx}[u(x) \cdot v(x)] = u'(x) \cdot v(x) + u(x) \cdot v'(x)$$
73 |
74 | Let:
75 | - $u(x) = 0.5 \cdot x$
76 | - $v(x) = 1 + \tanh(z(x))$
77 |
78 | Step 1: Find $u'(x)$
79 |
80 | $$u'(x) = 0.5$$
81 |
82 | Step 2: Find $v'(x)$
83 | Using the chain rule and the fact that the derivative of $\tanh(x)$ is $1 - \tanh^2(x)$:
84 |
85 | $$v'(x) = (1 - \tanh^2(z(x))) \cdot z'(x)$$
86 |
87 | The derivative of $z(x)$:
88 |
89 | $$z'(x) = a + 3b \cdot x^2$$
90 |
91 | Step 3: Using the identity for $1 - \tanh^2(z(x))$:
92 |
93 | $$1 - \tanh^2(z(x)) = (1 - \tanh(z(x)))(1 + \tanh(z(x)))$$
94 |
95 | $$1 - \tanh^2(z(x)) = (2 - (1 + \tanh(z(x))))(1 + \tanh(z(x)))$$
96 |
97 | $$1 - \tanh^2(z(x)) = (2- v(x))v(x)$$
98 |
99 | This confirms our identity. Now using the form with $(2 - (1 + \tanh(z(x))))$:
100 |
101 | $$v'(x) = (2- v(x))v(x) \cdot z'(x)$$
102 |
103 | Step 5: Apply the product rule for the complete derivative:
104 |
105 | $$\frac{df}{dx} = 0.5 \cdot (1 + \tanh(z(x))) + 0.5 \cdot x \cdot (2- v(x))v(x) \cdot z'(x)$$
106 |
107 | Substituting $z'(x) = a + 3b \cdot x^2$:
108 |
109 | $$\frac{df}{dx} = 0.5 \cdot v(x) + 0.5 \cdot x \cdot (2 - v(x)) \cdot v(x) \cdot (a + 3b \cdot x^2)$$
110 |
111 | $$\frac{df}{dx} = 0.5 \cdot v(x) \cdot \left[1 + x \cdot (2 - v(x)) \cdot (a + 3b \cdot x^2)\right]$$
112 |
113 |
--------------------------------------------------------------------------------
/backprop_math/layernorm.md:
--------------------------------------------------------------------------------
1 | # Layer Normalization Backward Pass Derivation
2 |
3 | ## Forward Pass
4 |
5 | First, let's establish the forward pass equations:
6 |
7 | $$\mu = \frac{1}{n}\sum_{i=1}^n X_i$$
8 |
9 | $$\sigma = \sqrt{\frac{1}{n}\sum_{i=1}^n(X_i-\mu)^2}$$
10 |
11 | $$\hat{X} = \frac{X - \mu}{\sigma}$$
12 |
13 | $$Y = \gamma \odot \hat{X} + \beta$$
14 |
15 | Where:
16 | - $X$ is the input tensor
17 | - $\mu$ is the mean (scalar)
18 | - $\sigma$ is the standard deviation (scalar)
19 | - $\hat{X}$ is the normalized input
20 | - $\gamma$ and $\beta$ are learnable parameters
21 | - $n$ is the feature dimension
22 | - $\odot$ represents element-wise multiplication
23 |
24 | ## Backward Pass Derivation
25 |
26 | We'll derive $\nabla_X$ (gradient with respect to input) given $\nabla_Y$ (gradient from the output).
27 |
28 | ### Step 1: Gradient from $Y$ to $\hat{X}$
29 |
30 | Starting with
31 |
32 | $$Y = \gamma \odot \hat{X} + \beta$$
33 |
34 | Taking the derivative with respect to $\hat{X}$:
35 |
36 | $$ \nabla_{\hat{X}} = \frac{\partial \mathcal{L}}{\partial \hat{X}} = \frac{\partial \mathcal{L}}{\partial Y} \cdot \frac{\partial Y}{\partial \hat{X}} = \nabla_Y \odot \gamma$$
37 |
38 | This means each element of the gradient with respect to $\hat{X}$ is the corresponding element of $\nabla_Y$ multiplied by the corresponding element of $\gamma$.
39 |
40 | ### Step 2: Gradient from $\hat{X}$ to $X$
41 |
42 | Now we need to compute $\nabla_X$ given $\nabla_{\hat{X}}$, using the chain rule again:
43 |
44 | $$\nabla_X = \frac{\partial \hat{X}}{\partial X} \cdot \nabla_{\hat{X}}$$
45 |
46 | We need to compute the gradient of $\hat{X}$ with respect to $X$. The normalized value is:
47 |
48 | $$\hat{X} = \frac{X - \mu}{\sigma}$$
49 |
50 | We need to account for how changes in $X$ affect $\hat{X}$ both directly and through $\mu$ and $\sigma$.
51 |
52 | #### Component 1: Direct effect on $X$
53 |
54 | For the direct effect (ignoring effects through $\mu$ and $\sigma$):
55 |
56 | $$\frac{\partial \hat{X}}{\partial X}_{\text{direct}} = \frac{1}{\sigma}\mathbf{I}$$
57 |
58 | Where $\mathbf{I}$ is the identity matrix.
59 |
60 | #### Component 2: Effect through $\mu$
61 |
62 | The mean $\mu = \frac{1}{n}\sum_{i=1}^n X_i$ depends on all elements of $X$.
63 |
64 | For any element $X_j$:
65 |
66 | $$\frac{\partial \mu}{\partial X_j} = \frac{1}{n}$$
67 |
68 | The effect on $\hat{X}_i$ through $\mu$ is:
69 |
70 | $$\frac{\partial \hat{X}_i}{\partial \mu} = -\frac{1}{\sigma}$$
71 |
72 | Combining these:
73 |
74 | $$\frac{\partial \hat{X}_i}{\partial X_j}_{\text{via }\mu} = \frac{\partial \hat{X}_i}{\partial \mu} \cdot \frac{\partial \mu}{\partial X_j} = -\frac{1}{\sigma} \cdot \frac{1}{n} = -\frac{1}{n\sigma}$$
75 |
76 | #### Component 3: Effect through $\sigma$
77 |
78 | The standard deviation $\sigma = \sqrt{\frac{1}{n}\sum_{i=1}^n(X_i-\mu)^2}$ also depends on all elements of $X$.
79 |
80 | First, let's compute $\frac{\partial \sigma}{\partial X_j}$:
81 |
82 | $${2\sigma} \cdot \frac{\partial \sigma}{\partial X_j} = \frac{\partial}{\partial X_j}\left(\frac{1}{n}\sum_{i=1}^n(X_i-\mu)^2\right)$$
83 |
84 | Which is:
85 |
86 | $$\frac{\partial \sigma}{\partial X_j} = \frac{1}{2\sigma} \cdot \frac{\partial}{\partial X_j}\left(\frac{1}{n}\sum_{i=1}^n(X_i-\mu)^2\right)$$
87 |
88 | We need to account for both the direct effect on $(X_j-\mu)^2$ and the indirect effect through $\mu$ on all terms $(X_i-\mu)^2$.
89 |
90 | The direct effect when $i = j$ is:
91 |
92 | $$\frac{\partial}{\partial X_j}(X_j-\mu)^2 = 2(X_j-\mu) \cdot \left(1 - \frac{\partial \mu}{\partial X_j}\right) = 2(X_j-\mu) \cdot \left(1 - \frac{1}{n}\right)$$
93 |
94 | The indirect effect through $\mu$ for each $i \neq j$ is:
95 |
96 | $$\frac{\partial}{\partial X_j}(X_i-\mu)^2 = 2(X_i-\mu) \cdot \left(- \frac{\partial \mu}{\partial X_j}\right) = -2(X_i-\mu) \cdot \frac{1}{n}$$
97 |
98 | Combining these and simplifying:
99 |
100 | $$\frac{\partial \sigma}{\partial X_j} = \frac{1}{2\sigma} \cdot \frac{1}{n} \cdot \left(2(X_j-\mu)\left(1-\frac{1}{n}\right) - \sum_{i \neq j}2(X_i-\mu)\frac{1}{n}\right)$$
101 |
102 | This further simplifies to:
103 |
104 | $$\frac{\partial \sigma}{\partial X_j} = \frac{1}{n\sigma}(X_j-\mu)$$
105 |
106 | because $\sum_{i=1}^n (X_i-\mu) = 0$ and that's because $\mu = \frac{1}{n}\sum_{i=1}^n X_i$.
107 |
108 | Or in terms of $\hat{X}$:
109 |
110 | $$\frac{\partial \sigma}{\partial X_j} = \frac{1}{n}\hat{X}_j$$
111 |
112 | Now, the effect on $\hat{X}_i$ through $\sigma$ is:
113 |
114 | $$\frac{\partial \hat{X}_i}{\partial \sigma} = -\frac{X_i-\mu}{\sigma^2} = -\frac{\hat{X}_i}{\sigma}$$
115 |
116 | Combining these:
117 |
118 | $$\frac{\partial \hat{X}_i}{\partial X_j}_{\text{via }\sigma} = \frac{\partial \hat{X}_i}{\partial \sigma} \cdot \frac{\partial \sigma}{\partial X_j} = -\frac{\hat{X}_i}{\sigma} \cdot \frac{1}{n}\hat{X}_j = -\frac{1}{n\sigma}\hat{X}_i\hat{X}_j$$
119 |
120 | #### Combining All Components
121 |
122 | Adding all three components together:
123 |
124 | $$\nabla_{X_i} = \left(\frac{\partial \hat{X}}{\partial X}\right)_{i,:}\cdot \nabla_{\hat{X}}$$
125 |
126 | $$= \frac{1}{\sigma}\nabla_{\hat{X}_i} - \frac{1}{n\sigma}\sum_{j=1}^n \nabla_{\hat{X}_j} - \frac{\hat{X}_i}{n\sigma}\sum_{j=1}^n \nabla_{\hat{X}_j}\hat{X}_j$$
127 |
128 | In vector notation:
129 |
130 | $$\nabla_X = \frac{1}{\sigma}\nabla_{\hat{X}} - \frac{1}{n\sigma}\mathbf{1}\sum_{i=1}^n \nabla_{\hat{X}_i} - \frac{1}{n\sigma}\hat{X} \odot \left(\sum_{i=1}^n \nabla_{\hat{X}_i}\hat{X}_i\right)$$
131 |
132 | Where $\mathbf{1}$ is a vector of ones.
133 |
134 | Substituting $\nabla_{\hat{X}} = \nabla_Y \odot \gamma$:
135 |
136 | $$\nabla_X = \frac{1}{\sigma}\left(\nabla_Y \odot \gamma - \frac{1}{n}\mathbf{1}\sum_{i=1}^n (\nabla_Y \odot \gamma)_i - \hat{X} \odot \frac{1}{n}\sum_{i=1}^n(\nabla_Y \odot \gamma \odot \hat{X})_i\right)$$
137 |
138 | This can be written more compactly as:
139 |
140 | $$\nabla_X = \frac{1}{\sigma}\left(\nabla_Y \odot \gamma - \left(\frac{1}{n}\hat{X} \cdot (\nabla_Y \odot \gamma)\right) \odot \hat{X} - \frac{1}{n}\nabla_Y \cdot \gamma \right)$$
141 |
142 | This is the complete formula for the backward pass of layer normalization with respect to the input $X$.
143 |
144 |
--------------------------------------------------------------------------------
/backprop_math/swiglu.md:
--------------------------------------------------------------------------------
1 | # Derivative of SwiGLU
2 |
3 | We have:
4 |
5 | $$f(x) = \frac{x}{1 + e^{-x}}$$
6 |
7 | Find $\frac{df}{dx}$ using the quotient rule.
8 | For $f(x) = \frac{u(x)}{v(x)}$, the quotient rule gives us:
9 |
10 | $$\frac{df}{dx} = \frac{u'(x) \cdot v(x) - u(x) \cdot v'(x)}{v(x)^2}$$
11 |
12 | Where:
13 | - $u(x) = x$
14 | - $v(x) = 1 + e^{-x}$
15 |
16 | Calculate $u'(x)$:
17 |
18 | $$u'(x) = \frac{d}{dx}[x] = 1$$
19 |
20 | Calculate $v'(x)$:
21 |
22 | $$v'(x) = \frac{d}{dx}[1 + e^{-x}] = -e^{-x}$$
23 |
24 | We apply the quotient rule:
25 |
26 | $$\frac{df}{dx} = \frac{1 \cdot (1 + e^{-x}) + x \cdot e^{-x}}{(1 + e^{-x})^2}$$
27 |
28 | $$\frac{df}{dx} = \frac{1 + e^{-x} + x \cdot e^{-x}}{(1 + e^{-x})^2}$$
29 |
30 | $$\frac{df}{dx} = \frac{1}{1 + e^{-x}} + \frac{x \cdot (e^{-x} + 1)}{(1 + e^{-x})^2} - \frac{x}{(1 + e^{-x})^2}$$
31 |
32 | Alternative expression using sigmoid function.
33 | Since $s = \sigma(x) = \frac{1}{1 + e^{-x}}$, we can write:
34 |
35 | $$\frac{df}{dx} = \sigma(x) + x \cdot \sigma(x) - x \cdot \sigma(x)^2 = \sigma(x) \cdot (1 + x \cdot (1 - \sigma(x)))$$
36 |
--------------------------------------------------------------------------------