├── Checkpoints └── chekpoint.txt ├── Data └── input.txt ├── README.md └── triton_nanoGPT.py /Checkpoints/chekpoint.txt: -------------------------------------------------------------------------------- 1 | .pth 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Triton-Accelerated NanoGPT 2 | 3 | The WHY behind this ordeal 4 | 5 | After practicing triton for about 2 weeks now, I challenged myself into implementing custom triton kernels for Karpathy's nanoGPT and quite an ordeal it was but somehow got something working, not perfect but getting there:), contributions are welcomed. 6 | 7 | ## Kernels 8 | Supports lightweight custom triton kernels for softmax, layer normalization, cross entropy loss and GELU activation. 9 | 10 | ## Training 11 | 12 | GPU-aware train loop with effective gradient accumulation, learning rate scheduling and gradient clipping with val loss tracking. 13 | 14 | - **Setup**: Requires GPU! Ensure you have PyTorch and Triton installed. GPU Poor? I am too, I used one free T4 on google colab. 15 | 16 | - **Data**: Using Tiny Shakespeare dataset by default. It will be downloaded automatically if not present. 17 | 18 | - **Training**: 19 | ```python 20 | python triton_nanoGPT.py 21 | ``` 22 | This will train for 100 epochs, save checkpoint as `nanoGPT_cpkt.pth` and sample from it. 23 | ## License 24 | 25 | MIT 26 | -------------------------------------------------------------------------------- /triton_nanoGPT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import triton 7 | import triton.language as tl 8 | import math 9 | import time 10 | 11 | # ================================================================ 12 | # The WHY behind this ordeal? 13 | # After practicing triton for about 2 weeks, I attempted 14 | # implementing custom Triton kernels for Karpathy's nanoGPT. 15 | # Still not perfect and would appreciate contributions:) 16 | # ================================================================ 17 | 18 | # ----------------------------- 19 | # Data Preprocessing 20 | # ----------------------------- 21 | 22 | def dataset(url, filepath): 23 | if not os.path.exists(filepath): 24 | print(f"Downloading dataset from {url}...") 25 | response = requests.get(url) 26 | with open(filepath, 'wb') as f: 27 | f.write(response.content) 28 | print(f"Dataset downloaded and saved to {filepath}.") 29 | else: 30 | print(f"Dataset already exists at {filepath}.") 31 | 32 | url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" 33 | filepath = "input.txt" 34 | 35 | dataset(url, filepath) 36 | with open('input.txt', 'r') as f: 37 | text = f.read() 38 | 39 | chars = sorted(list(set(text))) 40 | vocab_size = len(chars) 41 | print(f"Vocabulary size: {vocab_size}") 42 | 43 | stoi = { ch:i for i,ch in enumerate(chars) } 44 | itos = { i:ch for i,ch in enumerate(chars) } 45 | 46 | def encode(text): 47 | return torch.tensor([stoi[c] for c in text], dtype=torch.long) 48 | 49 | def decode(indices): 50 | return ''.join([itos[i.item()] for i in indices]) 51 | 52 | data = encode(text) 53 | 54 | n = int(0.9 * len(data)) 55 | train_data = data[:n].cuda() 56 | test_data = data[n:].cuda() 57 | 58 | print(f"Training data size: {train_data.numel()} characters") 59 | print(f"Testing data size: {test_data.numel()} characters") 60 | 61 | # ----------------------------- 62 | # Triton Kernels 63 | # ----------------------------- 64 | 65 | @triton.jit 66 | def softmax_kernel( 67 | output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, 68 | BLOCK_SIZE: tl.constexpr 69 | ): 70 | row_idx = tl.program_id(0) 71 | col_offsets = tl.arange(0, BLOCK_SIZE) 72 | mask = col_offsets < n_cols 73 | 74 | input_row_ptr = input_ptr + row_idx * input_row_stride + col_offsets 75 | output_row_ptr = output_ptr + row_idx * output_row_stride + col_offsets 76 | 77 | logits = tl.load(input_row_ptr, mask=mask, other=float('-inf')) 78 | max_logits = tl.max(logits, axis=0) 79 | logits = logits - max_logits 80 | exp_logits = tl.exp(logits) 81 | sum_exp_logits = tl.sum(exp_logits, axis=0) + 1e-6 82 | 83 | softmax_output = exp_logits / sum_exp_logits 84 | tl.store(output_row_ptr, softmax_output, mask=mask) 85 | 86 | @triton.jit 87 | def layer_norm_kernel( 88 | x_ptr, weight_ptr, bias_ptr, y_ptr, 89 | N, eps: tl.constexpr, 90 | BLOCK_SIZE: tl.constexpr 91 | ): 92 | row_idx = tl.program_id(0) 93 | cols = tl.arange(0, BLOCK_SIZE) 94 | mask = cols < N 95 | 96 | x_offset = x_ptr + row_idx * N + cols 97 | x = tl.load(x_offset, mask=mask, other=0.0) 98 | 99 | mean = tl.sum(x, axis=0) / N 100 | x_centered = x - mean 101 | var = tl.sum(x_centered * x_centered, axis=0) / N 102 | rstd = 1.0 / tl.sqrt(var + eps) 103 | 104 | w = tl.load(weight_ptr + cols, mask=mask, other=1.0) 105 | b = tl.load(bias_ptr + cols, mask=mask, other=0.0) 106 | 107 | y = (x_centered * rstd) * w + b 108 | tl.store(y_ptr + row_idx * N + cols, y, mask=mask) 109 | 110 | @triton.jit 111 | def cross_entropy_loss_kernel( 112 | logits_ptr, targets_ptr, loss_ptr, 113 | n_classes, n_elements, 114 | BLOCK_SIZE: tl.constexpr 115 | ): 116 | pid = tl.program_id(0) 117 | block_start = pid * BLOCK_SIZE 118 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 119 | mask = offsets < n_elements 120 | 121 | targets = tl.load(targets_ptr + offsets, mask=mask, other=-1) 122 | 123 | row_max = tl.full([BLOCK_SIZE], float('-inf'), dtype=tl.float32) 124 | row_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 125 | 126 | for i in range(n_classes): 127 | col_offset = offsets * n_classes + i 128 | logit = tl.load(logits_ptr + col_offset, mask=mask, other=float('-inf')) 129 | row_max = tl.maximum(row_max, logit) 130 | 131 | loss = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 132 | for i in range(n_classes): 133 | col_offset = offsets * n_classes + i 134 | logit = tl.load(logits_ptr + col_offset, mask=mask, other=float('-inf')) 135 | exp_logit = tl.exp(logit - row_max) 136 | row_sum += exp_logit 137 | loss = tl.where(targets == i, loss - logit + row_max, loss) 138 | 139 | loss += tl.log(row_sum) 140 | 141 | tl.store(loss_ptr + offsets, loss, mask=mask) 142 | 143 | @triton.jit 144 | def gelu_kernel( 145 | x_ptr, y_ptr, n_elements, 146 | BLOCK_SIZE: tl.constexpr 147 | ): 148 | offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 149 | mask = offsets < n_elements 150 | 151 | x = tl.load(x_ptr + offsets, mask=mask) 152 | 153 | sqrt_2_over_pi = 0.7978845608028654 154 | coeff = sqrt_2_over_pi * (1 + 0.044715 * x * x) 155 | y = 0.5 * x * (1 + (x * coeff) / (1 + tl.abs(x * coeff))) 156 | 157 | tl.store(y_ptr + offsets, y, mask=mask) 158 | 159 | # ----------------------------------- 160 | # Triton-accelerated Launch Functions 161 | # ----------------------------------- 162 | 163 | class TritonSoftmax(nn.Module): 164 | def forward(self, x): 165 | original_shape = x.shape 166 | if len(original_shape) > 2: 167 | x = x.view(-1, original_shape[-1]) 168 | x = x.clamp(-100, 100) 169 | B, N = x.shape 170 | y = torch.empty_like(x) 171 | grid = lambda meta: (B,) 172 | softmax_kernel[grid]( 173 | y, x, 174 | x.stride(0), y.stride(0), N, 175 | BLOCK_SIZE=triton.next_power_of_2(N) 176 | ) 177 | y = y + 1e-8 178 | y = y / y.sum(dim=-1, keepdim=True) 179 | return y.view(original_shape) 180 | 181 | def triton_cross_entropy_loss(logits, targets): 182 | return TritonCrossEntropyLoss.apply(logits, targets) 183 | 184 | class TritonCrossEntropyLoss(torch.autograd.Function): 185 | @staticmethod 186 | def forward(ctx, logits, targets): 187 | n_elements, n_classes = logits.shape 188 | loss = torch.empty(n_elements, device=logits.device, dtype=logits.dtype) 189 | 190 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 191 | 192 | cross_entropy_loss_kernel[grid]( 193 | logits, targets, loss, 194 | n_classes, n_elements, 195 | BLOCK_SIZE=1024 196 | ) 197 | 198 | ctx.save_for_backward(logits, targets) 199 | return loss.mean() 200 | 201 | @staticmethod 202 | def backward(ctx, grad_output): 203 | logits, targets = ctx.saved_tensors 204 | batch_size, n_classes = logits.shape 205 | 206 | logits_exp = torch.exp(logits - logits.max(dim=-1, keepdim=True).values) 207 | softmax_output = logits_exp / logits_exp.sum(dim=-1, keepdim=True) 208 | 209 | grad_input = softmax_output.clone() 210 | grad_input.scatter_add_(1, targets.unsqueeze(1), -torch.ones_like(grad_input)) 211 | grad_input *= grad_output.view(-1, 1) / batch_size 212 | 213 | return grad_input, None 214 | 215 | 216 | class TritonLayerNorm(nn.Module): 217 | def __init__(self, normalized_shape, eps=1e-5): 218 | super().__init__() 219 | self.normalized_shape = tuple(normalized_shape) if isinstance(normalized_shape, (tuple, list)) else (normalized_shape,) 220 | self.weight = nn.Parameter(torch.ones(self.normalized_shape)) 221 | self.bias = nn.Parameter(torch.zeros(self.normalized_shape)) 222 | self.eps = eps 223 | 224 | def forward(self, x): 225 | assert x.shape[-len(self.normalized_shape):] == self.normalized_shape, "Input shape does not match normalized_shape." 226 | y = torch.empty_like(x) 227 | x_ = x.reshape(-1, self.normalized_shape[-1]) 228 | y_ = y.reshape(-1, self.normalized_shape[-1]) 229 | M, N = x_.shape 230 | grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']),) 231 | layer_norm_kernel[grid]( 232 | x_, self.weight, self.bias, y_, 233 | N, eps=self.eps, 234 | BLOCK_SIZE=128 235 | ) 236 | return y 237 | 238 | class TritonGELU(nn.Module): 239 | def forward(self, x): 240 | n_elements = x.numel() 241 | y = torch.empty_like(x) 242 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 243 | gelu_kernel[grid]( 244 | x, y, n_elements, 245 | BLOCK_SIZE=1024 246 | ) 247 | return y 248 | 249 | # ----------------------------- 250 | # Model 251 | # ----------------------------- 252 | 253 | class MultiHeadAttention(nn.Module): 254 | def __init__(self, dim, num_heads, seq_length, dropout=0.1): 255 | super().__init__() 256 | self.num_heads = num_heads 257 | self.head_dim = dim // num_heads 258 | self.scale = self.head_dim ** -0.5 259 | self.seq_length = seq_length 260 | 261 | self.qkv = nn.Linear(dim, dim * 3, bias=False) 262 | self.proj = nn.Linear(dim, dim, bias=False) 263 | self.dropout = nn.Dropout(dropout) 264 | 265 | self.softmax = TritonSoftmax() 266 | self.register_buffer("mask", torch.tril(torch.ones(seq_length, seq_length)).bool()) 267 | 268 | def forward(self, x): 269 | B, T, C = x.shape 270 | qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 271 | q, k, v = qkv[0], qkv[1], qkv[2] 272 | 273 | attn = (q @ k.transpose(-2, -1)) * self.scale 274 | attn = attn.masked_fill(~self.mask[:T, :T], float('-inf')) 275 | attn = self.softmax(attn) 276 | attn = self.dropout(attn) 277 | 278 | x = (attn @ v).transpose(1, 2).reshape(B, T, C) 279 | x = self.proj(x) 280 | return x 281 | 282 | class FeedForward(nn.Module): 283 | def __init__(self, dim, hidden_dim, dropout=0.1): 284 | super().__init__() 285 | self.net = nn.Sequential( 286 | nn.Linear(dim, hidden_dim, bias=False), 287 | TritonGELU(), 288 | nn.Linear(hidden_dim, dim, bias=False), 289 | nn.Dropout(dropout) 290 | ) 291 | 292 | def forward(self, x): 293 | return self.net(x) 294 | 295 | class TransformerBlock(nn.Module): 296 | def __init__(self, dim, num_heads, seq_length, dropout=0.1): 297 | super().__init__() 298 | self.attn = MultiHeadAttention(dim, num_heads, seq_length, dropout) 299 | self.ff = FeedForward(dim, 4 * dim, dropout) 300 | self.ln1 = TritonLayerNorm(dim) 301 | self.ln2 = TritonLayerNorm(dim) 302 | 303 | def forward(self, x): 304 | x = x + self.attn(self.ln1(x)) 305 | x = x + self.ff(self.ln2(x)) 306 | return x 307 | 308 | class NanoGPT(nn.Module): 309 | def __init__(self, vocab_size, dim, num_heads, num_layers, seq_length, dropout=0.1): 310 | super().__init__() 311 | self.dim = dim 312 | self.num_heads = num_heads 313 | self.num_layers = num_layers 314 | self.seq_length = seq_length 315 | 316 | self.token_embedding = nn.Embedding(vocab_size, dim) 317 | self.position_embedding = nn.Embedding(seq_length, dim) 318 | self.blocks = nn.ModuleList([ 319 | TransformerBlock(dim, num_heads, seq_length, dropout) 320 | for _ in range(num_layers) 321 | ]) 322 | self.ln_f = TritonLayerNorm(dim) 323 | self.head = nn.Linear(dim, vocab_size, bias=False) 324 | 325 | self.apply(self._init_weights) 326 | 327 | def _init_weights(self, module): 328 | if isinstance(module, nn.Linear): 329 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 330 | elif isinstance(module, nn.Embedding): 331 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 332 | 333 | def forward(self, idx): 334 | B, T = idx.shape 335 | assert T <= self.seq_length, f"Input sequence length {T} exceeds model's maximum sequence length {self.seq_length}" 336 | 337 | tok_emb = self.token_embedding(idx) 338 | pos_emb = self.position_embedding(torch.arange(T, device=idx.device)) 339 | x = tok_emb + pos_emb 340 | 341 | for block in self.blocks: 342 | x = block(x) 343 | 344 | x = self.ln_f(x) 345 | logits = self.head(x) 346 | 347 | return logits 348 | 349 | def compute_loss(self, logits, targets): 350 | return triton_cross_entropy_loss(logits.view(-1, logits.size(-1)), targets.view(-1)) 351 | 352 | #---------------------------- 353 | # Training 354 | #---------------------------- 355 | 356 | def train(model, train_data, val_data, batch_size, seq_length, learning_rate, num_epochs): 357 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1) 358 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) 359 | 360 | def get_batch(split): 361 | data = train_data if split == 'train' else val_data 362 | ix = torch.randint(len(data) - seq_length, (batch_size,)) 363 | x = torch.stack([data[i:i+seq_length] for i in ix]) 364 | y = torch.stack([data[i+1:i+seq_length+1] for i in ix]) 365 | return x.to(model.token_embedding.weight.device), y.to(model.token_embedding.weight.device) 366 | 367 | def estimate_mfu(model, dt): 368 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 369 | # first estimate the number of flops we do per iteration. 370 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 371 | N = sum(p.numel() for p in model.parameters()) 372 | L, H, Q, T = model.num_layers, model.num_heads, model.dim // model.num_heads, model.seq_length 373 | flops_per_token = 6*N + 12*L*H*Q*T 374 | flops_per_fwdbwd = flops_per_token * T * batch_size # multiply by batch size 375 | flops_achieved = flops_per_fwdbwd * (1.0/dt) # per second 376 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 377 | mfu = flops_achieved / flops_promised 378 | return mfu 379 | 380 | iter_num = 0 381 | best_val_loss = float('inf') 382 | val_losses = [] 383 | 384 | model.train() 385 | t0 = time.time() 386 | for epoch in range(num_epochs): 387 | for _ in range(100): # 100 batches per epoch 388 | iter_num += 1 389 | 390 | t_start = time.time() 391 | 392 | # Data loading 393 | xb, yb = get_batch('train') 394 | t_data = time.time() 395 | 396 | # Forward pass 397 | logits = model(xb) 398 | t_forward = time.time() 399 | 400 | # Loss computation 401 | loss = model.compute_loss(logits, yb) 402 | t_loss = time.time() 403 | 404 | if torch.isnan(loss).any() or torch.isinf(loss).any(): 405 | print(f"Warning: NaN or Inf detected in loss at iteration {iter_num}") 406 | print(f"Logits min: {logits.min()}, max: {logits.max()}") 407 | print(f"Target min: {yb.min()}, max: {yb.max()}") 408 | continue 409 | 410 | # Backward pass 411 | optimizer.zero_grad() 412 | loss.backward() 413 | t_backward = time.time() 414 | 415 | # Optimizer step 416 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 417 | optimizer.step() 418 | torch.cuda.synchronize() 419 | t_optim = time.time() 420 | 421 | if iter_num % 10 == 0: 422 | dt = t_optim - t_start 423 | dt_data = t_data - t_start 424 | dt_forward = t_forward - t_data 425 | dt_loss = t_loss - t_forward 426 | dt_backward = t_backward - t_loss 427 | dt_optim = t_optim - t_backward 428 | mfu = estimate_mfu(model, dt) 429 | 430 | print(f"iter {iter_num}: loss {loss.item():.4f}, time {dt*1000:.2f}ms, mfu {mfu*100:.2f}%") 431 | # print(f" Data loading: {dt_data*1000:.2f}ms") 432 | # print(f" Forward pass: {dt_forward*1000:.2f}ms") 433 | # print(f" Loss computation: {dt_loss*1000:.2f}ms") 434 | # print(f" Backward pass: {dt_backward*1000:.2f}ms") 435 | # print(f" Optimizer step: {dt_optim*1000:.2f}ms") 436 | # print(f" Other time: {(dt - dt_data - dt_forward - dt_loss - dt_backward - dt_optim)*1000:.2f}ms") 437 | 438 | scheduler.step() 439 | 440 | # Validation 441 | model.eval() 442 | val_loss = 0 443 | with torch.no_grad(): 444 | for _ in range(50): # 50 val batches 445 | xb, yb = get_batch('val') 446 | logits = model(xb) 447 | val_loss += model.compute_loss(logits, yb).item() 448 | val_loss /= 50 449 | val_losses.append(val_loss) 450 | print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}") 451 | 452 | if val_loss < best_val_loss: 453 | best_val_loss = val_loss 454 | torch.save(model.state_dict(), 'Checkpoints/nanoGPT_cpkt.pth') 455 | print(f"Saved checkpoint for validation loss: {best_val_loss:.4f}") 456 | 457 | model.train() 458 | 459 | return model, val_losses 460 | 461 | if __name__ == "__main__": 462 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 463 | print(f"Using device: {device}") 464 | 465 | # Hyperparameters 466 | vocab_size = 65 467 | dim = 384 468 | num_heads = 6 469 | num_layers = 6 470 | seq_length = 256 471 | dropout = 0.1 472 | batch_size = 64 473 | learning_rate = 3e-4 474 | num_epochs = 500 475 | 476 | model = NanoGPT( 477 | vocab_size=vocab_size, 478 | dim=dim, 479 | num_heads=num_heads, 480 | num_layers=num_layers, 481 | seq_length=seq_length, 482 | dropout=dropout 483 | ).to(device) 484 | 485 | model.config = type('Config', (), { 486 | 'n_layer': num_layers, 487 | 'n_head': num_heads, 488 | 'n_embd': dim, 489 | 'block_size': seq_length 490 | }) 491 | 492 | # Train config 493 | model, validation_losses = train( 494 | model, 495 | train_data, 496 | test_data, 497 | batch_size=batch_size, 498 | seq_length=seq_length, 499 | learning_rate=learning_rate, 500 | num_epochs=num_epochs 501 | ) 502 | 503 | # Load checkpoint 504 | model.load_state_dict(torch.load('Checkpoints/nanoGPT_cpkt.pth', weights_only=True)) 505 | 506 | # Generate sample 507 | model.eval() 508 | start_text = "Once upon" 509 | input_ids = encode(start_text).unsqueeze(0).to(device) 510 | with torch.no_grad(): 511 | for _ in range(240): 512 | logits = model(input_ids) 513 | next_token_logits = logits[:, -1, :] 514 | next_token_logits = torch.clamp(next_token_logits, -100, 100) 515 | probs = F.softmax(next_token_logits, dim=-1) + 1e-8 516 | probs = probs / probs.sum() 517 | if torch.isnan(probs).any() or torch.isinf(probs).any(): 518 | probs = torch.ones_like(probs) / probs.shape[-1] 519 | 520 | next_token = torch.multinomial(probs, num_samples=1) 521 | input_ids = torch.cat([input_ids, next_token], dim=1) 522 | 523 | generated_text = decode(input_ids[0].cpu()) 524 | print("Generated Text:") 525 | print(generated_text) 526 | --------------------------------------------------------------------------------