├── .gitignore
├── part_1
├── out
│ ├── multi_head_attn_grid.png
│ └── mha_shapes.txt
├── tests
│ ├── test_causal_mask.py
│ └── test_attn_math.py
├── attn_mask.py
├── demo_visualize_multi_head.py
├── block.py
├── ffn.py
├── pos_encoding.py
├── single_head.py
├── vis_utils.py
├── attn_numpy_demo.py
├── demo_mha_shapes.py
├── multi_head.py
└── orchestrator.py
├── .vscode
├── settings.json
└── launch.json
├── part_3
├── tests
│ ├── test_rmsnorm.py
│ ├── test_kvcache_shapes.py
│ └── test_rope_apply.py
├── tokenizer.py
├── rmsnorm.py
├── swiglu.py
├── utils.py
├── block_modern.py
├── kv_cache.py
├── orchestrator.py
├── demo_generate.py
├── rope_custom.py
├── attn_modern.py
└── model_modern.py
├── part_2
├── tests
│ ├── test_tokenizer.py
│ └── test_dataset_shift.py
├── tokenizer.py
├── dataset.py
├── utils.py
├── eval_loss.py
├── sample.py
├── orchestrator.py
├── model_gpt.py
└── train.py
├── part_5
├── tests
│ ├── test_hybrid_block.py
│ ├── test_gate_shapes.py
│ └── test_moe_forward.py
├── block_hybrid.py
├── experts.py
├── demo_moe.py
├── moe.py
├── orchestrator.py
├── gating.py
└── README.md
├── part_7
├── tests
│ ├── test_bt_loss.py
│ └── test_reward_forward.py
├── loss_reward.py
├── model_reward.py
├── eval_rm.py
├── data_prefs.py
├── orchestrator.py
├── collator_rm.py
└── train_rm.py
├── part_8
├── tests
│ ├── test_policy_forward.py
│ └── test_ppo_loss.py
├── ppo_loss.py
├── policy.py
├── orchestrator.py
├── eval_ppo.py
├── rollout.py
└── train_ppo.py
├── part_6
├── tests
│ ├── test_formatter.py
│ └── test_masking.py
├── curriculum.py
├── formatters.py
├── evaluate.py
├── dataset_sft.py
├── sample_sft.py
├── orchestrator.py
├── collator_sft.py
└── train_sft.py
├── part_4
├── tests
│ ├── test_scheduler.py
│ ├── test_tokenizer_bpe.py
│ └── test_resume_shapes.py
├── lr_scheduler.py
├── amp_accum.py
├── dataset_bpe.py
├── orchestrator.py
├── tokenizer_bpe.py
├── sample.py
├── logger.py
├── train.py
└── checkpointing.py
├── part_9
├── tests
│ └── test_grpo_loss.py
├── policy.py
├── grpo_loss.py
├── orchestrator.py
├── eval_ppo.py
├── rollout.py
└── train_grpo.py
├── requirements.txt
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .pytest_cache/
3 | .DS_Store
4 | runs/
--------------------------------------------------------------------------------
/part_1/out/multi_head_attn_grid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vivekkalyanarangan30/llm_from_scratch/HEAD/part_1/out/multi_head_attn_grid.png
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python-envs.defaultEnvManager": "ms-python.python:conda",
3 | "python-envs.defaultPackageManager": "ms-python.python:conda",
4 | "python-envs.pythonProjects": []
5 | }
--------------------------------------------------------------------------------
/part_3/tests/test_rmsnorm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from rmsnorm import RMSNorm
3 |
4 | def test_rmsnorm_shapes():
5 | x = torch.randn(2,3,8)
6 | rn = RMSNorm(8)
7 | y = rn(x)
8 | assert y.shape == x.shape
--------------------------------------------------------------------------------
/part_1/tests/test_causal_mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from attn_mask import causal_mask
3 |
4 | def test_mask_is_upper_triangle():
5 | m = causal_mask(5)
6 | # ensure shape and diagonal rule
7 | assert m.shape == (1,1,5,5)
8 | assert m[0,0].sum() == torch.triu(torch.ones(5,5), diagonal=1).sum()
9 |
--------------------------------------------------------------------------------
/part_2/tests/test_tokenizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tokenizer import ByteTokenizer
3 |
4 | def test_roundtrip():
5 | tok = ByteTokenizer()
6 | s = "Hello, ByteTok! äö"
7 | ids = tok.encode(s)
8 | assert ids.dtype == torch.long
9 | s2 = tok.decode(ids)
10 | assert len(s2) > 0
--------------------------------------------------------------------------------
/part_5/tests/test_hybrid_block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from block_hybrid import HybridFFN
3 |
4 | def test_hybrid_ffn_blend():
5 | B,T,C = 1, 4, 16
6 | ffn = HybridFFN(dim=C, alpha=0.3, n_expert=3, k=2)
7 | x = torch.randn(B,T,C)
8 | y, aux = ffn(x)
9 | assert y.shape == x.shape
10 | assert aux.item() >= 0.0
--------------------------------------------------------------------------------
/part_7/tests/test_bt_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from loss_reward import bradley_terry_loss
3 |
4 | def test_bradley_terry_monotonic():
5 | pos = torch.tensor([2.0, 3.0])
6 | neg = torch.tensor([1.0, 1.5])
7 | l1 = bradley_terry_loss(pos, neg)
8 | l2 = bradley_terry_loss(pos+1.0, neg) # increase margin
9 | assert l2 < l1
--------------------------------------------------------------------------------
/part_1/attn_mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def causal_mask(T: int, device=None):
4 | """Returns a bool mask where True means *masked* (disallowed).
5 | Shape: (1, 1, T, T) suitable for broadcasting with (B, heads, T, T).
6 | """
7 | m = torch.triu(torch.ones((T, T), dtype=torch.bool, device=device), diagonal=1)
8 | return m.view(1, 1, T, T)
--------------------------------------------------------------------------------
/part_8/tests/test_policy_forward.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from policy import PolicyWithValue
3 |
4 | def test_policy_shapes():
5 | B,T,V = 2, 16, 256
6 | pol = PolicyWithValue(vocab_size=V, block_size=T, n_layer=2, n_head=2, n_embd=64)
7 | x = torch.randint(0, V, (B,T))
8 | logits, values, loss = pol(x, None)
9 | assert logits.shape == (B,T,V)
10 | assert values.shape == (B,T)
11 |
--------------------------------------------------------------------------------
/part_6/tests/test_formatter.py:
--------------------------------------------------------------------------------
1 | from formatters import Example, format_example, format_prompt_only
2 |
3 | def test_template_contains_markers():
4 | ex = Example("Say hi","Hello!")
5 | s = format_example(ex)
6 | assert "### Instruction:" in s and "### Response:" in s
7 | p = format_prompt_only("Explain transformers.")
8 | assert p.endswith("### Response:\n") or p.endswith("### Response:\n")
--------------------------------------------------------------------------------
/part_2/tests/test_dataset_shift.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dataset import ByteDataset
3 |
4 |
5 | def test_shift_alignment(tmp_path):
6 | p = tmp_path / 'toy.txt'
7 | p.write_text('abcdefg')
8 | ds = ByteDataset(str(p), block_size=3, split=1.0)
9 | x, y = ds.get_batch('train', 2, device=torch.device('cpu'))
10 | # shift must be next-token
11 | assert (y[:, :-1] == x[:, 1:]).all()
12 |
--------------------------------------------------------------------------------
/part_3/tests/test_kvcache_shapes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from kv_cache import RollingKV
3 |
4 | def test_rolling_kv_keep_window_with_sink():
5 | B,H,D = 1,2,4
6 | kv = RollingKV(window=4, sink=2)
7 | for _ in range(10):
8 | k_new = torch.randn(B,H,1,D)
9 | v_new = torch.randn(B,H,1,D)
10 | k,v = kv.step(k_new, v_new)
11 | # Should never exceed sink+window length
12 | assert k.size(2) <= 6
--------------------------------------------------------------------------------
/part_5/tests/test_gate_shapes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from gating import TopKGate
3 |
4 | def test_gate_topk_shapes():
5 | S, C, E, K = 32, 64, 4, 2
6 | x = torch.randn(S, C)
7 | gate = TopKGate(C, E, k=K)
8 | idx, w, aux = gate(x)
9 | assert idx.shape == (S, K)
10 | assert w.shape == (S, K)
11 | assert aux.ndim == 0
12 | # per-token weights are non-negative and <=1
13 | assert torch.all(w >= 0)
14 | assert torch.all(w <= 1)
--------------------------------------------------------------------------------
/part_8/tests/test_ppo_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ppo_loss import ppo_losses
3 |
4 | def test_clipped_objective_behaves():
5 | N = 32
6 | old_logp = torch.zeros(N)
7 | new_logp = torch.log(torch.full((N,), 1.2)) # ratio=1.2
8 | adv = torch.ones(N)
9 | new_v = torch.zeros(N)
10 | old_v = torch.zeros(N)
11 | ret = torch.ones(N)
12 | out = ppo_losses(new_logp, old_logp, adv, new_v, old_v, ret, clip_ratio=0.1)
13 | assert out.total_loss.ndim == 0
--------------------------------------------------------------------------------
/part_5/tests/test_moe_forward.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from moe import MoE
3 |
4 | def test_moe_forward_dims_and_grad():
5 | B,T,C = 2, 8, 32
6 | moe = MoE(dim=C, n_expert=4, k=1)
7 | x = torch.randn(B,T,C, requires_grad=True)
8 | y, aux = moe(x)
9 | assert y.shape == x.shape
10 | loss = (y**2).mean() + 0.01*aux
11 | loss.backward()
12 | # some gradient must flow to gate and experts
13 | grads = [p.grad is not None for p in moe.parameters()]
14 | assert any(grads)
--------------------------------------------------------------------------------
/part_7/tests/test_reward_forward.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from model_reward import RewardModel
3 |
4 | def test_reward_shapes_and_grad():
5 | B,T,V = 4, 16, 256
6 | m = RewardModel(vocab_size=V, block_size=T, n_layer=2, n_head=2, n_embd=64)
7 | x = torch.randint(0, V, (B,T))
8 | r = m(x)
9 | assert r.shape == (B,)
10 | # gradient flows
11 | loss = (r**2).mean()
12 | loss.backward()
13 | grads = [p.grad is not None for p in m.parameters()]
14 | assert any(grads)
--------------------------------------------------------------------------------
/part_1/demo_visualize_multi_head.py:
--------------------------------------------------------------------------------
1 | """Visualize multi-head attention weights per head (grid)."""
2 | import torch
3 | from multi_head import MultiHeadSelfAttention
4 | from vis_utils import save_attention_heads_grid
5 |
6 | B, T, d_model, n_head = 1, 5, 12, 3
7 | x = torch.randn(B, T, d_model)
8 | attn = MultiHeadSelfAttention(d_model, n_head, trace_shapes=False)
9 |
10 | out, w = attn(x) # w: (B, H, T, T)
11 |
12 | save_attention_heads_grid(w.detach().cpu().numpy(), filename="multi_head_attn_grid.png")
--------------------------------------------------------------------------------
/part_4/tests/test_scheduler.py:
--------------------------------------------------------------------------------
1 | from lr_scheduler import WarmupCosineLR
2 |
3 | class DummyOpt:
4 | def __init__(self):
5 | self.param_groups = [{'lr': 0.0}]
6 |
7 | def test_warmup_cosine_lr_progression():
8 | opt = DummyOpt()
9 | sch = WarmupCosineLR(opt, warmup_steps=10, total_steps=110, base_lr=1e-3)
10 | lrs = [sch.step() for _ in range(110)]
11 | assert max(lrs) <= 1e-3 + 1e-12
12 | assert lrs[0] < lrs[9] # warmup increasing
13 | assert lrs[-1] < lrs[10] # cosine decays
--------------------------------------------------------------------------------
/part_3/tokenizer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch
3 |
4 | class ByteTokenizer:
5 | """Simple byte-level tokenizer (0..255)."""
6 | def encode(self, s: str) -> torch.Tensor:
7 | return torch.tensor(list(s.encode('utf-8')), dtype=torch.long)
8 | def decode(self, ids) -> str:
9 | if isinstance(ids, torch.Tensor):
10 | ids = ids.tolist()
11 | return bytes(ids).decode('utf-8', errors='ignore')
12 | @property
13 | def vocab_size(self) -> int:
14 | return 256
--------------------------------------------------------------------------------
/part_3/rmsnorm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class RMSNorm(nn.Module):
5 | """Root Mean Square Layer Normalization.
6 | y = x * g / rms(x), rms(x) = sqrt(mean(x^2) + eps)
7 | """
8 | def __init__(self, dim: int, eps: float = 1e-8):
9 | super().__init__()
10 | self.eps = eps
11 | self.weight = nn.Parameter(torch.ones(dim))
12 | def forward(self, x: torch.Tensor) -> torch.Tensor:
13 | rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
14 | return (x / rms) * self.weight
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 |
8 |
9 |
10 | {
11 | "name": "Python Debugger: Current File",
12 | "type": "debugpy",
13 | "request": "launch",
14 | "program": "${file}",
15 | "console": "integratedTerminal",
16 | "args": ["--demo"]
17 | }
18 | ]
19 | }
--------------------------------------------------------------------------------
/part_6/tests/test_masking.py:
--------------------------------------------------------------------------------
1 | from collator_sft import SFTCollator
2 | from formatters import Example
3 |
4 | def test_masking_sets_prompt_to_ignore():
5 | col = SFTCollator(block_size=256, bpe_dir='../part_4/runs/part4-demo/tokenizer')
6 | text = "This is a tiny test."
7 | x, y = col.collate([(text, "OK")])
8 | # All labels up to response marker should be -100
9 | boundary = ("\n### Instruction:\n" + text + "\n\n### Response:\n")
10 | # We don't have direct access to the tokenized boundary; just sanity check: some -100s present
11 | assert (y[0] == -100).sum() > 0
12 |
--------------------------------------------------------------------------------
/part_6/curriculum.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import List
3 |
4 | class LengthCurriculum:
5 | """6.3 Curriculum: iterate examples from short→long prompts (one pass demo)."""
6 | def __init__(self, items: List[tuple[str,str]]):
7 | self.items = sorted(items, key=lambda p: len(p[0]))
8 | self._i = 0
9 | def __iter__(self):
10 | self._i = 0
11 | return self
12 | def __next__(self):
13 | if self._i >= len(self.items):
14 | raise StopIteration
15 | it = self.items[self._i]
16 | self._i += 1
17 | return it
--------------------------------------------------------------------------------
/part_4/tests/test_tokenizer_bpe.py:
--------------------------------------------------------------------------------
1 | import os, tempfile
2 | from tokenizer_bpe import BPETokenizer
3 |
4 | def test_bpe_train_save_load_roundtrip():
5 | with tempfile.TemporaryDirectory() as d:
6 | txt = os.path.join(d, 'tiny.txt')
7 | with open(txt, 'w') as f:
8 | f.write('hello hello world')
9 | tok = BPETokenizer(vocab_size=100)
10 | tok.train(txt)
11 | out = os.path.join(d, 'tok')
12 | tok.save(out)
13 | tok2 = BPETokenizer()
14 | tok2.load(out)
15 | ids = tok2.encode('hello world')
16 | assert isinstance(ids, list) and len(ids) > 0
--------------------------------------------------------------------------------
/part_2/tokenizer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch
3 |
4 | class ByteTokenizer:
5 | """Ultra-simple byte-level tokenizer.
6 | - encode(str) -> LongTensor [N]
7 | - decode(Tensor[int]) -> str
8 | - vocab_size = 256
9 | """
10 | def encode(self, s: str) -> torch.Tensor:
11 | return torch.tensor(list(s.encode('utf-8')), dtype=torch.long)
12 |
13 | def decode(self, ids) -> str:
14 | if isinstance(ids, torch.Tensor):
15 | ids = ids.tolist()
16 | return bytes(ids).decode('utf-8', errors='ignore')
17 |
18 | @property
19 | def vocab_size(self) -> int:
20 | return 256
--------------------------------------------------------------------------------
/part_3/swiglu.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class SwiGLU(nn.Module):
4 | """SwiGLU FFN: (xW1) ⊗ swish(xW2) W3 with expansion factor `mult`.
5 | """
6 | def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0):
7 | super().__init__()
8 | inner = mult * dim
9 | self.w1 = nn.Linear(dim, inner, bias=False)
10 | self.w2 = nn.Linear(dim, inner, bias=False)
11 | self.w3 = nn.Linear(inner, dim, bias=False)
12 | self.act = nn.SiLU()
13 | self.drop = nn.Dropout(dropout)
14 | def forward(self, x):
15 | a = self.w1(x)
16 | b = self.act(self.w2(x))
17 | return self.drop(self.w3(a * b))
--------------------------------------------------------------------------------
/part_6/formatters.py:
--------------------------------------------------------------------------------
1 | """Prompt/response formatting utilities (6.1).
2 | We keep a very simple template with clear separators.
3 | """
4 | from dataclasses import dataclass
5 |
6 | template = (
7 | "\n"
8 | "### Instruction:\n{instruction}\n\n"
9 | "### Response:\n{response}"
10 | )
11 |
12 | @dataclass
13 | class Example:
14 | instruction: str
15 | response: str
16 |
17 |
18 | def format_example(ex: Example) -> str:
19 | return template.format(instruction=ex.instruction.strip(), response=ex.response.strip())
20 |
21 |
22 | def format_prompt_only(instruction: str) -> str:
23 | return template.format(instruction=instruction.strip(), response="")
--------------------------------------------------------------------------------
/part_7/loss_reward.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch, torch.nn.functional as F
3 |
4 | def bradley_terry_loss(r_pos: torch.Tensor, r_neg: torch.Tensor) -> torch.Tensor:
5 | """-log sigma(r_pos - r_neg) = softplus(-(r_pos - r_neg))
6 | https://docs.pytorch.org/docs/stable/generated/torch.nn.Softplus.html"""
7 | diff = r_pos - r_neg
8 | return F.softplus(-diff).mean()
9 |
10 |
11 | def margin_ranking_loss(r_pos: torch.Tensor, r_neg: torch.Tensor, margin: float = 1.0) -> torch.Tensor:
12 | """https://docs.pytorch.org/docs/stable/generated/torch.nn.MarginRankingLoss.html"""
13 | y = torch.ones_like(r_pos)
14 | return F.margin_ranking_loss(r_pos, r_neg, y, margin=margin)
--------------------------------------------------------------------------------
/part_1/block.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from multi_head import MultiHeadSelfAttention
3 | from ffn import FeedForward
4 |
5 | class TransformerBlock(nn.Module):
6 | """1.6 Transformer block = LN → MHA → residual → LN → FFN → residual."""
7 | def __init__(self, d_model: int, n_head: int, dropout: float = 0.0):
8 | super().__init__()
9 | self.ln1 = nn.LayerNorm(d_model)
10 | self.attn = MultiHeadSelfAttention(d_model, n_head, dropout)
11 | self.ln2 = nn.LayerNorm(d_model)
12 | self.ffn = FeedForward(d_model, mult=4, dropout=dropout)
13 |
14 | def forward(self, x):
15 | x = x + self.attn(self.ln1(x))[0]
16 | x = x + self.ffn(self.ln2(x))
17 | return x
--------------------------------------------------------------------------------
/part_1/out/mha_shapes.txt:
--------------------------------------------------------------------------------
1 | Input x: (1, 5, 12) = (B,T,d_model)
2 | Linear qkv(x): (1, 5, 36) = (B,T,3*d_model)
3 | view to 5D: (1, 5, 3, 3, 4) = (B,T,3,heads,d_head)
4 | q,k,v split: q=(1, 5, 3, 4) k=(1, 5, 3, 4) v=(1, 5, 3, 4)
5 | transpose heads: q=(1, 3, 5, 4) k=(1, 3, 5, 4) v=(1, 3, 5, 4) = (B,heads,T,d_head)
6 | scores q@k^T: (1, 3, 5, 5) = (B,heads,T,T)
7 | softmax(weights): (1, 3, 5, 5) = (B,heads,T,T)
8 | context @v: (1, 3, 5, 4) = (B,heads,T,d_head)
9 | merge heads: (1, 5, 12) = (B,T,d_model)
10 | final proj: (1, 5, 12) = (B,T,d_model)
11 |
12 | Legend:
13 | B=batch, T=sequence length, d_model=embedding size, heads=n_head, d_head=d_model/heads
14 | qkv(x) is a single Linear producing [Q|K|V]; we reshape then split into q,k,v
15 |
--------------------------------------------------------------------------------
/part_1/ffn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class FeedForward(nn.Module):
4 | """1.5 FFN with expansion factor `mult`.
5 |
6 | Dimensions:
7 | input: (B, T, d_model)
8 | inner: (B, T, mult*d_model)
9 | output: (B, T, d_model)
10 |
11 | `mult*d_model` means the hidden width is `mult` times larger than `d_model`.
12 | Typical values: mult=4 for GELU FFN in GPT-style blocks.
13 | """
14 | def __init__(self, d_model: int, mult: int = 4, dropout: float = 0.0):
15 | super().__init__()
16 | self.net = nn.Sequential(
17 | nn.Linear(d_model, mult * d_model),
18 | nn.GELU(),
19 | nn.Linear(mult * d_model, d_model),
20 | nn.Dropout(dropout),
21 | )
22 |
23 | def forward(self, x):
24 | return self.net(x)
--------------------------------------------------------------------------------
/part_6/evaluate.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import re
3 | from typing import List, Tuple
4 |
5 | def _normalize(s: str) -> str:
6 | s = s.lower()
7 | s = re.sub(r"[^a-z0-9\s]", " ", s)
8 | s = re.sub(r"\s+", " ", s).strip()
9 | return s
10 |
11 | def exact_match(pred: str, gold: str) -> float:
12 | return float(_normalize(pred) == _normalize(gold))
13 |
14 | def token_f1(pred: str, gold: str) -> float:
15 | p = _normalize(pred).split()
16 | g = _normalize(gold).split()
17 | if not p and not g:
18 | return 1.0
19 | if not p or not g:
20 | return 0.0
21 | common = 0
22 | gp = g.copy()
23 | for t in p:
24 | if t in gp:
25 | gp.remove(t); common += 1
26 | if common == 0:
27 | return 0.0
28 | prec = common / len(p)
29 | rec = common / len(g)
30 | return 2*prec*rec/(prec+rec)
--------------------------------------------------------------------------------
/part_4/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | class WarmupCosineLR:
4 | """Linear warmup → cosine decay (per-step API)."""
5 | def __init__(self, optimizer, warmup_steps: int, total_steps: int, base_lr: float):
6 | self.optimizer = optimizer
7 | self.warmup_steps = max(1, warmup_steps)
8 | self.total_steps = max(self.warmup_steps+1, total_steps)
9 | self.base_lr = base_lr
10 | self.step_num = 0
11 | def step(self):
12 | self.step_num += 1
13 | if self.step_num <= self.warmup_steps:
14 | lr = self.base_lr * self.step_num / self.warmup_steps
15 | else:
16 | progress = (self.step_num - self.warmup_steps) / (self.total_steps - self.warmup_steps)
17 | lr = 0.5 * self.base_lr * (1.0 + math.cos(math.pi * progress))
18 | for g in self.optimizer.param_groups:
19 | g['lr'] = lr
20 | return lr
--------------------------------------------------------------------------------
/part_9/tests/test_grpo_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from grpo_loss import ppo_policy_only_losses
3 |
4 | def test_grpo_clipped_objective_behaves():
5 | N = 32
6 | old_logp = torch.zeros(N) # log π_old = 0 → π_old = 1.0
7 | new_logp = torch.log(torch.full((N,), 1.2)) # log π_new = log(1.2) → ratio=1.2
8 | adv = torch.ones(N) # positive advantage
9 | kl_mean = torch.tensor(0.5) # pretend KL(π||π_ref)=0.5
10 |
11 | out = ppo_policy_only_losses(
12 | new_logp=new_logp,
13 | old_logp=old_logp,
14 | adv=adv,
15 | clip_ratio=0.1,
16 | ent_coef=0.0,
17 | kl_coef=0.1,
18 | kl_mean=kl_mean,
19 | )
20 |
21 | # Ensure scalar loss
22 | assert out.total_loss.ndim == 0
23 | assert torch.isfinite(out.policy_loss)
24 | # KL penalty should have been added
25 | assert torch.allclose(out.kl_ref, kl_mean)
26 |
--------------------------------------------------------------------------------
/part_5/block_hybrid.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch.nn as nn
3 | from moe import MoE
4 |
5 | class HybridFFN(nn.Module):
6 | """Blend dense FFN with MoE output: y = α * Dense(x) + (1−α) * MoE(x).
7 | Use α∈[0,1] to trade between stability (dense) and capacity (MoE).
8 | """
9 | def __init__(self, dim: int, alpha: float = 0.5, mult: int = 4, swiglu: bool = True, n_expert: int = 4, k: int = 1, dropout: float = 0.0):
10 | super().__init__()
11 | self.alpha = alpha
12 | inner = mult * dim
13 | self.dense = nn.Sequential(
14 | nn.Linear(dim, inner), nn.GELU(), nn.Linear(inner, dim), nn.Dropout(dropout)
15 | )
16 | self.moe = MoE(dim, n_expert=n_expert, k=k, mult=mult, swiglu=swiglu, dropout=dropout)
17 | def forward(self, x):
18 | y_dense = self.dense(x)
19 | y_moe, aux = self.moe(x)
20 | y = self.alpha * y_dense + (1.0 - self.alpha) * y_moe
21 | return y, aux
--------------------------------------------------------------------------------
/part_5/experts.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch.nn as nn
3 |
4 | class ExpertMLP(nn.Module):
5 | """Single expert MLP (SwiGLU or GELU)."""
6 | def __init__(self, dim: int, mult: int = 4, swiglu: bool = True, dropout: float = 0.0):
7 | super().__init__()
8 | inner = mult * dim
9 | if swiglu:
10 | self.inp1 = nn.Linear(dim, inner, bias=False)
11 | self.inp2 = nn.Linear(dim, inner, bias=False)
12 | self.act = nn.SiLU()
13 | self.out = nn.Linear(inner, dim, bias=False)
14 | self.drop = nn.Dropout(dropout)
15 | self.swiglu = True
16 | else:
17 | self.ff = nn.Sequential(
18 | nn.Linear(dim, inner), nn.GELU(), nn.Linear(inner, dim), nn.Dropout(dropout)
19 | )
20 | self.swiglu = False
21 | def forward(self, x):
22 | if self.swiglu:
23 | a = self.inp1(x); b = self.act(self.inp2(x))
24 | return self.drop(self.out(a * b))
25 | else:
26 | return self.ff(x)
--------------------------------------------------------------------------------
/part_5/demo_moe.py:
--------------------------------------------------------------------------------
1 | import argparse, torch
2 | from moe import MoE
3 |
4 | if __name__ == "__main__":
5 | p = argparse.ArgumentParser()
6 | p.add_argument('--tokens', type=int, default=64)
7 | p.add_argument('--hidden', type=int, default=128)
8 | p.add_argument('--experts', type=int, default=4)
9 | p.add_argument('--top_k', type=int, default=1)
10 | p.add_argument('--cpu', action='store_true')
11 | args = p.parse_args()
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
14 | x = torch.randn(2, args.tokens//2, args.hidden, device=device) # (B=2,T=tokens/2,C)
15 |
16 | moe = MoE(dim=args.hidden, n_expert=args.experts, k=args.top_k).to(device)
17 | with torch.no_grad():
18 | y, aux = moe(x)
19 |
20 | # simple routing histogram
21 | from gating import TopKGate
22 | gate = moe.gate
23 | idx, w, _ = gate(x.view(-1, args.hidden))
24 | hist = torch.bincount(idx[:,0], minlength=args.experts)
25 | print(f"Output shape: {tuple(y.shape)} | aux={float(aux):.4f}")
26 | print("Primary expert load (counts):", hist.tolist())
--------------------------------------------------------------------------------
/part_1/tests/test_attn_math.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from single_head import SingleHeadSelfAttention
4 |
5 | # mirror the tiny example in attn_numpy_demo.py
6 | X = np.array([[[0.1, 0.2, 0.3, 0.4],
7 | [0.5, 0.4, 0.3, 0.2],
8 | [0.0, 0.1, 0.0, 0.1]]], dtype=np.float32)
9 | Wq = np.array([[ 0.2, -0.1],[ 0.0, 0.1],[ 0.1, 0.2],[-0.1, 0.0]], dtype=np.float32)
10 | Wk = np.array([[ 0.1, 0.1],[ 0.0, -0.1],[ 0.2, 0.0],[ 0.0, 0.2]], dtype=np.float32)
11 | Wv = np.array([[ 0.1, 0.0],[-0.1, 0.1],[ 0.2, -0.1],[ 0.0, 0.2]], dtype=np.float32)
12 |
13 | def test_single_head_matches_numpy():
14 | torch.manual_seed(0)
15 | x = torch.tensor(X)
16 | attn = SingleHeadSelfAttention(d_model=4, d_k=2)
17 | # load weights
18 | with torch.no_grad():
19 | attn.q.weight.copy_(torch.tensor(Wq).t())
20 | attn.k.weight.copy_(torch.tensor(Wk).t())
21 | attn.v.weight.copy_(torch.tensor(Wv).t())
22 | out, w = attn(x)
23 | assert out.shape == (1,3,2)
24 | # Basic numeric sanity
25 | assert torch.isfinite(out).all()
26 | assert torch.isfinite(w).all()
--------------------------------------------------------------------------------
/part_2/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from pathlib import Path
3 | import torch
4 |
5 | class ByteDataset:
6 | """Holds raw bytes of a text file and yields (x,y) blocks for LM.
7 | - block_size: sequence length (context window)
8 | - split: fraction for training (rest is val)
9 | """
10 | def __init__(self, path: str, block_size: int = 256, split: float = 0.9):
11 | data = Path(path).read_bytes()
12 | data = torch.tensor(list(data), dtype=torch.long)
13 | n = int(len(data) * split)
14 | self.train = data[:n]
15 | self.val = data[n:]
16 | self.block_size = block_size
17 |
18 | def get_batch(self, which: str, batch_size: int, device: torch.device):
19 | buf = self.train if which == 'train' else self.val
20 | assert len(buf) > self.block_size + 1, 'file too small for given block_size'
21 | ix = torch.randint(0, len(buf) - self.block_size - 1, (batch_size,))
22 | x = torch.stack([buf[i:i+self.block_size] for i in ix])
23 | y = torch.stack([buf[i+1:i+1+self.block_size] for i in ix])
24 | return x.to(device), y.to(device)
--------------------------------------------------------------------------------
/part_4/tests/test_resume_shapes.py:
--------------------------------------------------------------------------------
1 | import torch, tempfile, os
2 | import torch.nn as nn
3 | from checkpointing import save_checkpoint, load_checkpoint
4 |
5 | class Dummy(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 | self.l = torch.nn.Linear(8,8)
9 |
10 |
11 | def test_save_and_load(tmp_path):
12 | m = Dummy(); opt = torch.optim.AdamW(m.parameters(), lr=1e-3)
13 | class S: pass
14 | sch = S(); sch.__dict__ = {'warmup_steps': 10, 'total_steps': 100, 'base_lr': 1e-3, 'step_num': 5}
15 | class A: pass
16 | amp = A(); amp.scaler = torch.cuda.amp.GradScaler(enabled=False)
17 |
18 | out = tmp_path/"chk"
19 | save_checkpoint(m, opt, sch, amp, step=123, out_dir=str(out), tokenizer_dir=None)
20 |
21 | m2 = Dummy(); opt2 = torch.optim.AdamW(m2.parameters(), lr=1e-3)
22 | sch2 = S(); sch2.__dict__ = {'warmup_steps': 1, 'total_steps': 1, 'base_lr': 1e-3, 'step_num': 0}
23 | amp2 = A(); amp2.scaler = torch.cuda.amp.GradScaler(enabled=False)
24 |
25 | step = load_checkpoint(m2, str(out/"model_last.pt"), optimizer=opt2, scheduler=sch2, amp=amp2)
26 | assert isinstance(step, int)
--------------------------------------------------------------------------------
/part_4/amp_accum.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class AmpGrad:
4 | """AMP + gradient accumulation wrapper.
5 | Usage:
6 | amp = AmpGrad(optimizer, accum=4, amp=True)
7 | amp.backward(loss)
8 | if amp.should_step(): amp.step(); amp.zero_grad()
9 | """
10 | def __init__(self, optimizer, accum: int = 1, amp: bool = True):
11 | self.optim = optimizer
12 | self.accum = max(1, accum)
13 | self.amp = amp and torch.cuda.is_available()
14 | self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
15 | self._n = 0
16 | def backward(self, loss: torch.Tensor):
17 | loss = loss / self.accum
18 | if self.amp:
19 | self.scaler.scale(loss).backward()
20 | else:
21 | loss.backward()
22 | self._n += 1
23 | def should_step(self):
24 | return (self._n % self.accum) == 0
25 | def step(self):
26 | if self.amp:
27 | self.scaler.step(self.optim)
28 | self.scaler.update()
29 | else:
30 | self.optim.step()
31 | def zero_grad(self):
32 | self.optim.zero_grad(set_to_none=True)
--------------------------------------------------------------------------------
/part_2/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch
3 |
4 | def top_k_top_p_filtering(logits: torch.Tensor, top_k: int | None = None, top_p: float | None = None):
5 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
6 | - logits: (B, vocab)
7 | Returns filtered logits with -inf for masked entries.
8 | """
9 | B, V = logits.shape
10 | filtered = logits.clone()
11 |
12 | if top_k is not None and top_k < V:
13 | topk_vals, _ = torch.topk(filtered, top_k, dim=-1)
14 | kth = topk_vals[:, -1].unsqueeze(-1)
15 | filtered[filtered < kth] = float('-inf')
16 |
17 | if top_p is not None and 0 < top_p < 1.0:
18 | sorted_logits, sorted_idx = torch.sort(filtered, descending=True, dim=-1)
19 | probs = torch.softmax(sorted_logits, dim=-1)
20 | cumsum = torch.cumsum(probs, dim=-1)
21 | mask = cumsum > top_p
22 | # keep at least 1 token
23 | mask[..., 0] = False
24 | sorted_logits[mask] = float('-inf')
25 | # Scatter back
26 | filtered = torch.full_like(filtered, float('-inf'))
27 | filtered.scatter_(1, sorted_idx, sorted_logits)
28 |
29 | return filtered
--------------------------------------------------------------------------------
/part_3/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch
3 |
4 | def top_k_top_p_filtering(logits: torch.Tensor, top_k: int | None = None, top_p: float | None = None):
5 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
6 | - logits: (B, vocab)
7 | Returns filtered logits with -inf for masked entries.
8 | """
9 | B, V = logits.shape
10 | filtered = logits.clone()
11 |
12 | if top_k is not None and top_k < V:
13 | topk_vals, _ = torch.topk(filtered, top_k, dim=-1)
14 | kth = topk_vals[:, -1].unsqueeze(-1)
15 | filtered[filtered < kth] = float('-inf')
16 |
17 | if top_p is not None and 0 < top_p < 1.0:
18 | sorted_logits, sorted_idx = torch.sort(filtered, descending=True, dim=-1)
19 | probs = torch.softmax(sorted_logits, dim=-1)
20 | cumsum = torch.cumsum(probs, dim=-1)
21 | mask = cumsum > top_p
22 | # keep at least 1 token
23 | mask[..., 0] = False
24 | sorted_logits[mask] = float('-inf')
25 | # Scatter back
26 | filtered = torch.full_like(filtered, float('-inf'))
27 | filtered.scatter_(1, sorted_idx, sorted_logits)
28 |
29 | return filtered
--------------------------------------------------------------------------------
/part_8/ppo_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch, torch.nn.functional as F
3 | from dataclasses import dataclass
4 |
5 | @dataclass
6 | class PPOLossOut:
7 | policy_loss: torch.Tensor
8 | value_loss: torch.Tensor
9 | entropy: torch.Tensor
10 | approx_kl: torch.Tensor
11 | total_loss: torch.Tensor
12 |
13 |
14 | def ppo_losses(new_logp, old_logp, adv, new_values, old_values, returns,
15 | clip_ratio=0.2, vf_coef=0.5, ent_coef=0.0):
16 | # policy
17 | ratio = torch.exp(new_logp - old_logp) # (N,)
18 | unclipped = ratio * adv
19 | clipped = torch.clamp(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) * adv
20 | policy_loss = -torch.mean(torch.min(unclipped, clipped))
21 |
22 | # value (clip optional → here: simple MSE)
23 | value_loss = F.mse_loss(new_values, returns)
24 |
25 | # entropy bonus (we approximate entropy via -new_logp mean; strictly needs full dist)
26 | entropy = -new_logp.mean()
27 |
28 | # approx KL for logging
29 | approx_kl = torch.mean(old_logp - new_logp)
30 |
31 | total = policy_loss + vf_coef * value_loss - ent_coef * entropy
32 | return PPOLossOut(policy_loss, value_loss, entropy, approx_kl, total)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.3.1
2 | aiohappyeyeballs==2.6.1
3 | aiohttp==3.12.15
4 | aiosignal==1.4.0
5 | attrs==25.3.0
6 | certifi==2025.8.3
7 | charset-normalizer==3.4.3
8 | contourpy==1.3.3
9 | cycler==0.12.1
10 | datasets==4.0.0
11 | dill==0.3.8
12 | filelock==3.19.1
13 | fonttools==4.59.1
14 | frozenlist==1.7.0
15 | fsspec==2025.3.0
16 | grpcio==1.74.0
17 | hf-xet==1.1.7
18 | huggingface-hub==0.34.4
19 | idna==3.10
20 | iniconfig==2.1.0
21 | Jinja2==3.1.6
22 | kiwisolver==1.4.9
23 | Markdown==3.8.2
24 | MarkupSafe==3.0.2
25 | matplotlib==3.10.5
26 | mpmath==1.3.0
27 | multidict==6.6.4
28 | multiprocess==0.70.16
29 | networkx==3.5
30 | numpy==2.3.2
31 | packaging==25.0
32 | pandas==2.3.2
33 | pillow==11.3.0
34 | pluggy==1.6.0
35 | propcache==0.3.2
36 | protobuf==6.32.0
37 | pyarrow==21.0.0
38 | Pygments==2.19.2
39 | pyparsing==3.2.3
40 | pytest==8.4.1
41 | python-dateutil==2.9.0.post0
42 | pytz==2025.2
43 | PyYAML==6.0.2
44 | requests==2.32.4
45 | six==1.17.0
46 | sympy==1.14.0
47 | tensorboard==2.20.0
48 | tensorboard-data-server==0.7.2
49 | tokenizers==0.21.4
50 | torch==2.8.0
51 | tqdm==4.67.1
52 | typing_extensions==4.14.1
53 | tzdata==2025.2
54 | urllib3==2.5.0
55 | Werkzeug==3.1.3
56 | xxhash==3.5.0
57 | yarl==1.20.1
58 |
--------------------------------------------------------------------------------
/part_4/dataset_bpe.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch
3 | from torch.utils.data import Dataset, DataLoader
4 | from pathlib import Path
5 | from typing import Tuple
6 | from tokenizer_bpe import BPETokenizer
7 |
8 | class TextBPEBuffer(Dataset):
9 | """Memory-mapped-ish single-file dataset: tokenize once → long tensor of ids.
10 | get(idx) returns a (block_size,) slice; we construct (x,y) with shift inside collate.
11 | """
12 | def __init__(self, path: str, tokenizer: BPETokenizer, block_size: int = 256):
13 | super().__init__()
14 | self.block_size = block_size
15 | text = Path(path).read_text(encoding='utf-8')
16 | self.ids = torch.tensor(tokenizer.encode(text), dtype=torch.long)
17 | def __len__(self):
18 | return max(0, self.ids.numel() - self.block_size - 1)
19 | def __getitem__(self, i: int):
20 | x = self.ids[i:i+self.block_size]
21 | y = self.ids[i+1:i+self.block_size+1]
22 | return x, y
23 |
24 | def make_loader(path: str, tokenizer: BPETokenizer, block_size: int, batch_size: int, shuffle=True) -> DataLoader:
25 | ds = TextBPEBuffer(path, tokenizer, block_size)
26 | return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, drop_last=True)
--------------------------------------------------------------------------------
/part_3/block_modern.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from rmsnorm import RMSNorm
3 | from swiglu import SwiGLU
4 | from attn_modern import CausalSelfAttentionModern
5 |
6 | class TransformerBlockModern(nn.Module):
7 | def __init__(self, n_embd: int, n_head: int, dropout: float = 0.0,
8 | use_rmsnorm: bool = True, use_swiglu: bool = True,
9 | rope: bool = True, max_pos: int = 4096,
10 | sliding_window: int | None = None, attention_sink: int = 0, n_kv_head: int | None = None):
11 | super().__init__()
12 | Norm = RMSNorm if use_rmsnorm else nn.LayerNorm
13 | self.ln1 = Norm(n_embd)
14 | self.attn = CausalSelfAttentionModern(n_embd, n_head, dropout, rope, max_pos, sliding_window, attention_sink, n_kv_head)
15 | self.ln2 = Norm(n_embd)
16 | self.ffn = SwiGLU(n_embd, mult=4, dropout=dropout) if use_swiglu else nn.Sequential(
17 | nn.Linear(n_embd, 4*n_embd), nn.GELU(), nn.Linear(4*n_embd, n_embd), nn.Dropout(dropout)
18 | )
19 | def forward(self, x, kv_cache=None, start_pos: int = 0):
20 | a, kv_cache = self.attn(self.ln1(x), kv_cache=kv_cache, start_pos=start_pos)
21 | x = x + a
22 | x = x + self.ffn(self.ln2(x))
23 | return x, kv_cache
--------------------------------------------------------------------------------
/part_1/pos_encoding.py:
--------------------------------------------------------------------------------
1 | """1.1 Positional encodings (absolute learned + sinusoidal)."""
2 | import math
3 | import torch
4 | import torch.nn as nn
5 |
6 | class LearnedPositionalEncoding(nn.Module):
7 | def __init__(self, max_len: int, d_model: int):
8 | super().__init__()
9 | self.emb = nn.Embedding(max_len, d_model)
10 |
11 | def forward(self, x: torch.Tensor):
12 | # x: (B, T, d_model) — we only need its T and device
13 | B, T, _ = x.shape
14 | pos = torch.arange(T, device=x.device)
15 | pos_emb = self.emb(pos) # (T, d_model)
16 | return x + pos_emb.unsqueeze(0) # broadcast over batch
17 |
18 | class SinusoidalPositionalEncoding(nn.Module):
19 | def __init__(self, max_len: int, d_model: int):
20 | super().__init__()
21 | pe = torch.zeros(max_len, d_model)
22 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
23 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
24 | pe[:, 0::2] = torch.sin(position * div_term)
25 | pe[:, 1::2] = torch.cos(position * div_term)
26 | self.register_buffer('pe', pe) # (max_len, d_model)
27 |
28 | def forward(self, x: torch.Tensor):
29 | B, T, _ = x.shape
30 | return x + self.pe[:T].unsqueeze(0)
--------------------------------------------------------------------------------
/part_3/kv_cache.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch
3 | from dataclasses import dataclass
4 |
5 | @dataclass
6 | class KVCache:
7 | k: torch.Tensor # (B,H,T,D)
8 | v: torch.Tensor # (B,H,T,D)
9 |
10 | @property
11 | def T(self):
12 | return self.k.size(2)
13 |
14 | class RollingKV:
15 | """Rolling buffer with optional attention sink.
16 | Keeps first `sink` tokens + last `window` tokens.
17 | """
18 | def __init__(self, window: int, sink: int = 0):
19 | self.window = window
20 | self.sink = sink
21 | self.k = None
22 | self.v = None
23 | def step(self, k_new: torch.Tensor, v_new: torch.Tensor):
24 | if self.k is None:
25 | self.k, self.v = k_new, v_new
26 | else:
27 | self.k = torch.cat([self.k, k_new], dim=2)
28 | self.v = torch.cat([self.v, v_new], dim=2)
29 | # crop
30 | if self.k.size(2) > self.window + self.sink:
31 | sink_part = self.k[:, :, :self.sink, :]
32 | sink_val = self.v[:, :, :self.sink, :]
33 | tail_k = self.k[:, :, -self.window:, :]
34 | tail_v = self.v[:, :, -self.window:, :]
35 | self.k = torch.cat([sink_part, tail_k], dim=2)
36 | self.v = torch.cat([sink_val, tail_v], dim=2)
37 | return self.k, self.v
--------------------------------------------------------------------------------
/part_1/single_head.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from attn_mask import causal_mask
6 |
7 | class SingleHeadSelfAttention(nn.Module):
8 | """1.3 Single-head attention (explicit shapes)."""
9 | def __init__(self, d_model: int, d_k: int, dropout: float = 0.0, trace_shapes: bool = False):
10 | super().__init__()
11 | self.q = nn.Linear(d_model, d_k, bias=False)
12 | self.k = nn.Linear(d_model, d_k, bias=False)
13 | self.v = nn.Linear(d_model, d_k, bias=False)
14 | self.dropout = nn.Dropout(dropout)
15 | self.trace_shapes = trace_shapes
16 |
17 | def forward(self, x: torch.Tensor): # x: (B, T, d_model)
18 | B, T, _ = x.shape
19 | q = self.q(x) # (B,T,d_k)
20 | k = self.k(x) # (B,T,d_k)
21 | v = self.v(x) # (B,T,d_k)
22 | if self.trace_shapes:
23 | print(f"q {q.shape} k {k.shape} v {v.shape}")
24 | scale = 1.0 / math.sqrt(q.size(-1))
25 | attn = torch.matmul(q, k.transpose(-2, -1)) * scale # (B,T,T)
26 | mask = causal_mask(T, device=x.device)
27 | attn = attn.masked_fill(mask.squeeze(1), float('-inf'))
28 | w = F.softmax(attn, dim=-1)
29 | w = self.dropout(w)
30 | out = torch.matmul(w, v) # (B,T,d_k)
31 | if self.trace_shapes:
32 | print(f"weights {w.shape} out {out.shape}")
33 | return out, w
--------------------------------------------------------------------------------
/part_2/eval_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import argparse, torch
3 | from dataset import ByteDataset
4 | from model_gpt import GPT
5 |
6 |
7 | def main():
8 | p = argparse.ArgumentParser()
9 | p.add_argument('--data', type=str, required=True)
10 | p.add_argument('--ckpt', type=str, required=True)
11 | p.add_argument('--block_size', type=int, default=256)
12 | p.add_argument('--batch_size', type=int, default=32)
13 | p.add_argument('--iters', type=int, default=100)
14 | p.add_argument('--cpu', action='store_true')
15 | args = p.parse_args()
16 |
17 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
18 |
19 | ds = ByteDataset(args.data, block_size=args.block_size)
20 | ckpt = torch.load(args.ckpt, map_location=device)
21 | cfg = ckpt.get('config', {
22 | 'vocab_size': 256,
23 | 'block_size': args.block_size,
24 | 'n_layer': 4,
25 | 'n_head': 4,
26 | 'n_embd': 256,
27 | 'dropout': 0.0,
28 | })
29 | model = GPT(**cfg).to(device)
30 | model.load_state_dict(ckpt['model'])
31 |
32 | model.eval()
33 | losses = []
34 | with torch.no_grad():
35 | for _ in range(args.iters):
36 | xb, yb = ds.get_batch('val', args.batch_size, device)
37 | _, loss = model(xb, yb)
38 | losses.append(loss.item())
39 | print(f"val loss: {sum(losses)/len(losses):.4f}")
40 |
41 |
42 | if __name__ == '__main__':
43 | main()
--------------------------------------------------------------------------------
/part_3/tests/test_rope_apply.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from rope_custom import RoPECache, apply_rope_single
3 |
4 | def test_rope_rotation_shapes_single():
5 | # Vanilla case: same #heads for q and k
6 | B, H, T, D = 1, 2, 5, 8
7 | rc = RoPECache(head_dim=D, max_pos=32)
8 | q = torch.randn(B, H, T, D)
9 | k = torch.randn(B, H, T, D)
10 | pos = torch.arange(0, T)
11 | cos, sin = rc.get(pos)
12 |
13 | q2 = apply_rope_single(q, cos, sin)
14 | k2 = apply_rope_single(k, cos, sin)
15 |
16 | assert q2.shape == q.shape
17 | assert k2.shape == k.shape
18 | # check that rotation mixes even/odd features (values should change)
19 | assert not torch.allclose(q2, q)
20 | assert not torch.allclose(k2, k)
21 |
22 | def test_rope_rotation_shapes_gqa():
23 | # GQA case: q has H heads; k has fewer Hk heads (shared KV)
24 | B, H, Hk, T, D = 2, 8, 2, 7, 16
25 | rc = RoPECache(head_dim=D, max_pos=128)
26 | q = torch.randn(B, H, T, D)
27 | k = torch.randn(B, Hk, T, D)
28 | pos = torch.arange(10, 10 + T) # arbitrary start position
29 | cos, sin = rc.get(pos)
30 |
31 | q2 = apply_rope_single(q, cos, sin)
32 | k2 = apply_rope_single(k, cos, sin)
33 |
34 | assert q2.shape == (B, H, T, D)
35 | assert k2.shape == (B, Hk, T, D)
36 | # rotations should be deterministic for same positions
37 | # check a couple of coords moved as expected (values changed)
38 | assert not torch.allclose(q2, q)
39 | assert not torch.allclose(k2, k)
40 |
--------------------------------------------------------------------------------
/part_1/vis_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 |
5 | OUT_DIR = os.path.join(os.path.dirname(__file__), 'out')
6 |
7 |
8 | def _ensure_out():
9 | os.makedirs(OUT_DIR, exist_ok=True)
10 |
11 |
12 | def save_matrix_heatmap(mat: np.ndarray, title: str, filename: str, xlabel: str = '', ylabel: str = ''):
13 | """Generic matrix heatmap saver.
14 | Do not set any specific colors/styles; keep defaults for clarity.
15 | """
16 | _ensure_out()
17 | plt.figure()
18 | plt.imshow(mat, aspect='auto')
19 | plt.title(title)
20 | plt.xlabel(xlabel)
21 | plt.ylabel(ylabel)
22 | plt.colorbar()
23 | path = os.path.join(OUT_DIR, filename)
24 | plt.savefig(path, bbox_inches='tight')
25 | plt.close()
26 | print(f"Saved: {path}")
27 |
28 |
29 | def save_attention_heads_grid(weights: np.ndarray, filename: str, title_prefix: str = "Head"):
30 | """Plot all heads in a single grid figure (B=1 assumed).
31 | weights: (1, H, T, T)
32 | """
33 | _ensure_out()
34 | _, H, T, _ = weights.shape
35 | cols = min(4, H)
36 | rows = (H + cols - 1) // cols
37 | plt.figure(figsize=(3*cols, 3*rows))
38 | for h in range(H):
39 | ax = plt.subplot(rows, cols, h+1)
40 | ax.imshow(weights[0, h], aspect='auto')
41 | ax.set_title(f"{title_prefix} {h}")
42 | ax.set_xlabel('Key pos')
43 | ax.set_ylabel('Query pos')
44 | plt.tight_layout()
45 | path = os.path.join(OUT_DIR, filename)
46 | plt.savefig(path, bbox_inches='tight')
47 | plt.close()
48 | print(f"Saved: {path}")
--------------------------------------------------------------------------------
/part_5/moe.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch, torch.nn as nn
3 | from gating import TopKGate
4 | from experts import ExpertMLP
5 |
6 | class MoE(nn.Module):
7 | """Mixture‑of‑Experts layer (token‑wise top‑k routing).
8 | Implementation is single‑GPU friendly (loops over experts for clarity).
9 | https://arxiv.org/pdf/2101.03961
10 | """
11 | def __init__(self, dim: int, n_expert: int, k: int = 1, mult: int = 4, swiglu: bool = True, dropout: float = 0.0):
12 | super().__init__()
13 | self.dim = dim
14 | self.n_expert = n_expert
15 | self.k = k
16 | self.gate = TopKGate(dim, n_expert, k=k)
17 | self.experts = nn.ModuleList([ExpertMLP(dim, mult=mult, swiglu=swiglu, dropout=dropout) for _ in range(n_expert)])
18 |
19 | def forward(self, x: torch.Tensor):
20 | """x: (B, T, C) → y: (B, T, C), aux_loss
21 | Steps: flatten tokens → gate → per‑expert forward → scatter back with weights.
22 | """
23 | B, T, C = x.shape
24 | S = B * T
25 | x_flat = x.reshape(S, C)
26 | idx, w, aux = self.gate(x_flat) # (S,k), (S,k)
27 |
28 | y = torch.zeros_like(x_flat) # (S,C)
29 | for e in range(self.n_expert):
30 | # tokens where expert e is selected at any of k slots
31 | for slot in range(self.k):
32 | sel = (idx[:, slot] == e)
33 | if sel.any():
34 | x_e = x_flat[sel]
35 | y_e = self.experts[e](x_e)
36 | y[sel] += w[sel, slot:slot+1] * y_e
37 | y = y.view(B, T, C)
38 | return y, aux
--------------------------------------------------------------------------------
/part_7/model_reward.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch, torch.nn as nn
3 |
4 | class RewardModel(nn.Module):
5 | """Transformer encoder → pooled representation → scalar reward.
6 | Bidirectional encoder is fine for reward modeling (not used for generation).
7 | """
8 | def __init__(self, vocab_size: int, block_size: int, n_layer: int = 4, n_head: int = 4, n_embd: int = 256, dropout: float = 0.1):
9 | super().__init__()
10 | self.vocab_size = vocab_size
11 | self.block_size = block_size
12 | self.tok_emb = nn.Embedding(vocab_size, n_embd)
13 | self.pos_emb = nn.Embedding(block_size, n_embd)
14 | enc_layer = nn.TransformerEncoderLayer(d_model=n_embd, nhead=n_head, dim_feedforward=4*n_embd,
15 | dropout=dropout, activation='gelu', batch_first=True)
16 | self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layer)
17 | self.ln = nn.LayerNorm(n_embd)
18 | self.head = nn.Linear(n_embd, 1)
19 |
20 | def forward(self, x: torch.Tensor):
21 | B, T = x.shape
22 | pos = torch.arange(T, device=x.device).unsqueeze(0)
23 | h = self.tok_emb(x) + self.pos_emb(pos)
24 | pad_mask = (x == 2)
25 | h = self.encoder(h, src_key_padding_mask=pad_mask)
26 | h = self.ln(h)
27 | # masked mean pool over tokens (ignoring pads)
28 | mask = (~pad_mask).float().unsqueeze(-1)
29 | h_sum = (h * mask).sum(dim=1)
30 | len_ = mask.sum(dim=1).clamp_min(1.0)
31 | pooled = h_sum / len_
32 | r = self.head(pooled).squeeze(-1) # (B,)
33 | return r
--------------------------------------------------------------------------------
/part_8/policy.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch, torch.nn as nn
3 | import sys
4 | from pathlib import Path as _P
5 | # Try user’s structure first
6 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3'))
7 | try:
8 | from model_utils.model_modern import GPTModern # user-custom path
9 | except Exception:
10 | from model_modern import GPTModern # fallback
11 |
12 | class PolicyWithValue(nn.Module):
13 | """Policy network = SFT LM + tiny value head.
14 | NOTE: For simplicity we place value head on top of LM logits (vocab→1).
15 | This avoids depending on hidden-state internals while keeping the tutorial runnable.
16 | """
17 | def __init__(self, vocab_size: int, block_size: int, n_layer=4, n_head=4, n_embd=256,
18 | use_rmsnorm=True, use_swiglu=True, rope=True, dropout=0.0):
19 | super().__init__()
20 | self.lm = GPTModern(vocab_size=vocab_size, block_size=block_size, n_layer=n_layer,
21 | n_head=n_head, n_embd=n_embd, use_rmsnorm=use_rmsnorm,
22 | use_swiglu=use_swiglu, rope=rope, dropout=dropout)
23 | # value head over logits (toy). Shapes: (B,T,V) -> (B,T,1) -> (B,T)
24 | self.val_head = nn.Linear(vocab_size, 1, bias=False)
25 |
26 | def forward(self, x: torch.Tensor, y: torch.Tensor | None = None):
27 | # Delegate LM forward; returns logits (B,T,V), loss, _
28 | logits, loss, _ = self.lm(x, y)
29 | values = self.val_head(logits).squeeze(-1) # (B,T)
30 | return logits, values, loss
31 |
32 | def generate(self, *args, **kwargs):
33 | return self.lm.generate(*args, **kwargs)
--------------------------------------------------------------------------------
/part_9/policy.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch, torch.nn as nn
3 | import sys
4 | from pathlib import Path as _P
5 | # Try user’s structure first
6 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3'))
7 | try:
8 | from model_utils.model_modern import GPTModern # user-custom path
9 | except Exception:
10 | from model_modern import GPTModern # fallback
11 |
12 | class PolicyWithValue(nn.Module):
13 | """Policy network = SFT LM + tiny value head.
14 | NOTE: For simplicity we place value head on top of LM logits (vocab→1).
15 | This avoids depending on hidden-state internals while keeping the tutorial runnable.
16 | """
17 | def __init__(self, vocab_size: int, block_size: int, n_layer=4, n_head=4, n_embd=256,
18 | use_rmsnorm=True, use_swiglu=True, rope=True, dropout=0.0):
19 | super().__init__()
20 | self.lm = GPTModern(vocab_size=vocab_size, block_size=block_size, n_layer=n_layer,
21 | n_head=n_head, n_embd=n_embd, use_rmsnorm=use_rmsnorm,
22 | use_swiglu=use_swiglu, rope=rope, dropout=dropout)
23 | # value head over logits (toy). Shapes: (B,T,V) -> (B,T,1) -> (B,T)
24 | self.val_head = nn.Linear(vocab_size, 1, bias=False)
25 |
26 | def forward(self, x: torch.Tensor, y: torch.Tensor | None = None):
27 | # Delegate LM forward; returns logits (B,T,V), loss, _
28 | logits, loss, _ = self.lm(x, y)
29 | values = self.val_head(logits).squeeze(-1) # (B,T)
30 | return logits, values, loss
31 |
32 | def generate(self, *args, **kwargs):
33 | return self.lm.generate(*args, **kwargs)
--------------------------------------------------------------------------------
/part_2/sample.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import argparse, torch
3 | from tokenizer import ByteTokenizer
4 | from model_gpt import GPT
5 |
6 |
7 | def main():
8 | p = argparse.ArgumentParser()
9 | p.add_argument('--ckpt', type=str, required=True)
10 | p.add_argument('--prompt', type=str, default='')
11 | p.add_argument('--tokens', type=int, default=200)
12 | p.add_argument('--temperature', type=float, default=1.0)
13 | p.add_argument('--top_k', type=int, default=50)
14 | p.add_argument('--top_p', type=float, default=None)
15 | p.add_argument('--cpu', action='store_true')
16 | args = p.parse_args()
17 |
18 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
19 |
20 | tok = ByteTokenizer()
21 | prompt_ids = tok.encode(args.prompt).unsqueeze(0).to(device)
22 | if prompt_ids.numel() == 0:
23 | # If no prompt provided, seed with newline byte (10)
24 | prompt_ids = torch.tensor([[10]], dtype=torch.long, device=device)
25 |
26 |
27 | ckpt = torch.load(args.ckpt, map_location=device)
28 | config = ckpt.get('config', None)
29 |
30 | if config is None:
31 | # fall back to defaults
32 | model = GPT(tok.vocab_size, block_size=256).to(device)
33 | model.load_state_dict(ckpt['model'])
34 | else:
35 | model = GPT(**config).to(device)
36 | model.load_state_dict(ckpt['model'])
37 |
38 | with torch.no_grad():
39 | out = model.generate(prompt_ids, max_new_tokens=args.tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p)
40 | print(tok.decode(out[0].cpu()))
41 |
42 |
43 | if __name__ == '__main__':
44 | main()
--------------------------------------------------------------------------------
/part_9/grpo_loss.py:
--------------------------------------------------------------------------------
1 | # grpo_loss.py
2 | from __future__ import annotations
3 | import torch
4 | from dataclasses import dataclass
5 |
6 | @dataclass
7 | class PolicyOnlyLossOut:
8 | policy_loss: torch.Tensor
9 | entropy: torch.Tensor
10 | approx_kl: torch.Tensor
11 | kl_ref: torch.Tensor
12 | total_loss: torch.Tensor
13 |
14 |
15 | def ppo_policy_only_losses(new_logp, old_logp, adv, clip_ratio=0.2, ent_coef=0.0,
16 | kl_coef: float = 0.0, kl_mean: torch.Tensor | None = None):
17 | """
18 | PPO-style clipped policy loss, *policy only* (no value head),
19 | plus a separate KL(π||π_ref) penalty term: total = L_PPO + kl_coef * KL.
20 | Inputs are flat over action tokens: new_logp, old_logp, adv: (N_act,).
21 | kl_mean is a scalar tensor (mean over action tokens).
22 | """
23 | device = new_logp.device if new_logp.is_cuda else None
24 | if new_logp.numel() == 0:
25 | zero = torch.tensor(0.0, device=device)
26 | return PolicyOnlyLossOut(zero, zero, zero, zero, zero)
27 |
28 | ratio = torch.exp(new_logp - old_logp) # (N,)
29 | unclipped = ratio * adv
30 | clipped = torch.clamp(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) * adv
31 | policy_loss = -torch.mean(torch.min(unclipped, clipped))
32 |
33 | entropy = -new_logp.mean() if ent_coef != 0.0 else new_logp.new_tensor(0.0)
34 | approx_kl = torch.mean(old_logp - new_logp)
35 |
36 | kl_ref = kl_mean if kl_mean is not None else new_logp.new_tensor(0.0)
37 |
38 | total = policy_loss - ent_coef * entropy + kl_coef * kl_ref # entropy bonus was not used in original GRPO paper
39 | return PolicyOnlyLossOut(policy_loss, entropy, approx_kl, kl_ref, total)
40 |
--------------------------------------------------------------------------------
/part_5/orchestrator.py:
--------------------------------------------------------------------------------
1 | # Repository layout (Part 5)
2 | #
3 | # part_5/
4 | # orchestrator.py # run unit tests + optional MoE demo
5 | # README.md # 5.1/5.3 concept notes (compact MD)
6 | # gating.py # router/gating (top‑k) + load‑balancing aux loss
7 | # experts.py # MLP experts (SwiGLU or GELU)
8 | # moe.py # Mixture-of-Experts layer (dispatch/combine)
9 | # block_hybrid.py # Hybrid dense+MoE block examples
10 | # demo_moe.py # small forward pass demo + routing histogram
11 | # tests/
12 | # test_gate_shapes.py
13 | # test_moe_forward.py
14 | # test_hybrid_block.py
15 | #
16 | # Run from inside `part_5/`:
17 | # cd part_5
18 | # python orchestrator.py --demo
19 | # pytest -q
20 |
21 | import argparse, pathlib, subprocess, sys, shlex
22 |
23 | ROOT = pathlib.Path(__file__).resolve().parent
24 |
25 | def run(cmd: str):
26 | print(f"\n>>> {cmd}")
27 | res = subprocess.run(shlex.split(cmd), cwd=ROOT)
28 | if res.returncode != 0:
29 | sys.exit(res.returncode)
30 |
31 | if __name__ == "__main__":
32 | p = argparse.ArgumentParser()
33 | p.add_argument("--demo", action="store_true", help="run a tiny MoE demo")
34 | args = p.parse_args()
35 |
36 | # 1) unit tests
37 | run("python -m pytest -q tests/test_gate_shapes.py")
38 | run("python -m pytest -q tests/test_moe_forward.py")
39 | run("python -m pytest -q tests/test_hybrid_block.py")
40 |
41 | # 2) optional demo
42 | if args.demo:
43 | run("python demo_moe.py --tokens 6 --hidden 128 --experts 4 --top_k 1")
44 |
45 | print("\nPart 5 checks complete. ✅")
--------------------------------------------------------------------------------
/part_6/dataset_sft.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import List, Dict, Tuple
3 | from dataclasses import dataclass
4 | import os
5 | import traceback
6 |
7 | try:
8 | from datasets import load_dataset
9 | except Exception:
10 | print("Couldn't import `datasets`. Will use fallback data only.")
11 | load_dataset = None
12 |
13 | from formatters import Example
14 |
15 | @dataclass
16 | class SFTItem:
17 | prompt: str
18 | response: str
19 |
20 |
21 | def load_tiny_hf(split: str = "train[:200]", sample_dataset: bool = False) -> List[SFTItem]:
22 | """Try to load a tiny instruction dataset from HF; fall back to a baked-in list.
23 | We use `tatsu-lab/alpaca` as a familiar schema (instruction, input, output) and keep only a slice.
24 | """
25 | items: List[SFTItem] = []
26 | if load_dataset is not None and not sample_dataset:
27 | try:
28 | ds = load_dataset("tatsu-lab/alpaca", split=split)
29 | for row in ds:
30 | instr = row.get("instruction", "").strip()
31 | inp = row.get("input", "").strip()
32 | out = row.get("output", "").strip()
33 | if inp:
34 | instr = instr + "\n" + inp
35 | if instr and out:
36 | items.append(SFTItem(prompt=instr, response=out))
37 | except Exception:
38 | pass
39 | if not items:
40 | # fallback tiny list
41 | seeds = [
42 | ("First prime number", "2"),
43 | ("What are the three primary colors?", "red"),
44 | ("Device name which points to direction?", "compass"),
45 | ]
46 | items = [SFTItem(prompt=p, response=r) for p,r in seeds]
47 | return items
--------------------------------------------------------------------------------
/part_7/eval_rm.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import argparse, torch
3 | from data_prefs import load_preferences
4 | from collator_rm import PairCollator
5 | from model_reward import RewardModel
6 |
7 |
8 | def main():
9 | p = argparse.ArgumentParser()
10 | p.add_argument('--ckpt', type=str, required=True)
11 | p.add_argument('--split', type=str, default='val[:200]')
12 | p.add_argument('--cpu', action='store_true')
13 | p.add_argument('--bpe_dir', type=str, default=None)
14 | args = p.parse_args()
15 |
16 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
17 |
18 | items = load_preferences(split=args.split)
19 | triples = [(it.prompt, it.chosen, it.rejected) for it in items]
20 |
21 | col = PairCollator(block_size=256, bpe_dir=args.bpe_dir)
22 | ckpt = torch.load(args.ckpt, map_location=device)
23 | cfg = ckpt.get('config', {})
24 |
25 | model = RewardModel(vocab_size=cfg.get('vocab_size', col.vocab_size), block_size=cfg.get('block_size', 256),
26 | n_layer=cfg.get('n_layer', 4), n_head=cfg.get('n_head', 4), n_embd=cfg.get('n_embd', 256))
27 | model.load_state_dict(ckpt['model'])
28 | model.to(device).eval()
29 |
30 | # Evaluate accuracy r_pos>r_neg
31 | import math
32 | B = 16
33 | correct = 0; total = 0
34 | for i in range(0, len(triples), B):
35 | batch = triples[i:i+B]
36 | pos, neg = col.collate(batch)
37 | pos, neg = pos.to(device), neg.to(device)
38 | with torch.no_grad():
39 | r_pos = model(pos)
40 | r_neg = model(neg)
41 | correct += (r_pos > r_neg).sum().item()
42 | total += pos.size(0)
43 | acc = correct / max(1, total)
44 | print(f"pairs={total} accuracy (r_pos>r_neg) = {acc:.3f}")
45 |
46 | if __name__ == '__main__':
47 | main()
--------------------------------------------------------------------------------
/part_7/data_prefs.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from dataclasses import dataclass
3 | from typing import List, Tuple
4 |
5 | try:
6 | from datasets import load_dataset
7 | except Exception:
8 | load_dataset = None
9 |
10 | @dataclass
11 | class PrefExample:
12 | prompt: str
13 | chosen: str
14 | rejected: str
15 |
16 |
17 | def load_preferences(split: str = "train[:200]") -> List[PrefExample]:
18 | """Load a tiny preference set. Tries Anthropic HH; falls back to a toy set.
19 | HH fields: 'chosen', 'rejected' (full conversations). We use an empty prompt.
20 | """
21 | items: List[PrefExample] = []
22 | if load_dataset is not None:
23 | try:
24 | ds = load_dataset("Anthropic/hh-rlhf", split=split)
25 | for row in ds:
26 | ch = str(row.get("chosen", "")).strip()
27 | rj = str(row.get("rejected", "")).strip()
28 | if ch and rj:
29 | items.append(PrefExample(prompt="", chosen=ch, rejected=rj))
30 | except Exception:
31 | print("Failed to load Anthropic/hh-rlhf dataset. Using fallback toy pairs.")
32 | pass
33 | if not items:
34 | # fallback toy pairs
35 | items = [
36 | PrefExample("Summarize: Scaling laws for neural language models.",
37 | "Scaling laws describe how performance improves predictably as model size, data, and compute increase.",
38 | "Scaling laws are when you scale pictures to look bigger."),
39 | PrefExample("Give two uses of attention in transformers.",
40 | "It lets the model focus on relevant tokens and enables parallel context integration across positions.",
41 | "It remembers all past words exactly without any computation."),
42 | ]
43 | return items
--------------------------------------------------------------------------------
/part_5/gating.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch, torch.nn as nn
3 |
4 | class TopKGate(nn.Module):
5 | """Top‑k softmax gating with Switch‑style load‑balancing aux loss.
6 | Args:
7 | dim: input hidden size
8 | n_expert: number of experts
9 | k: number of experts to route per token (1 or 2 typical)
10 | Returns:
11 | (indices, weights, aux_loss) where
12 | indices: (S, k) long, expert ids for each token
13 | weights: (S, k) float, gate weights (sum ≤ 1 per token)
14 | aux_loss: scalar load‑balancing penalty
15 | """
16 | def __init__(self, dim: int, n_expert: int, k: int = 1):
17 | super().__init__()
18 | assert k >= 1 and k <= n_expert
19 | self.n_expert = n_expert
20 | self.k = k
21 | self.w_g = nn.Linear(dim, n_expert, bias=True)
22 |
23 | def forward(self, x: torch.Tensor):
24 | # x: (S, C) where S = tokens (batch * seq)
25 | logits = self.w_g(x) # (S, E)
26 | probs = torch.softmax(logits, dim=-1) # (S, E)
27 | topk_vals, topk_idx = torch.topk(probs, k=self.k, dim=-1) # (S,k)
28 |
29 | # Load‑balancing aux loss (Switch):
30 | S, E = probs.size(0), probs.size(1)
31 | # importance: avg prob per expert
32 | importance = probs.mean(dim=0) # (E,)
33 | # load: fraction of tokens assigned as primary (top‑1 hard assignment)
34 | hard1 = topk_idx[:, 0] # (S,)
35 | load = torch.zeros(E, device=x.device)
36 | load.scatter_add_(0, hard1, torch.ones_like(hard1, dtype=load.dtype))
37 | load = load / max(S, 1)
38 | aux_loss = (E * (importance * load).sum())
39 | # print("*"*50)
40 | # print(probs, importance, hard1, load, aux_loss)
41 | # print("*"*50)
42 |
43 | return topk_idx, topk_vals, aux_loss
--------------------------------------------------------------------------------
/part_9/orchestrator.py:
--------------------------------------------------------------------------------
1 | # Repository layout (Part 9 — RLHF with GRPO)
2 | #
3 | # part_9/
4 | # orchestrator.py # run unit tests + optional tiny PPO demo
5 | # policy.py # policy = SFT LM + value head (toy head on logits)
6 | # rollout.py # prompt formatting, sampling, logprobs/KL utilities
7 | # grpo_loss.py # PPO clipped objective + value + entropy + KL penalty
8 | # train_ppo.py # single‑GPU RLHF loop (tiny, on‑policy)
9 | # eval_ppo.py # compare reward vs. reference on a small set
10 | # tests/
11 | # test_ppo_loss.py
12 | # test_policy_forward.py
13 | #
14 | # Run from inside `part_8/`:
15 | # cd part_8
16 | # python orchestrator.py --demo
17 | # pytest -q
18 |
19 | import argparse, pathlib, subprocess, sys
20 | ROOT = pathlib.Path(__file__).resolve().parent
21 |
22 | def run(cmd: str):
23 | print(f"\n>>> {cmd}")
24 | res = subprocess.run(cmd.split(), cwd=ROOT)
25 | if res.returncode != 0:
26 | sys.exit(res.returncode)
27 |
28 | if __name__ == "__main__":
29 | p = argparse.ArgumentParser()
30 | p.add_argument("--demo", action="store_true", help="tiny GRPO demo")
31 | args = p.parse_args()
32 |
33 | # 1) unit tests
34 | run("python -m pytest -q tests/test_grpo_loss.py")
35 |
36 | # 2) optional demo (requires SFT+RM checkpoints from Parts 6 & 7)
37 | if args.demo:
38 | run("python train_grpo.py --group_size 4 --policy_ckpt ../part_6/runs/sft-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --steps 200 --batch_prompts 4 --resp_len 128 --bpe_dir ../part_4/runs/part4-demo/tokenizer")
39 | run("python eval_ppo.py --policy_ckpt runs/grpo-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --split train[:24] --bpe_dir ../part_4/runs/part4-demo/tokenizer")
40 |
41 | print("\nPart 9 checks complete. ✅")
42 |
--------------------------------------------------------------------------------
/part_3/orchestrator.py:
--------------------------------------------------------------------------------
1 | # Repository layout (Part 3)
2 | #
3 | # part_3/
4 | # orchestrator.py # runs tests + a small generation demo
5 | # tokenizer.py # local byte-level tokenizer (self-contained)
6 | # rmsnorm.py # 3.1 RMSNorm
7 | # rope.py # 3.2 RoPE cache + apply
8 | # swiglu.py # 3.3 SwiGLU FFN
9 | # kv_cache.py # 3.4/3.6 KV cache + rolling buffer
10 | # attn_modern.py # attention w/ RoPE, sliding window, sink, optional KV cache
11 | # block_modern.py # block = (RMSNorm|LN) + modern attention + (SwiGLU|GELU)
12 | # model_modern.py # GPTModern wrapper with feature flags
13 | # demo_generate.py # simple generation demo (shows KV cache + sliding window)
14 | # tests/
15 | # test_rmsnorm.py
16 | # test_rope_apply.py
17 | # test_kvcache_shapes.py
18 | #
19 | # Run from inside `part_3/`:
20 | # cd part_3
21 | # python orchestrator.py --demo
22 | # pytest -q
23 |
24 | import argparse, pathlib, subprocess, sys, shlex
25 |
26 | ROOT = pathlib.Path(__file__).resolve().parent
27 |
28 | def run(cmd: str):
29 | print(f"\n>>> {cmd}")
30 | res = subprocess.run(shlex.split(cmd), cwd=ROOT)
31 | if res.returncode != 0:
32 | sys.exit(res.returncode)
33 |
34 | if __name__ == "__main__":
35 | p = argparse.ArgumentParser()
36 | p.add_argument("--demo", action="store_true", help="run a tiny generation demo")
37 | args = p.parse_args()
38 |
39 | # 1) run unit tests
40 | run("python -m pytest -q tests/test_rmsnorm.py")
41 | run("python -m pytest -q tests/test_rope_apply.py")
42 | run("python -m pytest -q tests/test_kvcache_shapes.py")
43 |
44 | # 2) (optional) generation demo
45 | if args.demo:
46 | run("python demo_generate.py --rmsnorm --rope --swiglu --sliding_window 64 --sink 4 --tokens 200")
47 |
48 | print("\nPart 3 checks complete. ✅")
--------------------------------------------------------------------------------
/part_7/orchestrator.py:
--------------------------------------------------------------------------------
1 | # Repository layout (Part 7)
2 | #
3 | # part_7/
4 | # orchestrator.py # run unit tests + optional tiny RM demo
5 | # data_prefs.py # 7.1 HF preference loader (+tiny fallback)
6 | # collator_rm.py # pairwise tokenization → (pos, neg) tensors
7 | # model_reward.py # 7.2 reward model (Transformer encoder → scalar)
8 | # loss_reward.py # 7.3 Bradley–Terry & margin-ranking losses
9 | # train_rm.py # minimal one‑GPU training on tiny slice
10 | # eval_rm.py # 7.4 sanity checks & simple accuracy on val
11 | # tests/
12 | # test_bt_loss.py
13 | # test_reward_forward.py
14 | #
15 | # Run from inside `part_7/`:
16 | # cd part_7
17 | # python orchestrator.py --demo
18 | # pytest -q
19 |
20 | import argparse, pathlib, subprocess, sys, shlex
21 | ROOT = pathlib.Path(__file__).resolve().parent
22 |
23 | def run(cmd: str):
24 | print(f"\n>>> {cmd}")
25 | res = subprocess.run(shlex.split(cmd), cwd=ROOT)
26 | if res.returncode != 0:
27 | sys.exit(res.returncode)
28 |
29 | if __name__ == "__main__":
30 | p = argparse.ArgumentParser()
31 | p.add_argument("--demo", action="store_true", help="tiny reward‑model demo")
32 | args = p.parse_args()
33 |
34 | # 1) unit tests
35 | run("python -m pytest -q tests/test_bt_loss.py")
36 | run("python -m pytest -q tests/test_reward_forward.py")
37 |
38 | # 2) optional demo: tiny train + eval
39 | if args.demo:
40 | run("python train_rm.py --steps 300 --batch_size 8 --block_size 256 --n_layer 2 --n_head 2 --n_embd 128 --loss bt --bpe_dir ../part_4/runs/part4-demo/tokenizer")
41 | run("python eval_rm.py --ckpt runs/rm-demo/model_last.pt --split train[:8] --bpe_dir ../part_4/runs/part4-demo/tokenizer")
42 | run("python eval_rm.py --ckpt runs/rm-demo/model_last.pt --split test[:8] --bpe_dir ../part_4/runs/part4-demo/tokenizer")
43 |
44 | print("\nPart 7 checks complete. ✅")
--------------------------------------------------------------------------------
/part_5/README.md:
--------------------------------------------------------------------------------
1 | # Part 5 — Mixture-of-Experts (MoE)
2 |
3 | This part focuses purely on the **Mixture-of-Experts feed-forward component**.
4 | We are **not** implementing self-attention in this module.
5 | In a full transformer block, the order is generally:
6 |
7 | ```
8 | [LayerNorm] → [Self-Attention] → [Residual Add]
9 | → [LayerNorm] → [Feed-Forward Block (Dense or MoE)] → [Residual Add]
10 | ```
11 |
12 | Our MoE layer is a drop-in replacement for the dense feed-forward block.
13 | You can imagine it being called **after** the attention output, before the second residual connection.
14 |
15 | ---
16 |
17 | ## **5.1 Theory in 60 seconds**
18 | - **Experts**: multiple parallel MLPs; each token activates only a small subset (top-k) → sparse compute.
19 | - **Gate/Router**: scores each token across experts; picks top-k and assigns weights via a softmax.
20 | - **Dispatch/Combine**: send token to chosen experts, run their MLP, combine results using gate weights.
21 | - **Load balancing**: encourage uniform expert usage. A common aux loss (Switch Transformer) is
22 | `L_aux = E * Σ ( importance * load )` where:
23 | - *importance* = avg gate probability per expert
24 | - *load* = fraction of tokens routed as primary to that expert
25 |
26 | ---
27 |
28 | ## **5.3 Distributed notes (single-GPU friendly)**
29 | - Real MoE implementations distribute experts across GPUs (**expert parallelism**).
30 | - Here we keep everything **on one device** for simplicity. Dispatch is simulated with indexing/masking.
31 | - In production, dispatch/combination typically involves **all-to-all communication** across devices.
32 |
33 | ---
34 |
35 | ## **5.4 Hybrid architectures**
36 | - MoE need not replace every FFN — use it in alternating layers or blend outputs:
37 | `y = α * Dense(x) + (1 − α) * MoE(x)`
38 |
39 | ---
40 |
41 | **Milestone for this part:** integrate this MoE layer in place of a dense feed-forward in a transformer block and compare efficiency/accuracy trade-offs — even with a toy attention block if you wish.
42 |
--------------------------------------------------------------------------------
/part_3/demo_generate.py:
--------------------------------------------------------------------------------
1 | import argparse, torch
2 | from tokenizer import ByteTokenizer
3 | from model_modern import GPTModern
4 | import time
5 |
6 | if __name__ == "__main__":
7 | p = argparse.ArgumentParser()
8 | p.add_argument('--rmsnorm', action='store_true')
9 | p.add_argument('--rope', action='store_true')
10 | p.add_argument('--swiglu', action='store_true')
11 | p.add_argument('--sliding_window', type=int, default=None)
12 | p.add_argument('--sink', type=int, default=0)
13 | p.add_argument('--group_size', type=int, default=2)
14 | p.add_argument('--tokens', type=int, default=120)
15 | p.add_argument('--cpu', action='store_true')
16 | args = p.parse_args()
17 |
18 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
19 |
20 | tok = ByteTokenizer()
21 | model = GPTModern(vocab_size=tok.vocab_size, block_size=128, n_layer=2, n_head=4, n_embd=128,
22 | use_rmsnorm=args.rmsnorm, use_swiglu=args.swiglu, rope=args.rope,
23 | max_pos=4096, sliding_window=args.sliding_window, attention_sink=args.sink, n_kv_head=args.group_size).to(device)
24 |
25 | # empty prompt → newline
26 | prompt = torch.tensor([[10]], dtype=torch.long, device=device)
27 |
28 | with torch.no_grad():
29 | start = time.time()
30 | out = model.generate(prompt, max_new_tokens=args.tokens, temperature=0.0, top_k=50, top_p=None,
31 | sliding_window=args.sliding_window, attention_sink=args.sink)
32 | print(f"Generated {args.tokens} tokens in {time.time()-start:.2f} sec")
33 |
34 | start = time.time()
35 | out_nocache = model.generate_nocache(prompt, max_new_tokens=args.tokens, temperature=0.0, top_k=50, top_p=None,
36 | sliding_window=args.sliding_window, attention_sink=args.sink)
37 | print(f"(nocache) Generated {args.tokens} tokens in {time.time()-start:.2f} sec")
38 | print(tok.decode(out[0].cpu()))
39 | print(tok.decode(out_nocache[0].cpu()))
--------------------------------------------------------------------------------
/part_3/rope_custom.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch
3 | import math
4 |
5 | class RoPECache:
6 | """Precompute cos/sin for positions up to max_pos for even head_dim."""
7 | def __init__(self, head_dim: int, max_pos: int, base: float = 10000.0, device: torch.device | None = None):
8 | assert head_dim % 2 == 0, "RoPE head_dim must be even"
9 | self.head_dim = head_dim
10 | self.base = base
11 | self.device = device
12 | self._build(max_pos)
13 | def get(self, positions: torch.Tensor):
14 | # positions: (T,) or (1,T)
15 | if positions.dim() == 2:
16 | positions = positions[0]
17 | need = int(positions.max().item()) + 1 if positions.numel() > 0 else 1
18 | if need > self.max_pos:
19 | # grow tables
20 | self._build(max(need, int(self.max_pos * 2)))
21 | cos = self.cos[positions] # (T, D/2)
22 | sin = self.sin[positions]
23 | return cos, sin
24 |
25 | def _build(self, max_pos: int):
26 | """(Re)build cos/sin tables for a new max_pos."""
27 | self.max_pos = max_pos
28 | inv_freq = 1.0 / (10000.0 ** (torch.arange(0, self.head_dim, 2, device=self.device).float() / self.head_dim))
29 | t = torch.arange(max_pos, device=self.device).float()
30 | freqs = torch.outer(t, inv_freq) # (max_pos, head_dim/2)
31 | self.cos = torch.cos(freqs)
32 | self.sin = torch.sin(freqs)
33 |
34 | def apply_rope_single(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
35 | """Rotate pairs along last dim for RoPE.
36 | x: (B,H,T,D) with D even; cos/sin: (T,D/2)
37 | """
38 | assert x.size(-1) % 2 == 0
39 | cos = cos.unsqueeze(0).unsqueeze(0) # (1,1,T,D/2)
40 | sin = sin.unsqueeze(0).unsqueeze(0)
41 | x1 = x[..., ::2]
42 | x2 = x[..., 1::2]
43 | xr1 = x1 * cos - x2 * sin
44 | xr2 = x1 * sin + x2 * cos
45 | out = torch.empty_like(x)
46 | out[..., ::2] = xr1
47 | out[..., 1::2] = xr2
48 | return out
49 |
--------------------------------------------------------------------------------
/part_2/orchestrator.py:
--------------------------------------------------------------------------------
1 | # Repository layout (Part 2)
2 | #
3 | # part_2/
4 | # orchestrator.py # runs quick smoke-train + eval + sample
5 | # tokenizer.py # 2.1 byte-level tokenizer (0..255)
6 | # dataset.py # 2.2 dataset + batching + shift
7 | # utils.py # 2.5 sampling helpers (top-k/top-p)
8 | # model_gpt.py # tiny GPT: tok/pos emb + blocks + head
9 | # train.py # 2.3/2.4 training loop w/ val eval & ckpt
10 | # sample.py # 2.5 text generation from a checkpoint
11 | # eval_loss.py # 2.6 evaluate loss on a file/ckpt
12 | # tests/
13 | # test_tokenizer.py # round-trip encode/decode
14 | # test_dataset_shift.py # label shift sanity
15 | # runs/ # (created at runtime) checkpoints & logs
16 | #
17 | # NOTE ON IMPORTS
18 | # ----------------
19 | # All imports are LOCAL. Run from inside `part_2/`.
20 | # Example quickstart (CPU ok):
21 | # cd part_2
22 | # python train.py --data tiny.txt --steps 300 --sample_every 100
23 | # python sample.py --ckpt runs/min-gpt/model_best.pt --tokens 200 --prompt 'Once upon a time '
24 |
25 |
26 | import subprocess, sys, pathlib, shlex
27 |
28 | ROOT = pathlib.Path(__file__).resolve().parent
29 | RUNS = ROOT / 'runs' / 'min-gpt'
30 |
31 | def run(cmd: str):
32 | print(f"\n>>> {cmd}")
33 | res = subprocess.run(shlex.split(cmd), cwd=ROOT)
34 | if res.returncode != 0:
35 | sys.exit(res.returncode)
36 |
37 | if __name__ == '__main__':
38 | # quick smoke training on a tiny file path tiny_hi.txt; adjust as needed
39 | run("python train.py --data tiny_hi.txt --steps 400 --sample_every 100 --eval_interval 100 --batch_size 32 --block_size 128 --n_layer 2 --n_head 2 --n_embd 128")
40 |
41 | # sample from the best checkpoint
42 | run(f"python sample.py --ckpt {RUNS}/model_best.pt --tokens 200 --prompt 'Once upon a time '")
43 |
44 | # evaluate final val loss
45 | run(f"python eval_loss.py --data tiny_hi.txt --ckpt {RUNS}/model_best.pt --iters 50 --block_size 128")
--------------------------------------------------------------------------------
/part_1/attn_numpy_demo.py:
--------------------------------------------------------------------------------
1 | """1.2 Self-attention from first principles on a tiny example (NumPy only).
2 | We use T=3 tokens, d_model=4, d_k=d_v=2, single-head.
3 | This script prints intermediate tensors so you can trace the math.
4 |
5 | Dimensions summary (single head)
6 | --------------------------------
7 | X: (B=1, T=3, d_model=4)
8 | Wq/Wk/Wv: (d_model=4, d_k=2)
9 | Q,K,V: (1, 3, 2)
10 | Scores: (1, 3, 3) = Q @ K^T
11 | Weights: (1, 3, 3) = softmax over last dim
12 | Output: (1, 3, 2) = Weights @ V
13 | """
14 | import numpy as np
15 |
16 | np.set_printoptions(precision=4, suppress=True)
17 |
18 | # Toy inputs (batch=1, seq=3, d_model=4)
19 | X = np.array([[[0.1, 0.2, 0.3, 0.4],
20 | [0.5, 0.4, 0.3, 0.2],
21 | [0.0, 0.1, 0.0, 0.1]]], dtype=np.float32)
22 |
23 | # Weight matrices (learned in real models). We fix numbers for determinism.
24 | Wq = np.array([[ 0.2, -0.1],
25 | [ 0.0, 0.1],
26 | [ 0.1, 0.2],
27 | [-0.1, 0.0]], dtype=np.float32)
28 | Wk = np.array([[ 0.1, 0.1],
29 | [ 0.0, -0.1],
30 | [ 0.2, 0.0],
31 | [ 0.0, 0.2]], dtype=np.float32)
32 | Wv = np.array([[ 0.1, 0.0],
33 | [-0.1, 0.1],
34 | [ 0.2, -0.1],
35 | [ 0.0, 0.2]], dtype=np.float32)
36 |
37 | # Project to Q, K, V
38 | Q = X @ Wq # (1,3,2)
39 | K = X @ Wk # (1,3,2)
40 | V = X @ Wv # (1,3,2)
41 |
42 | print("Q shape:", Q.shape, "\nQ=\n", Q[0])
43 | print("K shape:", K.shape, "\nK=\n", K[0])
44 | print("V shape:", V.shape, "\nV=\n", V[0])
45 |
46 | # Scaled dot-products
47 | scale = 1.0 / np.sqrt(Q.shape[-1])
48 | attn_scores = (Q @ K.transpose(0,2,1)) * scale # (1,3,3)
49 |
50 | # Causal mask (upper triangle set to -inf so softmax->0)
51 | mask = np.triu(np.ones((1,3,3), dtype=bool), k=1)
52 | attn_scores = np.where(mask, -1e9, attn_scores)
53 |
54 | # Softmax over last dim
55 | weights = np.exp(attn_scores - attn_scores.max(axis=-1, keepdims=True))
56 | weights = weights / weights.sum(axis=-1, keepdims=True)
57 | print("Weights shape:", weights.shape, "\nAttention Weights (causal)=\n", weights[0])
58 |
59 | # Weighted sum of V
60 | out = weights @ V # (1,3,2)
61 | print("Output shape:", out.shape, "\nOutput=\n", out[0])
--------------------------------------------------------------------------------
/part_4/orchestrator.py:
--------------------------------------------------------------------------------
1 | # Repository layout (Part 4)
2 | #
3 | # part_4/
4 | # orchestrator.py # run unit tests + optional smoke train & sample
5 | # tokenizer_bpe.py # 4.1 BPE tokenization (train/save/load)
6 | # dataset_bpe.py # streaming dataset + batching & label shift
7 | # lr_scheduler.py # 4.3 Warmup + cosine decay scheduler
8 | # amp_accum.py # 4.2 AMP (autocast+GradScaler) + grad accumulation helpers
9 | # checkpointing.py # 4.4 save/resume (model/opt/scaler/scheduler/tokenizer)
10 | # logger.py # 4.5 logging backends (wandb / tensorboard / noop)
11 | # train.py # core training loop (no Trainer API)
12 | # sample.py # load checkpoint & generate text
13 | # tests/
14 | # test_tokenizer_bpe.py
15 | # test_scheduler.py
16 | # test_resume_shapes.py
17 | #
18 | # Run from inside `part_4/`:
19 | # cd part_4
20 | # python orchestrator.py --demo # tiny smoke run on ../tiny.txt
21 | # pytest -q
22 | # tensorboard --logdir=runs/part4-demo
23 |
24 | import argparse, pathlib, subprocess, sys, shlex
25 |
26 | ROOT = pathlib.Path(__file__).resolve().parent
27 |
28 | def run(cmd: str):
29 | print(f"\n>>> {cmd}")
30 | res = subprocess.run(shlex.split(cmd), cwd=ROOT)
31 | if res.returncode != 0:
32 | sys.exit(res.returncode)
33 |
34 | if __name__ == "__main__":
35 | p = argparse.ArgumentParser()
36 | p.add_argument("--demo", action="store_true", help="run a tiny smoke train+sample")
37 | args = p.parse_args()
38 |
39 | # 1) unit tests
40 | run("python -m pytest -q tests/test_tokenizer_bpe.py")
41 | run("python -m pytest -q tests/test_scheduler.py")
42 | run("python -m pytest -q tests/test_resume_shapes.py")
43 |
44 | # 2) optional demo (quick overfit on tiny file)
45 | if args.demo:
46 | run("python train.py --data ../part_2/tiny.txt --out runs/part4-demo --bpe --vocab_size 8000 --epochs 1 --steps 300 --batch_size 16 --block_size 128 --n_layer 2 --n_head 2 --n_embd 128 --mixed_precision --grad_accum_steps 2 --log tensorboard")
47 | run("python sample.py --ckpt runs/part4-demo/model_last.pt --tokens 100 --prompt 'Generate a short story'")
48 |
49 | print("\nPart 4 checks complete. ✅")
--------------------------------------------------------------------------------
/part_1/demo_mha_shapes.py:
--------------------------------------------------------------------------------
1 | """Walkthrough of multi-head attention with explicit matrix math and shapes.
2 | Generates a text log at ./out/mha_shapes.txt.
3 | """
4 | import os
5 | import math
6 | import torch
7 | from multi_head import MultiHeadSelfAttention
8 |
9 | OUT_TXT = os.path.join(os.path.dirname(__file__), 'out', 'mha_shapes.txt')
10 |
11 |
12 | def log(s):
13 | print(s)
14 | with open(OUT_TXT, 'a') as f:
15 | f.write(s + "\n")
16 |
17 |
18 | if __name__ == "__main__":
19 | # Reset file
20 | os.makedirs(os.path.dirname(OUT_TXT), exist_ok=True)
21 | open(OUT_TXT, 'w').close()
22 |
23 | B, T, d_model, n_head = 1, 5, 12, 3
24 | d_head = d_model // n_head
25 | x = torch.randn(B, T, d_model)
26 | attn = MultiHeadSelfAttention(d_model, n_head, trace_shapes=True)
27 |
28 | log(f"Input x: {tuple(x.shape)} = (B,T,d_model)")
29 | qkv = attn.qkv(x) # (B,T,3*d_model)
30 | log(f"Linear qkv(x): {tuple(qkv.shape)} = (B,T,3*d_model)")
31 |
32 | qkv = qkv.view(B, T, 3, n_head, d_head)
33 | log(f"view to 5D: {tuple(qkv.shape)} = (B,T,3,heads,d_head)")
34 |
35 | q, k, v = qkv.unbind(dim=2)
36 | log(f"q,k,v split: q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}")
37 |
38 | q = q.transpose(1, 2)
39 | k = k.transpose(1, 2)
40 | v = v.transpose(1, 2)
41 | log(f"transpose heads: q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)} = (B,heads,T,d_head)")
42 |
43 | scale = 1.0 / math.sqrt(d_head)
44 | scores = torch.matmul(q, k.transpose(-2, -1)) * scale
45 | log(f"scores q@k^T: {tuple(scores.shape)} = (B,heads,T,T)")
46 |
47 | weights = torch.softmax(scores, dim=-1)
48 | log(f"softmax(weights): {tuple(weights.shape)} = (B,heads,T,T)")
49 |
50 | ctx = torch.matmul(weights, v)
51 | log(f"context @v: {tuple(ctx.shape)} = (B,heads,T,d_head)")
52 |
53 | out = ctx.transpose(1, 2).contiguous().view(B, T, d_model)
54 | log(f"merge heads: {tuple(out.shape)} = (B,T,d_model)")
55 |
56 | out = attn.proj(out)
57 | log(f"final proj: {tuple(out.shape)} = (B,T,d_model)")
58 |
59 | log("\nLegend:")
60 | log(" B=batch, T=sequence length, d_model=embedding size, heads=n_head, d_head=d_model/heads")
61 | log(" qkv(x) is a single Linear producing [Q|K|V]; we reshape then split into q,k,v")
--------------------------------------------------------------------------------
/part_6/sample_sft.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import argparse, torch
3 |
4 | # Reuse GPTModern
5 | import sys
6 | from pathlib import Path as _P
7 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3'))
8 | from model_modern import GPTModern # noqa: E402
9 |
10 | from collator_sft import SFTCollator
11 | from formatters import format_prompt_only
12 |
13 |
14 | def main():
15 | p = argparse.ArgumentParser()
16 | p.add_argument('--ckpt', type=str, required=True)
17 | p.add_argument('--prompt', type=str, required=True)
18 | p.add_argument('--block_size', type=int, default=256)
19 | p.add_argument('--n_layer', type=int, default=4)
20 | p.add_argument('--n_head', type=int, default=4)
21 | p.add_argument('--n_embd', type=int, default=256)
22 | p.add_argument('--tokens', type=int, default=80)
23 | p.add_argument('--temperature', type=float, default=0.2)
24 | p.add_argument('--cpu', action='store_true')
25 | p.add_argument('--bpe_dir', type=str, default='../part_4/runs/part4-demo/tokenizer')
26 | args = p.parse_args()
27 |
28 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
29 |
30 | ckpt = torch.load(args.ckpt, map_location=device)
31 | cfg = ckpt.get('config', {})
32 |
33 | col = SFTCollator(block_size=cfg.get('block_size', 256), bpe_dir=args.bpe_dir)
34 | model = GPTModern(vocab_size=col.vocab_size, block_size=args.block_size,
35 | n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd,
36 | use_rmsnorm=True, use_swiglu=True, rope=True).to(device)
37 | model.load_state_dict(ckpt['model'])
38 | model.eval()
39 |
40 | prompt_text = format_prompt_only(args.prompt).replace('','')
41 | ids = col.encode(prompt_text)
42 | idx = torch.tensor([ids], dtype=torch.long, device=device)
43 |
44 | with torch.no_grad():
45 | out = model.generate(idx, max_new_tokens=args.tokens,
46 | temperature=args.temperature, top_k=3)
47 |
48 | # decode: prefer BPE if collator has it, else fall back to bytes
49 | out_ids = out[0].tolist()
50 | orig_len = idx.size(1)
51 | if hasattr(col, "tok") and hasattr(col.tok, "decode"):
52 | # decode full text or just the generated suffix; suffix is often clearer
53 | generated = col.tok.decode(out_ids)
54 | print(generated)
55 | else:
56 | generated = bytes(out_ids[orig_len:]).decode("utf-8", errors="ignore")
57 | print(generated)
58 |
59 |
60 | if __name__ == '__main__':
61 | main()
--------------------------------------------------------------------------------
/part_1/multi_head.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from attn_mask import causal_mask
6 |
7 | class MultiHeadSelfAttention(nn.Module):
8 | """1.4 Multi-head attention with explicit shape tracing.
9 |
10 | Dimensions (before masking):
11 | x: (B, T, d_model)
12 | qkv: (B, T, 3*d_model)
13 | view→ (B, T, 3, n_head, d_head) where d_head = d_model // n_head
14 | split→ q,k,v each (B, T, n_head, d_head)
15 | swap→ (B, n_head, T, d_head)
16 | scores: (B, n_head, T, T) = q @ k^T / sqrt(d_head)
17 | weights:(B, n_head, T, T) = softmax(scores)
18 | ctx: (B, n_head, T, d_head) = weights @ v
19 | merge: (B, T, n_head*d_head) = (B, T, d_model)
20 | """
21 | def __init__(self, d_model: int, n_head: int, dropout: float = 0.0, trace_shapes: bool = True):
22 | super().__init__()
23 | assert d_model % n_head == 0, "d_model must be divisible by n_head"
24 | self.n_head = n_head
25 | self.d_head = d_model // n_head
26 | self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
27 | self.proj = nn.Linear(d_model, d_model, bias=False)
28 | self.dropout = nn.Dropout(dropout)
29 | self.trace_shapes = trace_shapes
30 |
31 | def forward(self, x: torch.Tensor): # (B,T,d_model)
32 | B, T, C = x.shape
33 | qkv = self.qkv(x) # (B,T,3*C)
34 | qkv = qkv.view(B, T, 3, self.n_head, self.d_head) # (B,T,3,heads,dim)
35 | if self.trace_shapes:
36 | print("qkv view:", qkv.shape)
37 | q, k, v = qkv.unbind(dim=2) # each: (B,T,heads,dim)
38 | q = q.transpose(1, 2) # (B,heads,T,dim)
39 | k = k.transpose(1, 2)
40 | v = v.transpose(1, 2)
41 | if self.trace_shapes:
42 | print("q:", q.shape, "k:", k.shape, "v:", v.shape)
43 |
44 | scale = 1.0 / math.sqrt(self.d_head)
45 | attn = torch.matmul(q, k.transpose(-2, -1)) * scale # (B,heads,T,T)
46 | mask = causal_mask(T, device=x.device)
47 | attn = attn.masked_fill(mask, float('-inf'))
48 | w = F.softmax(attn, dim=-1)
49 | w = self.dropout(w)
50 | ctx = torch.matmul(w, v) # (B,heads,T,dim)
51 | if self.trace_shapes:
52 | print("weights:", w.shape, "ctx:", ctx.shape)
53 | out = ctx.transpose(1, 2).contiguous().view(B, T, C) # (B,T,d_model)
54 | out = self.proj(out)
55 | if self.trace_shapes:
56 | print("out:", out.shape)
57 | return out, w
--------------------------------------------------------------------------------
/part_1/orchestrator.py:
--------------------------------------------------------------------------------
1 | # Repository layout (Part 1)
2 | #
3 | # part_1/
4 | # orchestrator.py # runs demos/tests/visualizations for Part 1
5 | # pos_encoding.py # 1.1 positional encodings (learned + sinusoidal)
6 | # attn_numpy_demo.py # 1.2 self-attention math with tiny numbers (NumPy)
7 | # single_head.py # 1.3 single attention head (PyTorch)
8 | # multi_head.py # 1.4 multi-head attention (with shape tracing)
9 | # ffn.py # 1.5 feed-forward network (GELU, width = mult*d_model)
10 | # block.py # 1.6 Transformer block (residuals + LayerNorm)
11 | # attn_mask.py # causal mask helpers
12 | # vis_utils.py # plotting helpers (matrices & attention maps)
13 | # demo_mha_shapes.py # prints explicit matrix multiplications & shapes step-by-step
14 | # demo_visualize_multi_head.py # saves attention heatmaps per head (grid)
15 | # out/ # (created at runtime) images & logs live here
16 | # tests/
17 | # test_attn_math.py # correctness: tiny example vs PyTorch single-head
18 | # test_causal_mask.py # verifies masking behavior
19 | #
20 | # NOTE ON IMPORTS
21 | # ----------------
22 | # All imports are LOCAL. Run from inside `part_1/`.
23 | # Example quickstart (CPU ok):
24 | # cd part_1
25 | # python orchestrator.py --visualize
26 |
27 |
28 | import subprocess, sys, pathlib, argparse, shlex
29 |
30 | ROOT = pathlib.Path(__file__).resolve().parent
31 | OUT = ROOT / "out"
32 |
33 |
34 | def run(cmd: str):
35 | print(f"\n>>> {cmd}")
36 | res = subprocess.run(shlex.split(cmd), cwd=ROOT)
37 | if res.returncode != 0:
38 | sys.exit(res.returncode)
39 |
40 |
41 | def main():
42 | p = argparse.ArgumentParser()
43 | p.add_argument("--visualize", action="store_true", help="run visualization scripts and save PNGs to ./out")
44 | args = p.parse_args()
45 |
46 | OUT.mkdir(exist_ok=True)
47 |
48 | # 1.2 sanity check: NumPy tiny example
49 | run("python attn_numpy_demo.py")
50 |
51 | # 1.3/1.4 unit tests
52 | run("python -m pytest -q tests/test_attn_math.py")
53 | run("python -m pytest -q tests/test_causal_mask.py")
54 |
55 | # Matrix math walkthrough for MHA
56 | run("python demo_mha_shapes.py")
57 |
58 | if args.visualize:
59 | run("python demo_visualize_multi_head.py")
60 | print(f"\nVisualization images saved to: {OUT}")
61 |
62 | print("\nAll Part 1 demos/tests completed. ✅")
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
--------------------------------------------------------------------------------
/part_8/orchestrator.py:
--------------------------------------------------------------------------------
1 | # Repository layout (Part 8 — RLHF with PPO)
2 | #
3 | # part_8/
4 | # orchestrator.py # run unit tests + optional tiny PPO demo
5 | # policy.py # policy = SFT LM + value head (toy head on logits)
6 | # rollout.py # prompt formatting, sampling, logprobs/KL utilities
7 | # ppo_loss.py # PPO clipped objective + value + entropy + KL penalty
8 | # train_ppo.py # single‑GPU RLHF loop (tiny, on‑policy)
9 | # eval_ppo.py # compare reward vs. reference on a small set
10 | # tests/
11 | # test_ppo_loss.py
12 | # test_policy_forward.py
13 | #
14 | # Run from inside `part_8/`:
15 | # cd part_8
16 | # python orchestrator.py --demo
17 | # pytest -q
18 |
19 | import argparse, pathlib, subprocess, sys
20 | ROOT = pathlib.Path(__file__).resolve().parent
21 |
22 | def run(cmd: str):
23 | print(f"\n>>> {cmd}")
24 | res = subprocess.run(cmd.split(), cwd=ROOT)
25 | if res.returncode != 0:
26 | sys.exit(res.returncode)
27 |
28 | if __name__ == "__main__":
29 | p = argparse.ArgumentParser()
30 | p.add_argument("--demo", action="store_true", help="tiny PPO demo")
31 | args = p.parse_args()
32 |
33 | # 1) unit tests
34 | run("python -m pytest -q tests/test_ppo_loss.py")
35 | run("python -m pytest -q tests/test_policy_forward.py")
36 |
37 | # 2) optional demo (requires SFT+RM checkpoints from Parts 6 & 7)
38 | if args.demo:
39 | # run("python train_ppo.py --policy_ckpt ../part_6/runs/sft-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --steps 10 --batch_size 4 --resp_len 128 --bpe_dir ../part_4/runs/part4-demo/tokenizer")
40 | # run("python eval_ppo.py --policy_ckpt runs/ppo-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --split train[:24] --bpe_dir ../part_4/runs/part4-demo/tokenizer")
41 |
42 | # run("python train_ppo.py --policy_ckpt ../part_6/runs/sft-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --steps 50 --batch_size 4 --resp_len 128 --bpe_dir ../part_4/runs/part4-demo/tokenizer")
43 | # run("python eval_ppo.py --policy_ckpt runs/ppo-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --split train[:24] --bpe_dir ../part_4/runs/part4-demo/tokenizer")
44 |
45 | run("python train_ppo.py --policy_ckpt ../part_6/runs/sft-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --steps 100 --batch_size 4 --resp_len 128 --bpe_dir ../part_4/runs/part4-demo/tokenizer")
46 | run("python eval_ppo.py --policy_ckpt runs/ppo-demo/model_last.pt --reward_ckpt ../part_7/runs/rm-demo/model_last.pt --split train[:24] --bpe_dir ../part_4/runs/part4-demo/tokenizer")
47 |
48 | print("\nPart 8 checks complete. ✅")
--------------------------------------------------------------------------------
/part_7/collator_rm.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import List, Tuple
3 | import torch
4 |
5 | # Prefer BPE from Part 4, else ByteTokenizer from Part 3
6 | import sys
7 | from pathlib import Path as _P
8 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_4'))
9 | try:
10 | from tokenizer_bpe import BPETokenizer
11 | _HAS_BPE = True
12 | except Exception:
13 | _HAS_BPE = False
14 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3'))
15 | try:
16 | from tokenizer import ByteTokenizer
17 | except Exception:
18 | ByteTokenizer = None
19 |
20 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_6'))
21 | try:
22 | from formatters import Example, format_example # reuse formatting
23 | except Exception:
24 | pass
25 |
26 | class PairCollator:
27 | """Tokenize preference pairs into (pos, neg) input ids.
28 | We format as the SFT template with the 'chosen' or 'rejected' text as the Response.
29 | """
30 | def __init__(self, block_size: int = 256, bpe_dir: str | None = None, vocab_size: int | None = None):
31 | self.block_size = block_size
32 | self.tok = None
33 | if _HAS_BPE:
34 | try:
35 | self.tok = BPETokenizer(vocab_size=vocab_size or 8000)
36 | if bpe_dir:
37 | self.tok.load(bpe_dir)
38 | except Exception:
39 | self.tok = None
40 | if self.tok is None and ByteTokenizer is not None:
41 | self.tok = ByteTokenizer()
42 | if self.tok is None:
43 | raise RuntimeError("No tokenizer available.")
44 |
45 | @property
46 | def vocab_size(self) -> int:
47 | return getattr(self.tok, 'vocab_size', 256)
48 |
49 | def _encode(self, text: str) -> List[int]:
50 | if hasattr(self.tok, 'encode'):
51 | ids = self.tok.encode(text)
52 | if isinstance(ids, torch.Tensor):
53 | ids = ids.tolist()
54 | return ids
55 | return list(text.encode('utf-8'))
56 |
57 | def collate(self, batch: List[Tuple[str, str, str]]):
58 | # batch of (prompt, chosen, rejected)
59 | pos_ids, neg_ids = [], []
60 | for prompt, chosen, rejected in batch:
61 | pos_text = format_example(Example(prompt, chosen))
62 | neg_text = format_example(Example(prompt, rejected))
63 | pos_ids.append(self._encode(pos_text)[:self.block_size])
64 | neg_ids.append(self._encode(neg_text)[:self.block_size])
65 | def pad_to(x, pad=2):
66 | return x + [pad] * (self.block_size - len(x)) if len(x) < self.block_size else x[:self.block_size]
67 | pos = torch.tensor([pad_to(x) for x in pos_ids], dtype=torch.long)
68 | neg = torch.tensor([pad_to(x) for x in neg_ids], dtype=torch.long)
69 | return pos, neg
--------------------------------------------------------------------------------
/part_7/train_rm.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import argparse, torch
3 | from pathlib import Path
4 |
5 | from data_prefs import load_preferences
6 | from collator_rm import PairCollator
7 | from model_reward import RewardModel
8 | from loss_reward import bradley_terry_loss, margin_ranking_loss
9 |
10 |
11 | def main():
12 | p = argparse.ArgumentParser()
13 | p.add_argument('--out', type=str, default='runs/rm-demo')
14 | p.add_argument('--steps', type=int, default=500)
15 | p.add_argument('--batch_size', type=int, default=8)
16 | p.add_argument('--block_size', type=int, default=256)
17 | p.add_argument('--n_layer', type=int, default=4)
18 | p.add_argument('--n_head', type=int, default=4)
19 | p.add_argument('--n_embd', type=int, default=256)
20 | p.add_argument('--lr', type=float, default=1e-4)
21 | p.add_argument('--loss', choices=['bt','margin'], default='bt')
22 | p.add_argument('--cpu', action='store_true')
23 | p.add_argument('--bpe_dir', type=str, default=None)
24 | args = p.parse_args()
25 |
26 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
27 |
28 | # data
29 | items = load_preferences(split='train[:80]')
30 | triples = [(it.prompt, it.chosen, it.rejected) for it in items]
31 |
32 | # collator + model
33 | col = PairCollator(block_size=args.block_size, bpe_dir=args.bpe_dir)
34 | model = RewardModel(vocab_size=col.vocab_size, block_size=args.block_size,
35 | n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd).to(device)
36 | opt = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
37 |
38 | # train (tiny)
39 | step = 0; i = 0
40 | while step < args.steps:
41 | batch = triples[i:i+args.batch_size]
42 | if not batch:
43 | i = 0; continue
44 | pos, neg = col.collate(batch)
45 | pos, neg = pos.to(device), neg.to(device)
46 | r_pos = model(pos)
47 | r_neg = model(neg)
48 | if args.loss == 'bt':
49 | loss = bradley_terry_loss(r_pos, r_neg)
50 | else:
51 | loss = margin_ranking_loss(r_pos, r_neg, margin=1.0)
52 | opt.zero_grad(set_to_none=True)
53 | loss.backward()
54 | opt.step()
55 | step += 1; i += args.batch_size
56 | if step % 25 == 0:
57 | acc = (r_pos > r_neg).float().mean().item()
58 | print(f"step {step}: loss={loss.item():.4f} acc={acc:.2f}")
59 |
60 | Path(args.out).mkdir(parents=True, exist_ok=True)
61 | torch.save({'model': model.state_dict(), 'config': {
62 | 'vocab_size': col.vocab_size,
63 | 'block_size': args.block_size,
64 | 'n_layer': args.n_layer,
65 | 'n_head': args.n_head,
66 | 'n_embd': args.n_embd,
67 | }}, str(Path(args.out)/'model_last.pt'))
68 | print(f"Saved reward model to {args.out}/model_last.pt")
69 |
70 | if __name__ == '__main__':
71 | main()
--------------------------------------------------------------------------------
/part_4/tokenizer_bpe.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import os, json
3 | from pathlib import Path
4 | from typing import List, Union
5 |
6 | try:
7 | from tokenizers import ByteLevelBPETokenizer, Tokenizer
8 | except Exception:
9 | ByteLevelBPETokenizer = None
10 |
11 | class BPETokenizer:
12 | """Minimal BPE wrapper (HuggingFace tokenizers).
13 | Trains on a text file or a folder of .txt files. Saves merges/vocab to out_dir.
14 | """
15 | def __init__(self, vocab_size: int = 32000, special_tokens: List[str] | None = None):
16 | if ByteLevelBPETokenizer is None:
17 | raise ImportError("Please `pip install tokenizers` for BPETokenizer.")
18 | self.vocab_size = vocab_size
19 | self.special_tokens = special_tokens or ["", "", "", "", ""]
20 | self._tok = None
21 |
22 | def train(self, data_path: Union[str, Path]):
23 | files: List[str] = []
24 | p = Path(data_path)
25 | if p.is_dir():
26 | files = [str(fp) for fp in p.glob("**/*.txt")]
27 | else:
28 | files = [str(p)]
29 | tok = ByteLevelBPETokenizer()
30 | tok.train(files=files, vocab_size=self.vocab_size, min_frequency=2, special_tokens=self.special_tokens)
31 | self._tok = tok
32 |
33 | def save(self, out_dir: Union[str, Path]):
34 | out = Path(out_dir); out.mkdir(parents=True, exist_ok=True)
35 | assert self._tok is not None, "Train or load before save()."
36 | self._tok.save_model(str(out))
37 | self._tok.save(str(out / "tokenizer.json"))
38 | meta = {"vocab_size": self.vocab_size, "special_tokens": self.special_tokens}
39 | (out/"bpe_meta.json").write_text(json.dumps(meta))
40 |
41 | def load(self, dir_path: Union[str, Path]):
42 | dirp = Path(dir_path)
43 | # Prefer explicit filenames; fall back to glob if needed.
44 | vocab = dirp / "vocab.json"
45 | merges = dirp / "merges.txt"
46 | tokenizer = dirp / "tokenizer.json"
47 | if not vocab.exists() or not merges.exists():
48 | # Fallback for custom basenames
49 | vs = list(dirp.glob("*.json"))
50 | ms = list(dirp.glob("*.txt"))
51 | if not vs or not ms:
52 | raise FileNotFoundError(f"Could not find vocab.json/merges.txt in {dirp}")
53 | vocab = vs[0]
54 | merges = ms[0]
55 | # tok = ByteLevelBPETokenizer(str(vocab), str(merges))
56 | tok = Tokenizer.from_file(str(tokenizer))
57 | self._tok = tok
58 | meta_file = dirp / "bpe_meta.json"
59 | if meta_file.exists():
60 | meta = json.loads(meta_file.read_text())
61 | self.vocab_size = meta.get("vocab_size", self.vocab_size)
62 | self.special_tokens = meta.get("special_tokens", self.special_tokens)
63 |
64 |
65 | def encode(self, text: str):
66 | ids = self._tok.encode(text).ids
67 | return ids
68 |
69 | def decode(self, ids):
70 | return self._tok.decode(ids)
--------------------------------------------------------------------------------
/part_6/orchestrator.py:
--------------------------------------------------------------------------------
1 | # Repository layout (Part 6)
2 | #
3 | # part_6/
4 | # orchestrator.py # run unit tests + optional tiny SFT demo
5 | # formatters.py # 6.1 prompt/response templates
6 | # dataset_sft.py # HF dataset loader (+tiny fallback) → (prompt, response)
7 | # collator_sft.py # 6.2 causal LM labels with masking
8 | # curriculum.py # 6.3 length‑based curriculum sampler
9 | # evaluate.py # 6.4 simple exact/F1 metrics
10 | # train_sft.py # minimal one‑GPU SFT loop (few steps)
11 | # sample_sft.py # load ckpt & generate from instructions
12 | # tests/
13 | # test_formatter.py
14 | # test_masking.py
15 | #
16 | # Run from inside `part_6/`:
17 | # cd part_6
18 | # python orchestrator.py --demo
19 | # pytest -q
20 |
21 | #
22 | # part_6/
23 | # orchestrator.py # run unit tests + optional tiny SFT demo
24 | # formatters.py # 6.1 prompt/response templates
25 | # dataset_sft.py # HF dataset loader (+tiny fallback) → (prompt, response)
26 | # collator_sft.py # 6.2 causal LM labels with masking
27 | # curriculum.py # 6.3 length‑based curriculum sampler
28 | # evaluate.py # 6.4 simple exact/F1 metrics
29 | # train_sft.py # minimal one‑GPU SFT loop (few steps)
30 | # sample_sft.py # load ckpt & generate from instructions
31 | # tests/
32 | # test_formatter.py
33 | # test_masking.py
34 | #
35 | # Run from inside `part_6/`:
36 | # cd part_6
37 | # python orchestrator.py --demo
38 | # pytest -q
39 |
40 | ### FILE: part_6/orchestrator.py
41 | import argparse, pathlib, subprocess, sys, shlex
42 | ROOT = pathlib.Path(__file__).resolve().parent
43 |
44 | def run(cmd: str):
45 | print(f"\n>>> {cmd}")
46 | res = subprocess.run(shlex.split(cmd), cwd=ROOT)
47 | if res.returncode != 0:
48 | sys.exit(res.returncode)
49 |
50 | if __name__ == "__main__":
51 | p = argparse.ArgumentParser()
52 | p.add_argument("--demo", action="store_true", help="tiny SFT demo on a few samples")
53 | args = p.parse_args()
54 |
55 | # 1) unit tests
56 | run("python -m pytest -q tests/test_formatter.py")
57 | run("python -m pytest -q tests/test_masking.py")
58 |
59 | # 2) optional demo
60 | if args.demo:
61 | # --ckpt ../part_4/runs/part4-demo/model_last.pt # assumes Part 4 demo has been run
62 | run("python train_sft.py --data huggingface --ckpt ../part_4/runs/part4-demo/model_last.pt --out runs/sft-demo --steps 300 --batch_size 8 --block_size 256 --n_layer 2 --n_head 2 --n_embd 128")
63 | run("python sample_sft.py --ckpt runs/sft-demo/model_last.pt --block_size 256 --n_layer 2 --n_head 2 --n_embd 128 --prompt 'What are the three primary colors?' --tokens 30 --temperature 0.2")
64 | run("python sample_sft.py --ckpt runs/sft-demo/model_last.pt --block_size 256 --n_layer 2 --n_head 2 --n_embd 128 --prompt 'What does DNA stand for?' --tokens 30 --temperature 0.2")
65 | run("python sample_sft.py --ckpt runs/sft-demo/model_last.pt --block_size 256 --n_layer 2 --n_head 2 --n_embd 128 --prompt 'Reverse engineer this code to create a new version\ndef factorialize(num):\n factorial = 1\n for i in range(1, num):\n factorial *= i\n \n return factorial' --tokens 64 --temperature 0.2")
66 |
67 | print("\nPart 6 checks complete. ✅")
--------------------------------------------------------------------------------
/part_8/eval_ppo.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import argparse, torch
3 | from pathlib import Path
4 |
5 | from policy import PolicyWithValue
6 | from rollout import RLHFTokenizer, sample_prompts, format_prompt_only
7 |
8 | # Reward model
9 | import sys
10 | from pathlib import Path as _P
11 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_7'))
12 | from model_reward import RewardModel # noqa: E402
13 |
14 |
15 | def score_policy(policy_ckpt: str, rm_ckpt: str, bpe_dir: str | None, n: int = 16):
16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 | tok = RLHFTokenizer(block_size=256, bpe_dir=bpe_dir)
18 |
19 | ckpt = torch.load(policy_ckpt, map_location=device)
20 | cfg = ckpt.get('config', {})
21 | pol = PolicyWithValue(cfg.get('vocab_size', tok.vocab_size), cfg.get('block_size', tok.block_size),
22 | cfg.get('n_layer', 2), cfg.get('n_head', 2), cfg.get('n_embd', 128)).to(device)
23 | pol.load_state_dict(ckpt['model'])
24 | pol.eval()
25 |
26 | # For comparing against reference policy (SFT)
27 | ref = PolicyWithValue(cfg.get('vocab_size', tok.vocab_size), cfg.get('block_size', tok.block_size),
28 | cfg.get('n_layer', 2), cfg.get('n_head', 2), cfg.get('n_embd', 128)).to(device)
29 | ckpt_ref = torch.load("../part_6/runs/sft-demo/model_last.pt", map_location=device) # hardcoded path to SFT checkpoint
30 | ref.lm.load_state_dict(ckpt_ref['model'])
31 | for p_ in ref.parameters():
32 | p_.requires_grad_(False)
33 | ref.eval()
34 |
35 | rckpt = torch.load(rm_ckpt, map_location=device)
36 | rm = RewardModel(vocab_size=rckpt['config'].get('vocab_size', tok.vocab_size), block_size=rckpt['config'].get('block_size', tok.block_size),
37 | n_layer=rckpt['config'].get('n_layer', 4), n_head=rckpt['config'].get('n_head', 4), n_embd=rckpt['config'].get('n_embd', 256)).to(device)
38 | rm.load_state_dict(rckpt['model'])
39 | rm.eval()
40 |
41 | prompts = sample_prompts(n)
42 | rewards = []
43 | for p in prompts:
44 | prefix = format_prompt_only(p).replace('', '')
45 | ids = tok.encode(prefix)
46 | x = torch.tensor([ids[-tok.block_size:]], dtype=torch.long, device=device)
47 | with torch.no_grad():
48 | y = pol.generate(x, max_new_tokens=128, temperature=0.2, top_k=50)
49 | y_old = ref.generate(x, max_new_tokens=128, temperature=0.2, top_k=50)
50 | resp = tok.decode(y[0].tolist()[len(ids[-tok.block_size:]):])
51 | resp_old = tok.decode(y_old[0].tolist()[len(ids[-tok.block_size:]):])
52 |
53 | # compute RM reward on formatted full text
54 | from part_6.formatters import Example, format_example
55 | text = format_example(Example(p, resp))
56 | z = torch.tensor([tok.encode(text)[:tok.block_size]], dtype=torch.long, device=device)
57 | with torch.no_grad():
58 | r = rm(z)[0].item()
59 | rewards.append(r)
60 | return sum(rewards)/max(1,len(rewards))
61 |
62 |
63 | if __name__ == '__main__':
64 | p = argparse.ArgumentParser()
65 | p.add_argument('--policy_ckpt', type=str, required=True)
66 | p.add_argument('--reward_ckpt', type=str, required=True)
67 | p.add_argument('--split', type=str, default='val[:32]') # unused in this tiny script
68 | p.add_argument('--bpe_dir', type=str, default=None)
69 | args = p.parse_args()
70 |
71 | avg_r = score_policy(args.policy_ckpt, args.reward_ckpt, args.bpe_dir, n=16)
72 | print(f"Avg RM reward: {avg_r:.4f}")
--------------------------------------------------------------------------------
/part_9/eval_ppo.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import argparse, torch
3 | from pathlib import Path
4 |
5 | from policy import PolicyWithValue
6 | from rollout import RLHFTokenizer, sample_prompts, format_prompt_only
7 |
8 | # Reward model
9 | import sys
10 | from pathlib import Path as _P
11 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_7'))
12 | from model_reward import RewardModel # noqa: E402
13 |
14 |
15 | def score_policy(policy_ckpt: str, rm_ckpt: str, bpe_dir: str | None, n: int = 16):
16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 | tok = RLHFTokenizer(block_size=256, bpe_dir=bpe_dir)
18 |
19 | ckpt = torch.load(policy_ckpt, map_location=device)
20 | cfg = ckpt.get('config', {})
21 | pol = PolicyWithValue(cfg.get('vocab_size', tok.vocab_size), cfg.get('block_size', tok.block_size),
22 | cfg.get('n_layer', 2), cfg.get('n_head', 2), cfg.get('n_embd', 128)).to(device)
23 | pol.load_state_dict(ckpt['model'])
24 | pol.eval()
25 |
26 | # For comparing against reference policy (SFT)
27 | ref = PolicyWithValue(cfg.get('vocab_size', tok.vocab_size), cfg.get('block_size', tok.block_size),
28 | cfg.get('n_layer', 2), cfg.get('n_head', 2), cfg.get('n_embd', 128)).to(device)
29 | ckpt_ref = torch.load("../part_6/runs/sft-demo/model_last.pt", map_location=device) # hardcoded path to SFT checkpoint
30 | ref.lm.load_state_dict(ckpt_ref['model'])
31 | for p_ in ref.parameters():
32 | p_.requires_grad_(False)
33 | ref.eval()
34 |
35 | rckpt = torch.load(rm_ckpt, map_location=device)
36 | rm = RewardModel(vocab_size=rckpt['config'].get('vocab_size', tok.vocab_size), block_size=rckpt['config'].get('block_size', tok.block_size),
37 | n_layer=rckpt['config'].get('n_layer', 4), n_head=rckpt['config'].get('n_head', 4), n_embd=rckpt['config'].get('n_embd', 256)).to(device)
38 | rm.load_state_dict(rckpt['model'])
39 | rm.eval()
40 |
41 | prompts = sample_prompts(n)
42 | rewards = []
43 | for p in prompts:
44 | prefix = format_prompt_only(p).replace('', '')
45 | ids = tok.encode(prefix)
46 | x = torch.tensor([ids[-tok.block_size:]], dtype=torch.long, device=device)
47 | with torch.no_grad():
48 | y = pol.generate(x, max_new_tokens=128, temperature=0.2, top_k=50)
49 | y_old = ref.generate(x, max_new_tokens=128, temperature=0.2, top_k=50)
50 | resp = tok.decode(y[0].tolist()[len(ids[-tok.block_size:]):])
51 | resp_old = tok.decode(y_old[0].tolist()[len(ids[-tok.block_size:]):])
52 |
53 | # compute RM reward on formatted full text
54 | from part_6.formatters import Example, format_example
55 | text = format_example(Example(p, resp))
56 | z = torch.tensor([tok.encode(text)[:tok.block_size]], dtype=torch.long, device=device)
57 | with torch.no_grad():
58 | r = rm(z)[0].item()
59 | rewards.append(r)
60 | return sum(rewards)/max(1,len(rewards))
61 |
62 |
63 | if __name__ == '__main__':
64 | p = argparse.ArgumentParser()
65 | p.add_argument('--policy_ckpt', type=str, required=True)
66 | p.add_argument('--reward_ckpt', type=str, required=True)
67 | p.add_argument('--split', type=str, default='val[:32]') # unused in this tiny script
68 | p.add_argument('--bpe_dir', type=str, default=None)
69 | args = p.parse_args()
70 |
71 | avg_r = score_policy(args.policy_ckpt, args.reward_ckpt, args.bpe_dir, n=16)
72 | print(f"Avg RM reward: {avg_r:.4f}")
--------------------------------------------------------------------------------
/part_4/sample.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import argparse, torch
3 | from pathlib import Path
4 |
5 | # load Part 3 model
6 | import sys
7 | from pathlib import Path as _P
8 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3'))
9 | from model_modern import GPTModern # noqa: E402
10 |
11 | from tokenizer_bpe import BPETokenizer
12 |
13 |
14 | def main():
15 | p = argparse.ArgumentParser()
16 | p.add_argument('--ckpt', type=str, required=True)
17 | p.add_argument('--prompt', type=str, default='')
18 | p.add_argument('--tokens', type=int, default=100)
19 | p.add_argument('--cpu', action='store_true')
20 | args = p.parse_args()
21 |
22 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
23 |
24 | ckpt = torch.load(args.ckpt, map_location='cpu') # load on CPU first; move model later
25 | sd = ckpt['model']
26 | cfg = ckpt.get('config') or {}
27 |
28 | # tokenizer (if present)
29 | tok = None
30 | tok_dir_file = Path(args.ckpt).with_name('tokenizer_dir.txt')
31 | if tok_dir_file.exists():
32 | tok_dir = tok_dir_file.read_text().strip() # file contains the dir path
33 | tok = BPETokenizer()
34 | tok.load(tok_dir) # <-- instance method, pass the directory
35 | vocab_from_tok = tok.vocab_size
36 | else:
37 | vocab_from_tok = None
38 |
39 |
40 | # ---- build config (prefer saved config; otherwise infer) ----
41 | if not cfg:
42 | # If a tokenizer is present and vocab differs, override with tokenizer vocab
43 | # if vocab_from_tok is not None and cfg.get('vocab_size') != vocab_from_tok:
44 | # cfg = {**cfg, 'vocab_size': vocab_from_tok}
45 | # else:
46 | # Old checkpoints without config: infer essentials from weights
47 | # tok_emb.weight: [V, C] where C == n_embd
48 | V, C = sd['tok_emb.weight'].shape
49 | # pos_emb.weight: [block_size, C] if present
50 | block_size = sd['pos_emb.weight'].shape[0] if 'pos_emb.weight' in sd else 256
51 | # count transformer blocks present
52 | import re
53 | layer_ids = {int(m.group(1)) for k in sd.keys() if (m := re.match(r"blocks\.(\d+)\.", k))}
54 | n_layer = max(layer_ids) + 1 if layer_ids else 1
55 | # pick an n_head that divides C (head count doesn't affect weight shapes)
56 | n_head = 8 if C % 8 == 0 else 4 if C % 4 == 0 else 2 if C % 2 == 0 else 1
57 | cfg = dict(
58 | vocab_size=vocab_from_tok or V,
59 | block_size=block_size,
60 | n_layer=n_layer,
61 | n_head=n_head,
62 | n_embd=C,
63 | dropout=0.0,
64 | use_rmsnorm=True,
65 | use_swiglu=True,
66 | rope=True,
67 | max_pos=4096,
68 | sliding_window=None,
69 | attention_sink=0,
70 | )
71 |
72 | # ---- build & load model ----
73 | model = GPTModern(**cfg).to(device).eval()
74 | model.load_state_dict(ckpt['model'])
75 | model.to(device).eval()
76 |
77 | # prompt ids
78 | if tok:
79 | ids = tok.encode(args.prompt)
80 | if len(ids) == 0: ids = [10]
81 | else:
82 | ids = [10] if args.prompt == '' else list(args.prompt.encode('utf-8'))
83 | idx = torch.tensor([ids], dtype=torch.long, device=device)
84 |
85 | with torch.no_grad():
86 | out = model.generate(idx, max_new_tokens=args.tokens)
87 | out_ids = out[0].tolist()
88 | if tok:
89 | print(tok.decode(out_ids))
90 | else:
91 | print(bytes(out_ids).decode('utf-8', errors='ignore'))
92 |
93 | if __name__ == '__main__':
94 | main()
--------------------------------------------------------------------------------
/part_6/collator_sft.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import List, Tuple
3 | import torch
4 | import traceback
5 |
6 | # Reuse tokenizers: prefer BPE from Part 4 if available; else byte-level from Part 3
7 | import sys
8 | from pathlib import Path as _P
9 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_4'))
10 | try:
11 | from tokenizer_bpe import BPETokenizer
12 | _HAS_BPE = True
13 | except Exception:
14 | _HAS_BPE = False
15 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3'))
16 | try:
17 | from tokenizer import ByteTokenizer
18 | except Exception:
19 | ByteTokenizer = None
20 |
21 | from formatters import Example, format_example, format_prompt_only
22 |
23 | class SFTCollator:
24 | """Turn (instruction,response) into token ids and masked labels for causal LM (6.2).
25 | Labels for the prompt part are set to -100 so they don't contribute to loss.
26 | """
27 | def __init__(self, block_size: int = 256, bpe_dir: str | None = None):
28 | self.block_size = block_size
29 | self.tok = None
30 | if _HAS_BPE:
31 | # If a trained tokenizer directory exists from Part 4, you can `load` it.
32 | # Otherwise we create an ad-hoc BPE on the fly using fallback prompts during demo.
33 | try:
34 | self.tok = BPETokenizer(vocab_size=8000)
35 | if bpe_dir:
36 | self.tok.load(bpe_dir)
37 | print(f"Loaded BPE tokenizer from {bpe_dir}")
38 | else:
39 | # weak ad-hoc training would belong elsewhere; for the demo we assume Part 4 tokenizer exists
40 | pass
41 | except Exception:
42 | print(traceback.format_exc())
43 | self.tok = None
44 | if self.tok is None and ByteTokenizer is not None:
45 | self.tok = ByteTokenizer()
46 | if self.tok is None:
47 | raise RuntimeError("No tokenizer available. Install tokenizers or ensure Part 3 ByteTokenizer exists.")
48 |
49 | @property
50 | def vocab_size(self) -> int:
51 | return getattr(self.tok, 'vocab_size', 256)
52 |
53 | def encode(self, text: str) -> List[int]:
54 | if hasattr(self.tok, 'encode'):
55 | ids = self.tok.encode(text)
56 | if isinstance(ids, torch.Tensor):
57 | ids = ids.tolist()
58 | return ids
59 | # ByteTokenizer-like
60 | return list(text.encode('utf-8'))
61 |
62 | def collate(self, batch: List[Tuple[str,str]]):
63 | # Build "prompt + response" and create label mask where prompt positions are -100.
64 | input_ids = []
65 | labels = []
66 | for prompt, response in batch:
67 | prefix_text = format_prompt_only(prompt).replace('','')
68 | text = format_example(Example(prompt, response))
69 | ids = self.encode(text)[:self.block_size]
70 | prompt_ids = self.encode(prefix_text)[:self.block_size]
71 | n_prompt = min(len(prompt_ids), len(ids))
72 | x = ids
73 | y = ids.copy()
74 | for t in range(len(y) - 1):
75 | y[t] = ids[t + 1]
76 | y[-1] = -100
77 | for i in range(n_prompt-1):
78 | y[i] = -100
79 | input_ids.append(x)
80 | labels.append(y)
81 | # pad to block_size
82 | def pad_to(ids, val):
83 | if len(ids) < self.block_size:
84 | ids = ids + [val]*(self.block_size - len(ids))
85 | return ids[:self.block_size]
86 | x = torch.tensor([pad_to(s, 2) for s in input_ids], dtype=torch.long)
87 | y = torch.tensor([pad_to(s, -100) for s in labels], dtype=torch.long)
88 | return x, y
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LLM from Scratch — Hands-On Curriculum (PyTorch)
2 |
3 | [](https://youtu.be/p3sij8QzONQ?si=yEuD584cBZRNiUYm)
4 |
5 | ## Part 0 — Foundations & Mindset
6 | - **0.1** Understanding the high-level LLM training pipeline (pretraining → finetuning → alignment)
7 | - **0.2** Hardware & software environment setup (PyTorch, CUDA/Mac, mixed precision, profiling tools)
8 |
9 | ```
10 | conda create -n llm_from_scratch python=3.11
11 | conda activate llm_from_scratch
12 | pip install -r requirements.txt
13 | ```
14 |
15 | ## Part 1 — Core Transformer Architecture
16 | - **1.1** Positional embeddings (absolute learned vs. sinusoidal)
17 | - **1.2** Self-attention from first principles (manual computation with a tiny example)
18 | - **1.3** Building a *single attention head* in PyTorch
19 | - **1.4** Multi-head attention (splitting, concatenation, projections)
20 | - **1.5** Feed-forward networks (MLP layers) — GELU, dimensionality expansion
21 | - **1.5** Residual connections & **LayerNorm**
22 | - **1.6** Stacking into a full Transformer block
23 |
24 | ## Part 2 — Training a Tiny LLM
25 | - **2.1** Byte-level tokenization
26 | - **2.2** Dataset batching & shifting for next-token prediction
27 | - **2.3** Cross-entropy loss & label shifting
28 | - **2.4** Training loop from scratch (no Trainer API)
29 | - **2.5** Sampling: temperature, top-k, top-p
30 | - **2.6** Evaluating loss on val set
31 |
32 | ## Part 3 — Modernizing the Architecture
33 | - **3.1** **RMSNorm** (replace LayerNorm, compare gradients & convergence)
34 | - **3.2** **RoPE** (Rotary Positional Embeddings) — theory & code
35 | - **3.3** SwiGLU activations in MLP
36 | - **3.4** KV cache for faster inference
37 | - **3.5** Sliding-window attention & **attention sink**
38 | - **3.6** Rolling buffer KV cache for streaming
39 |
40 | ## Part 4 — Scaling Up
41 | - **4.1** Switching from byte-level to BPE tokenization
42 | - **4.2** Gradient accumulation & mixed precision
43 | - **4.3** Learning rate schedules & warmup
44 | - **4.4** Checkpointing & resuming
45 | - **4.5** Logging & visualization (TensorBoard / wandb)
46 |
47 | ## Part 5 — Mixture-of-Experts (MoE)
48 | - **5.1** MoE theory: expert routing, gating networks, and load balancing
49 | - **5.2** Implementing MoE layers in PyTorch
50 | - **5.3** Combining MoE with dense layers for hybrid architectures
51 |
52 | ## Part 6 — Supervised Fine-Tuning (SFT)
53 | - **6.1** Instruction dataset formatting (prompt + response)
54 | - **6.2** Causal LM loss with masked labels
55 | - **6.3** Curriculum learning for instruction data
56 | - **6.4** Evaluating outputs against gold responses
57 |
58 | ## Part 7 — Reward Modeling
59 | - **7.1** Preference datasets (pairwise rankings)
60 | - **7.2** Reward model architecture (transformer encoder)
61 | - **7.3** Loss functions: Bradley–Terry, margin ranking loss
62 | - **7.4** Sanity checks for reward shaping
63 |
64 | ## Part 8 — RLHF with PPO
65 | - **8.1** Policy network: our base LM (from SFT) with a value head for reward prediction.
66 | - **8.2** Reward signal: provided by the reward model trained in Part 7.
67 | - **8.3** PPO objective: balance between maximizing reward and staying close to the SFT policy (KL penalty).
68 | - **8.4** Training loop: sample prompts → generate completions → score with reward model → optimize policy via PPO.
69 | - **8.5** Logging & stability tricks: reward normalization, KL-controlled rollout length, gradient clipping.
70 |
71 | ## Part 9 — RLHF with GRPO
72 | - **9.1** Group-relative baseline: instead of a value head, multiple completions are sampled per prompt and their rewards are normalized against the group mean.
73 | - **9.2** Advantage calculation: each completion’s advantage = (reward – group mean reward), broadcast to all tokens in that trajectory.
74 | - **9.3** Objective: PPO-style clipped policy loss, but *policy-only* (no value loss).
75 | - **9.4** KL regularization: explicit KL(π‖π_ref) penalty term added directly to the loss (not folded into the advantage).
76 | - **9.5** Training loop differences: sample `k` completions per prompt → compute rewards → subtract per-prompt mean → apply GRPO loss with KL penalty.
77 |
--------------------------------------------------------------------------------
/part_8/rollout.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch
3 | from typing import List, Tuple
4 |
5 | # tokenizer pref: BPE from Part 4 → fallback to ByteTokenizer from Part 3
6 | import sys
7 | from pathlib import Path as _P
8 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_4'))
9 | try:
10 | from tokenizer_bpe import BPETokenizer
11 | _HAS_BPE = True
12 | except Exception:
13 | _HAS_BPE = False
14 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3'))
15 | try:
16 | from tokenizer import ByteTokenizer
17 | except Exception:
18 | ByteTokenizer = None
19 |
20 | from part_6.formatters import Example, format_example, format_prompt_only
21 |
22 | # ---------- tokenizer helpers ----------
23 | class RLHFTokenizer:
24 | def __init__(self, block_size: int, bpe_dir: str | None = None, vocab_size: int = 8000):
25 | self.block_size = block_size
26 | self.tok = None
27 | if _HAS_BPE:
28 | try:
29 | self.tok = BPETokenizer(vocab_size=vocab_size)
30 | if bpe_dir:
31 | self.tok.load(bpe_dir)
32 | except Exception:
33 | self.tok = None
34 | if self.tok is None and ByteTokenizer is not None:
35 | self.tok = ByteTokenizer()
36 | if self.tok is None:
37 | raise RuntimeError("No tokenizer available for RLHF.")
38 |
39 | @property
40 | def vocab_size(self) -> int:
41 | return getattr(self.tok, 'vocab_size', 256)
42 |
43 | def encode(self, text: str) -> List[int]:
44 | ids = self.tok.encode(text)
45 | if isinstance(ids, torch.Tensor):
46 | ids = ids.tolist()
47 | return ids
48 |
49 | def decode(self, ids: List[int]) -> str:
50 | if hasattr(self.tok, 'decode'):
51 | return self.tok.decode(ids)
52 | return bytes(ids).decode('utf-8', errors='ignore')
53 |
54 | # ---------- logprob utilities ----------
55 |
56 | def shift_labels(x: torch.Tensor) -> torch.Tensor:
57 | # For causal LM: predict x[t+1] from x[:t]
58 | return x[:, 1:].contiguous()
59 |
60 | def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
61 | """Compute per-token logprobs of the given labels.
62 | logits: (B,T,V), labels: (B,T) over same T
63 | returns: (B,T) log p(labels)
64 | """
65 | logp = torch.log_softmax(logits, dim=-1)
66 | return logp.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
67 |
68 | @torch.no_grad()
69 | def model_logprobs(model, x: torch.Tensor) -> torch.Tensor:
70 | # compute log p(x[t+1] | x[:t]) for t
71 | logits, _, _ = model.lm(x, None) if hasattr(model, 'lm') else model(x, None)
72 | labels = shift_labels(x)
73 | lp = gather_logprobs(logits[:, :-1, :], labels)
74 | return lp # (B, T-1)
75 |
76 | # ---------- KL ----------
77 |
78 | def approx_kl(policy_logp: torch.Tensor, ref_logp: torch.Tensor) -> torch.Tensor:
79 | # Mean over tokens: KL(pi||ref) ≈ (logp_pi - logp_ref).mean()
80 | return (policy_logp - ref_logp).mean()
81 |
82 | # ---------- small prompt source ----------
83 | try:
84 | from datasets import load_dataset as _load_ds
85 | except Exception:
86 | _load_ds = None
87 |
88 | def sample_prompts(n: int) -> List[str]:
89 | if _load_ds is not None:
90 | try:
91 | ds = _load_ds("tatsu-lab/alpaca", split="train[:24]")
92 | arr = []
93 | for r in ds:
94 | inst = (r.get('instruction') or '').strip()
95 | inp = (r.get('input') or '').strip()
96 | if inp:
97 | inst = inst + "\n" + inp
98 | if inst:
99 | arr.append(inst)
100 | if len(arr) >= n:
101 | break
102 | if arr:
103 | return arr
104 | except Exception:
105 | pass
106 | # fallback
107 | base = [
108 | "Explain the purpose of attention in transformers.",
109 | "Give two pros and cons of BPE tokenization.",
110 | "Summarize why PPO is used in RLHF.",
111 | "Write a tiny Python function that reverses a list.",
112 | ]
113 | return (base * ((n+len(base)-1)//len(base)))[:n]
--------------------------------------------------------------------------------
/part_9/rollout.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch
3 | from typing import List, Tuple
4 |
5 | # tokenizer pref: BPE from Part 4 → fallback to ByteTokenizer from Part 3
6 | import sys
7 | from pathlib import Path as _P
8 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_4'))
9 | try:
10 | from tokenizer_bpe import BPETokenizer
11 | _HAS_BPE = True
12 | except Exception:
13 | _HAS_BPE = False
14 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3'))
15 | try:
16 | from tokenizer import ByteTokenizer
17 | except Exception:
18 | ByteTokenizer = None
19 |
20 | from part_6.formatters import Example, format_example, format_prompt_only
21 |
22 | # ---------- tokenizer helpers ----------
23 | class RLHFTokenizer:
24 | def __init__(self, block_size: int, bpe_dir: str | None = None, vocab_size: int = 8000):
25 | self.block_size = block_size
26 | self.tok = None
27 | if _HAS_BPE:
28 | try:
29 | self.tok = BPETokenizer(vocab_size=vocab_size)
30 | if bpe_dir:
31 | self.tok.load(bpe_dir)
32 | except Exception:
33 | self.tok = None
34 | if self.tok is None and ByteTokenizer is not None:
35 | self.tok = ByteTokenizer()
36 | if self.tok is None:
37 | raise RuntimeError("No tokenizer available for RLHF.")
38 |
39 | @property
40 | def vocab_size(self) -> int:
41 | return getattr(self.tok, 'vocab_size', 256)
42 |
43 | def encode(self, text: str) -> List[int]:
44 | ids = self.tok.encode(text)
45 | if isinstance(ids, torch.Tensor):
46 | ids = ids.tolist()
47 | return ids
48 |
49 | def decode(self, ids: List[int]) -> str:
50 | if hasattr(self.tok, 'decode'):
51 | return self.tok.decode(ids)
52 | return bytes(ids).decode('utf-8', errors='ignore')
53 |
54 | # ---------- logprob utilities ----------
55 |
56 | def shift_labels(x: torch.Tensor) -> torch.Tensor:
57 | # For causal LM: predict x[t+1] from x[:t]
58 | return x[:, 1:].contiguous()
59 |
60 | def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
61 | """Compute per-token logprobs of the given labels.
62 | logits: (B,T,V), labels: (B,T) over same T
63 | returns: (B,T) log p(labels)
64 | """
65 | logp = torch.log_softmax(logits, dim=-1)
66 | return logp.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
67 |
68 | @torch.no_grad()
69 | def model_logprobs(model, x: torch.Tensor) -> torch.Tensor:
70 | # compute log p(x[t+1] | x[:t]) for t
71 | logits, _, _ = model.lm(x, None) if hasattr(model, 'lm') else model(x, None)
72 | labels = shift_labels(x)
73 | lp = gather_logprobs(logits[:, :-1, :], labels)
74 | return lp # (B, T-1)
75 |
76 | # ---------- KL ----------
77 |
78 | def approx_kl(policy_logp: torch.Tensor, ref_logp: torch.Tensor) -> torch.Tensor:
79 | # Mean over tokens: KL(pi||ref) ≈ (logp_pi - logp_ref).mean()
80 | return (policy_logp - ref_logp).mean()
81 |
82 | # ---------- small prompt source ----------
83 | try:
84 | from datasets import load_dataset as _load_ds
85 | except Exception:
86 | _load_ds = None
87 |
88 | def sample_prompts(n: int) -> List[str]:
89 | if _load_ds is not None:
90 | try:
91 | ds = _load_ds("tatsu-lab/alpaca", split="train[:24]")
92 | arr = []
93 | for r in ds:
94 | inst = (r.get('instruction') or '').strip()
95 | inp = (r.get('input') or '').strip()
96 | if inp:
97 | inst = inst + "\n" + inp
98 | if inst:
99 | arr.append(inst)
100 | if len(arr) >= n:
101 | break
102 | if arr:
103 | return arr
104 | except Exception:
105 | pass
106 | # fallback
107 | base = [
108 | "Explain the purpose of attention in transformers.",
109 | "Give two pros and cons of BPE tokenization.",
110 | "Summarize why PPO is used in RLHF.",
111 | "Write a tiny Python function that reverses a list.",
112 | ]
113 | return (base * ((n+len(base)-1)//len(base)))[:n]
--------------------------------------------------------------------------------
/part_6/train_sft.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import argparse, torch
3 | import torch.nn as nn
4 | from pathlib import Path
5 | torch.manual_seed(0)
6 |
7 | # Reuse GPTModern from Part 3
8 | import sys
9 | from pathlib import Path as _P
10 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3'))
11 | from model_modern import GPTModern # noqa: E402
12 |
13 | from dataset_sft import load_tiny_hf
14 | from collator_sft import SFTCollator
15 | from curriculum import LengthCurriculum
16 |
17 |
18 | def main():
19 | p = argparse.ArgumentParser()
20 | p.add_argument('--data', type=str, default='huggingface', help='huggingface or path to local jsonl (unused in demo)')
21 | p.add_argument('--ckpt', type=str, required=False)
22 | p.add_argument('--out', type=str, default='runs/sft')
23 | p.add_argument('--steps', type=int, default=200)
24 | p.add_argument('--batch_size', type=int, default=8)
25 | p.add_argument('--block_size', type=int, default=256)
26 | p.add_argument('--n_layer', type=int, default=4)
27 | p.add_argument('--n_head', type=int, default=4)
28 | p.add_argument('--n_embd', type=int, default=256)
29 | p.add_argument('--lr', type=float, default=3e-4)
30 | p.add_argument('--cpu', action='store_true')
31 | p.add_argument('--bpe_dir', type=str, default='../part_4/runs/part4-demo/tokenizer') # assumes tokenizer exists from Part 4
32 | args = p.parse_args()
33 |
34 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
35 |
36 | # Load a tiny HF slice or fallback examples
37 | items = load_tiny_hf(split='train[:24]', sample_dataset=False)
38 |
39 | # Print few samples
40 | print(f"Loaded {len(items)} SFT items. Few samples:")
41 | for it in items[:3]:
42 | print(f"PROMPT: {it.prompt}\nRESPONSE: {it.response}\n{'-'*40}")
43 |
44 | # Curriculum over (prompt,response)
45 | tuples = [(it.prompt, it.response) for it in items]
46 | cur = list(LengthCurriculum(tuples))
47 | print(cur)
48 |
49 | # Collator + model
50 | col = SFTCollator(block_size=args.block_size, bpe_dir=args.bpe_dir)
51 | model = GPTModern(vocab_size=col.vocab_size, block_size=args.block_size,
52 | n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd,
53 | use_rmsnorm=True, use_swiglu=True, rope=True).to(device)
54 |
55 | if args.ckpt:
56 | print(f"Using model config from checkpoint {args.ckpt}")
57 | ckpt = torch.load(args.ckpt, map_location=device)
58 | cfg = ckpt.get('config', {})
59 | model.load_state_dict(ckpt['model'])
60 |
61 | opt = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.1)
62 | model.train()
63 |
64 | # Simple loop (single machine). We just cycle curriculum to fill batches, for a few steps.
65 | step = 0
66 | i = 0
67 | while step < args.steps:
68 | batch = cur[i:i+args.batch_size]
69 | if not batch:
70 | # restart curriculum
71 | # cur = list(LengthCurriculum(tuples));
72 | i = 0
73 | continue
74 | xb, yb = col.collate(batch)
75 | xb, yb = xb.to(device), yb.to(device)
76 | logits, loss, _ = model(xb, yb)
77 | opt.zero_grad(set_to_none=True)
78 | loss.backward()
79 | opt.step()
80 | step += 1; i += args.batch_size
81 | if step % 20 == 0:
82 | print(f"step {step}: loss={loss.item():.4f}")
83 |
84 | Path(args.out).mkdir(parents=True, exist_ok=True)
85 | cfg = {
86 | "vocab_size": col.vocab_size,
87 | "block_size": args.block_size,
88 | "n_layer": args.n_layer,
89 | "n_head": args.n_head,
90 | "n_embd": args.n_embd,
91 | "dropout": 0.0,
92 | "use_rmsnorm": True,
93 | "use_swiglu": True,
94 | "rope": True,
95 | # tokenizer info (best-effort)
96 | "tokenizer_type": "byte" if col.vocab_size == 256 else "bpe",
97 | "tokenizer_dir": None, # set a real path if you have a trained BPE dir
98 | }
99 | torch.save({'model': model.state_dict(), 'config': cfg},
100 | str(Path(args.out)/'model_last.pt'))
101 | print(f"Saved SFT checkpoint to {args.out}/model_last.pt")
102 |
103 | if __name__ == '__main__':
104 | main()
--------------------------------------------------------------------------------
/part_3/attn_modern.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import math, torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from rope_custom import RoPECache, apply_rope_single
6 | from kv_cache import KVCache # your existing class
7 |
8 | class CausalSelfAttentionModern(nn.Module):
9 | def __init__(self, n_embd: int, n_head: int, dropout: float = 0.0,
10 | rope: bool = True, max_pos: int = 4096,
11 | sliding_window: int | None = None, attention_sink: int = 0,
12 | n_kv_head: int | None = None): # ← NEW
13 | super().__init__()
14 | assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
15 | self.n_head = n_head
16 | self.n_kv_head = n_kv_head or n_head # ← NEW (GQA defaults to MHA)
17 | assert self.n_head % self.n_kv_head == 0, "n_head must be multiple of n_kv_head (GQA grouping)"
18 | self.group_size = self.n_head // self.n_kv_head
19 | self.d_head = n_embd // n_head
20 |
21 | # Separate projections for Q vs K/V (sizes differ under GQA) ← CHANGED
22 | self.wq = nn.Linear(n_embd, self.n_head * self.d_head, bias=False)
23 | self.wk = nn.Linear(n_embd, self.n_kv_head * self.d_head, bias=False)
24 | self.wv = nn.Linear(n_embd, self.n_kv_head * self.d_head, bias=False)
25 | self.proj = nn.Linear(n_embd, n_embd, bias=False)
26 | self.dropout = nn.Dropout(dropout)
27 |
28 | self.use_rope = rope
29 | self.rope_cache: RoPECache | None = None
30 | self.max_pos = max_pos
31 | self.sliding_window = sliding_window
32 | self.attention_sink = attention_sink
33 |
34 | def _maybe_init_rope(self, device):
35 | if self.use_rope and self.rope_cache is None:
36 | self.rope_cache = RoPECache(self.d_head, self.max_pos, device=device)
37 |
38 | def forward(self, x: torch.Tensor, kv_cache: KVCache | None = None, start_pos: int = 0):
39 | """x: (B,T,C). If kv_cache given, we assume generation (T small, often 1)."""
40 | B, T, C = x.shape
41 | self._maybe_init_rope(x.device)
42 |
43 | # Projections
44 | q = self.wq(x).view(B, T, self.n_head, self.d_head).transpose(1, 2) # (B,H, T,D)
45 | k = self.wk(x).view(B, T, self.n_kv_head, self.d_head).transpose(1, 2) # (B,Hk,T,D)
46 | v = self.wv(x).view(B, T, self.n_kv_head, self.d_head).transpose(1, 2) # (B,Hk,T,D)
47 |
48 | # RoPE on *current* tokens (cached keys are already rotated)
49 | if self.use_rope:
50 | pos = torch.arange(start_pos, start_pos + T, device=x.device)
51 | cos, sin = self.rope_cache.get(pos)
52 | q = apply_rope_single(q, cos, sin) # (B,H, T,D)
53 | k = apply_rope_single(k, cos, sin) # (B,Hk,T,D)
54 |
55 | # Concatenate past cache (cache is stored in Hk heads)
56 | if kv_cache is not None:
57 | k_all = torch.cat([kv_cache.k, k], dim=2) # (B,Hk, Tpast+T, D)
58 | v_all = torch.cat([kv_cache.v, v], dim=2)
59 | else:
60 | k_all, v_all = k, v
61 |
62 | # Sliding-window + attention-sink (crop along seq length)
63 | if self.sliding_window is not None and k_all.size(2) > (self.sliding_window + self.attention_sink):
64 | s = self.attention_sink
65 | k_all = torch.cat([k_all[:, :, :s, :], k_all[:, :, -self.sliding_window:, :]], dim=2)
66 | v_all = torch.cat([v_all[:, :, :s, :], v_all[:, :, -self.sliding_window:, :]], dim=2)
67 |
68 | # --- GQA expand: repeat K/V heads to match Q heads before attention ---
69 | if self.n_kv_head != self.n_head:
70 | k_attn = k_all.repeat_interleave(self.group_size, dim=1) # (B,H,Tk,D)
71 | v_attn = v_all.repeat_interleave(self.group_size, dim=1) # (B,H,Tk,D)
72 | else:
73 | k_attn, v_attn = k_all, v_all
74 |
75 | # Scaled dot-product attention (PyTorch scales internally)
76 | is_causal = kv_cache is None
77 | y = F.scaled_dot_product_attention(q, k_attn, v_attn,
78 | attn_mask=None,
79 | dropout_p=self.dropout.p if self.training else 0.0,
80 | is_causal=is_causal) # (B,H,T,D)
81 |
82 | y = y.transpose(1, 2).contiguous().view(B, T, C)
83 | y = self.proj(y)
84 |
85 | # Update KV cache (store compact Hk heads, not expanded)
86 | if kv_cache is not None:
87 | k_new = torch.cat([kv_cache.k, k], dim=2) # (B,Hk,*,D)
88 | v_new = torch.cat([kv_cache.v, v], dim=2)
89 | else:
90 | k_new, v_new = k, v
91 | new_cache = KVCache(k_new, v_new)
92 | return y, new_cache
93 |
--------------------------------------------------------------------------------
/part_2/model_gpt.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | # ---- Blocks (self-contained for isolation) ----
8 | class CausalSelfAttention(nn.Module):
9 | def __init__(self, n_embd: int, n_head: int, dropout: float = 0.0):
10 | super().__init__()
11 | assert n_embd % n_head == 0
12 | self.n_head = n_head
13 | self.d_head = n_embd // n_head
14 | self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
15 | self.proj = nn.Linear(n_embd, n_embd, bias=False)
16 | self.dropout = nn.Dropout(dropout)
17 |
18 | def forward(self, x: torch.Tensor): # (B,T,C)
19 | B, T, C = x.shape
20 | qkv = self.qkv(x).view(B, T, 3, self.n_head, self.d_head)
21 | q, k, v = qkv.unbind(dim=2)
22 | q = q.transpose(1, 2)
23 | k = k.transpose(1, 2)
24 | v = v.transpose(1, 2)
25 | scale = 1.0 / math.sqrt(self.d_head)
26 | # PyTorch SDPA (uses flash when available)
27 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout.p if self.training else 0.0, is_causal=True)
28 | y = y.transpose(1, 2).contiguous().view(B, T, C)
29 | y = self.proj(y)
30 | return y
31 |
32 | class FeedForward(nn.Module):
33 | def __init__(self, n_embd: int, mult: int = 4, dropout: float = 0.0):
34 | super().__init__()
35 | self.net = nn.Sequential(
36 | nn.Linear(n_embd, mult * n_embd),
37 | nn.GELU(),
38 | nn.Linear(mult * n_embd, n_embd),
39 | nn.Dropout(dropout),
40 | )
41 |
42 | def forward(self, x):
43 | return self.net(x)
44 |
45 | class Block(nn.Module):
46 | def __init__(self, n_embd: int, n_head: int, dropout: float):
47 | super().__init__()
48 | self.ln1 = nn.LayerNorm(n_embd)
49 | self.attn = CausalSelfAttention(n_embd, n_head, dropout)
50 | self.ln2 = nn.LayerNorm(n_embd)
51 | self.ffn = FeedForward(n_embd, mult=4, dropout=dropout)
52 |
53 | def forward(self, x):
54 | x = x + self.attn(self.ln1(x))
55 | x = x + self.ffn(self.ln2(x))
56 | return x
57 |
58 | # ---- Tiny GPT ----
59 | class GPT(nn.Module):
60 | def __init__(self, vocab_size: int, block_size: int, n_layer: int = 4, n_head: int = 4, n_embd: int = 256, dropout: float = 0.0):
61 | super().__init__()
62 | self.block_size = block_size
63 | self.tok_emb = nn.Embedding(vocab_size, n_embd)
64 | self.pos_emb = nn.Embedding(block_size, n_embd)
65 | self.drop = nn.Dropout(dropout)
66 | self.blocks = nn.ModuleList([Block(n_embd, n_head, dropout) for _ in range(n_layer)])
67 | self.ln_f = nn.LayerNorm(n_embd)
68 | self.head = nn.Linear(n_embd, vocab_size, bias=False)
69 |
70 | self.apply(self._init_weights)
71 |
72 | def _init_weights(self, m):
73 | if isinstance(m, nn.Linear):
74 | nn.init.normal_(m.weight, mean=0.0, std=0.02)
75 | if m.bias is not None:
76 | nn.init.zeros_(m.bias)
77 | elif isinstance(m, nn.Embedding):
78 | nn.init.normal_(m.weight, mean=0.0, std=0.02)
79 |
80 | def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None):
81 | B, T = idx.shape
82 | assert T <= self.block_size
83 | pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
84 | x = self.tok_emb(idx) + self.pos_emb(pos)
85 | x = self.drop(x)
86 | for blk in self.blocks:
87 | x = blk(x)
88 | x = self.ln_f(x)
89 | logits = self.head(x)
90 | loss = None
91 | if targets is not None:
92 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
93 | return logits, loss
94 |
95 | @torch.no_grad()
96 | def generate(self, idx: torch.Tensor, max_new_tokens: int = 200, temperature: float = 1.0,
97 | top_k: int | None = 50, top_p: float | None = None):
98 | from utils import top_k_top_p_filtering
99 | self.eval()
100 | # Guard: if the prompt is empty, start with a newline byte (10)
101 | if idx.size(1) == 0:
102 | idx = torch.full((idx.size(0), 1), 10, dtype=torch.long, device=idx.device)
103 | for _ in range(max_new_tokens):
104 | idx_cond = idx[:, -self.block_size:]
105 | logits, _ = self(idx_cond)
106 | logits = logits[:, -1, :] / max(temperature, 1e-6)
107 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
108 | probs = torch.softmax(logits, dim=-1)
109 | next_id = torch.multinomial(probs, num_samples=1)
110 | idx = torch.cat([idx, next_id], dim=1)
111 | return idx
112 |
--------------------------------------------------------------------------------
/part_2/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import argparse, time
3 | import torch
4 | from tokenizer import ByteTokenizer
5 | from dataset import ByteDataset
6 | from model_gpt import GPT
7 |
8 |
9 | def estimate_loss(model: GPT, ds: ByteDataset, args) -> dict:
10 | model.eval()
11 | out = {}
12 | with torch.no_grad():
13 | for split in ['train', 'val']:
14 | losses = []
15 | for _ in range(args.eval_iters):
16 | xb, yb = ds.get_batch(split, args.batch_size, args.device)
17 | _, loss = model(xb, yb)
18 | losses.append(loss.item())
19 | out[split] = sum(losses) / len(losses)
20 | model.train()
21 | return out
22 |
23 |
24 | def main():
25 | p = argparse.ArgumentParser()
26 | p.add_argument('--data', type=str, required=True)
27 | p.add_argument('--out_dir', type=str, default='runs/min-gpt')
28 | p.add_argument('--block_size', type=int, default=256)
29 | p.add_argument('--batch_size', type=int, default=32)
30 | p.add_argument('--n_layer', type=int, default=4)
31 | p.add_argument('--n_head', type=int, default=4)
32 | p.add_argument('--n_embd', type=int, default=256)
33 | p.add_argument('--dropout', type=float, default=0.0)
34 | p.add_argument('--steps', type=int, default=2000)
35 | p.add_argument('--lr', type=float, default=3e-4)
36 | p.add_argument('--weight_decay', type=float, default=0.1)
37 | p.add_argument('--grad_clip', type=float, default=1.0)
38 | p.add_argument('--eval_interval', type=int, default=200)
39 | p.add_argument('--eval_iters', type=int, default=50)
40 | p.add_argument('--sample_every', type=int, default=200)
41 | p.add_argument('--sample_tokens', type=int, default=256)
42 | p.add_argument('--temperature', type=float, default=1.0)
43 | p.add_argument('--top_k', type=int, default=50)
44 | p.add_argument('--top_p', type=float, default=None)
45 | p.add_argument('--cpu', action='store_true')
46 | p.add_argument('--compile', action='store_true')
47 | p.add_argument('--amp', action='store_true')
48 | args = p.parse_args()
49 |
50 | args.device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
51 |
52 | tok = ByteTokenizer()
53 | ds = ByteDataset(args.data, block_size=args.block_size)
54 | model = GPT(tok.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(args.device)
55 |
56 | if args.compile and hasattr(torch, 'compile'):
57 | model = torch.compile(model)
58 |
59 | opt = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weight_decay)
60 | scaler = torch.cuda.amp.GradScaler(enabled=(args.amp and args.device.type == 'cuda'))
61 |
62 | best_val = float('inf')
63 | t0 = time.time()
64 | model.train()
65 | for step in range(1, args.steps + 1):
66 | xb, yb = ds.get_batch('train', args.batch_size, args.device)
67 | with torch.cuda.amp.autocast(enabled=(args.amp and args.device.type == 'cuda')):
68 | _, loss = model(xb, yb)
69 | opt.zero_grad(set_to_none=True)
70 | scaler.scale(loss).backward()
71 | if args.grad_clip > 0:
72 | scaler.unscale_(opt)
73 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
74 | scaler.step(opt)
75 | scaler.update()
76 |
77 | if step % 50 == 0:
78 | print(f"step {step:5d} | loss {loss.item():.4f} | {(time.time()-t0):.1f}s")
79 | t0 = time.time()
80 |
81 | if step % args.eval_interval == 0:
82 | losses = estimate_loss(model, ds, args)
83 | print(f"eval | train {losses['train']:.4f} | val {losses['val']:.4f}")
84 | if losses['val'] < best_val:
85 | best_val = losses['val']
86 | ckpt_path = f"{args.out_dir}/model_best.pt"
87 | import os; os.makedirs(args.out_dir, exist_ok=True)
88 | torch.save({'model': model.state_dict(), 'config': {
89 | 'vocab_size': tok.vocab_size,
90 | 'block_size': args.block_size,
91 | 'n_layer': args.n_layer,
92 | 'n_head': args.n_head,
93 | 'n_embd': args.n_embd,
94 | 'dropout': args.dropout,
95 | }}, ckpt_path)
96 | print(f"saved checkpoint: {ckpt_path}")
97 |
98 | if args.sample_every > 0 and step % args.sample_every == 0:
99 | start = torch.randint(low=0, high=len(ds.train) - args.block_size - 1, size=(1,)).item()
100 | seed = ds.train[start:start + args.block_size].unsqueeze(0).to(args.device)
101 | out = model.generate(seed, max_new_tokens=args.sample_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p)
102 | txt = tok.decode(out[0].cpu())
103 | print("\n================ SAMPLE ================\n" + txt[-(args.block_size + args.sample_tokens):] + "\n=======================================\n")
104 |
105 | # final save
106 | import os; os.makedirs(args.out_dir, exist_ok=True)
107 | torch.save({'model': model.state_dict()}, f"{args.out_dir}/model_final.pt")
108 |
109 |
110 | if __name__ == '__main__':
111 | main()
--------------------------------------------------------------------------------
/part_3/model_modern.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch
3 | import torch.nn as nn
4 | from block_modern import TransformerBlockModern
5 | from tokenizer import ByteTokenizer
6 |
7 | # Get the absolute path to the folder that contains part_2 and part_3
8 | import os, sys
9 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
10 | sys.path.insert(0, parent_dir)
11 |
12 | class GPTModern(nn.Module):
13 | def __init__(self, vocab_size: int = 256, block_size: int = 256,
14 | n_layer: int=4, n_head: int=4, n_embd: int=256, dropout: float=0.0,
15 | use_rmsnorm: bool = True, use_swiglu: bool = True, rope: bool = True,
16 | max_pos: int = 4096, sliding_window: int | None = None, attention_sink: int = 0, n_kv_head: int | None = None):
17 | super().__init__()
18 | self.block_size = block_size
19 | self.tok_emb = nn.Embedding(vocab_size, n_embd)
20 | # self.pos_emb = nn.Embedding(block_size, n_embd)
21 | self.drop = nn.Dropout(dropout)
22 | self.blocks = nn.ModuleList([
23 | TransformerBlockModern(n_embd, n_head, dropout, use_rmsnorm, use_swiglu, rope, max_pos, sliding_window, attention_sink, n_kv_head)
24 | for _ in range(n_layer)
25 | ])
26 | self.ln_f = nn.Identity() if use_rmsnorm else nn.LayerNorm(n_embd)
27 | self.head = nn.Linear(n_embd, vocab_size, bias=False)
28 |
29 | def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache_list=None, start_pos: int = 0):
30 | B, T = idx.shape
31 | assert T <= self.block_size
32 | pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
33 | x = self.tok_emb(idx)
34 | # + self.pos_emb(pos)
35 | x = self.drop(x)
36 |
37 | new_caches = []
38 | for i, blk in enumerate(self.blocks):
39 | cache = None if kv_cache_list is None else kv_cache_list[i]
40 | x, cache = blk(x, kv_cache=cache, start_pos=start_pos)
41 | new_caches.append(cache)
42 | x = self.ln_f(x)
43 | logits = self.head(x)
44 |
45 | loss = None
46 | if targets is not None:
47 | import torch.nn.functional as F
48 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
49 | return logits, loss, new_caches
50 |
51 | @torch.no_grad()
52 | def generate(self,
53 | prompt: torch.Tensor,
54 | max_new_tokens=200,
55 | temperature=1.0,
56 | top_k=50,
57 | top_p=None,
58 | eos_id=1, # addition from part 6 for early stopping
59 | sliding_window: int | None = None,
60 | attention_sink: int = 0):
61 | try:
62 | from utils import top_k_top_p_filtering as _tk
63 | except Exception:
64 | _tk = lambda x, **_: x
65 |
66 | self.eval()
67 | idx = prompt
68 | kvs = [None] * len(self.blocks)
69 |
70 | for _ in range(max_new_tokens):
71 | # feed full prompt once; then only the last token
72 | idx_cond = idx[:, -self.block_size:] if kvs[0] is None else idx[:, -1:]
73 |
74 | # absolute start position from cache length (0 on first step)
75 | start_pos = 0 if kvs[0] is None else kvs[0].k.size(2)
76 |
77 | logits, _, kvs = self(idx_cond, kv_cache_list=kvs, start_pos=start_pos)
78 |
79 | next_logits = logits[:, -1, :] / max(temperature, 1e-6)
80 | next_logits = _tk(next_logits, top_k=top_k, top_p=top_p)
81 | probs = torch.softmax(next_logits, dim=-1)
82 | next_id = torch.argmax(probs, dim=-1, keepdim=True) if temperature == 0.0 else torch.multinomial(probs, 1)
83 | idx = torch.cat([idx, next_id], dim=1)
84 |
85 | # addition from part 6 for early stopping
86 | if eos_id is not None:
87 | if (next_id == eos_id).all():
88 | break
89 |
90 | return idx
91 |
92 |
93 | @torch.no_grad()
94 | def generate_nocache(self, prompt: torch.Tensor, max_new_tokens=200, temperature=1.0, top_k=50, top_p=None,
95 | sliding_window: int | None = None, attention_sink: int = 0):
96 | try:
97 | from utils import top_k_top_p_filtering as _tk
98 | except Exception:
99 | _tk = lambda x, **_: x
100 |
101 | self.eval()
102 | idx = prompt
103 |
104 | for _ in range(max_new_tokens):
105 | # always run a full forward over the cropped window, with NO cache
106 | idx_cond = idx[:, -self.block_size:]
107 | # absolute position of first token in the window (matches cached path)
108 | start_pos = idx.size(1) - idx_cond.size(1)
109 |
110 | logits, _, _ = self(idx_cond, kv_cache_list=None, start_pos=start_pos)
111 |
112 | next_logits = logits[:, -1, :] / max(temperature, 1e-6)
113 | next_logits = _tk(next_logits, top_k=top_k, top_p=top_p)
114 | probs = torch.softmax(next_logits, dim=-1)
115 | topv, topi = torch.topk(probs, 10)
116 | print("top ids:", topi.tolist())
117 | print("top vs:", topv.tolist())
118 | next_id = torch.argmax(probs, dim=-1, keepdim=True) if temperature == 0.0 else torch.multinomial(probs, 1)
119 | idx = torch.cat([idx, next_id], dim=1)
120 |
121 | return idx
122 |
123 |
--------------------------------------------------------------------------------
/part_4/logger.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import time
3 | from pathlib import Path
4 |
5 | class NoopLogger:
6 | def log(self, **kwargs):
7 | pass
8 | def close(self):
9 | pass
10 |
11 | class TBLogger(NoopLogger):
12 | """
13 | Backward compatible:
14 | - logger.log(step=..., loss=..., lr=...)
15 | Extras you can optionally use:
16 | - logger.hist("params/wte.weight", tensor, step)
17 | - logger.text("samples/generation", text, step)
18 | - logger.image("attn/heatmap", HWC_or_CHW_tensor_or_np, step)
19 | - logger.graph(model, example_batch)
20 | - logger.hparams(dict_of_config, dict_of_metrics_once)
21 | - logger.flush()
22 | Auto-behavior:
23 | - If a value in .log(...) is a tensor/ndarray with >1 element, it logs a histogram.
24 | - If key starts with "text/", logs as text.
25 | """
26 | # logger.py
27 | def __init__(self, out_dir: str, flush_secs: int = 10, run_name: str | None = None):
28 | self.w = None
29 | self.hparams_logged = False
30 | run_name = run_name or time.strftime("%Y%m%d-%H%M%S")
31 | run_dir = Path(out_dir) / run_name
32 | run_dir.mkdir(parents=True, exist_ok=True)
33 | try:
34 | from torch.utils.tensorboard import SummaryWriter
35 | self.w = SummaryWriter(log_dir=str(run_dir), flush_secs=flush_secs)
36 | except Exception as e:
37 | print(f"[TBLogger] TensorBoard not available: {e}. Logging disabled.")
38 | self._auto_hist_max_elems = 2048
39 | self.run_dir = str(run_dir) # handy for prints/debug
40 |
41 |
42 |
43 | # ---------- backwards-compatible ----------
44 | def log(self, step: Optional[int] = None, **kv: Any):
45 | if not self.w: return
46 | for k, v in kv.items():
47 | # text channel (opt-in via key prefix)
48 | if isinstance(k, str) and k.startswith("text/"):
49 | try:
50 | self.w.add_text(k[5:], str(v), global_step=step)
51 | except Exception:
52 | pass
53 | continue
54 |
55 | # scalar vs histogram auto-route
56 | try:
57 | import torch, numpy as np # lazy
58 | is_torch = isinstance(v, torch.Tensor)
59 | is_np = isinstance(v, np.ndarray)
60 | if is_torch or is_np:
61 | # scalar?
62 | numel = int(v.numel() if is_torch else v.size)
63 | if numel == 1:
64 | val = (v.item() if is_torch else float(v))
65 | self.w.add_scalar(k, float(val), global_step=step)
66 | else:
67 | # small-ish tensors => histogram
68 | if numel <= self._auto_hist_max_elems:
69 | self.w.add_histogram(k, v.detach().cpu() if is_torch else v, global_step=step)
70 | else:
71 | # fall back to scalar summary stats
72 | arr = v.detach().cpu().flatten().numpy() if is_torch else v.flatten()
73 | self.w.add_scalar(k + "/mean", float(arr.mean()), global_step=step)
74 | self.w.add_scalar(k + "/std", float(arr.std()), global_step=step)
75 | continue
76 | except Exception:
77 | pass
78 |
79 | # number-like
80 | try:
81 | self.w.add_scalar(k, float(v), global_step=step)
82 | except Exception:
83 | # swallow non-numeric junk silently (same behavior as before)
84 | pass
85 |
86 | # ---------- nice-to-have helpers ----------
87 | def hist(self, tag: str, values: Any, step: Optional[int] = None, bins: str = "tensorflow"):
88 | if not self.w: return
89 | try:
90 | import torch
91 | if isinstance(values, torch.Tensor):
92 | values = values.detach().cpu()
93 | self.w.add_histogram(tag, values, global_step=step, bins=bins)
94 | except Exception:
95 | pass
96 |
97 | def text(self, tag: str, text: str, step: Optional[int] = None):
98 | if not self.w: return
99 | try:
100 | self.w.add_text(tag, text, global_step=step)
101 | except Exception:
102 | pass
103 |
104 | def image(self, tag: str, img, step: Optional[int] = None):
105 | """
106 | img: torch.Tensor [C,H,W] or [H,W,C] or numpy array
107 | """
108 | if not self.w: return
109 | try:
110 | self.w.add_image(tag, img, global_step=step, dataformats="CHW" if getattr(img, "ndim", 0) == 3 and img.shape[0] in (1,3) else "HWC")
111 | except Exception:
112 | pass
113 |
114 | def graph(self, model, example_input):
115 | if not self.w: return
116 | try:
117 | # example_input: a Tensor batch or a tuple
118 | if not isinstance(example_input, tuple):
119 | example_input = (example_input,)
120 | self.w.add_graph(model, example_input)
121 | except Exception:
122 | pass # graph tracing can fail depending on model control flow; don't crash
123 |
124 | def hparams(self, hparams: Dict[str, Any], metrics_once: Optional[Dict[str, float]] = None):
125 | if not self.w or self.hparams_logged:
126 | return
127 | try:
128 | # Single, stable sub-run so it doesn’t spam the left pane
129 | self.w.add_hparams(hparams, metrics_once or {}, run_name="_hparams")
130 | self.hparams_logged = True
131 | except Exception:
132 | pass
133 |
134 | def flush(self):
135 | if self.w:
136 | try: self.w.flush()
137 | except Exception: pass
138 |
139 | def close(self):
140 | if self.w:
141 | try: self.w.close()
142 | except Exception: pass
143 |
144 | class WBLogger(NoopLogger):
145 | def __init__(self, project: str, run_name: str | None = None):
146 | try:
147 | import wandb
148 | wandb.init(project=project, name=run_name)
149 | self.wb = wandb
150 | except Exception:
151 | self.wb = None
152 | def log(self, **kv):
153 | if self.wb: self.wb.log(kv)
154 |
155 |
156 | def init_logger(which: str, out_dir: str = "runs/part4"):
157 | if which == 'tensorboard':
158 | tb = TBLogger(out_dir)
159 | return tb if tb.w is not None else NoopLogger()
160 | if which == 'wandb':
161 | return WBLogger(project='llm-part4')
162 | return NoopLogger()
163 |
--------------------------------------------------------------------------------
/part_4/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import argparse, time, signal
3 | from pathlib import Path
4 | import sys
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | # so we can import Part 3 model
10 | from pathlib import Path as _P
11 | sys.path.append(str(_P(__file__).resolve().parents[1] / 'part_3'))
12 | from model_modern import GPTModern
13 |
14 | from tokenizer_bpe import BPETokenizer
15 | from dataset_bpe import make_loader
16 | from lr_scheduler import WarmupCosineLR
17 | from amp_accum import AmpGrad
18 | from checkpointing import (
19 | load_checkpoint,
20 | _log_hparams_tb,
21 | _maybe_log_graph_tb,
22 | _is_tb,
23 | _log_model_stats,
24 | _maybe_log_attention,
25 | _log_samples_tb,
26 | _log_runtime,
27 | atomic_save_all,
28 | )
29 | from logger import init_logger
30 |
31 |
32 | def run_cfg_from_args(args, vocab_size: int) -> dict:
33 | return dict(
34 | vocab_size=vocab_size,
35 | block_size=args.block_size,
36 | n_layer=args.n_layer,
37 | n_head=args.n_head,
38 | n_embd=args.n_embd,
39 | dropout=args.dropout,
40 | use_rmsnorm=True,
41 | use_swiglu=True,
42 | rope=True,
43 | max_pos=4096,
44 | sliding_window=None,
45 | attention_sink=0,
46 | )
47 |
48 |
49 | def main():
50 | p = argparse.ArgumentParser()
51 | p.add_argument('--data', type=str, required=True)
52 | p.add_argument('--out', type=str, default='runs/part4')
53 |
54 | # tokenizer / model dims
55 | p.add_argument('--bpe', action='store_true', help='train and use a BPE tokenizer (recommended)')
56 | p.add_argument('--vocab_size', type=int, default=32000)
57 | p.add_argument('--block_size', type=int, default=256)
58 | p.add_argument('--n_layer', type=int, default=6)
59 | p.add_argument('--n_head', type=int, default=8)
60 | p.add_argument('--n_embd', type=int, default=512)
61 | p.add_argument('--dropout', type=float, default=0.0)
62 |
63 | # train
64 | p.add_argument('--batch_size', type=int, default=32)
65 | p.add_argument('--epochs', type=int, default=1)
66 | p.add_argument('--steps', type=int, default=300, help='max optimizer steps for this run')
67 | p.add_argument('--lr', type=float, default=3e-4)
68 | p.add_argument('--warmup_steps', type=int, default=20)
69 | p.add_argument('--mixed_precision', action='store_true')
70 | p.add_argument('--grad_accum_steps', type=int, default=4)
71 |
72 | # misc
73 | p.add_argument('--log', choices=['wandb', 'tensorboard', 'none'], default='tensorboard')
74 | p.add_argument('--save_every', type=int, default=50, help='save checkpoint every N optimizer steps')
75 | p.add_argument('--keep_last_k', type=int, default=2, help='keep last K step checkpoints (plus model_last.pt)')
76 | args = p.parse_args()
77 |
78 | # device
79 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
80 |
81 | # output dir and (possible) checkpoint
82 | out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True)
83 | ckpt_path = out_dir / "model_last.pt"
84 | have_ckpt = ckpt_path.exists()
85 |
86 | # ---- load checkpoint meta if present ----
87 | ckpt = None
88 | saved_tok_dir = None
89 | if have_ckpt:
90 | ckpt = torch.load(str(ckpt_path), map_location=device)
91 | if "config" not in ckpt:
92 | raise RuntimeError(
93 | "Checkpoint is missing 'config'."
94 | "Please re-save a checkpoint that includes the model config."
95 | )
96 | tok_file = ckpt_path.with_name("tokenizer_dir.txt")
97 | saved_tok_dir = tok_file.read_text().strip() if tok_file.exists() else None
98 |
99 | # ---- tokenizer ----
100 | tok = None
101 | tok_dir = None
102 | if have_ckpt:
103 | if not saved_tok_dir:
104 | raise RuntimeError(
105 | "Checkpoint was found but tokenizer_dir.txt is missing. "
106 | "Resume requires the original tokenizer."
107 | )
108 | tok = BPETokenizer(); tok.load(saved_tok_dir)
109 | tok_dir = saved_tok_dir
110 | vocab_size = tok.vocab_size
111 | print(f"[resume] Loaded tokenizer from {tok_dir} (vocab={vocab_size})")
112 | else:
113 | if args.bpe:
114 | tok = BPETokenizer(vocab_size=args.vocab_size)
115 | tok.train(args.data)
116 | tok_dir = str(out_dir / 'tokenizer'); Path(tok_dir).mkdir(parents=True, exist_ok=True)
117 | tok.save(tok_dir)
118 | vocab_size = tok.vocab_size
119 | print(f"[init] Trained tokenizer to {tok_dir} (vocab={vocab_size})")
120 | else:
121 | tok = None
122 | vocab_size = 256 # byte-level fallback (not recommended for Part 4)
123 |
124 | # ---- dataset ----
125 | train_loader = make_loader(args.data, tok, args.block_size, args.batch_size, shuffle=True)
126 |
127 | # ---- build model config ----
128 | if have_ckpt:
129 | cfg_build = ckpt["config"]
130 | if cfg_build.get("vocab_size") != vocab_size:
131 | raise RuntimeError(
132 | f"Tokenizer vocab ({vocab_size}) != checkpoint config vocab ({cfg_build.get('vocab_size')}). "
133 | "This deterministic script forbids vocab changes on resume."
134 | )
135 | else:
136 | cfg_build = run_cfg_from_args(args, vocab_size)
137 |
138 | # ---- init model/opt/sched/amp ----
139 | model = GPTModern(**cfg_build).to(device)
140 | optim = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.1)
141 |
142 | total_steps = min(args.steps, args.epochs * len(train_loader))
143 | warmup = min(args.warmup_steps, max(total_steps // 10, 1))
144 | sched = WarmupCosineLR(optim, warmup_steps=warmup, total_steps=total_steps, base_lr=args.lr)
145 |
146 | amp = AmpGrad(optim, accum=args.grad_accum_steps, amp=args.mixed_precision)
147 |
148 | # ---- strict resume ----
149 | step = 0
150 | if have_ckpt:
151 | step = load_checkpoint(model, str(ckpt_path), optimizer=optim, scheduler=sched, amp=amp, strict=True)
152 | print(f"[resume] Loaded checkpoint at step {step}")
153 |
154 | # ---- logging ----
155 | logger = init_logger(args.log, out_dir=str(out_dir))
156 | _log_hparams_tb(logger, args, total_steps)
157 | if _is_tb(logger):
158 | try:
159 | ex_x, ex_y = next(iter(train_loader))
160 | _maybe_log_graph_tb(logger, model, ex_x.to(device), ex_y.to(device))
161 | except Exception:
162 | pass
163 |
164 | # ---- graceful save on SIGINT/SIGTERM ----
165 | save_requested = {"flag": False}
166 | def _on_term(sig, frame): save_requested["flag"] = True
167 | signal.signal(signal.SIGTERM, _on_term)
168 | signal.signal(signal.SIGINT, _on_term)
169 |
170 | # ---- train loop ----
171 | model.train()
172 | while step < args.steps:
173 | for xb, yb in train_loader:
174 | if step >= args.steps: break
175 | if save_requested["flag"]:
176 | atomic_save_all(model, optim, sched, amp, step, out_dir, tok_dir, args.keep_last_k, cfg_build)
177 | print(f"[signal] Saved checkpoint at step {step} to {out_dir}. Exiting.")
178 | return
179 |
180 | it_t0 = time.time()
181 | xb, yb = xb.to(device), yb.to(device)
182 | with torch.cuda.amp.autocast(enabled=amp.amp):
183 | logits, loss, _ = model(xb, yb)
184 | amp.backward(loss)
185 |
186 | if amp.should_step():
187 | amp.step(); amp.zero_grad()
188 | lr = sched.step()
189 | step += 1
190 |
191 | # periodic checkpoint
192 | if step % args.save_every == 0:
193 | atomic_save_all(model, optim, sched, amp, step, out_dir, tok_dir, args.keep_last_k, cfg_build)
194 | if _is_tb(logger):
195 | logger.text("meta/checkpoint", f"Saved at step {step}", step)
196 |
197 | # logging
198 | if step % 50 == 0:
199 | logger.log(step=step, loss=float(loss.item()), lr=float(lr))
200 | _log_runtime(logger, step, it_t0, xb, device)
201 | _log_model_stats(logger, model, step, do_hists=False)
202 | _maybe_log_attention(logger, model, xb, step, every=100)
203 | _log_samples_tb(logger, model, tok, xb, device, step, max_new_tokens=64)
204 |
205 | # ---- final save ----
206 | atomic_save_all(model, optim, sched, amp, step, out_dir, tok_dir, args.keep_last_k, cfg_build)
207 | print(f"Saved checkpoint to {out_dir}/model_last.pt")
208 |
209 |
210 | if __name__ == '__main__':
211 | main()
212 |
--------------------------------------------------------------------------------
/part_8/train_ppo.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import argparse, torch
3 | from pathlib import Path
4 |
5 | # import torch
6 | # torch.manual_seed(0)
7 |
8 | from policy import PolicyWithValue
9 | from rollout import RLHFTokenizer, format_prompt_only, format_example, sample_prompts, gather_logprobs, shift_labels
10 | from rollout import model_logprobs
11 |
12 | # Reward model from Part 7
13 | import sys
14 | from pathlib import Path as _P
15 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_7'))
16 | from model_reward import RewardModel # noqa: E402
17 |
18 | from ppo_loss import ppo_losses
19 |
20 |
21 | def compute_reward(reward_model: RewardModel, tok: RLHFTokenizer, prompt: str, response: str, device) -> float:
22 | text = format_example(__import__('part_6.formatters', fromlist=['Example']).Example(prompt, response))
23 | ids = tok.encode(text)
24 | x = torch.tensor([ids[:tok.block_size]], dtype=torch.long, device=device)
25 | with torch.no_grad():
26 | r = reward_model(x)
27 | return float(r[0].item())
28 |
29 |
30 | def main():
31 | p = argparse.ArgumentParser()
32 | p.add_argument('--out', type=str, default='runs/ppo-demo')
33 | p.add_argument('--policy_ckpt', type=str, required=True, help='SFT checkpoint (Part 6)')
34 | p.add_argument('--reward_ckpt', type=str, required=True, help='Reward model checkpoint (Part 7)')
35 | p.add_argument('--steps', type=int, default=100)
36 | p.add_argument('--batch_size', type=int, default=4)
37 | p.add_argument('--block_size', type=int, default=256)
38 | p.add_argument('--resp_len', type=int, default=64)
39 | p.add_argument('--kl_coef', type=float, default=0.01)
40 | p.add_argument('--gamma', type=float, default=1.0)
41 | p.add_argument('--lam', type=float, default=0.95)
42 | p.add_argument('--lr', type=float, default=1e-5)
43 | p.add_argument('--bpe_dir', type=str, default=None)
44 | p.add_argument('--cpu', action='store_true')
45 | args = p.parse_args()
46 |
47 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
48 |
49 | # tokenizer
50 | tok = RLHFTokenizer(block_size=args.block_size, bpe_dir=args.bpe_dir)
51 |
52 | # Load SFT policy as initial policy AND reference
53 | ckpt = torch.load(args.policy_ckpt, map_location=device)
54 | cfg = ckpt.get('config', {})
55 | vocab_size = cfg.get('vocab_size', tok.vocab_size)
56 | block_size = cfg.get('block_size', tok.block_size)
57 | n_layer = cfg.get('n_layer', 2)
58 | n_head = cfg.get('n_head', 2)
59 | n_embd = cfg.get('n_embd', 128)
60 |
61 | policy = PolicyWithValue(vocab_size, block_size, n_layer, n_head, n_embd).to(device)
62 | policy.lm.load_state_dict(ckpt['model']) # initialize LM weights from SFT
63 |
64 |
65 | ref = PolicyWithValue(vocab_size, block_size, n_layer, n_head, n_embd).to(device)
66 | ref.lm.load_state_dict(ckpt['model'])
67 | for p_ in ref.parameters():
68 | p_.requires_grad_(False)
69 | ref.eval()
70 |
71 | # Reward model
72 | rckpt = torch.load(args.reward_ckpt, map_location=device)
73 | rm = RewardModel(vocab_size=rckpt['config'].get('vocab_size', tok.vocab_size), block_size=rckpt['config'].get('block_size', tok.block_size),
74 | n_layer=rckpt['config'].get('n_layer', 4), n_head=rckpt['config'].get('n_head', 4), n_embd=rckpt['config'].get('n_embd', 256)).to(device)
75 | rm.load_state_dict(rckpt['model'])
76 | rm.eval()
77 |
78 | opt = torch.optim.AdamW(policy.parameters(), lr=args.lr, betas=(0.9, 0.999))
79 |
80 | # small prompt pool
81 | prompts = sample_prompts(16)
82 |
83 | step = 0
84 | while step < args.steps:
85 | # ----- COLLECT ROLLOUT BATCH -----
86 | batch_prompts = prompts[ (step*args.batch_size) % len(prompts) : ((step+1)*args.batch_size) % len(prompts) ]
87 | if len(batch_prompts) < args.batch_size:
88 | batch_prompts += prompts[:args.batch_size-len(batch_prompts)]
89 | texts = [format_prompt_only(p).replace("", "") for p in batch_prompts]
90 | in_ids = [tok.encode(t) for t in texts]
91 |
92 | with torch.no_grad():
93 | out_ids = []
94 | for i, x in enumerate(in_ids):
95 | idx = torch.tensor([x], dtype=torch.long, device=device)
96 | out = policy.generate(idx, max_new_tokens=args.resp_len, temperature=0.2, top_k=3)
97 | out_ids.append(out[0].tolist())
98 |
99 | # split prompt/response per sample
100 | data = []
101 | for i, prompt in enumerate(batch_prompts):
102 | full = out_ids[i]
103 | # find boundary: index where prompt ends in the tokenized form
104 | # Use original prompt tokenization length (clipped by block_size)
105 | p_ids = in_ids[i][-block_size:]
106 | boundary = len(p_ids)
107 | resp_ids = full[boundary:]
108 | # compute rewards via RM on formatted prompt+response text
109 | resp_text = tok.decode(resp_ids)
110 | r_scalar = compute_reward(rm, tok, prompt, resp_text, device)
111 | data.append((torch.tensor(full, dtype=torch.long), boundary, r_scalar))
112 |
113 | # pad to same length
114 | policy_ctx = getattr(policy, "block_size", block_size)
115 | max_len = min(policy_ctx, max(t[0].numel() for t in data))
116 | B = len(data)
117 | seq = torch.zeros(B, max_len, dtype=torch.long, device=device)
118 | mask = torch.zeros(B, max_len, dtype=torch.bool, device=device)
119 | last_idx = torch.zeros(B, dtype=torch.long, device=device)
120 | rewards = torch.zeros(B, max_len, dtype=torch.float, device=device)
121 |
122 | for i, (ids, boundary, r_scalar) in enumerate(data):
123 | L_full = ids.numel()
124 | L = min(L_full, max_len)
125 | drop = L_full - L # tokens dropped from the left
126 | b = max(0, boundary - drop) # shift boundary after left-trim
127 | seq[i, :L] = ids[-L:]
128 | if L < max_len:
129 | seq[i, L:] = 2 # fill remaining positions with token
130 | mask[i, b:L] = True
131 | rewards[i, L-1] = r_scalar
132 | last_idx[i] = L-1
133 |
134 |
135 | # logprobs & values for policy and reference
136 | # model_logprobs returns (B, T-1) for next-token logp; align to seq[:,1:]
137 | pol_lp = model_logprobs(policy, seq)
138 | ref_lp = model_logprobs(ref, seq)
139 | # values for seq positions (B,T)
140 | with torch.no_grad():
141 | logits, values, _ = policy(seq, None)
142 | values = values[:, :-1] # align to pol_lp
143 |
144 | # Select only action positions
145 | act_mask = mask[:,1:] # since logprobs are for predicting token t from <=t-1
146 | old_logp = pol_lp[act_mask].detach()
147 | ref_logp = ref_lp[act_mask].detach()
148 | old_values = values[act_mask].detach()
149 |
150 | # KL per action token and shaped rewards
151 | kl = (old_logp - ref_logp)
152 | shaped_r = rewards[:,1:][act_mask] - args.kl_coef * kl # penalty for drifting
153 |
154 | # Compute advantages/returns with last‑step bootstrap = 0 (episodic per response)
155 | # Flatten by sequence order inside each sample; we’ll approximate by grouping tokens per sample using last_idx.
156 | # For tutorial simplicity, treat advantages = shaped_r - old_values (no GAE). Works for end-only reward.
157 | returns = shaped_r # target value = immediate shaped reward
158 | adv = returns - old_values
159 | # normalize adv
160 | adv = (adv - adv.mean()) / (adv.std().clamp_min(1e-6))
161 |
162 | # ----- UPDATE (single pass PPO for demo) -----
163 | # This step is done multiple times per batch in practice
164 | policy.train()
165 | logits_new, values_new_full, _ = policy(seq, None)
166 | logp_full = torch.log_softmax(logits_new[:, :-1, :], dim=-1)
167 | labels = seq[:,1:]
168 | new_logp_all = logp_full.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
169 | new_logp = new_logp_all[act_mask]
170 | new_values = values_new_full[:, :-1][act_mask]
171 |
172 | from ppo_loss import ppo_losses
173 | out_loss = ppo_losses(new_logp, old_logp, adv, new_values, old_values, returns,
174 | clip_ratio=0.2, vf_coef=0.5, ent_coef=0.0)
175 | loss = out_loss.total_loss
176 |
177 | opt.zero_grad(set_to_none=True)
178 | loss.backward()
179 | torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
180 | opt.step()
181 | policy.eval()
182 |
183 | with torch.no_grad():
184 | # KL(old || new): movement of the updated policy from the snapshot used to collect data
185 | lp_post = model_logprobs(policy, seq) # (B, T-1)
186 | lp_post = lp_post[act_mask] # only action positions
187 | kl_post = (old_logp - lp_post).mean() # ≈ E[log π_old - log π_new]
188 |
189 | # KL(now || ref): how far the current policy is from the frozen reference
190 | lp_now = lp_post # already computed above on the same positions
191 | kl_ref_now = (lp_now - ref_logp).mean() # ≈ E[log π_now - log π_ref]
192 |
193 | step += 1
194 | if step % 10 == 0:
195 | print(
196 | f"step {step} | loss {loss.item():.4f}"
197 | f"| value loss {out_loss.value_loss.item():.4f} | KL_move {kl_post.item():.6f} | KL_ref {kl_ref_now.item():.6f}"
198 | )
199 |
200 |
201 | Path(args.out).mkdir(parents=True, exist_ok=True)
202 | torch.save({'model': policy.state_dict(), 'config': {
203 | 'vocab_size': vocab_size,
204 | 'block_size': block_size,
205 | 'n_layer': n_layer,
206 | 'n_head': n_head,
207 | 'n_embd': n_embd,
208 | }}, str(Path(args.out)/'model_last.pt'))
209 | print(f"Saved PPO policy to {args.out}/model_last.pt")
210 |
211 | if __name__ == '__main__':
212 | main()
--------------------------------------------------------------------------------
/part_9/train_grpo.py:
--------------------------------------------------------------------------------
1 | # train_grpo.py
2 | from __future__ import annotations
3 | import argparse, torch
4 | from pathlib import Path
5 |
6 | from policy import PolicyWithValue # we will ignore the value head
7 | from rollout import RLHFTokenizer, format_prompt_only, sample_prompts, model_logprobs
8 |
9 | # Reward model from Part 7
10 | import sys
11 | from pathlib import Path as _P
12 | sys.path.append(str(_P(__file__).resolve().parents[1]/'part_7'))
13 | from model_reward import RewardModel # noqa: E402
14 |
15 | from grpo_loss import ppo_policy_only_losses
16 |
17 |
18 | @torch.no_grad()
19 | def compute_reward(reward_model: RewardModel, tok: RLHFTokenizer, prompt_text: str, response_ids: list[int], device) -> float:
20 | # Build full formatted text (as in your PPO)
21 | from part_6.formatters import Example, format_example
22 | resp_text = tok.decode(response_ids)
23 | text = format_example(Example(prompt_text, resp_text))
24 | ids = tok.encode(text)
25 | x = torch.tensor([ids[:tok.block_size]], dtype=torch.long, device=device)
26 | r = reward_model(x)
27 | return float(r[0].item())
28 |
29 |
30 | def main():
31 | p = argparse.ArgumentParser()
32 | p.add_argument('--out', type=str, default='runs/grpo-demo')
33 | p.add_argument('--policy_ckpt', type=str, required=True, help='SFT checkpoint (Part 6)')
34 | p.add_argument('--reward_ckpt', type=str, required=True, help='Reward model checkpoint (Part 7)')
35 | p.add_argument('--steps', type=int, default=100)
36 | p.add_argument('--batch_prompts', type=int, default=32, help='number of distinct prompts per step (before grouping)')
37 | p.add_argument('--group_size', type=int, default=4, help='completions per prompt')
38 | p.add_argument('--block_size', type=int, default=256)
39 | p.add_argument('--resp_len', type=int, default=64)
40 | p.add_argument('--kl_coef', type=float, default=0.01)
41 | p.add_argument('--lr', type=float, default=1e-5)
42 | p.add_argument('--bpe_dir', type=str, default=None)
43 | p.add_argument('--cpu', action='store_true')
44 | args = p.parse_args()
45 |
46 | device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
47 |
48 | # tokenizer
49 | tok = RLHFTokenizer(block_size=args.block_size, bpe_dir=args.bpe_dir)
50 |
51 | # Load SFT policy (and a frozen reference)
52 | ckpt = torch.load(args.policy_ckpt, map_location=device)
53 | cfg = ckpt.get('config', {})
54 | vocab_size = cfg.get('vocab_size', tok.vocab_size)
55 | block_size = cfg.get('block_size', tok.block_size)
56 | n_layer = cfg.get('n_layer', 2)
57 | n_head = cfg.get('n_head', 2)
58 | n_embd = cfg.get('n_embd', 128)
59 |
60 | policy = PolicyWithValue(vocab_size, block_size, n_layer, n_head, n_embd).to(device)
61 | policy.lm.load_state_dict(ckpt['model'])
62 | policy.eval()
63 |
64 | ref = PolicyWithValue(vocab_size, block_size, n_layer, n_head, n_embd).to(device)
65 | ref.lm.load_state_dict(ckpt['model'])
66 | for p_ in ref.parameters():
67 | p_.requires_grad_(False)
68 | ref.eval()
69 |
70 | # Reward model
71 | rckpt = torch.load(args.reward_ckpt, map_location=device)
72 | rm = RewardModel(vocab_size=rckpt['config'].get('vocab_size', tok.vocab_size),
73 | block_size=rckpt['config'].get('block_size', tok.block_size),
74 | n_layer=rckpt['config'].get('n_layer', 4),
75 | n_head=rckpt['config'].get('n_head', 4),
76 | n_embd=rckpt['config'].get('n_embd', 256)).to(device)
77 | rm.load_state_dict(rckpt['model'])
78 | rm.eval()
79 |
80 | opt = torch.optim.AdamW(policy.parameters(), lr=args.lr, betas=(0.9, 0.999))
81 |
82 | # small prompt pool (reuse your helper)
83 | prompts_pool = sample_prompts(16)
84 |
85 | step = 0
86 | pool_idx = 0
87 | G = args.group_size
88 |
89 | while step < args.steps:
90 | # ----- SELECT PROMPTS -----
91 | # Choose P prompts, each will yield G completions → B = P*G trajectories
92 | P = max(1, args.batch_prompts)
93 | if pool_idx + P > len(prompts_pool):
94 | pool_idx = 0
95 | batch_prompts = prompts_pool[pool_idx: pool_idx + P]
96 | pool_idx += P
97 |
98 | # Tokenize prompt-only texts
99 | prompt_texts = [format_prompt_only(p).replace("", "") for p in batch_prompts]
100 | prompt_in_ids = [tok.encode(t) for t in prompt_texts]
101 |
102 | # ----- GENERATE G COMPLETIONS PER PROMPT -----
103 | # We will collect all trajectories flat, but track their group/prompt ids.
104 | seq_list = [] # list[Tensor of token ids]
105 | boundary_list = [] # index where response starts in the (possibly clipped) sequence
106 | prompt_id_of = [] # which prompt this trajectory belongs to (0..P-1)
107 | raw_rewards = [] # scalar reward per trajectory (before KL shaping)
108 | last_idx_list = [] # for padding bookkeeping
109 |
110 | with torch.no_grad():
111 | for pid, p_ids in enumerate(prompt_in_ids):
112 | for g in range(G):
113 | idx = torch.tensor([p_ids], dtype=torch.long, device=device)
114 | out = policy.generate(idx, max_new_tokens=args.resp_len, temperature=2, top_k=3)
115 | full_ids = out[0].tolist()
116 |
117 | # split prompt/response
118 | boundary = len(p_ids[-block_size:]) # prompt length clipped to context
119 | resp_ids = full_ids[boundary:]
120 | r_scalar = compute_reward(rm, tok, batch_prompts[pid], resp_ids, device)
121 |
122 | seq_list.append(torch.tensor(full_ids, dtype=torch.long))
123 | boundary_list.append(boundary)
124 | prompt_id_of.append(pid)
125 | raw_rewards.append(r_scalar)
126 |
127 | # ----- PAD TO BATCH -----
128 | B = len(seq_list) # B = P*G
129 | policy_ctx = getattr(policy, "block_size", block_size)
130 | max_len = min(policy_ctx, max(s.numel() for s in seq_list))
131 | seq = torch.zeros(B, max_len, dtype=torch.long, device=device)
132 | mask = torch.zeros(B, max_len, dtype=torch.bool, device=device)
133 | last_idx = torch.zeros(B, dtype=torch.long, device=device)
134 |
135 | # keep a per-traj “action positions” mask and response-only boundary
136 | for i, (ids, bnd) in enumerate(zip(seq_list, boundary_list)):
137 | L_full = ids.numel()
138 | L = min(L_full, max_len)
139 | drop = L_full - L
140 | b = max(0, bnd - drop) # shifted boundary after left-trim
141 | seq[i, :L] = ids[-L:]
142 | if L < max_len:
143 | seq[i, L:] = 2 # pad token
144 | # actions are predicting token t from <=t-1 → positions [1..L-1]
145 | # but we only care about response tokens: mask [b..L-1] → actions [b+1..L-1]
146 | mask[i, b:L] = True
147 | last_idx[i] = L - 1
148 |
149 | # ----- LOGPROBS & KL VS REF (token-level) -----
150 | # model_logprobs returns log p(x[t] | x[:t-1]) for t=1..T-1 over labels=seq[:,1:]
151 | with torch.no_grad():
152 | pol_lp_full = model_logprobs(policy, seq) # (B, T-1)
153 | ref_lp_full = model_logprobs(ref, seq) # (B, T-1)
154 |
155 | # action positions (predict positions [1..T-1]); we want only response tokens:
156 | act_mask = mask[:, 1:] # align to (B, T-1)
157 | old_logp = pol_lp_full[act_mask].detach()
158 | ref_logp = ref_lp_full[act_mask].detach()
159 |
160 | # per-token KL on action tokens
161 | kl_tok = (old_logp - ref_logp) # (N_act,)
162 |
163 | # ----- SHAPED TRAJECTORY REWARD & GROUP BASELINE -----
164 | # For GRPO, advantage is trajectory-level and broadcast to its tokens.
165 | # We include KL shaping at trajectory level using mean token KL per trajectory.
166 | # First, compute mean KL per trajectory on its action tokens.
167 | # Build an index map from flat action tokens back to traj ids.
168 | # We can reconstruct counts by iterating rows.
169 | traj_id_for_token = []
170 | counts = torch.zeros(B, dtype=torch.long, device=device)
171 | offset = 0
172 | for i in range(B):
173 | mrow = act_mask[i]
174 | n_i = int(mrow.sum().item())
175 | if n_i > 0:
176 | traj_id_for_token.extend([i] * n_i)
177 | counts[i] = n_i
178 | offset += n_i
179 | traj_id_for_token = torch.tensor(traj_id_for_token, dtype=torch.long, device=device)
180 | raw_rewards_t = torch.tensor(raw_rewards, dtype=torch.float, device=device)
181 |
182 | # Compute per-prompt group mean of shaped rewards
183 | group_mean = torch.zeros(B, dtype=torch.float, device=device)
184 | for pid in range(P):
185 | idxs = [i for i in range(B) if prompt_id_of[i] == pid]
186 | if not idxs:
187 | continue
188 | idxs_t = torch.tensor(idxs, dtype=torch.long, device=device)
189 | mean_val = raw_rewards_t[idxs_t].mean()
190 | group_mean[idxs_t] = mean_val
191 |
192 | # Advantage per trajectory, broadcast to its action tokens
193 | traj_adv = raw_rewards_t - group_mean # (B,)
194 |
195 | # Build a flat tensor of advantages aligned with old_logp/new_logp on action tokens
196 | if kl_tok.numel() > 0:
197 | adv_flat = traj_adv[traj_id_for_token]
198 | else:
199 | adv_flat = torch.zeros(0, dtype=torch.float, device=device)
200 |
201 | # Normalize advantages (optional but usually helpful)
202 | if adv_flat.numel() > 1:
203 | adv_flat = (adv_flat - adv_flat.mean()) / (adv_flat.std().clamp_min(1e-6))
204 |
205 | # ----- UPDATE (policy-only PPO clipped objective) -----
206 | policy.train()
207 | logits_new, _, _ = policy(seq, None) # ignore value head
208 | logp_full = torch.log_softmax(logits_new[:, :-1, :], dim=-1)
209 | labels = seq[:, 1:]
210 | new_logp_all = logp_full.gather(-1, labels.unsqueeze(-1)).squeeze(-1) # (B, T-1)
211 | new_logp = new_logp_all[act_mask]
212 |
213 | # Mean KL over action tokens
214 | kl_now_ref_mean = (new_logp - ref_logp).mean() if new_logp.numel() > 0 else torch.tensor(0.0, device=device)
215 |
216 | out_loss = ppo_policy_only_losses(
217 | new_logp=new_logp,
218 | old_logp=old_logp,
219 | adv=adv_flat,
220 | clip_ratio=0.2,
221 | ent_coef=0.0, # set >0 if you want entropy bonus from -new_logp mean
222 | kl_coef=args.kl_coef,
223 | kl_mean=kl_now_ref_mean,
224 | )
225 | loss = out_loss.total_loss
226 |
227 | opt.zero_grad(set_to_none=True)
228 | loss.backward()
229 | torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
230 | opt.step()
231 | policy.eval()
232 |
233 | # Some quick diagnostics (movement vs old, and now vs ref)
234 | with torch.no_grad():
235 | lp_post = model_logprobs(policy, seq)[act_mask]
236 | kl_move = (old_logp - lp_post).mean() if lp_post.numel() > 0 else torch.tensor(0.0, device=device)
237 | # KL(now || ref)
238 | kl_ref_now = (lp_post - ref_logp).mean() if lp_post.numel() > 0 else torch.tensor(0.0, device=device)
239 |
240 | step += 1
241 | if step % 10 == 0:
242 | print(
243 | f"step {step} | loss {loss.item():.4f}"
244 | f"| KL_move {kl_move.item():.6f} | KL_ref {kl_ref_now.item():.6f}"
245 | )
246 |
247 | Path(args.out).mkdir(parents=True, exist_ok=True)
248 | torch.save({'model': policy.state_dict(), 'config': {
249 | 'vocab_size': vocab_size,
250 | 'block_size': block_size,
251 | 'n_layer': n_layer,
252 | 'n_head': n_head,
253 | 'n_embd': n_embd,
254 | }}, str(Path(args.out)/'model_last.pt'))
255 | print(f"Saved GRPO policy to {args.out}/model_last.pt")
256 |
257 |
258 | if __name__ == '__main__':
259 | main()
260 |
--------------------------------------------------------------------------------
/part_4/checkpointing.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import os
3 | from pathlib import Path
4 | from typing import Any, Dict, Optional, Tuple
5 | import sys
6 | sys.path.append(str(Path(__file__).resolve().parents[1]/'part_3'))
7 | import time
8 | import torch
9 | import shutil
10 | import torch.nn as nn
11 |
12 | DEF_NAME = "model_last.pt"
13 |
14 | # ----------------------------- TB-only helpers (safe no-ops otherwise) ----------------------------- #
15 | def _is_tb(logger) -> bool:
16 | return getattr(logger, "w", None) is not None
17 |
18 |
19 | # checkpointing._log_hparams_tb
20 | def _log_hparams_tb(logger, args, total_steps):
21 | if not _is_tb(logger): return
22 | try:
23 | h = dict(
24 | vocab_size=args.vocab_size, block_size=args.block_size, n_layer=args.n_layer,
25 | n_head=args.n_head, n_embd=args.n_embd, dropout=args.dropout, lr=args.lr,
26 | warmup_steps=args.warmup_steps, batch_size=args.batch_size, grad_accum=args.grad_accum_steps,
27 | mixed_precision=args.mixed_precision, steps=args.steps, epochs=args.epochs,
28 | )
29 | logger.hparams(h, {"meta/total_steps": float(total_steps)})
30 | except Exception:
31 | pass
32 |
33 | def _maybe_log_graph_tb(logger, model, xb, yb):
34 | if not hasattr(logger, "graph"):
35 | return
36 | try:
37 | class _TensorOnly(nn.Module):
38 | def __init__(self, m):
39 | super().__init__(); self.m = m.eval()
40 | def forward(self, x, y=None):
41 | out = self.m(x, y) if y is not None else self.m(x)
42 | if isinstance(out, (list, tuple)):
43 | for o in out:
44 | if torch.is_tensor(o):
45 | return o
46 | return out[0]
47 | return out
48 | wrapped = _TensorOnly(model).to(xb.device)
49 | logger.graph(wrapped, (xb, yb))
50 | except Exception:
51 | pass
52 |
53 | def _log_model_stats(logger, model, step: int, do_hists: bool = False):
54 | if not _is_tb(logger): return
55 | try:
56 | params = [p for p in model.parameters() if p.requires_grad]
57 | total_param_norm = torch.norm(torch.stack([p.detach().norm(2) for p in params]), 2).item()
58 | grads = [p.grad for p in params if p.grad is not None]
59 | total_grad_norm = float('nan')
60 | if grads:
61 | total_grad_norm = torch.norm(torch.stack([g.detach().norm(2) for g in grads]), 2).item()
62 | logger.log(step=step, **{
63 | "train/param_global_l2": total_param_norm,
64 | "train/grad_global_l2": total_grad_norm,
65 | })
66 | if do_hists:
67 | for name, p in model.named_parameters():
68 | logger.hist(f"params/{name}", p, step)
69 | if p.grad is not None:
70 | logger.hist(f"grads/{name}", p.grad, step)
71 | except Exception:
72 | pass
73 |
74 | def _maybe_log_attention(logger, model, xb, step: int, every: int = 100):
75 | """
76 | Logs Q/K/V histograms for each Transformer block using the current minibatch xb.
77 | No model edits. No hooks. Runs a light no-grad recomputation of the pre-attn path.
78 | - Takes first batch and first head only to keep logs tiny.
79 | - Uses pre-RoPE values (simpler & stable for histograms).
80 | """
81 | if not _is_tb(logger) or step == 0 or (step % every):
82 | return
83 | try:
84 | import torch
85 | with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):
86 | # Recreate inputs seen by blocks
87 | x = model.tok_emb(xb) # (B,T,C)
88 | x = model.drop(x)
89 |
90 | B, T, _ = x.shape
91 | for li, blk in enumerate(getattr(model, "blocks", [])):
92 | h = blk.ln1(x) # pre-attn normalized hidden
93 |
94 | attn = blk.attn
95 | # Project to Q/K/V exactly like the module (pre-RoPE for simplicity)
96 | q = attn.wq(h).view(B, T, attn.n_head, attn.d_head).transpose(1, 2) # (B,H,T,D)
97 | k = attn.wk(h).view(B, T, attn.n_kv_head, attn.d_head).transpose(1, 2) # (B,Hk,T,D)
98 | v = attn.wv(h).view(B, T, attn.n_kv_head, attn.d_head).transpose(1, 2) # (B,Hk,T,D)
99 |
100 | # Take a tiny slice to keep logs light
101 | q1 = q[:1, :1].contiguous().view(-1).float().cpu()
102 | k1 = k[:1, :1].contiguous().view(-1).float().cpu()
103 | v1 = v[:1, :1].contiguous().view(-1).float().cpu()
104 |
105 | # Drop non-finite (defensive)
106 | q1 = q1[torch.isfinite(q1)]
107 | k1 = k1[torch.isfinite(k1)]
108 | v1 = v1[torch.isfinite(v1)]
109 |
110 | if q1.numel() > 0: logger.hist(f"qkv/block{li}/q_hist", q1, step)
111 | if k1.numel() > 0: logger.hist(f"qkv/block{li}/k_hist", k1, step)
112 | if v1.numel() > 0: logger.hist(f"qkv/block{li}/v_hist", v1, step)
113 |
114 | # Optional small scalars (norms) that show up on Time Series
115 | if q1.numel(): logger.log(step=step, **{f"qkv/block{li}/q_l2_mean": float(q1.square().mean().sqrt())})
116 | if k1.numel(): logger.log(step=step, **{f"qkv/block{li}/k_l2_mean": float(k1.square().mean().sqrt())})
117 | if v1.numel(): logger.log(step=step, **{f"qkv/block{li}/v_l2_mean": float(v1.square().mean().sqrt())})
118 |
119 | # Advance x to next block with a CHEAP approximation to avoid doubling full compute:
120 | # use the model's own FFN path only; skip re-running attention (we're only logging pre-attn stats).
121 | x = x + blk.ffn(blk.ln2(x))
122 |
123 | except Exception as e:
124 | print(f"[qkv] logging failed: {e}")
125 |
126 |
127 | def _log_runtime(logger, step: int, it_t0: float, xb, device):
128 | try:
129 | dt = time.time() - it_t0
130 | toks = int(xb.numel())
131 | toks_per_s = toks / max(dt, 1e-6)
132 | mem = torch.cuda.memory_allocated()/(1024**2) if torch.cuda.is_available() else 0.0
133 | logger.log(step=step, **{
134 | "sys/throughput_tokens_per_s": toks_per_s,
135 | "sys/step_time_s": dt,
136 | "sys/gpu_mem_alloc_mb": mem
137 | })
138 | except Exception:
139 | pass
140 |
141 | def _log_samples_tb(logger, model, tok, xb, device, step: int, max_new_tokens: int = 64):
142 | if not _is_tb(logger): return
143 | if tok is None: return
144 | try:
145 | model.eval()
146 | with torch.no_grad():
147 | out = model.generate(xb[:1].to(device), max_new_tokens=max_new_tokens, temperature=1.0, top_k=50)
148 | model.train()
149 | text = tok.decode(out[0].tolist())
150 | logger.text("samples/generation", text, step)
151 | except Exception:
152 | pass
153 | # ---------------------------------------------------------------------- #
154 |
155 | def _extract_config_from_model(model) -> dict:
156 | """
157 | Best-effort extraction of GPTModern-like config including GQA fields.
158 | """
159 | cfg = {}
160 | try:
161 | tok_emb = getattr(model, "tok_emb", None)
162 | blocks = getattr(model, "blocks", None)
163 | if tok_emb is None or not blocks:
164 | return cfg
165 |
166 | try:
167 | from swiglu import SwiGLU # optional
168 | except Exception:
169 | class SwiGLU: pass
170 |
171 | cfg["vocab_size"] = int(tok_emb.num_embeddings)
172 | cfg["block_size"] = int(getattr(model, "block_size", 0) or 0)
173 | cfg["n_layer"] = int(len(blocks))
174 |
175 | first_blk = blocks[0]
176 | attn = getattr(first_blk, "attn", None)
177 | if attn is None:
178 | return cfg
179 |
180 | # Heads & dims
181 | cfg["n_head"] = int(getattr(attn, "n_head"))
182 | d_head = int(getattr(attn, "d_head"))
183 | cfg["n_embd"] = int(cfg["n_head"] * d_head)
184 | cfg["n_kv_head"]= int(getattr(attn, "n_kv_head", cfg["n_head"])) # default to MHA
185 |
186 | # Dropout (if present)
187 | drop = getattr(attn, "dropout", None)
188 | cfg["dropout"] = float(getattr(drop, "p", 0.0)) if drop is not None else 0.0
189 |
190 | # Norm/FFN style
191 | cfg["use_rmsnorm"] = isinstance(getattr(model, "ln_f", None), nn.Identity)
192 | cfg["use_swiglu"] = isinstance(getattr(first_blk, "ffn", None), SwiGLU)
193 |
194 | # Positional / attention tricks
195 | for k in ("rope", "max_pos", "sliding_window", "attention_sink"):
196 | if hasattr(attn, k):
197 | val = getattr(attn, k)
198 | cfg[k] = int(val) if isinstance(val, bool) else val
199 | except Exception:
200 | return {}
201 | return cfg
202 |
203 | def _verify_model_matches(model, cfg: Dict[str, Any]) -> Tuple[bool, str]:
204 | """Return (ok, message)."""
205 | expected = {
206 | "block_size": cfg.get("block_size"),
207 | "n_layer": cfg.get("n_layer"),
208 | "n_head": cfg.get("n_head"),
209 | "n_embd": cfg.get("n_embd"),
210 | "vocab_size": cfg.get("vocab_size"),
211 | "n_kv_head": cfg.get("n_kv_head", cfg.get("n_head")),
212 | }
213 | got = {
214 | "block_size": int(getattr(model, "block_size", -1)),
215 | "n_layer": int(len(model.blocks)),
216 | "vocab_size": int(model.tok_emb.num_embeddings),
217 | }
218 | first_blk = model.blocks[0]
219 | got.update({
220 | "n_head": int(first_blk.attn.n_head),
221 | "n_embd": int(first_blk.attn.n_head * first_blk.attn.d_head),
222 | "n_kv_head": int(getattr(first_blk.attn, "n_kv_head", first_blk.attn.n_head)),
223 | })
224 | diffs = [f"{k}: ckpt={expected[k]} vs model={got[k]}" for k in expected if expected[k] != got[k]]
225 | if diffs:
226 | return False, "Architecture mismatch:\n " + "\n ".join(diffs)
227 | return True, "ok"
228 |
229 |
230 | def save_checkpoint(model, optimizer, scheduler, amp, step: int, out_dir: str,
231 | tokenizer_dir: str | None = None, config: dict | None = None):
232 | out = Path(out_dir); out.mkdir(parents=True, exist_ok=True)
233 |
234 | # Prefer the model’s own config if available (e.g., a dict or dataclass with __dict__/asdict)
235 | if hasattr(model, "config"):
236 | cfg_obj = model.config
237 | cfg = dict(cfg_obj) if isinstance(cfg_obj, dict) else getattr(cfg_obj, "__dict__", None) or _extract_config_from_model(model)
238 | else:
239 | cfg = config if config is not None else _extract_config_from_model(model)
240 |
241 | torch.save({
242 | "model": model.state_dict(),
243 | "optimizer": optimizer.state_dict() if optimizer is not None else None,
244 | "scheduler": scheduler.state_dict() if hasattr(scheduler, "state_dict") else None,
245 | "amp_scaler": amp.scaler.state_dict() if amp and getattr(amp, "scaler", None) else None,
246 | "step": int(step),
247 | "config": cfg, # ← always write config
248 | "version": "part4-v2",
249 | }, out / DEF_NAME)
250 |
251 | if tokenizer_dir is not None:
252 | (out / "tokenizer_dir.txt").write_text(tokenizer_dir)
253 |
254 |
255 |
256 | def load_checkpoint(model, path: str, optimizer=None, scheduler=None, amp=None, strict: bool = True):
257 | ckpt = torch.load(path, map_location="cpu")
258 |
259 | cfg = ckpt.get("config")
260 | if cfg:
261 | ok, msg = _verify_model_matches(model, cfg)
262 | if not ok:
263 | raise RuntimeError(msg + "\nRebuild the model with this config, or load with strict=False.")
264 | else:
265 | # Legacy checkpoint without config: strongly encourage a rebuild step elsewhere
266 | print("[compat] Warning: checkpoint has no config; cannot verify architecture.")
267 |
268 | missing, unexpected = model.load_state_dict(ckpt["model"], strict=strict)
269 | if strict and (missing or unexpected):
270 | raise RuntimeError(f"State dict mismatch:\n missing: {missing}\n unexpected: {unexpected}")
271 |
272 | if optimizer is not None and ckpt.get("optimizer") is not None:
273 | optimizer.load_state_dict(ckpt["optimizer"])
274 | if scheduler is not None and ckpt.get("scheduler") is not None and hasattr(scheduler, "load_state_dict"):
275 | scheduler.load_state_dict(ckpt["scheduler"])
276 | if amp is not None and ckpt.get("amp_scaler") is not None and getattr(amp, "scaler", None):
277 | amp.scaler.load_state_dict(ckpt["amp_scaler"])
278 |
279 | return ckpt.get("step", 0)
280 |
281 |
282 | # ----------------------------- checkpoint/save utils ----------------------------- #
283 | def checkpoint_paths(out_dir: Path, step: int):
284 | return out_dir / f"model_step{step:07d}.pt", out_dir / "model_last.pt"
285 |
286 | def atomic_save_all(model, optim, sched, amp, step: int, out_dir: Path,
287 | tok_dir: str | None, keep_last_k: int, config: dict):
288 | """Write model_last.pt (with config) + a rolling per-step copy."""
289 | save_checkpoint(model, optim, sched, amp, step, str(out_dir), tok_dir, config=config) # writes model_last.pt
290 | per_step, last = checkpoint_paths(out_dir, step)
291 | try:
292 | shutil.copy2(last, per_step)
293 | except Exception:
294 | pass
295 | # GC old per-step checkpoints
296 | try:
297 | ckpts = sorted(out_dir.glob("model_step*.pt"))
298 | for old in ckpts[:-keep_last_k]:
299 | old.unlink(missing_ok=True)
300 | except Exception:
301 | pass
--------------------------------------------------------------------------------