├── README.md ├── annotated_examples ├── classics │ ├── cross_entropy.py │ ├── geglu.py │ ├── layernorm.py │ ├── matmul.py │ └── swiglu.py └── gemlite │ ├── gemm.py │ └── gemm_splitK.py ├── assets └── logo.png └── backprop_math ├── cross_entropy.md ├── geglu.md ├── layernorm.md └── swiglu.md /README.md: -------------------------------------------------------------------------------- 1 | # Triton Academy 2 | 3 |
4 | 5 |
6 | 7 | 8 | Triton Academy is an initiative to educate developers on writing efficient GPU kernels using Triton. It's a work in progress and it's goal is to provide: 9 | 10 | 🎓 Tutorials – Learn Triton step by step 11 | 🔍 Famous Triton Kernels detailed explanations 12 | 📐 Mathematic formulas and proofs behind backpropogation in kernels 13 | 📖 Documentation & Debugging – Official docs and best practices 14 | 🔬 Benchmarks – Performance comparisons with CUDA 15 | 16 | ## What is Triton? 17 | 18 | Triton is an open-source programming language and compiler designed specifically for GPU programming. It aims to simplify the development of efficient GPU kernels by providing a higher-level abstraction than CUDA or other low-level GPU programming models. 19 | 20 | Triton enables developers to write high-performance GPU code with Python-like syntax while automatically handling many low-level optimizations that would otherwise require significant expertise in GPU architecture. It was developed by OpenAI and is now widely used in machine learning and scientific computing applications. 21 | 22 | ## Why Triton instead of CUDA? 23 | 24 | While CUDA offers fine-grained control over GPU programming, it comes with several challenges: 25 | 26 | 1. **Steep Learning Curve**: CUDA requires understanding complex GPU architecture details 27 | 2. **Verbose Code**: Simple operations often require extensive boilerplate code 28 | 3. **Manual Optimization**: Developers must manually handle memory coalescing, tiling, and other optimizations 29 | 4. **Limited Portability**: CUDA is specific to NVIDIA GPUs 30 | 31 | Triton addresses these issues by: 32 | 33 | 1. **Higher Abstraction**: Provides intuitive programming constructs 34 | 2. **Automatic Optimization**: Handles many low-level optimizations automatically 35 | 3. **Python Integration**: Seamlessly integrates with the Python ecosystem 36 | 4. **Performance**: Achieves performance comparable to hand-optimized CUDA in many cases 37 | 5. **Readability**: More maintainable code that clearly expresses intent 38 | 39 | ## Quick Example: Vector Addition in Triton 40 | 41 | Let’s expand on your example and provide comprehensive resources on Triton, details on Triton Academy, and how you can contribute! 42 | 43 | ```python 44 | import torch 45 | import triton 46 | import triton.language as tl 47 | 48 | @triton.jit 49 | def vector_add_kernel(X, Y, Z, N: tl.constexpr): 50 | pid = tl.program_id(axis=0) # Get the program ID 51 | block_size = 128 # Number of elements per block 52 | offsets = pid * block_size + tl.arange(0, block_size) # Compute memory offsets 53 | mask = offsets < N # Ensure we don’t go out of bounds 54 | x = tl.load(X + offsets, mask=mask) # Load X 55 | y = tl.load(Y + offsets, mask=mask) # Load Y 56 | tl.store(Z + offsets, x + y, mask=mask) # Store the result 57 | 58 | # Example usage 59 | N = 1024 60 | X = torch.randn(N, device="cuda") 61 | Y = torch.randn(N, device="cuda") 62 | Z = torch.empty(N, device="cuda") 63 | 64 | grid = (N // 128,) 65 | vector_add_kernel[grid](X, Y, Z, N=N) 66 | print(Z) # Output: X + Y 67 | ``` 68 | 69 | Explanation: 70 | • tl.program_id(axis=0) → Gets the program index 71 | • tl.arange(0, block_size) → Generates thread-local indices 72 | • tl.load and tl.store → Handle memory operations efficiently 73 | 74 | ## Resources 75 | - [CUDA Programming Course – High-Performance Computing with GPUs](https://www.youtube.com/watch?v=86FAWCzIe_4&t=30156s&pp=ygUNdHJpdG9uIGNvdXJzZQ%3D%3D) 76 | - [Triton Tutorials](https://triton-lang.org/main/getting-started/tutorials/index.html) 77 | - [Triton Puzzles](https://github.com/srush/Triton-Puzzles) 78 | - 79 | -------------------------------------------------------------------------------- /annotated_examples/classics/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # Modifications Copyright 2025 Mekkcyber. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import triton 17 | import triton.language as tl 18 | import torch 19 | from transformers.models.llama.modeling_llama import logger 20 | 21 | from triton.language.extra import libdevice 22 | triton_tanh = libdevice.tanh 23 | triton_cast = tl.cast 24 | MAX_FUSED_SIZE : int = 65536 25 | next_power_of_2 = triton.next_power_of_2 26 | 27 | def calculate_settings(n : int) -> (int, int,): 28 | BLOCK_SIZE : int = next_power_of_2(n) 29 | if BLOCK_SIZE > MAX_FUSED_SIZE: 30 | raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\ 31 | f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") 32 | num_warps : int = 4 33 | if BLOCK_SIZE >= 32768: num_warps = 32 34 | elif BLOCK_SIZE >= 8192: num_warps = 16 35 | elif BLOCK_SIZE >= 2048: num_warps = 8 36 | return BLOCK_SIZE, num_warps 37 | 38 | @triton.jit 39 | def _cross_entropy_forward( 40 | logits_ptr , # Pointer to logits tensor [batch*seq_len, vocab_size] 41 | logits_row_stride , # Stride for accessing rows in logits 42 | loss_ptr , # Pointer to output loss values 43 | logsumexp_ptr , # Pointer to store logsumexp values (needed for backward) 44 | labels_ptr , # Pointer to label indices 45 | VOCAB_SIZE , # Size of vocabulary 46 | BLOCK_SIZE : tl.constexpr, # Block size for parallel processing 47 | DO_SOFTCAPPING , # Flag for logit softcapping (e.g., for Gemma 2) 48 | SOFTCAP , # Softcapping parameter value 49 | DO_LOGIT_SCALING , # Flag for logit scaling (e.g., for Cohere models) 50 | LOGIT_SCALE , # Scaling factor for logits 51 | ): 52 | """ 53 | Computes cross-entropy loss in a numerically stable way. 54 | 55 | Cross Entropy Loss Formula: 56 | CE = -∑(y_i * log(p_i)) where p_i = softmax(x_i) = exp(x_i) / ∑exp(x_j) 57 | 58 | For one-hot labels (our case), this simplifies to: 59 | CE = -log(p_correct) = -(logit_correct - logsumexp) 60 | = logsumexp - logit_correct 61 | 62 | Numerical Stability: 63 | We use the LogSumExp trick for numerical stability: 64 | logsumexp(x) = max(x) + log(∑exp(x - max(x))) 65 | 66 | This prevents overflow by ensuring the largest exponentiated term is exp(0.0) = 1.0. 67 | 68 | Special handling: 69 | - If label == -100: loss = 0 (ignore token, e.g., padding) 70 | - Otherwise: loss = logsumexp - logit_correct 71 | """ 72 | # Get current row index from the program ID, every block thread will handle a different row 73 | row_idx = tl.program_id(0) 74 | 75 | # Offset pointers to the current row 76 | # we cast to tl.int64 to avoid overflow because vocab sizes are large 77 | logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64) 78 | # each row corresponds to a different token in the sequence, hence a loss, logsumexp and label 79 | loss_ptr += row_idx 80 | logsumexp_ptr += row_idx 81 | labels_ptr += row_idx 82 | 83 | # Create offsets for accessing columns in parallel 84 | col_offsets = tl.arange(0, BLOCK_SIZE) 85 | # Create mask for valid vocabulary indices 86 | mask = col_offsets < VOCAB_SIZE 87 | 88 | # Load the label index for this row 89 | label_idx = tl.load(labels_ptr).to(tl.int32) 90 | # Load logits for this row, masking invalid indices 91 | # we mask invalid indices to -infinity to ensure they don't contribute to the sum (exp(-infinity) = 0) 92 | logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) 93 | 94 | # Apply logit scaling if enabled: x → t*x (t = LOGIT_SCALE) 95 | # This scales the logits before softmax, affecting the "temperature" of the distribution 96 | # Higher values (t > 1) make the distribution more uniform/smoother 97 | # Lower values (0 < t < 1) make the distribution more peaked/confident 98 | # Logit scaling was introduced in models like Cohere Command and Claude to control 99 | # the model's confidence in its predictions. It helps prevent overconfidence and 100 | # can improve model calibration, especially in out-of-distribution scenarios. 101 | # Unlike temperature sampling at inference time, this scaling is applied during training. 102 | if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits 103 | 104 | # Apply logit softcapping if enabled: x → t*tanh(x/t) (t = SOFTCAP) 105 | # This bounds logits to [-t, t] range, preventing extreme values 106 | # Softcapping was introduced in models like Gemma 2 to improve training stability 107 | # by preventing logits from growing too large, which can cause: 108 | # 1. Numerical instability in softmax computation 109 | # 2. Overconfident predictions leading to poor generalization 110 | # 3. Gradient explosion during backpropagation 111 | # Unlike simple clipping, tanh-based softcapping maintains differentiability 112 | # and allows gradients to flow even for extreme values, just at a reduced magnitude. 113 | if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) 114 | 115 | # Compute logsumexp in a numerically stable way 116 | # First find the maximum logit value 117 | c = tl.max(logits, 0) 118 | # Then compute logsumexp = max + log(sum(exp(logits - max))) 119 | logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) 120 | 121 | # Compute loss only if label is valid (not -100) 122 | if label_idx != -100: 123 | # Load the logit for the correct class 124 | x = tl.load(logits_ptr + label_idx).to(tl.float32) 125 | 126 | # Apply the same transformations to the target logit 127 | if DO_LOGIT_SCALING: x = LOGIT_SCALE * x 128 | if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP) 129 | 130 | # Compute cross entropy: logsumexp - correct_logit 131 | # This is equivalent to -log(softmax(correct_logit)) 132 | loss = logsumexp - x 133 | else: 134 | # For padding tokens (label_idx == -100), set loss to 0 135 | loss = 0.0 136 | 137 | # Store results for this row 138 | tl.store(logsumexp_ptr, logsumexp) # Save logsumexp for backward pass 139 | tl.store(loss_ptr, loss) # Save the computed loss 140 | 141 | @triton.jit 142 | def _cross_entropy_backward( 143 | logits_ptr , # Pointer to input logits 144 | logits_row_stride , # Stride between rows in logits tensor 145 | dloss_ptr , # Pointer to gradient of loss w.r.t output 146 | dloss_row_stride , # Stride between rows in dloss tensor 147 | logsumexp_ptr , # Pointer to precomputed logsumexp values 148 | labels_ptr , # Pointer to target labels 149 | VOCAB_SIZE , # Size of vocabulary (number of classes) 150 | BLOCK_SIZE : tl.constexpr, # Size of processing block 151 | DO_SOFTCAPPING , # Whether to apply softcapping 152 | SOFTCAP , # Softcapping parameter value 153 | DO_LOGIT_SCALING , # Whether to apply logit scaling 154 | LOGIT_SCALE , # Logit scaling parameter value 155 | ): 156 | """ 157 | Backward pass for cross entropy loss. 158 | 159 | Cross Entropy Loss: CE(x, class) = -log(softmax(x)[class]) 160 | = -log(exp(x_class) / sum(exp(x_i))) 161 | = -x_class + log(sum(exp(x_i))) 162 | 163 | For the backward pass, we need to compute gradients w.r.t. each logit. 164 | 165 | Let L = CE(x, class) and z = log(sum(exp(x_i))) (logsumexp) 166 | 167 | For the correct class (i = class): 168 | dL/dx_class = d/dx_class(-x_class + z) = -1 + exp(x_class - z) = -1 + softmax(x_class) (check backprop_math/cross_entropy.md) 169 | 170 | For other classes (i ≠ class): 171 | dL/dx_i = d/dx_i(-x_class + z) = d/dx_i(z) = exp(x_i - z) = softmax(x_i) (check backprop_math/cross_entropy.md) 172 | 173 | When logit transformations are applied, we use the chain rule to compute gradients. 174 | """ 175 | # Get current row and block indices 176 | row_idx = tl.program_id(0) 177 | 178 | # Calculate pointers for current row 179 | logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64) 180 | dloss_ptr += row_idx * dloss_row_stride 181 | 182 | # Calculate column offsets for current block 183 | col_offsets = tl.arange(0, BLOCK_SIZE) 184 | # Create mask for valid vocabulary indices 185 | mask = col_offsets < VOCAB_SIZE 186 | 187 | # Load the target label for current row 188 | label_idx = tl.load(labels_ptr + row_idx).to(tl.int32) 189 | 190 | # Load gradient of loss w.r.t output 191 | # For padding tokens (label_idx == -100), set gradient to 0 192 | if label_idx != -100: 193 | dloss = tl.load(dloss_ptr) 194 | else: 195 | dloss = 0.0 196 | 197 | # Load logits for current row 198 | x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) 199 | 200 | # Apply logit scaling if enabled 201 | # If x is scaled as x' = s*x in forward, then dx'/dx = s 202 | if DO_LOGIT_SCALING: 203 | x = x * LOGIT_SCALE 204 | 205 | # Store original values before softcapping for gradient calculation 206 | # For softcapping: x' = t*tanh(x/t), we need to track intermediate values they will be used in the backward pass chain rule 207 | tanh_term = x 208 | if DO_SOFTCAPPING: 209 | # Apply softcapping: x' = t*tanh(x/t) 210 | tanh_term = triton_tanh(x / SOFTCAP) # Store tanh(x/t) for gradient calculation 211 | x = SOFTCAP * tanh_term # This is the softcapped value 212 | 213 | logsumexp = tl.load(logsumexp_ptr + row_idx) 214 | 215 | # Compute softmax: exp(x - logsumexp) = softmax(x) for the whole row 216 | # This gives us part of the gradient formula 217 | y = tl.exp(x - logsumexp) 218 | 219 | # Adjust gradient for the target class 220 | # For i = target: gradient = softmax(x_i) - 1 221 | # For i ≠ target: gradient = softmax(x_i) 222 | y = tl.where( 223 | col_offsets == label_idx, 224 | y - 1.0, # For target class: exp(x - logsumexp) - 1 225 | y, # For other classes: exp(x - logsumexp) 226 | ) 227 | 228 | # Apply chain rule for logit scaling 229 | # If x' = s*x, then dL/dx = dL/dx' * dx'/dx = dL/dx' * s 230 | if DO_LOGIT_SCALING: 231 | y = y * LOGIT_SCALE 232 | 233 | # Apply chain rule for softcapping 234 | # For x' = t*tanh(x/t), dx'/dx = 1 - tanh²(x/t) 235 | # This is the derivative of tanh: d/dx[tanh(x)] = 1 - tanh²(x) 236 | if DO_SOFTCAPPING: 237 | y = y * (1.0 - tanh_term*tanh_term) # tanh_term = tanh(x/t) 238 | 239 | # Store final gradients 240 | # For padding tokens (label_idx == -100), gradient is 0 241 | tl.store(logits_ptr + col_offsets, dloss * y, mask = mask) 242 | 243 | class Fast_CrossEntropyLoss(torch.autograd.Function): 244 | @staticmethod 245 | def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0): 246 | n_rows : int 247 | vocab_size : int 248 | n_rows, vocab_size = logits.shape 249 | 250 | losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda") 251 | 252 | DO_SOFTCAPPING : bool = bool(logit_softcapping != 0) 253 | DO_LOGIT_SCALING : bool = bool(logit_scaling != 0) 254 | 255 | BLOCK_SIZE : int 256 | num_warps : int 257 | # For small vocabs <= 65336 like Llama, Mistral 258 | BLOCK_SIZE, num_warps = calculate_settings(vocab_size) 259 | logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda") 260 | 261 | _cross_entropy_forward[(n_rows,)]( 262 | logits, logits.stride(0), 263 | losses, 264 | logsumexp, 265 | labels, 266 | VOCAB_SIZE = vocab_size, 267 | BLOCK_SIZE = BLOCK_SIZE, 268 | DO_SOFTCAPPING = DO_SOFTCAPPING, 269 | SOFTCAP = logit_softcapping, 270 | DO_LOGIT_SCALING = DO_LOGIT_SCALING, 271 | LOGIT_SCALE = logit_scaling, 272 | num_warps = num_warps, 273 | ) 274 | 275 | ctx.save_for_backward(logits, logsumexp, labels) 276 | ctx.DO_SOFTCAPPING = DO_SOFTCAPPING 277 | ctx.logit_softcapping = logit_softcapping 278 | ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING 279 | ctx.logit_scaling = logit_scaling 280 | return losses 281 | pass 282 | 283 | 284 | @staticmethod 285 | def backward(ctx, dlosses): 286 | logits, logsumexp, labels = ctx.saved_tensors 287 | n_rows : int 288 | vocab_size : int 289 | n_rows, vocab_size = logits.shape 290 | 291 | BLOCK_SIZE, num_warps = calculate_settings(vocab_size) 292 | 293 | _cross_entropy_backward[(n_rows,)]( 294 | logits, logits.stride(0), 295 | dlosses, dlosses.stride(0), 296 | logsumexp, 297 | labels, 298 | VOCAB_SIZE = vocab_size, 299 | BLOCK_SIZE = BLOCK_SIZE, 300 | DO_SOFTCAPPING = ctx.DO_SOFTCAPPING, 301 | SOFTCAP = ctx.logit_softcapping, 302 | DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING, 303 | LOGIT_SCALE = ctx.logit_scaling, 304 | num_warps = num_warps, 305 | ) 306 | return logits, None, None, None 307 | 308 | 309 | def fast_cross_entropy_loss(logits, labels, logit_softcapping=0, logit_scaling=0, n_items=None): 310 | """ 311 | Arguments: 312 | logits: (batch, seq_len, vocab_size) 313 | labels: (batch, seq_len,) 314 | Returns: 315 | losses: float 316 | """ 317 | batch, seq_len, d = logits.shape 318 | assert(labels.shape == (batch, seq_len)) 319 | 320 | loss = Fast_CrossEntropyLoss.apply( 321 | logits.view(batch*seq_len, d), 322 | labels.view(-1), 323 | logit_softcapping, 324 | logit_scaling, 325 | ) 326 | if n_items is None: 327 | n_items = torch.count_nonzero(labels != -100) 328 | return loss.sum() / n_items 329 | 330 | 331 | def reference_cross_entropy_loss(logits, labels, logit_softcapping=0, logit_scaling=0): 332 | """Reference implementation using PyTorch's native functions""" 333 | if logit_scaling != 0: 334 | logits = logits * logit_scaling 335 | 336 | if logit_softcapping != 0: 337 | logits = logit_softcapping * torch.tanh(logits / logit_softcapping) 338 | 339 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 340 | 341 | # Get the log probability for the correct labels 342 | label_mask = labels != -100 343 | labels_masked = labels.clone() 344 | labels_masked[~label_mask] = 0 345 | 346 | # Gather the log probabilities for the correct labels 347 | label_log_probs = log_probs.gather(dim=-1, index=labels_masked.unsqueeze(-1)).squeeze(-1) 348 | 349 | # Apply the mask to ignore padding tokens 350 | label_log_probs = label_log_probs * label_mask 351 | 352 | loss = -label_log_probs.sum() / label_mask.sum() 353 | return loss 354 | 355 | 356 | def test_cross_entropy(): 357 | """Test the forward and backward pass of the custom cross entropy implementation""" 358 | print("Testing Fast Cross Entropy implementation...") 359 | 360 | # Test configurations 361 | test_configs = [ 362 | {"name": "Standard", "softcap": 0, "scaling": 0}, 363 | {"name": "With Softcapping", "softcap": 10.0, "scaling": 0}, 364 | {"name": "With Scaling", "softcap": 0, "scaling": 2.0}, 365 | {"name": "With Both", "softcap": 10.0, "scaling": 2.0} 366 | ] 367 | 368 | for config in test_configs: 369 | print(f"\nTesting {config['name']} configuration...") 370 | 371 | # Create test inputs 372 | batch_size, seq_len, vocab_size = 2, 10, 32000 373 | logits = torch.randn(batch_size, seq_len, vocab_size, device='cuda', requires_grad=True) 374 | # Create labels with some -100 values to test padding 375 | labels = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda') 376 | labels[0, 0] = -100 # Add some padding tokens 377 | 378 | # Clone inputs for reference implementation 379 | logits_ref = logits.clone().detach().requires_grad_(True) 380 | 381 | # Forward pass 382 | our_loss = fast_cross_entropy_loss( 383 | logits, labels, 384 | logit_softcapping=config['softcap'], 385 | logit_scaling=config['scaling'] 386 | ) 387 | 388 | # Reference implementation 389 | ref_loss = reference_cross_entropy_loss( 390 | logits_ref, labels, 391 | logit_softcapping=config['softcap'], 392 | logit_scaling=config['scaling'] 393 | ) 394 | 395 | # Compare forward results 396 | forward_diff = torch.abs(our_loss - ref_loss).item() 397 | print(f"Forward pass difference: {forward_diff:.6f}") 398 | assert forward_diff < 1e-4, f"Forward pass failed for {config['name']} configuration!" 399 | 400 | # Backward pass 401 | our_loss.backward() 402 | ref_loss.backward() 403 | # Compare gradients 404 | 405 | grad_diff = torch.max(torch.abs(logits.grad - logits_ref.grad)).item() 406 | print(f"Max gradient difference: {grad_diff:.6f}") 407 | assert grad_diff < 1e-4, f"Backward pass failed for {config['name']} configuration!" 408 | 409 | # Reset gradients for next test 410 | logits.grad.zero_() 411 | logits_ref.grad.zero_() 412 | 413 | print("\nAll tests passed successfully!") 414 | return True 415 | 416 | if __name__ == "__main__": 417 | test_cross_entropy() 418 | -------------------------------------------------------------------------------- /annotated_examples/classics/geglu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # Modifications Copyright 2025 Mekkcyber. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import triton 17 | import triton.language as tl 18 | import torch 19 | 20 | from triton.language.extra import libdevice 21 | triton_tanh = libdevice.tanh 22 | triton_cast = tl.cast 23 | MAX_FUSED_SIZE : int = 65536 24 | next_power_of_2 = triton.next_power_of_2 25 | 26 | def calculate_settings(n : int) -> (int, int,): 27 | BLOCK_SIZE : int = next_power_of_2(n) 28 | if BLOCK_SIZE > MAX_FUSED_SIZE: 29 | raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\ 30 | f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") 31 | num_warps : int = 4 32 | if BLOCK_SIZE >= 32768: num_warps = 32 33 | elif BLOCK_SIZE >= 8192: num_warps = 16 34 | elif BLOCK_SIZE >= 2048: num_warps = 8 35 | return BLOCK_SIZE, num_warps 36 | 37 | @triton.jit 38 | def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): 39 | """ 40 | Forward pass for the GeGLU (Gated Gaussian Error Linear Unit) activation function. 41 | 42 | This kernel computes: 43 | 1. The GELU activation: f = 0.5 * e * (1 + erf(e/sqrt(2))) 44 | 2. The GeGLU output: h = f * g 45 | 46 | Parameters: 47 | - e: gate values (first half of the projection) 48 | - g: up values (second half of the projection) 49 | - h: output tensor to store the result 50 | - n_elements: total number of elements in the tensors 51 | - BLOCK_SIZE: size of each CUDA block for parallelization 52 | """ 53 | # Get the current block index in the grid 54 | block_idx = tl.program_id(0) 55 | 56 | # Calculate memory offsets for this block 57 | # Each block thread processes BLOCK_SIZE elements from the input 58 | offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 59 | 60 | # Create a mask to handle the case where n_elements is not divisible by BLOCK_SIZE 61 | mask = offsets < n_elements 62 | 63 | # Load input values from global memory 64 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) 65 | g_row = tl.load(g + offsets, mask = mask, other = 0) 66 | 67 | # Compute GELU activation using the exact formula: 68 | # f(x) = 0.5 * x * (1 + erf(x/sqrt(2))) 69 | # where erf is the error function 70 | # rsqrt(2.0) computes 1/sqrt(2) 71 | f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0) 72 | 73 | # Convert the result back to the same dtype as g_row 74 | f_row = f_row.to(g_row.dtype) 75 | 76 | # Compute the final GeGLU output by multiplying the GELU activation with g element-wise 77 | h_row = f_row * g_row 78 | 79 | # Store the result back to global memory 80 | tl.store(h + offsets, h_row, mask = mask) 81 | 82 | 83 | def geglu_exact_forward_kernel(gate, up): 84 | batch, seq_len, hd = gate.shape 85 | n_elements = gate.numel() 86 | out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda") 87 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 88 | _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) 89 | return out 90 | 91 | 92 | @triton.jit 93 | def _exact_backward_kernel(dY, e, g, n_elements, BLOCK_SIZE : tl.constexpr,): 94 | """ 95 | Backward pass for the GeGLU (Gated Gaussian Error Linear Unit) activation function. 96 | 97 | In the forward pass: 98 | - f = 1/2 * e * (1 + erf(1/sqrt(2) * e)) # The GELU function 99 | - h = f * g # The GeGLU output 100 | 101 | Where: 102 | - e: gate values (first half of the projection) 103 | - g: up values (second half of the projection) 104 | - h: output of GeGLU 105 | 106 | In the backward pass, we need to compute: 107 | - de: gradient with respect to e (gate values) 108 | - dg: gradient with respect to g (up values) 109 | 110 | For de, we need the derivative of f with respect to e: 111 | df/de = 1/2 * (1 + erf(1/sqrt(2) * x)) + 1/sqrt(2*pi) * x * exp(-1/2 * x^2) (see backprop_math/geglu.md) 112 | 113 | Parameters: 114 | - dY: gradient flowing from the next layer (dL/dh) 115 | - e: gate values from forward pass 116 | - g: up values from forward pass 117 | - n_elements: total number of elements in the tensors 118 | - BLOCK_SIZE: size of each CUDA block for parallelization 119 | """ 120 | block_idx = tl.program_id(0) 121 | offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 122 | mask = offsets < n_elements 123 | 124 | # Load the gradients and values 125 | dY_row = tl.load(dY + offsets, mask = mask, other = 0) 126 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) 127 | g_row = tl.load(g + offsets, mask = mask, other = 0) 128 | 129 | # Compute the partial GELU activation: 1/2 * (1 + erf(1/sqrt(2) * e)) 130 | # This is reused in both the forward computation and the derivative 131 | f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0) 132 | 133 | # Complete the GELU computation: f = f_partial_row * e_row 134 | f_row = f_partial_row * e_row 135 | f_row = f_row.to(dY_row.dtype) 136 | 137 | # Compute gradient for g: dg = dY * f 138 | # By chain rule: dL/dg = dL/dh * dh/dg = dY * f (as specified above h is the output of the GeGLU) 139 | dg_row = dY_row * f_row 140 | 141 | # Compute gradient for e using the derivative of GELU 142 | # df/de = f_partial_row + (1/sqrt(2*pi)) * e * exp(-e²/2) 143 | t = 0.3989422804014327 # 1/sqrt(2*pi) 144 | df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row) 145 | 146 | # Apply chain rule: dL/de = dL/dh * dh/de = dY * (g * df/de) 147 | de_row = g_row.to(tl.float32) * df_de 148 | de_row = de_row.to(dY_row.dtype) * dY_row 149 | 150 | # Store the computed gradients back to memory 151 | tl.store(e + offsets, de_row, mask = mask) 152 | tl.store(g + offsets, dg_row, mask = mask) 153 | 154 | 155 | def geglu_exact_backward_kernel(DW, e, g): 156 | n_elements = e.numel() 157 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 158 | _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) 159 | return e, g 160 | 161 | 162 | @triton.jit 163 | def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): 164 | """ 165 | Computes the forward pass of the approximate GELU activation function for GeGLU. 166 | 167 | GeGLU (Gated GELU Linear Unit) combines a gating mechanism with GELU activation: 168 | - Input is split into two parts: gate (e) and up (g) 169 | - GELU is applied to the gate values 170 | - The result is multiplied by the up values 171 | 172 | This kernel implements the approximate version of GELU which is faster but slightly less accurate. 173 | 174 | Formula: 175 | f(e) = 0.5 * e * (1 + tanh(sqrt(2/π) * e * (1 + 0.044715 * e²))) 176 | h = f(e) * g 177 | 178 | Where: 179 | - e: gate values 180 | - g: up values 181 | - h: output 182 | - f(e): approximate GELU activation 183 | """ 184 | # Get the current block index and compute offsets for each thread 185 | block_idx = tl.program_id(0) 186 | offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 187 | mask = offsets < n_elements 188 | 189 | # Constant for the GELU approximation: sqrt(2/π) 190 | s = 0.7978845608028654 # Precomputed value of sqrt(2/π) 191 | 192 | # Load gate and up values from memory 193 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) 194 | g_row = tl.load(g + offsets, mask = mask, other = 0) 195 | 196 | # Compute the approximate GELU activation: 197 | # f(e) = 0.5 * e * (1 + tanh(sqrt(2/π) * e * (1 + 0.044715 * e²))) 198 | # 199 | # This is a faster approximation of the exact GELU: 200 | # f(e) = 0.5 * e * (1 + erf(e/sqrt(2))) 201 | # 202 | # The approximation uses tanh instead of erf and adds a cubic term 203 | # to better match the shape of the exact GELU function 204 | inner_term = s * e_row * (1.0 + 0.044715 * e_row * e_row) 205 | f_row = 0.5 * e_row * (triton_tanh(inner_term) + 1.0) 206 | 207 | # Convert back to the original data type 208 | f_row = f_row.to(g_row.dtype) 209 | 210 | # Compute the final output: h = f(e) * g 211 | h_row = f_row * g_row 212 | 213 | # Store the result back to memory 214 | tl.store(h + offsets, h_row, mask = mask) 215 | 216 | 217 | def geglu_approx_forward_kernel(gate, up): 218 | batch, seq_len, hd = gate.shape 219 | n_elements = gate.numel() 220 | out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda") 221 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 222 | _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) 223 | return out 224 | 225 | @triton.jit 226 | def _approx_backward_kernel(dY, e, g, n_elements, BLOCK_SIZE : tl.constexpr,): 227 | """ 228 | Backward pass for the approximate GELU activation function. 229 | 230 | Forward pass: 231 | f(e) = 0.5 * e * (1 + tanh(sqrt(2/π) * e * (1 + 0.044715 * e²))) 232 | h = f(e) * g 233 | 234 | Where: 235 | - e: gate values 236 | - g: up values 237 | - dY: gradient from upstream layers 238 | - h: output 239 | 240 | Backward pass derivatives: 241 | 1. df/de = 0.5 * (1 + tanh(inner))(1 + e * (2 - (1 + tanh(inner))) * d(inner)/de) 242 | where inner = sqrt(2/π) * e * (1 + 0.044715 * e²) 243 | 2. d(inner)/de = sqrt(2/π) * (1 + 0.044715 * e² * 3) = (a + 3b * e²) where a = sqrt(2/π) and b = 0.044715*sqrt(2/π) 244 | 3. de = dY * g * df/de 245 | 4. dg = dY * f(e) 246 | 247 | """ 248 | # Get block index and compute offsets for parallel processing 249 | block_idx = tl.program_id(0) 250 | offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 251 | mask = offsets < n_elements 252 | 253 | # Load input values from memory 254 | dY_row = tl.load(dY + offsets, mask=mask, other=0) 255 | e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32) 256 | g_row = tl.load(g + offsets, mask=mask, other=0) 257 | 258 | # Constants for the GELU approximation 259 | a = 0.7978845608028654 # Precomputed value of sqrt(2/π) 260 | b = 0.044715 * a 261 | 262 | # Full inner term: sqrt(2/π) * e * (1 + 0.044715 * e²) 263 | inner = e_row * (a + b * e_row * e_row) 264 | 265 | # Compute tanh of the inner term 266 | tanh_inner = triton_tanh(inner) 267 | 268 | # Compute (1 + tanh(inner_term)) 269 | v = 1.0 + tanh_inner 270 | 271 | # compute f(e) = 0.5 * e * (1 + tanh(inner_term)) based on the forward pass formula in the backprop_math/geglu.md 272 | f_row = 0.5 * e_row * v 273 | 274 | # compute df/de based on the fomula in the backprop_math/geglu.md 275 | df_de = 0.5 * v * (1.0 + e_row * (2.0 - v) * (a + 3.0 * b * e_row * e_row)) 276 | 277 | # Compute gradients for backpropagation: 278 | # dg = dY * f(e) 279 | dg_row = dY_row * f_row 280 | 281 | # Compute gradients for backpropagation: 282 | # de = dY * g * df/de 283 | de_row = g_row * df_de 284 | de_row = de_row.to(dY_row.dtype) * dY_row 285 | 286 | # Store results and gradients back to memory 287 | tl.store(e + offsets, de_row, mask=mask) # Store gradient for gate values 288 | tl.store(g + offsets, dg_row, mask=mask) # Store gradient for up values 289 | 290 | 291 | def geglu_approx_backward_kernel(dY, e, g): 292 | n_elements = e.numel() 293 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 294 | _approx_backward_kernel[grid](dY, e, g, n_elements, BLOCK_SIZE = 1024,) 295 | return e, g 296 | 297 | 298 | def test_geglu_correctness(use_approx=False): 299 | """ 300 | Test the correctness of the GEGLU implementation by comparing with a reference implementation. 301 | Tests both forward and backward passes for GEGLU (Gated GELU). 302 | 303 | Args: 304 | use_approx (bool): If True, use the approximate GEGLU implementation. 305 | If False, use the exact GEGLU implementation. 306 | """ 307 | import torch 308 | import torch.nn.functional as F 309 | 310 | # Define reference implementations for GEGLU (Gated GELU) 311 | def geglu_reference_forward(x): 312 | """Reference implementation of GEGLU forward pass""" 313 | x_chunks = torch.chunk(x, 2, dim=-1) 314 | gate, value = x_chunks[0], x_chunks[1] 315 | return value * F.gelu(gate) 316 | 317 | # Select the appropriate kernels based on the use_approx flag 318 | forward_kernel = _approx_forward_kernel if use_approx else _exact_forward_kernel 319 | backward_kernel = _approx_backward_kernel if use_approx else _exact_backward_kernel 320 | 321 | implementation_type = "approximate" if use_approx else "exact" 322 | print(f"Testing {implementation_type} GEGLU implementation...") 323 | 324 | def test_forward(): 325 | """Test the forward pass of GEGLU""" 326 | print(f"Testing {implementation_type} GEGLU forward pass...") 327 | 328 | batch_size, seq_len, hidden_dim = 2, 10, 128 329 | x = torch.randn(batch_size, seq_len, hidden_dim * 2, device='cuda', requires_grad=True) 330 | 331 | ref_output = geglu_reference_forward(x) 332 | 333 | x_chunks = torch.chunk(x, 2, dim=-1) 334 | gate, value = x_chunks[0], x_chunks[1] 335 | gate_flat = gate.reshape(-1) 336 | value_flat = value.reshape(-1) 337 | 338 | output_flat = torch.empty_like(gate_flat) 339 | 340 | n_elements = gate_flat.numel() 341 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 342 | forward_kernel[grid](gate_flat, value_flat, output_flat, n_elements, BLOCK_SIZE=1024) 343 | 344 | our_output = output_flat.reshape(gate.shape) 345 | 346 | max_diff = torch.max(torch.abs(ref_output - our_output)) 347 | print(f"Max difference in {implementation_type} GEGLU forward pass: {max_diff.item()}") 348 | assert max_diff < 1e-2 if use_approx else 1e-5, f"{implementation_type} GEGLU forward pass implementation is incorrect!" 349 | return True 350 | 351 | def test_backward(): 352 | """Test the backward pass of GEGLU""" 353 | print(f"Testing {implementation_type} GEGLU backward pass...") 354 | 355 | batch_size, seq_len, hidden_dim = 2, 10, 128 356 | x = torch.randn(batch_size, seq_len, hidden_dim * 2, device='cuda', requires_grad=True) 357 | 358 | x_ref = x.clone().detach().requires_grad_(True) 359 | ref_output = geglu_reference_forward(x_ref) 360 | 361 | grad_output = torch.randn_like(ref_output) 362 | 363 | ref_output.backward(grad_output) 364 | ref_grad = x_ref.grad.clone() 365 | 366 | x_chunks = torch.chunk(x, 2, dim=-1) 367 | gate, value = x_chunks[0], x_chunks[1] 368 | gate_flat = gate.reshape(-1) 369 | value_flat = value.reshape(-1) 370 | 371 | output_flat = torch.empty_like(gate_flat) 372 | n_elements = gate_flat.numel() 373 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 374 | forward_kernel[grid](gate_flat, value_flat, output_flat, n_elements, BLOCK_SIZE=1024) 375 | 376 | grad_output_flat = grad_output.reshape(-1) 377 | 378 | dW = grad_output_flat.clone() 379 | e = gate_flat.clone() 380 | g = value_flat.clone() 381 | 382 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 383 | backward_kernel[grid](dW, e, g, n_elements, BLOCK_SIZE=1024) 384 | 385 | our_grad = torch.cat([e.reshape(gate.shape), g.reshape(value.shape)], dim=-1) 386 | 387 | max_diff = torch.max(torch.abs(ref_grad - our_grad)) 388 | print(f"Max difference in {implementation_type} GEGLU backward pass: {max_diff.item()}") 389 | assert max_diff < 1e-2 if use_approx else 1e-5, f"{implementation_type} GEGLU backward pass implementation is incorrect!" 390 | return True 391 | 392 | forward_passed = test_forward() 393 | backward_passed = test_backward() 394 | 395 | if forward_passed and backward_passed: 396 | print(f"All tests passed! {implementation_type.capitalize()} GEGLU implementation is correct.") 397 | else: 398 | print(f"Tests failed! {implementation_type.capitalize()} GEGLU implementation needs fixing.") 399 | 400 | if __name__ == "__main__": 401 | # Test exact implementation 402 | test_geglu_correctness(use_approx=False) 403 | 404 | # Test approximate implementation 405 | test_geglu_correctness(use_approx=True) 406 | -------------------------------------------------------------------------------- /annotated_examples/classics/layernorm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # Modifications Copyright 2025 Mekkcyber. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License 15 | 16 | import triton 17 | import triton.language as tl 18 | import torch 19 | 20 | MAX_FUSED_SIZE : int = 65536 21 | next_power_of_2 = triton.next_power_of_2 22 | 23 | def calculate_settings(n : int) -> (int, int,): 24 | BLOCK_SIZE : int = next_power_of_2(n) 25 | if BLOCK_SIZE > MAX_FUSED_SIZE: 26 | raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\ 27 | f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") 28 | num_warps : int = 4 29 | if BLOCK_SIZE >= 32768: num_warps = 32 30 | elif BLOCK_SIZE >= 8192: num_warps = 16 31 | elif BLOCK_SIZE >= 2048: num_warps = 8 32 | return BLOCK_SIZE, num_warps 33 | 34 | @triton.jit 35 | def layernorm_forward( 36 | Y, Y_row_stride, # Output tensor and its row stride 37 | X, X_row_stride, # Input tensor and its row stride 38 | weight, # Scale parameter for normalization 39 | bias, # Bias parameter for normalization 40 | inv_var, # Buffer to store inverse variance 41 | mean, # Buffer to store mean 42 | n_cols, eps, # Number of columns and epsilon for numerical stability 43 | BLOCK_SIZE : tl.constexpr # Compile-time constant for block size 44 | ): 45 | """ 46 | This kernel implements the forward pass of Layer Normalization using Triton. 47 | 48 | Layer Normalization normalizes each input row independently using the formula: 49 | y = ((x - mean) / sqrt(variance + eps)) * weight + bias 50 | 51 | Example with a 3x5 input matrix X: 52 | X = [ 53 | [1.0, 2.0, 3.0, 4.0, 5.0], # Row 0 54 | [6.0, 7.0, 8.0, 9.0, 10.0], # Row 1 55 | [11.0, 12.0, 13.0, 14.0, 15.0] # Row 2 56 | ] 57 | weight = [0.5, 0.5, 0.5, 0.5, 0.5] 58 | bias = [0.1, 0.1, 0.1, 0.1, 0.1] 59 | 60 | For row_idx = 1 (second CUDA thread block): 61 | """ 62 | 63 | # Each CUDA thread block processes one row of the input 64 | row_idx = tl.program_id(0) 65 | 66 | # Create column indices [0, 1, 2, ..., BLOCK_SIZE-1] 67 | # BLOCK_SIZE is the nearest power of 2 greater than or equal to n_cols 68 | # These will be used to access elements within a row 69 | col_offsets = tl.arange(0, BLOCK_SIZE) 70 | 71 | # Create a mask to handle cases where n_cols < BLOCK_SIZE 72 | # For example, if n_cols=5 and BLOCK_SIZE=8, only the first 5 elements are valid 73 | mask = col_offsets < n_cols 74 | 75 | # In the case of Layer Normalization, the input tensor X and output tensor Y have the same shape. 76 | # This means we can use the same indexing pattern to access corresponding elements in both tensors. 77 | # We're using row_idx to determine which row we're processing, and then using the same 78 | # col_offsets within that row to access individual elements. 79 | # 80 | # The row_stride parameters (X_row_stride and Y_row_stride) tell us how far to jump in memory 81 | # to move from one row to the next. While these are often the same for X and Y, having separate 82 | # stride parameters allows for flexibility in memory layout. 83 | 84 | # In row-major order, elements in a row are stored contiguously in memory 85 | # For a matrix with n_cols columns, the row_stride equals n_cols 86 | # Example with our 3x5 matrix X stored in row-major order in memory: 87 | # [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0] 88 | # |--------- Row0 ---------|---------- Row1 ---------|---------- Row2 -----------| 89 | # To move from the start of Row 0 to the start of Row 1, we add row_stride (5) 90 | 91 | # In the beginning X and Y point to the first element of their first row 92 | # if row_idx==1 : 93 | # - Y + row_idx * Y_row_stride = Y + 1 * 5 = Y + 5 points to the second row of Y 94 | # - X + row_idx * X_row_stride = X + 1 * 5 = X + 5 points to the second row of X 95 | 96 | Y += row_idx * Y_row_stride 97 | X += row_idx * X_row_stride 98 | # inv_var and mean are 1D tensors with n_rows elements 99 | # when row_idx==1, inv_var points to the second element in the inverse variance buffer 100 | # when row_idx==1, mean points to the second element in the mean buffer 101 | inv_var += row_idx 102 | mean += row_idx 103 | 104 | # Load the entire row from input tensor X 105 | # For row_idx=1: X_row = [6.0, 7.0, 8.0, 9.0, 10.0, 0, 0, 0] 106 | # The 'other=0' parameter sets values outside the mask to 0 107 | X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) 108 | 109 | # Load weight parameters for this row 110 | # weight_row = [0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0] (if BLOCK_SIZE=8) 111 | weight_row = tl.load(weight + col_offsets, mask = mask, other = 0).to(tl.float32) 112 | 113 | # Load bias parameters for this row 114 | # bias_row = [0.1, 0.1, 0.1, 0.1, 0.1, 0, 0, 0] (if BLOCK_SIZE=8) 115 | bias_row = tl.load(bias + col_offsets, mask = mask, other = 0).to(tl.float32) 116 | 117 | # Calculate mean of the row 118 | # For row_idx=1: mean_X = (6.0 + 7.0 + 8.0 + 9.0 + 10.0) / 5 = 8.0 119 | mean_X = tl.sum(X_row, axis = 0) / n_cols 120 | 121 | # Subtract mean from each element in the row 122 | # For row_idx=1: XX = [-2.0, -1.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0] 123 | XX = tl.where(mask, X_row - mean_X, 0) 124 | 125 | # Calculate variance of the row 126 | # For row_idx=1: row_var = ((-2.0)² + (-1.0)² + 0.0² + 1.0² + 2.0²) / 5 = (4.0 + 1.0 + 0.0 + 1.0 + 4.0) / 5 = 2.0 127 | row_var = tl.sum(XX * XX, axis = 0) / n_cols 128 | 129 | # Calculate inverse square root of variance (for stability, add epsilon) 130 | # For row_idx=1: inv_var_val = 1/sqrt(2.0 + eps) ≈ 0.707 131 | inv_var_val = tl.math.rsqrt(row_var + eps) 132 | 133 | # Store the inverse variance and mean for later use in backward pass 134 | tl.store(inv_var, inv_var_val) 135 | tl.store(mean, mean_X) 136 | 137 | # Calculate normalized output with scaling and bias 138 | # For row_idx=1: 139 | # output = ([-2.0, -1.0, 0.0, 1.0, 2.0] * 0.707) * [0.5, 0.5, 0.5, 0.5, 0.5] + [0.1, 0.1, 0.1, 0.1, 0.1] 140 | # = [-0.607, -0.2535, 0.1, 0.4535, 0.807] 141 | output = (XX * inv_var_val) * weight_row + bias_row 142 | 143 | # Store the output row 144 | tl.store(Y + col_offsets, output, mask = mask) 145 | 146 | 147 | @triton.jit 148 | def layernorm_backward( 149 | dY, dY_row_stride, # Gradient from upstream and its row stride 150 | X, X_row_stride, # Input tensor and its row stride 151 | weight, # Scale parameter for normalization 152 | bias, # Bias parameter for normalization 153 | inv_var, # Stored inverse variance from forward pass 154 | mean, # Stored mean from forward pass 155 | n_cols, eps, # Number of columns and epsilon for numerical stability 156 | BLOCK_SIZE : tl.constexpr # Compile-time constant for block size 157 | ): 158 | """ 159 | This kernel implements the backward pass of Layer Normalization using Triton. 160 | 161 | The backward pass computes the gradient with respect to the input (dX) given the 162 | gradient with respect to the output (dY). 163 | 164 | Example with a 3x5 input matrix X and corresponding gradient dY: 165 | X = [ 166 | [1.0, 2.0, 3.0, 4.0, 5.0], # Row 0 167 | [6.0, 7.0, 8.0, 9.0, 10.0], # Row 1 168 | [11.0, 12.0, 13.0, 14.0, 15.0] # Row 2 169 | ] 170 | dY = [ 171 | [0.1, 0.2, 0.3, 0.4, 0.5], # Row 0 172 | [0.6, 0.7, 0.8, 0.9, 1.0], # Row 1 173 | [1.1, 1.2, 1.3, 1.4, 1.5] # Row 2 174 | ] 175 | weight = [0.5, 0.5, 0.5, 0.5, 0.5] 176 | """ 177 | 178 | # Each CUDA thread block processes one row of the input 179 | row_idx = tl.program_id(0) 180 | 181 | # Create column indices [0, 1, 2, ..., BLOCK_SIZE-1] 182 | col_offsets = tl.arange(0, BLOCK_SIZE) 183 | 184 | # Create a mask to handle cases where n_cols < BLOCK_SIZE 185 | mask = col_offsets < n_cols 186 | 187 | # Calculate pointers to the current row in each tensor 188 | # For row_idx=1, we're processing the second row of each tensor 189 | dY += row_idx * dY_row_stride 190 | X += row_idx * X_row_stride 191 | inv_var += row_idx 192 | mean += row_idx 193 | 194 | # Load the gradient from upstream (dY) 195 | # For row_idx=1: dY_row = [0.6, 0.7, 0.8, 0.9, 1.0, 0, 0, 0] 196 | dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32) 197 | 198 | # Load the input values 199 | # For row_idx=1: X_row = [6.0, 7.0, 8.0, 9.0, 10.0, 0, 0, 0] 200 | X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) 201 | 202 | # Load weight parameters 203 | # weight_row = [0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0] 204 | weight_row = tl.load(weight + col_offsets, mask = mask, other = 0).to(tl.float32) 205 | 206 | # Load the stored inverse variance and mean from the forward pass 207 | # For row_idx=1: inv_var_val ≈ 0.707, mean_val = 8.0 208 | inv_var_val = tl.load(inv_var).to(tl.float32) 209 | mean_val = tl.load(mean).to(tl.float32) 210 | 211 | # Calculate the normalized input values (same as in forward pass) 212 | # For row_idx=1: normed = [(6.0-8.0)*0.707, (7.0-8.0)*0.707, (8.0-8.0)*0.707, (9.0-8.0)*0.707, (10.0-8.0)*0.707] 213 | # = [-1.414, -0.707, 0.0, 0.707, 1.414] 214 | normed = (X_row - mean_val) * inv_var_val 215 | 216 | # Scale the upstream gradient by the weight 217 | # For row_idx=1: dY_W = [0.6*0.5, 0.7*0.5, 0.8*0.5, 0.9*0.5, 1.0*0.5] 218 | # = [0.3, 0.35, 0.4, 0.45, 0.5] 219 | dY_W = dY_row * weight_row 220 | 221 | # Calculate the gradient with respect to the input 222 | # This follows the chain rule for backpropagation through layer normalization 223 | # The formula has three terms: 224 | # 1. dY_W: direct contribution from upstream gradient 225 | # 2. -tl.sum(dY_W, axis=0)/n_cols: contribution from the mean term 226 | # 3. -normed * tl.sum(dY_W * normed, axis=0)/n_cols: contribution from the variance term 227 | 228 | # In general, the result would be non-zero and would then be scaled by inv_var_val 229 | dX_row = dY_W - tl.sum(dY_W, axis = 0) / n_cols - normed * tl.sum(dY_W * normed, axis = 0) / n_cols 230 | dX_row = dX_row * inv_var_val 231 | 232 | # Store the gradient with respect to the input 233 | # Note: We're reusing the dY tensor to store the result (in-place operation) 234 | tl.store(dY + col_offsets, dX_row, mask = mask) 235 | 236 | 237 | class Fast_Layernorm(torch.autograd.Function): 238 | @staticmethod 239 | def forward(ctx, X, weight, bias, eps): 240 | shape = X.shape 241 | dim = shape[-1] 242 | X = X.view(-1, dim) 243 | n_rows, n_cols = X.shape 244 | BLOCK_SIZE, num_warps = calculate_settings(n_cols) 245 | device = X.device 246 | Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device) 247 | inv_var = torch.empty(n_rows, dtype = torch.float32, device = device) 248 | mean = torch.empty(n_rows, dtype = torch.float32, device = device) 249 | 250 | layernorm_forward[(n_rows,)]( 251 | Y, Y.stride(0), 252 | X, X.stride(0), 253 | weight, 254 | bias, 255 | inv_var, 256 | mean, 257 | n_cols, eps, 258 | BLOCK_SIZE = BLOCK_SIZE, 259 | num_warps = num_warps, 260 | ) 261 | ctx.eps = eps 262 | ctx.BLOCK_SIZE = BLOCK_SIZE 263 | ctx.num_warps = num_warps 264 | ctx.save_for_backward(X, weight, bias, inv_var, mean) 265 | return Y.view(*shape) 266 | pass 267 | 268 | @staticmethod 269 | def backward(ctx, dY): 270 | shape = dY.shape 271 | dim = shape[-1] 272 | dY = dY.view(-1, dim) 273 | X, weight, bias, inv_var, mean = ctx.saved_tensors 274 | n_rows, n_cols = dY.shape 275 | 276 | layernorm_backward[(n_rows,)]( 277 | dY, dY.stride(0), 278 | X, X.stride(0), 279 | weight, 280 | bias, 281 | inv_var, 282 | mean, 283 | n_cols, ctx.eps, 284 | BLOCK_SIZE = ctx.BLOCK_SIZE, 285 | num_warps = ctx.num_warps, 286 | ) 287 | dX = dY.view(*shape) 288 | return dX, None, None, None, None 289 | 290 | def fast_layernorm(layernorm, X): 291 | assert(layernorm.elementwise_affine is True) 292 | W = layernorm.weight 293 | bias = layernorm.bias 294 | eps = layernorm.variance_epsilon if \ 295 | hasattr(layernorm, "variance_epsilon") \ 296 | else layernorm.eps 297 | out = Fast_Layernorm.apply(X, W, bias, eps) 298 | return out -------------------------------------------------------------------------------- /annotated_examples/classics/matmul.py: -------------------------------------------------------------------------------- 1 | import torch, math, random, copy 2 | from torch import Tensor 3 | import triton 4 | import triton.language as tl 5 | import pdb 6 | 7 | @triton.jit 8 | def matmul_kernel( 9 | # Pointers to matrices 10 | a_ptr, b_ptr, c_ptr, 11 | # Matrix dimensions 12 | M, N, K, 13 | # The stride variables represent how much to increase the ptr by when moving by 1 14 | # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` 15 | # by to get the element one row down (A has M rows). 16 | stride_am, stride_ak, # Strides for matrix A (M, K) 17 | stride_bk, stride_bn, # Strides for matrix B (K, N) 18 | stride_cm, stride_cn, # Strides for matrix C (M, N) 19 | # Meta-parameters 20 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # Tile sizes 21 | GROUP_SIZE_M: tl.constexpr, # Number of M-dimension tiles per group (for L2 cache optimization) 22 | ACTIVATION: tl.constexpr # Optional activation function to apply 23 | ): 24 | """Kernel for computing the matmul C = A x B. 25 | A has shape (M, K), B has shape (K, N) and C has shape (M, N) 26 | 27 | Example: 28 | For M=16, N=8, K=16 with BLOCK_SIZE_M=2, BLOCK_SIZE_N=2, BLOCK_SIZE_K=2, GROUP_SIZE_M=3: 29 | 30 | Matrix A (16x16): 31 | [A00, A01, A02, ..., A0,15] 32 | [A10, A11, A12, ..., A1,15] 33 | [... ] 34 | [A15,0, A15,1, ..., A15,15] 35 | 36 | Matrix B (16x8): 37 | [B00, B01, B02, ..., B07] 38 | [B10, B11, B12, ..., B17] 39 | [... ] 40 | [B15,0, B15,1, ..., B15,7] 41 | 42 | Matrix C (16x8): 43 | [C00, C01, C02, ..., C07] 44 | [C10, C11, C12, ..., C17] 45 | [... ] 46 | [C15,0, C15,1, ..., C15,7] 47 | 48 | - We divide matrices into blocks of size 2x2 49 | - Matrix A (16x16) has 8x8 blocks 50 | - Matrix B (16x8) has 8x4 blocks 51 | - Matrix C (16x8) has 8x4 blocks 52 | - We'll have 32 thread blocks computing the 32 blocks of C 53 | """ 54 | # ----------------------------------------------------------- 55 | # STEP 1: Determine which block of the output matrix C this thread block will compute 56 | # ----------------------------------------------------------- 57 | # Each thread block is assigned a unique program ID (pid) 58 | pid = tl.program_id(axis=0) 59 | 60 | # Calculate how many blocks we need in each dimension 61 | # For our example: num_pid_m = ceil(16/2) = 8, num_pid_n = ceil(8/2) = 4 62 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) # Number of blocks in M dimension 63 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # Number of blocks in N dimension 64 | 65 | # L2 cache optimization: Group blocks along M dimension to promote data reuse 66 | # For our example: num_pid_in_group = 3*4 = 12 (total blocks in a group) 67 | num_pid_in_group = GROUP_SIZE_M * num_pid_n 68 | 69 | # Determine which group this thread block belongs to 70 | # We have 3 groups, theoretically they all have 12 blocks 71 | # For pid<12: group_id = 0 72 | # For 12<=pid<24: group_id = 1 73 | # else: group_id = 2 74 | group_id = pid // num_pid_in_group 75 | 76 | # Find the first block index in M dimension for this group 77 | # For group_id=0: first_pid_m = 0 78 | # For group_id=1: first_pid_m = 3 79 | # For group_id=2: first_pid_m = 6 80 | first_pid_m = group_id * GROUP_SIZE_M 81 | 82 | # Calculate actual group size (might be smaller at boundaries) 83 | # For first_pid_m=0: group_size_m = min(8-0, 3) = 3 84 | # For first_pid_m=3: group_size_m = min(8-3, 3) = 3 85 | # For first_pid_m=6: group_size_m = min(8-6, 3) = 2 (last group is smaller) 86 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 87 | 88 | # Calculate the specific block indices this thread block will compute 89 | # For pid=0: pid_m = 0 + ((0 % 12) % 3) = 0, pid_n = (0 % 12) // 3 = 0 90 | # For pid=1: pid_m = 0 + ((1 % 12) % 3) = 1, pid_n = (1 % 12) // 3 = 0 91 | # For pid=2: pid_m = 0 + ((2 % 12) % 3) = 2, pid_n = (2 % 12) // 3 = 0 92 | # For pid=3: pid_m = 0 + ((3 % 12) % 3) = 0, pid_n = (3 % 12) // 3 = 1 93 | # For pid=4: pid_m = 0 + ((4 % 12) % 3) = 1, pid_n = (4 % 12) // 3 = 1 94 | # For pid=5: pid_m = 0 + ((5 % 12) % 3) = 2, pid_n = (5 % 12) // 3 = 1 95 | # For pid=6: pid_m = 0 + ((6 % 12) % 3) = 0, pid_n = (6 % 12) // 3 = 2 96 | # For pid=7: pid_m = 0 + ((7 % 12) % 3) = 1, pid_n = (7 % 12) // 3 = 2 97 | # For pid=8: pid_m = 0 + ((8 % 12) % 3) = 2, pid_n = (8 % 12) // 3 = 2 98 | # For pid=9: pid_m = 0 + ((9 % 12) % 3) = 0, pid_n = (9 % 12) // 3 = 3 99 | # For pid=10: pid_m = 0 + ((10 % 12) % 3) = 1, pid_n = (10 % 12) // 3 = 3 100 | # For pid=11: pid_m = 0 + ((11 % 12) % 3) = 2, pid_n = (11 % 12) // 3 = 3 101 | # For pid=12: pid_m = 3 + ((12 % 12) % 3) = 3, pid_n = (12 % 12) // 3 = 0 102 | # For pid=13: pid_m = 3 + ((13 % 12) % 3) = 4, pid_n = (13 % 12) // 3 = 0 103 | # For pid=14: pid_m = 3 + ((14 % 12) % 3) = 5, pid_n = (14 % 12) // 3 = 0 104 | # For pid=15: pid_m = 3 + ((15 % 12) % 3) = 3, pid_n = (15 % 12) // 3 = 1 105 | # For pid=16: pid_m = 3 + ((16 % 12) % 3) = 4, pid_n = (16 % 12) // 3 = 1 106 | # For pid=17: pid_m = 3 + ((17 % 12) % 3) = 5, pid_n = (17 % 12) // 3 = 1 107 | # For pid=18: pid_m = 3 + ((18 % 12) % 3) = 3, pid_n = (18 % 12) // 3 = 2 108 | # For pid=19: pid_m = 3 + ((19 % 12) % 3) = 4, pid_n = (19 % 12) // 3 = 2 109 | # For pid=20: pid_m = 3 + ((20 % 12) % 3) = 5, pid_n = (20 % 12) // 3 = 2 110 | # For pid=21: pid_m = 3 + ((21 % 12) % 3) = 3, pid_n = (21 % 12) // 3 = 3 111 | # For pid=22: pid_m = 3 + ((22 % 12) % 3) = 4, pid_n = (22 % 12) // 3 = 3 112 | # For pid=23: pid_m = 3 + ((23 % 12) % 3) = 5, pid_n = (23 % 12) // 3 = 3 113 | # For pid=24: pid_m = 6 + ((24 % 12) % 3) = 6, pid_n = (24 % 12) // 3 = 0 114 | # For pid=25: pid_m = 6 + ((25 % 12) % 3) = 7, pid_n = (25 % 12) // 3 = 0 115 | # For pid=26: pid_m = 6 + ((26 % 12) % 3) = 6, pid_n = (26 % 12) // 3 = 1 116 | # For pid=27: pid_m = 6 + ((27 % 12) % 3) = 7, pid_n = (27 % 12) // 3 = 1 117 | # For pid=28: pid_m = 6 + ((28 % 12) % 3) = 6, pid_n = (28 % 12) // 3 = 2 118 | # For pid=29: pid_m = 6 + ((29 % 12) % 3) = 7, pid_n = (29 % 12) // 3 = 2 119 | # For pid=30: pid_m = 6 + ((30 % 12) % 3) = 6, pid_n = (30 % 12) // 3 = 3 120 | # For pid=31: pid_m = 6 + ((31 % 12) % 3) = 7, pid_n = (31 % 12) // 3 = 3 121 | # 122 | # Matrix C (16x8) with blocks of size 2x2 will be computed by these pids: 123 | # [pid=0, pid=3, pid=6, pid=9 ] 124 | # [pid=1, pid=4, pid=7, pid=10] 125 | # [pid=2, pid=5, pid=8, pid=11] 126 | # [pid=12, pid=15, pid=18, pid=21] 127 | # [pid=13, pid=16, pid=19, pid=22] 128 | # [pid=14, pid=17, pid=20, pid=23] 129 | # [pid=24, pid=26, pid=28, pid=30] 130 | # [pid=25, pid=27, pid=29, pid=31] 131 | # 132 | # Swizzle pattern visualization: 133 | # Group 0: Group 1: Group 2: 134 | # +---+---+---+---+ +---+---+---+---+ +---+---+---+---+ 135 | # | 0 | 3 | 6 | 9 | |12 |15 |18 |21 | |24 |26 |28 |30 | 136 | # +---+---+---+---+ +---+---+---+---+ +---+---+---+---+ 137 | # | 1 | 4 | 7 |10 | |13 |16 |19 |22 | |25 |27 |29 |31 | 138 | # +---+---+---+---+ +---+---+---+---+ +---+---+---+---+ 139 | # | 2 | 5 | 8 |11 | |14 |17 |20 |23 | +---+---+---+---+ 140 | # +---+---+---+---+ +---+---+---+---+ 141 | # 142 | # Notice how threads are assigned in column-major order within each group: 143 | # - Within each group, we process blocks in a column-first pattern (0,1,2 then 3,4,5 etc.) 144 | # - This creates spatial locality for memory accesses within each group 145 | # - Adjacent thread blocks process adjacent memory, improving cache efficiency 146 | 147 | pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 148 | pid_n = (pid % num_pid_in_group) // group_size_m 149 | 150 | # ----------------------------------------------------------- 151 | # STEP 2: Create pointers to the blocks of input matrices A and B 152 | # ----------------------------------------------------------- 153 | # Calculate offsets for the block of A and B this thread block will process 154 | # For pid=13 (computing block in row 4, column 0 of C): 155 | # offs_am = [8,9] (rows 8-9 of A) 156 | # offs_bn = [0,1] (columns 0-1 of B) 157 | # offs_k = [0,1] (columns 0-1 of A / rows 0-1 of B) 158 | offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M 159 | offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N 160 | offs_k = tl.arange(0, BLOCK_SIZE_K) 161 | 162 | # Create pointer blocks for A and B 163 | # Strides represent the memory distance between consecutive elements in a dimension: 164 | # - stride_am: bytes between consecutive rows in matrix A (used to move between rows of A) 165 | # - stride_ak: bytes between consecutive columns in matrix A (used to move between columns of A) 166 | # - stride_bk: bytes between consecutive rows in matrix B (used to move between rows of B) 167 | # - stride_bn: bytes between consecutive columns in matrix B (used to move between columns of B) 168 | # 169 | # For pid=13 (computing block in row 4, column 0 of C): 170 | # a_ptrs calculates pointers to A[8:10, 0:2] by: 171 | # - Starting at base pointer a_ptr 172 | # - Adding row offsets (offs_am[:, None] * stride_am) to move to rows 8-9 173 | # - Adding column offsets (offs_k[None, :] * stride_ak) to access columns 0-1 174 | # b_ptrs calculates pointers to B[0:2, 0:2] by: 175 | # - Starting at base pointer b_ptr 176 | # - Adding row offsets (offs_k[:, None] * stride_bk) to move to rows 0-1 177 | # - Adding column offsets (offs_bn[None, :] * stride_bn) to access columns 0-1 178 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 179 | b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 180 | 181 | # ----------------------------------------------------------- 182 | # STEP 3: Compute the matrix multiplication C = A × B block by block 183 | # ----------------------------------------------------------- 184 | # Initialize accumulator with zeros 185 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 186 | 187 | # For our example with K=16 and BLOCK_SIZE_K=2, we need 8 iterations 188 | for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 189 | # Load blocks of A and B 190 | # For pid=0, iteration 0: we load A[0:2, 0:2] and B[0:2, 0:2] 191 | # For pid=0, iteration 1: we load A[0:2, 2:4] and B[2:4, 0:2] 192 | # For pid=0, iteration 2: we load A[0:2, 4:6] and B[4:6, 0:2] 193 | # And so on... 194 | 195 | # Calculate how many elements remain in the K dimension for the current iteration 196 | k_remaining = K - k * BLOCK_SIZE_K 197 | 198 | # a_mask handles the columns of matrix A (K dimension) 199 | # For example, if K=10 and BLOCK_SIZE_K=4, in the last iteration (k=2): 200 | # - k_remaining = 10 - 2*4 = 2 201 | # - offs_k = [0,1,2,3] 202 | # - a_mask will be [[True,True,False,False]] 203 | # This ensures we only load the valid 2 remaining columns 204 | a_mask = (offs_k[None, :] < k_remaining) 205 | 206 | # b_mask handles the rows of matrix B (K dimension) 207 | # Using the same example, b_mask will be: 208 | # [[True], [True], [False], [False]] 209 | # This ensures we only load the valid 2 remaining rows 210 | b_mask = (offs_k[:, None] < k_remaining) 211 | a = tl.load(a_ptrs, mask=a_mask, other=0.0) 212 | b = tl.load(b_ptrs, mask=b_mask, other=0.0) 213 | 214 | # Compute matrix multiplication for this block and accumulate 215 | # For pid=0, iteration 0: 216 | # C[0,0] += A[0,0]*B[0,0] + A[0,1]*B[1,0] 217 | # C[0,1] += A[0,0]*B[0,1] + A[0,1]*B[1,1] 218 | # C[1,0] += A[1,0]*B[0,0] + A[1,1]*B[1,0] 219 | # C[1,1] += A[1,0]*B[0,1] + A[1,1]*B[1,1] 220 | accumulator = tl.dot(a, b, accumulator) 221 | 222 | # Move pointers to the next K block 223 | # For iteration 1: a_ptrs now points to A[0:2, 2:4] 224 | # For iteration 1: b_ptrs now points to B[2:4, 0:2] 225 | a_ptrs += BLOCK_SIZE_K * stride_ak 226 | b_ptrs += BLOCK_SIZE_K * stride_bk 227 | 228 | # ----------------------------------------------------------- 229 | # STEP 4: Apply activation function (if specified) and prepare for output 230 | # ----------------------------------------------------------- 231 | # Apply activation function if specified 232 | if ACTIVATION == "leaky_relu": 233 | accumulator = leaky_relu(accumulator) 234 | 235 | # Convert back to float16 for output 236 | c = accumulator.to(tl.float16) 237 | 238 | # ----------------------------------------------------------- 239 | # STEP 5: Write the computed block back to matrix C 240 | # ----------------------------------------------------------- 241 | # Calculate global indices for this block in matrix C 242 | # same as the offs_am and offs_bn 243 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 244 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 245 | 246 | # Create pointers to the output locations in C 247 | # same as the a_ptrs and b_ptrs 248 | c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 249 | 250 | # Create mask to handle boundary conditions 251 | c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 252 | 253 | # Store the computed values to matrix C 254 | # For pid=0: This writes to C[0:2, 0:2] 255 | tl.store(c_ptrs, c, mask=c_mask) 256 | 257 | @triton.jit 258 | def leaky_relu(x): 259 | return tl.where(x >= 0, x, 0.01 * x) 260 | 261 | def matmul(a, b, activation=None): 262 | M, K = a.shape 263 | K, N = b.shape 264 | # Initialize output tensor 265 | c = torch.empty((M, N), device=a.device, dtype=torch.float16) 266 | BLOCK_SIZE_M = 64 267 | BLOCK_SIZE_N = 64 268 | BLOCK_SIZE_K = 64 269 | GROUP_SIZE_M = 8 270 | # Calculate grid dimensions based on block sizes 271 | grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) 272 | 273 | # Launch the kernel 274 | matmul_kernel[grid]( 275 | a_ptr=a.data_ptr(), 276 | b_ptr=b.data_ptr(), 277 | c_ptr=c.data_ptr(), 278 | M=M, N=N, K=K, 279 | stride_am=a.stride(0), 280 | stride_ak=a.stride(1), 281 | stride_bk=b.stride(0), 282 | stride_bn=b.stride(1), 283 | stride_cm=c.stride(0), 284 | stride_cn=c.stride(1), 285 | ACTIVATION=activation, 286 | BLOCK_SIZE_M=BLOCK_SIZE_M, 287 | BLOCK_SIZE_N=BLOCK_SIZE_N, 288 | BLOCK_SIZE_K=BLOCK_SIZE_K, 289 | GROUP_SIZE_M=GROUP_SIZE_M, 290 | ) 291 | 292 | return c 293 | 294 | # Test function for the matmul implementation 295 | def test_matmul(): 296 | import torch 297 | for m, n, k in [(32, 32, 32), (256, 512, 128), (1024, 1024, 1024)]: 298 | a = torch.randn((m, k), device='cuda', dtype=torch.float16) 299 | b = torch.randn((k, n), device='cuda', dtype=torch.float16) 300 | c_ref = torch.matmul(a, b) 301 | c_triton = matmul(a, b) 302 | assert torch.allclose(c_ref, c_triton, rtol=1e-2, atol=1e-2), f"Failed for size {m}x{k}x{n}" 303 | 304 | print("Size tests passed!") 305 | 306 | if __name__ == "__main__": 307 | test_matmul() 308 | -------------------------------------------------------------------------------- /annotated_examples/classics/swiglu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # Modifications Copyright 2025 Mekkcyber. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import triton 17 | import triton.language as tl 18 | import torch 19 | 20 | MAX_FUSED_SIZE : int = 65536 21 | next_power_of_2 = triton.next_power_of_2 22 | 23 | def calculate_settings(n : int) -> (int, int,): 24 | BLOCK_SIZE : int = next_power_of_2(n) 25 | if BLOCK_SIZE > MAX_FUSED_SIZE: 26 | raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\ 27 | f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") 28 | num_warps : int = 4 29 | if BLOCK_SIZE >= 32768: num_warps = 32 30 | elif BLOCK_SIZE >= 8192: num_warps = 16 31 | elif BLOCK_SIZE >= 2048: num_warps = 8 32 | return BLOCK_SIZE, num_warps 33 | 34 | @triton.jit 35 | def _forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): 36 | block_idx = tl.program_id(0) 37 | offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 38 | mask = offsets < n_elements 39 | 40 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) 41 | g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32) 42 | 43 | # f = e * sigmoid(e) 44 | f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row)) 45 | f_row = f_row.to(g_row.dtype) 46 | # h = f * g 47 | h_row = f_row * g_row 48 | 49 | # Store h 50 | tl.store(h + offsets, h_row, mask = mask) 51 | 52 | 53 | def swiglu_forward_kernel(e, g): 54 | batch, seq_len, hd = e.shape 55 | n_elements = e.numel() 56 | h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device) 57 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 58 | _forward_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,) 59 | return h 60 | 61 | 62 | @triton.jit 63 | def _backward_kernel(dY, e, g, n_elements, BLOCK_SIZE : tl.constexpr,): 64 | """ 65 | Backward pass for SwiGLU activation function. 66 | 67 | Forward pass (for reference): 68 | f = e * sigmoid(e) # SiLU/Swish activation 69 | h = f * g # Gating mechanism 70 | 71 | Backward pass derivation: 72 | Given dL/dh (dY), we need to compute: 73 | - dL/de: Gradient with respect to first input 74 | - dL/dg: Gradient with respect to second input 75 | 76 | Using the chain rule: 77 | dL/dg = dL/dh * dh/dg = dY * f 78 | dL/de = dL/dh * dh/de = dY * dh/de = dY * g * df/de 79 | Where df/de = sigmoid(e) + e * sigmoid(e) * (1 - sigmoid(e)) 80 | = sigmoid(e) * (1 + e * (1 - sigmoid(e))) (see backprop_math/swiglu.md) 81 | """ 82 | # Get the block index and calculate offsets for parallel processing 83 | block_idx = tl.program_id(0) 84 | offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 85 | mask = offsets < n_elements 86 | 87 | # Load input tensors 88 | dY_row = tl.load(dY + offsets, mask = mask, other = 0) 89 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) 90 | g_row = tl.load(g + offsets, mask = mask, other = 0) 91 | 92 | # Compute sigmoid(e) - needed for both forward and backward calculations 93 | se_row = tl.sigmoid(e_row) # sigmoid(e) 94 | 95 | # Compute f = e * sigmoid(e) (SiLU/Swish activation) 96 | f_row = se_row * e_row 97 | f_row = f_row.to(dY_row.dtype) # Convert back to original dtype 98 | 99 | # Compute dL/dg = dY * f 100 | dg_row = dY_row * f_row 101 | 102 | # Compute dL/de = dY * g * sigmoid(e) * (1 + e * (1 - sigmoid(e))) 103 | # This is the derivative of SwiGLU with respect to e 104 | de_row = dY_row.to(tl.float32) * g_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row)) 105 | de_row = de_row.to(dY_row.dtype) # Convert back to original dtype 106 | 107 | # Store computed gradients back to the input buffers 108 | # Note: We're reusing the input buffers to store the gradients 109 | tl.store(e + offsets, de_row, mask = mask) # Store dL/de in e buffer 110 | tl.store(g + offsets, dg_row, mask = mask) # Store dL/dg in g buffer 111 | 112 | 113 | def swiglu_DWf_DW_dfg_kernel(dY, e, g): 114 | n_elements = e.numel() 115 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 116 | _backward_kernel[grid](dY, e, g, n_elements, BLOCK_SIZE = 1024,) 117 | return e, g 118 | 119 | def test_swiglu_correctness(): 120 | """ 121 | Test the correctness of the SwiGLU implementation by comparing with a reference implementation. 122 | Tests both forward and backward passes for SwiGLU (SiLU/Swish). 123 | """ 124 | import torch 125 | import torch.nn.functional as F 126 | 127 | def swiglu_reference_forward(e, g): 128 | """Reference implementation of SwiGLU forward pass""" 129 | return g * (e * F.sigmoid(e)) 130 | 131 | forward_kernel = _forward_kernel 132 | backward_kernel = _backward_kernel 133 | 134 | batch_size, seq_len, hidden_dim = 2, 10, 128 135 | e = torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.float32, device='cuda') 136 | g = torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.float32, device='cuda') 137 | 138 | h = torch.empty_like(e) 139 | 140 | n_elements = e.numel() 141 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 142 | forward_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024) 143 | 144 | our_output = h.clone() 145 | 146 | ref_output = swiglu_reference_forward(e, g) 147 | 148 | max_diff = torch.max(torch.abs(ref_output - our_output)) 149 | print(f"Max difference in SwiGLU forward pass: {max_diff.item()}") 150 | assert max_diff < 1e-5, "SwiGLU forward pass implementation is incorrect!" 151 | 152 | # Test backward pass 153 | dY = torch.randn_like(h) 154 | 155 | # Compute reference gradients 156 | e.requires_grad_(True) 157 | g.requires_grad_(True) 158 | ref_output = swiglu_reference_forward(e, g) 159 | ref_output.backward(dY) 160 | ref_de = e.grad.clone() 161 | ref_dg = g.grad.clone() 162 | 163 | backward_kernel[grid](dY, e, g, n_elements, BLOCK_SIZE = 1024) 164 | 165 | max_diff_de = torch.max(torch.abs(ref_de - e)) 166 | print(f"Max difference in SwiGLU backward pass (de): {max_diff_de.item()}") 167 | assert max_diff_de < 1e-5, "SwiGLU backward pass implementation for de is incorrect!" 168 | 169 | max_diff_dg = torch.max(torch.abs(ref_dg - g)) 170 | print(f"Max difference in SwiGLU backward pass (dg): {max_diff_dg.item()}") 171 | assert max_diff_dg < 1e-5, "SwiGLU backward pass implementation for dg is incorrect!" 172 | 173 | print("All tests passed!") 174 | 175 | 176 | if __name__ == "__main__": 177 | test_swiglu_correctness() 178 | -------------------------------------------------------------------------------- /annotated_examples/gemlite/gemm.py: -------------------------------------------------------------------------------- 1 | # Written by Dr. Hicham Badri @Mobius Labs GmbH - 2024 2 | # Modified by Mekkcyber - 2025 3 | #******************************************************* 4 | import torch, time 5 | import triton 6 | import triton.language as tl 7 | 8 | ######################################################################################################################################################################## 9 | 10 | # Prerequisities for the gemm, for better understanding start from the gemm_kernel function, everything is explained there 11 | @triton.jit 12 | def dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample: tl.constexpr, W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr): 13 | """ 14 | Dequantizes packed integer values into floating point values using various quantization schemes. 15 | 16 | Args: 17 | b: Packed quantized values (typically int32) 18 | scales: Scaling factors for dequantization (per group or channel) 19 | zeros: Zero points for asymmetric quantization (per group or channel) 20 | q_shift: Bit shift amount for unpacking elements from packed format 21 | meta_dtype: Target data type for metadata operations 22 | unpack_mask: Bit mask for extracting individual elements (e.g., 0xF for 4-bit) 23 | elements_per_sample: Number of quantized elements packed into each storage unit 24 | W_group_mode: Quantization scheme to use (1-4) 25 | zero_is_scalar: Whether zero point is shared across all elements 26 | 27 | Returns: 28 | Dequantized tensor in floating point format 29 | """ 30 | # Step 1: Unpack the elements if they are packed (e.g., 8 4-bit values in one int32) 31 | if(elements_per_sample > 1): 32 | # Extract individual quantized values using bit shifting and masking 33 | # q_shift determines which element to extract based on position 34 | b = (b >> q_shift) & unpack_mask # int32 -> int32 35 | 36 | # Step 2: Apply the appropriate dequantization formula based on W_group_mode 37 | 38 | if(W_group_mode == 1): # Shift-only mode (zero-point subtraction) 39 | # Formula: dequantized = quantized - zero_point 40 | b = b.to(meta_dtype) - zeros 41 | 42 | if(W_group_mode == 2): # Scale-only mode (symmetric quantization) 43 | # Formula: dequantized = quantized * scale 44 | # Used when quantized values are centered around zero 45 | b = b.to(meta_dtype) * scales 46 | 47 | if(W_group_mode == 3): # Scale and shift mode (asymmetric quantization) 48 | # Formula: dequantized = (quantized - zero_point) * scale 49 | if(zero_is_scalar): 50 | # When zero_point is shared across all elements (memory optimization) 51 | b = (b - zeros).to(meta_dtype) * scales 52 | else: 53 | # When each group has its own zero_point 54 | b = (b.to(meta_dtype) - zeros) * scales 55 | 56 | if(W_group_mode == 4): # Fused multiply-add mode 57 | # Formula: dequantized = quantized * scale + zero 58 | # Uses fused multiply-add for better performance 59 | # Note: in this mode, 'zeros' is actually an additive term, not a zero point 60 | b = tl.fma(b.to(meta_dtype), scales, zeros) 61 | 62 | return b 63 | 64 | @triton.jit 65 | def swizzle_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr): 66 | grid_m = tl.cdiv(M, BLOCK_SIZE_M) 67 | grid_n = tl.cdiv(N, BLOCK_SIZE_N) 68 | width = GROUP_SIZE_M * grid_n 69 | group_id = pid // width 70 | group_size = tl.minimum(grid_m - group_id * GROUP_SIZE_M, GROUP_SIZE_M) 71 | pid_m = group_id * GROUP_SIZE_M + (pid % group_size) 72 | pid_n = (pid % width) // group_size 73 | return pid_m, pid_n 74 | 75 | @triton.jit 76 | def linear_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr): 77 | pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N) 78 | pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N) 79 | return pid_m, pid_n 80 | ######################################################################################################################################################################## 81 | 82 | # START HERE for the gemm kernel 83 | @triton.jit 84 | def gemm_kernel( 85 | a_ptr, b_ptr, c_ptr, 86 | scales_ptr, zeros_ptr, scales_a_ptr, 87 | M, N, K, 88 | ######### Quant parms ######### 89 | W_nbits: tl.constexpr, 90 | group_size: tl.constexpr, 91 | unpack_mask: tl.constexpr, 92 | elements_per_sample: tl.constexpr, 93 | ######### Strides ######### 94 | stride_am, stride_ak, 95 | stride_bk, stride_bn, 96 | stride_cm, stride_cn, 97 | stride_meta_g, stride_meta_n, 98 | ######### Dtypes ######### 99 | input_dtype: tl.constexpr, 100 | output_dtype: tl.constexpr, 101 | acc_dtype: tl.constexpr, 102 | meta_dtype: tl.constexpr, 103 | ######### Meta-data mode ######### 104 | channel_scale_mode: tl.constexpr, 105 | W_group_mode: tl.constexpr, 106 | zero_is_scalar: tl.constexpr, 107 | ######### tuning params ######### 108 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, 109 | GROUP_SIZE_M: tl.constexpr, 110 | A_load_order: tl.constexpr, meta_evict_policy: tl.constexpr, 111 | data_contiguous: tl.constexpr, 112 | ): 113 | """ 114 | Based on https://github.com/fpgaminer/GPTQ-triton 115 | GEMM for C = matmul(A, dequantize(B, scales, zeros)) 116 | A is of shape (M, K): float16 or bfloat16 117 | B is of shape (K//elements_per_sample, N): int32 as a packed matrix 118 | C is of shape (M, N): float16 or bfloat16 depending on the input A 119 | scales and zeros is of shape (group_size, N): float16 or bfloat16 120 | 121 | BLOCK_SIZE_M >=16 122 | BLOCK_SIZE_K <= group_size 123 | """ 124 | 125 | # This kernel implements a quantized matrix multiplication operation where: 126 | # - Matrix A is a full-precision matrix (float16/bfloat16) 127 | # - Matrix B is a quantized matrix (packed into int32) 128 | # - The result C = A @ dequantize(B) 129 | 130 | # Example: 131 | # If we have: 132 | # - A: 128x512 matrix (M=128, K=512) in float16 133 | # - B: quantized 128x256 matrix packed into int32 134 | # - W_nbits=4 (4-bit quantization) 135 | # - elements_per_sample=8 (8 elements packed into each int32) 136 | # B should be dequantized to a 512x256 matrix in float16 137 | # Then C will be a 128x256 matrix in float16 138 | 139 | # Get the program ID which identifies which tile this thread block processes 140 | pid = tl.program_id(axis=0) 141 | 142 | # for each pid, we need to find the corresponding block in the output matrix, we can do this in two ways: 143 | # 1. swizzle_tile: Creating groups horizontally and inside each group, pids are mapped vertically 144 | # Example of swizzle pattern with GROUP_SIZE_M=2: 145 | # pid layout in the output matrix (each number represents a block's pid): 146 | # 0 2 4 6 147 | # 1 3 5 7 148 | # 8 10 12 14 149 | # 9 11 13 15 150 | # This improves cache locality by keeping adjacent thread blocks working on adjacent rows, see classic/matmul.py for more details 151 | # 2. linear_tile: we simply set pid_m = pid // (tl.cdiv(N, BLOCK_SIZE_N)) and pid_n = pid % (tl.cdiv(N, BLOCK_SIZE_N)) 152 | # Example of linear pattern with GROUP_SIZE_M=2: 153 | # pid layout in the output matrix (each number represents a block's pid): 154 | # 0 1 2 3 155 | # 4 5 6 7 156 | # 8 9 10 11 157 | # 12 13 14 15 158 | # This is simpler to compute but may result in poorer cache locality 159 | pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M) 160 | 161 | # Calculate how many blocks we need in the K dimension 162 | # For example, if K=512 and BLOCK_SIZE_K=64, we need 8 iterations 163 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 164 | 165 | # Calculate offsets for each thread within the block 166 | # If BLOCK_SIZE_M=16, BLOCK_SIZE_N=16, each thread block processes a 16x16 tile 167 | # offs_m will be [0,1,2,...,15] + (pid_m * 16) 168 | # offs_n will be [0,1,2,...,15] + (pid_n * 16) 169 | # offs_k will be [0,12,...,BLOCK_SIZE_K-1] 170 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 171 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 172 | offs_k = tl.arange(0, BLOCK_SIZE_K) 173 | 174 | # Optimize memory access patterns based on data layout 175 | if(data_contiguous): 176 | # multiple_of(tensor, values): Informs the compiler that the shape dimensions of the tensor are multiples of 'values' 177 | # - This allows the compiler to optimize memory access patterns and vectorize loads 178 | # - For example, if the shape of offs_am is a multiple of BLOCK_SIZE_M, the compiler can generate more efficient code 179 | # 180 | # max_contiguous(tensor, values): Informs the compiler that the first 'values' elements in the tensor are contiguous 181 | # - This helps the compiler generate more efficient memory access patterns 182 | # - For example, if the first BLOCK_SIZE_M elements in offs_am are contiguous (0,1,2,...), 183 | # the compiler can use coalesced memory accesses 184 | offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) 185 | 186 | # Calculate pointers to input matrices 187 | # For matrix A: If stride_am=K and stride_ak=1, this accesses A in row-major order 188 | # Example: For A[2,3] with K=512, the offset would be 2*512 + 3 = 1027 189 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 190 | # Create a mask to handle boundary conditions (when M is not divisible by BLOCK_SIZE_M) 191 | a_mask = (offs_am[:, None] < M) 192 | 193 | # For matrix B: Calculate pointers based on the packed format 194 | # If B is packed with 8 elements per int32, we divide offs_k by 8 to get the actual memory location 195 | # 196 | # Example: With 8-bit elements packed into 32-bit integers (elements_per_sample = 4): 197 | # - For offs_k values [0,1,2,3], the division gives [0,0,0,0] 198 | # - For offs_k values [4,5,6,7], the division gives [1,1,1,1] 199 | # 200 | # This means that elements 0-3 are stored in the first 32-bit word, elements 4-7 in the second word, etc. 201 | # Later, we'll use q_shift to extract the correct bits from each packed word. 202 | b_ptrs = b_ptr + ((offs_k[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) 203 | 204 | # Calculate bit shift for unpacking quantized values from packed integers 205 | # 206 | # Example 1: With 4-bit quantization (W_nbits=4) and 8 elements per int32 (elements_per_sample=8): 207 | # - For offs_k = [0,1,2,3,4,5,6,7]: 208 | # offs_k % elements_per_sample = [0,1,2,3,4,5,6,7] 209 | # q_shift = [0,4,8,12,16,20,24,28] bits 210 | # 211 | # Example 2: With 8-bit quantization (W_nbits=8) and 4 elements per int32 (elements_per_sample=4): 212 | # - For offs_k = [0,1,2,3,4,5,6,7]: 213 | # offs_k % elements_per_sample = [0,1,2,3,0,1,2,3] 214 | # q_shift = [0,8,16,24,0,8,16,24] bits 215 | # 216 | # The modulo operation (%) wraps around when we exceed elements_per_sample, 217 | # ensuring we extract the correct element position within each packed integer. 218 | q_shift = ((offs_k % elements_per_sample) * W_nbits).to(tl.int32)[:, None] 219 | 220 | # Calculate pointers to quantization metadata (scales and zeros) 221 | # These pointers point to the start of each column in the metadata matrices 222 | # For example, if we have a matrix with N columns, each column has its own 223 | # scale and zero point values for dequantization 224 | scales_ptrs = scales_ptr + offs_bn[None, :] * stride_meta_n 225 | zeros_ptrs = zeros_ptr + offs_bn[None, :] * stride_meta_n 226 | 227 | # Calculate stride multiplier for group quantization 228 | # If group_size=64 and BLOCK_SIZE_K=32, stride_mul=0.5 229 | # This means we need a new scale/zero for every 2 K-dimension blocks 230 | stride_mul = BLOCK_SIZE_K / group_size 231 | 232 | # If zero point is a scalar (same for all elements), load it once 233 | # eviction_policy='evict_last' tells the compiler how to manage the cache: 234 | # - 'evict_last': Keep the data in cache as long as possible (good for reused data) 235 | # - 'evict_first': Evict from cache quickly (good for data used only once) 236 | # This helps optimize memory access patterns and cache utilization 237 | if(zero_is_scalar): 238 | zero_scalar = tl.load(zeros_ptr, eviction_policy='evict_last') 239 | 240 | # Initialize accumulator for matrix multiplication 241 | # This will store the partial sums during the computation 242 | # For a 16x16 tile, this is a 16x16 matrix initialized to zeros 243 | acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) 244 | 245 | # Main computation loop - iterate over blocks in the K dimension 246 | # For example, if K=512 and BLOCK_SIZE_K=64, we do 8 iterations 247 | for k in range(num_pid_k): 248 | # Load matrix A based on the specified loading order 249 | # Different loading orders can help with instruction scheduling and can lead to better performance 250 | if(A_load_order == 0): # Early load - load A before B 251 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last') 252 | 253 | # Load packed quantized values from matrix B 254 | # Load packed quantized weights with 'evict_first' policy since we'll immediately 255 | # dequantize these values and won't need the packed representation in cache 256 | # Each row of B is repeated elements_per_sample times, this way we can unpack it using the q_shift 257 | # If you don't get why you can look at how b_ptrs is computed 258 | b = tl.load(b_ptrs, eviction_policy='evict_first') 259 | 260 | if(A_load_order == 1): # Load A after loading B 261 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last') 262 | 263 | # Load quantization metadata (scales and zero points) based on the group mode 264 | # Different modes use different patterns of scales and zeros 265 | # W_group_mode controls how quantization parameters (scales and zeros) are applied: 266 | 267 | # Mode 0: No quantization - neither scales nor zeros are used 268 | # Mode 1: Zero-point only quantization - only zero points are used, no scales 269 | # Used for integer quantization where only zero-point shifting is needed 270 | # Mode 2: Scale-only quantization - only scales are used, no zero points 271 | # Used for symmetric quantization where values are centered around zero 272 | # Mode 3: Full quantization - both scales and zero points are used 273 | # Used for asymmetric quantization with arbitrary ranges 274 | # Mode 4: Asymmetric (Grouped - b*scales + zeros) 275 | 276 | if(W_group_mode > 0): 277 | # Calculate offset for grouped quantization 278 | # For every group_size weights, we have a single scale/zero point 279 | # stride_mul = BLOCK_SIZE_K / group_size controls how often we need new metadata 280 | # 281 | # Examples: 282 | # 1. If group_size=64 and BLOCK_SIZE_K=64, stride_mul=1 283 | # We need a new scale for each K block (k_m increases by 1 each iteration) 284 | # 2. If group_size=128 and BLOCK_SIZE_K=64, stride_mul=0.5 285 | # We need a new scale every 2 K blocks (k_m increases by 1 every 2 iterations) 286 | # 287 | # This mapping ensures we use the correct scale/zero for each weight group 288 | k_m = (k * stride_mul).to(tl.int32) 289 | 290 | # Load scales if needed (modes 2, 3, 4) 291 | # Example: For per-channel quantization, each output channel has its own scale 292 | if(W_group_mode >= 2): # [2, 3, 4] 293 | scales = tl.load(scales_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) 294 | else: 295 | scales = None 296 | 297 | # Load zero points if needed (modes 1, 3, 4) 298 | # Example: For per-channel quantization, each output channel has its own zero point 299 | if(W_group_mode == 1 or W_group_mode >= 3): # [1, 3, 4] 300 | if(zero_is_scalar): 301 | # If zero_is_scalar=1, use the same zero point for all elements 302 | # This saves memory and bandwidth when all channels share the same zero point 303 | zeros = zero_scalar 304 | else: 305 | # Otherwise load per-group zero points from memory 306 | # stride_meta_g controls the spacing between groups in memory 307 | zeros = tl.load(zeros_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) 308 | else: 309 | zeros = None 310 | 311 | if(A_load_order == 2): # Mid load - load A after loading metadata 312 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last') 313 | 314 | # Unpack and dequantize the values from matrix B 315 | # Dequantization formula depends on the W_group_mode: 316 | # - Mode 0: No dequantization just unpacking 317 | # - Mode 1: dequantized_value = quantized_value - zero_point 318 | # - Mode 2: dequantized_value = quantized_value * scale 319 | # - Mode 3: dequantized_value = (quantized_value - zero_point) * scale 320 | # - Mode 4: dequantized_value = b*scales + zeros 321 | b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) 322 | 323 | if(A_load_order == 3): # Late load - load A after dequantization 324 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last') 325 | 326 | # Perform matrix multiplication for this block 327 | # For 16x16 tiles, this computes a 16x16 result and adds to accumulator 328 | # Example: If a is 16x64 and b is 64x16, this computes a 16x16 partial result 329 | acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype, input_precision="tf32") 330 | 331 | # Advance pointers for the next iteration 332 | # Move to the next block in the K dimension 333 | a_ptrs += BLOCK_SIZE_K * stride_ak 334 | b_ptrs += (BLOCK_SIZE_K // elements_per_sample) * stride_bk 335 | 336 | # Apply channel-wise scaling to the result if needed 337 | # This is used for various quantization schemes 338 | if(channel_scale_mode == 1): # Weight-only scaling 339 | # Load scales for each output channel 340 | # Example: If each output has a different scale factor 341 | scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) 342 | # Apply scales to each column of the result 343 | acc = acc.to(meta_dtype) * scales_b[None, :] 344 | 345 | if(channel_scale_mode == 2): # Activation-only scaling 346 | # Load scaling factors for each input channel (row of the activation matrix) 347 | scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) 348 | 349 | # Create a vector of ones for the output dimension 350 | # Since we're only scaling by activation (input channels), we use 1.0 for all output channels 351 | # This creates a vector of size BLOCK_SIZE_N filled with 1's of the metadata type 352 | scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=meta_dtype) 353 | 354 | # Apply the scaling factors to the accumulated result: 355 | # 1. scales_a[:, None]: Reshape scales_a from [BLOCK_SIZE_M] to [BLOCK_SIZE_M, 1] for broadcasting 356 | # 2. scales_b[None, :]: Reshape scales_b from [BLOCK_SIZE_N] to [1, BLOCK_SIZE_N] for broadcasting 357 | # 3. This creates a scaling matrix of shape [BLOCK_SIZE_M, BLOCK_SIZE_N] where each row is scaled by its corresponding scales_a value 358 | # 4. Multiply the accumulator by this scaling matrix element-wise 359 | acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) 360 | 361 | if(channel_scale_mode == 3): # Both weight and activation scaling 362 | # Load scales for both input and output channels 363 | scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) 364 | scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) 365 | # Apply both scales to the result 366 | # Example: If row 2 has scale 0.5 and column 3 has scale 0.25, 367 | # element [2,3] is multiplied by 0.5*0.25=0.125 368 | acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) 369 | 370 | # Convert the result to the output data type 371 | acc = acc.to(output_dtype) 372 | 373 | # Calculate pointers to the output matrix 374 | # Similar to input pointers, but for matrix C 375 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 376 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 377 | offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) 378 | c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) 379 | 380 | # Store the result to the output matrix 381 | # Use masks to handle boundary conditions 382 | tl.store(c_ptrs, acc, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) 383 | 384 | 385 | def test_kernel(): 386 | 387 | # Set up test parameters 388 | M, N, K = 128, 256, 512 389 | W_nbits = 4 390 | group_size = 64 391 | elements_per_sample = 8 # For 4-bit quantization, 8 elements fit in one int32 392 | 393 | # Create input matrices 394 | a = torch.randn(M, K, dtype=torch.float16, device='cuda') 395 | 396 | # Create quantized weights (normally this would come from a quantization process) 397 | # For testing, we'll create random data 398 | b_unpacked = torch.randint(-8, 7, (K, N), dtype=torch.int8, device='cuda') 399 | 400 | # Pack the weights into int32 401 | b_packed = torch.zeros((K // elements_per_sample, N), dtype=torch.int32, device='cuda') 402 | for i in range(elements_per_sample): 403 | b_packed |= (b_unpacked[i::elements_per_sample, :].to(torch.int32) & ((1 << W_nbits) - 1)) << (i * W_nbits) 404 | 405 | # Create scales and zeros for dequantization 406 | scales = torch.ones((K // group_size, N), dtype=torch.float16, device='cuda') 407 | zeros = torch.zeros((K // group_size, N), dtype=torch.float16, device='cuda') 408 | scales_a = torch.ones(M, dtype=torch.float16, device='cuda') 409 | 410 | # Output matrix 411 | c_triton = torch.zeros((M, N), dtype=torch.float16, device='cuda') 412 | 413 | # Calculate strides 414 | stride_am, stride_ak = a.stride(0), a.stride(1) 415 | stride_bk, stride_bn = b_packed.stride(0), b_packed.stride(1) 416 | stride_cm, stride_cn = c_triton.stride(0), c_triton.stride(1) 417 | stride_meta_g, stride_meta_n = scales.stride(0), scales.stride(1) 418 | 419 | # Define grid 420 | grid = lambda META: ( 421 | triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), 422 | ) 423 | 424 | # For verification, compute reference result using PyTorch 425 | # Dequantize the weights 426 | b_dequantized = torch.zeros((K, N), dtype=torch.float16, device='cuda') 427 | for g in range(K // group_size): 428 | start_idx = g * group_size 429 | end_idx = min((g + 1) * group_size, K) 430 | for i in range(start_idx, end_idx): 431 | element_idx = i % elements_per_sample 432 | packed_idx = i // elements_per_sample 433 | shift = element_idx * W_nbits 434 | mask = (1 << W_nbits) - 1 435 | b_dequantized[i] = ((b_packed[packed_idx] >> shift) & mask).to(torch.float16) 436 | # Apply scales 437 | b_dequantized[i] *= scales[g] 438 | 439 | # Compute reference result 440 | c_ref = torch.matmul(a, b_dequantized) 441 | 442 | # Check correctness 443 | max_diff = torch.max(torch.abs(c_ref - c_triton)) 444 | print(f"Max difference between PyTorch and Triton: {max_diff.item()}") 445 | 446 | # Benchmark 447 | warmup = 25 448 | rep = 100 449 | 450 | torch.cuda.synchronize() 451 | start = time.time() 452 | for _ in range(warmup + rep): 453 | gemm_kernel[grid]( 454 | a_ptr=a, b_ptr=b_packed, c_ptr=c_triton, 455 | scales_ptr=scales, zeros_ptr=zeros, scales_a_ptr=scales_a, 456 | M=M, N=N, K=K, 457 | W_nbits=W_nbits, group_size=group_size, 458 | unpack_mask=(1 << W_nbits) - 1, elements_per_sample=elements_per_sample, 459 | stride_am=stride_am, stride_ak=stride_ak, 460 | stride_bk=stride_bk, stride_bn=stride_bn, 461 | stride_cm=stride_cm, stride_cn=stride_cn, 462 | stride_meta_g=stride_meta_g, stride_meta_n=stride_meta_n, 463 | input_dtype=tl.float16, output_dtype=tl.float16, acc_dtype=tl.float32, meta_dtype=tl.float16, 464 | channel_scale_mode=1, W_group_mode=1, zero_is_scalar=0, 465 | BLOCK_SIZE_M=16, BLOCK_SIZE_N=16, BLOCK_SIZE_K=64, GROUP_SIZE_M=8, 466 | A_load_order=1, meta_evict_policy='evict_last', data_contiguous=1, 467 | ) 468 | torch.cuda.synchronize() 469 | end = time.time() 470 | 471 | elapsed_time = (end - start) / rep 472 | print(f"Triton kernel time: {elapsed_time * 1000:.2f} ms") 473 | 474 | 475 | return c_triton, c_ref, max_diff.item() 476 | 477 | if __name__ == "__main__": 478 | test_kernel() 479 | -------------------------------------------------------------------------------- /annotated_examples/gemlite/gemm_splitK.py: -------------------------------------------------------------------------------- 1 | # Written by Dr. Hicham Badri @Mobius Labs GmbH - 2024 2 | # Modified by MekkCyber - 2025 3 | #******************************************************* 4 | import triton 5 | import triton.language as tl 6 | 7 | @triton.jit 8 | def dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample: tl.constexpr, W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr): 9 | """ 10 | Dequantizes packed integer values into floating point values using various quantization schemes. 11 | 12 | Args: 13 | b: Packed quantized values (typically int32) 14 | scales: Scaling factors for dequantization (per group or channel) 15 | zeros: Zero points for asymmetric quantization (per group or channel) 16 | q_shift: Bit shift amount for unpacking elements from packed format 17 | meta_dtype: Target data type for metadata operations 18 | unpack_mask: Bit mask for extracting individual elements (e.g., 0xF for 4-bit) 19 | elements_per_sample: Number of quantized elements packed into each storage unit 20 | W_group_mode: Quantization scheme to use (1-4) 21 | zero_is_scalar: Whether zero point is shared across all elements 22 | 23 | Returns: 24 | Dequantized tensor in floating point format 25 | """ 26 | # Step 1: Unpack the elements if they are packed (e.g., 8 4-bit values in one int32) 27 | if(elements_per_sample > 1): 28 | # Extract individual quantized values using bit shifting and masking 29 | # q_shift determines which element to extract based on position 30 | b = (b >> q_shift) & unpack_mask # int32 -> int32 31 | 32 | # Step 2: Apply the appropriate dequantization formula based on W_group_mode 33 | 34 | if(W_group_mode == 1): # Shift-only mode (zero-point subtraction) 35 | # Formula: dequantized = quantized - zero_point 36 | b = b.to(meta_dtype) - zeros 37 | 38 | if(W_group_mode == 2): # Scale-only mode (symmetric quantization) 39 | # Formula: dequantized = quantized * scale 40 | # Used when quantized values are centered around zero 41 | b = b.to(meta_dtype) * scales 42 | 43 | if(W_group_mode == 3): # Scale and shift mode (asymmetric quantization) 44 | # Formula: dequantized = (quantized - zero_point) * scale 45 | if(zero_is_scalar): 46 | # When zero_point is shared across all elements (memory optimization) 47 | b = (b - zeros).to(meta_dtype) * scales 48 | else: 49 | # When each group has its own zero_point 50 | b = (b.to(meta_dtype) - zeros) * scales 51 | 52 | if(W_group_mode == 4): # Fused multiply-add mode 53 | # Formula: dequantized = quantized * scale + zero 54 | # Uses fused multiply-add for better performance 55 | # Note: in this mode, 'zeros' is actually an additive term, not a zero point 56 | b = tl.fma(b.to(meta_dtype), scales, zeros) 57 | 58 | return b 59 | 60 | @triton.jit 61 | def swizzle_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr): 62 | grid_m = tl.cdiv(M, BLOCK_SIZE_M) 63 | grid_n = tl.cdiv(N, BLOCK_SIZE_N) 64 | width = GROUP_SIZE_M * grid_n 65 | group_id = pid // width 66 | group_size = tl.minimum(grid_m - group_id * GROUP_SIZE_M, GROUP_SIZE_M) 67 | pid_m = group_id * GROUP_SIZE_M + (pid % group_size) 68 | pid_n = (pid % width) // group_size 69 | return pid_m, pid_n 70 | 71 | @triton.jit 72 | def linear_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr): 73 | pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N) 74 | pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N) 75 | return pid_m, pid_n 76 | 77 | 78 | @triton.jit 79 | def gemm_splitK_kernel( 80 | a_ptr, b_ptr, c_ptr, 81 | scales_ptr, zeros_ptr, scales_a_ptr, 82 | M, N, K, 83 | ######### Quant parms ######### 84 | W_nbits: tl.constexpr, 85 | group_size: tl.constexpr, 86 | unpack_mask: tl.constexpr, 87 | elements_per_sample: tl.constexpr, 88 | ######### Strides ######### 89 | stride_am, stride_ak, 90 | stride_bk, stride_bn, 91 | stride_cm, stride_cn, 92 | stride_meta_g, stride_meta_n, 93 | ######### Dtypes ######### 94 | input_dtype: tl.constexpr, 95 | output_dtype: tl.constexpr, 96 | acc_dtype: tl.constexpr, 97 | meta_dtype: tl.constexpr, 98 | ######### Meta-data mode ######### 99 | channel_scale_mode: tl.constexpr, 100 | W_group_mode: tl.constexpr, 101 | zero_is_scalar: tl.constexpr, 102 | ######### tuning params ######### 103 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, 104 | GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, 105 | A_load_order: tl.constexpr, meta_evict_policy: tl.constexpr, atomic_mode: tl.constexpr, 106 | data_contiguous: tl.constexpr, 107 | ): 108 | """ 109 | Quantized GEMM with split-K parallelism for C = matmul(A, dequantize(B, scales, zeros)) 110 | 111 | A is of shape (M, K): float16 or bfloat16 112 | B is of shape (K//elements_per_sample, N): int32 as a packed matrix 113 | C is of shape (M, N): float16 or bfloat16 (same dtype as input A) 114 | scales is of shape (K//group_size, N) or (1, N): meta_dtype 115 | zeros is of shape (K//group_size, N) or (1, 1): meta_dtype 116 | 117 | Requirements: 118 | - BLOCK_SIZE_M must be >= 16 119 | - BLOCK_SIZE_K * SPLIT_K must be <= group_size 120 | 121 | Based on the split-K dequantization GEMM implementation from: 122 | https://github.com/foundation-model-stack/foundation-model-stack/blob/main/fms/models/llm/kernels/triton/splitk_dequant_gemm.py 123 | """ 124 | 125 | # It's recommended to understand the standard GEMM implementation first in gemlite/gemm.py, as this split-K version 126 | # builds upon it, so we will not delve into the details of the standard GEMM implementation here. 127 | # 128 | # Unlike standard GEMM where each thread block computes its entire output tile, 129 | # in split-K GEMM we have a 2D grid of thread blocks where: 130 | # - pid (x-dim): Each thread block along x-axis is responsible for a unique output tile 131 | # - pid_k (y-dim): Multiple thread blocks along y-axis collaborate on the same output tile, 132 | # each computing a partial result and using atomic adds to safely accumulate into the final output 133 | pid = tl.program_id(axis=0) # Determines which output tile to compute 134 | pid_k = tl.program_id(axis=1) # Determines which K-slice to process for this output tile 135 | 136 | pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M) 137 | 138 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) 139 | 140 | 141 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 142 | offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 143 | offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 144 | 145 | offs_am = offs_m 146 | offs_ak = tl.max_contiguous(tl.multiple_of(offs_k, BLOCK_SIZE_K), BLOCK_SIZE_K) 147 | 148 | if(data_contiguous): 149 | offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_SIZE_N), BLOCK_SIZE_N) 150 | offs_bk = offs_k 151 | else: 152 | offs_bn = offs_n 153 | offs_bk = tl.max_contiguous(tl.multiple_of(offs_k, BLOCK_SIZE_K), BLOCK_SIZE_K) 154 | 155 | b_ptrs = b_ptr + ((offs_bk[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) 156 | q_shift = ((offs_bk % elements_per_sample) * W_nbits).to(tl.int32)[:, None] 157 | 158 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak) 159 | a_mask = offs_am[:, None] < M 160 | 161 | scales_ptrs = scales_ptr + offs_bn[None, :] * stride_meta_n 162 | zeros_ptrs = zeros_ptr + offs_bn[None, :] * stride_meta_n 163 | 164 | stride_mul: tl.constexpr = BLOCK_SIZE_K / group_size 165 | 166 | # BLOCK_SIZE_K_U: How much to advance pointers in matrix A (unpacked matrix) 167 | # We multiply by SPLIT_K since each thread block processes only K / SPLIT_K * BLOCK_SIZE_K elements 168 | # This represents the stride in the K dimension for matrix A 169 | BLOCK_SIZE_K_U: tl.constexpr = BLOCK_SIZE_K * SPLIT_K 170 | 171 | # BLOCK_SIZE_K_P: How much to advance pointers in matrix B (packed matrix) 172 | # Since B is packed with elements_per_sample values per int32 173 | # We divide BLOCK_SIZE_K by elements_per_sample to get number of int32s 174 | # Then multiply by SPLIT_K for the same reason as above 175 | # This represents the stride in the K dimension for packed matrix B 176 | BLOCK_SIZE_K_P: tl.constexpr = (BLOCK_SIZE_K // elements_per_sample) * SPLIT_K 177 | 178 | if(zero_is_scalar): 179 | zero_scalar = tl.load(zeros_ptr, eviction_policy='evict_last') 180 | 181 | acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) 182 | 183 | for k in range(num_pid_k): 184 | 185 | if(A_load_order == 0): 186 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last') 187 | 188 | b = tl.load(b_ptrs, eviction_policy='evict_first') 189 | 190 | if(A_load_order == 1): 191 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last') 192 | 193 | # This code calculates which group we're currently processing for weight quantization 194 | if(W_group_mode > 0): 195 | # Important: We need BLOCK_SIZE_K to be smaller than group_size 196 | # This is because we only load one line from the scales here 197 | # 198 | # Example with proper sizing: 199 | # - BLOCK_SIZE_K = 8 (how many K elements we process per block) 200 | # - group_size = 32 (we quantize weights in groups of 32) 201 | # - SPLIT_K = 4 (we split K dimension across 4 blocks) 202 | # - stride_mul = BLOCK_SIZE_K/group_size = 8/32 = 0.25 (fraction of a group in one block) 203 | # 204 | # k: outer loop counter (0, 1, 2..., tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)) 205 | # pid_k: which split the thread block is responsible for (0 to SPLIT_K-1) 206 | # 207 | # For example, if k=2, SPLIT_K=4, pid_k=1: 208 | # k * SPLIT_K + pid_k = 2 * 4 + 1 = 9 209 | # 210 | # Multiply by stride_mul to convert from blocks to groups: 211 | # 9 * 0.25 = 2.25, which means we're processing part of the 2nd group 212 | k_m = ((k * SPLIT_K + pid_k) * stride_mul).to(tl.int32) 213 | 214 | if(W_group_mode >= 2): #[2, 3, 4] 215 | scales = tl.load(scales_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) 216 | else: 217 | scales = None 218 | 219 | if(W_group_mode == 1 or W_group_mode >= 3): #[1, 3, 4] 220 | if(zero_is_scalar): 221 | zeros = zero_scalar 222 | else: 223 | zeros = tl.load(zeros_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) 224 | else: 225 | zeros = None 226 | 227 | if(A_load_order == 2): #Mid load 228 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last') 229 | 230 | b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) 231 | 232 | if(A_load_order == 3): 233 | a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last') 234 | 235 | acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype, input_precision="tf32") 236 | 237 | # Advance pointers for the next iteration of the k-loop 238 | a_ptrs += BLOCK_SIZE_K_U * stride_ak 239 | b_ptrs += BLOCK_SIZE_K_P * stride_bk 240 | 241 | if(channel_scale_mode == 1): 242 | scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) 243 | acc = acc.to(meta_dtype) * scales_b[None, :] 244 | 245 | if(channel_scale_mode == 2): 246 | scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) 247 | scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=meta_dtype) 248 | acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) 249 | 250 | if(channel_scale_mode == 3): 251 | scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) 252 | scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) 253 | acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) 254 | 255 | acc = acc.to(output_dtype) 256 | 257 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 258 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 259 | offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) 260 | c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) 261 | 262 | # We know that each thread block computes a partial result for the same output location (M,N coordinates) 263 | # When SPLIT_K > 1, multiple blocks will write to the same memory location 264 | # We use atomic_add to safely accumulate these partial results from different blocks 265 | # without race conditions, ensuring all contributions are correctly summed 266 | # The atomic operation guarantees that concurrent updates to the same memory 267 | # location happen in a coordinated way, preventing data corruption 268 | if(SPLIT_K > 1): 269 | tl.atomic_add(c_ptrs, acc, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N), sem=atomic_mode) #release / relaxed 270 | else: 271 | # When SPLIT_K = 1, each output location is computed by exactly one block 272 | # so we can use a simple store operation instead of an atomic add (this is the same as the standard GEMM) 273 | tl.store(c_ptrs, acc, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) 274 | -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MekkCyber/TritonAcademy/7a3bf78b72406edbeb9fb6c4be07d7e694061e3f/assets/logo.png -------------------------------------------------------------------------------- /backprop_math/cross_entropy.md: -------------------------------------------------------------------------------- 1 | # Cross Entropy 2 | 3 | ## Definition of Cross Entropy Loss 4 | Cross Entropy Loss for a classification problem is defined as: 5 | 6 | $$CE(x, class) = -\log(\text{softmax}(x)[class])$$ 7 | 8 | Where softmax is defined as: 9 | 10 | $$\text{softmax}(x)[i] = \frac{e^{x_i}}{\sum_j e^{x_j}}$$ 11 | 12 | ## Expanded form of Cross Entropy Loss 13 | 14 | $$CE(x, class) = -\log\left(\frac{e^{x_{class}}}{\sum_i e^{x_i}}\right)$$ 15 | 16 | $$CE(x, class) = -x_{class} + \log\left(\sum_i e^{x_i}\right)$$ 17 | 18 | Let's denote $z = \log\left(\sum_i e^{x_i}\right)$, which is the LogSumExp function. 19 | 20 | ## Compute gradients with respect to each logit 21 | 22 | ### Case 1: For the correct class (i = class) 23 | 24 | $$\frac{\partial CE}{\partial x_{class}} = \frac{\partial}{\partial x_{class}}(-x_{class} + z)$$ 25 | 26 | $$\frac{\partial CE}{\partial x_{class}} = -1 + \frac{\partial z}{\partial x_{class}}$$ 27 | 28 | For the LogSumExp term: 29 | 30 | $$\frac{\partial z}{\partial x_{class}} = \frac{\partial}{\partial x_{class}}\log\left(\sum_i e^{x_i}\right)$$ 31 | 32 | $$\frac{\partial z}{\partial x_{class}} = \frac{1}{\sum_i e^{x_i}} \cdot \frac{\partial}{\partial x_{class}}\left(\sum_i e^{x_i}\right)$$ 33 | 34 | $$\frac{\partial z}{\partial x_{class}} = \frac{1}{\sum_i e^{x_i}} \cdot e^{x_{class}}$$ 35 | 36 | $$\frac{\partial z}{\partial x_{class}} = \frac{e^{x_{class}}}{\sum_i e^{x_i}} = \text{softmax}(x)[class]$$ 37 | 38 | Substituting back: 39 | 40 | $$\frac{\partial CE}{\partial x_{class}} = -1 + \text{softmax}(x)[class]$$ 41 | 42 | ### Case 2: For other classes (i ≠ class) 43 | 44 | $$\frac{\partial CE}{\partial x_i} = \frac{\partial}{\partial x_i}(-x_{class} + z)$$ 45 | 46 | $$\frac{\partial CE}{\partial x_i} = 0 + \frac{\partial z}{\partial x_i}$$ 47 | 48 | For the LogSumExp term same as before: 49 | 50 | $$\frac{\partial z}{\partial x_i} = \frac{e^{x_i}}{\sum_j e^{x_j}} = \text{softmax}(x)[i]$$ 51 | 52 | For the correct class (i = class): 53 | 54 | $$\frac{\partial CE}{\partial x_{class}} = -1 + \text{softmax}(x)[class]$$ 55 | 56 | For other classes (i ≠ class): 57 | 58 | $$\frac{\partial CE}{\partial x_i} = \text{softmax}(x)[i]$$ 59 | 60 | ## Generalize to a single formula 61 | We can combine both cases into one formula: 62 | 63 | $$\frac{\partial CE}{\partial x_i} = \text{softmax}(x)[i] - \mathbf{1}_{i=class}$$ 64 | 65 | Where $\mathbf{1}_{i=class}$ is an indicator function that equals 1 when i is the correct class and 0 otherwise. 66 | -------------------------------------------------------------------------------- /backprop_math/geglu.md: -------------------------------------------------------------------------------- 1 | # Derivative of GeLU 2 | 3 | ## Exact Derivative 4 | 5 | Starting with: 6 | 7 | $$f = \frac{1}{2} \cdot x \cdot \left(1 + \text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)\right)$$ 8 | 9 | Using the product rule: 10 | 11 | $$\frac{d}{dx}[u(x) \cdot v(x)] = u'(x) \cdot v(x) + u(x) \cdot v'(x)$$ 12 | 13 | Let: 14 | 15 | - $u(x) = \frac{1}{2} \cdot x$ 16 | - $v(x) = 1 + \text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)$ 17 | 18 | Step 1: Find $u'(x)$ 19 | 20 | $$u'(x) = \frac{d}{dx}\left[\frac{1}{2} \cdot x\right] = \frac{1}{2}$$ 21 | 22 | Step 2: Find $v'(x)$ 23 | We need the chain rule here. The derivative of erf(x) is 24 | 25 | $$\frac{2}{\sqrt{\pi}} \cdot e^{-x^2}$$ 26 | 27 | $$v'(x) = \frac{d}{dx}\left[1 + \text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)\right] = \frac{d}{dx}\left[\text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)\right]$$ 28 | 29 | Using the chain rule with $g(x) = \frac{1}{\sqrt{2}} \cdot x$: 30 | 31 | $$v'(x) = \frac{2}{\sqrt{\pi}} \cdot e^{-\left(\frac{1}{\sqrt{2}} \cdot x\right)^2} \cdot \frac{d}{dx}\left[\frac{1}{\sqrt{2}} \cdot x\right]$$ 32 | 33 | $$v'(x) = \frac{2}{\sqrt{\pi}} \cdot e^{-\frac{x^2}{2}} \cdot \frac{1}{\sqrt{2}}$$ 34 | 35 | $$v'(x) = \frac{2}{\sqrt{\pi}} \cdot \frac{1}{\sqrt{2}} \cdot e^{-\frac{x^2}{2}}$$ 36 | 37 | $$v'(x) = \frac{2}{\sqrt{2\pi}} \cdot e^{-\frac{x^2}{2}}$$ 38 | 39 | Step 3: Apply the product rule 40 | 41 | $$\frac{df}{dx} = u'(x) \cdot v(x) + u(x) \cdot v'(x)$$ 42 | 43 | $$\frac{df}{dx} = \frac{1}{2} \cdot \left(1 + \text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)\right) + \frac{1}{2} \cdot x \cdot \frac{2}{\sqrt{2\pi}} \cdot e^{-\frac{x^2}{2}}$$ 44 | 45 | $$\frac{df}{dx} = \frac{1}{2} \cdot \left(1 + \text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)\right) + \frac{x}{\sqrt{2\pi}} \cdot e^{-\frac{x^2}{2}}$$ 46 | 47 | This is our final result: 48 | 49 | $$\frac{df}{dx} = \frac{1}{2} \cdot \left(1 + \text{erf}\left(\frac{1}{\sqrt{2}} \cdot x\right)\right) + \frac{x}{\sqrt{2\pi}} \cdot e^{-\frac{x^2}{2}}$$ 50 | 51 | 52 | ## Approximate Derivative 53 | 54 | Starting with: 55 | 56 | $$f(x) = 0.5 \cdot x \cdot (1 + \tanh(\sqrt{\frac{2}{\pi}} \cdot x \cdot (1 + 0.044715 \cdot x^2)))$$ 57 | 58 | For simplicity, let's denote: 59 | 60 | $$z(x) = \sqrt{\frac{2}{\pi}} \cdot x \cdot (1 + 0.044715 \cdot x^2) = x \cdot (a + b \cdot x^2)$$ 61 | 62 | and 63 | 64 | $$v(x) = 1 + \tanh(z(x))$$ 65 | 66 | Then: 67 | 68 | $$f(x) = 0.5 \cdot x \cdot (1 + \tanh(z(x))) = 0.5 \cdot x \cdot v(x)$$ 69 | 70 | Using the product rule: 71 | 72 | $$\frac{d}{dx}[u(x) \cdot v(x)] = u'(x) \cdot v(x) + u(x) \cdot v'(x)$$ 73 | 74 | Let: 75 | - $u(x) = 0.5 \cdot x$ 76 | - $v(x) = 1 + \tanh(z(x))$ 77 | 78 | Step 1: Find $u'(x)$ 79 | 80 | $$u'(x) = 0.5$$ 81 | 82 | Step 2: Find $v'(x)$ 83 | Using the chain rule and the fact that the derivative of $\tanh(x)$ is $1 - \tanh^2(x)$: 84 | 85 | $$v'(x) = (1 - \tanh^2(z(x))) \cdot z'(x)$$ 86 | 87 | The derivative of $z(x)$: 88 | 89 | $$z'(x) = a + 3b \cdot x^2$$ 90 | 91 | Step 3: Using the identity for $1 - \tanh^2(z(x))$: 92 | 93 | $$1 - \tanh^2(z(x)) = (1 - \tanh(z(x)))(1 + \tanh(z(x)))$$ 94 | 95 | $$1 - \tanh^2(z(x)) = (2 - (1 + \tanh(z(x))))(1 + \tanh(z(x)))$$ 96 | 97 | $$1 - \tanh^2(z(x)) = (2- v(x))v(x)$$ 98 | 99 | This confirms our identity. Now using the form with $(2 - (1 + \tanh(z(x))))$: 100 | 101 | $$v'(x) = (2- v(x))v(x) \cdot z'(x)$$ 102 | 103 | Step 5: Apply the product rule for the complete derivative: 104 | 105 | $$\frac{df}{dx} = 0.5 \cdot (1 + \tanh(z(x))) + 0.5 \cdot x \cdot (2- v(x))v(x) \cdot z'(x)$$ 106 | 107 | Substituting $z'(x) = a + 3b \cdot x^2$: 108 | 109 | $$\frac{df}{dx} = 0.5 \cdot v(x) + 0.5 \cdot x \cdot (2 - v(x)) \cdot v(x) \cdot (a + 3b \cdot x^2)$$ 110 | 111 | $$\frac{df}{dx} = 0.5 \cdot v(x) \cdot \left[1 + x \cdot (2 - v(x)) \cdot (a + 3b \cdot x^2)\right]$$ 112 | 113 | -------------------------------------------------------------------------------- /backprop_math/layernorm.md: -------------------------------------------------------------------------------- 1 | # Layer Normalization Backward Pass Derivation 2 | 3 | ## Forward Pass 4 | 5 | First, let's establish the forward pass equations: 6 | 7 | $$\mu = \frac{1}{n}\sum_{i=1}^n X_i$$ 8 | 9 | $$\sigma = \sqrt{\frac{1}{n}\sum_{i=1}^n(X_i-\mu)^2}$$ 10 | 11 | $$\hat{X} = \frac{X - \mu}{\sigma}$$ 12 | 13 | $$Y = \gamma \odot \hat{X} + \beta$$ 14 | 15 | Where: 16 | - $X$ is the input tensor 17 | - $\mu$ is the mean (scalar) 18 | - $\sigma$ is the standard deviation (scalar) 19 | - $\hat{X}$ is the normalized input 20 | - $\gamma$ and $\beta$ are learnable parameters 21 | - $n$ is the feature dimension 22 | - $\odot$ represents element-wise multiplication 23 | 24 | ## Backward Pass Derivation 25 | 26 | We'll derive $\nabla_X$ (gradient with respect to input) given $\nabla_Y$ (gradient from the output). 27 | 28 | ### Step 1: Gradient from $Y$ to $\hat{X}$ 29 | 30 | Starting with 31 | 32 | $$Y = \gamma \odot \hat{X} + \beta$$ 33 | 34 | Taking the derivative with respect to $\hat{X}$: 35 | 36 | $$ \nabla_{\hat{X}} = \frac{\partial \mathcal{L}}{\partial \hat{X}} = \frac{\partial \mathcal{L}}{\partial Y} \cdot \frac{\partial Y}{\partial \hat{X}} = \nabla_Y \odot \gamma$$ 37 | 38 | This means each element of the gradient with respect to $\hat{X}$ is the corresponding element of $\nabla_Y$ multiplied by the corresponding element of $\gamma$. 39 | 40 | ### Step 2: Gradient from $\hat{X}$ to $X$ 41 | 42 | Now we need to compute $\nabla_X$ given $\nabla_{\hat{X}}$, using the chain rule again: 43 | 44 | $$\nabla_X = \frac{\partial \hat{X}}{\partial X} \cdot \nabla_{\hat{X}}$$ 45 | 46 | We need to compute the gradient of $\hat{X}$ with respect to $X$. The normalized value is: 47 | 48 | $$\hat{X} = \frac{X - \mu}{\sigma}$$ 49 | 50 | We need to account for how changes in $X$ affect $\hat{X}$ both directly and through $\mu$ and $\sigma$. 51 | 52 | #### Component 1: Direct effect on $X$ 53 | 54 | For the direct effect (ignoring effects through $\mu$ and $\sigma$): 55 | 56 | $$\frac{\partial \hat{X}}{\partial X}_{\text{direct}} = \frac{1}{\sigma}\mathbf{I}$$ 57 | 58 | Where $\mathbf{I}$ is the identity matrix. 59 | 60 | #### Component 2: Effect through $\mu$ 61 | 62 | The mean $\mu = \frac{1}{n}\sum_{i=1}^n X_i$ depends on all elements of $X$. 63 | 64 | For any element $X_j$: 65 | 66 | $$\frac{\partial \mu}{\partial X_j} = \frac{1}{n}$$ 67 | 68 | The effect on $\hat{X}_i$ through $\mu$ is: 69 | 70 | $$\frac{\partial \hat{X}_i}{\partial \mu} = -\frac{1}{\sigma}$$ 71 | 72 | Combining these: 73 | 74 | $$\frac{\partial \hat{X}_i}{\partial X_j}_{\text{via }\mu} = \frac{\partial \hat{X}_i}{\partial \mu} \cdot \frac{\partial \mu}{\partial X_j} = -\frac{1}{\sigma} \cdot \frac{1}{n} = -\frac{1}{n\sigma}$$ 75 | 76 | #### Component 3: Effect through $\sigma$ 77 | 78 | The standard deviation $\sigma = \sqrt{\frac{1}{n}\sum_{i=1}^n(X_i-\mu)^2}$ also depends on all elements of $X$. 79 | 80 | First, let's compute $\frac{\partial \sigma}{\partial X_j}$: 81 | 82 | $${2\sigma} \cdot \frac{\partial \sigma}{\partial X_j} = \frac{\partial}{\partial X_j}\left(\frac{1}{n}\sum_{i=1}^n(X_i-\mu)^2\right)$$ 83 | 84 | Which is: 85 | 86 | $$\frac{\partial \sigma}{\partial X_j} = \frac{1}{2\sigma} \cdot \frac{\partial}{\partial X_j}\left(\frac{1}{n}\sum_{i=1}^n(X_i-\mu)^2\right)$$ 87 | 88 | We need to account for both the direct effect on $(X_j-\mu)^2$ and the indirect effect through $\mu$ on all terms $(X_i-\mu)^2$. 89 | 90 | The direct effect when $i = j$ is: 91 | 92 | $$\frac{\partial}{\partial X_j}(X_j-\mu)^2 = 2(X_j-\mu) \cdot \left(1 - \frac{\partial \mu}{\partial X_j}\right) = 2(X_j-\mu) \cdot \left(1 - \frac{1}{n}\right)$$ 93 | 94 | The indirect effect through $\mu$ for each $i \neq j$ is: 95 | 96 | $$\frac{\partial}{\partial X_j}(X_i-\mu)^2 = 2(X_i-\mu) \cdot \left(- \frac{\partial \mu}{\partial X_j}\right) = -2(X_i-\mu) \cdot \frac{1}{n}$$ 97 | 98 | Combining these and simplifying: 99 | 100 | $$\frac{\partial \sigma}{\partial X_j} = \frac{1}{2\sigma} \cdot \frac{1}{n} \cdot \left(2(X_j-\mu)\left(1-\frac{1}{n}\right) - \sum_{i \neq j}2(X_i-\mu)\frac{1}{n}\right)$$ 101 | 102 | This further simplifies to: 103 | 104 | $$\frac{\partial \sigma}{\partial X_j} = \frac{1}{n\sigma}(X_j-\mu)$$ 105 | 106 | because $\sum_{i=1}^n (X_i-\mu) = 0$ and that's because $\mu = \frac{1}{n}\sum_{i=1}^n X_i$. 107 | 108 | Or in terms of $\hat{X}$: 109 | 110 | $$\frac{\partial \sigma}{\partial X_j} = \frac{1}{n}\hat{X}_j$$ 111 | 112 | Now, the effect on $\hat{X}_i$ through $\sigma$ is: 113 | 114 | $$\frac{\partial \hat{X}_i}{\partial \sigma} = -\frac{X_i-\mu}{\sigma^2} = -\frac{\hat{X}_i}{\sigma}$$ 115 | 116 | Combining these: 117 | 118 | $$\frac{\partial \hat{X}_i}{\partial X_j}_{\text{via }\sigma} = \frac{\partial \hat{X}_i}{\partial \sigma} \cdot \frac{\partial \sigma}{\partial X_j} = -\frac{\hat{X}_i}{\sigma} \cdot \frac{1}{n}\hat{X}_j = -\frac{1}{n\sigma}\hat{X}_i\hat{X}_j$$ 119 | 120 | #### Combining All Components 121 | 122 | Adding all three components together: 123 | 124 | $$\nabla_{X_i} = \left(\frac{\partial \hat{X}}{\partial X}\right)_{i,:}\cdot \nabla_{\hat{X}}$$ 125 | 126 | $$= \frac{1}{\sigma}\nabla_{\hat{X}_i} - \frac{1}{n\sigma}\sum_{j=1}^n \nabla_{\hat{X}_j} - \frac{\hat{X}_i}{n\sigma}\sum_{j=1}^n \nabla_{\hat{X}_j}\hat{X}_j$$ 127 | 128 | In vector notation: 129 | 130 | $$\nabla_X = \frac{1}{\sigma}\nabla_{\hat{X}} - \frac{1}{n\sigma}\mathbf{1}\sum_{i=1}^n \nabla_{\hat{X}_i} - \frac{1}{n\sigma}\hat{X} \odot \left(\sum_{i=1}^n \nabla_{\hat{X}_i}\hat{X}_i\right)$$ 131 | 132 | Where $\mathbf{1}$ is a vector of ones. 133 | 134 | Substituting $\nabla_{\hat{X}} = \nabla_Y \odot \gamma$: 135 | 136 | $$\nabla_X = \frac{1}{\sigma}\left(\nabla_Y \odot \gamma - \frac{1}{n}\mathbf{1}\sum_{i=1}^n (\nabla_Y \odot \gamma)_i - \hat{X} \odot \frac{1}{n}\sum_{i=1}^n(\nabla_Y \odot \gamma \odot \hat{X})_i\right)$$ 137 | 138 | This can be written more compactly as: 139 | 140 | $$\nabla_X = \frac{1}{\sigma}\left(\nabla_Y \odot \gamma - \left(\frac{1}{n}\hat{X} \cdot (\nabla_Y \odot \gamma)\right) \odot \hat{X} - \frac{1}{n}\nabla_Y \cdot \gamma \right)$$ 141 | 142 | This is the complete formula for the backward pass of layer normalization with respect to the input $X$. 143 | 144 | -------------------------------------------------------------------------------- /backprop_math/swiglu.md: -------------------------------------------------------------------------------- 1 | # Derivative of SwiGLU 2 | 3 | We have: 4 | 5 | $$f(x) = \frac{x}{1 + e^{-x}}$$ 6 | 7 | Find $\frac{df}{dx}$ using the quotient rule. 8 | For $f(x) = \frac{u(x)}{v(x)}$, the quotient rule gives us: 9 | 10 | $$\frac{df}{dx} = \frac{u'(x) \cdot v(x) - u(x) \cdot v'(x)}{v(x)^2}$$ 11 | 12 | Where: 13 | - $u(x) = x$ 14 | - $v(x) = 1 + e^{-x}$ 15 | 16 | Calculate $u'(x)$: 17 | 18 | $$u'(x) = \frac{d}{dx}[x] = 1$$ 19 | 20 | Calculate $v'(x)$: 21 | 22 | $$v'(x) = \frac{d}{dx}[1 + e^{-x}] = -e^{-x}$$ 23 | 24 | We apply the quotient rule: 25 | 26 | $$\frac{df}{dx} = \frac{1 \cdot (1 + e^{-x}) + x \cdot e^{-x}}{(1 + e^{-x})^2}$$ 27 | 28 | $$\frac{df}{dx} = \frac{1 + e^{-x} + x \cdot e^{-x}}{(1 + e^{-x})^2}$$ 29 | 30 | $$\frac{df}{dx} = \frac{1}{1 + e^{-x}} + \frac{x \cdot (e^{-x} + 1)}{(1 + e^{-x})^2} - \frac{x}{(1 + e^{-x})^2}$$ 31 | 32 | Alternative expression using sigmoid function. 33 | Since $s = \sigma(x) = \frac{1}{1 + e^{-x}}$, we can write: 34 | 35 | $$\frac{df}{dx} = \sigma(x) + x \cdot \sigma(x) - x \cdot \sigma(x)^2 = \sigma(x) \cdot (1 + x \cdot (1 - \sigma(x)))$$ 36 | --------------------------------------------------------------------------------