├── .gitignore ├── part_1 ├── out │ ├── multi_head_attn_grid.png │ └── mha_shapes.txt ├── tests │ ├── test_causal_mask.py │ └── test_attn_math.py ├── attn_mask.py ├── demo_visualize_multi_head.py ├── block.py ├── ffn.py ├── pos_encoding.py ├── single_head.py ├── vis_utils.py ├── attn_numpy_demo.py ├── demo_mha_shapes.py ├── multi_head.py └── orchestrator.py ├── .vscode ├── settings.json └── launch.json ├── part_3 ├── tests │ ├── test_rmsnorm.py │ ├── test_kvcache_shapes.py │ └── test_rope_apply.py ├── tokenizer.py ├── rmsnorm.py ├── swiglu.py ├── utils.py ├── block_modern.py ├── kv_cache.py ├── orchestrator.py ├── demo_generate.py ├── rope_custom.py ├── attn_modern.py └── model_modern.py ├── part_2 ├── tests │ ├── test_tokenizer.py │ └── test_dataset_shift.py ├── tokenizer.py ├── dataset.py ├── utils.py ├── eval_loss.py ├── sample.py ├── orchestrator.py ├── model_gpt.py └── train.py ├── part_5 ├── tests │ ├── test_hybrid_block.py │ ├── test_gate_shapes.py │ └── test_moe_forward.py ├── block_hybrid.py ├── experts.py ├── demo_moe.py ├── moe.py ├── orchestrator.py ├── gating.py └── README.md ├── part_7 ├── tests │ ├── test_bt_loss.py │ └── test_reward_forward.py ├── loss_reward.py ├── model_reward.py ├── eval_rm.py ├── data_prefs.py ├── orchestrator.py ├── collator_rm.py └── train_rm.py ├── part_8 ├── tests │ ├── test_policy_forward.py │ └── test_ppo_loss.py ├── ppo_loss.py ├── policy.py ├── orchestrator.py ├── eval_ppo.py ├── rollout.py └── train_ppo.py ├── part_6 ├── tests │ ├── test_formatter.py │ └── test_masking.py ├── curriculum.py ├── formatters.py ├── evaluate.py ├── dataset_sft.py ├── sample_sft.py ├── orchestrator.py ├── collator_sft.py └── train_sft.py ├── part_4 ├── tests │ ├── test_scheduler.py │ ├── test_tokenizer_bpe.py │ └── test_resume_shapes.py ├── lr_scheduler.py ├── amp_accum.py ├── dataset_bpe.py ├── orchestrator.py ├── tokenizer_bpe.py ├── sample.py ├── logger.py ├── train.py └── checkpointing.py ├── part_9 ├── tests │ └── test_grpo_loss.py ├── policy.py ├── grpo_loss.py ├── orchestrator.py ├── eval_ppo.py ├── rollout.py └── train_grpo.py ├── requirements.txt └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .pytest_cache/ 3 | .DS_Store 4 | runs/ -------------------------------------------------------------------------------- /part_1/out/multi_head_attn_grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vivekkalyanarangan30/llm_from_scratch/HEAD/part_1/out/multi_head_attn_grid.png -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python-envs.defaultEnvManager": "ms-python.python:conda", 3 | "python-envs.defaultPackageManager": "ms-python.python:conda", 4 | "python-envs.pythonProjects": [] 5 | } -------------------------------------------------------------------------------- /part_3/tests/test_rmsnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rmsnorm import RMSNorm 3 | 4 | def test_rmsnorm_shapes(): 5 | x = torch.randn(2,3,8) 6 | rn = RMSNorm(8) 7 | y = rn(x) 8 | assert y.shape == x.shape -------------------------------------------------------------------------------- /part_1/tests/test_causal_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from attn_mask import causal_mask 3 | 4 | def test_mask_is_upper_triangle(): 5 | m = causal_mask(5) 6 | # ensure shape and diagonal rule 7 | assert m.shape == (1,1,5,5) 8 | assert m[0,0].sum() == torch.triu(torch.ones(5,5), diagonal=1).sum() 9 | -------------------------------------------------------------------------------- /part_2/tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tokenizer import ByteTokenizer 3 | 4 | def test_roundtrip(): 5 | tok = ByteTokenizer() 6 | s = "Hello, ByteTok! äö" 7 | ids = tok.encode(s) 8 | assert ids.dtype == torch.long 9 | s2 = tok.decode(ids) 10 | assert len(s2) > 0 -------------------------------------------------------------------------------- /part_5/tests/test_hybrid_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from block_hybrid import HybridFFN 3 | 4 | def test_hybrid_ffn_blend(): 5 | B,T,C = 1, 4, 16 6 | ffn = HybridFFN(dim=C, alpha=0.3, n_expert=3, k=2) 7 | x = torch.randn(B,T,C) 8 | y, aux = ffn(x) 9 | assert y.shape == x.shape 10 | assert aux.item() >= 0.0 -------------------------------------------------------------------------------- /part_7/tests/test_bt_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from loss_reward import bradley_terry_loss 3 | 4 | def test_bradley_terry_monotonic(): 5 | pos = torch.tensor([2.0, 3.0]) 6 | neg = torch.tensor([1.0, 1.5]) 7 | l1 = bradley_terry_loss(pos, neg) 8 | l2 = bradley_terry_loss(pos+1.0, neg) # increase margin 9 | assert l2 < l1 -------------------------------------------------------------------------------- /part_1/attn_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def causal_mask(T: int, device=None): 4 | """Returns a bool mask where True means *masked* (disallowed). 5 | Shape: (1, 1, T, T) suitable for broadcasting with (B, heads, T, T). 6 | """ 7 | m = torch.triu(torch.ones((T, T), dtype=torch.bool, device=device), diagonal=1) 8 | return m.view(1, 1, T, T) -------------------------------------------------------------------------------- /part_8/tests/test_policy_forward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from policy import PolicyWithValue 3 | 4 | def test_policy_shapes(): 5 | B,T,V = 2, 16, 256 6 | pol = PolicyWithValue(vocab_size=V, block_size=T, n_layer=2, n_head=2, n_embd=64) 7 | x = torch.randint(0, V, (B,T)) 8 | logits, values, loss = pol(x, None) 9 | assert logits.shape == (B,T,V) 10 | assert values.shape == (B,T) 11 | -------------------------------------------------------------------------------- /part_6/tests/test_formatter.py: -------------------------------------------------------------------------------- 1 | from formatters import Example, format_example, format_prompt_only 2 | 3 | def test_template_contains_markers(): 4 | ex = Example("Say hi","Hello!") 5 | s = format_example(ex) 6 | assert "### Instruction:" in s and "### Response:" in s 7 | p = format_prompt_only("Explain transformers.") 8 | assert p.endswith("### Response:\n") or p.endswith("### Response:\n") -------------------------------------------------------------------------------- /part_2/tests/test_dataset_shift.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataset import ByteDataset 3 | 4 | 5 | def test_shift_alignment(tmp_path): 6 | p = tmp_path / 'toy.txt' 7 | p.write_text('abcdefg') 8 | ds = ByteDataset(str(p), block_size=3, split=1.0) 9 | x, y = ds.get_batch('train', 2, device=torch.device('cpu')) 10 | # shift must be next-token 11 | assert (y[:, :-1] == x[:, 1:]).all() 12 | -------------------------------------------------------------------------------- /part_3/tests/test_kvcache_shapes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kv_cache import RollingKV 3 | 4 | def test_rolling_kv_keep_window_with_sink(): 5 | B,H,D = 1,2,4 6 | kv = RollingKV(window=4, sink=2) 7 | for _ in range(10): 8 | k_new = torch.randn(B,H,1,D) 9 | v_new = torch.randn(B,H,1,D) 10 | k,v = kv.step(k_new, v_new) 11 | # Should never exceed sink+window length 12 | assert k.size(2) <= 6 -------------------------------------------------------------------------------- /part_5/tests/test_gate_shapes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from gating import TopKGate 3 | 4 | def test_gate_topk_shapes(): 5 | S, C, E, K = 32, 64, 4, 2 6 | x = torch.randn(S, C) 7 | gate = TopKGate(C, E, k=K) 8 | idx, w, aux = gate(x) 9 | assert idx.shape == (S, K) 10 | assert w.shape == (S, K) 11 | assert aux.ndim == 0 12 | # per-token weights are non-negative and <=1 13 | assert torch.all(w >= 0) 14 | assert torch.all(w <= 1) -------------------------------------------------------------------------------- /part_8/tests/test_ppo_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ppo_loss import ppo_losses 3 | 4 | def test_clipped_objective_behaves(): 5 | N = 32 6 | old_logp = torch.zeros(N) 7 | new_logp = torch.log(torch.full((N,), 1.2)) # ratio=1.2 8 | adv = torch.ones(N) 9 | new_v = torch.zeros(N) 10 | old_v = torch.zeros(N) 11 | ret = torch.ones(N) 12 | out = ppo_losses(new_logp, old_logp, adv, new_v, old_v, ret, clip_ratio=0.1) 13 | assert out.total_loss.ndim == 0 -------------------------------------------------------------------------------- /part_5/tests/test_moe_forward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from moe import MoE 3 | 4 | def test_moe_forward_dims_and_grad(): 5 | B,T,C = 2, 8, 32 6 | moe = MoE(dim=C, n_expert=4, k=1) 7 | x = torch.randn(B,T,C, requires_grad=True) 8 | y, aux = moe(x) 9 | assert y.shape == x.shape 10 | loss = (y**2).mean() + 0.01*aux 11 | loss.backward() 12 | # some gradient must flow to gate and experts 13 | grads = [p.grad is not None for p in moe.parameters()] 14 | assert any(grads) -------------------------------------------------------------------------------- /part_7/tests/test_reward_forward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model_reward import RewardModel 3 | 4 | def test_reward_shapes_and_grad(): 5 | B,T,V = 4, 16, 256 6 | m = RewardModel(vocab_size=V, block_size=T, n_layer=2, n_head=2, n_embd=64) 7 | x = torch.randint(0, V, (B,T)) 8 | r = m(x) 9 | assert r.shape == (B,) 10 | # gradient flows 11 | loss = (r**2).mean() 12 | loss.backward() 13 | grads = [p.grad is not None for p in m.parameters()] 14 | assert any(grads) -------------------------------------------------------------------------------- /part_1/demo_visualize_multi_head.py: -------------------------------------------------------------------------------- 1 | """Visualize multi-head attention weights per head (grid).""" 2 | import torch 3 | from multi_head import MultiHeadSelfAttention 4 | from vis_utils import save_attention_heads_grid 5 | 6 | B, T, d_model, n_head = 1, 5, 12, 3 7 | x = torch.randn(B, T, d_model) 8 | attn = MultiHeadSelfAttention(d_model, n_head, trace_shapes=False) 9 | 10 | out, w = attn(x) # w: (B, H, T, T) 11 | 12 | save_attention_heads_grid(w.detach().cpu().numpy(), filename="multi_head_attn_grid.png") -------------------------------------------------------------------------------- /part_4/tests/test_scheduler.py: -------------------------------------------------------------------------------- 1 | from lr_scheduler import WarmupCosineLR 2 | 3 | class DummyOpt: 4 | def __init__(self): 5 | self.param_groups = [{'lr': 0.0}] 6 | 7 | def test_warmup_cosine_lr_progression(): 8 | opt = DummyOpt() 9 | sch = WarmupCosineLR(opt, warmup_steps=10, total_steps=110, base_lr=1e-3) 10 | lrs = [sch.step() for _ in range(110)] 11 | assert max(lrs) <= 1e-3 + 1e-12 12 | assert lrs[0] < lrs[9] # warmup increasing 13 | assert lrs[-1] < lrs[10] # cosine decays -------------------------------------------------------------------------------- /part_3/tokenizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | 4 | class ByteTokenizer: 5 | """Simple byte-level tokenizer (0..255).""" 6 | def encode(self, s: str) -> torch.Tensor: 7 | return torch.tensor(list(s.encode('utf-8')), dtype=torch.long) 8 | def decode(self, ids) -> str: 9 | if isinstance(ids, torch.Tensor): 10 | ids = ids.tolist() 11 | return bytes(ids).decode('utf-8', errors='ignore') 12 | @property 13 | def vocab_size(self) -> int: 14 | return 256 -------------------------------------------------------------------------------- /part_3/rmsnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class RMSNorm(nn.Module): 5 | """Root Mean Square Layer Normalization. 6 | y = x * g / rms(x), rms(x) = sqrt(mean(x^2) + eps) 7 | """ 8 | def __init__(self, dim: int, eps: float = 1e-8): 9 | super().__init__() 10 | self.eps = eps 11 | self.weight = nn.Parameter(torch.ones(dim)) 12 | def forward(self, x: torch.Tensor) -> torch.Tensor: 13 | rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt() 14 | return (x / rms) * self.weight -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | 8 | 9 | 10 | { 11 | "name": "Python Debugger: Current File", 12 | "type": "debugpy", 13 | "request": "launch", 14 | "program": "${file}", 15 | "console": "integratedTerminal", 16 | "args": ["--demo"] 17 | } 18 | ] 19 | } -------------------------------------------------------------------------------- /part_6/tests/test_masking.py: -------------------------------------------------------------------------------- 1 | from collator_sft import SFTCollator 2 | from formatters import Example 3 | 4 | def test_masking_sets_prompt_to_ignore(): 5 | col = SFTCollator(block_size=256, bpe_dir='../part_4/runs/part4-demo/tokenizer') 6 | text = "This is a tiny test." 7 | x, y = col.collate([(text, "OK")]) 8 | # All labels up to response marker should be -100 9 | boundary = ("\n### Instruction:\n" + text + "\n\n### Response:\n") 10 | # We don't have direct access to the tokenized boundary; just sanity check: some -100s present 11 | assert (y[0] == -100).sum() > 0 12 | -------------------------------------------------------------------------------- /part_6/curriculum.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List 3 | 4 | class LengthCurriculum: 5 | """6.3 Curriculum: iterate examples from short→long prompts (one pass demo).""" 6 | def __init__(self, items: List[tuple[str,str]]): 7 | self.items = sorted(items, key=lambda p: len(p[0])) 8 | self._i = 0 9 | def __iter__(self): 10 | self._i = 0 11 | return self 12 | def __next__(self): 13 | if self._i >= len(self.items): 14 | raise StopIteration 15 | it = self.items[self._i] 16 | self._i += 1 17 | return it -------------------------------------------------------------------------------- /part_4/tests/test_tokenizer_bpe.py: -------------------------------------------------------------------------------- 1 | import os, tempfile 2 | from tokenizer_bpe import BPETokenizer 3 | 4 | def test_bpe_train_save_load_roundtrip(): 5 | with tempfile.TemporaryDirectory() as d: 6 | txt = os.path.join(d, 'tiny.txt') 7 | with open(txt, 'w') as f: 8 | f.write('hello hello world') 9 | tok = BPETokenizer(vocab_size=100) 10 | tok.train(txt) 11 | out = os.path.join(d, 'tok') 12 | tok.save(out) 13 | tok2 = BPETokenizer() 14 | tok2.load(out) 15 | ids = tok2.encode('hello world') 16 | assert isinstance(ids, list) and len(ids) > 0 -------------------------------------------------------------------------------- /part_2/tokenizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | 4 | class ByteTokenizer: 5 | """Ultra-simple byte-level tokenizer. 6 | - encode(str) -> LongTensor [N] 7 | - decode(Tensor[int]) -> str 8 | - vocab_size = 256 9 | """ 10 | def encode(self, s: str) -> torch.Tensor: 11 | return torch.tensor(list(s.encode('utf-8')), dtype=torch.long) 12 | 13 | def decode(self, ids) -> str: 14 | if isinstance(ids, torch.Tensor): 15 | ids = ids.tolist() 16 | return bytes(ids).decode('utf-8', errors='ignore') 17 | 18 | @property 19 | def vocab_size(self) -> int: 20 | return 256 -------------------------------------------------------------------------------- /part_3/swiglu.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class SwiGLU(nn.Module): 4 | """SwiGLU FFN: (xW1) ⊗ swish(xW2) W3 with expansion factor `mult`. 5 | """ 6 | def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0): 7 | super().__init__() 8 | inner = mult * dim 9 | self.w1 = nn.Linear(dim, inner, bias=False) 10 | self.w2 = nn.Linear(dim, inner, bias=False) 11 | self.w3 = nn.Linear(inner, dim, bias=False) 12 | self.act = nn.SiLU() 13 | self.drop = nn.Dropout(dropout) 14 | def forward(self, x): 15 | a = self.w1(x) 16 | b = self.act(self.w2(x)) 17 | return self.drop(self.w3(a * b)) -------------------------------------------------------------------------------- /part_6/formatters.py: -------------------------------------------------------------------------------- 1 | """Prompt/response formatting utilities (6.1). 2 | We keep a very simple template with clear separators. 3 | """ 4 | from dataclasses import dataclass 5 | 6 | template = ( 7 | "\n" 8 | "### Instruction:\n{instruction}\n\n" 9 | "### Response:\n{response}" 10 | ) 11 | 12 | @dataclass 13 | class Example: 14 | instruction: str 15 | response: str 16 | 17 | 18 | def format_example(ex: Example) -> str: 19 | return template.format(instruction=ex.instruction.strip(), response=ex.response.strip()) 20 | 21 | 22 | def format_prompt_only(instruction: str) -> str: 23 | return template.format(instruction=instruction.strip(), response="") -------------------------------------------------------------------------------- /part_7/loss_reward.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch, torch.nn.functional as F 3 | 4 | def bradley_terry_loss(r_pos: torch.Tensor, r_neg: torch.Tensor) -> torch.Tensor: 5 | """-log sigma(r_pos - r_neg) = softplus(-(r_pos - r_neg)) 6 | https://docs.pytorch.org/docs/stable/generated/torch.nn.Softplus.html""" 7 | diff = r_pos - r_neg 8 | return F.softplus(-diff).mean() 9 | 10 | 11 | def margin_ranking_loss(r_pos: torch.Tensor, r_neg: torch.Tensor, margin: float = 1.0) -> torch.Tensor: 12 | """https://docs.pytorch.org/docs/stable/generated/torch.nn.MarginRankingLoss.html""" 13 | y = torch.ones_like(r_pos) 14 | return F.margin_ranking_loss(r_pos, r_neg, y, margin=margin) -------------------------------------------------------------------------------- /part_1/block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from multi_head import MultiHeadSelfAttention 3 | from ffn import FeedForward 4 | 5 | class TransformerBlock(nn.Module): 6 | """1.6 Transformer block = LN → MHA → residual → LN → FFN → residual.""" 7 | def __init__(self, d_model: int, n_head: int, dropout: float = 0.0): 8 | super().__init__() 9 | self.ln1 = nn.LayerNorm(d_model) 10 | self.attn = MultiHeadSelfAttention(d_model, n_head, dropout) 11 | self.ln2 = nn.LayerNorm(d_model) 12 | self.ffn = FeedForward(d_model, mult=4, dropout=dropout) 13 | 14 | def forward(self, x): 15 | x = x + self.attn(self.ln1(x))[0] 16 | x = x + self.ffn(self.ln2(x)) 17 | return x -------------------------------------------------------------------------------- /part_1/out/mha_shapes.txt: -------------------------------------------------------------------------------- 1 | Input x: (1, 5, 12) = (B,T,d_model) 2 | Linear qkv(x): (1, 5, 36) = (B,T,3*d_model) 3 | view to 5D: (1, 5, 3, 3, 4) = (B,T,3,heads,d_head) 4 | q,k,v split: q=(1, 5, 3, 4) k=(1, 5, 3, 4) v=(1, 5, 3, 4) 5 | transpose heads: q=(1, 3, 5, 4) k=(1, 3, 5, 4) v=(1, 3, 5, 4) = (B,heads,T,d_head) 6 | scores q@k^T: (1, 3, 5, 5) = (B,heads,T,T) 7 | softmax(weights): (1, 3, 5, 5) = (B,heads,T,T) 8 | context @v: (1, 3, 5, 4) = (B,heads,T,d_head) 9 | merge heads: (1, 5, 12) = (B,T,d_model) 10 | final proj: (1, 5, 12) = (B,T,d_model) 11 | 12 | Legend: 13 | B=batch, T=sequence length, d_model=embedding size, heads=n_head, d_head=d_model/heads 14 | qkv(x) is a single Linear producing [Q|K|V]; we reshape then split into q,k,v 15 | -------------------------------------------------------------------------------- /part_1/ffn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class FeedForward(nn.Module): 4 | """1.5 FFN with expansion factor `mult`. 5 | 6 | Dimensions: 7 | input: (B, T, d_model) 8 | inner: (B, T, mult*d_model) 9 | output: (B, T, d_model) 10 | 11 | `mult*d_model` means the hidden width is `mult` times larger than `d_model`. 12 | Typical values: mult=4 for GELU FFN in GPT-style blocks. 13 | """ 14 | def __init__(self, d_model: int, mult: int = 4, dropout: float = 0.0): 15 | super().__init__() 16 | self.net = nn.Sequential( 17 | nn.Linear(d_model, mult * d_model), 18 | nn.GELU(), 19 | nn.Linear(mult * d_model, d_model), 20 | nn.Dropout(dropout), 21 | ) 22 | 23 | def forward(self, x): 24 | return self.net(x) -------------------------------------------------------------------------------- /part_6/evaluate.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import re 3 | from typing import List, Tuple 4 | 5 | def _normalize(s: str) -> str: 6 | s = s.lower() 7 | s = re.sub(r"[^a-z0-9\s]", " ", s) 8 | s = re.sub(r"\s+", " ", s).strip() 9 | return s 10 | 11 | def exact_match(pred: str, gold: str) -> float: 12 | return float(_normalize(pred) == _normalize(gold)) 13 | 14 | def token_f1(pred: str, gold: str) -> float: 15 | p = _normalize(pred).split() 16 | g = _normalize(gold).split() 17 | if not p and not g: 18 | return 1.0 19 | if not p or not g: 20 | return 0.0 21 | common = 0 22 | gp = g.copy() 23 | for t in p: 24 | if t in gp: 25 | gp.remove(t); common += 1 26 | if common == 0: 27 | return 0.0 28 | prec = common / len(p) 29 | rec = common / len(g) 30 | return 2*prec*rec/(prec+rec) -------------------------------------------------------------------------------- /part_4/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | class WarmupCosineLR: 4 | """Linear warmup → cosine decay (per-step API).""" 5 | def __init__(self, optimizer, warmup_steps: int, total_steps: int, base_lr: float): 6 | self.optimizer = optimizer 7 | self.warmup_steps = max(1, warmup_steps) 8 | self.total_steps = max(self.warmup_steps+1, total_steps) 9 | self.base_lr = base_lr 10 | self.step_num = 0 11 | def step(self): 12 | self.step_num += 1 13 | if self.step_num <= self.warmup_steps: 14 | lr = self.base_lr * self.step_num / self.warmup_steps 15 | else: 16 | progress = (self.step_num - self.warmup_steps) / (self.total_steps - self.warmup_steps) 17 | lr = 0.5 * self.base_lr * (1.0 + math.cos(math.pi * progress)) 18 | for g in self.optimizer.param_groups: 19 | g['lr'] = lr 20 | return lr -------------------------------------------------------------------------------- /part_9/tests/test_grpo_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from grpo_loss import ppo_policy_only_losses 3 | 4 | def test_grpo_clipped_objective_behaves(): 5 | N = 32 6 | old_logp = torch.zeros(N) # log π_old = 0 → π_old = 1.0 7 | new_logp = torch.log(torch.full((N,), 1.2)) # log π_new = log(1.2) → ratio=1.2 8 | adv = torch.ones(N) # positive advantage 9 | kl_mean = torch.tensor(0.5) # pretend KL(π||π_ref)=0.5 10 | 11 | out = ppo_policy_only_losses( 12 | new_logp=new_logp, 13 | old_logp=old_logp, 14 | adv=adv, 15 | clip_ratio=0.1, 16 | ent_coef=0.0, 17 | kl_coef=0.1, 18 | kl_mean=kl_mean, 19 | ) 20 | 21 | # Ensure scalar loss 22 | assert out.total_loss.ndim == 0 23 | assert torch.isfinite(out.policy_loss) 24 | # KL penalty should have been added 25 | assert torch.allclose(out.kl_ref, kl_mean) 26 | -------------------------------------------------------------------------------- /part_5/block_hybrid.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch.nn as nn 3 | from moe import MoE 4 | 5 | class HybridFFN(nn.Module): 6 | """Blend dense FFN with MoE output: y = α * Dense(x) + (1−α) * MoE(x). 7 | Use α∈[0,1] to trade between stability (dense) and capacity (MoE). 8 | """ 9 | def __init__(self, dim: int, alpha: float = 0.5, mult: int = 4, swiglu: bool = True, n_expert: int = 4, k: int = 1, dropout: float = 0.0): 10 | super().__init__() 11 | self.alpha = alpha 12 | inner = mult * dim 13 | self.dense = nn.Sequential( 14 | nn.Linear(dim, inner), nn.GELU(), nn.Linear(inner, dim), nn.Dropout(dropout) 15 | ) 16 | self.moe = MoE(dim, n_expert=n_expert, k=k, mult=mult, swiglu=swiglu, dropout=dropout) 17 | def forward(self, x): 18 | y_dense = self.dense(x) 19 | y_moe, aux = self.moe(x) 20 | y = self.alpha * y_dense + (1.0 - self.alpha) * y_moe 21 | return y, aux -------------------------------------------------------------------------------- /part_5/experts.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch.nn as nn 3 | 4 | class ExpertMLP(nn.Module): 5 | """Single expert MLP (SwiGLU or GELU).""" 6 | def __init__(self, dim: int, mult: int = 4, swiglu: bool = True, dropout: float = 0.0): 7 | super().__init__() 8 | inner = mult * dim 9 | if swiglu: 10 | self.inp1 = nn.Linear(dim, inner, bias=False) 11 | self.inp2 = nn.Linear(dim, inner, bias=False) 12 | self.act = nn.SiLU() 13 | self.out = nn.Linear(inner, dim, bias=False) 14 | self.drop = nn.Dropout(dropout) 15 | self.swiglu = True 16 | else: 17 | self.ff = nn.Sequential( 18 | nn.Linear(dim, inner), nn.GELU(), nn.Linear(inner, dim), nn.Dropout(dropout) 19 | ) 20 | self.swiglu = False 21 | def forward(self, x): 22 | if self.swiglu: 23 | a = self.inp1(x); b = self.act(self.inp2(x)) 24 | return self.drop(self.out(a * b)) 25 | else: 26 | return self.ff(x) -------------------------------------------------------------------------------- /part_5/demo_moe.py: -------------------------------------------------------------------------------- 1 | import argparse, torch 2 | from moe import MoE 3 | 4 | if __name__ == "__main__": 5 | p = argparse.ArgumentParser() 6 | p.add_argument('--tokens', type=int, default=64) 7 | p.add_argument('--hidden', type=int, default=128) 8 | p.add_argument('--experts', type=int, default=4) 9 | p.add_argument('--top_k', type=int, default=1) 10 | p.add_argument('--cpu', action='store_true') 11 | args = p.parse_args() 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 14 | x = torch.randn(2, args.tokens//2, args.hidden, device=device) # (B=2,T=tokens/2,C) 15 | 16 | moe = MoE(dim=args.hidden, n_expert=args.experts, k=args.top_k).to(device) 17 | with torch.no_grad(): 18 | y, aux = moe(x) 19 | 20 | # simple routing histogram 21 | from gating import TopKGate 22 | gate = moe.gate 23 | idx, w, _ = gate(x.view(-1, args.hidden)) 24 | hist = torch.bincount(idx[:,0], minlength=args.experts) 25 | print(f"Output shape: {tuple(y.shape)} | aux={float(aux):.4f}") 26 | print("Primary expert load (counts):", hist.tolist()) -------------------------------------------------------------------------------- /part_1/tests/test_attn_math.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from single_head import SingleHeadSelfAttention 4 | 5 | # mirror the tiny example in attn_numpy_demo.py 6 | X = np.array([[[0.1, 0.2, 0.3, 0.4], 7 | [0.5, 0.4, 0.3, 0.2], 8 | [0.0, 0.1, 0.0, 0.1]]], dtype=np.float32) 9 | Wq = np.array([[ 0.2, -0.1],[ 0.0, 0.1],[ 0.1, 0.2],[-0.1, 0.0]], dtype=np.float32) 10 | Wk = np.array([[ 0.1, 0.1],[ 0.0, -0.1],[ 0.2, 0.0],[ 0.0, 0.2]], dtype=np.float32) 11 | Wv = np.array([[ 0.1, 0.0],[-0.1, 0.1],[ 0.2, -0.1],[ 0.0, 0.2]], dtype=np.float32) 12 | 13 | def test_single_head_matches_numpy(): 14 | torch.manual_seed(0) 15 | x = torch.tensor(X) 16 | attn = SingleHeadSelfAttention(d_model=4, d_k=2) 17 | # load weights 18 | with torch.no_grad(): 19 | attn.q.weight.copy_(torch.tensor(Wq).t()) 20 | attn.k.weight.copy_(torch.tensor(Wk).t()) 21 | attn.v.weight.copy_(torch.tensor(Wv).t()) 22 | out, w = attn(x) 23 | assert out.shape == (1,3,2) 24 | # Basic numeric sanity 25 | assert torch.isfinite(out).all() 26 | assert torch.isfinite(w).all() -------------------------------------------------------------------------------- /part_2/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from pathlib import Path 3 | import torch 4 | 5 | class ByteDataset: 6 | """Holds raw bytes of a text file and yields (x,y) blocks for LM. 7 | - block_size: sequence length (context window) 8 | - split: fraction for training (rest is val) 9 | """ 10 | def __init__(self, path: str, block_size: int = 256, split: float = 0.9): 11 | data = Path(path).read_bytes() 12 | data = torch.tensor(list(data), dtype=torch.long) 13 | n = int(len(data) * split) 14 | self.train = data[:n] 15 | self.val = data[n:] 16 | self.block_size = block_size 17 | 18 | def get_batch(self, which: str, batch_size: int, device: torch.device): 19 | buf = self.train if which == 'train' else self.val 20 | assert len(buf) > self.block_size + 1, 'file too small for given block_size' 21 | ix = torch.randint(0, len(buf) - self.block_size - 1, (batch_size,)) 22 | x = torch.stack([buf[i:i+self.block_size] for i in ix]) 23 | y = torch.stack([buf[i+1:i+1+self.block_size] for i in ix]) 24 | return x.to(device), y.to(device) -------------------------------------------------------------------------------- /part_4/tests/test_resume_shapes.py: -------------------------------------------------------------------------------- 1 | import torch, tempfile, os 2 | import torch.nn as nn 3 | from checkpointing import save_checkpoint, load_checkpoint 4 | 5 | class Dummy(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.l = torch.nn.Linear(8,8) 9 | 10 | 11 | def test_save_and_load(tmp_path): 12 | m = Dummy(); opt = torch.optim.AdamW(m.parameters(), lr=1e-3) 13 | class S: pass 14 | sch = S(); sch.__dict__ = {'warmup_steps': 10, 'total_steps': 100, 'base_lr': 1e-3, 'step_num': 5} 15 | class A: pass 16 | amp = A(); amp.scaler = torch.cuda.amp.GradScaler(enabled=False) 17 | 18 | out = tmp_path/"chk" 19 | save_checkpoint(m, opt, sch, amp, step=123, out_dir=str(out), tokenizer_dir=None) 20 | 21 | m2 = Dummy(); opt2 = torch.optim.AdamW(m2.parameters(), lr=1e-3) 22 | sch2 = S(); sch2.__dict__ = {'warmup_steps': 1, 'total_steps': 1, 'base_lr': 1e-3, 'step_num': 0} 23 | amp2 = A(); amp2.scaler = torch.cuda.amp.GradScaler(enabled=False) 24 | 25 | step = load_checkpoint(m2, str(out/"model_last.pt"), optimizer=opt2, scheduler=sch2, amp=amp2) 26 | assert isinstance(step, int) -------------------------------------------------------------------------------- /part_4/amp_accum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class AmpGrad: 4 | """AMP + gradient accumulation wrapper. 5 | Usage: 6 | amp = AmpGrad(optimizer, accum=4, amp=True) 7 | amp.backward(loss) 8 | if amp.should_step(): amp.step(); amp.zero_grad() 9 | """ 10 | def __init__(self, optimizer, accum: int = 1, amp: bool = True): 11 | self.optim = optimizer 12 | self.accum = max(1, accum) 13 | self.amp = amp and torch.cuda.is_available() 14 | self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp) 15 | self._n = 0 16 | def backward(self, loss: torch.Tensor): 17 | loss = loss / self.accum 18 | if self.amp: 19 | self.scaler.scale(loss).backward() 20 | else: 21 | loss.backward() 22 | self._n += 1 23 | def should_step(self): 24 | return (self._n % self.accum) == 0 25 | def step(self): 26 | if self.amp: 27 | self.scaler.step(self.optim) 28 | self.scaler.update() 29 | else: 30 | self.optim.step() 31 | def zero_grad(self): 32 | self.optim.zero_grad(set_to_none=True) -------------------------------------------------------------------------------- /part_2/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | 4 | def top_k_top_p_filtering(logits: torch.Tensor, top_k: int | None = None, top_p: float | None = None): 5 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. 6 | - logits: (B, vocab) 7 | Returns filtered logits with -inf for masked entries. 8 | """ 9 | B, V = logits.shape 10 | filtered = logits.clone() 11 | 12 | if top_k is not None and top_k < V: 13 | topk_vals, _ = torch.topk(filtered, top_k, dim=-1) 14 | kth = topk_vals[:, -1].unsqueeze(-1) 15 | filtered[filtered < kth] = float('-inf') 16 | 17 | if top_p is not None and 0 < top_p < 1.0: 18 | sorted_logits, sorted_idx = torch.sort(filtered, descending=True, dim=-1) 19 | probs = torch.softmax(sorted_logits, dim=-1) 20 | cumsum = torch.cumsum(probs, dim=-1) 21 | mask = cumsum > top_p 22 | # keep at least 1 token 23 | mask[..., 0] = False 24 | sorted_logits[mask] = float('-inf') 25 | # Scatter back 26 | filtered = torch.full_like(filtered, float('-inf')) 27 | filtered.scatter_(1, sorted_idx, sorted_logits) 28 | 29 | return filtered -------------------------------------------------------------------------------- /part_3/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | 4 | def top_k_top_p_filtering(logits: torch.Tensor, top_k: int | None = None, top_p: float | None = None): 5 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. 6 | - logits: (B, vocab) 7 | Returns filtered logits with -inf for masked entries. 8 | """ 9 | B, V = logits.shape 10 | filtered = logits.clone() 11 | 12 | if top_k is not None and top_k < V: 13 | topk_vals, _ = torch.topk(filtered, top_k, dim=-1) 14 | kth = topk_vals[:, -1].unsqueeze(-1) 15 | filtered[filtered < kth] = float('-inf') 16 | 17 | if top_p is not None and 0 < top_p < 1.0: 18 | sorted_logits, sorted_idx = torch.sort(filtered, descending=True, dim=-1) 19 | probs = torch.softmax(sorted_logits, dim=-1) 20 | cumsum = torch.cumsum(probs, dim=-1) 21 | mask = cumsum > top_p 22 | # keep at least 1 token 23 | mask[..., 0] = False 24 | sorted_logits[mask] = float('-inf') 25 | # Scatter back 26 | filtered = torch.full_like(filtered, float('-inf')) 27 | filtered.scatter_(1, sorted_idx, sorted_logits) 28 | 29 | return filtered -------------------------------------------------------------------------------- /part_8/ppo_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch, torch.nn.functional as F 3 | from dataclasses import dataclass 4 | 5 | @dataclass 6 | class PPOLossOut: 7 | policy_loss: torch.Tensor 8 | value_loss: torch.Tensor 9 | entropy: torch.Tensor 10 | approx_kl: torch.Tensor 11 | total_loss: torch.Tensor 12 | 13 | 14 | def ppo_losses(new_logp, old_logp, adv, new_values, old_values, returns, 15 | clip_ratio=0.2, vf_coef=0.5, ent_coef=0.0): 16 | # policy 17 | ratio = torch.exp(new_logp - old_logp) # (N,) 18 | unclipped = ratio * adv 19 | clipped = torch.clamp(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) * adv 20 | policy_loss = -torch.mean(torch.min(unclipped, clipped)) 21 | 22 | # value (clip optional → here: simple MSE) 23 | value_loss = F.mse_loss(new_values, returns) 24 | 25 | # entropy bonus (we approximate entropy via -new_logp mean; strictly needs full dist) 26 | entropy = -new_logp.mean() 27 | 28 | # approx KL for logging 29 | approx_kl = torch.mean(old_logp - new_logp) 30 | 31 | total = policy_loss + vf_coef * value_loss - ent_coef * entropy 32 | return PPOLossOut(policy_loss, value_loss, entropy, approx_kl, total) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.3.1 2 | aiohappyeyeballs==2.6.1 3 | aiohttp==3.12.15 4 | aiosignal==1.4.0 5 | attrs==25.3.0 6 | certifi==2025.8.3 7 | charset-normalizer==3.4.3 8 | contourpy==1.3.3 9 | cycler==0.12.1 10 | datasets==4.0.0 11 | dill==0.3.8 12 | filelock==3.19.1 13 | fonttools==4.59.1 14 | frozenlist==1.7.0 15 | fsspec==2025.3.0 16 | grpcio==1.74.0 17 | hf-xet==1.1.7 18 | huggingface-hub==0.34.4 19 | idna==3.10 20 | iniconfig==2.1.0 21 | Jinja2==3.1.6 22 | kiwisolver==1.4.9 23 | Markdown==3.8.2 24 | MarkupSafe==3.0.2 25 | matplotlib==3.10.5 26 | mpmath==1.3.0 27 | multidict==6.6.4 28 | multiprocess==0.70.16 29 | networkx==3.5 30 | numpy==2.3.2 31 | packaging==25.0 32 | pandas==2.3.2 33 | pillow==11.3.0 34 | pluggy==1.6.0 35 | propcache==0.3.2 36 | protobuf==6.32.0 37 | pyarrow==21.0.0 38 | Pygments==2.19.2 39 | pyparsing==3.2.3 40 | pytest==8.4.1 41 | python-dateutil==2.9.0.post0 42 | pytz==2025.2 43 | PyYAML==6.0.2 44 | requests==2.32.4 45 | six==1.17.0 46 | sympy==1.14.0 47 | tensorboard==2.20.0 48 | tensorboard-data-server==0.7.2 49 | tokenizers==0.21.4 50 | torch==2.8.0 51 | tqdm==4.67.1 52 | typing_extensions==4.14.1 53 | tzdata==2025.2 54 | urllib3==2.5.0 55 | Werkzeug==3.1.3 56 | xxhash==3.5.0 57 | yarl==1.20.1 58 | -------------------------------------------------------------------------------- /part_4/dataset_bpe.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from pathlib import Path 5 | from typing import Tuple 6 | from tokenizer_bpe import BPETokenizer 7 | 8 | class TextBPEBuffer(Dataset): 9 | """Memory-mapped-ish single-file dataset: tokenize once → long tensor of ids. 10 | get(idx) returns a (block_size,) slice; we construct (x,y) with shift inside collate. 11 | """ 12 | def __init__(self, path: str, tokenizer: BPETokenizer, block_size: int = 256): 13 | super().__init__() 14 | self.block_size = block_size 15 | text = Path(path).read_text(encoding='utf-8') 16 | self.ids = torch.tensor(tokenizer.encode(text), dtype=torch.long) 17 | def __len__(self): 18 | return max(0, self.ids.numel() - self.block_size - 1) 19 | def __getitem__(self, i: int): 20 | x = self.ids[i:i+self.block_size] 21 | y = self.ids[i+1:i+self.block_size+1] 22 | return x, y 23 | 24 | def make_loader(path: str, tokenizer: BPETokenizer, block_size: int, batch_size: int, shuffle=True) -> DataLoader: 25 | ds = TextBPEBuffer(path, tokenizer, block_size) 26 | return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, drop_last=True) -------------------------------------------------------------------------------- /part_3/block_modern.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from rmsnorm import RMSNorm 3 | from swiglu import SwiGLU 4 | from attn_modern import CausalSelfAttentionModern 5 | 6 | class TransformerBlockModern(nn.Module): 7 | def __init__(self, n_embd: int, n_head: int, dropout: float = 0.0, 8 | use_rmsnorm: bool = True, use_swiglu: bool = True, 9 | rope: bool = True, max_pos: int = 4096, 10 | sliding_window: int | None = None, attention_sink: int = 0, n_kv_head: int | None = None): 11 | super().__init__() 12 | Norm = RMSNorm if use_rmsnorm else nn.LayerNorm 13 | self.ln1 = Norm(n_embd) 14 | self.attn = CausalSelfAttentionModern(n_embd, n_head, dropout, rope, max_pos, sliding_window, attention_sink, n_kv_head) 15 | self.ln2 = Norm(n_embd) 16 | self.ffn = SwiGLU(n_embd, mult=4, dropout=dropout) if use_swiglu else nn.Sequential( 17 | nn.Linear(n_embd, 4*n_embd), nn.GELU(), nn.Linear(4*n_embd, n_embd), nn.Dropout(dropout) 18 | ) 19 | def forward(self, x, kv_cache=None, start_pos: int = 0): 20 | a, kv_cache = self.attn(self.ln1(x), kv_cache=kv_cache, start_pos=start_pos) 21 | x = x + a 22 | x = x + self.ffn(self.ln2(x)) 23 | return x, kv_cache -------------------------------------------------------------------------------- /part_1/pos_encoding.py: -------------------------------------------------------------------------------- 1 | """1.1 Positional encodings (absolute learned + sinusoidal).""" 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | 6 | class LearnedPositionalEncoding(nn.Module): 7 | def __init__(self, max_len: int, d_model: int): 8 | super().__init__() 9 | self.emb = nn.Embedding(max_len, d_model) 10 | 11 | def forward(self, x: torch.Tensor): 12 | # x: (B, T, d_model) — we only need its T and device 13 | B, T, _ = x.shape 14 | pos = torch.arange(T, device=x.device) 15 | pos_emb = self.emb(pos) # (T, d_model) 16 | return x + pos_emb.unsqueeze(0) # broadcast over batch 17 | 18 | class SinusoidalPositionalEncoding(nn.Module): 19 | def __init__(self, max_len: int, d_model: int): 20 | super().__init__() 21 | pe = torch.zeros(max_len, d_model) 22 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 23 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 24 | pe[:, 0::2] = torch.sin(position * div_term) 25 | pe[:, 1::2] = torch.cos(position * div_term) 26 | self.register_buffer('pe', pe) # (max_len, d_model) 27 | 28 | def forward(self, x: torch.Tensor): 29 | B, T, _ = x.shape 30 | return x + self.pe[:T].unsqueeze(0) -------------------------------------------------------------------------------- /part_3/kv_cache.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | from dataclasses import dataclass 4 | 5 | @dataclass 6 | class KVCache: 7 | k: torch.Tensor # (B,H,T,D) 8 | v: torch.Tensor # (B,H,T,D) 9 | 10 | @property 11 | def T(self): 12 | return self.k.size(2) 13 | 14 | class RollingKV: 15 | """Rolling buffer with optional attention sink. 16 | Keeps first `sink` tokens + last `window` tokens. 17 | """ 18 | def __init__(self, window: int, sink: int = 0): 19 | self.window = window 20 | self.sink = sink 21 | self.k = None 22 | self.v = None 23 | def step(self, k_new: torch.Tensor, v_new: torch.Tensor): 24 | if self.k is None: 25 | self.k, self.v = k_new, v_new 26 | else: 27 | self.k = torch.cat([self.k, k_new], dim=2) 28 | self.v = torch.cat([self.v, v_new], dim=2) 29 | # crop 30 | if self.k.size(2) > self.window + self.sink: 31 | sink_part = self.k[:, :, :self.sink, :] 32 | sink_val = self.v[:, :, :self.sink, :] 33 | tail_k = self.k[:, :, -self.window:, :] 34 | tail_v = self.v[:, :, -self.window:, :] 35 | self.k = torch.cat([sink_part, tail_k], dim=2) 36 | self.v = torch.cat([sink_val, tail_v], dim=2) 37 | return self.k, self.v -------------------------------------------------------------------------------- /part_1/single_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from attn_mask import causal_mask 6 | 7 | class SingleHeadSelfAttention(nn.Module): 8 | """1.3 Single-head attention (explicit shapes).""" 9 | def __init__(self, d_model: int, d_k: int, dropout: float = 0.0, trace_shapes: bool = False): 10 | super().__init__() 11 | self.q = nn.Linear(d_model, d_k, bias=False) 12 | self.k = nn.Linear(d_model, d_k, bias=False) 13 | self.v = nn.Linear(d_model, d_k, bias=False) 14 | self.dropout = nn.Dropout(dropout) 15 | self.trace_shapes = trace_shapes 16 | 17 | def forward(self, x: torch.Tensor): # x: (B, T, d_model) 18 | B, T, _ = x.shape 19 | q = self.q(x) # (B,T,d_k) 20 | k = self.k(x) # (B,T,d_k) 21 | v = self.v(x) # (B,T,d_k) 22 | if self.trace_shapes: 23 | print(f"q {q.shape} k {k.shape} v {v.shape}") 24 | scale = 1.0 / math.sqrt(q.size(-1)) 25 | attn = torch.matmul(q, k.transpose(-2, -1)) * scale # (B,T,T) 26 | mask = causal_mask(T, device=x.device) 27 | attn = attn.masked_fill(mask.squeeze(1), float('-inf')) 28 | w = F.softmax(attn, dim=-1) 29 | w = self.dropout(w) 30 | out = torch.matmul(w, v) # (B,T,d_k) 31 | if self.trace_shapes: 32 | print(f"weights {w.shape} out {out.shape}") 33 | return out, w -------------------------------------------------------------------------------- /part_2/eval_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import argparse, torch 3 | from dataset import ByteDataset 4 | from model_gpt import GPT 5 | 6 | 7 | def main(): 8 | p = argparse.ArgumentParser() 9 | p.add_argument('--data', type=str, required=True) 10 | p.add_argument('--ckpt', type=str, required=True) 11 | p.add_argument('--block_size', type=int, default=256) 12 | p.add_argument('--batch_size', type=int, default=32) 13 | p.add_argument('--iters', type=int, default=100) 14 | p.add_argument('--cpu', action='store_true') 15 | args = p.parse_args() 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 18 | 19 | ds = ByteDataset(args.data, block_size=args.block_size) 20 | ckpt = torch.load(args.ckpt, map_location=device) 21 | cfg = ckpt.get('config', { 22 | 'vocab_size': 256, 23 | 'block_size': args.block_size, 24 | 'n_layer': 4, 25 | 'n_head': 4, 26 | 'n_embd': 256, 27 | 'dropout': 0.0, 28 | }) 29 | model = GPT(**cfg).to(device) 30 | model.load_state_dict(ckpt['model']) 31 | 32 | model.eval() 33 | losses = [] 34 | with torch.no_grad(): 35 | for _ in range(args.iters): 36 | xb, yb = ds.get_batch('val', args.batch_size, device) 37 | _, loss = model(xb, yb) 38 | losses.append(loss.item()) 39 | print(f"val loss: {sum(losses)/len(losses):.4f}") 40 | 41 | 42 | if __name__ == '__main__': 43 | main() -------------------------------------------------------------------------------- /part_3/tests/test_rope_apply.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rope_custom import RoPECache, apply_rope_single 3 | 4 | def test_rope_rotation_shapes_single(): 5 | # Vanilla case: same #heads for q and k 6 | B, H, T, D = 1, 2, 5, 8 7 | rc = RoPECache(head_dim=D, max_pos=32) 8 | q = torch.randn(B, H, T, D) 9 | k = torch.randn(B, H, T, D) 10 | pos = torch.arange(0, T) 11 | cos, sin = rc.get(pos) 12 | 13 | q2 = apply_rope_single(q, cos, sin) 14 | k2 = apply_rope_single(k, cos, sin) 15 | 16 | assert q2.shape == q.shape 17 | assert k2.shape == k.shape 18 | # check that rotation mixes even/odd features (values should change) 19 | assert not torch.allclose(q2, q) 20 | assert not torch.allclose(k2, k) 21 | 22 | def test_rope_rotation_shapes_gqa(): 23 | # GQA case: q has H heads; k has fewer Hk heads (shared KV) 24 | B, H, Hk, T, D = 2, 8, 2, 7, 16 25 | rc = RoPECache(head_dim=D, max_pos=128) 26 | q = torch.randn(B, H, T, D) 27 | k = torch.randn(B, Hk, T, D) 28 | pos = torch.arange(10, 10 + T) # arbitrary start position 29 | cos, sin = rc.get(pos) 30 | 31 | q2 = apply_rope_single(q, cos, sin) 32 | k2 = apply_rope_single(k, cos, sin) 33 | 34 | assert q2.shape == (B, H, T, D) 35 | assert k2.shape == (B, Hk, T, D) 36 | # rotations should be deterministic for same positions 37 | # check a couple of coords moved as expected (values changed) 38 | assert not torch.allclose(q2, q) 39 | assert not torch.allclose(k2, k) 40 | -------------------------------------------------------------------------------- /part_1/vis_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | OUT_DIR = os.path.join(os.path.dirname(__file__), 'out') 6 | 7 | 8 | def _ensure_out(): 9 | os.makedirs(OUT_DIR, exist_ok=True) 10 | 11 | 12 | def save_matrix_heatmap(mat: np.ndarray, title: str, filename: str, xlabel: str = '', ylabel: str = ''): 13 | """Generic matrix heatmap saver. 14 | Do not set any specific colors/styles; keep defaults for clarity. 15 | """ 16 | _ensure_out() 17 | plt.figure() 18 | plt.imshow(mat, aspect='auto') 19 | plt.title(title) 20 | plt.xlabel(xlabel) 21 | plt.ylabel(ylabel) 22 | plt.colorbar() 23 | path = os.path.join(OUT_DIR, filename) 24 | plt.savefig(path, bbox_inches='tight') 25 | plt.close() 26 | print(f"Saved: {path}") 27 | 28 | 29 | def save_attention_heads_grid(weights: np.ndarray, filename: str, title_prefix: str = "Head"): 30 | """Plot all heads in a single grid figure (B=1 assumed). 31 | weights: (1, H, T, T) 32 | """ 33 | _ensure_out() 34 | _, H, T, _ = weights.shape 35 | cols = min(4, H) 36 | rows = (H + cols - 1) // cols 37 | plt.figure(figsize=(3*cols, 3*rows)) 38 | for h in range(H): 39 | ax = plt.subplot(rows, cols, h+1) 40 | ax.imshow(weights[0, h], aspect='auto') 41 | ax.set_title(f"{title_prefix} {h}") 42 | ax.set_xlabel('Key pos') 43 | ax.set_ylabel('Query pos') 44 | plt.tight_layout() 45 | path = os.path.join(OUT_DIR, filename) 46 | plt.savefig(path, bbox_inches='tight') 47 | plt.close() 48 | print(f"Saved: {path}") -------------------------------------------------------------------------------- /part_5/moe.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch, torch.nn as nn 3 | from gating import TopKGate 4 | from experts import ExpertMLP 5 | 6 | class MoE(nn.Module): 7 | """Mixture‑of‑Experts layer (token‑wise top‑k routing). 8 | Implementation is single‑GPU friendly (loops over experts for clarity). 9 | https://arxiv.org/pdf/2101.03961 10 | """ 11 | def __init__(self, dim: int, n_expert: int, k: int = 1, mult: int = 4, swiglu: bool = True, dropout: float = 0.0): 12 | super().__init__() 13 | self.dim = dim 14 | self.n_expert = n_expert 15 | self.k = k 16 | self.gate = TopKGate(dim, n_expert, k=k) 17 | self.experts = nn.ModuleList([ExpertMLP(dim, mult=mult, swiglu=swiglu, dropout=dropout) for _ in range(n_expert)]) 18 | 19 | def forward(self, x: torch.Tensor): 20 | """x: (B, T, C) → y: (B, T, C), aux_loss 21 | Steps: flatten tokens → gate → per‑expert forward → scatter back with weights. 22 | """ 23 | B, T, C = x.shape 24 | S = B * T 25 | x_flat = x.reshape(S, C) 26 | idx, w, aux = self.gate(x_flat) # (S,k), (S,k) 27 | 28 | y = torch.zeros_like(x_flat) # (S,C) 29 | for e in range(self.n_expert): 30 | # tokens where expert e is selected at any of k slots 31 | for slot in range(self.k): 32 | sel = (idx[:, slot] == e) 33 | if sel.any(): 34 | x_e = x_flat[sel] 35 | y_e = self.experts[e](x_e) 36 | y[sel] += w[sel, slot:slot+1] * y_e 37 | y = y.view(B, T, C) 38 | return y, aux -------------------------------------------------------------------------------- /part_7/model_reward.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch, torch.nn as nn 3 | 4 | class RewardModel(nn.Module): 5 | """Transformer encoder → pooled representation → scalar reward. 6 | Bidirectional encoder is fine for reward modeling (not used for generation). 7 | """ 8 | def __init__(self, vocab_size: int, block_size: int, n_layer: int = 4, n_head: int = 4, n_embd: int = 256, dropout: float = 0.1): 9 | super().__init__() 10 | self.vocab_size = vocab_size 11 | self.block_size = block_size 12 | self.tok_emb = nn.Embedding(vocab_size, n_embd) 13 | self.pos_emb = nn.Embedding(block_size, n_embd) 14 | enc_layer = nn.TransformerEncoderLayer(d_model=n_embd, nhead=n_head, dim_feedforward=4*n_embd, 15 | dropout=dropout, activation='gelu', batch_first=True) 16 | self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layer) 17 | self.ln = nn.LayerNorm(n_embd) 18 | self.head = nn.Linear(n_embd, 1) 19 | 20 | def forward(self, x: torch.Tensor): 21 | B, T = x.shape 22 | pos = torch.arange(T, device=x.device).unsqueeze(0) 23 | h = self.tok_emb(x) + self.pos_emb(pos) 24 | pad_mask = (x == 2) 25 | h = self.encoder(h, src_key_padding_mask=pad_mask) 26 | h = self.ln(h) 27 | # masked mean pool over tokens (ignoring pads) 28 | mask = (~pad_mask).float().unsqueeze(-1) 29 | h_sum = (h * mask).sum(dim=1) 30 | len_ = mask.sum(dim=1).clamp_min(1.0) 31 | pooled = h_sum / len_ 32 | r = self.head(pooled).squeeze(-1) # (B,) 33 | return r -------------------------------------------------------------------------------- /part_8/policy.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch, torch.nn as nn 3 | import sys 4 | from pathlib import Path as _P 5 | # Try user’s structure first 6 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3')) 7 | try: 8 | from model_utils.model_modern import GPTModern # user-custom path 9 | except Exception: 10 | from model_modern import GPTModern # fallback 11 | 12 | class PolicyWithValue(nn.Module): 13 | """Policy network = SFT LM + tiny value head. 14 | NOTE: For simplicity we place value head on top of LM logits (vocab→1). 15 | This avoids depending on hidden-state internals while keeping the tutorial runnable. 16 | """ 17 | def __init__(self, vocab_size: int, block_size: int, n_layer=4, n_head=4, n_embd=256, 18 | use_rmsnorm=True, use_swiglu=True, rope=True, dropout=0.0): 19 | super().__init__() 20 | self.lm = GPTModern(vocab_size=vocab_size, block_size=block_size, n_layer=n_layer, 21 | n_head=n_head, n_embd=n_embd, use_rmsnorm=use_rmsnorm, 22 | use_swiglu=use_swiglu, rope=rope, dropout=dropout) 23 | # value head over logits (toy). Shapes: (B,T,V) -> (B,T,1) -> (B,T) 24 | self.val_head = nn.Linear(vocab_size, 1, bias=False) 25 | 26 | def forward(self, x: torch.Tensor, y: torch.Tensor | None = None): 27 | # Delegate LM forward; returns logits (B,T,V), loss, _ 28 | logits, loss, _ = self.lm(x, y) 29 | values = self.val_head(logits).squeeze(-1) # (B,T) 30 | return logits, values, loss 31 | 32 | def generate(self, *args, **kwargs): 33 | return self.lm.generate(*args, **kwargs) -------------------------------------------------------------------------------- /part_9/policy.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch, torch.nn as nn 3 | import sys 4 | from pathlib import Path as _P 5 | # Try user’s structure first 6 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3')) 7 | try: 8 | from model_utils.model_modern import GPTModern # user-custom path 9 | except Exception: 10 | from model_modern import GPTModern # fallback 11 | 12 | class PolicyWithValue(nn.Module): 13 | """Policy network = SFT LM + tiny value head. 14 | NOTE: For simplicity we place value head on top of LM logits (vocab→1). 15 | This avoids depending on hidden-state internals while keeping the tutorial runnable. 16 | """ 17 | def __init__(self, vocab_size: int, block_size: int, n_layer=4, n_head=4, n_embd=256, 18 | use_rmsnorm=True, use_swiglu=True, rope=True, dropout=0.0): 19 | super().__init__() 20 | self.lm = GPTModern(vocab_size=vocab_size, block_size=block_size, n_layer=n_layer, 21 | n_head=n_head, n_embd=n_embd, use_rmsnorm=use_rmsnorm, 22 | use_swiglu=use_swiglu, rope=rope, dropout=dropout) 23 | # value head over logits (toy). Shapes: (B,T,V) -> (B,T,1) -> (B,T) 24 | self.val_head = nn.Linear(vocab_size, 1, bias=False) 25 | 26 | def forward(self, x: torch.Tensor, y: torch.Tensor | None = None): 27 | # Delegate LM forward; returns logits (B,T,V), loss, _ 28 | logits, loss, _ = self.lm(x, y) 29 | values = self.val_head(logits).squeeze(-1) # (B,T) 30 | return logits, values, loss 31 | 32 | def generate(self, *args, **kwargs): 33 | return self.lm.generate(*args, **kwargs) -------------------------------------------------------------------------------- /part_2/sample.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import argparse, torch 3 | from tokenizer import ByteTokenizer 4 | from model_gpt import GPT 5 | 6 | 7 | def main(): 8 | p = argparse.ArgumentParser() 9 | p.add_argument('--ckpt', type=str, required=True) 10 | p.add_argument('--prompt', type=str, default='') 11 | p.add_argument('--tokens', type=int, default=200) 12 | p.add_argument('--temperature', type=float, default=1.0) 13 | p.add_argument('--top_k', type=int, default=50) 14 | p.add_argument('--top_p', type=float, default=None) 15 | p.add_argument('--cpu', action='store_true') 16 | args = p.parse_args() 17 | 18 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 19 | 20 | tok = ByteTokenizer() 21 | prompt_ids = tok.encode(args.prompt).unsqueeze(0).to(device) 22 | if prompt_ids.numel() == 0: 23 | # If no prompt provided, seed with newline byte (10) 24 | prompt_ids = torch.tensor([[10]], dtype=torch.long, device=device) 25 | 26 | 27 | ckpt = torch.load(args.ckpt, map_location=device) 28 | config = ckpt.get('config', None) 29 | 30 | if config is None: 31 | # fall back to defaults 32 | model = GPT(tok.vocab_size, block_size=256).to(device) 33 | model.load_state_dict(ckpt['model']) 34 | else: 35 | model = GPT(**config).to(device) 36 | model.load_state_dict(ckpt['model']) 37 | 38 | with torch.no_grad(): 39 | out = model.generate(prompt_ids, max_new_tokens=args.tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p) 40 | print(tok.decode(out[0].cpu())) 41 | 42 | 43 | if __name__ == '__main__': 44 | main() -------------------------------------------------------------------------------- /part_9/grpo_loss.py: -------------------------------------------------------------------------------- 1 | # grpo_loss.py 2 | from __future__ import annotations 3 | import torch 4 | from dataclasses import dataclass 5 | 6 | @dataclass 7 | class PolicyOnlyLossOut: 8 | policy_loss: torch.Tensor 9 | entropy: torch.Tensor 10 | approx_kl: torch.Tensor 11 | kl_ref: torch.Tensor 12 | total_loss: torch.Tensor 13 | 14 | 15 | def ppo_policy_only_losses(new_logp, old_logp, adv, clip_ratio=0.2, ent_coef=0.0, 16 | kl_coef: float = 0.0, kl_mean: torch.Tensor | None = None): 17 | """ 18 | PPO-style clipped policy loss, *policy only* (no value head), 19 | plus a separate KL(π||π_ref) penalty term: total = L_PPO + kl_coef * KL. 20 | Inputs are flat over action tokens: new_logp, old_logp, adv: (N_act,). 21 | kl_mean is a scalar tensor (mean over action tokens). 22 | """ 23 | device = new_logp.device if new_logp.is_cuda else None 24 | if new_logp.numel() == 0: 25 | zero = torch.tensor(0.0, device=device) 26 | return PolicyOnlyLossOut(zero, zero, zero, zero, zero) 27 | 28 | ratio = torch.exp(new_logp - old_logp) # (N,) 29 | unclipped = ratio * adv 30 | clipped = torch.clamp(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) * adv 31 | policy_loss = -torch.mean(torch.min(unclipped, clipped)) 32 | 33 | entropy = -new_logp.mean() if ent_coef != 0.0 else new_logp.new_tensor(0.0) 34 | approx_kl = torch.mean(old_logp - new_logp) 35 | 36 | kl_ref = kl_mean if kl_mean is not None else new_logp.new_tensor(0.0) 37 | 38 | total = policy_loss - ent_coef * entropy + kl_coef * kl_ref # entropy bonus was not used in original GRPO paper 39 | return PolicyOnlyLossOut(policy_loss, entropy, approx_kl, kl_ref, total) 40 | -------------------------------------------------------------------------------- /part_5/orchestrator.py: -------------------------------------------------------------------------------- 1 | # Repository layout (Part 5) 2 | # 3 | # part_5/ 4 | # orchestrator.py # run unit tests + optional MoE demo 5 | # README.md # 5.1/5.3 concept notes (compact MD) 6 | # gating.py # router/gating (top‑k) + load‑balancing aux loss 7 | # experts.py # MLP experts (SwiGLU or GELU) 8 | # moe.py # Mixture-of-Experts layer (dispatch/combine) 9 | # block_hybrid.py # Hybrid dense+MoE block examples 10 | # demo_moe.py # small forward pass demo + routing histogram 11 | # tests/ 12 | # test_gate_shapes.py 13 | # test_moe_forward.py 14 | # test_hybrid_block.py 15 | # 16 | # Run from inside `part_5/`: 17 | # cd part_5 18 | # python orchestrator.py --demo 19 | # pytest -q 20 | 21 | import argparse, pathlib, subprocess, sys, shlex 22 | 23 | ROOT = pathlib.Path(__file__).resolve().parent 24 | 25 | def run(cmd: str): 26 | print(f"\n>>> {cmd}") 27 | res = subprocess.run(shlex.split(cmd), cwd=ROOT) 28 | if res.returncode != 0: 29 | sys.exit(res.returncode) 30 | 31 | if __name__ == "__main__": 32 | p = argparse.ArgumentParser() 33 | p.add_argument("--demo", action="store_true", help="run a tiny MoE demo") 34 | args = p.parse_args() 35 | 36 | # 1) unit tests 37 | run("python -m pytest -q tests/test_gate_shapes.py") 38 | run("python -m pytest -q tests/test_moe_forward.py") 39 | run("python -m pytest -q tests/test_hybrid_block.py") 40 | 41 | # 2) optional demo 42 | if args.demo: 43 | run("python demo_moe.py --tokens 6 --hidden 128 --experts 4 --top_k 1") 44 | 45 | print("\nPart 5 checks complete. ✅") -------------------------------------------------------------------------------- /part_6/dataset_sft.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List, Dict, Tuple 3 | from dataclasses import dataclass 4 | import os 5 | import traceback 6 | 7 | try: 8 | from datasets import load_dataset 9 | except Exception: 10 | print("Couldn't import `datasets`. Will use fallback data only.") 11 | load_dataset = None 12 | 13 | from formatters import Example 14 | 15 | @dataclass 16 | class SFTItem: 17 | prompt: str 18 | response: str 19 | 20 | 21 | def load_tiny_hf(split: str = "train[:200]", sample_dataset: bool = False) -> List[SFTItem]: 22 | """Try to load a tiny instruction dataset from HF; fall back to a baked-in list. 23 | We use `tatsu-lab/alpaca` as a familiar schema (instruction, input, output) and keep only a slice. 24 | """ 25 | items: List[SFTItem] = [] 26 | if load_dataset is not None and not sample_dataset: 27 | try: 28 | ds = load_dataset("tatsu-lab/alpaca", split=split) 29 | for row in ds: 30 | instr = row.get("instruction", "").strip() 31 | inp = row.get("input", "").strip() 32 | out = row.get("output", "").strip() 33 | if inp: 34 | instr = instr + "\n" + inp 35 | if instr and out: 36 | items.append(SFTItem(prompt=instr, response=out)) 37 | except Exception: 38 | pass 39 | if not items: 40 | # fallback tiny list 41 | seeds = [ 42 | ("First prime number", "2"), 43 | ("What are the three primary colors?", "red"), 44 | ("Device name which points to direction?", "compass"), 45 | ] 46 | items = [SFTItem(prompt=p, response=r) for p,r in seeds] 47 | return items -------------------------------------------------------------------------------- /part_7/eval_rm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import argparse, torch 3 | from data_prefs import load_preferences 4 | from collator_rm import PairCollator 5 | from model_reward import RewardModel 6 | 7 | 8 | def main(): 9 | p = argparse.ArgumentParser() 10 | p.add_argument('--ckpt', type=str, required=True) 11 | p.add_argument('--split', type=str, default='val[:200]') 12 | p.add_argument('--cpu', action='store_true') 13 | p.add_argument('--bpe_dir', type=str, default=None) 14 | args = p.parse_args() 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 17 | 18 | items = load_preferences(split=args.split) 19 | triples = [(it.prompt, it.chosen, it.rejected) for it in items] 20 | 21 | col = PairCollator(block_size=256, bpe_dir=args.bpe_dir) 22 | ckpt = torch.load(args.ckpt, map_location=device) 23 | cfg = ckpt.get('config', {}) 24 | 25 | model = RewardModel(vocab_size=cfg.get('vocab_size', col.vocab_size), block_size=cfg.get('block_size', 256), 26 | n_layer=cfg.get('n_layer', 4), n_head=cfg.get('n_head', 4), n_embd=cfg.get('n_embd', 256)) 27 | model.load_state_dict(ckpt['model']) 28 | model.to(device).eval() 29 | 30 | # Evaluate accuracy r_pos>r_neg 31 | import math 32 | B = 16 33 | correct = 0; total = 0 34 | for i in range(0, len(triples), B): 35 | batch = triples[i:i+B] 36 | pos, neg = col.collate(batch) 37 | pos, neg = pos.to(device), neg.to(device) 38 | with torch.no_grad(): 39 | r_pos = model(pos) 40 | r_neg = model(neg) 41 | correct += (r_pos > r_neg).sum().item() 42 | total += pos.size(0) 43 | acc = correct / max(1, total) 44 | print(f"pairs={total} accuracy (r_pos>r_neg) = {acc:.3f}") 45 | 46 | if __name__ == '__main__': 47 | main() -------------------------------------------------------------------------------- /part_7/data_prefs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from dataclasses import dataclass 3 | from typing import List, Tuple 4 | 5 | try: 6 | from datasets import load_dataset 7 | except Exception: 8 | load_dataset = None 9 | 10 | @dataclass 11 | class PrefExample: 12 | prompt: str 13 | chosen: str 14 | rejected: str 15 | 16 | 17 | def load_preferences(split: str = "train[:200]") -> List[PrefExample]: 18 | """Load a tiny preference set. Tries Anthropic HH; falls back to a toy set. 19 | HH fields: 'chosen', 'rejected' (full conversations). We use an empty prompt. 20 | """ 21 | items: List[PrefExample] = [] 22 | if load_dataset is not None: 23 | try: 24 | ds = load_dataset("Anthropic/hh-rlhf", split=split) 25 | for row in ds: 26 | ch = str(row.get("chosen", "")).strip() 27 | rj = str(row.get("rejected", "")).strip() 28 | if ch and rj: 29 | items.append(PrefExample(prompt="", chosen=ch, rejected=rj)) 30 | except Exception: 31 | print("Failed to load Anthropic/hh-rlhf dataset. Using fallback toy pairs.") 32 | pass 33 | if not items: 34 | # fallback toy pairs 35 | items = [ 36 | PrefExample("Summarize: Scaling laws for neural language models.", 37 | "Scaling laws describe how performance improves predictably as model size, data, and compute increase.", 38 | "Scaling laws are when you scale pictures to look bigger."), 39 | PrefExample("Give two uses of attention in transformers.", 40 | "It lets the model focus on relevant tokens and enables parallel context integration across positions.", 41 | "It remembers all past words exactly without any computation."), 42 | ] 43 | return items -------------------------------------------------------------------------------- /part_5/gating.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch, torch.nn as nn 3 | 4 | class TopKGate(nn.Module): 5 | """Top‑k softmax gating with Switch‑style load‑balancing aux loss. 6 | Args: 7 | dim: input hidden size 8 | n_expert: number of experts 9 | k: number of experts to route per token (1 or 2 typical) 10 | Returns: 11 | (indices, weights, aux_loss) where 12 | indices: (S, k) long, expert ids for each token 13 | weights: (S, k) float, gate weights (sum ≤ 1 per token) 14 | aux_loss: scalar load‑balancing penalty 15 | """ 16 | def __init__(self, dim: int, n_expert: int, k: int = 1): 17 | super().__init__() 18 | assert k >= 1 and k <= n_expert 19 | self.n_expert = n_expert 20 | self.k = k 21 | self.w_g = nn.Linear(dim, n_expert, bias=True) 22 | 23 | def forward(self, x: torch.Tensor): 24 | # x: (S, C) where S = tokens (batch * seq) 25 | logits = self.w_g(x) # (S, E) 26 | probs = torch.softmax(logits, dim=-1) # (S, E) 27 | topk_vals, topk_idx = torch.topk(probs, k=self.k, dim=-1) # (S,k) 28 | 29 | # Load‑balancing aux loss (Switch): 30 | S, E = probs.size(0), probs.size(1) 31 | # importance: avg prob per expert 32 | importance = probs.mean(dim=0) # (E,) 33 | # load: fraction of tokens assigned as primary (top‑1 hard assignment) 34 | hard1 = topk_idx[:, 0] # (S,) 35 | load = torch.zeros(E, device=x.device) 36 | load.scatter_add_(0, hard1, torch.ones_like(hard1, dtype=load.dtype)) 37 | load = load / max(S, 1) 38 | aux_loss = (E * (importance * load).sum()) 39 | # print("*"*50) 40 | # print(probs, importance, hard1, load, aux_loss) 41 | # print("*"*50) 42 | 43 | return topk_idx, topk_vals, aux_loss -------------------------------------------------------------------------------- /part_9/orchestrator.py: -------------------------------------------------------------------------------- 1 | # Repository layout (Part 9 — RLHF with GRPO) 2 | # 3 | # part_9/ 4 | # orchestrator.py # run unit tests + optional tiny PPO demo 5 | # policy.py # policy = SFT LM + value head (toy head on logits) 6 | # rollout.py # prompt formatting, sampling, logprobs/KL utilities 7 | # grpo_loss.py # PPO clipped objective + value + entropy + KL penalty 8 | # train_ppo.py # single‑GPU RLHF loop (tiny, on‑policy) 9 | # eval_ppo.py # compare reward vs. reference on a small set 10 | # tests/ 11 | # test_ppo_loss.py 12 | # test_policy_forward.py 13 | # 14 | # Run from inside `part_8/`: 15 | # cd part_8 16 | # python orchestrator.py --demo 17 | # pytest -q 18 | 19 | import argparse, pathlib, subprocess, sys 20 | ROOT = pathlib.Path(__file__).resolve().parent 21 | 22 | def run(cmd: str): 23 | print(f"\n>>> {cmd}") 24 | res = subprocess.run(cmd.split(), cwd=ROOT) 25 | if res.returncode != 0: 26 | sys.exit(res.returncode) 27 | 28 | if __name__ == "__main__": 29 | p = argparse.ArgumentParser() 30 | p.add_argument("--demo", action="store_true", help="tiny GRPO demo") 31 | args = p.parse_args() 32 | 33 | # 1) unit tests 34 | run("python -m pytest -q tests/test_grpo_loss.py") 35 | 36 | # 2) optional demo (requires SFT+RM checkpoints from Parts 6 & 7) 37 | if args.demo: 38 | run("python train_grpo.py --group_size 4 --policy_ckpt ../part_6/runs/sft-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --steps 200 --batch_prompts 4 --resp_len 128 --bpe_dir ../part_4/runs/part4-demo/tokenizer") 39 | run("python eval_ppo.py --policy_ckpt runs/grpo-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --split train[:24] --bpe_dir ../part_4/runs/part4-demo/tokenizer") 40 | 41 | print("\nPart 9 checks complete. ✅") 42 | -------------------------------------------------------------------------------- /part_3/orchestrator.py: -------------------------------------------------------------------------------- 1 | # Repository layout (Part 3) 2 | # 3 | # part_3/ 4 | # orchestrator.py # runs tests + a small generation demo 5 | # tokenizer.py # local byte-level tokenizer (self-contained) 6 | # rmsnorm.py # 3.1 RMSNorm 7 | # rope.py # 3.2 RoPE cache + apply 8 | # swiglu.py # 3.3 SwiGLU FFN 9 | # kv_cache.py # 3.4/3.6 KV cache + rolling buffer 10 | # attn_modern.py # attention w/ RoPE, sliding window, sink, optional KV cache 11 | # block_modern.py # block = (RMSNorm|LN) + modern attention + (SwiGLU|GELU) 12 | # model_modern.py # GPTModern wrapper with feature flags 13 | # demo_generate.py # simple generation demo (shows KV cache + sliding window) 14 | # tests/ 15 | # test_rmsnorm.py 16 | # test_rope_apply.py 17 | # test_kvcache_shapes.py 18 | # 19 | # Run from inside `part_3/`: 20 | # cd part_3 21 | # python orchestrator.py --demo 22 | # pytest -q 23 | 24 | import argparse, pathlib, subprocess, sys, shlex 25 | 26 | ROOT = pathlib.Path(__file__).resolve().parent 27 | 28 | def run(cmd: str): 29 | print(f"\n>>> {cmd}") 30 | res = subprocess.run(shlex.split(cmd), cwd=ROOT) 31 | if res.returncode != 0: 32 | sys.exit(res.returncode) 33 | 34 | if __name__ == "__main__": 35 | p = argparse.ArgumentParser() 36 | p.add_argument("--demo", action="store_true", help="run a tiny generation demo") 37 | args = p.parse_args() 38 | 39 | # 1) run unit tests 40 | run("python -m pytest -q tests/test_rmsnorm.py") 41 | run("python -m pytest -q tests/test_rope_apply.py") 42 | run("python -m pytest -q tests/test_kvcache_shapes.py") 43 | 44 | # 2) (optional) generation demo 45 | if args.demo: 46 | run("python demo_generate.py --rmsnorm --rope --swiglu --sliding_window 64 --sink 4 --tokens 200") 47 | 48 | print("\nPart 3 checks complete. ✅") -------------------------------------------------------------------------------- /part_7/orchestrator.py: -------------------------------------------------------------------------------- 1 | # Repository layout (Part 7) 2 | # 3 | # part_7/ 4 | # orchestrator.py # run unit tests + optional tiny RM demo 5 | # data_prefs.py # 7.1 HF preference loader (+tiny fallback) 6 | # collator_rm.py # pairwise tokenization → (pos, neg) tensors 7 | # model_reward.py # 7.2 reward model (Transformer encoder → scalar) 8 | # loss_reward.py # 7.3 Bradley–Terry & margin-ranking losses 9 | # train_rm.py # minimal one‑GPU training on tiny slice 10 | # eval_rm.py # 7.4 sanity checks & simple accuracy on val 11 | # tests/ 12 | # test_bt_loss.py 13 | # test_reward_forward.py 14 | # 15 | # Run from inside `part_7/`: 16 | # cd part_7 17 | # python orchestrator.py --demo 18 | # pytest -q 19 | 20 | import argparse, pathlib, subprocess, sys, shlex 21 | ROOT = pathlib.Path(__file__).resolve().parent 22 | 23 | def run(cmd: str): 24 | print(f"\n>>> {cmd}") 25 | res = subprocess.run(shlex.split(cmd), cwd=ROOT) 26 | if res.returncode != 0: 27 | sys.exit(res.returncode) 28 | 29 | if __name__ == "__main__": 30 | p = argparse.ArgumentParser() 31 | p.add_argument("--demo", action="store_true", help="tiny reward‑model demo") 32 | args = p.parse_args() 33 | 34 | # 1) unit tests 35 | run("python -m pytest -q tests/test_bt_loss.py") 36 | run("python -m pytest -q tests/test_reward_forward.py") 37 | 38 | # 2) optional demo: tiny train + eval 39 | if args.demo: 40 | run("python train_rm.py --steps 300 --batch_size 8 --block_size 256 --n_layer 2 --n_head 2 --n_embd 128 --loss bt --bpe_dir ../part_4/runs/part4-demo/tokenizer") 41 | run("python eval_rm.py --ckpt runs/rm-demo/model_last.pt --split train[:8] --bpe_dir ../part_4/runs/part4-demo/tokenizer") 42 | run("python eval_rm.py --ckpt runs/rm-demo/model_last.pt --split test[:8] --bpe_dir ../part_4/runs/part4-demo/tokenizer") 43 | 44 | print("\nPart 7 checks complete. ✅") -------------------------------------------------------------------------------- /part_5/README.md: -------------------------------------------------------------------------------- 1 | # Part 5 — Mixture-of-Experts (MoE) 2 | 3 | This part focuses purely on the **Mixture-of-Experts feed-forward component**. 4 | We are **not** implementing self-attention in this module. 5 | In a full transformer block, the order is generally: 6 | 7 | ``` 8 | [LayerNorm] → [Self-Attention] → [Residual Add] 9 | → [LayerNorm] → [Feed-Forward Block (Dense or MoE)] → [Residual Add] 10 | ``` 11 | 12 | Our MoE layer is a drop-in replacement for the dense feed-forward block. 13 | You can imagine it being called **after** the attention output, before the second residual connection. 14 | 15 | --- 16 | 17 | ## **5.1 Theory in 60 seconds** 18 | - **Experts**: multiple parallel MLPs; each token activates only a small subset (top-k) → sparse compute. 19 | - **Gate/Router**: scores each token across experts; picks top-k and assigns weights via a softmax. 20 | - **Dispatch/Combine**: send token to chosen experts, run their MLP, combine results using gate weights. 21 | - **Load balancing**: encourage uniform expert usage. A common aux loss (Switch Transformer) is 22 | `L_aux = E * Σ ( importance * load )` where: 23 | - *importance* = avg gate probability per expert 24 | - *load* = fraction of tokens routed as primary to that expert 25 | 26 | --- 27 | 28 | ## **5.3 Distributed notes (single-GPU friendly)** 29 | - Real MoE implementations distribute experts across GPUs (**expert parallelism**). 30 | - Here we keep everything **on one device** for simplicity. Dispatch is simulated with indexing/masking. 31 | - In production, dispatch/combination typically involves **all-to-all communication** across devices. 32 | 33 | --- 34 | 35 | ## **5.4 Hybrid architectures** 36 | - MoE need not replace every FFN — use it in alternating layers or blend outputs: 37 | `y = α * Dense(x) + (1 − α) * MoE(x)` 38 | 39 | --- 40 | 41 | **Milestone for this part:** integrate this MoE layer in place of a dense feed-forward in a transformer block and compare efficiency/accuracy trade-offs — even with a toy attention block if you wish. 42 | -------------------------------------------------------------------------------- /part_3/demo_generate.py: -------------------------------------------------------------------------------- 1 | import argparse, torch 2 | from tokenizer import ByteTokenizer 3 | from model_modern import GPTModern 4 | import time 5 | 6 | if __name__ == "__main__": 7 | p = argparse.ArgumentParser() 8 | p.add_argument('--rmsnorm', action='store_true') 9 | p.add_argument('--rope', action='store_true') 10 | p.add_argument('--swiglu', action='store_true') 11 | p.add_argument('--sliding_window', type=int, default=None) 12 | p.add_argument('--sink', type=int, default=0) 13 | p.add_argument('--group_size', type=int, default=2) 14 | p.add_argument('--tokens', type=int, default=120) 15 | p.add_argument('--cpu', action='store_true') 16 | args = p.parse_args() 17 | 18 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 19 | 20 | tok = ByteTokenizer() 21 | model = GPTModern(vocab_size=tok.vocab_size, block_size=128, n_layer=2, n_head=4, n_embd=128, 22 | use_rmsnorm=args.rmsnorm, use_swiglu=args.swiglu, rope=args.rope, 23 | max_pos=4096, sliding_window=args.sliding_window, attention_sink=args.sink, n_kv_head=args.group_size).to(device) 24 | 25 | # empty prompt → newline 26 | prompt = torch.tensor([[10]], dtype=torch.long, device=device) 27 | 28 | with torch.no_grad(): 29 | start = time.time() 30 | out = model.generate(prompt, max_new_tokens=args.tokens, temperature=0.0, top_k=50, top_p=None, 31 | sliding_window=args.sliding_window, attention_sink=args.sink) 32 | print(f"Generated {args.tokens} tokens in {time.time()-start:.2f} sec") 33 | 34 | start = time.time() 35 | out_nocache = model.generate_nocache(prompt, max_new_tokens=args.tokens, temperature=0.0, top_k=50, top_p=None, 36 | sliding_window=args.sliding_window, attention_sink=args.sink) 37 | print(f"(nocache) Generated {args.tokens} tokens in {time.time()-start:.2f} sec") 38 | print(tok.decode(out[0].cpu())) 39 | print(tok.decode(out_nocache[0].cpu())) -------------------------------------------------------------------------------- /part_3/rope_custom.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | import math 4 | 5 | class RoPECache: 6 | """Precompute cos/sin for positions up to max_pos for even head_dim.""" 7 | def __init__(self, head_dim: int, max_pos: int, base: float = 10000.0, device: torch.device | None = None): 8 | assert head_dim % 2 == 0, "RoPE head_dim must be even" 9 | self.head_dim = head_dim 10 | self.base = base 11 | self.device = device 12 | self._build(max_pos) 13 | def get(self, positions: torch.Tensor): 14 | # positions: (T,) or (1,T) 15 | if positions.dim() == 2: 16 | positions = positions[0] 17 | need = int(positions.max().item()) + 1 if positions.numel() > 0 else 1 18 | if need > self.max_pos: 19 | # grow tables 20 | self._build(max(need, int(self.max_pos * 2))) 21 | cos = self.cos[positions] # (T, D/2) 22 | sin = self.sin[positions] 23 | return cos, sin 24 | 25 | def _build(self, max_pos: int): 26 | """(Re)build cos/sin tables for a new max_pos.""" 27 | self.max_pos = max_pos 28 | inv_freq = 1.0 / (10000.0 ** (torch.arange(0, self.head_dim, 2, device=self.device).float() / self.head_dim)) 29 | t = torch.arange(max_pos, device=self.device).float() 30 | freqs = torch.outer(t, inv_freq) # (max_pos, head_dim/2) 31 | self.cos = torch.cos(freqs) 32 | self.sin = torch.sin(freqs) 33 | 34 | def apply_rope_single(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: 35 | """Rotate pairs along last dim for RoPE. 36 | x: (B,H,T,D) with D even; cos/sin: (T,D/2) 37 | """ 38 | assert x.size(-1) % 2 == 0 39 | cos = cos.unsqueeze(0).unsqueeze(0) # (1,1,T,D/2) 40 | sin = sin.unsqueeze(0).unsqueeze(0) 41 | x1 = x[..., ::2] 42 | x2 = x[..., 1::2] 43 | xr1 = x1 * cos - x2 * sin 44 | xr2 = x1 * sin + x2 * cos 45 | out = torch.empty_like(x) 46 | out[..., ::2] = xr1 47 | out[..., 1::2] = xr2 48 | return out 49 | -------------------------------------------------------------------------------- /part_2/orchestrator.py: -------------------------------------------------------------------------------- 1 | # Repository layout (Part 2) 2 | # 3 | # part_2/ 4 | # orchestrator.py # runs quick smoke-train + eval + sample 5 | # tokenizer.py # 2.1 byte-level tokenizer (0..255) 6 | # dataset.py # 2.2 dataset + batching + shift 7 | # utils.py # 2.5 sampling helpers (top-k/top-p) 8 | # model_gpt.py # tiny GPT: tok/pos emb + blocks + head 9 | # train.py # 2.3/2.4 training loop w/ val eval & ckpt 10 | # sample.py # 2.5 text generation from a checkpoint 11 | # eval_loss.py # 2.6 evaluate loss on a file/ckpt 12 | # tests/ 13 | # test_tokenizer.py # round-trip encode/decode 14 | # test_dataset_shift.py # label shift sanity 15 | # runs/ # (created at runtime) checkpoints & logs 16 | # 17 | # NOTE ON IMPORTS 18 | # ---------------- 19 | # All imports are LOCAL. Run from inside `part_2/`. 20 | # Example quickstart (CPU ok): 21 | # cd part_2 22 | # python train.py --data tiny.txt --steps 300 --sample_every 100 23 | # python sample.py --ckpt runs/min-gpt/model_best.pt --tokens 200 --prompt 'Once upon a time ' 24 | 25 | 26 | import subprocess, sys, pathlib, shlex 27 | 28 | ROOT = pathlib.Path(__file__).resolve().parent 29 | RUNS = ROOT / 'runs' / 'min-gpt' 30 | 31 | def run(cmd: str): 32 | print(f"\n>>> {cmd}") 33 | res = subprocess.run(shlex.split(cmd), cwd=ROOT) 34 | if res.returncode != 0: 35 | sys.exit(res.returncode) 36 | 37 | if __name__ == '__main__': 38 | # quick smoke training on a tiny file path tiny_hi.txt; adjust as needed 39 | run("python train.py --data tiny_hi.txt --steps 400 --sample_every 100 --eval_interval 100 --batch_size 32 --block_size 128 --n_layer 2 --n_head 2 --n_embd 128") 40 | 41 | # sample from the best checkpoint 42 | run(f"python sample.py --ckpt {RUNS}/model_best.pt --tokens 200 --prompt 'Once upon a time '") 43 | 44 | # evaluate final val loss 45 | run(f"python eval_loss.py --data tiny_hi.txt --ckpt {RUNS}/model_best.pt --iters 50 --block_size 128") -------------------------------------------------------------------------------- /part_1/attn_numpy_demo.py: -------------------------------------------------------------------------------- 1 | """1.2 Self-attention from first principles on a tiny example (NumPy only). 2 | We use T=3 tokens, d_model=4, d_k=d_v=2, single-head. 3 | This script prints intermediate tensors so you can trace the math. 4 | 5 | Dimensions summary (single head) 6 | -------------------------------- 7 | X: (B=1, T=3, d_model=4) 8 | Wq/Wk/Wv: (d_model=4, d_k=2) 9 | Q,K,V: (1, 3, 2) 10 | Scores: (1, 3, 3) = Q @ K^T 11 | Weights: (1, 3, 3) = softmax over last dim 12 | Output: (1, 3, 2) = Weights @ V 13 | """ 14 | import numpy as np 15 | 16 | np.set_printoptions(precision=4, suppress=True) 17 | 18 | # Toy inputs (batch=1, seq=3, d_model=4) 19 | X = np.array([[[0.1, 0.2, 0.3, 0.4], 20 | [0.5, 0.4, 0.3, 0.2], 21 | [0.0, 0.1, 0.0, 0.1]]], dtype=np.float32) 22 | 23 | # Weight matrices (learned in real models). We fix numbers for determinism. 24 | Wq = np.array([[ 0.2, -0.1], 25 | [ 0.0, 0.1], 26 | [ 0.1, 0.2], 27 | [-0.1, 0.0]], dtype=np.float32) 28 | Wk = np.array([[ 0.1, 0.1], 29 | [ 0.0, -0.1], 30 | [ 0.2, 0.0], 31 | [ 0.0, 0.2]], dtype=np.float32) 32 | Wv = np.array([[ 0.1, 0.0], 33 | [-0.1, 0.1], 34 | [ 0.2, -0.1], 35 | [ 0.0, 0.2]], dtype=np.float32) 36 | 37 | # Project to Q, K, V 38 | Q = X @ Wq # (1,3,2) 39 | K = X @ Wk # (1,3,2) 40 | V = X @ Wv # (1,3,2) 41 | 42 | print("Q shape:", Q.shape, "\nQ=\n", Q[0]) 43 | print("K shape:", K.shape, "\nK=\n", K[0]) 44 | print("V shape:", V.shape, "\nV=\n", V[0]) 45 | 46 | # Scaled dot-products 47 | scale = 1.0 / np.sqrt(Q.shape[-1]) 48 | attn_scores = (Q @ K.transpose(0,2,1)) * scale # (1,3,3) 49 | 50 | # Causal mask (upper triangle set to -inf so softmax->0) 51 | mask = np.triu(np.ones((1,3,3), dtype=bool), k=1) 52 | attn_scores = np.where(mask, -1e9, attn_scores) 53 | 54 | # Softmax over last dim 55 | weights = np.exp(attn_scores - attn_scores.max(axis=-1, keepdims=True)) 56 | weights = weights / weights.sum(axis=-1, keepdims=True) 57 | print("Weights shape:", weights.shape, "\nAttention Weights (causal)=\n", weights[0]) 58 | 59 | # Weighted sum of V 60 | out = weights @ V # (1,3,2) 61 | print("Output shape:", out.shape, "\nOutput=\n", out[0]) -------------------------------------------------------------------------------- /part_4/orchestrator.py: -------------------------------------------------------------------------------- 1 | # Repository layout (Part 4) 2 | # 3 | # part_4/ 4 | # orchestrator.py # run unit tests + optional smoke train & sample 5 | # tokenizer_bpe.py # 4.1 BPE tokenization (train/save/load) 6 | # dataset_bpe.py # streaming dataset + batching & label shift 7 | # lr_scheduler.py # 4.3 Warmup + cosine decay scheduler 8 | # amp_accum.py # 4.2 AMP (autocast+GradScaler) + grad accumulation helpers 9 | # checkpointing.py # 4.4 save/resume (model/opt/scaler/scheduler/tokenizer) 10 | # logger.py # 4.5 logging backends (wandb / tensorboard / noop) 11 | # train.py # core training loop (no Trainer API) 12 | # sample.py # load checkpoint & generate text 13 | # tests/ 14 | # test_tokenizer_bpe.py 15 | # test_scheduler.py 16 | # test_resume_shapes.py 17 | # 18 | # Run from inside `part_4/`: 19 | # cd part_4 20 | # python orchestrator.py --demo # tiny smoke run on ../tiny.txt 21 | # pytest -q 22 | # tensorboard --logdir=runs/part4-demo 23 | 24 | import argparse, pathlib, subprocess, sys, shlex 25 | 26 | ROOT = pathlib.Path(__file__).resolve().parent 27 | 28 | def run(cmd: str): 29 | print(f"\n>>> {cmd}") 30 | res = subprocess.run(shlex.split(cmd), cwd=ROOT) 31 | if res.returncode != 0: 32 | sys.exit(res.returncode) 33 | 34 | if __name__ == "__main__": 35 | p = argparse.ArgumentParser() 36 | p.add_argument("--demo", action="store_true", help="run a tiny smoke train+sample") 37 | args = p.parse_args() 38 | 39 | # 1) unit tests 40 | run("python -m pytest -q tests/test_tokenizer_bpe.py") 41 | run("python -m pytest -q tests/test_scheduler.py") 42 | run("python -m pytest -q tests/test_resume_shapes.py") 43 | 44 | # 2) optional demo (quick overfit on tiny file) 45 | if args.demo: 46 | run("python train.py --data ../part_2/tiny.txt --out runs/part4-demo --bpe --vocab_size 8000 --epochs 1 --steps 300 --batch_size 16 --block_size 128 --n_layer 2 --n_head 2 --n_embd 128 --mixed_precision --grad_accum_steps 2 --log tensorboard") 47 | run("python sample.py --ckpt runs/part4-demo/model_last.pt --tokens 100 --prompt 'Generate a short story'") 48 | 49 | print("\nPart 4 checks complete. ✅") -------------------------------------------------------------------------------- /part_1/demo_mha_shapes.py: -------------------------------------------------------------------------------- 1 | """Walkthrough of multi-head attention with explicit matrix math and shapes. 2 | Generates a text log at ./out/mha_shapes.txt. 3 | """ 4 | import os 5 | import math 6 | import torch 7 | from multi_head import MultiHeadSelfAttention 8 | 9 | OUT_TXT = os.path.join(os.path.dirname(__file__), 'out', 'mha_shapes.txt') 10 | 11 | 12 | def log(s): 13 | print(s) 14 | with open(OUT_TXT, 'a') as f: 15 | f.write(s + "\n") 16 | 17 | 18 | if __name__ == "__main__": 19 | # Reset file 20 | os.makedirs(os.path.dirname(OUT_TXT), exist_ok=True) 21 | open(OUT_TXT, 'w').close() 22 | 23 | B, T, d_model, n_head = 1, 5, 12, 3 24 | d_head = d_model // n_head 25 | x = torch.randn(B, T, d_model) 26 | attn = MultiHeadSelfAttention(d_model, n_head, trace_shapes=True) 27 | 28 | log(f"Input x: {tuple(x.shape)} = (B,T,d_model)") 29 | qkv = attn.qkv(x) # (B,T,3*d_model) 30 | log(f"Linear qkv(x): {tuple(qkv.shape)} = (B,T,3*d_model)") 31 | 32 | qkv = qkv.view(B, T, 3, n_head, d_head) 33 | log(f"view to 5D: {tuple(qkv.shape)} = (B,T,3,heads,d_head)") 34 | 35 | q, k, v = qkv.unbind(dim=2) 36 | log(f"q,k,v split: q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}") 37 | 38 | q = q.transpose(1, 2) 39 | k = k.transpose(1, 2) 40 | v = v.transpose(1, 2) 41 | log(f"transpose heads: q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)} = (B,heads,T,d_head)") 42 | 43 | scale = 1.0 / math.sqrt(d_head) 44 | scores = torch.matmul(q, k.transpose(-2, -1)) * scale 45 | log(f"scores q@k^T: {tuple(scores.shape)} = (B,heads,T,T)") 46 | 47 | weights = torch.softmax(scores, dim=-1) 48 | log(f"softmax(weights): {tuple(weights.shape)} = (B,heads,T,T)") 49 | 50 | ctx = torch.matmul(weights, v) 51 | log(f"context @v: {tuple(ctx.shape)} = (B,heads,T,d_head)") 52 | 53 | out = ctx.transpose(1, 2).contiguous().view(B, T, d_model) 54 | log(f"merge heads: {tuple(out.shape)} = (B,T,d_model)") 55 | 56 | out = attn.proj(out) 57 | log(f"final proj: {tuple(out.shape)} = (B,T,d_model)") 58 | 59 | log("\nLegend:") 60 | log(" B=batch, T=sequence length, d_model=embedding size, heads=n_head, d_head=d_model/heads") 61 | log(" qkv(x) is a single Linear producing [Q|K|V]; we reshape then split into q,k,v") -------------------------------------------------------------------------------- /part_6/sample_sft.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import argparse, torch 3 | 4 | # Reuse GPTModern 5 | import sys 6 | from pathlib import Path as _P 7 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3')) 8 | from model_modern import GPTModern # noqa: E402 9 | 10 | from collator_sft import SFTCollator 11 | from formatters import format_prompt_only 12 | 13 | 14 | def main(): 15 | p = argparse.ArgumentParser() 16 | p.add_argument('--ckpt', type=str, required=True) 17 | p.add_argument('--prompt', type=str, required=True) 18 | p.add_argument('--block_size', type=int, default=256) 19 | p.add_argument('--n_layer', type=int, default=4) 20 | p.add_argument('--n_head', type=int, default=4) 21 | p.add_argument('--n_embd', type=int, default=256) 22 | p.add_argument('--tokens', type=int, default=80) 23 | p.add_argument('--temperature', type=float, default=0.2) 24 | p.add_argument('--cpu', action='store_true') 25 | p.add_argument('--bpe_dir', type=str, default='../part_4/runs/part4-demo/tokenizer') 26 | args = p.parse_args() 27 | 28 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 29 | 30 | ckpt = torch.load(args.ckpt, map_location=device) 31 | cfg = ckpt.get('config', {}) 32 | 33 | col = SFTCollator(block_size=cfg.get('block_size', 256), bpe_dir=args.bpe_dir) 34 | model = GPTModern(vocab_size=col.vocab_size, block_size=args.block_size, 35 | n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, 36 | use_rmsnorm=True, use_swiglu=True, rope=True).to(device) 37 | model.load_state_dict(ckpt['model']) 38 | model.eval() 39 | 40 | prompt_text = format_prompt_only(args.prompt).replace('','') 41 | ids = col.encode(prompt_text) 42 | idx = torch.tensor([ids], dtype=torch.long, device=device) 43 | 44 | with torch.no_grad(): 45 | out = model.generate(idx, max_new_tokens=args.tokens, 46 | temperature=args.temperature, top_k=3) 47 | 48 | # decode: prefer BPE if collator has it, else fall back to bytes 49 | out_ids = out[0].tolist() 50 | orig_len = idx.size(1) 51 | if hasattr(col, "tok") and hasattr(col.tok, "decode"): 52 | # decode full text or just the generated suffix; suffix is often clearer 53 | generated = col.tok.decode(out_ids) 54 | print(generated) 55 | else: 56 | generated = bytes(out_ids[orig_len:]).decode("utf-8", errors="ignore") 57 | print(generated) 58 | 59 | 60 | if __name__ == '__main__': 61 | main() -------------------------------------------------------------------------------- /part_1/multi_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from attn_mask import causal_mask 6 | 7 | class MultiHeadSelfAttention(nn.Module): 8 | """1.4 Multi-head attention with explicit shape tracing. 9 | 10 | Dimensions (before masking): 11 | x: (B, T, d_model) 12 | qkv: (B, T, 3*d_model) 13 | view→ (B, T, 3, n_head, d_head) where d_head = d_model // n_head 14 | split→ q,k,v each (B, T, n_head, d_head) 15 | swap→ (B, n_head, T, d_head) 16 | scores: (B, n_head, T, T) = q @ k^T / sqrt(d_head) 17 | weights:(B, n_head, T, T) = softmax(scores) 18 | ctx: (B, n_head, T, d_head) = weights @ v 19 | merge: (B, T, n_head*d_head) = (B, T, d_model) 20 | """ 21 | def __init__(self, d_model: int, n_head: int, dropout: float = 0.0, trace_shapes: bool = True): 22 | super().__init__() 23 | assert d_model % n_head == 0, "d_model must be divisible by n_head" 24 | self.n_head = n_head 25 | self.d_head = d_model // n_head 26 | self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) 27 | self.proj = nn.Linear(d_model, d_model, bias=False) 28 | self.dropout = nn.Dropout(dropout) 29 | self.trace_shapes = trace_shapes 30 | 31 | def forward(self, x: torch.Tensor): # (B,T,d_model) 32 | B, T, C = x.shape 33 | qkv = self.qkv(x) # (B,T,3*C) 34 | qkv = qkv.view(B, T, 3, self.n_head, self.d_head) # (B,T,3,heads,dim) 35 | if self.trace_shapes: 36 | print("qkv view:", qkv.shape) 37 | q, k, v = qkv.unbind(dim=2) # each: (B,T,heads,dim) 38 | q = q.transpose(1, 2) # (B,heads,T,dim) 39 | k = k.transpose(1, 2) 40 | v = v.transpose(1, 2) 41 | if self.trace_shapes: 42 | print("q:", q.shape, "k:", k.shape, "v:", v.shape) 43 | 44 | scale = 1.0 / math.sqrt(self.d_head) 45 | attn = torch.matmul(q, k.transpose(-2, -1)) * scale # (B,heads,T,T) 46 | mask = causal_mask(T, device=x.device) 47 | attn = attn.masked_fill(mask, float('-inf')) 48 | w = F.softmax(attn, dim=-1) 49 | w = self.dropout(w) 50 | ctx = torch.matmul(w, v) # (B,heads,T,dim) 51 | if self.trace_shapes: 52 | print("weights:", w.shape, "ctx:", ctx.shape) 53 | out = ctx.transpose(1, 2).contiguous().view(B, T, C) # (B,T,d_model) 54 | out = self.proj(out) 55 | if self.trace_shapes: 56 | print("out:", out.shape) 57 | return out, w -------------------------------------------------------------------------------- /part_1/orchestrator.py: -------------------------------------------------------------------------------- 1 | # Repository layout (Part 1) 2 | # 3 | # part_1/ 4 | # orchestrator.py # runs demos/tests/visualizations for Part 1 5 | # pos_encoding.py # 1.1 positional encodings (learned + sinusoidal) 6 | # attn_numpy_demo.py # 1.2 self-attention math with tiny numbers (NumPy) 7 | # single_head.py # 1.3 single attention head (PyTorch) 8 | # multi_head.py # 1.4 multi-head attention (with shape tracing) 9 | # ffn.py # 1.5 feed-forward network (GELU, width = mult*d_model) 10 | # block.py # 1.6 Transformer block (residuals + LayerNorm) 11 | # attn_mask.py # causal mask helpers 12 | # vis_utils.py # plotting helpers (matrices & attention maps) 13 | # demo_mha_shapes.py # prints explicit matrix multiplications & shapes step-by-step 14 | # demo_visualize_multi_head.py # saves attention heatmaps per head (grid) 15 | # out/ # (created at runtime) images & logs live here 16 | # tests/ 17 | # test_attn_math.py # correctness: tiny example vs PyTorch single-head 18 | # test_causal_mask.py # verifies masking behavior 19 | # 20 | # NOTE ON IMPORTS 21 | # ---------------- 22 | # All imports are LOCAL. Run from inside `part_1/`. 23 | # Example quickstart (CPU ok): 24 | # cd part_1 25 | # python orchestrator.py --visualize 26 | 27 | 28 | import subprocess, sys, pathlib, argparse, shlex 29 | 30 | ROOT = pathlib.Path(__file__).resolve().parent 31 | OUT = ROOT / "out" 32 | 33 | 34 | def run(cmd: str): 35 | print(f"\n>>> {cmd}") 36 | res = subprocess.run(shlex.split(cmd), cwd=ROOT) 37 | if res.returncode != 0: 38 | sys.exit(res.returncode) 39 | 40 | 41 | def main(): 42 | p = argparse.ArgumentParser() 43 | p.add_argument("--visualize", action="store_true", help="run visualization scripts and save PNGs to ./out") 44 | args = p.parse_args() 45 | 46 | OUT.mkdir(exist_ok=True) 47 | 48 | # 1.2 sanity check: NumPy tiny example 49 | run("python attn_numpy_demo.py") 50 | 51 | # 1.3/1.4 unit tests 52 | run("python -m pytest -q tests/test_attn_math.py") 53 | run("python -m pytest -q tests/test_causal_mask.py") 54 | 55 | # Matrix math walkthrough for MHA 56 | run("python demo_mha_shapes.py") 57 | 58 | if args.visualize: 59 | run("python demo_visualize_multi_head.py") 60 | print(f"\nVisualization images saved to: {OUT}") 61 | 62 | print("\nAll Part 1 demos/tests completed. ✅") 63 | 64 | 65 | if __name__ == "__main__": 66 | main() -------------------------------------------------------------------------------- /part_8/orchestrator.py: -------------------------------------------------------------------------------- 1 | # Repository layout (Part 8 — RLHF with PPO) 2 | # 3 | # part_8/ 4 | # orchestrator.py # run unit tests + optional tiny PPO demo 5 | # policy.py # policy = SFT LM + value head (toy head on logits) 6 | # rollout.py # prompt formatting, sampling, logprobs/KL utilities 7 | # ppo_loss.py # PPO clipped objective + value + entropy + KL penalty 8 | # train_ppo.py # single‑GPU RLHF loop (tiny, on‑policy) 9 | # eval_ppo.py # compare reward vs. reference on a small set 10 | # tests/ 11 | # test_ppo_loss.py 12 | # test_policy_forward.py 13 | # 14 | # Run from inside `part_8/`: 15 | # cd part_8 16 | # python orchestrator.py --demo 17 | # pytest -q 18 | 19 | import argparse, pathlib, subprocess, sys 20 | ROOT = pathlib.Path(__file__).resolve().parent 21 | 22 | def run(cmd: str): 23 | print(f"\n>>> {cmd}") 24 | res = subprocess.run(cmd.split(), cwd=ROOT) 25 | if res.returncode != 0: 26 | sys.exit(res.returncode) 27 | 28 | if __name__ == "__main__": 29 | p = argparse.ArgumentParser() 30 | p.add_argument("--demo", action="store_true", help="tiny PPO demo") 31 | args = p.parse_args() 32 | 33 | # 1) unit tests 34 | run("python -m pytest -q tests/test_ppo_loss.py") 35 | run("python -m pytest -q tests/test_policy_forward.py") 36 | 37 | # 2) optional demo (requires SFT+RM checkpoints from Parts 6 & 7) 38 | if args.demo: 39 | # run("python train_ppo.py --policy_ckpt ../part_6/runs/sft-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --steps 10 --batch_size 4 --resp_len 128 --bpe_dir ../part_4/runs/part4-demo/tokenizer") 40 | # run("python eval_ppo.py --policy_ckpt runs/ppo-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --split train[:24] --bpe_dir ../part_4/runs/part4-demo/tokenizer") 41 | 42 | # run("python train_ppo.py --policy_ckpt ../part_6/runs/sft-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --steps 50 --batch_size 4 --resp_len 128 --bpe_dir ../part_4/runs/part4-demo/tokenizer") 43 | # run("python eval_ppo.py --policy_ckpt runs/ppo-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --split train[:24] --bpe_dir ../part_4/runs/part4-demo/tokenizer") 44 | 45 | run("python train_ppo.py --policy_ckpt ../part_6/runs/sft-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --steps 100 --batch_size 4 --resp_len 128 --bpe_dir ../part_4/runs/part4-demo/tokenizer") 46 | run("python eval_ppo.py --policy_ckpt runs/ppo-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --split train[:24] --bpe_dir ../part_4/runs/part4-demo/tokenizer") 47 | 48 | print("\nPart 8 checks complete. ✅") -------------------------------------------------------------------------------- /part_7/collator_rm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List, Tuple 3 | import torch 4 | 5 | # Prefer BPE from Part 4, else ByteTokenizer from Part 3 6 | import sys 7 | from pathlib import Path as _P 8 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_4')) 9 | try: 10 | from tokenizer_bpe import BPETokenizer 11 | _HAS_BPE = True 12 | except Exception: 13 | _HAS_BPE = False 14 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3')) 15 | try: 16 | from tokenizer import ByteTokenizer 17 | except Exception: 18 | ByteTokenizer = None 19 | 20 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_6')) 21 | try: 22 | from formatters import Example, format_example # reuse formatting 23 | except Exception: 24 | pass 25 | 26 | class PairCollator: 27 | """Tokenize preference pairs into (pos, neg) input ids. 28 | We format as the SFT template with the 'chosen' or 'rejected' text as the Response. 29 | """ 30 | def __init__(self, block_size: int = 256, bpe_dir: str | None = None, vocab_size: int | None = None): 31 | self.block_size = block_size 32 | self.tok = None 33 | if _HAS_BPE: 34 | try: 35 | self.tok = BPETokenizer(vocab_size=vocab_size or 8000) 36 | if bpe_dir: 37 | self.tok.load(bpe_dir) 38 | except Exception: 39 | self.tok = None 40 | if self.tok is None and ByteTokenizer is not None: 41 | self.tok = ByteTokenizer() 42 | if self.tok is None: 43 | raise RuntimeError("No tokenizer available.") 44 | 45 | @property 46 | def vocab_size(self) -> int: 47 | return getattr(self.tok, 'vocab_size', 256) 48 | 49 | def _encode(self, text: str) -> List[int]: 50 | if hasattr(self.tok, 'encode'): 51 | ids = self.tok.encode(text) 52 | if isinstance(ids, torch.Tensor): 53 | ids = ids.tolist() 54 | return ids 55 | return list(text.encode('utf-8')) 56 | 57 | def collate(self, batch: List[Tuple[str, str, str]]): 58 | # batch of (prompt, chosen, rejected) 59 | pos_ids, neg_ids = [], [] 60 | for prompt, chosen, rejected in batch: 61 | pos_text = format_example(Example(prompt, chosen)) 62 | neg_text = format_example(Example(prompt, rejected)) 63 | pos_ids.append(self._encode(pos_text)[:self.block_size]) 64 | neg_ids.append(self._encode(neg_text)[:self.block_size]) 65 | def pad_to(x, pad=2): 66 | return x + [pad] * (self.block_size - len(x)) if len(x) < self.block_size else x[:self.block_size] 67 | pos = torch.tensor([pad_to(x) for x in pos_ids], dtype=torch.long) 68 | neg = torch.tensor([pad_to(x) for x in neg_ids], dtype=torch.long) 69 | return pos, neg -------------------------------------------------------------------------------- /part_7/train_rm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import argparse, torch 3 | from pathlib import Path 4 | 5 | from data_prefs import load_preferences 6 | from collator_rm import PairCollator 7 | from model_reward import RewardModel 8 | from loss_reward import bradley_terry_loss, margin_ranking_loss 9 | 10 | 11 | def main(): 12 | p = argparse.ArgumentParser() 13 | p.add_argument('--out', type=str, default='runs/rm-demo') 14 | p.add_argument('--steps', type=int, default=500) 15 | p.add_argument('--batch_size', type=int, default=8) 16 | p.add_argument('--block_size', type=int, default=256) 17 | p.add_argument('--n_layer', type=int, default=4) 18 | p.add_argument('--n_head', type=int, default=4) 19 | p.add_argument('--n_embd', type=int, default=256) 20 | p.add_argument('--lr', type=float, default=1e-4) 21 | p.add_argument('--loss', choices=['bt','margin'], default='bt') 22 | p.add_argument('--cpu', action='store_true') 23 | p.add_argument('--bpe_dir', type=str, default=None) 24 | args = p.parse_args() 25 | 26 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 27 | 28 | # data 29 | items = load_preferences(split='train[:80]') 30 | triples = [(it.prompt, it.chosen, it.rejected) for it in items] 31 | 32 | # collator + model 33 | col = PairCollator(block_size=args.block_size, bpe_dir=args.bpe_dir) 34 | model = RewardModel(vocab_size=col.vocab_size, block_size=args.block_size, 35 | n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd).to(device) 36 | opt = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) 37 | 38 | # train (tiny) 39 | step = 0; i = 0 40 | while step < args.steps: 41 | batch = triples[i:i+args.batch_size] 42 | if not batch: 43 | i = 0; continue 44 | pos, neg = col.collate(batch) 45 | pos, neg = pos.to(device), neg.to(device) 46 | r_pos = model(pos) 47 | r_neg = model(neg) 48 | if args.loss == 'bt': 49 | loss = bradley_terry_loss(r_pos, r_neg) 50 | else: 51 | loss = margin_ranking_loss(r_pos, r_neg, margin=1.0) 52 | opt.zero_grad(set_to_none=True) 53 | loss.backward() 54 | opt.step() 55 | step += 1; i += args.batch_size 56 | if step % 25 == 0: 57 | acc = (r_pos > r_neg).float().mean().item() 58 | print(f"step {step}: loss={loss.item():.4f} acc={acc:.2f}") 59 | 60 | Path(args.out).mkdir(parents=True, exist_ok=True) 61 | torch.save({'model': model.state_dict(), 'config': { 62 | 'vocab_size': col.vocab_size, 63 | 'block_size': args.block_size, 64 | 'n_layer': args.n_layer, 65 | 'n_head': args.n_head, 66 | 'n_embd': args.n_embd, 67 | }}, str(Path(args.out)/'model_last.pt')) 68 | print(f"Saved reward model to {args.out}/model_last.pt") 69 | 70 | if __name__ == '__main__': 71 | main() -------------------------------------------------------------------------------- /part_4/tokenizer_bpe.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import os, json 3 | from pathlib import Path 4 | from typing import List, Union 5 | 6 | try: 7 | from tokenizers import ByteLevelBPETokenizer, Tokenizer 8 | except Exception: 9 | ByteLevelBPETokenizer = None 10 | 11 | class BPETokenizer: 12 | """Minimal BPE wrapper (HuggingFace tokenizers). 13 | Trains on a text file or a folder of .txt files. Saves merges/vocab to out_dir. 14 | """ 15 | def __init__(self, vocab_size: int = 32000, special_tokens: List[str] | None = None): 16 | if ByteLevelBPETokenizer is None: 17 | raise ImportError("Please `pip install tokenizers` for BPETokenizer.") 18 | self.vocab_size = vocab_size 19 | self.special_tokens = special_tokens or ["", "", "", "", ""] 20 | self._tok = None 21 | 22 | def train(self, data_path: Union[str, Path]): 23 | files: List[str] = [] 24 | p = Path(data_path) 25 | if p.is_dir(): 26 | files = [str(fp) for fp in p.glob("**/*.txt")] 27 | else: 28 | files = [str(p)] 29 | tok = ByteLevelBPETokenizer() 30 | tok.train(files=files, vocab_size=self.vocab_size, min_frequency=2, special_tokens=self.special_tokens) 31 | self._tok = tok 32 | 33 | def save(self, out_dir: Union[str, Path]): 34 | out = Path(out_dir); out.mkdir(parents=True, exist_ok=True) 35 | assert self._tok is not None, "Train or load before save()." 36 | self._tok.save_model(str(out)) 37 | self._tok.save(str(out / "tokenizer.json")) 38 | meta = {"vocab_size": self.vocab_size, "special_tokens": self.special_tokens} 39 | (out/"bpe_meta.json").write_text(json.dumps(meta)) 40 | 41 | def load(self, dir_path: Union[str, Path]): 42 | dirp = Path(dir_path) 43 | # Prefer explicit filenames; fall back to glob if needed. 44 | vocab = dirp / "vocab.json" 45 | merges = dirp / "merges.txt" 46 | tokenizer = dirp / "tokenizer.json" 47 | if not vocab.exists() or not merges.exists(): 48 | # Fallback for custom basenames 49 | vs = list(dirp.glob("*.json")) 50 | ms = list(dirp.glob("*.txt")) 51 | if not vs or not ms: 52 | raise FileNotFoundError(f"Could not find vocab.json/merges.txt in {dirp}") 53 | vocab = vs[0] 54 | merges = ms[0] 55 | # tok = ByteLevelBPETokenizer(str(vocab), str(merges)) 56 | tok = Tokenizer.from_file(str(tokenizer)) 57 | self._tok = tok 58 | meta_file = dirp / "bpe_meta.json" 59 | if meta_file.exists(): 60 | meta = json.loads(meta_file.read_text()) 61 | self.vocab_size = meta.get("vocab_size", self.vocab_size) 62 | self.special_tokens = meta.get("special_tokens", self.special_tokens) 63 | 64 | 65 | def encode(self, text: str): 66 | ids = self._tok.encode(text).ids 67 | return ids 68 | 69 | def decode(self, ids): 70 | return self._tok.decode(ids) -------------------------------------------------------------------------------- /part_6/orchestrator.py: -------------------------------------------------------------------------------- 1 | # Repository layout (Part 6) 2 | # 3 | # part_6/ 4 | # orchestrator.py # run unit tests + optional tiny SFT demo 5 | # formatters.py # 6.1 prompt/response templates 6 | # dataset_sft.py # HF dataset loader (+tiny fallback) → (prompt, response) 7 | # collator_sft.py # 6.2 causal LM labels with masking 8 | # curriculum.py # 6.3 length‑based curriculum sampler 9 | # evaluate.py # 6.4 simple exact/F1 metrics 10 | # train_sft.py # minimal one‑GPU SFT loop (few steps) 11 | # sample_sft.py # load ckpt & generate from instructions 12 | # tests/ 13 | # test_formatter.py 14 | # test_masking.py 15 | # 16 | # Run from inside `part_6/`: 17 | # cd part_6 18 | # python orchestrator.py --demo 19 | # pytest -q 20 | 21 | # 22 | # part_6/ 23 | # orchestrator.py # run unit tests + optional tiny SFT demo 24 | # formatters.py # 6.1 prompt/response templates 25 | # dataset_sft.py # HF dataset loader (+tiny fallback) → (prompt, response) 26 | # collator_sft.py # 6.2 causal LM labels with masking 27 | # curriculum.py # 6.3 length‑based curriculum sampler 28 | # evaluate.py # 6.4 simple exact/F1 metrics 29 | # train_sft.py # minimal one‑GPU SFT loop (few steps) 30 | # sample_sft.py # load ckpt & generate from instructions 31 | # tests/ 32 | # test_formatter.py 33 | # test_masking.py 34 | # 35 | # Run from inside `part_6/`: 36 | # cd part_6 37 | # python orchestrator.py --demo 38 | # pytest -q 39 | 40 | ### FILE: part_6/orchestrator.py 41 | import argparse, pathlib, subprocess, sys, shlex 42 | ROOT = pathlib.Path(__file__).resolve().parent 43 | 44 | def run(cmd: str): 45 | print(f"\n>>> {cmd}") 46 | res = subprocess.run(shlex.split(cmd), cwd=ROOT) 47 | if res.returncode != 0: 48 | sys.exit(res.returncode) 49 | 50 | if __name__ == "__main__": 51 | p = argparse.ArgumentParser() 52 | p.add_argument("--demo", action="store_true", help="tiny SFT demo on a few samples") 53 | args = p.parse_args() 54 | 55 | # 1) unit tests 56 | run("python -m pytest -q tests/test_formatter.py") 57 | run("python -m pytest -q tests/test_masking.py") 58 | 59 | # 2) optional demo 60 | if args.demo: 61 | # --ckpt ../part_4/runs/part4-demo/model_last.pt # assumes Part 4 demo has been run 62 | run("python train_sft.py --data huggingface --ckpt ../part_4/runs/part4-demo/model_last.pt --out runs/sft-demo --steps 300 --batch_size 8 --block_size 256 --n_layer 2 --n_head 2 --n_embd 128") 63 | run("python sample_sft.py --ckpt runs/sft-demo/model_last.pt --block_size 256 --n_layer 2 --n_head 2 --n_embd 128 --prompt 'What are the three primary colors?' --tokens 30 --temperature 0.2") 64 | run("python sample_sft.py --ckpt runs/sft-demo/model_last.pt --block_size 256 --n_layer 2 --n_head 2 --n_embd 128 --prompt 'What does DNA stand for?' --tokens 30 --temperature 0.2") 65 | run("python sample_sft.py --ckpt runs/sft-demo/model_last.pt --block_size 256 --n_layer 2 --n_head 2 --n_embd 128 --prompt 'Reverse engineer this code to create a new version\ndef factorialize(num):\n factorial = 1\n for i in range(1, num):\n factorial *= i\n \n return factorial' --tokens 64 --temperature 0.2") 66 | 67 | print("\nPart 6 checks complete. ✅") -------------------------------------------------------------------------------- /part_8/eval_ppo.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import argparse, torch 3 | from pathlib import Path 4 | 5 | from policy import PolicyWithValue 6 | from rollout import RLHFTokenizer, sample_prompts, format_prompt_only 7 | 8 | # Reward model 9 | import sys 10 | from pathlib import Path as _P 11 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_7')) 12 | from model_reward import RewardModel # noqa: E402 13 | 14 | 15 | def score_policy(policy_ckpt: str, rm_ckpt: str, bpe_dir: str | None, n: int = 16): 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | tok = RLHFTokenizer(block_size=256, bpe_dir=bpe_dir) 18 | 19 | ckpt = torch.load(policy_ckpt, map_location=device) 20 | cfg = ckpt.get('config', {}) 21 | pol = PolicyWithValue(cfg.get('vocab_size', tok.vocab_size), cfg.get('block_size', tok.block_size), 22 | cfg.get('n_layer', 2), cfg.get('n_head', 2), cfg.get('n_embd', 128)).to(device) 23 | pol.load_state_dict(ckpt['model']) 24 | pol.eval() 25 | 26 | # For comparing against reference policy (SFT) 27 | ref = PolicyWithValue(cfg.get('vocab_size', tok.vocab_size), cfg.get('block_size', tok.block_size), 28 | cfg.get('n_layer', 2), cfg.get('n_head', 2), cfg.get('n_embd', 128)).to(device) 29 | ckpt_ref = torch.load("../part_6/runs/sft-demo/model_last.pt", map_location=device) # hardcoded path to SFT checkpoint 30 | ref.lm.load_state_dict(ckpt_ref['model']) 31 | for p_ in ref.parameters(): 32 | p_.requires_grad_(False) 33 | ref.eval() 34 | 35 | rckpt = torch.load(rm_ckpt, map_location=device) 36 | rm = RewardModel(vocab_size=rckpt['config'].get('vocab_size', tok.vocab_size), block_size=rckpt['config'].get('block_size', tok.block_size), 37 | n_layer=rckpt['config'].get('n_layer', 4), n_head=rckpt['config'].get('n_head', 4), n_embd=rckpt['config'].get('n_embd', 256)).to(device) 38 | rm.load_state_dict(rckpt['model']) 39 | rm.eval() 40 | 41 | prompts = sample_prompts(n) 42 | rewards = [] 43 | for p in prompts: 44 | prefix = format_prompt_only(p).replace('', '') 45 | ids = tok.encode(prefix) 46 | x = torch.tensor([ids[-tok.block_size:]], dtype=torch.long, device=device) 47 | with torch.no_grad(): 48 | y = pol.generate(x, max_new_tokens=128, temperature=0.2, top_k=50) 49 | y_old = ref.generate(x, max_new_tokens=128, temperature=0.2, top_k=50) 50 | resp = tok.decode(y[0].tolist()[len(ids[-tok.block_size:]):]) 51 | resp_old = tok.decode(y_old[0].tolist()[len(ids[-tok.block_size:]):]) 52 | 53 | # compute RM reward on formatted full text 54 | from part_6.formatters import Example, format_example 55 | text = format_example(Example(p, resp)) 56 | z = torch.tensor([tok.encode(text)[:tok.block_size]], dtype=torch.long, device=device) 57 | with torch.no_grad(): 58 | r = rm(z)[0].item() 59 | rewards.append(r) 60 | return sum(rewards)/max(1,len(rewards)) 61 | 62 | 63 | if __name__ == '__main__': 64 | p = argparse.ArgumentParser() 65 | p.add_argument('--policy_ckpt', type=str, required=True) 66 | p.add_argument('--reward_ckpt', type=str, required=True) 67 | p.add_argument('--split', type=str, default='val[:32]') # unused in this tiny script 68 | p.add_argument('--bpe_dir', type=str, default=None) 69 | args = p.parse_args() 70 | 71 | avg_r = score_policy(args.policy_ckpt, args.reward_ckpt, args.bpe_dir, n=16) 72 | print(f"Avg RM reward: {avg_r:.4f}") -------------------------------------------------------------------------------- /part_9/eval_ppo.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import argparse, torch 3 | from pathlib import Path 4 | 5 | from policy import PolicyWithValue 6 | from rollout import RLHFTokenizer, sample_prompts, format_prompt_only 7 | 8 | # Reward model 9 | import sys 10 | from pathlib import Path as _P 11 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_7')) 12 | from model_reward import RewardModel # noqa: E402 13 | 14 | 15 | def score_policy(policy_ckpt: str, rm_ckpt: str, bpe_dir: str | None, n: int = 16): 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | tok = RLHFTokenizer(block_size=256, bpe_dir=bpe_dir) 18 | 19 | ckpt = torch.load(policy_ckpt, map_location=device) 20 | cfg = ckpt.get('config', {}) 21 | pol = PolicyWithValue(cfg.get('vocab_size', tok.vocab_size), cfg.get('block_size', tok.block_size), 22 | cfg.get('n_layer', 2), cfg.get('n_head', 2), cfg.get('n_embd', 128)).to(device) 23 | pol.load_state_dict(ckpt['model']) 24 | pol.eval() 25 | 26 | # For comparing against reference policy (SFT) 27 | ref = PolicyWithValue(cfg.get('vocab_size', tok.vocab_size), cfg.get('block_size', tok.block_size), 28 | cfg.get('n_layer', 2), cfg.get('n_head', 2), cfg.get('n_embd', 128)).to(device) 29 | ckpt_ref = torch.load("../part_6/runs/sft-demo/model_last.pt", map_location=device) # hardcoded path to SFT checkpoint 30 | ref.lm.load_state_dict(ckpt_ref['model']) 31 | for p_ in ref.parameters(): 32 | p_.requires_grad_(False) 33 | ref.eval() 34 | 35 | rckpt = torch.load(rm_ckpt, map_location=device) 36 | rm = RewardModel(vocab_size=rckpt['config'].get('vocab_size', tok.vocab_size), block_size=rckpt['config'].get('block_size', tok.block_size), 37 | n_layer=rckpt['config'].get('n_layer', 4), n_head=rckpt['config'].get('n_head', 4), n_embd=rckpt['config'].get('n_embd', 256)).to(device) 38 | rm.load_state_dict(rckpt['model']) 39 | rm.eval() 40 | 41 | prompts = sample_prompts(n) 42 | rewards = [] 43 | for p in prompts: 44 | prefix = format_prompt_only(p).replace('', '') 45 | ids = tok.encode(prefix) 46 | x = torch.tensor([ids[-tok.block_size:]], dtype=torch.long, device=device) 47 | with torch.no_grad(): 48 | y = pol.generate(x, max_new_tokens=128, temperature=0.2, top_k=50) 49 | y_old = ref.generate(x, max_new_tokens=128, temperature=0.2, top_k=50) 50 | resp = tok.decode(y[0].tolist()[len(ids[-tok.block_size:]):]) 51 | resp_old = tok.decode(y_old[0].tolist()[len(ids[-tok.block_size:]):]) 52 | 53 | # compute RM reward on formatted full text 54 | from part_6.formatters import Example, format_example 55 | text = format_example(Example(p, resp)) 56 | z = torch.tensor([tok.encode(text)[:tok.block_size]], dtype=torch.long, device=device) 57 | with torch.no_grad(): 58 | r = rm(z)[0].item() 59 | rewards.append(r) 60 | return sum(rewards)/max(1,len(rewards)) 61 | 62 | 63 | if __name__ == '__main__': 64 | p = argparse.ArgumentParser() 65 | p.add_argument('--policy_ckpt', type=str, required=True) 66 | p.add_argument('--reward_ckpt', type=str, required=True) 67 | p.add_argument('--split', type=str, default='val[:32]') # unused in this tiny script 68 | p.add_argument('--bpe_dir', type=str, default=None) 69 | args = p.parse_args() 70 | 71 | avg_r = score_policy(args.policy_ckpt, args.reward_ckpt, args.bpe_dir, n=16) 72 | print(f"Avg RM reward: {avg_r:.4f}") -------------------------------------------------------------------------------- /part_4/sample.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import argparse, torch 3 | from pathlib import Path 4 | 5 | # load Part 3 model 6 | import sys 7 | from pathlib import Path as _P 8 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3')) 9 | from model_modern import GPTModern # noqa: E402 10 | 11 | from tokenizer_bpe import BPETokenizer 12 | 13 | 14 | def main(): 15 | p = argparse.ArgumentParser() 16 | p.add_argument('--ckpt', type=str, required=True) 17 | p.add_argument('--prompt', type=str, default='') 18 | p.add_argument('--tokens', type=int, default=100) 19 | p.add_argument('--cpu', action='store_true') 20 | args = p.parse_args() 21 | 22 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 23 | 24 | ckpt = torch.load(args.ckpt, map_location='cpu') # load on CPU first; move model later 25 | sd = ckpt['model'] 26 | cfg = ckpt.get('config') or {} 27 | 28 | # tokenizer (if present) 29 | tok = None 30 | tok_dir_file = Path(args.ckpt).with_name('tokenizer_dir.txt') 31 | if tok_dir_file.exists(): 32 | tok_dir = tok_dir_file.read_text().strip() # file contains the dir path 33 | tok = BPETokenizer() 34 | tok.load(tok_dir) # <-- instance method, pass the directory 35 | vocab_from_tok = tok.vocab_size 36 | else: 37 | vocab_from_tok = None 38 | 39 | 40 | # ---- build config (prefer saved config; otherwise infer) ---- 41 | if not cfg: 42 | # If a tokenizer is present and vocab differs, override with tokenizer vocab 43 | # if vocab_from_tok is not None and cfg.get('vocab_size') != vocab_from_tok: 44 | # cfg = {**cfg, 'vocab_size': vocab_from_tok} 45 | # else: 46 | # Old checkpoints without config: infer essentials from weights 47 | # tok_emb.weight: [V, C] where C == n_embd 48 | V, C = sd['tok_emb.weight'].shape 49 | # pos_emb.weight: [block_size, C] if present 50 | block_size = sd['pos_emb.weight'].shape[0] if 'pos_emb.weight' in sd else 256 51 | # count transformer blocks present 52 | import re 53 | layer_ids = {int(m.group(1)) for k in sd.keys() if (m := re.match(r"blocks\.(\d+)\.", k))} 54 | n_layer = max(layer_ids) + 1 if layer_ids else 1 55 | # pick an n_head that divides C (head count doesn't affect weight shapes) 56 | n_head = 8 if C % 8 == 0 else 4 if C % 4 == 0 else 2 if C % 2 == 0 else 1 57 | cfg = dict( 58 | vocab_size=vocab_from_tok or V, 59 | block_size=block_size, 60 | n_layer=n_layer, 61 | n_head=n_head, 62 | n_embd=C, 63 | dropout=0.0, 64 | use_rmsnorm=True, 65 | use_swiglu=True, 66 | rope=True, 67 | max_pos=4096, 68 | sliding_window=None, 69 | attention_sink=0, 70 | ) 71 | 72 | # ---- build & load model ---- 73 | model = GPTModern(**cfg).to(device).eval() 74 | model.load_state_dict(ckpt['model']) 75 | model.to(device).eval() 76 | 77 | # prompt ids 78 | if tok: 79 | ids = tok.encode(args.prompt) 80 | if len(ids) == 0: ids = [10] 81 | else: 82 | ids = [10] if args.prompt == '' else list(args.prompt.encode('utf-8')) 83 | idx = torch.tensor([ids], dtype=torch.long, device=device) 84 | 85 | with torch.no_grad(): 86 | out = model.generate(idx, max_new_tokens=args.tokens) 87 | out_ids = out[0].tolist() 88 | if tok: 89 | print(tok.decode(out_ids)) 90 | else: 91 | print(bytes(out_ids).decode('utf-8', errors='ignore')) 92 | 93 | if __name__ == '__main__': 94 | main() -------------------------------------------------------------------------------- /part_6/collator_sft.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List, Tuple 3 | import torch 4 | import traceback 5 | 6 | # Reuse tokenizers: prefer BPE from Part 4 if available; else byte-level from Part 3 7 | import sys 8 | from pathlib import Path as _P 9 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_4')) 10 | try: 11 | from tokenizer_bpe import BPETokenizer 12 | _HAS_BPE = True 13 | except Exception: 14 | _HAS_BPE = False 15 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3')) 16 | try: 17 | from tokenizer import ByteTokenizer 18 | except Exception: 19 | ByteTokenizer = None 20 | 21 | from formatters import Example, format_example, format_prompt_only 22 | 23 | class SFTCollator: 24 | """Turn (instruction,response) into token ids and masked labels for causal LM (6.2). 25 | Labels for the prompt part are set to -100 so they don't contribute to loss. 26 | """ 27 | def __init__(self, block_size: int = 256, bpe_dir: str | None = None): 28 | self.block_size = block_size 29 | self.tok = None 30 | if _HAS_BPE: 31 | # If a trained tokenizer directory exists from Part 4, you can `load` it. 32 | # Otherwise we create an ad-hoc BPE on the fly using fallback prompts during demo. 33 | try: 34 | self.tok = BPETokenizer(vocab_size=8000) 35 | if bpe_dir: 36 | self.tok.load(bpe_dir) 37 | print(f"Loaded BPE tokenizer from {bpe_dir}") 38 | else: 39 | # weak ad-hoc training would belong elsewhere; for the demo we assume Part 4 tokenizer exists 40 | pass 41 | except Exception: 42 | print(traceback.format_exc()) 43 | self.tok = None 44 | if self.tok is None and ByteTokenizer is not None: 45 | self.tok = ByteTokenizer() 46 | if self.tok is None: 47 | raise RuntimeError("No tokenizer available. Install tokenizers or ensure Part 3 ByteTokenizer exists.") 48 | 49 | @property 50 | def vocab_size(self) -> int: 51 | return getattr(self.tok, 'vocab_size', 256) 52 | 53 | def encode(self, text: str) -> List[int]: 54 | if hasattr(self.tok, 'encode'): 55 | ids = self.tok.encode(text) 56 | if isinstance(ids, torch.Tensor): 57 | ids = ids.tolist() 58 | return ids 59 | # ByteTokenizer-like 60 | return list(text.encode('utf-8')) 61 | 62 | def collate(self, batch: List[Tuple[str,str]]): 63 | # Build "prompt + response" and create label mask where prompt positions are -100. 64 | input_ids = [] 65 | labels = [] 66 | for prompt, response in batch: 67 | prefix_text = format_prompt_only(prompt).replace('','') 68 | text = format_example(Example(prompt, response)) 69 | ids = self.encode(text)[:self.block_size] 70 | prompt_ids = self.encode(prefix_text)[:self.block_size] 71 | n_prompt = min(len(prompt_ids), len(ids)) 72 | x = ids 73 | y = ids.copy() 74 | for t in range(len(y) - 1): 75 | y[t] = ids[t + 1] 76 | y[-1] = -100 77 | for i in range(n_prompt-1): 78 | y[i] = -100 79 | input_ids.append(x) 80 | labels.append(y) 81 | # pad to block_size 82 | def pad_to(ids, val): 83 | if len(ids) < self.block_size: 84 | ids = ids + [val]*(self.block_size - len(ids)) 85 | return ids[:self.block_size] 86 | x = torch.tensor([pad_to(s, 2) for s in input_ids], dtype=torch.long) 87 | y = torch.tensor([pad_to(s, -100) for s in labels], dtype=torch.long) 88 | return x, y -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM from Scratch — Hands-On Curriculum (PyTorch) 2 | 3 | [![Watch the video](https://img.youtube.com/vi/p3sij8QzONQ/0.jpg)](https://youtu.be/p3sij8QzONQ?si=yEuD584cBZRNiUYm) 4 | 5 | ## Part 0 — Foundations & Mindset 6 | - **0.1** Understanding the high-level LLM training pipeline (pretraining → finetuning → alignment) 7 | - **0.2** Hardware & software environment setup (PyTorch, CUDA/Mac, mixed precision, profiling tools) 8 | 9 | ``` 10 | conda create -n llm_from_scratch python=3.11 11 | conda activate llm_from_scratch 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Part 1 — Core Transformer Architecture 16 | - **1.1** Positional embeddings (absolute learned vs. sinusoidal) 17 | - **1.2** Self-attention from first principles (manual computation with a tiny example) 18 | - **1.3** Building a *single attention head* in PyTorch 19 | - **1.4** Multi-head attention (splitting, concatenation, projections) 20 | - **1.5** Feed-forward networks (MLP layers) — GELU, dimensionality expansion 21 | - **1.5** Residual connections & **LayerNorm** 22 | - **1.6** Stacking into a full Transformer block 23 | 24 | ## Part 2 — Training a Tiny LLM 25 | - **2.1** Byte-level tokenization 26 | - **2.2** Dataset batching & shifting for next-token prediction 27 | - **2.3** Cross-entropy loss & label shifting 28 | - **2.4** Training loop from scratch (no Trainer API) 29 | - **2.5** Sampling: temperature, top-k, top-p 30 | - **2.6** Evaluating loss on val set 31 | 32 | ## Part 3 — Modernizing the Architecture 33 | - **3.1** **RMSNorm** (replace LayerNorm, compare gradients & convergence) 34 | - **3.2** **RoPE** (Rotary Positional Embeddings) — theory & code 35 | - **3.3** SwiGLU activations in MLP 36 | - **3.4** KV cache for faster inference 37 | - **3.5** Sliding-window attention & **attention sink** 38 | - **3.6** Rolling buffer KV cache for streaming 39 | 40 | ## Part 4 — Scaling Up 41 | - **4.1** Switching from byte-level to BPE tokenization 42 | - **4.2** Gradient accumulation & mixed precision 43 | - **4.3** Learning rate schedules & warmup 44 | - **4.4** Checkpointing & resuming 45 | - **4.5** Logging & visualization (TensorBoard / wandb) 46 | 47 | ## Part 5 — Mixture-of-Experts (MoE) 48 | - **5.1** MoE theory: expert routing, gating networks, and load balancing 49 | - **5.2** Implementing MoE layers in PyTorch 50 | - **5.3** Combining MoE with dense layers for hybrid architectures 51 | 52 | ## Part 6 — Supervised Fine-Tuning (SFT) 53 | - **6.1** Instruction dataset formatting (prompt + response) 54 | - **6.2** Causal LM loss with masked labels 55 | - **6.3** Curriculum learning for instruction data 56 | - **6.4** Evaluating outputs against gold responses 57 | 58 | ## Part 7 — Reward Modeling 59 | - **7.1** Preference datasets (pairwise rankings) 60 | - **7.2** Reward model architecture (transformer encoder) 61 | - **7.3** Loss functions: Bradley–Terry, margin ranking loss 62 | - **7.4** Sanity checks for reward shaping 63 | 64 | ## Part 8 — RLHF with PPO 65 | - **8.1** Policy network: our base LM (from SFT) with a value head for reward prediction. 66 | - **8.2** Reward signal: provided by the reward model trained in Part 7. 67 | - **8.3** PPO objective: balance between maximizing reward and staying close to the SFT policy (KL penalty). 68 | - **8.4** Training loop: sample prompts → generate completions → score with reward model → optimize policy via PPO. 69 | - **8.5** Logging & stability tricks: reward normalization, KL-controlled rollout length, gradient clipping. 70 | 71 | ## Part 9 — RLHF with GRPO 72 | - **9.1** Group-relative baseline: instead of a value head, multiple completions are sampled per prompt and their rewards are normalized against the group mean. 73 | - **9.2** Advantage calculation: each completion’s advantage = (reward – group mean reward), broadcast to all tokens in that trajectory. 74 | - **9.3** Objective: PPO-style clipped policy loss, but *policy-only* (no value loss). 75 | - **9.4** KL regularization: explicit KL(π‖π_ref) penalty term added directly to the loss (not folded into the advantage). 76 | - **9.5** Training loop differences: sample `k` completions per prompt → compute rewards → subtract per-prompt mean → apply GRPO loss with KL penalty. 77 | -------------------------------------------------------------------------------- /part_8/rollout.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | from typing import List, Tuple 4 | 5 | # tokenizer pref: BPE from Part 4 → fallback to ByteTokenizer from Part 3 6 | import sys 7 | from pathlib import Path as _P 8 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_4')) 9 | try: 10 | from tokenizer_bpe import BPETokenizer 11 | _HAS_BPE = True 12 | except Exception: 13 | _HAS_BPE = False 14 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3')) 15 | try: 16 | from tokenizer import ByteTokenizer 17 | except Exception: 18 | ByteTokenizer = None 19 | 20 | from part_6.formatters import Example, format_example, format_prompt_only 21 | 22 | # ---------- tokenizer helpers ---------- 23 | class RLHFTokenizer: 24 | def __init__(self, block_size: int, bpe_dir: str | None = None, vocab_size: int = 8000): 25 | self.block_size = block_size 26 | self.tok = None 27 | if _HAS_BPE: 28 | try: 29 | self.tok = BPETokenizer(vocab_size=vocab_size) 30 | if bpe_dir: 31 | self.tok.load(bpe_dir) 32 | except Exception: 33 | self.tok = None 34 | if self.tok is None and ByteTokenizer is not None: 35 | self.tok = ByteTokenizer() 36 | if self.tok is None: 37 | raise RuntimeError("No tokenizer available for RLHF.") 38 | 39 | @property 40 | def vocab_size(self) -> int: 41 | return getattr(self.tok, 'vocab_size', 256) 42 | 43 | def encode(self, text: str) -> List[int]: 44 | ids = self.tok.encode(text) 45 | if isinstance(ids, torch.Tensor): 46 | ids = ids.tolist() 47 | return ids 48 | 49 | def decode(self, ids: List[int]) -> str: 50 | if hasattr(self.tok, 'decode'): 51 | return self.tok.decode(ids) 52 | return bytes(ids).decode('utf-8', errors='ignore') 53 | 54 | # ---------- logprob utilities ---------- 55 | 56 | def shift_labels(x: torch.Tensor) -> torch.Tensor: 57 | # For causal LM: predict x[t+1] from x[:t] 58 | return x[:, 1:].contiguous() 59 | 60 | def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 61 | """Compute per-token logprobs of the given labels. 62 | logits: (B,T,V), labels: (B,T) over same T 63 | returns: (B,T) log p(labels) 64 | """ 65 | logp = torch.log_softmax(logits, dim=-1) 66 | return logp.gather(-1, labels.unsqueeze(-1)).squeeze(-1) 67 | 68 | @torch.no_grad() 69 | def model_logprobs(model, x: torch.Tensor) -> torch.Tensor: 70 | # compute log p(x[t+1] | x[:t]) for t 71 | logits, _, _ = model.lm(x, None) if hasattr(model, 'lm') else model(x, None) 72 | labels = shift_labels(x) 73 | lp = gather_logprobs(logits[:, :-1, :], labels) 74 | return lp # (B, T-1) 75 | 76 | # ---------- KL ---------- 77 | 78 | def approx_kl(policy_logp: torch.Tensor, ref_logp: torch.Tensor) -> torch.Tensor: 79 | # Mean over tokens: KL(pi||ref) ≈ (logp_pi - logp_ref).mean() 80 | return (policy_logp - ref_logp).mean() 81 | 82 | # ---------- small prompt source ---------- 83 | try: 84 | from datasets import load_dataset as _load_ds 85 | except Exception: 86 | _load_ds = None 87 | 88 | def sample_prompts(n: int) -> List[str]: 89 | if _load_ds is not None: 90 | try: 91 | ds = _load_ds("tatsu-lab/alpaca", split="train[:24]") 92 | arr = [] 93 | for r in ds: 94 | inst = (r.get('instruction') or '').strip() 95 | inp = (r.get('input') or '').strip() 96 | if inp: 97 | inst = inst + "\n" + inp 98 | if inst: 99 | arr.append(inst) 100 | if len(arr) >= n: 101 | break 102 | if arr: 103 | return arr 104 | except Exception: 105 | pass 106 | # fallback 107 | base = [ 108 | "Explain the purpose of attention in transformers.", 109 | "Give two pros and cons of BPE tokenization.", 110 | "Summarize why PPO is used in RLHF.", 111 | "Write a tiny Python function that reverses a list.", 112 | ] 113 | return (base * ((n+len(base)-1)//len(base)))[:n] -------------------------------------------------------------------------------- /part_9/rollout.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | from typing import List, Tuple 4 | 5 | # tokenizer pref: BPE from Part 4 → fallback to ByteTokenizer from Part 3 6 | import sys 7 | from pathlib import Path as _P 8 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_4')) 9 | try: 10 | from tokenizer_bpe import BPETokenizer 11 | _HAS_BPE = True 12 | except Exception: 13 | _HAS_BPE = False 14 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3')) 15 | try: 16 | from tokenizer import ByteTokenizer 17 | except Exception: 18 | ByteTokenizer = None 19 | 20 | from part_6.formatters import Example, format_example, format_prompt_only 21 | 22 | # ---------- tokenizer helpers ---------- 23 | class RLHFTokenizer: 24 | def __init__(self, block_size: int, bpe_dir: str | None = None, vocab_size: int = 8000): 25 | self.block_size = block_size 26 | self.tok = None 27 | if _HAS_BPE: 28 | try: 29 | self.tok = BPETokenizer(vocab_size=vocab_size) 30 | if bpe_dir: 31 | self.tok.load(bpe_dir) 32 | except Exception: 33 | self.tok = None 34 | if self.tok is None and ByteTokenizer is not None: 35 | self.tok = ByteTokenizer() 36 | if self.tok is None: 37 | raise RuntimeError("No tokenizer available for RLHF.") 38 | 39 | @property 40 | def vocab_size(self) -> int: 41 | return getattr(self.tok, 'vocab_size', 256) 42 | 43 | def encode(self, text: str) -> List[int]: 44 | ids = self.tok.encode(text) 45 | if isinstance(ids, torch.Tensor): 46 | ids = ids.tolist() 47 | return ids 48 | 49 | def decode(self, ids: List[int]) -> str: 50 | if hasattr(self.tok, 'decode'): 51 | return self.tok.decode(ids) 52 | return bytes(ids).decode('utf-8', errors='ignore') 53 | 54 | # ---------- logprob utilities ---------- 55 | 56 | def shift_labels(x: torch.Tensor) -> torch.Tensor: 57 | # For causal LM: predict x[t+1] from x[:t] 58 | return x[:, 1:].contiguous() 59 | 60 | def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 61 | """Compute per-token logprobs of the given labels. 62 | logits: (B,T,V), labels: (B,T) over same T 63 | returns: (B,T) log p(labels) 64 | """ 65 | logp = torch.log_softmax(logits, dim=-1) 66 | return logp.gather(-1, labels.unsqueeze(-1)).squeeze(-1) 67 | 68 | @torch.no_grad() 69 | def model_logprobs(model, x: torch.Tensor) -> torch.Tensor: 70 | # compute log p(x[t+1] | x[:t]) for t 71 | logits, _, _ = model.lm(x, None) if hasattr(model, 'lm') else model(x, None) 72 | labels = shift_labels(x) 73 | lp = gather_logprobs(logits[:, :-1, :], labels) 74 | return lp # (B, T-1) 75 | 76 | # ---------- KL ---------- 77 | 78 | def approx_kl(policy_logp: torch.Tensor, ref_logp: torch.Tensor) -> torch.Tensor: 79 | # Mean over tokens: KL(pi||ref) ≈ (logp_pi - logp_ref).mean() 80 | return (policy_logp - ref_logp).mean() 81 | 82 | # ---------- small prompt source ---------- 83 | try: 84 | from datasets import load_dataset as _load_ds 85 | except Exception: 86 | _load_ds = None 87 | 88 | def sample_prompts(n: int) -> List[str]: 89 | if _load_ds is not None: 90 | try: 91 | ds = _load_ds("tatsu-lab/alpaca", split="train[:24]") 92 | arr = [] 93 | for r in ds: 94 | inst = (r.get('instruction') or '').strip() 95 | inp = (r.get('input') or '').strip() 96 | if inp: 97 | inst = inst + "\n" + inp 98 | if inst: 99 | arr.append(inst) 100 | if len(arr) >= n: 101 | break 102 | if arr: 103 | return arr 104 | except Exception: 105 | pass 106 | # fallback 107 | base = [ 108 | "Explain the purpose of attention in transformers.", 109 | "Give two pros and cons of BPE tokenization.", 110 | "Summarize why PPO is used in RLHF.", 111 | "Write a tiny Python function that reverses a list.", 112 | ] 113 | return (base * ((n+len(base)-1)//len(base)))[:n] -------------------------------------------------------------------------------- /part_6/train_sft.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import argparse, torch 3 | import torch.nn as nn 4 | from pathlib import Path 5 | torch.manual_seed(0) 6 | 7 | # Reuse GPTModern from Part 3 8 | import sys 9 | from pathlib import Path as _P 10 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3')) 11 | from model_modern import GPTModern # noqa: E402 12 | 13 | from dataset_sft import load_tiny_hf 14 | from collator_sft import SFTCollator 15 | from curriculum import LengthCurriculum 16 | 17 | 18 | def main(): 19 | p = argparse.ArgumentParser() 20 | p.add_argument('--data', type=str, default='huggingface', help='huggingface or path to local jsonl (unused in demo)') 21 | p.add_argument('--ckpt', type=str, required=False) 22 | p.add_argument('--out', type=str, default='runs/sft') 23 | p.add_argument('--steps', type=int, default=200) 24 | p.add_argument('--batch_size', type=int, default=8) 25 | p.add_argument('--block_size', type=int, default=256) 26 | p.add_argument('--n_layer', type=int, default=4) 27 | p.add_argument('--n_head', type=int, default=4) 28 | p.add_argument('--n_embd', type=int, default=256) 29 | p.add_argument('--lr', type=float, default=3e-4) 30 | p.add_argument('--cpu', action='store_true') 31 | p.add_argument('--bpe_dir', type=str, default='../part_4/runs/part4-demo/tokenizer') # assumes tokenizer exists from Part 4 32 | args = p.parse_args() 33 | 34 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 35 | 36 | # Load a tiny HF slice or fallback examples 37 | items = load_tiny_hf(split='train[:24]', sample_dataset=False) 38 | 39 | # Print few samples 40 | print(f"Loaded {len(items)} SFT items. Few samples:") 41 | for it in items[:3]: 42 | print(f"PROMPT: {it.prompt}\nRESPONSE: {it.response}\n{'-'*40}") 43 | 44 | # Curriculum over (prompt,response) 45 | tuples = [(it.prompt, it.response) for it in items] 46 | cur = list(LengthCurriculum(tuples)) 47 | print(cur) 48 | 49 | # Collator + model 50 | col = SFTCollator(block_size=args.block_size, bpe_dir=args.bpe_dir) 51 | model = GPTModern(vocab_size=col.vocab_size, block_size=args.block_size, 52 | n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, 53 | use_rmsnorm=True, use_swiglu=True, rope=True).to(device) 54 | 55 | if args.ckpt: 56 | print(f"Using model config from checkpoint {args.ckpt}") 57 | ckpt = torch.load(args.ckpt, map_location=device) 58 | cfg = ckpt.get('config', {}) 59 | model.load_state_dict(ckpt['model']) 60 | 61 | opt = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.1) 62 | model.train() 63 | 64 | # Simple loop (single machine). We just cycle curriculum to fill batches, for a few steps. 65 | step = 0 66 | i = 0 67 | while step < args.steps: 68 | batch = cur[i:i+args.batch_size] 69 | if not batch: 70 | # restart curriculum 71 | # cur = list(LengthCurriculum(tuples)); 72 | i = 0 73 | continue 74 | xb, yb = col.collate(batch) 75 | xb, yb = xb.to(device), yb.to(device) 76 | logits, loss, _ = model(xb, yb) 77 | opt.zero_grad(set_to_none=True) 78 | loss.backward() 79 | opt.step() 80 | step += 1; i += args.batch_size 81 | if step % 20 == 0: 82 | print(f"step {step}: loss={loss.item():.4f}") 83 | 84 | Path(args.out).mkdir(parents=True, exist_ok=True) 85 | cfg = { 86 | "vocab_size": col.vocab_size, 87 | "block_size": args.block_size, 88 | "n_layer": args.n_layer, 89 | "n_head": args.n_head, 90 | "n_embd": args.n_embd, 91 | "dropout": 0.0, 92 | "use_rmsnorm": True, 93 | "use_swiglu": True, 94 | "rope": True, 95 | # tokenizer info (best-effort) 96 | "tokenizer_type": "byte" if col.vocab_size == 256 else "bpe", 97 | "tokenizer_dir": None, # set a real path if you have a trained BPE dir 98 | } 99 | torch.save({'model': model.state_dict(), 'config': cfg}, 100 | str(Path(args.out)/'model_last.pt')) 101 | print(f"Saved SFT checkpoint to {args.out}/model_last.pt") 102 | 103 | if __name__ == '__main__': 104 | main() -------------------------------------------------------------------------------- /part_3/attn_modern.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import math, torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from rope_custom import RoPECache, apply_rope_single 6 | from kv_cache import KVCache # your existing class 7 | 8 | class CausalSelfAttentionModern(nn.Module): 9 | def __init__(self, n_embd: int, n_head: int, dropout: float = 0.0, 10 | rope: bool = True, max_pos: int = 4096, 11 | sliding_window: int | None = None, attention_sink: int = 0, 12 | n_kv_head: int | None = None): # ← NEW 13 | super().__init__() 14 | assert n_embd % n_head == 0, "n_embd must be divisible by n_head" 15 | self.n_head = n_head 16 | self.n_kv_head = n_kv_head or n_head # ← NEW (GQA defaults to MHA) 17 | assert self.n_head % self.n_kv_head == 0, "n_head must be multiple of n_kv_head (GQA grouping)" 18 | self.group_size = self.n_head // self.n_kv_head 19 | self.d_head = n_embd // n_head 20 | 21 | # Separate projections for Q vs K/V (sizes differ under GQA) ← CHANGED 22 | self.wq = nn.Linear(n_embd, self.n_head * self.d_head, bias=False) 23 | self.wk = nn.Linear(n_embd, self.n_kv_head * self.d_head, bias=False) 24 | self.wv = nn.Linear(n_embd, self.n_kv_head * self.d_head, bias=False) 25 | self.proj = nn.Linear(n_embd, n_embd, bias=False) 26 | self.dropout = nn.Dropout(dropout) 27 | 28 | self.use_rope = rope 29 | self.rope_cache: RoPECache | None = None 30 | self.max_pos = max_pos 31 | self.sliding_window = sliding_window 32 | self.attention_sink = attention_sink 33 | 34 | def _maybe_init_rope(self, device): 35 | if self.use_rope and self.rope_cache is None: 36 | self.rope_cache = RoPECache(self.d_head, self.max_pos, device=device) 37 | 38 | def forward(self, x: torch.Tensor, kv_cache: KVCache | None = None, start_pos: int = 0): 39 | """x: (B,T,C). If kv_cache given, we assume generation (T small, often 1).""" 40 | B, T, C = x.shape 41 | self._maybe_init_rope(x.device) 42 | 43 | # Projections 44 | q = self.wq(x).view(B, T, self.n_head, self.d_head).transpose(1, 2) # (B,H, T,D) 45 | k = self.wk(x).view(B, T, self.n_kv_head, self.d_head).transpose(1, 2) # (B,Hk,T,D) 46 | v = self.wv(x).view(B, T, self.n_kv_head, self.d_head).transpose(1, 2) # (B,Hk,T,D) 47 | 48 | # RoPE on *current* tokens (cached keys are already rotated) 49 | if self.use_rope: 50 | pos = torch.arange(start_pos, start_pos + T, device=x.device) 51 | cos, sin = self.rope_cache.get(pos) 52 | q = apply_rope_single(q, cos, sin) # (B,H, T,D) 53 | k = apply_rope_single(k, cos, sin) # (B,Hk,T,D) 54 | 55 | # Concatenate past cache (cache is stored in Hk heads) 56 | if kv_cache is not None: 57 | k_all = torch.cat([kv_cache.k, k], dim=2) # (B,Hk, Tpast+T, D) 58 | v_all = torch.cat([kv_cache.v, v], dim=2) 59 | else: 60 | k_all, v_all = k, v 61 | 62 | # Sliding-window + attention-sink (crop along seq length) 63 | if self.sliding_window is not None and k_all.size(2) > (self.sliding_window + self.attention_sink): 64 | s = self.attention_sink 65 | k_all = torch.cat([k_all[:, :, :s, :], k_all[:, :, -self.sliding_window:, :]], dim=2) 66 | v_all = torch.cat([v_all[:, :, :s, :], v_all[:, :, -self.sliding_window:, :]], dim=2) 67 | 68 | # --- GQA expand: repeat K/V heads to match Q heads before attention --- 69 | if self.n_kv_head != self.n_head: 70 | k_attn = k_all.repeat_interleave(self.group_size, dim=1) # (B,H,Tk,D) 71 | v_attn = v_all.repeat_interleave(self.group_size, dim=1) # (B,H,Tk,D) 72 | else: 73 | k_attn, v_attn = k_all, v_all 74 | 75 | # Scaled dot-product attention (PyTorch scales internally) 76 | is_causal = kv_cache is None 77 | y = F.scaled_dot_product_attention(q, k_attn, v_attn, 78 | attn_mask=None, 79 | dropout_p=self.dropout.p if self.training else 0.0, 80 | is_causal=is_causal) # (B,H,T,D) 81 | 82 | y = y.transpose(1, 2).contiguous().view(B, T, C) 83 | y = self.proj(y) 84 | 85 | # Update KV cache (store compact Hk heads, not expanded) 86 | if kv_cache is not None: 87 | k_new = torch.cat([kv_cache.k, k], dim=2) # (B,Hk,*,D) 88 | v_new = torch.cat([kv_cache.v, v], dim=2) 89 | else: 90 | k_new, v_new = k, v 91 | new_cache = KVCache(k_new, v_new) 92 | return y, new_cache 93 | -------------------------------------------------------------------------------- /part_2/model_gpt.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | # ---- Blocks (self-contained for isolation) ---- 8 | class CausalSelfAttention(nn.Module): 9 | def __init__(self, n_embd: int, n_head: int, dropout: float = 0.0): 10 | super().__init__() 11 | assert n_embd % n_head == 0 12 | self.n_head = n_head 13 | self.d_head = n_embd // n_head 14 | self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False) 15 | self.proj = nn.Linear(n_embd, n_embd, bias=False) 16 | self.dropout = nn.Dropout(dropout) 17 | 18 | def forward(self, x: torch.Tensor): # (B,T,C) 19 | B, T, C = x.shape 20 | qkv = self.qkv(x).view(B, T, 3, self.n_head, self.d_head) 21 | q, k, v = qkv.unbind(dim=2) 22 | q = q.transpose(1, 2) 23 | k = k.transpose(1, 2) 24 | v = v.transpose(1, 2) 25 | scale = 1.0 / math.sqrt(self.d_head) 26 | # PyTorch SDPA (uses flash when available) 27 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout.p if self.training else 0.0, is_causal=True) 28 | y = y.transpose(1, 2).contiguous().view(B, T, C) 29 | y = self.proj(y) 30 | return y 31 | 32 | class FeedForward(nn.Module): 33 | def __init__(self, n_embd: int, mult: int = 4, dropout: float = 0.0): 34 | super().__init__() 35 | self.net = nn.Sequential( 36 | nn.Linear(n_embd, mult * n_embd), 37 | nn.GELU(), 38 | nn.Linear(mult * n_embd, n_embd), 39 | nn.Dropout(dropout), 40 | ) 41 | 42 | def forward(self, x): 43 | return self.net(x) 44 | 45 | class Block(nn.Module): 46 | def __init__(self, n_embd: int, n_head: int, dropout: float): 47 | super().__init__() 48 | self.ln1 = nn.LayerNorm(n_embd) 49 | self.attn = CausalSelfAttention(n_embd, n_head, dropout) 50 | self.ln2 = nn.LayerNorm(n_embd) 51 | self.ffn = FeedForward(n_embd, mult=4, dropout=dropout) 52 | 53 | def forward(self, x): 54 | x = x + self.attn(self.ln1(x)) 55 | x = x + self.ffn(self.ln2(x)) 56 | return x 57 | 58 | # ---- Tiny GPT ---- 59 | class GPT(nn.Module): 60 | def __init__(self, vocab_size: int, block_size: int, n_layer: int = 4, n_head: int = 4, n_embd: int = 256, dropout: float = 0.0): 61 | super().__init__() 62 | self.block_size = block_size 63 | self.tok_emb = nn.Embedding(vocab_size, n_embd) 64 | self.pos_emb = nn.Embedding(block_size, n_embd) 65 | self.drop = nn.Dropout(dropout) 66 | self.blocks = nn.ModuleList([Block(n_embd, n_head, dropout) for _ in range(n_layer)]) 67 | self.ln_f = nn.LayerNorm(n_embd) 68 | self.head = nn.Linear(n_embd, vocab_size, bias=False) 69 | 70 | self.apply(self._init_weights) 71 | 72 | def _init_weights(self, m): 73 | if isinstance(m, nn.Linear): 74 | nn.init.normal_(m.weight, mean=0.0, std=0.02) 75 | if m.bias is not None: 76 | nn.init.zeros_(m.bias) 77 | elif isinstance(m, nn.Embedding): 78 | nn.init.normal_(m.weight, mean=0.0, std=0.02) 79 | 80 | def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None): 81 | B, T = idx.shape 82 | assert T <= self.block_size 83 | pos = torch.arange(0, T, device=idx.device).unsqueeze(0) 84 | x = self.tok_emb(idx) + self.pos_emb(pos) 85 | x = self.drop(x) 86 | for blk in self.blocks: 87 | x = blk(x) 88 | x = self.ln_f(x) 89 | logits = self.head(x) 90 | loss = None 91 | if targets is not None: 92 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 93 | return logits, loss 94 | 95 | @torch.no_grad() 96 | def generate(self, idx: torch.Tensor, max_new_tokens: int = 200, temperature: float = 1.0, 97 | top_k: int | None = 50, top_p: float | None = None): 98 | from utils import top_k_top_p_filtering 99 | self.eval() 100 | # Guard: if the prompt is empty, start with a newline byte (10) 101 | if idx.size(1) == 0: 102 | idx = torch.full((idx.size(0), 1), 10, dtype=torch.long, device=idx.device) 103 | for _ in range(max_new_tokens): 104 | idx_cond = idx[:, -self.block_size:] 105 | logits, _ = self(idx_cond) 106 | logits = logits[:, -1, :] / max(temperature, 1e-6) 107 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 108 | probs = torch.softmax(logits, dim=-1) 109 | next_id = torch.multinomial(probs, num_samples=1) 110 | idx = torch.cat([idx, next_id], dim=1) 111 | return idx 112 | -------------------------------------------------------------------------------- /part_2/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import argparse, time 3 | import torch 4 | from tokenizer import ByteTokenizer 5 | from dataset import ByteDataset 6 | from model_gpt import GPT 7 | 8 | 9 | def estimate_loss(model: GPT, ds: ByteDataset, args) -> dict: 10 | model.eval() 11 | out = {} 12 | with torch.no_grad(): 13 | for split in ['train', 'val']: 14 | losses = [] 15 | for _ in range(args.eval_iters): 16 | xb, yb = ds.get_batch(split, args.batch_size, args.device) 17 | _, loss = model(xb, yb) 18 | losses.append(loss.item()) 19 | out[split] = sum(losses) / len(losses) 20 | model.train() 21 | return out 22 | 23 | 24 | def main(): 25 | p = argparse.ArgumentParser() 26 | p.add_argument('--data', type=str, required=True) 27 | p.add_argument('--out_dir', type=str, default='runs/min-gpt') 28 | p.add_argument('--block_size', type=int, default=256) 29 | p.add_argument('--batch_size', type=int, default=32) 30 | p.add_argument('--n_layer', type=int, default=4) 31 | p.add_argument('--n_head', type=int, default=4) 32 | p.add_argument('--n_embd', type=int, default=256) 33 | p.add_argument('--dropout', type=float, default=0.0) 34 | p.add_argument('--steps', type=int, default=2000) 35 | p.add_argument('--lr', type=float, default=3e-4) 36 | p.add_argument('--weight_decay', type=float, default=0.1) 37 | p.add_argument('--grad_clip', type=float, default=1.0) 38 | p.add_argument('--eval_interval', type=int, default=200) 39 | p.add_argument('--eval_iters', type=int, default=50) 40 | p.add_argument('--sample_every', type=int, default=200) 41 | p.add_argument('--sample_tokens', type=int, default=256) 42 | p.add_argument('--temperature', type=float, default=1.0) 43 | p.add_argument('--top_k', type=int, default=50) 44 | p.add_argument('--top_p', type=float, default=None) 45 | p.add_argument('--cpu', action='store_true') 46 | p.add_argument('--compile', action='store_true') 47 | p.add_argument('--amp', action='store_true') 48 | args = p.parse_args() 49 | 50 | args.device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 51 | 52 | tok = ByteTokenizer() 53 | ds = ByteDataset(args.data, block_size=args.block_size) 54 | model = GPT(tok.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(args.device) 55 | 56 | if args.compile and hasattr(torch, 'compile'): 57 | model = torch.compile(model) 58 | 59 | opt = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weight_decay) 60 | scaler = torch.cuda.amp.GradScaler(enabled=(args.amp and args.device.type == 'cuda')) 61 | 62 | best_val = float('inf') 63 | t0 = time.time() 64 | model.train() 65 | for step in range(1, args.steps + 1): 66 | xb, yb = ds.get_batch('train', args.batch_size, args.device) 67 | with torch.cuda.amp.autocast(enabled=(args.amp and args.device.type == 'cuda')): 68 | _, loss = model(xb, yb) 69 | opt.zero_grad(set_to_none=True) 70 | scaler.scale(loss).backward() 71 | if args.grad_clip > 0: 72 | scaler.unscale_(opt) 73 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 74 | scaler.step(opt) 75 | scaler.update() 76 | 77 | if step % 50 == 0: 78 | print(f"step {step:5d} | loss {loss.item():.4f} | {(time.time()-t0):.1f}s") 79 | t0 = time.time() 80 | 81 | if step % args.eval_interval == 0: 82 | losses = estimate_loss(model, ds, args) 83 | print(f"eval | train {losses['train']:.4f} | val {losses['val']:.4f}") 84 | if losses['val'] < best_val: 85 | best_val = losses['val'] 86 | ckpt_path = f"{args.out_dir}/model_best.pt" 87 | import os; os.makedirs(args.out_dir, exist_ok=True) 88 | torch.save({'model': model.state_dict(), 'config': { 89 | 'vocab_size': tok.vocab_size, 90 | 'block_size': args.block_size, 91 | 'n_layer': args.n_layer, 92 | 'n_head': args.n_head, 93 | 'n_embd': args.n_embd, 94 | 'dropout': args.dropout, 95 | }}, ckpt_path) 96 | print(f"saved checkpoint: {ckpt_path}") 97 | 98 | if args.sample_every > 0 and step % args.sample_every == 0: 99 | start = torch.randint(low=0, high=len(ds.train) - args.block_size - 1, size=(1,)).item() 100 | seed = ds.train[start:start + args.block_size].unsqueeze(0).to(args.device) 101 | out = model.generate(seed, max_new_tokens=args.sample_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p) 102 | txt = tok.decode(out[0].cpu()) 103 | print("\n================ SAMPLE ================\n" + txt[-(args.block_size + args.sample_tokens):] + "\n=======================================\n") 104 | 105 | # final save 106 | import os; os.makedirs(args.out_dir, exist_ok=True) 107 | torch.save({'model': model.state_dict()}, f"{args.out_dir}/model_final.pt") 108 | 109 | 110 | if __name__ == '__main__': 111 | main() -------------------------------------------------------------------------------- /part_3/model_modern.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | import torch.nn as nn 4 | from block_modern import TransformerBlockModern 5 | from tokenizer import ByteTokenizer 6 | 7 | # Get the absolute path to the folder that contains part_2 and part_3 8 | import os, sys 9 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 10 | sys.path.insert(0, parent_dir) 11 | 12 | class GPTModern(nn.Module): 13 | def __init__(self, vocab_size: int = 256, block_size: int = 256, 14 | n_layer: int=4, n_head: int=4, n_embd: int=256, dropout: float=0.0, 15 | use_rmsnorm: bool = True, use_swiglu: bool = True, rope: bool = True, 16 | max_pos: int = 4096, sliding_window: int | None = None, attention_sink: int = 0, n_kv_head: int | None = None): 17 | super().__init__() 18 | self.block_size = block_size 19 | self.tok_emb = nn.Embedding(vocab_size, n_embd) 20 | # self.pos_emb = nn.Embedding(block_size, n_embd) 21 | self.drop = nn.Dropout(dropout) 22 | self.blocks = nn.ModuleList([ 23 | TransformerBlockModern(n_embd, n_head, dropout, use_rmsnorm, use_swiglu, rope, max_pos, sliding_window, attention_sink, n_kv_head) 24 | for _ in range(n_layer) 25 | ]) 26 | self.ln_f = nn.Identity() if use_rmsnorm else nn.LayerNorm(n_embd) 27 | self.head = nn.Linear(n_embd, vocab_size, bias=False) 28 | 29 | def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache_list=None, start_pos: int = 0): 30 | B, T = idx.shape 31 | assert T <= self.block_size 32 | pos = torch.arange(0, T, device=idx.device).unsqueeze(0) 33 | x = self.tok_emb(idx) 34 | # + self.pos_emb(pos) 35 | x = self.drop(x) 36 | 37 | new_caches = [] 38 | for i, blk in enumerate(self.blocks): 39 | cache = None if kv_cache_list is None else kv_cache_list[i] 40 | x, cache = blk(x, kv_cache=cache, start_pos=start_pos) 41 | new_caches.append(cache) 42 | x = self.ln_f(x) 43 | logits = self.head(x) 44 | 45 | loss = None 46 | if targets is not None: 47 | import torch.nn.functional as F 48 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 49 | return logits, loss, new_caches 50 | 51 | @torch.no_grad() 52 | def generate(self, 53 | prompt: torch.Tensor, 54 | max_new_tokens=200, 55 | temperature=1.0, 56 | top_k=50, 57 | top_p=None, 58 | eos_id=1, # addition from part 6 for early stopping 59 | sliding_window: int | None = None, 60 | attention_sink: int = 0): 61 | try: 62 | from utils import top_k_top_p_filtering as _tk 63 | except Exception: 64 | _tk = lambda x, **_: x 65 | 66 | self.eval() 67 | idx = prompt 68 | kvs = [None] * len(self.blocks) 69 | 70 | for _ in range(max_new_tokens): 71 | # feed full prompt once; then only the last token 72 | idx_cond = idx[:, -self.block_size:] if kvs[0] is None else idx[:, -1:] 73 | 74 | # absolute start position from cache length (0 on first step) 75 | start_pos = 0 if kvs[0] is None else kvs[0].k.size(2) 76 | 77 | logits, _, kvs = self(idx_cond, kv_cache_list=kvs, start_pos=start_pos) 78 | 79 | next_logits = logits[:, -1, :] / max(temperature, 1e-6) 80 | next_logits = _tk(next_logits, top_k=top_k, top_p=top_p) 81 | probs = torch.softmax(next_logits, dim=-1) 82 | next_id = torch.argmax(probs, dim=-1, keepdim=True) if temperature == 0.0 else torch.multinomial(probs, 1) 83 | idx = torch.cat([idx, next_id], dim=1) 84 | 85 | # addition from part 6 for early stopping 86 | if eos_id is not None: 87 | if (next_id == eos_id).all(): 88 | break 89 | 90 | return idx 91 | 92 | 93 | @torch.no_grad() 94 | def generate_nocache(self, prompt: torch.Tensor, max_new_tokens=200, temperature=1.0, top_k=50, top_p=None, 95 | sliding_window: int | None = None, attention_sink: int = 0): 96 | try: 97 | from utils import top_k_top_p_filtering as _tk 98 | except Exception: 99 | _tk = lambda x, **_: x 100 | 101 | self.eval() 102 | idx = prompt 103 | 104 | for _ in range(max_new_tokens): 105 | # always run a full forward over the cropped window, with NO cache 106 | idx_cond = idx[:, -self.block_size:] 107 | # absolute position of first token in the window (matches cached path) 108 | start_pos = idx.size(1) - idx_cond.size(1) 109 | 110 | logits, _, _ = self(idx_cond, kv_cache_list=None, start_pos=start_pos) 111 | 112 | next_logits = logits[:, -1, :] / max(temperature, 1e-6) 113 | next_logits = _tk(next_logits, top_k=top_k, top_p=top_p) 114 | probs = torch.softmax(next_logits, dim=-1) 115 | topv, topi = torch.topk(probs, 10) 116 | print("top ids:", topi.tolist()) 117 | print("top vs:", topv.tolist()) 118 | next_id = torch.argmax(probs, dim=-1, keepdim=True) if temperature == 0.0 else torch.multinomial(probs, 1) 119 | idx = torch.cat([idx, next_id], dim=1) 120 | 121 | return idx 122 | 123 | -------------------------------------------------------------------------------- /part_4/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import time 3 | from pathlib import Path 4 | 5 | class NoopLogger: 6 | def log(self, **kwargs): 7 | pass 8 | def close(self): 9 | pass 10 | 11 | class TBLogger(NoopLogger): 12 | """ 13 | Backward compatible: 14 | - logger.log(step=..., loss=..., lr=...) 15 | Extras you can optionally use: 16 | - logger.hist("params/wte.weight", tensor, step) 17 | - logger.text("samples/generation", text, step) 18 | - logger.image("attn/heatmap", HWC_or_CHW_tensor_or_np, step) 19 | - logger.graph(model, example_batch) 20 | - logger.hparams(dict_of_config, dict_of_metrics_once) 21 | - logger.flush() 22 | Auto-behavior: 23 | - If a value in .log(...) is a tensor/ndarray with >1 element, it logs a histogram. 24 | - If key starts with "text/", logs as text. 25 | """ 26 | # logger.py 27 | def __init__(self, out_dir: str, flush_secs: int = 10, run_name: str | None = None): 28 | self.w = None 29 | self.hparams_logged = False 30 | run_name = run_name or time.strftime("%Y%m%d-%H%M%S") 31 | run_dir = Path(out_dir) / run_name 32 | run_dir.mkdir(parents=True, exist_ok=True) 33 | try: 34 | from torch.utils.tensorboard import SummaryWriter 35 | self.w = SummaryWriter(log_dir=str(run_dir), flush_secs=flush_secs) 36 | except Exception as e: 37 | print(f"[TBLogger] TensorBoard not available: {e}. Logging disabled.") 38 | self._auto_hist_max_elems = 2048 39 | self.run_dir = str(run_dir) # handy for prints/debug 40 | 41 | 42 | 43 | # ---------- backwards-compatible ---------- 44 | def log(self, step: Optional[int] = None, **kv: Any): 45 | if not self.w: return 46 | for k, v in kv.items(): 47 | # text channel (opt-in via key prefix) 48 | if isinstance(k, str) and k.startswith("text/"): 49 | try: 50 | self.w.add_text(k[5:], str(v), global_step=step) 51 | except Exception: 52 | pass 53 | continue 54 | 55 | # scalar vs histogram auto-route 56 | try: 57 | import torch, numpy as np # lazy 58 | is_torch = isinstance(v, torch.Tensor) 59 | is_np = isinstance(v, np.ndarray) 60 | if is_torch or is_np: 61 | # scalar? 62 | numel = int(v.numel() if is_torch else v.size) 63 | if numel == 1: 64 | val = (v.item() if is_torch else float(v)) 65 | self.w.add_scalar(k, float(val), global_step=step) 66 | else: 67 | # small-ish tensors => histogram 68 | if numel <= self._auto_hist_max_elems: 69 | self.w.add_histogram(k, v.detach().cpu() if is_torch else v, global_step=step) 70 | else: 71 | # fall back to scalar summary stats 72 | arr = v.detach().cpu().flatten().numpy() if is_torch else v.flatten() 73 | self.w.add_scalar(k + "/mean", float(arr.mean()), global_step=step) 74 | self.w.add_scalar(k + "/std", float(arr.std()), global_step=step) 75 | continue 76 | except Exception: 77 | pass 78 | 79 | # number-like 80 | try: 81 | self.w.add_scalar(k, float(v), global_step=step) 82 | except Exception: 83 | # swallow non-numeric junk silently (same behavior as before) 84 | pass 85 | 86 | # ---------- nice-to-have helpers ---------- 87 | def hist(self, tag: str, values: Any, step: Optional[int] = None, bins: str = "tensorflow"): 88 | if not self.w: return 89 | try: 90 | import torch 91 | if isinstance(values, torch.Tensor): 92 | values = values.detach().cpu() 93 | self.w.add_histogram(tag, values, global_step=step, bins=bins) 94 | except Exception: 95 | pass 96 | 97 | def text(self, tag: str, text: str, step: Optional[int] = None): 98 | if not self.w: return 99 | try: 100 | self.w.add_text(tag, text, global_step=step) 101 | except Exception: 102 | pass 103 | 104 | def image(self, tag: str, img, step: Optional[int] = None): 105 | """ 106 | img: torch.Tensor [C,H,W] or [H,W,C] or numpy array 107 | """ 108 | if not self.w: return 109 | try: 110 | self.w.add_image(tag, img, global_step=step, dataformats="CHW" if getattr(img, "ndim", 0) == 3 and img.shape[0] in (1,3) else "HWC") 111 | except Exception: 112 | pass 113 | 114 | def graph(self, model, example_input): 115 | if not self.w: return 116 | try: 117 | # example_input: a Tensor batch or a tuple 118 | if not isinstance(example_input, tuple): 119 | example_input = (example_input,) 120 | self.w.add_graph(model, example_input) 121 | except Exception: 122 | pass # graph tracing can fail depending on model control flow; don't crash 123 | 124 | def hparams(self, hparams: Dict[str, Any], metrics_once: Optional[Dict[str, float]] = None): 125 | if not self.w or self.hparams_logged: 126 | return 127 | try: 128 | # Single, stable sub-run so it doesn’t spam the left pane 129 | self.w.add_hparams(hparams, metrics_once or {}, run_name="_hparams") 130 | self.hparams_logged = True 131 | except Exception: 132 | pass 133 | 134 | def flush(self): 135 | if self.w: 136 | try: self.w.flush() 137 | except Exception: pass 138 | 139 | def close(self): 140 | if self.w: 141 | try: self.w.close() 142 | except Exception: pass 143 | 144 | class WBLogger(NoopLogger): 145 | def __init__(self, project: str, run_name: str | None = None): 146 | try: 147 | import wandb 148 | wandb.init(project=project, name=run_name) 149 | self.wb = wandb 150 | except Exception: 151 | self.wb = None 152 | def log(self, **kv): 153 | if self.wb: self.wb.log(kv) 154 | 155 | 156 | def init_logger(which: str, out_dir: str = "runs/part4"): 157 | if which == 'tensorboard': 158 | tb = TBLogger(out_dir) 159 | return tb if tb.w is not None else NoopLogger() 160 | if which == 'wandb': 161 | return WBLogger(project='llm-part4') 162 | return NoopLogger() 163 | -------------------------------------------------------------------------------- /part_4/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import argparse, time, signal 3 | from pathlib import Path 4 | import sys 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | # so we can import Part 3 model 10 | from pathlib import Path as _P 11 | sys.path.append(str(_P(__file__).resolve().parents[1] / 'part_3')) 12 | from model_modern import GPTModern 13 | 14 | from tokenizer_bpe import BPETokenizer 15 | from dataset_bpe import make_loader 16 | from lr_scheduler import WarmupCosineLR 17 | from amp_accum import AmpGrad 18 | from checkpointing import ( 19 | load_checkpoint, 20 | _log_hparams_tb, 21 | _maybe_log_graph_tb, 22 | _is_tb, 23 | _log_model_stats, 24 | _maybe_log_attention, 25 | _log_samples_tb, 26 | _log_runtime, 27 | atomic_save_all, 28 | ) 29 | from logger import init_logger 30 | 31 | 32 | def run_cfg_from_args(args, vocab_size: int) -> dict: 33 | return dict( 34 | vocab_size=vocab_size, 35 | block_size=args.block_size, 36 | n_layer=args.n_layer, 37 | n_head=args.n_head, 38 | n_embd=args.n_embd, 39 | dropout=args.dropout, 40 | use_rmsnorm=True, 41 | use_swiglu=True, 42 | rope=True, 43 | max_pos=4096, 44 | sliding_window=None, 45 | attention_sink=0, 46 | ) 47 | 48 | 49 | def main(): 50 | p = argparse.ArgumentParser() 51 | p.add_argument('--data', type=str, required=True) 52 | p.add_argument('--out', type=str, default='runs/part4') 53 | 54 | # tokenizer / model dims 55 | p.add_argument('--bpe', action='store_true', help='train and use a BPE tokenizer (recommended)') 56 | p.add_argument('--vocab_size', type=int, default=32000) 57 | p.add_argument('--block_size', type=int, default=256) 58 | p.add_argument('--n_layer', type=int, default=6) 59 | p.add_argument('--n_head', type=int, default=8) 60 | p.add_argument('--n_embd', type=int, default=512) 61 | p.add_argument('--dropout', type=float, default=0.0) 62 | 63 | # train 64 | p.add_argument('--batch_size', type=int, default=32) 65 | p.add_argument('--epochs', type=int, default=1) 66 | p.add_argument('--steps', type=int, default=300, help='max optimizer steps for this run') 67 | p.add_argument('--lr', type=float, default=3e-4) 68 | p.add_argument('--warmup_steps', type=int, default=20) 69 | p.add_argument('--mixed_precision', action='store_true') 70 | p.add_argument('--grad_accum_steps', type=int, default=4) 71 | 72 | # misc 73 | p.add_argument('--log', choices=['wandb', 'tensorboard', 'none'], default='tensorboard') 74 | p.add_argument('--save_every', type=int, default=50, help='save checkpoint every N optimizer steps') 75 | p.add_argument('--keep_last_k', type=int, default=2, help='keep last K step checkpoints (plus model_last.pt)') 76 | args = p.parse_args() 77 | 78 | # device 79 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 80 | 81 | # output dir and (possible) checkpoint 82 | out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True) 83 | ckpt_path = out_dir / "model_last.pt" 84 | have_ckpt = ckpt_path.exists() 85 | 86 | # ---- load checkpoint meta if present ---- 87 | ckpt = None 88 | saved_tok_dir = None 89 | if have_ckpt: 90 | ckpt = torch.load(str(ckpt_path), map_location=device) 91 | if "config" not in ckpt: 92 | raise RuntimeError( 93 | "Checkpoint is missing 'config'." 94 | "Please re-save a checkpoint that includes the model config." 95 | ) 96 | tok_file = ckpt_path.with_name("tokenizer_dir.txt") 97 | saved_tok_dir = tok_file.read_text().strip() if tok_file.exists() else None 98 | 99 | # ---- tokenizer ---- 100 | tok = None 101 | tok_dir = None 102 | if have_ckpt: 103 | if not saved_tok_dir: 104 | raise RuntimeError( 105 | "Checkpoint was found but tokenizer_dir.txt is missing. " 106 | "Resume requires the original tokenizer." 107 | ) 108 | tok = BPETokenizer(); tok.load(saved_tok_dir) 109 | tok_dir = saved_tok_dir 110 | vocab_size = tok.vocab_size 111 | print(f"[resume] Loaded tokenizer from {tok_dir} (vocab={vocab_size})") 112 | else: 113 | if args.bpe: 114 | tok = BPETokenizer(vocab_size=args.vocab_size) 115 | tok.train(args.data) 116 | tok_dir = str(out_dir / 'tokenizer'); Path(tok_dir).mkdir(parents=True, exist_ok=True) 117 | tok.save(tok_dir) 118 | vocab_size = tok.vocab_size 119 | print(f"[init] Trained tokenizer to {tok_dir} (vocab={vocab_size})") 120 | else: 121 | tok = None 122 | vocab_size = 256 # byte-level fallback (not recommended for Part 4) 123 | 124 | # ---- dataset ---- 125 | train_loader = make_loader(args.data, tok, args.block_size, args.batch_size, shuffle=True) 126 | 127 | # ---- build model config ---- 128 | if have_ckpt: 129 | cfg_build = ckpt["config"] 130 | if cfg_build.get("vocab_size") != vocab_size: 131 | raise RuntimeError( 132 | f"Tokenizer vocab ({vocab_size}) != checkpoint config vocab ({cfg_build.get('vocab_size')}). " 133 | "This deterministic script forbids vocab changes on resume." 134 | ) 135 | else: 136 | cfg_build = run_cfg_from_args(args, vocab_size) 137 | 138 | # ---- init model/opt/sched/amp ---- 139 | model = GPTModern(**cfg_build).to(device) 140 | optim = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.1) 141 | 142 | total_steps = min(args.steps, args.epochs * len(train_loader)) 143 | warmup = min(args.warmup_steps, max(total_steps // 10, 1)) 144 | sched = WarmupCosineLR(optim, warmup_steps=warmup, total_steps=total_steps, base_lr=args.lr) 145 | 146 | amp = AmpGrad(optim, accum=args.grad_accum_steps, amp=args.mixed_precision) 147 | 148 | # ---- strict resume ---- 149 | step = 0 150 | if have_ckpt: 151 | step = load_checkpoint(model, str(ckpt_path), optimizer=optim, scheduler=sched, amp=amp, strict=True) 152 | print(f"[resume] Loaded checkpoint at step {step}") 153 | 154 | # ---- logging ---- 155 | logger = init_logger(args.log, out_dir=str(out_dir)) 156 | _log_hparams_tb(logger, args, total_steps) 157 | if _is_tb(logger): 158 | try: 159 | ex_x, ex_y = next(iter(train_loader)) 160 | _maybe_log_graph_tb(logger, model, ex_x.to(device), ex_y.to(device)) 161 | except Exception: 162 | pass 163 | 164 | # ---- graceful save on SIGINT/SIGTERM ---- 165 | save_requested = {"flag": False} 166 | def _on_term(sig, frame): save_requested["flag"] = True 167 | signal.signal(signal.SIGTERM, _on_term) 168 | signal.signal(signal.SIGINT, _on_term) 169 | 170 | # ---- train loop ---- 171 | model.train() 172 | while step < args.steps: 173 | for xb, yb in train_loader: 174 | if step >= args.steps: break 175 | if save_requested["flag"]: 176 | atomic_save_all(model, optim, sched, amp, step, out_dir, tok_dir, args.keep_last_k, cfg_build) 177 | print(f"[signal] Saved checkpoint at step {step} to {out_dir}. Exiting.") 178 | return 179 | 180 | it_t0 = time.time() 181 | xb, yb = xb.to(device), yb.to(device) 182 | with torch.cuda.amp.autocast(enabled=amp.amp): 183 | logits, loss, _ = model(xb, yb) 184 | amp.backward(loss) 185 | 186 | if amp.should_step(): 187 | amp.step(); amp.zero_grad() 188 | lr = sched.step() 189 | step += 1 190 | 191 | # periodic checkpoint 192 | if step % args.save_every == 0: 193 | atomic_save_all(model, optim, sched, amp, step, out_dir, tok_dir, args.keep_last_k, cfg_build) 194 | if _is_tb(logger): 195 | logger.text("meta/checkpoint", f"Saved at step {step}", step) 196 | 197 | # logging 198 | if step % 50 == 0: 199 | logger.log(step=step, loss=float(loss.item()), lr=float(lr)) 200 | _log_runtime(logger, step, it_t0, xb, device) 201 | _log_model_stats(logger, model, step, do_hists=False) 202 | _maybe_log_attention(logger, model, xb, step, every=100) 203 | _log_samples_tb(logger, model, tok, xb, device, step, max_new_tokens=64) 204 | 205 | # ---- final save ---- 206 | atomic_save_all(model, optim, sched, amp, step, out_dir, tok_dir, args.keep_last_k, cfg_build) 207 | print(f"Saved checkpoint to {out_dir}/model_last.pt") 208 | 209 | 210 | if __name__ == '__main__': 211 | main() 212 | -------------------------------------------------------------------------------- /part_8/train_ppo.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import argparse, torch 3 | from pathlib import Path 4 | 5 | # import torch 6 | # torch.manual_seed(0) 7 | 8 | from policy import PolicyWithValue 9 | from rollout import RLHFTokenizer, format_prompt_only, format_example, sample_prompts, gather_logprobs, shift_labels 10 | from rollout import model_logprobs 11 | 12 | # Reward model from Part 7 13 | import sys 14 | from pathlib import Path as _P 15 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_7')) 16 | from model_reward import RewardModel # noqa: E402 17 | 18 | from ppo_loss import ppo_losses 19 | 20 | 21 | def compute_reward(reward_model: RewardModel, tok: RLHFTokenizer, prompt: str, response: str, device) -> float: 22 | text = format_example(__import__('part_6.formatters', fromlist=['Example']).Example(prompt, response)) 23 | ids = tok.encode(text) 24 | x = torch.tensor([ids[:tok.block_size]], dtype=torch.long, device=device) 25 | with torch.no_grad(): 26 | r = reward_model(x) 27 | return float(r[0].item()) 28 | 29 | 30 | def main(): 31 | p = argparse.ArgumentParser() 32 | p.add_argument('--out', type=str, default='runs/ppo-demo') 33 | p.add_argument('--policy_ckpt', type=str, required=True, help='SFT checkpoint (Part 6)') 34 | p.add_argument('--reward_ckpt', type=str, required=True, help='Reward model checkpoint (Part 7)') 35 | p.add_argument('--steps', type=int, default=100) 36 | p.add_argument('--batch_size', type=int, default=4) 37 | p.add_argument('--block_size', type=int, default=256) 38 | p.add_argument('--resp_len', type=int, default=64) 39 | p.add_argument('--kl_coef', type=float, default=0.01) 40 | p.add_argument('--gamma', type=float, default=1.0) 41 | p.add_argument('--lam', type=float, default=0.95) 42 | p.add_argument('--lr', type=float, default=1e-5) 43 | p.add_argument('--bpe_dir', type=str, default=None) 44 | p.add_argument('--cpu', action='store_true') 45 | args = p.parse_args() 46 | 47 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 48 | 49 | # tokenizer 50 | tok = RLHFTokenizer(block_size=args.block_size, bpe_dir=args.bpe_dir) 51 | 52 | # Load SFT policy as initial policy AND reference 53 | ckpt = torch.load(args.policy_ckpt, map_location=device) 54 | cfg = ckpt.get('config', {}) 55 | vocab_size = cfg.get('vocab_size', tok.vocab_size) 56 | block_size = cfg.get('block_size', tok.block_size) 57 | n_layer = cfg.get('n_layer', 2) 58 | n_head = cfg.get('n_head', 2) 59 | n_embd = cfg.get('n_embd', 128) 60 | 61 | policy = PolicyWithValue(vocab_size, block_size, n_layer, n_head, n_embd).to(device) 62 | policy.lm.load_state_dict(ckpt['model']) # initialize LM weights from SFT 63 | 64 | 65 | ref = PolicyWithValue(vocab_size, block_size, n_layer, n_head, n_embd).to(device) 66 | ref.lm.load_state_dict(ckpt['model']) 67 | for p_ in ref.parameters(): 68 | p_.requires_grad_(False) 69 | ref.eval() 70 | 71 | # Reward model 72 | rckpt = torch.load(args.reward_ckpt, map_location=device) 73 | rm = RewardModel(vocab_size=rckpt['config'].get('vocab_size', tok.vocab_size), block_size=rckpt['config'].get('block_size', tok.block_size), 74 | n_layer=rckpt['config'].get('n_layer', 4), n_head=rckpt['config'].get('n_head', 4), n_embd=rckpt['config'].get('n_embd', 256)).to(device) 75 | rm.load_state_dict(rckpt['model']) 76 | rm.eval() 77 | 78 | opt = torch.optim.AdamW(policy.parameters(), lr=args.lr, betas=(0.9, 0.999)) 79 | 80 | # small prompt pool 81 | prompts = sample_prompts(16) 82 | 83 | step = 0 84 | while step < args.steps: 85 | # ----- COLLECT ROLLOUT BATCH ----- 86 | batch_prompts = prompts[ (step*args.batch_size) % len(prompts) : ((step+1)*args.batch_size) % len(prompts) ] 87 | if len(batch_prompts) < args.batch_size: 88 | batch_prompts += prompts[:args.batch_size-len(batch_prompts)] 89 | texts = [format_prompt_only(p).replace("", "") for p in batch_prompts] 90 | in_ids = [tok.encode(t) for t in texts] 91 | 92 | with torch.no_grad(): 93 | out_ids = [] 94 | for i, x in enumerate(in_ids): 95 | idx = torch.tensor([x], dtype=torch.long, device=device) 96 | out = policy.generate(idx, max_new_tokens=args.resp_len, temperature=0.2, top_k=3) 97 | out_ids.append(out[0].tolist()) 98 | 99 | # split prompt/response per sample 100 | data = [] 101 | for i, prompt in enumerate(batch_prompts): 102 | full = out_ids[i] 103 | # find boundary: index where prompt ends in the tokenized form 104 | # Use original prompt tokenization length (clipped by block_size) 105 | p_ids = in_ids[i][-block_size:] 106 | boundary = len(p_ids) 107 | resp_ids = full[boundary:] 108 | # compute rewards via RM on formatted prompt+response text 109 | resp_text = tok.decode(resp_ids) 110 | r_scalar = compute_reward(rm, tok, prompt, resp_text, device) 111 | data.append((torch.tensor(full, dtype=torch.long), boundary, r_scalar)) 112 | 113 | # pad to same length 114 | policy_ctx = getattr(policy, "block_size", block_size) 115 | max_len = min(policy_ctx, max(t[0].numel() for t in data)) 116 | B = len(data) 117 | seq = torch.zeros(B, max_len, dtype=torch.long, device=device) 118 | mask = torch.zeros(B, max_len, dtype=torch.bool, device=device) 119 | last_idx = torch.zeros(B, dtype=torch.long, device=device) 120 | rewards = torch.zeros(B, max_len, dtype=torch.float, device=device) 121 | 122 | for i, (ids, boundary, r_scalar) in enumerate(data): 123 | L_full = ids.numel() 124 | L = min(L_full, max_len) 125 | drop = L_full - L # tokens dropped from the left 126 | b = max(0, boundary - drop) # shift boundary after left-trim 127 | seq[i, :L] = ids[-L:] 128 | if L < max_len: 129 | seq[i, L:] = 2 # fill remaining positions with token 130 | mask[i, b:L] = True 131 | rewards[i, L-1] = r_scalar 132 | last_idx[i] = L-1 133 | 134 | 135 | # logprobs & values for policy and reference 136 | # model_logprobs returns (B, T-1) for next-token logp; align to seq[:,1:] 137 | pol_lp = model_logprobs(policy, seq) 138 | ref_lp = model_logprobs(ref, seq) 139 | # values for seq positions (B,T) 140 | with torch.no_grad(): 141 | logits, values, _ = policy(seq, None) 142 | values = values[:, :-1] # align to pol_lp 143 | 144 | # Select only action positions 145 | act_mask = mask[:,1:] # since logprobs are for predicting token t from <=t-1 146 | old_logp = pol_lp[act_mask].detach() 147 | ref_logp = ref_lp[act_mask].detach() 148 | old_values = values[act_mask].detach() 149 | 150 | # KL per action token and shaped rewards 151 | kl = (old_logp - ref_logp) 152 | shaped_r = rewards[:,1:][act_mask] - args.kl_coef * kl # penalty for drifting 153 | 154 | # Compute advantages/returns with last‑step bootstrap = 0 (episodic per response) 155 | # Flatten by sequence order inside each sample; we’ll approximate by grouping tokens per sample using last_idx. 156 | # For tutorial simplicity, treat advantages = shaped_r - old_values (no GAE). Works for end-only reward. 157 | returns = shaped_r # target value = immediate shaped reward 158 | adv = returns - old_values 159 | # normalize adv 160 | adv = (adv - adv.mean()) / (adv.std().clamp_min(1e-6)) 161 | 162 | # ----- UPDATE (single pass PPO for demo) ----- 163 | # This step is done multiple times per batch in practice 164 | policy.train() 165 | logits_new, values_new_full, _ = policy(seq, None) 166 | logp_full = torch.log_softmax(logits_new[:, :-1, :], dim=-1) 167 | labels = seq[:,1:] 168 | new_logp_all = logp_full.gather(-1, labels.unsqueeze(-1)).squeeze(-1) 169 | new_logp = new_logp_all[act_mask] 170 | new_values = values_new_full[:, :-1][act_mask] 171 | 172 | from ppo_loss import ppo_losses 173 | out_loss = ppo_losses(new_logp, old_logp, adv, new_values, old_values, returns, 174 | clip_ratio=0.2, vf_coef=0.5, ent_coef=0.0) 175 | loss = out_loss.total_loss 176 | 177 | opt.zero_grad(set_to_none=True) 178 | loss.backward() 179 | torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0) 180 | opt.step() 181 | policy.eval() 182 | 183 | with torch.no_grad(): 184 | # KL(old || new): movement of the updated policy from the snapshot used to collect data 185 | lp_post = model_logprobs(policy, seq) # (B, T-1) 186 | lp_post = lp_post[act_mask] # only action positions 187 | kl_post = (old_logp - lp_post).mean() # ≈ E[log π_old - log π_new] 188 | 189 | # KL(now || ref): how far the current policy is from the frozen reference 190 | lp_now = lp_post # already computed above on the same positions 191 | kl_ref_now = (lp_now - ref_logp).mean() # ≈ E[log π_now - log π_ref] 192 | 193 | step += 1 194 | if step % 10 == 0: 195 | print( 196 | f"step {step} | loss {loss.item():.4f}" 197 | f"| value loss {out_loss.value_loss.item():.4f} | KL_move {kl_post.item():.6f} | KL_ref {kl_ref_now.item():.6f}" 198 | ) 199 | 200 | 201 | Path(args.out).mkdir(parents=True, exist_ok=True) 202 | torch.save({'model': policy.state_dict(), 'config': { 203 | 'vocab_size': vocab_size, 204 | 'block_size': block_size, 205 | 'n_layer': n_layer, 206 | 'n_head': n_head, 207 | 'n_embd': n_embd, 208 | }}, str(Path(args.out)/'model_last.pt')) 209 | print(f"Saved PPO policy to {args.out}/model_last.pt") 210 | 211 | if __name__ == '__main__': 212 | main() -------------------------------------------------------------------------------- /part_9/train_grpo.py: -------------------------------------------------------------------------------- 1 | # train_grpo.py 2 | from __future__ import annotations 3 | import argparse, torch 4 | from pathlib import Path 5 | 6 | from policy import PolicyWithValue # we will ignore the value head 7 | from rollout import RLHFTokenizer, format_prompt_only, sample_prompts, model_logprobs 8 | 9 | # Reward model from Part 7 10 | import sys 11 | from pathlib import Path as _P 12 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_7')) 13 | from model_reward import RewardModel # noqa: E402 14 | 15 | from grpo_loss import ppo_policy_only_losses 16 | 17 | 18 | @torch.no_grad() 19 | def compute_reward(reward_model: RewardModel, tok: RLHFTokenizer, prompt_text: str, response_ids: list[int], device) -> float: 20 | # Build full formatted text (as in your PPO) 21 | from part_6.formatters import Example, format_example 22 | resp_text = tok.decode(response_ids) 23 | text = format_example(Example(prompt_text, resp_text)) 24 | ids = tok.encode(text) 25 | x = torch.tensor([ids[:tok.block_size]], dtype=torch.long, device=device) 26 | r = reward_model(x) 27 | return float(r[0].item()) 28 | 29 | 30 | def main(): 31 | p = argparse.ArgumentParser() 32 | p.add_argument('--out', type=str, default='runs/grpo-demo') 33 | p.add_argument('--policy_ckpt', type=str, required=True, help='SFT checkpoint (Part 6)') 34 | p.add_argument('--reward_ckpt', type=str, required=True, help='Reward model checkpoint (Part 7)') 35 | p.add_argument('--steps', type=int, default=100) 36 | p.add_argument('--batch_prompts', type=int, default=32, help='number of distinct prompts per step (before grouping)') 37 | p.add_argument('--group_size', type=int, default=4, help='completions per prompt') 38 | p.add_argument('--block_size', type=int, default=256) 39 | p.add_argument('--resp_len', type=int, default=64) 40 | p.add_argument('--kl_coef', type=float, default=0.01) 41 | p.add_argument('--lr', type=float, default=1e-5) 42 | p.add_argument('--bpe_dir', type=str, default=None) 43 | p.add_argument('--cpu', action='store_true') 44 | args = p.parse_args() 45 | 46 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 47 | 48 | # tokenizer 49 | tok = RLHFTokenizer(block_size=args.block_size, bpe_dir=args.bpe_dir) 50 | 51 | # Load SFT policy (and a frozen reference) 52 | ckpt = torch.load(args.policy_ckpt, map_location=device) 53 | cfg = ckpt.get('config', {}) 54 | vocab_size = cfg.get('vocab_size', tok.vocab_size) 55 | block_size = cfg.get('block_size', tok.block_size) 56 | n_layer = cfg.get('n_layer', 2) 57 | n_head = cfg.get('n_head', 2) 58 | n_embd = cfg.get('n_embd', 128) 59 | 60 | policy = PolicyWithValue(vocab_size, block_size, n_layer, n_head, n_embd).to(device) 61 | policy.lm.load_state_dict(ckpt['model']) 62 | policy.eval() 63 | 64 | ref = PolicyWithValue(vocab_size, block_size, n_layer, n_head, n_embd).to(device) 65 | ref.lm.load_state_dict(ckpt['model']) 66 | for p_ in ref.parameters(): 67 | p_.requires_grad_(False) 68 | ref.eval() 69 | 70 | # Reward model 71 | rckpt = torch.load(args.reward_ckpt, map_location=device) 72 | rm = RewardModel(vocab_size=rckpt['config'].get('vocab_size', tok.vocab_size), 73 | block_size=rckpt['config'].get('block_size', tok.block_size), 74 | n_layer=rckpt['config'].get('n_layer', 4), 75 | n_head=rckpt['config'].get('n_head', 4), 76 | n_embd=rckpt['config'].get('n_embd', 256)).to(device) 77 | rm.load_state_dict(rckpt['model']) 78 | rm.eval() 79 | 80 | opt = torch.optim.AdamW(policy.parameters(), lr=args.lr, betas=(0.9, 0.999)) 81 | 82 | # small prompt pool (reuse your helper) 83 | prompts_pool = sample_prompts(16) 84 | 85 | step = 0 86 | pool_idx = 0 87 | G = args.group_size 88 | 89 | while step < args.steps: 90 | # ----- SELECT PROMPTS ----- 91 | # Choose P prompts, each will yield G completions → B = P*G trajectories 92 | P = max(1, args.batch_prompts) 93 | if pool_idx + P > len(prompts_pool): 94 | pool_idx = 0 95 | batch_prompts = prompts_pool[pool_idx: pool_idx + P] 96 | pool_idx += P 97 | 98 | # Tokenize prompt-only texts 99 | prompt_texts = [format_prompt_only(p).replace("", "") for p in batch_prompts] 100 | prompt_in_ids = [tok.encode(t) for t in prompt_texts] 101 | 102 | # ----- GENERATE G COMPLETIONS PER PROMPT ----- 103 | # We will collect all trajectories flat, but track their group/prompt ids. 104 | seq_list = [] # list[Tensor of token ids] 105 | boundary_list = [] # index where response starts in the (possibly clipped) sequence 106 | prompt_id_of = [] # which prompt this trajectory belongs to (0..P-1) 107 | raw_rewards = [] # scalar reward per trajectory (before KL shaping) 108 | last_idx_list = [] # for padding bookkeeping 109 | 110 | with torch.no_grad(): 111 | for pid, p_ids in enumerate(prompt_in_ids): 112 | for g in range(G): 113 | idx = torch.tensor([p_ids], dtype=torch.long, device=device) 114 | out = policy.generate(idx, max_new_tokens=args.resp_len, temperature=2, top_k=3) 115 | full_ids = out[0].tolist() 116 | 117 | # split prompt/response 118 | boundary = len(p_ids[-block_size:]) # prompt length clipped to context 119 | resp_ids = full_ids[boundary:] 120 | r_scalar = compute_reward(rm, tok, batch_prompts[pid], resp_ids, device) 121 | 122 | seq_list.append(torch.tensor(full_ids, dtype=torch.long)) 123 | boundary_list.append(boundary) 124 | prompt_id_of.append(pid) 125 | raw_rewards.append(r_scalar) 126 | 127 | # ----- PAD TO BATCH ----- 128 | B = len(seq_list) # B = P*G 129 | policy_ctx = getattr(policy, "block_size", block_size) 130 | max_len = min(policy_ctx, max(s.numel() for s in seq_list)) 131 | seq = torch.zeros(B, max_len, dtype=torch.long, device=device) 132 | mask = torch.zeros(B, max_len, dtype=torch.bool, device=device) 133 | last_idx = torch.zeros(B, dtype=torch.long, device=device) 134 | 135 | # keep a per-traj “action positions” mask and response-only boundary 136 | for i, (ids, bnd) in enumerate(zip(seq_list, boundary_list)): 137 | L_full = ids.numel() 138 | L = min(L_full, max_len) 139 | drop = L_full - L 140 | b = max(0, bnd - drop) # shifted boundary after left-trim 141 | seq[i, :L] = ids[-L:] 142 | if L < max_len: 143 | seq[i, L:] = 2 # pad token 144 | # actions are predicting token t from <=t-1 → positions [1..L-1] 145 | # but we only care about response tokens: mask [b..L-1] → actions [b+1..L-1] 146 | mask[i, b:L] = True 147 | last_idx[i] = L - 1 148 | 149 | # ----- LOGPROBS & KL VS REF (token-level) ----- 150 | # model_logprobs returns log p(x[t] | x[:t-1]) for t=1..T-1 over labels=seq[:,1:] 151 | with torch.no_grad(): 152 | pol_lp_full = model_logprobs(policy, seq) # (B, T-1) 153 | ref_lp_full = model_logprobs(ref, seq) # (B, T-1) 154 | 155 | # action positions (predict positions [1..T-1]); we want only response tokens: 156 | act_mask = mask[:, 1:] # align to (B, T-1) 157 | old_logp = pol_lp_full[act_mask].detach() 158 | ref_logp = ref_lp_full[act_mask].detach() 159 | 160 | # per-token KL on action tokens 161 | kl_tok = (old_logp - ref_logp) # (N_act,) 162 | 163 | # ----- SHAPED TRAJECTORY REWARD & GROUP BASELINE ----- 164 | # For GRPO, advantage is trajectory-level and broadcast to its tokens. 165 | # We include KL shaping at trajectory level using mean token KL per trajectory. 166 | # First, compute mean KL per trajectory on its action tokens. 167 | # Build an index map from flat action tokens back to traj ids. 168 | # We can reconstruct counts by iterating rows. 169 | traj_id_for_token = [] 170 | counts = torch.zeros(B, dtype=torch.long, device=device) 171 | offset = 0 172 | for i in range(B): 173 | mrow = act_mask[i] 174 | n_i = int(mrow.sum().item()) 175 | if n_i > 0: 176 | traj_id_for_token.extend([i] * n_i) 177 | counts[i] = n_i 178 | offset += n_i 179 | traj_id_for_token = torch.tensor(traj_id_for_token, dtype=torch.long, device=device) 180 | raw_rewards_t = torch.tensor(raw_rewards, dtype=torch.float, device=device) 181 | 182 | # Compute per-prompt group mean of shaped rewards 183 | group_mean = torch.zeros(B, dtype=torch.float, device=device) 184 | for pid in range(P): 185 | idxs = [i for i in range(B) if prompt_id_of[i] == pid] 186 | if not idxs: 187 | continue 188 | idxs_t = torch.tensor(idxs, dtype=torch.long, device=device) 189 | mean_val = raw_rewards_t[idxs_t].mean() 190 | group_mean[idxs_t] = mean_val 191 | 192 | # Advantage per trajectory, broadcast to its action tokens 193 | traj_adv = raw_rewards_t - group_mean # (B,) 194 | 195 | # Build a flat tensor of advantages aligned with old_logp/new_logp on action tokens 196 | if kl_tok.numel() > 0: 197 | adv_flat = traj_adv[traj_id_for_token] 198 | else: 199 | adv_flat = torch.zeros(0, dtype=torch.float, device=device) 200 | 201 | # Normalize advantages (optional but usually helpful) 202 | if adv_flat.numel() > 1: 203 | adv_flat = (adv_flat - adv_flat.mean()) / (adv_flat.std().clamp_min(1e-6)) 204 | 205 | # ----- UPDATE (policy-only PPO clipped objective) ----- 206 | policy.train() 207 | logits_new, _, _ = policy(seq, None) # ignore value head 208 | logp_full = torch.log_softmax(logits_new[:, :-1, :], dim=-1) 209 | labels = seq[:, 1:] 210 | new_logp_all = logp_full.gather(-1, labels.unsqueeze(-1)).squeeze(-1) # (B, T-1) 211 | new_logp = new_logp_all[act_mask] 212 | 213 | # Mean KL over action tokens 214 | kl_now_ref_mean = (new_logp - ref_logp).mean() if new_logp.numel() > 0 else torch.tensor(0.0, device=device) 215 | 216 | out_loss = ppo_policy_only_losses( 217 | new_logp=new_logp, 218 | old_logp=old_logp, 219 | adv=adv_flat, 220 | clip_ratio=0.2, 221 | ent_coef=0.0, # set >0 if you want entropy bonus from -new_logp mean 222 | kl_coef=args.kl_coef, 223 | kl_mean=kl_now_ref_mean, 224 | ) 225 | loss = out_loss.total_loss 226 | 227 | opt.zero_grad(set_to_none=True) 228 | loss.backward() 229 | torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0) 230 | opt.step() 231 | policy.eval() 232 | 233 | # Some quick diagnostics (movement vs old, and now vs ref) 234 | with torch.no_grad(): 235 | lp_post = model_logprobs(policy, seq)[act_mask] 236 | kl_move = (old_logp - lp_post).mean() if lp_post.numel() > 0 else torch.tensor(0.0, device=device) 237 | # KL(now || ref) 238 | kl_ref_now = (lp_post - ref_logp).mean() if lp_post.numel() > 0 else torch.tensor(0.0, device=device) 239 | 240 | step += 1 241 | if step % 10 == 0: 242 | print( 243 | f"step {step} | loss {loss.item():.4f}" 244 | f"| KL_move {kl_move.item():.6f} | KL_ref {kl_ref_now.item():.6f}" 245 | ) 246 | 247 | Path(args.out).mkdir(parents=True, exist_ok=True) 248 | torch.save({'model': policy.state_dict(), 'config': { 249 | 'vocab_size': vocab_size, 250 | 'block_size': block_size, 251 | 'n_layer': n_layer, 252 | 'n_head': n_head, 253 | 'n_embd': n_embd, 254 | }}, str(Path(args.out)/'model_last.pt')) 255 | print(f"Saved GRPO policy to {args.out}/model_last.pt") 256 | 257 | 258 | if __name__ == '__main__': 259 | main() 260 | -------------------------------------------------------------------------------- /part_4/checkpointing.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import os 3 | from pathlib import Path 4 | from typing import Any, Dict, Optional, Tuple 5 | import sys 6 | sys.path.append(str(Path(__file__).resolve().parents[1]/'part_3')) 7 | import time 8 | import torch 9 | import shutil 10 | import torch.nn as nn 11 | 12 | DEF_NAME = "model_last.pt" 13 | 14 | # ----------------------------- TB-only helpers (safe no-ops otherwise) ----------------------------- # 15 | def _is_tb(logger) -> bool: 16 | return getattr(logger, "w", None) is not None 17 | 18 | 19 | # checkpointing._log_hparams_tb 20 | def _log_hparams_tb(logger, args, total_steps): 21 | if not _is_tb(logger): return 22 | try: 23 | h = dict( 24 | vocab_size=args.vocab_size, block_size=args.block_size, n_layer=args.n_layer, 25 | n_head=args.n_head, n_embd=args.n_embd, dropout=args.dropout, lr=args.lr, 26 | warmup_steps=args.warmup_steps, batch_size=args.batch_size, grad_accum=args.grad_accum_steps, 27 | mixed_precision=args.mixed_precision, steps=args.steps, epochs=args.epochs, 28 | ) 29 | logger.hparams(h, {"meta/total_steps": float(total_steps)}) 30 | except Exception: 31 | pass 32 | 33 | def _maybe_log_graph_tb(logger, model, xb, yb): 34 | if not hasattr(logger, "graph"): 35 | return 36 | try: 37 | class _TensorOnly(nn.Module): 38 | def __init__(self, m): 39 | super().__init__(); self.m = m.eval() 40 | def forward(self, x, y=None): 41 | out = self.m(x, y) if y is not None else self.m(x) 42 | if isinstance(out, (list, tuple)): 43 | for o in out: 44 | if torch.is_tensor(o): 45 | return o 46 | return out[0] 47 | return out 48 | wrapped = _TensorOnly(model).to(xb.device) 49 | logger.graph(wrapped, (xb, yb)) 50 | except Exception: 51 | pass 52 | 53 | def _log_model_stats(logger, model, step: int, do_hists: bool = False): 54 | if not _is_tb(logger): return 55 | try: 56 | params = [p for p in model.parameters() if p.requires_grad] 57 | total_param_norm = torch.norm(torch.stack([p.detach().norm(2) for p in params]), 2).item() 58 | grads = [p.grad for p in params if p.grad is not None] 59 | total_grad_norm = float('nan') 60 | if grads: 61 | total_grad_norm = torch.norm(torch.stack([g.detach().norm(2) for g in grads]), 2).item() 62 | logger.log(step=step, **{ 63 | "train/param_global_l2": total_param_norm, 64 | "train/grad_global_l2": total_grad_norm, 65 | }) 66 | if do_hists: 67 | for name, p in model.named_parameters(): 68 | logger.hist(f"params/{name}", p, step) 69 | if p.grad is not None: 70 | logger.hist(f"grads/{name}", p.grad, step) 71 | except Exception: 72 | pass 73 | 74 | def _maybe_log_attention(logger, model, xb, step: int, every: int = 100): 75 | """ 76 | Logs Q/K/V histograms for each Transformer block using the current minibatch xb. 77 | No model edits. No hooks. Runs a light no-grad recomputation of the pre-attn path. 78 | - Takes first batch and first head only to keep logs tiny. 79 | - Uses pre-RoPE values (simpler & stable for histograms). 80 | """ 81 | if not _is_tb(logger) or step == 0 or (step % every): 82 | return 83 | try: 84 | import torch 85 | with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): 86 | # Recreate inputs seen by blocks 87 | x = model.tok_emb(xb) # (B,T,C) 88 | x = model.drop(x) 89 | 90 | B, T, _ = x.shape 91 | for li, blk in enumerate(getattr(model, "blocks", [])): 92 | h = blk.ln1(x) # pre-attn normalized hidden 93 | 94 | attn = blk.attn 95 | # Project to Q/K/V exactly like the module (pre-RoPE for simplicity) 96 | q = attn.wq(h).view(B, T, attn.n_head, attn.d_head).transpose(1, 2) # (B,H,T,D) 97 | k = attn.wk(h).view(B, T, attn.n_kv_head, attn.d_head).transpose(1, 2) # (B,Hk,T,D) 98 | v = attn.wv(h).view(B, T, attn.n_kv_head, attn.d_head).transpose(1, 2) # (B,Hk,T,D) 99 | 100 | # Take a tiny slice to keep logs light 101 | q1 = q[:1, :1].contiguous().view(-1).float().cpu() 102 | k1 = k[:1, :1].contiguous().view(-1).float().cpu() 103 | v1 = v[:1, :1].contiguous().view(-1).float().cpu() 104 | 105 | # Drop non-finite (defensive) 106 | q1 = q1[torch.isfinite(q1)] 107 | k1 = k1[torch.isfinite(k1)] 108 | v1 = v1[torch.isfinite(v1)] 109 | 110 | if q1.numel() > 0: logger.hist(f"qkv/block{li}/q_hist", q1, step) 111 | if k1.numel() > 0: logger.hist(f"qkv/block{li}/k_hist", k1, step) 112 | if v1.numel() > 0: logger.hist(f"qkv/block{li}/v_hist", v1, step) 113 | 114 | # Optional small scalars (norms) that show up on Time Series 115 | if q1.numel(): logger.log(step=step, **{f"qkv/block{li}/q_l2_mean": float(q1.square().mean().sqrt())}) 116 | if k1.numel(): logger.log(step=step, **{f"qkv/block{li}/k_l2_mean": float(k1.square().mean().sqrt())}) 117 | if v1.numel(): logger.log(step=step, **{f"qkv/block{li}/v_l2_mean": float(v1.square().mean().sqrt())}) 118 | 119 | # Advance x to next block with a CHEAP approximation to avoid doubling full compute: 120 | # use the model's own FFN path only; skip re-running attention (we're only logging pre-attn stats). 121 | x = x + blk.ffn(blk.ln2(x)) 122 | 123 | except Exception as e: 124 | print(f"[qkv] logging failed: {e}") 125 | 126 | 127 | def _log_runtime(logger, step: int, it_t0: float, xb, device): 128 | try: 129 | dt = time.time() - it_t0 130 | toks = int(xb.numel()) 131 | toks_per_s = toks / max(dt, 1e-6) 132 | mem = torch.cuda.memory_allocated()/(1024**2) if torch.cuda.is_available() else 0.0 133 | logger.log(step=step, **{ 134 | "sys/throughput_tokens_per_s": toks_per_s, 135 | "sys/step_time_s": dt, 136 | "sys/gpu_mem_alloc_mb": mem 137 | }) 138 | except Exception: 139 | pass 140 | 141 | def _log_samples_tb(logger, model, tok, xb, device, step: int, max_new_tokens: int = 64): 142 | if not _is_tb(logger): return 143 | if tok is None: return 144 | try: 145 | model.eval() 146 | with torch.no_grad(): 147 | out = model.generate(xb[:1].to(device), max_new_tokens=max_new_tokens, temperature=1.0, top_k=50) 148 | model.train() 149 | text = tok.decode(out[0].tolist()) 150 | logger.text("samples/generation", text, step) 151 | except Exception: 152 | pass 153 | # ---------------------------------------------------------------------- # 154 | 155 | def _extract_config_from_model(model) -> dict: 156 | """ 157 | Best-effort extraction of GPTModern-like config including GQA fields. 158 | """ 159 | cfg = {} 160 | try: 161 | tok_emb = getattr(model, "tok_emb", None) 162 | blocks = getattr(model, "blocks", None) 163 | if tok_emb is None or not blocks: 164 | return cfg 165 | 166 | try: 167 | from swiglu import SwiGLU # optional 168 | except Exception: 169 | class SwiGLU: pass 170 | 171 | cfg["vocab_size"] = int(tok_emb.num_embeddings) 172 | cfg["block_size"] = int(getattr(model, "block_size", 0) or 0) 173 | cfg["n_layer"] = int(len(blocks)) 174 | 175 | first_blk = blocks[0] 176 | attn = getattr(first_blk, "attn", None) 177 | if attn is None: 178 | return cfg 179 | 180 | # Heads & dims 181 | cfg["n_head"] = int(getattr(attn, "n_head")) 182 | d_head = int(getattr(attn, "d_head")) 183 | cfg["n_embd"] = int(cfg["n_head"] * d_head) 184 | cfg["n_kv_head"]= int(getattr(attn, "n_kv_head", cfg["n_head"])) # default to MHA 185 | 186 | # Dropout (if present) 187 | drop = getattr(attn, "dropout", None) 188 | cfg["dropout"] = float(getattr(drop, "p", 0.0)) if drop is not None else 0.0 189 | 190 | # Norm/FFN style 191 | cfg["use_rmsnorm"] = isinstance(getattr(model, "ln_f", None), nn.Identity) 192 | cfg["use_swiglu"] = isinstance(getattr(first_blk, "ffn", None), SwiGLU) 193 | 194 | # Positional / attention tricks 195 | for k in ("rope", "max_pos", "sliding_window", "attention_sink"): 196 | if hasattr(attn, k): 197 | val = getattr(attn, k) 198 | cfg[k] = int(val) if isinstance(val, bool) else val 199 | except Exception: 200 | return {} 201 | return cfg 202 | 203 | def _verify_model_matches(model, cfg: Dict[str, Any]) -> Tuple[bool, str]: 204 | """Return (ok, message).""" 205 | expected = { 206 | "block_size": cfg.get("block_size"), 207 | "n_layer": cfg.get("n_layer"), 208 | "n_head": cfg.get("n_head"), 209 | "n_embd": cfg.get("n_embd"), 210 | "vocab_size": cfg.get("vocab_size"), 211 | "n_kv_head": cfg.get("n_kv_head", cfg.get("n_head")), 212 | } 213 | got = { 214 | "block_size": int(getattr(model, "block_size", -1)), 215 | "n_layer": int(len(model.blocks)), 216 | "vocab_size": int(model.tok_emb.num_embeddings), 217 | } 218 | first_blk = model.blocks[0] 219 | got.update({ 220 | "n_head": int(first_blk.attn.n_head), 221 | "n_embd": int(first_blk.attn.n_head * first_blk.attn.d_head), 222 | "n_kv_head": int(getattr(first_blk.attn, "n_kv_head", first_blk.attn.n_head)), 223 | }) 224 | diffs = [f"{k}: ckpt={expected[k]} vs model={got[k]}" for k in expected if expected[k] != got[k]] 225 | if diffs: 226 | return False, "Architecture mismatch:\n " + "\n ".join(diffs) 227 | return True, "ok" 228 | 229 | 230 | def save_checkpoint(model, optimizer, scheduler, amp, step: int, out_dir: str, 231 | tokenizer_dir: str | None = None, config: dict | None = None): 232 | out = Path(out_dir); out.mkdir(parents=True, exist_ok=True) 233 | 234 | # Prefer the model’s own config if available (e.g., a dict or dataclass with __dict__/asdict) 235 | if hasattr(model, "config"): 236 | cfg_obj = model.config 237 | cfg = dict(cfg_obj) if isinstance(cfg_obj, dict) else getattr(cfg_obj, "__dict__", None) or _extract_config_from_model(model) 238 | else: 239 | cfg = config if config is not None else _extract_config_from_model(model) 240 | 241 | torch.save({ 242 | "model": model.state_dict(), 243 | "optimizer": optimizer.state_dict() if optimizer is not None else None, 244 | "scheduler": scheduler.state_dict() if hasattr(scheduler, "state_dict") else None, 245 | "amp_scaler": amp.scaler.state_dict() if amp and getattr(amp, "scaler", None) else None, 246 | "step": int(step), 247 | "config": cfg, # ← always write config 248 | "version": "part4-v2", 249 | }, out / DEF_NAME) 250 | 251 | if tokenizer_dir is not None: 252 | (out / "tokenizer_dir.txt").write_text(tokenizer_dir) 253 | 254 | 255 | 256 | def load_checkpoint(model, path: str, optimizer=None, scheduler=None, amp=None, strict: bool = True): 257 | ckpt = torch.load(path, map_location="cpu") 258 | 259 | cfg = ckpt.get("config") 260 | if cfg: 261 | ok, msg = _verify_model_matches(model, cfg) 262 | if not ok: 263 | raise RuntimeError(msg + "\nRebuild the model with this config, or load with strict=False.") 264 | else: 265 | # Legacy checkpoint without config: strongly encourage a rebuild step elsewhere 266 | print("[compat] Warning: checkpoint has no config; cannot verify architecture.") 267 | 268 | missing, unexpected = model.load_state_dict(ckpt["model"], strict=strict) 269 | if strict and (missing or unexpected): 270 | raise RuntimeError(f"State dict mismatch:\n missing: {missing}\n unexpected: {unexpected}") 271 | 272 | if optimizer is not None and ckpt.get("optimizer") is not None: 273 | optimizer.load_state_dict(ckpt["optimizer"]) 274 | if scheduler is not None and ckpt.get("scheduler") is not None and hasattr(scheduler, "load_state_dict"): 275 | scheduler.load_state_dict(ckpt["scheduler"]) 276 | if amp is not None and ckpt.get("amp_scaler") is not None and getattr(amp, "scaler", None): 277 | amp.scaler.load_state_dict(ckpt["amp_scaler"]) 278 | 279 | return ckpt.get("step", 0) 280 | 281 | 282 | # ----------------------------- checkpoint/save utils ----------------------------- # 283 | def checkpoint_paths(out_dir: Path, step: int): 284 | return out_dir / f"model_step{step:07d}.pt", out_dir / "model_last.pt" 285 | 286 | def atomic_save_all(model, optim, sched, amp, step: int, out_dir: Path, 287 | tok_dir: str | None, keep_last_k: int, config: dict): 288 | """Write model_last.pt (with config) + a rolling per-step copy.""" 289 | save_checkpoint(model, optim, sched, amp, step, str(out_dir), tok_dir, config=config) # writes model_last.pt 290 | per_step, last = checkpoint_paths(out_dir, step) 291 | try: 292 | shutil.copy2(last, per_step) 293 | except Exception: 294 | pass 295 | # GC old per-step checkpoints 296 | try: 297 | ckpts = sorted(out_dir.glob("model_step*.pt")) 298 | for old in ckpts[:-keep_last_k]: 299 | old.unlink(missing_ok=True) 300 | except Exception: 301 | pass --------------------------------------------------------------------------------