├── nanochat
├── __init__.py
├── logo.svg
├── configurator.py
├── dataloader.py
├── loss_eval.py
├── adamw.py
├── dataset.py
├── common.py
├── checkpoint_manager.py
├── muon.py
├── execution.py
└── core_eval.py
├── .python-version
├── dev
├── nanochat.png
├── generate_logo.html
├── runcpu.sh
├── repackage_data_reference.py
└── gen_synthetic_data.py
├── .gitignore
├── rustbpe
├── Cargo.toml
├── README.md
└── Cargo.lock
├── LICENSE
├── pyproject.toml
├── tasks
├── smoltalk.py
├── arc.py
├── customjson.py
├── humaneval.py
├── mmlu.py
├── gsm8k.py
└── common.py
├── scripts
├── base_loss.py
├── chat_cli.py
├── tok_train.py
├── base_eval.py
├── tok_eval.py
├── chat_sft.py
├── chat_eval.py
└── mid_train.py
├── run1000.sh
├── speedrun.sh
└── README.md
/nanochat/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.10
2 |
--------------------------------------------------------------------------------
/dev/nanochat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxroute/nanochat/master/dev/nanochat.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv/
2 | __pycache__/
3 | *.pyc
4 | rustbpe/target/
5 | dev-ignore/
6 | report.md
7 |
--------------------------------------------------------------------------------
/rustbpe/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "rustbpe"
3 | version = "0.1.0"
4 | edition = "2024"
5 |
6 | [dependencies]
7 | dary_heap = "0.3"
8 | indexmap = "2.2"
9 | fancy-regex = "0.16.1"
10 | log = "0.4.28"
11 | pyo3 = { version = "0.23.3", features = ["extension-module"] }
12 | pyo3-log = "0.12.4"
13 | ahash = "0.8.12"
14 | rayon = "1.11.0"
15 | compact_str = "0.9.0"
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Andrej Karpathy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/rustbpe/README.md:
--------------------------------------------------------------------------------
1 | # rustbpe
2 |
3 | > The missing tiktoken training code
4 |
5 | A very lightweight Rust library for training a GPT tokenizer. The issue is that the inference library [tiktoken](https://github.com/openai/tiktoken) is great, but only does inference. Separately, the huggingface [tokenizers](https://github.com/huggingface/tokenizers) library does training, but it is rather bloated and really hard to navigate because it has to support all the different historical baggage of how people dealt with tokenizers over the years. More recently, I also wrote the [minbpe](https://github.com/karpathy/minbpe) library which does both training and inference, but only in inefficient Python. Basically what I really want is a non-fancy, super simple, but still relatively efficient training code for GPT tokenizer (more efficient than minbpe, much cleaner/simpler than tokenizers), and then export the trained vocab for inference with tiktoken. Does that make sense? So here we are. There are more opportunities for optimization here, I just stopped a bit early because unlike minbpe before it, rustbpe is now simple and fast enough, and not a significant bottleneck for nanochat.
6 |
--------------------------------------------------------------------------------
/dev/generate_logo.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
12 |
28 |
29 |
--------------------------------------------------------------------------------
/nanochat/logo.svg:
--------------------------------------------------------------------------------
1 |
9 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "nanochat"
3 | version = "0.1.0"
4 | description = "the minimal full-stack ChatGPT clone"
5 | readme = "README.md"
6 | requires-python = ">=3.10"
7 | dependencies = [
8 | "datasets>=4.0.0",
9 | "fastapi>=0.117.1",
10 | "files-to-prompt>=0.6",
11 | "numpy==1.26.4",
12 | "psutil>=7.1.0",
13 | "regex>=2025.9.1",
14 | "setuptools>=80.9.0",
15 | "tiktoken>=0.11.0",
16 | "tokenizers>=0.22.0",
17 | "torch>=2.8.0",
18 | "uvicorn>=0.36.0",
19 | "wandb>=0.21.3",
20 | ]
21 |
22 | [build-system]
23 | requires = ["maturin>=1.7,<2.0"]
24 | build-backend = "maturin"
25 |
26 | [tool.maturin]
27 | module-name = "rustbpe"
28 | bindings = "pyo3"
29 | python-source = "."
30 | manifest-path = "rustbpe/Cargo.toml"
31 |
32 | [dependency-groups]
33 | dev = [
34 | "maturin>=1.9.4",
35 | "pytest>=8.0.0",
36 | ]
37 |
38 | [tool.pytest.ini_options]
39 | markers = [
40 | "slow: marks tests as slow (deselect with '-m \"not slow\"')",
41 | ]
42 | testpaths = ["tests"]
43 | python_files = ["test_*.py"]
44 | python_classes = ["Test*"]
45 | python_functions = ["test_*"]
46 |
47 | # target torch to cuda 12.8
48 | [tool.uv.sources]
49 | torch = [
50 | { index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
51 | { index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
52 | ]
53 |
54 | [[tool.uv.index]]
55 | name = "pytorch-cpu"
56 | url = "https://download.pytorch.org/whl/cpu"
57 | explicit = true
58 |
59 | [[tool.uv.index]]
60 | name = "pytorch-cu128"
61 | url = "https://download.pytorch.org/whl/cu128"
62 | explicit = true
--------------------------------------------------------------------------------
/tasks/smoltalk.py:
--------------------------------------------------------------------------------
1 | """
2 | SmolTalk by HuggingFace. Good "general" conversational dataset.
3 | https://huggingface.co/datasets/HuggingFaceTB/smol-smoltalk
4 | We use the "smol" version, which is more appropriate for smaller models.
5 | """
6 |
7 | from datasets import load_dataset
8 | from tasks.common import Task
9 |
10 | class SmolTalk(Task):
11 | """ smol-smoltalk dataset. train is 460K rows, test is 24K rows. """
12 |
13 | def __init__(self, split, **kwargs):
14 | super().__init__(**kwargs)
15 | assert split in ["train", "test"], "SmolTalk split must be train|test"
16 | self.ds = load_dataset("HuggingFaceTB/smol-smoltalk", split=split).shuffle(seed=42)
17 | self.length = len(self.ds)
18 |
19 | def num_examples(self):
20 | return self.length
21 |
22 | def get_example(self, index):
23 | row = self.ds[index]
24 | messages = row["messages"]
25 | # ---------------------------------------------------------------------
26 | # sanity checking asserts here
27 | # TODO: we could remove these asserts later, for now just don't want any footguns
28 | # there is an optional system message at the beginning
29 | assert len(messages) >= 1
30 | first_message = messages[0]
31 | if first_message["role"] == "system":
32 | rest_messages = messages[1:] # optional system message is OK
33 | else:
34 | rest_messages = messages
35 | assert len(rest_messages) >= 2, "SmolTalk messages must have at least 2 messages"
36 | for i, message in enumerate(rest_messages):
37 | # user and assistant alternate as user,assistant,user,assistant,...
38 | expected_role = "user" if i % 2 == 0 else "assistant"
39 | assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
40 | assert isinstance(message["content"], str), "Content must be a string"
41 | # ---------------------------------------------------------------------
42 | # create and return the Conversation object (ok to emit the system message too)
43 | conversation = {
44 | "messages": messages,
45 | }
46 | return conversation
47 |
--------------------------------------------------------------------------------
/nanochat/configurator.py:
--------------------------------------------------------------------------------
1 | """
2 | Poor Man's Configurator. Probably a terrible idea. Example usage:
3 | $ python train.py config/override_file.py --batch_size=32
4 | this will first run config/override_file.py, then override batch_size to 32
5 |
6 | The code in this file will be run as follows from e.g. train.py:
7 | >>> exec(open('configurator.py').read())
8 |
9 | So it's not a Python module, it's just shuttling this code away from train.py
10 | The code in this script then overrides the globals()
11 |
12 | I know people are not going to love this, I just really dislike configuration
13 | complexity and having to prepend config. to every single variable. If someone
14 | comes up with a better simple Python solution I am all ears.
15 | """
16 |
17 | import os
18 | import sys
19 | from ast import literal_eval
20 |
21 | def print0(s="",**kwargs):
22 | ddp_rank = int(os.environ.get('RANK', 0))
23 | if ddp_rank == 0:
24 | print(s, **kwargs)
25 |
26 | for arg in sys.argv[1:]:
27 | if '=' not in arg:
28 | # assume it's the name of a config file
29 | assert not arg.startswith('--')
30 | config_file = arg
31 | print0(f"Overriding config with {config_file}:")
32 | with open(config_file) as f:
33 | print0(f.read())
34 | exec(open(config_file).read())
35 | else:
36 | # assume it's a --key=value argument
37 | assert arg.startswith('--')
38 | key, val = arg.split('=')
39 | key = key[2:]
40 | if key in globals():
41 | try:
42 | # attempt to eval it it (e.g. if bool, number, or etc)
43 | attempt = literal_eval(val)
44 | except (SyntaxError, ValueError):
45 | # if that goes wrong, just use the string
46 | attempt = val
47 | # ensure the types match ok
48 | if globals()[key] is not None:
49 | attempt_type = type(attempt)
50 | default_type = type(globals()[key])
51 | assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
52 | # cross fingers
53 | print0(f"Overriding: {key} = {attempt}")
54 | globals()[key] = attempt
55 | else:
56 | raise ValueError(f"Unknown config key: {key}")
57 |
--------------------------------------------------------------------------------
/tasks/arc.py:
--------------------------------------------------------------------------------
1 | """
2 | The ARC dataset from Allen AI.
3 | https://huggingface.co/datasets/allenai/ai2_arc
4 | """
5 |
6 | from datasets import load_dataset
7 | from tasks.common import Task, render_mc
8 |
9 | class ARC(Task):
10 |
11 | def __init__(self, subset, split, **kwargs):
12 | super().__init__(**kwargs)
13 | assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge"
14 | assert split in ["train", "validation", "test"], "ARC split must be train|validation|test"
15 | self.ds = load_dataset("allenai/ai2_arc", subset, split=split).shuffle(seed=42)
16 |
17 | @property
18 | def eval_type(self):
19 | return 'categorical'
20 |
21 | def num_examples(self):
22 | return len(self.ds)
23 |
24 | def get_example(self, index):
25 | row = self.ds[index]
26 | question = row["question"] # the question text
27 | choices = row["choices"]["text"] # the text of each choice
28 | answer_string = row["answerKey"] # e.g. "A", "B", "C", "D"
29 | letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"]
30 | assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check
31 | # create and return the Conversation object
32 | user_message = render_mc(question, letters, choices)
33 | messages = [
34 | {"role": "user", "content": user_message},
35 | {"role": "assistant", "content": answer_string}
36 | ]
37 | conversation = {
38 | "messages": messages,
39 | "letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
40 | }
41 | return conversation
42 |
43 | def evaluate(self, conversation, assistant_response):
44 | # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
45 | # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
46 | assert assistant_response in conversation['letters'], f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}"
47 | assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
48 | return assistant_response == assistant_message
49 |
--------------------------------------------------------------------------------
/nanochat/dataloader.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 |
3 | import torch
4 |
5 | from nanochat.common import get_dist_info
6 | from nanochat.dataset import parquets_iter_batched
7 | from nanochat.tokenizer import get_tokenizer
8 |
9 | def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
10 | """Stream pretraining text from parquet files, tokenize, yield training batches."""
11 | assert split in ["train", "val"], "split must be 'train' or 'val'"
12 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
13 | needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
14 | # get the tokenizer and the bos token
15 | tokenizer = get_tokenizer()
16 | bos_token = tokenizer.get_bos_token_id()
17 | # scratch buffer holds the tokens for one iteration
18 | token_buffer = deque() # we stream tokens on the right and pop from the left
19 |
20 | # infinite iterator over document batches
21 | def document_batches():
22 | while True:
23 | # batch will iterate in group size of the parquet files, usually e.g. 1024 rows
24 | for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
25 | # for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows
26 | for i in range(0, len(batch), tokenizer_batch_size):
27 | yield batch[i:i+tokenizer_batch_size]
28 | batches = document_batches()
29 |
30 | batch_index = 0
31 | while True:
32 | # Accumulate enough tokens for one iteration before yielding.
33 | while len(token_buffer) < needed_tokens:
34 | doc_batch = next(batches)
35 | token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
36 | for tokens in token_lists:
37 | token_buffer.extend(tokens)
38 | batch_index += 1
39 | # Move tokens from the deque into the scratch buffer
40 | tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
41 | scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True)
42 | # Create the inputs/targets as 1D tensors
43 | inputs_cpu = scratch[:-1].to(dtype=torch.int32)
44 | targets_cpu = scratch[1:]
45 | # Reshape to 2D and move to GPU async
46 | inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
47 | targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
48 | yield inputs, targets
49 |
--------------------------------------------------------------------------------
/dev/runcpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
4 | # Run as:
5 | # bash dev/cpu_demo_run.sh
6 |
7 | # NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
8 | # Think of this run as educational/fun demo, not something you should expect to work well.
9 | # This is also why I hide this script away in dev/
10 |
11 | # all the setup stuff
12 | export OMP_NUM_THREADS=1
13 | NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
14 | mkdir -p $NANOCHAT_BASE_DIR
15 | command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
16 | [ -d ".venv" ] || uv venv
17 | uv sync
18 | source .venv/bin/activate
19 | if [ -z "$WANDB_RUN" ]; then
20 | WANDB_RUN=dummy
21 | fi
22 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
23 | source "$HOME/.cargo/env"
24 | uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
25 | EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
26 | if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
27 | curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
28 | unzip -q eval_bundle.zip
29 | rm eval_bundle.zip
30 | mv eval_bundle $NANOCHAT_BASE_DIR
31 | fi
32 |
33 | # wipe the report
34 | python -m nanochat.report reset
35 |
36 | # train tokenizer on ~1B characters
37 | python -m nanochat.dataset -n 4
38 | python -m scripts.tok_train --max_chars=1000000000
39 | python -m scripts.tok_eval
40 |
41 | # train a very small 4 layer model on the CPU
42 | # each optimization step processes a single sequence of 1024 tokens
43 | # we only run 50 steps of optimization (bump this to get better results)
44 | python -m scripts.base_train \
45 | --depth=4 \
46 | --max_seq_len=1024 \
47 | --device_batch_size=1 \
48 | --total_batch_size=1024 \
49 | --eval_every=50 \
50 | --eval_tokens=4096 \
51 | --core_metric_every=50 \
52 | --core_metric_max_per_task=12 \
53 | --sample_every=50 \
54 | --num_iterations=50
55 | python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
56 | python -m scripts.base_eval --max-per-task=16
57 |
58 | # midtraining
59 | python -m scripts.mid_train \
60 | --max_seq_len=1024 \
61 | --device_batch_size=1 \
62 | --eval_every=50 \
63 | --eval_tokens=4096 \
64 | --total_batch_size=1024 \
65 | --num_iterations=100
66 | # eval results will be terrible, this is just to execute the code paths.
67 | # note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
68 | python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
69 |
70 | # SFT
71 | python -m scripts.chat_sft \
72 | --device_batch_size=1 \
73 | --target_examples_per_step=4 \
74 | --num_iterations=100 \
75 | --eval_steps=4 \
76 | --eval_metrics_max_problems=16
77 |
78 | # Chat CLI
79 | # python -m scripts.chat_cli -p "Why is the sky blue?"
80 |
81 | # Chat Web
82 | # python -m scripts.chat_web
83 |
84 | python -m nanochat.report generate
85 |
--------------------------------------------------------------------------------
/tasks/customjson.py:
--------------------------------------------------------------------------------
1 | """
2 | CustomJSON task for loading conversations from JSONL files.
3 | Each line in the JSONL file should be a JSON array of messages.
4 | """
5 |
6 | import os
7 | import json
8 | from tasks.common import Task
9 |
10 | class CustomJSON(Task):
11 | """
12 | Load conversations from a JSONL file.
13 | Each line should be a JSON array of message objects with 'role' and 'content' fields.
14 | Example line: [{"role":"user","content":"Hi"},{"role":"assistant","content":"Hello"}]
15 | """
16 |
17 | def __init__(self, filepath, **kwargs):
18 | super().__init__(**kwargs)
19 | self.filepath = filepath
20 | self.conversations = []
21 |
22 | # Load all conversations from the JSONL file
23 | if not os.path.exists(filepath):
24 | # Helpful error message due to recent change. Will be removed in the future.
25 | print("-" * 80)
26 | print(f"Warning: File {filepath} does not exist")
27 | print("HINT (Oct 21 2025)")
28 | print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations")
29 | print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
30 | print("Quick fix: simply run the following command to download the file and you're done:")
31 | print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl")
32 | print("-" * 80)
33 |
34 | else:
35 | with open(filepath, 'r') as f:
36 | for line in f:
37 | line = line.strip()
38 | if not line: # skip empty lines
39 | continue
40 | messages = json.loads(line)
41 | # Validate the conversation structure
42 | assert isinstance(messages, list), f"Expected list of messages, got {type(messages)}"
43 | assert len(messages) >= 2, f"Conversation must have at least 2 messages, got {len(messages)}"
44 | # Validate message structure and alternating roles
45 | for i, message in enumerate(messages):
46 | assert "role" in message, f"Message {i} missing 'role' field"
47 | assert "content" in message, f"Message {i} missing 'content' field"
48 | expected_role = "user" if i % 2 == 0 else "assistant"
49 | assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
50 | assert isinstance(message["content"], str), f"Message {i} content must be a string"
51 |
52 | self.conversations.append(messages)
53 |
54 | self.length = len(self.conversations)
55 |
56 | def num_examples(self):
57 | return self.length
58 |
59 | def get_example(self, index):
60 | messages = self.conversations[index]
61 | conversation = {
62 | "messages": messages,
63 | }
64 | return conversation
65 |
66 |
--------------------------------------------------------------------------------
/nanochat/loss_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | A number of functions that help with evaluating a base model.
3 | """
4 | import math
5 | import torch
6 | import torch.distributed as dist
7 |
8 | @torch.no_grad()
9 | def evaluate_bpb(model, batches, steps, token_bytes):
10 | """
11 | Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
12 | which is a tokenization vocab size-indepedent metric, meaning you are still comparing
13 | apples:apples if you change the vocab size. The way this works is that instead of just
14 | calculating the average loss as usual, you calculate the sum loss, and indepependently
15 | also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
16 | the number of bytes that the target tokens represent.
17 |
18 | The added complexity is so that:
19 | 1) All "normal" tokens are normalized by the length of the token in bytes
20 | 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
21 | 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
22 |
23 | In addition to evaluate_loss, we need the token_bytes tensor:
24 | It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
25 | each token id, or 0 if the token is to not be counted (e.g. special tokens).
26 | """
27 | # record the losses
28 | total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
29 | total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
30 | batch_iter = iter(batches)
31 | for _ in range(steps):
32 | x, y = next(batch_iter)
33 | loss2d = model(x, y, loss_reduction='none') # (B, T)
34 | loss2d = loss2d.view(-1) # flatten
35 | y = y.view(-1) # flatten
36 | if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
37 | # slightly more complex code path if some target tokens are ignore_index (e.g. -1)
38 | # any target token < 0 is to be ignored: do NOT index token_bytes with negatives
39 | valid = y >= 0
40 | y_safe = torch.where(valid, y, torch.zeros_like(y))
41 | # map valid targets to their byte length; ignored targets contribute 0 bytes
42 | num_bytes2d = torch.where(
43 | valid,
44 | token_bytes[y_safe],
45 | torch.zeros_like(y, dtype=token_bytes.dtype)
46 | )
47 | total_nats += (loss2d * (num_bytes2d > 0)).sum()
48 | total_bytes += num_bytes2d.sum()
49 | else:
50 | # fast path: no ignored targets, safe to index directly
51 | num_bytes2d = token_bytes[y]
52 | total_nats += (loss2d * (num_bytes2d > 0)).sum()
53 | total_bytes += num_bytes2d.sum()
54 | # sum reduce across all ranks
55 | world_size = dist.get_world_size() if dist.is_initialized() else 1
56 | if world_size > 1:
57 | dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
58 | dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
59 | # move both to cpu, calculate bpb and return
60 | total_nats = total_nats.item()
61 | total_bytes = total_bytes.item()
62 | bpb = total_nats / (math.log(2) * total_bytes)
63 | return bpb
64 |
--------------------------------------------------------------------------------
/scripts/base_loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Loads a checkpoint, and:
3 | - Evaluates the loss on a larger chunk of train/val splits
4 | - Samples from the model
5 |
6 | Example run as:
7 | torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
8 | """
9 | import os
10 | from contextlib import nullcontext
11 | import torch
12 | from nanochat.checkpoint_manager import load_model
13 | from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
14 | from nanochat.dataloader import tokenizing_distributed_data_loader
15 | from nanochat.tokenizer import get_token_bytes
16 | from nanochat.loss_eval import evaluate_bpb
17 | from nanochat.engine import Engine
18 |
19 | # Configuration
20 | device_batch_size = 32
21 | split_tokens = 20*524288 # number of tokens to evaluate per split
22 | model_tag = None # optional model tag for the output directory name
23 | model_step = None # optional model step for the output directory name
24 | device_type = "" # cuda|cpu|mps (empty => autodetect)
25 | exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
26 |
27 | # Load the base model and the tokenizer
28 | device_type = autodetect_device_type() if device_type == "" else device_type
29 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
30 | model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
31 | sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
32 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
33 |
34 | # Evaluate the loss on each split
35 | tokens_per_step = device_batch_size * sequence_len * ddp_world_size
36 | assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
37 | steps = split_tokens // tokens_per_step
38 | token_bytes = get_token_bytes(device=device)
39 | bpb_results = {}
40 | for split_name in ["train", "val"]:
41 | loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
42 | with autocast_ctx:
43 | bpb = evaluate_bpb(model, loader, steps, token_bytes)
44 | print0(f"{split_name} bpb: {bpb:.4f}")
45 | bpb_results[split_name] = bpb
46 |
47 | # Master process also samples from the model
48 | samples = []
49 | if ddp_rank == 0:
50 | prompts = [
51 | "The capital of France is",
52 | "The chemical symbol of gold is",
53 | "If yesterday was Friday, then tomorrow will be",
54 | "The opposite of hot is",
55 | "The planets of the solar system are:",
56 | "My favorite color is",
57 | "If 5*x + 3 = 13, then x is",
58 | ]
59 | engine = Engine(model, tokenizer)
60 | for prompt in prompts:
61 | tokens = tokenizer(prompt, prepend="<|bos|>")
62 | with autocast_ctx:
63 | sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
64 | sample_str = tokenizer.decode(sample[0])
65 | print0(sample_str)
66 | samples.append(sample_str)
67 |
68 | # Log to report
69 | from nanochat.report import get_report
70 | get_report().log(section="Base model loss", data=[
71 | {
72 | "train bpb": bpb_results["train"],
73 | "val bpb": bpb_results["val"],
74 | },
75 | {f"sample {i}": sample for i, sample in enumerate(samples)},
76 | ])
77 |
78 | # Cleanup
79 | compute_cleanup()
80 |
--------------------------------------------------------------------------------
/dev/repackage_data_reference.py:
--------------------------------------------------------------------------------
1 | """
2 | Repackage the FinewebEdu-100B dataset into shards:
3 |
4 | - each shard is ~100MB in size (after zstd compression)
5 | - parquets are written with row group size of 1000
6 | - shuffle the dataset
7 |
8 | This will be uploaded to HuggingFace for hosting.
9 | The big deal is that our DataLoader will be able to stream
10 | the data and cache it along the way on disk, decreasing the
11 | training latency.
12 |
13 | NOTE: This file is meant only as reference/documentation of the
14 | dataset preparation and it is not used during the project runtime.
15 | """
16 | import os
17 | import time
18 |
19 | from datasets import load_dataset
20 | import pyarrow.parquet as pq
21 | import pyarrow as pa
22 |
23 | # Source dataset
24 | dataset_kwargs = {
25 | "path": "HuggingFaceFW/fineweb-edu",
26 | "split": "train",
27 | "name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total
28 | }
29 | ds = load_dataset(**dataset_kwargs)
30 |
31 | # Shuffle to scramble the order
32 | ds = ds.shuffle(seed=42)
33 | ndocs = len(ds) # total number of documents to process
34 | print(f"Total number of documents: {ndocs}")
35 |
36 | # Repackage into parquet files
37 | output_dir = "/home/ubuntu/.cache/nanochat/base_data"
38 | os.makedirs(output_dir, exist_ok=True)
39 |
40 | # Write to parquet files
41 | chars_per_shard = 250_000_000
42 | row_group_size = 1024 # HF uses 1000 but we use multiple of 2, nicer for distributed data loader later
43 | shard_docs = []
44 | shard_index = 0
45 | shard_characters = 0
46 | total_docs_processed = 0
47 | total_time_spent = 0
48 | t0 = time.time()
49 | for doc in ds:
50 | text = doc['text']
51 | shard_docs.append(text)
52 | shard_characters += len(text)
53 | collected_enough_chars = shard_characters >= chars_per_shard
54 | docs_multiple_of_row_group_size = len(shard_docs) % row_group_size == 0
55 | if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed)
56 | shard_path = os.path.join(output_dir, f"shard_{shard_index:05d}.parquet")
57 | shard_table = pa.Table.from_pydict({"text": shard_docs})
58 | pq.write_table(
59 | shard_table,
60 | shard_path,
61 | row_group_size=row_group_size,
62 | use_dictionary=False, # this is usually used for categorical data
63 | compression="zstd", # Valid values: {‘NONE’, ‘SNAPPY’, ‘GZIP’, ‘BROTLI’, ‘LZ4’, ‘ZSTD’}
64 | compression_level=3,
65 | write_statistics=False, # not needed for text
66 | )
67 | t1 = time.time()
68 | dt = t1 - t0 # for this shard alone
69 | t0 = t1
70 | total_docs_processed += len(shard_docs)
71 | total_time_spent += dt
72 | remaining_docs = ndocs - total_docs_processed
73 | avg_time_per_doc = total_time_spent / total_docs_processed
74 | remaining_time = remaining_docs * avg_time_per_doc
75 | remaining_time_hours = remaining_time / 3600
76 | print(f"Wrote {shard_path}. #documents: {len(shard_docs)} | #characters: {shard_characters} | time: {dt:.2f}s | remaining time: {remaining_time_hours:.2f}h")
77 | shard_docs = []
78 | shard_characters = 0
79 | shard_index += 1
80 |
81 | # Demonstration of how the data was later uploaded to HuggingFace
82 | def upload():
83 | import os
84 | from huggingface_hub import HfApi
85 | token = os.getenv("HF_TOKEN")
86 | api = HfApi(token=token)
87 | api.upload_large_folder(
88 | folder_path=output_dir,
89 | repo_id="karpathy/fineweb-edu-100b-shuffle",
90 | repo_type="dataset",
91 | )
92 | # upload()
93 |
--------------------------------------------------------------------------------
/nanochat/adamw.py:
--------------------------------------------------------------------------------
1 | """
2 | Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
3 | Not a general optimizer! But works for our specific use.
4 | """
5 | import torch
6 | import torch.distributed as dist
7 | from torch import Tensor
8 |
9 |
10 | class DistAdamW(torch.optim.Optimizer):
11 | """
12 | Distributed AdamW optimizer.
13 | In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
14 | """
15 | def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
16 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
17 | super().__init__(param_groups, defaults)
18 |
19 | @torch.compile
20 | @torch.no_grad()
21 | def step(self):
22 | rank = dist.get_rank()
23 | world_size = dist.get_world_size()
24 | reduce_scatter_futures: list[torch.Future] = []
25 | all_reduce_futures: list[torch.Future] = []
26 | grad_slices = []
27 | for group in self.param_groups:
28 | params: list[Tensor] = group["params"]
29 | for base_i in range(len(params)):
30 | grad = params[base_i].grad
31 | rank_size = grad.shape[0] // world_size
32 | grad_slice = torch.empty_like(grad[:rank_size])
33 | reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
34 | grad_slices.append(grad_slice)
35 |
36 | idx = 0
37 | for group in self.param_groups:
38 | beta1, beta2 = group['betas']
39 | eps = group['eps']
40 | wd = group['weight_decay']
41 | params = group['params']
42 | for base in range(len(params)):
43 | reduce_scatter_futures[idx].wait()
44 | p = params[base]
45 | rank_size = p.shape[0] // world_size
46 | p_slice = p[rank * rank_size:(rank + 1) * rank_size]
47 | lr = group['lr'] * getattr(p, "lr_mul", 1.0)
48 | state = self.state[p]
49 | g_slice = grad_slices[idx]
50 | # State init
51 | if not state:
52 | state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
53 | state['exp_avg'] = torch.zeros_like(p_slice)
54 | state['exp_avg_sq'] = torch.zeros_like(p_slice)
55 | exp_avg = state['exp_avg']
56 | exp_avg_sq = state['exp_avg_sq']
57 | state['step'] += 1
58 | t = state['step']
59 | # weight decay
60 | if wd != 0:
61 | eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
62 | p_slice.mul_(1 - eff_weight_decay)
63 | # update running averages
64 | exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
65 | exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
66 | # bias corrections
67 | bias1 = 1 - beta1 ** t
68 | bias2 = 1 - beta2 ** t
69 | # compute step
70 | denom = exp_avg_sq.sqrt().add_(eps)
71 | step_size = lr * (torch.sqrt(bias2) / bias1)
72 | update = exp_avg.div(denom).mul_(step_size)
73 | p_slice.add_(other=update, alpha=-1.0)
74 | idx += 1
75 | all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
76 | torch.futures.collect_all(all_reduce_futures).wait()
77 |
--------------------------------------------------------------------------------
/tasks/humaneval.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluate the Chat model on HumanEval dataset.
3 | Btw this dataset is a misnomer and has nothing to do with humans.
4 | It is a coding benchmark.
5 | """
6 |
7 | import re
8 | from datasets import load_dataset
9 | from nanochat.execution import execute_code
10 | from tasks.common import Task
11 |
12 | def extract_imports(prompt):
13 | """Extract import statements from the beginning of a code block."""
14 | imports = []
15 | for line in prompt.split('\n'):
16 | stripped = line.strip()
17 | if stripped.startswith('import ') or stripped.startswith('from '):
18 | imports.append(stripped)
19 | elif stripped and not stripped.startswith('#'):
20 | # Stop at first non-import, non-comment line
21 | break
22 | return '\n'.join(imports)
23 |
24 | def extract_program(completion):
25 | """
26 | Extract Python code from LLM completion.
27 |
28 | Handles various output formats:
29 | - Code wrapped in ```python ... ``` or ``` ... ``` blocks
30 | - Plain code without markdown blocks
31 | - Extra text before/after code blocks
32 |
33 | Returns the first code block if found, otherwise returns the whole completion.
34 | """
35 | # Try to find markdown code blocks (```python or just ```)
36 | # Match ```python\n...\n``` or ```\n...\n```
37 | pattern = r'```(?:python)?\s*\n(.*?)\n```'
38 | matches = re.findall(pattern, completion, re.DOTALL)
39 |
40 | if matches:
41 | # Return the first code block found
42 | return matches[0].strip()
43 |
44 | # No code blocks found, return the whole completion
45 | return completion.strip()
46 |
47 | class HumanEval(Task):
48 |
49 | def __init__(self, **kwargs):
50 | super().__init__(**kwargs)
51 | self.ds = load_dataset("openai/openai_humaneval", split="test").shuffle(seed=42)
52 |
53 | @property
54 | def eval_type(self):
55 | return 'generative'
56 |
57 | def num_examples(self):
58 | return len(self.ds)
59 |
60 | def get_example(self, index):
61 | """ Get a single problem from the dataset. """
62 | row = self.ds[index]
63 | prompt = row['prompt'] # prompts in HumanEval are the beginning of the program
64 | solution = row['canonical_solution'] # the correct continuation of the program
65 | entry_point = row['entry_point'] # the function to check
66 | test = row['test'] # the test cases
67 | complete_solution = f"{prompt}\n{solution}"
68 | messages = [
69 | {"role": "user", "content": prompt},
70 | {"role": "assistant", "content": complete_solution},
71 | ]
72 | conversation = {
73 | "messages": messages,
74 | "entry_point": entry_point, # needed during evaluation
75 | "test": test, # needed during evaluation
76 | }
77 | return conversation
78 |
79 | def evaluate(self, conversation, completion):
80 | """ Given (conversation, completion), return boolean success of the completion. """
81 | # the prompt will contain the imports and the function signature
82 | imports = extract_imports(conversation['messages'][0]['content'])
83 | # the completion will usually contain the whole function
84 | # but not always with the needed imports, so we manually append them
85 | completion_code = extract_program(completion)
86 | program = (
87 | imports
88 | + "\n\n"
89 | + completion_code
90 | + "\n\n"
91 | + conversation['test']
92 | + "\n"
93 | + f"check({conversation['entry_point']})"
94 | )
95 | result = execute_code(program)
96 | success = result.success
97 | return success
98 |
--------------------------------------------------------------------------------
/tasks/mmlu.py:
--------------------------------------------------------------------------------
1 | """
2 | The MMLU dataset.
3 | https://huggingface.co/datasets/cais/mmlu
4 | """
5 |
6 | from datasets import load_dataset
7 | from tasks.common import Task, render_mc
8 |
9 | class MMLU(Task):
10 |
11 | letters = ('A', 'B', 'C', 'D')
12 | groups = ('abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions')
13 |
14 | def __init__(self, subset, split, **kwargs):
15 | super().__init__(**kwargs)
16 | assert subset in ["all", "auxiliary_train"], f"subset {subset} must be all|auxiliary_train"
17 | assert split in ["train", "validation", "dev", "test"], f"split {split} must be train|validation|dev|test"
18 | if subset == "auxiliary_train":
19 | assert split == "train", "auxiliary_train must be split into train"
20 | self.subset = subset
21 | self.split = split
22 | self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42)
23 | if subset == "auxiliary_train":
24 | # I don't understand why but the auxiliary_train rows have some weird additional 'train' wrapper
25 | self.ds = self.ds.map(lambda row: row['train'], remove_columns=['train'])
26 |
27 | @property
28 | def eval_type(self):
29 | return 'categorical'
30 |
31 | def num_examples(self):
32 | return len(self.ds)
33 |
34 | def get_example(self, index):
35 | row = self.ds[index]
36 | question = row["question"] # the question text
37 | choices = row["choices"] # the text of each choice
38 | answer = row["answer"] # index of the answer, e.g. 0,1,2,3 (for A,B,C,D)
39 | subject = row["subject"] # e.g. "college_biology", "college_chemistry", etc.
40 | assert len(choices) == 4, "MMLU should have 4 choices"
41 | # create and return the Conversation object
42 | user_message = render_mc(question, self.letters, choices)
43 | assistant_message = self.letters[answer]
44 | messages = [
45 | {"role": "user", "content": user_message},
46 | {"role": "assistant", "content": assistant_message}
47 | ]
48 | conversation = {
49 | "messages": messages,
50 | "subject": subject, # might be useful later for grouping metrics by subject
51 | "letters": self.letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
52 | }
53 | return conversation
54 |
55 | def evaluate(self, conversation, assistant_response):
56 | # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
57 | # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
58 | assert assistant_response in self.letters, f"MMLU answer {assistant_response} is expected to be one of {self.letters}"
59 | assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
60 | return assistant_response == assistant_message
61 |
--------------------------------------------------------------------------------
/scripts/chat_cli.py:
--------------------------------------------------------------------------------
1 | """
2 | New and upgraded chat mode because a lot of the code has changed since the last one.
3 |
4 | Intended to be run single GPU only atm:
5 | python -m scripts.chat_cli -i mid
6 | """
7 | import argparse
8 | import torch
9 | from nanochat.common import compute_init, autodetect_device_type
10 | from contextlib import nullcontext
11 | from nanochat.engine import Engine
12 | from nanochat.checkpoint_manager import load_model
13 |
14 | parser = argparse.ArgumentParser(description='Chat with the model')
15 | parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
16 | parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
17 | parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
18 | parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
19 | parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
20 | parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
21 | parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
22 | parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
23 | args = parser.parse_args()
24 |
25 | # Init the model and tokenizer
26 |
27 | device_type = autodetect_device_type() if args.device_type == "" else args.device_type
28 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
29 | ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
30 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
31 | model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
32 |
33 | # Special tokens for the chat state machine
34 | bos = tokenizer.get_bos_token_id()
35 | user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
36 | assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>")
37 |
38 | # Create Engine for efficient generation
39 | engine = Engine(model, tokenizer)
40 |
41 | print("\nNanoChat Interactive Mode")
42 | print("-" * 50)
43 | print("Type 'quit' or 'exit' to end the conversation")
44 | print("Type 'clear' to start a new conversation")
45 | print("-" * 50)
46 |
47 | conversation_tokens = [bos]
48 |
49 | while True:
50 |
51 | if args.prompt:
52 | # Get the prompt from the launch command
53 | user_input = args.prompt
54 | else:
55 | # Get the prompt interactively from the console
56 | try:
57 | user_input = input("\nUser: ").strip()
58 | except (EOFError, KeyboardInterrupt):
59 | print("\nGoodbye!")
60 | break
61 |
62 | # Handle special commands
63 | if user_input.lower() in ['quit', 'exit']:
64 | print("Goodbye!")
65 | break
66 |
67 | if user_input.lower() == 'clear':
68 | conversation_tokens = [bos]
69 | print("Conversation cleared.")
70 | continue
71 |
72 | if not user_input:
73 | continue
74 |
75 | # Add User message to the conversation
76 | conversation_tokens.append(user_start)
77 | conversation_tokens.extend(tokenizer.encode(user_input))
78 | conversation_tokens.append(user_end)
79 |
80 | # Kick off the assistant
81 | conversation_tokens.append(assistant_start)
82 | generate_kwargs = {
83 | "num_samples": 1,
84 | "max_tokens": 256,
85 | "temperature": args.temperature,
86 | "top_k": args.top_k,
87 | }
88 | response_tokens = []
89 | print("\nAssistant: ", end="", flush=True)
90 | with autocast_ctx:
91 | for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
92 | token = token_column[0] # pop the batch dimension (num_samples=1)
93 | response_tokens.append(token)
94 | token_text = tokenizer.decode([token])
95 | print(token_text, end="", flush=True)
96 | print()
97 | # we have to ensure that the assistant end token is the last token
98 | # so even if generation ends due to max tokens, we have to append it to the end
99 | if response_tokens[-1] != assistant_end:
100 | response_tokens.append(assistant_end)
101 | conversation_tokens.extend(response_tokens)
102 |
103 | # In the prompt mode, we only want a single response and exit
104 | if args.prompt:
105 | break
106 |
--------------------------------------------------------------------------------
/scripts/tok_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a tokenizer using the HuggingFace Tokenizers library.
3 | In the style of GPT-4 tokenizer.
4 | """
5 | import os
6 | import time
7 | import argparse
8 | import torch
9 | from nanochat.tokenizer import RustBPETokenizer
10 | from nanochat.common import get_base_dir
11 | from nanochat.dataset import parquets_iter_batched
12 |
13 | # -----------------------------------------------------------------------------
14 | # Parse command line arguments
15 |
16 | parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
17 | parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
18 | parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
19 | parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
20 | args = parser.parse_args()
21 | print(f"max_chars: {args.max_chars:,}")
22 | print(f"doc_cap: {args.doc_cap:,}")
23 | print(f"vocab_size: {args.vocab_size:,}")
24 |
25 | # -----------------------------------------------------------------------------
26 | # Text iterator
27 |
28 | def text_iterator():
29 | """
30 | 1) Flatten the batches into a single iterator
31 | 2) Crop every document to args.doc_cap characters
32 | 3) Break when we've seen args.max_chars characters
33 | """
34 | nchars = 0
35 | for batch in parquets_iter_batched(split="train"):
36 | for doc in batch:
37 | doc_text = doc
38 | if len(doc_text) > args.doc_cap:
39 | doc_text = doc_text[:args.doc_cap]
40 | nchars += len(doc_text)
41 | yield doc_text
42 | if nchars > args.max_chars:
43 | return
44 | text_iter = text_iterator()
45 |
46 | # -----------------------------------------------------------------------------
47 | # Train the tokenizer
48 | t0 = time.time()
49 | tokenizer = RustBPETokenizer.train_from_iterator(text_iter, args.vocab_size)
50 | t1 = time.time()
51 | train_time = t1 - t0
52 | print(f"Training time: {train_time:.2f}s")
53 |
54 | # -----------------------------------------------------------------------------
55 | # Save the tokenizer to disk
56 | base_dir = get_base_dir()
57 | tokenizer_dir = os.path.join(base_dir, "tokenizer")
58 | tokenizer.save(tokenizer_dir)
59 |
60 | # -----------------------------------------------------------------------------
61 | # Quick inline sanity check
62 | test_text = """Hello world! This is a test.
63 | Numbers: 123, 4567, 89
64 | Contractions: I'm, you're, it's
65 | Special chars: @#$%^&*()
66 | Unicode: 你好世界 🌍"""
67 | encoded = tokenizer.encode(test_text)
68 | decoded = tokenizer.decode(encoded)
69 | assert decoded == test_text
70 |
71 | # -----------------------------------------------------------------------------
72 | # One more thing: we wish to cache a mapping from token id to number of bytes of that token
73 | # for efficient evaluation of bits per byte. Unlike the typical mean loss, this
74 | # allows us to report a loss that is invariant to the vocab size of the tokenizer.
75 | # The bits per byte on the validation set is then one of the primary metrics we care about.
76 | vocab_size = tokenizer.get_vocab_size()
77 | special_set = set(tokenizer.get_special_tokens())
78 | token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)]
79 | token_bytes = []
80 | for token_id in range(vocab_size):
81 | token_str = token_strings[token_id] # the Python string representation of this token
82 | if token_str in special_set:
83 | token_bytes.append(0) # special characters are not counted
84 | else:
85 | id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
86 | token_bytes.append(id_bytes)
87 | token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu')
88 | token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
89 | with open(token_bytes_path, "wb") as f:
90 | torch.save(token_bytes, f)
91 | print(f"Saved token_bytes to {token_bytes_path}")
92 |
93 | # Log to report
94 | from nanochat.report import get_report
95 | token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32)
96 | get_report().log(section="Tokenizer training", data=[
97 | vars(args), # argparse command line arguments
98 | {"train_time": train_time},
99 | {"num_special_tokens": len(special_set)},
100 | {
101 | "token_bytes_min": int(token_bytes_nonzero.min().item()),
102 | "token_bytes_max": int(token_bytes_nonzero.max().item()),
103 | "token_bytes_mean": token_bytes_nonzero.mean().item(),
104 | "token_bytes_std": token_bytes_nonzero.std().item(),
105 | }
106 | ])
107 |
--------------------------------------------------------------------------------
/run1000.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # The $1000 tier of nanochat
4 | # Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
5 | # A bit sparser on comments, see speedrun.sh for more detail
6 |
7 | # all the setup stuff
8 | export OMP_NUM_THREADS=1
9 | export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
10 | mkdir -p $NANOCHAT_BASE_DIR
11 | command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
12 | [ -d ".venv" ] || uv venv
13 | uv sync
14 | source .venv/bin/activate
15 | if [ -z "$WANDB_RUN" ]; then
16 | WANDB_RUN=dummy
17 | fi
18 | python -m nanochat.report reset
19 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
20 | source "$HOME/.cargo/env"
21 | uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
22 | EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
23 | if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
24 | curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
25 | unzip -q eval_bundle.zip
26 | rm eval_bundle.zip
27 | mv eval_bundle $NANOCHAT_BASE_DIR
28 | fi
29 | curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
30 |
31 | # train tokenizer on ~4B characters and kick off download of the rest for pretraining
32 | python -m nanochat.dataset -n 16
33 | # start downloading the rest of the shards for a total of 800 (see below why 800)
34 | python -m nanochat.dataset -n 800 &
35 | # todo: download the rest of it
36 | python -m scripts.tok_train --max_chars=4000000000
37 | python -m scripts.tok_eval
38 |
39 | # Documenting my process for determining the hyperparameters for this run1000.sh script:
40 | # We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
41 | # 1) I guessed the model size for this to be about depth=32
42 | # 2) Determine the device_batch_size that fits:
43 | # Running the base_train.py script with --depth=32, I saw that --device_batch_size=16
44 | # runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training,
45 | # I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
46 | # So the training script was running ok and showed:
47 | # Vocab size: 65,536
48 | # num_layers: 32
49 | # model_dim: 2048
50 | # num_heads: 16
51 | # num_kv_heads: 16
52 | # Tokens / micro-batch / rank: 8 x 2048 = 16,384
53 | # Tokens / micro-batch: 131,072
54 | # Total batch size 524,288 => gradient accumulation steps: 4
55 | # Number of parameters: 1,879,048,192
56 | # Estimated FLOPs per token: 1.207960e+10
57 | # Calculated number of iterations from target data:param ratio: 71,680
58 | # Total number of training tokens: 37,580,963,840
59 | # Tokens : Params ratio: 20.00
60 | # Total training FLOPs estimate: 4.539628e+20
61 | # step 00004/71680 (0.01%) | loss: 8.813754 | lrm: 1.00 | dt: 1571.88ms | tok/sec: 83,385 | mfu: 50.92 | total time: 0.00m
62 | # step 00005/71680 (0.01%) | loss: 8.488074 | lrm: 1.00 | dt: 1572.76ms | tok/sec: 83,338 | mfu: 50.89 | total time: 0.00m
63 | # ...
64 | # 3) validate that the runtime fits our budget:
65 | # The training script uses the Chinchilla scaling law to compute-optimally set #tokens = 20 * #params. In particular:
66 | # The script shows that we will be training for 71,680 steps, and each step takes 1.574s so:
67 | # estimated time to train: 71,680 * 1.574s / 60 / 60 = 31.3 hours.
68 | # This is OK, fits our budget, and leaves ~10 hours for midtraining and SFT and evals and maybe RL.
69 | # It's possible that we might even fit depth=33 or depth=34, but for now let's go along with this.
70 | # 4) The last thing to pay attention to is the amount of training data required for the run.
71 | # The script above calculated that "Total number of training tokens: 37,580,963,840"
72 | # The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings.
73 | # So ~38B tokens # ~4.8 chars/token = ~185B chars.
74 | # Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards.
75 | # For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards.
76 | # If we didn't have enough data, the training script would loop around and do multiple epochs over the same data,
77 | # which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
78 | # start to overfit hard.
79 | # 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script.
80 | torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=32 --device_batch_size=8
81 | torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
82 | torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
83 |
84 | # midtrain
85 | # NOTE: ensure that we use the same device_batch_size here as the base training script.
86 | torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
87 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
88 |
89 | # sft
90 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
91 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
92 |
93 | # generate final report
94 | python -m nanochat.report generate
95 |
96 | # talk to it
97 | python -m scripts.chat_web
98 |
--------------------------------------------------------------------------------
/tasks/gsm8k.py:
--------------------------------------------------------------------------------
1 | """
2 | GSM8K evaluation.
3 | https://huggingface.co/datasets/openai/gsm8k
4 |
5 | Example problem instance:
6 |
7 | Question:
8 | Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?
9 | Answer:
10 | Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.
11 | Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.
12 | #### 10
13 |
14 | Notice that GSM8K uses tool calls inside << >> tags.
15 | """
16 |
17 | import re
18 | from datasets import load_dataset
19 | from tasks.common import Task
20 |
21 |
22 | GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
23 | def extract_answer(completion):
24 | """
25 | Extract the numerical answer after #### marker.
26 | Follows official code for normalization:
27 | https://github.com/openai/grade-school-math/blob/3101c7d5072418e28b9008a6636bde82a006892c/grade_school_math/dataset.py#L28
28 | """
29 | match = GSM_RE.search(completion)
30 | if match:
31 | match_str = match.group(1).strip()
32 | match_str = match_str.replace(",", "")
33 | return match_str
34 | return None
35 |
36 |
37 | class GSM8K(Task):
38 |
39 | def __init__(self, subset, split, **kwargs):
40 | super().__init__(**kwargs)
41 | assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic"
42 | assert split in ["train", "test"], "GSM8K split must be train|test"
43 | self.ds = load_dataset("openai/gsm8k", subset, split=split).shuffle(seed=42)
44 |
45 | @property
46 | def eval_type(self):
47 | return 'generative'
48 |
49 | def num_examples(self):
50 | return len(self.ds)
51 |
52 | def get_example(self, index):
53 | """ Get a single problem from the dataset. """
54 | row = self.ds[index]
55 | question = row['question'] # string of the question prompt
56 | answer = row['answer'] # string of the full solution and the answer after #### marker
57 | # Create and return the Conversation object
58 | # This is tricky because GSM8K uses tool calls, which we need to parse here.
59 | assistant_message_parts = []
60 | parts = re.split(r'(<<[^>]+>>)', answer)
61 | for part in parts:
62 | if part.startswith('<<') and part.endswith('>>'):
63 | # This is a calculator tool call
64 | inner = part[2:-2] # Remove << >>
65 | # Split on = to get expression and result
66 | if '=' in inner:
67 | expr, result = inner.rsplit('=', 1)
68 | else:
69 | expr, result = inner, ""
70 | # Add the tool call as a part
71 | assistant_message_parts.append({"type": "python", "text": expr})
72 | # Add the result as a part
73 | assistant_message_parts.append({"type": "python_output", "text": result})
74 | else:
75 | # Regular text in between tool calls
76 | assistant_message_parts.append({"type": "text", "text": part})
77 | # No put it all together
78 | messages = [
79 | {"role": "user", "content": question}, # note: simple string
80 | {"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts)
81 | ]
82 | conversation = {
83 | "messages": messages,
84 | }
85 | return conversation
86 |
87 | def evaluate(self, conversation, assistant_response):
88 | """
89 | Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct)
90 | Note that:
91 | - the conversation has both user AND assistant message (containing the ground truth answer)
92 | - the assistant_response is usually the alternative assistant message achieved via sampling
93 |
94 | TODO: Technically, assistant_response should be a Message (either a string or a list of parts)
95 | We can handle this later possibly. For now just assume string.
96 | """
97 | assert isinstance(assistant_response, str), "Assuming simple string response for now"
98 | # First extract the ground truth answer
99 | assistant_message = conversation['messages'][-1]
100 | assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
101 | assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
102 | last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K
103 | # Extract both the ground truth answer and the predicted answer
104 | ref_num = extract_answer(last_text_part)
105 | pred_num = extract_answer(assistant_response)
106 | # Compare and return the success as int
107 | is_correct = int(pred_num == ref_num)
108 | return is_correct
109 |
110 | def reward(self, conversation, assistant_response):
111 | """
112 | Used during RL. To keep things simple, just re-use the evaluation above.
113 | Later this could be made more complex (e.g. format matching etc.)
114 | """
115 | is_correct = self.evaluate(conversation, assistant_response)
116 | is_correct_float = float(is_correct)
117 | return is_correct_float
118 |
--------------------------------------------------------------------------------
/nanochat/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | The base/pretraining dataset is a set of parquet files.
3 | This file contains utilities for:
4 | - iterating over the parquet files and yielding documents from it
5 | - download the files on demand if they are not on disk
6 |
7 | For details of how the dataset was prepared, see `repackage_data_reference.py`.
8 | """
9 |
10 | import os
11 | import argparse
12 | import time
13 | import requests
14 | import pyarrow.parquet as pq
15 | from multiprocessing import Pool
16 |
17 | from nanochat.common import get_base_dir
18 |
19 | # -----------------------------------------------------------------------------
20 | # The specifics of the current pretraining dataset
21 |
22 | # The URL on the internet where the data is hosted and downloaded from on demand
23 | BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
24 | MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
25 | index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
26 | base_dir = get_base_dir()
27 | DATA_DIR = os.path.join(base_dir, "base_data")
28 | os.makedirs(DATA_DIR, exist_ok=True)
29 |
30 | # -----------------------------------------------------------------------------
31 | # These functions are useful utilities to other modules, can/should be imported
32 |
33 | def list_parquet_files(data_dir=None):
34 | """ Looks into a data dir and returns full paths to all parquet files. """
35 | data_dir = DATA_DIR if data_dir is None else data_dir
36 | parquet_files = sorted([
37 | f for f in os.listdir(data_dir)
38 | if f.endswith('.parquet') and not f.endswith('.tmp')
39 | ])
40 | parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
41 | return parquet_paths
42 |
43 | def parquets_iter_batched(split, start=0, step=1):
44 | """
45 | Iterate through the dataset, in batches of underlying row_groups for efficiency.
46 | - split can be "train" or "val". the last parquet file will be val.
47 | - start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
48 | """
49 | assert split in ["train", "val"], "split must be 'train' or 'val'"
50 | parquet_paths = list_parquet_files()
51 | parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
52 | for filepath in parquet_paths:
53 | pf = pq.ParquetFile(filepath)
54 | for rg_idx in range(start, pf.num_row_groups, step):
55 | rg = pf.read_row_group(rg_idx)
56 | texts = rg.column('text').to_pylist()
57 | yield texts
58 |
59 | # -----------------------------------------------------------------------------
60 | def download_single_file(index):
61 | """ Downloads a single file index, with some backoff """
62 |
63 | # Construct the local filepath for this file and skip if it already exists
64 | filename = index_to_filename(index)
65 | filepath = os.path.join(DATA_DIR, filename)
66 | if os.path.exists(filepath):
67 | print(f"Skipping {filepath} (already exists)")
68 | return True
69 |
70 | # Construct the remote URL for this file
71 | url = f"{BASE_URL}/{filename}"
72 | print(f"Downloading {filename}...")
73 |
74 | # Download with retries
75 | max_attempts = 5
76 | for attempt in range(1, max_attempts + 1):
77 | try:
78 | response = requests.get(url, stream=True, timeout=30)
79 | response.raise_for_status()
80 | # Write to temporary file first
81 | temp_path = filepath + f".tmp"
82 | with open(temp_path, 'wb') as f:
83 | for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
84 | if chunk:
85 | f.write(chunk)
86 | # Move temp file to final location
87 | os.rename(temp_path, filepath)
88 | print(f"Successfully downloaded {filename}")
89 | return True
90 |
91 | except (requests.RequestException, IOError) as e:
92 | print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
93 | # Clean up any partial files
94 | for path in [filepath + f".tmp", filepath]:
95 | if os.path.exists(path):
96 | try:
97 | os.remove(path)
98 | except:
99 | pass
100 | # Try a few times with exponential backoff: 2^attempt seconds
101 | if attempt < max_attempts:
102 | wait_time = 2 ** attempt
103 | print(f"Waiting {wait_time} seconds before retry...")
104 | time.sleep(wait_time)
105 | else:
106 | print(f"Failed to download {filename} after {max_attempts} attempts")
107 | return False
108 |
109 | return False
110 |
111 |
112 | if __name__ == "__main__":
113 | parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
114 | parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
115 | parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
116 | args = parser.parse_args()
117 |
118 | num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
119 | ids_to_download = list(range(num))
120 | print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
121 | print(f"Target directory: {DATA_DIR}")
122 | print()
123 | with Pool(processes=args.num_workers) as pool:
124 | results = pool.map(download_single_file, ids_to_download)
125 |
126 | # Report results
127 | successful = sum(1 for success in results if success)
128 | print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
129 |
--------------------------------------------------------------------------------
/tasks/common.py:
--------------------------------------------------------------------------------
1 | """
2 | Base class for all Tasks.
3 | A Task is basically a dataset of conversations, together with some
4 | metadata and often also evaluation criteria.
5 | Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk.
6 | """
7 |
8 | import random
9 |
10 | class Task:
11 | """
12 | Base class of a Task. Allows for lightweight slicing of the underlying dataset.
13 | """
14 |
15 | def __init__(self, start=0, stop=None, step=1):
16 | # allows a lightweight logical view over a dataset
17 | assert start >= 0, f"Start must be non-negative, got {start}"
18 | assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}"
19 | assert step >= 1, f"Step must be strictly positive, got {step}"
20 | self.start = start
21 | self.stop = stop # could be None here
22 | self.step = step
23 |
24 | @property
25 | def eval_type(self):
26 | # one of 'generative' | 'categorical'
27 | raise NotImplementedError
28 |
29 | def num_examples(self):
30 | raise NotImplementedError
31 |
32 | def get_example(self, index):
33 | raise NotImplementedError
34 |
35 | def __len__(self):
36 | start = self.start
37 | stop = self.num_examples() if self.stop is None else self.stop
38 | step = self.step
39 | span = stop - start
40 | num = (span + step - 1) // step # ceil_div(span, step)
41 | assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns
42 | return num
43 |
44 | def __getitem__(self, index: int):
45 | assert isinstance(index, int), f"Index must be an integer, got {type(index)}"
46 | physical_index = self.start + index * self.step
47 | conversation = self.get_example(physical_index)
48 | return conversation
49 |
50 | def evaluate(self, problem, completion):
51 | raise NotImplementedError
52 |
53 |
54 | class TaskMixture(Task):
55 | """
56 | For SFT Training it becomes useful to train on a tax mixture of datasets.
57 | Fun trick: if you wish to oversample any task, just pass it in multiple times in the list.
58 | """
59 |
60 | def __init__(self, tasks, **kwargs):
61 | super().__init__(**kwargs)
62 | # tasks is a list of Task objects
63 | self.tasks = tasks
64 | self.lengths = [len(task) for task in self.tasks]
65 | self.num_conversations = sum(self.lengths)
66 | # Build list of all (task_idx, local_idx) pairs
67 | self.index_map = []
68 | for task_idx, task_length in enumerate(self.lengths):
69 | for local_idx in range(task_length):
70 | self.index_map.append((task_idx, local_idx))
71 | # Deterministically shuffle to mix tasks throughout training
72 | rng = random.Random(42)
73 | rng.shuffle(self.index_map)
74 | # Note: this is not the most elegant or best solution, but it's ok for now
75 |
76 | def num_examples(self):
77 | return self.num_conversations
78 |
79 | def get_example(self, index):
80 | """
81 | Access conversations according to a deterministic shuffle of all examples.
82 | This ensures tasks are mixed throughout training, regardless of dataset size.
83 | """
84 | assert 0 <= index < self.num_conversations, f"Index {index} out of range for mixture with {self.num_conversations} conversations"
85 | task_idx, local_idx = self.index_map[index]
86 | return self.tasks[task_idx][local_idx]
87 |
88 |
89 | class TaskSequence(Task):
90 | """
91 | For SFT Training sometimes we want to sequentially train on a list of tasks.
92 | This is useful for cases that require a training curriculum.
93 | """
94 |
95 | def __init__(self, tasks, **kwargs):
96 | super().__init__(**kwargs)
97 | self.tasks = tasks
98 | self.lengths = [len(task) for task in self.tasks]
99 | self.num_conversations = sum(self.lengths)
100 |
101 | def num_examples(self):
102 | return self.num_conversations
103 |
104 | def get_example(self, index):
105 | assert 0 <= index < self.num_conversations, f"Index {index} out of range for sequence with {self.num_conversations} conversations"
106 | for task_idx, task_length in enumerate(self.lengths):
107 | if index < task_length:
108 | return self.tasks[task_idx][index]
109 | index -= task_length
110 |
111 |
112 | def render_mc(question, letters, choices):
113 | """
114 | The common multiple choice rendering format we will use.
115 |
116 | Note two important design decisions:
117 | 1)
118 | Bigger models don't care as much, but smaller models prefer to have
119 | the letter *after* the choice, which results in better binding.
120 | 2)
121 | There is no whitespace between the delimiter (=) and the letter.
122 | This is actually critical because the tokenizer has different token ids
123 | for " A" vs. "A". The assistant responses will be just the letter itself,
124 | i.e. "A", so it is important that here in the prompt it is the exact same
125 | token, i.e. "A" with no whitespace before it. Again, bigger models don't care
126 | about this too much, but smaller models do care about some of these details.
127 | """
128 | query = f"Multiple Choice question: {question}\n"
129 | query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)])
130 | query += "\nRespond only with the letter of the correct answer."
131 | return query
132 |
133 |
134 | if __name__ == "__main__":
135 | # very lightweight test of slicing
136 | from tasks.mmlu import MMLU
137 |
138 | ds = MMLU(subset="auxiliary_train", split="train")
139 | print("Length of MMLU: ", len(ds))
140 | ex = ds[5]
141 | print("5th example: ", ex)
142 |
143 | ds = MMLU(subset="auxiliary_train", split="train", start=5, stop=10)
144 | print("Length of sliced MMLU[5:10]: ", len(ds))
145 | print("0th example of sliced MMLU: ", ds[0])
146 |
147 | print("They match: ", ex == ds[0])
148 |
--------------------------------------------------------------------------------
/nanochat/common.py:
--------------------------------------------------------------------------------
1 | """
2 | Common utilities for nanochat.
3 | """
4 |
5 | import os
6 | import re
7 | import logging
8 | import torch
9 | import torch.distributed as dist
10 |
11 | class ColoredFormatter(logging.Formatter):
12 | """Custom formatter that adds colors to log messages."""
13 | # ANSI color codes
14 | COLORS = {
15 | 'DEBUG': '\033[36m', # Cyan
16 | 'INFO': '\033[32m', # Green
17 | 'WARNING': '\033[33m', # Yellow
18 | 'ERROR': '\033[31m', # Red
19 | 'CRITICAL': '\033[35m', # Magenta
20 | }
21 | RESET = '\033[0m'
22 | BOLD = '\033[1m'
23 | def format(self, record):
24 | # Add color to the level name
25 | levelname = record.levelname
26 | if levelname in self.COLORS:
27 | record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
28 | # Format the message
29 | message = super().format(record)
30 | # Add color to specific parts of the message
31 | if levelname == 'INFO':
32 | # Highlight numbers and percentages
33 | message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
34 | message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
35 | return message
36 |
37 | def setup_default_logging():
38 | handler = logging.StreamHandler()
39 | handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
40 | logging.basicConfig(
41 | level=logging.INFO,
42 | handlers=[handler]
43 | )
44 |
45 | setup_default_logging()
46 | logger = logging.getLogger(__name__)
47 |
48 | def get_base_dir():
49 | # co-locate nanochat intermediates with other cached data in ~/.cache (by default)
50 | if os.environ.get("NANOCHAT_BASE_DIR"):
51 | nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
52 | else:
53 | home_dir = os.path.expanduser("~")
54 | cache_dir = os.path.join(home_dir, ".cache")
55 | nanochat_dir = os.path.join(cache_dir, "nanochat")
56 | os.makedirs(nanochat_dir, exist_ok=True)
57 | return nanochat_dir
58 |
59 | def print0(s="",**kwargs):
60 | ddp_rank = int(os.environ.get('RANK', 0))
61 | if ddp_rank == 0:
62 | print(s, **kwargs)
63 |
64 | def print_banner():
65 | # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
66 | banner = """
67 | █████ █████
68 | ░░███ ░░███
69 | ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
70 | ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███ ░░░███░
71 | ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
72 | ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
73 | ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░████████ ░░█████
74 | ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
75 | """
76 | print0(banner)
77 |
78 | def is_ddp():
79 | # TODO is there a proper way
80 | return int(os.environ.get('RANK', -1)) != -1
81 |
82 | def get_dist_info():
83 | if is_ddp():
84 | assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
85 | ddp_rank = int(os.environ['RANK'])
86 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
87 | ddp_world_size = int(os.environ['WORLD_SIZE'])
88 | return True, ddp_rank, ddp_local_rank, ddp_world_size
89 | else:
90 | return False, 0, 0, 1
91 |
92 | def autodetect_device_type():
93 | # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
94 | if torch.cuda.is_available():
95 | device_type = "cuda"
96 | elif torch.backends.mps.is_available():
97 | device_type = "mps"
98 | else:
99 | device_type = "cpu"
100 | print0(f"Autodetected device type: {device_type}")
101 | return device_type
102 |
103 | def compute_init(device_type="cuda"): # cuda|cpu|mps
104 | """Basic initialization that we keep doing over and over, so make common."""
105 |
106 | assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
107 | if device_type == "cuda":
108 | assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
109 | if device_type == "mps":
110 | assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
111 |
112 | # Reproducibility
113 | torch.manual_seed(42)
114 | if device_type == "cuda":
115 | torch.cuda.manual_seed(42)
116 | # skipping full reproducibility for now, possibly investigate slowdown later
117 | # torch.use_deterministic_algorithms(True)
118 |
119 | # Precision
120 | if device_type == "cuda":
121 | torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
122 |
123 | # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
124 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
125 | if ddp and device_type == "cuda":
126 | device = torch.device("cuda", ddp_local_rank)
127 | torch.cuda.set_device(device) # make "cuda" default to this device
128 | dist.init_process_group(backend="nccl", device_id=device)
129 | dist.barrier()
130 | else:
131 | device = torch.device(device_type) # mps|cpu
132 |
133 | if ddp_rank == 0:
134 | logger.info(f"Distributed world size: {ddp_world_size}")
135 |
136 | return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
137 |
138 | def compute_cleanup():
139 | """Companion function to compute_init, to clean things up before script exit"""
140 | if is_ddp():
141 | dist.destroy_process_group()
142 |
143 | class DummyWandb:
144 | """Useful if we wish to not use wandb but have all the same signatures"""
145 | def __init__(self):
146 | pass
147 | def log(self, *args, **kwargs):
148 | pass
149 | def finish(self):
150 | pass
151 |
--------------------------------------------------------------------------------
/nanochat/checkpoint_manager.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for saving and loading model/optim/state checkpoints.
3 | """
4 | import os
5 | import re
6 | import glob
7 | import json
8 | import logging
9 | import torch
10 |
11 | from nanochat.common import get_base_dir
12 | from nanochat.gpt import GPT, GPTConfig
13 | from nanochat.tokenizer import get_tokenizer
14 | from nanochat.common import setup_default_logging
15 |
16 | # Set up logging
17 | setup_default_logging()
18 | logger = logging.getLogger(__name__)
19 | def log0(message):
20 | if int(os.environ.get('RANK', 0)) == 0:
21 | logger.info(message)
22 |
23 | def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
24 | assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
25 | os.makedirs(checkpoint_dir, exist_ok=True)
26 | # Save the model state (parameters)
27 | model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
28 | torch.save(model_data, model_path)
29 | log0(f"Saved model file to: {model_path}")
30 | # Save the optimizer state (useful for SFT or any other fine-tuning)
31 | if optimizer_data is not None:
32 | optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
33 | torch.save(optimizer_data, optimizer_path)
34 | log0(f"Saved optimizer file to: {optimizer_path}")
35 | # Save the metadata dict as json
36 | meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
37 | with open(meta_path, "w") as f:
38 | json.dump(meta_data, f, indent=2)
39 | log0(f"Saved metadata file to: {meta_path}")
40 |
41 |
42 | def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
43 | # Load the model state
44 | model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
45 | model_data = torch.load(model_path, map_location=device)
46 | # Load the optimizer state if requested
47 | optimizer_data = None
48 | if load_optimizer:
49 | optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
50 | optimizer_data = torch.load(optimizer_path, map_location=device)
51 | # Load the metadata
52 | meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
53 | with open(meta_path, "r") as f:
54 | meta_data = json.load(f)
55 | return model_data, optimizer_data, meta_data
56 |
57 |
58 | def build_model(checkpoint_dir, step, device, phase):
59 | """
60 | A bunch of repetitive code to build a model from a given checkpoint.
61 | Returns:
62 | - base model - uncompiled, not wrapped in DDP
63 | - tokenizer
64 | - meta data saved during base model training
65 | """
66 | assert phase in ["train", "eval"], f"Invalid phase: {phase}"
67 | model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
68 | # Hack: fix torch compile issue, which prepends all keys with _orig_mod.
69 | model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
70 | model_config_kwargs = meta_data["model_config"]
71 | log0(f"Building model with config: {model_config_kwargs}")
72 | model_config = GPTConfig(**model_config_kwargs)
73 | with torch.device("meta"):
74 | model = GPT(model_config)
75 | # Load the model state
76 | model.to_empty(device=device)
77 | model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
78 | model.load_state_dict(model_data, strict=True, assign=True)
79 | # Put the model in the right training phase / mode
80 | if phase == "eval":
81 | model.eval()
82 | else:
83 | model.train()
84 | # Load the Tokenizer
85 | tokenizer = get_tokenizer()
86 | # Sanity check: compatibility between model and tokenizer
87 | assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
88 | return model, tokenizer, meta_data
89 |
90 |
91 | def find_largest_model(checkpoint_dir):
92 | # attempt to guess the model tag: take the biggest model available
93 | model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
94 | if not model_tags:
95 | raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
96 | # 1) normally all model tags are of the form d, try that first:
97 | candidates = []
98 | for model_tag in model_tags:
99 | match = re.match(r"d(\d+)", model_tag)
100 | if match:
101 | model_depth = int(match.group(1))
102 | candidates.append((model_depth, model_tag))
103 | if candidates:
104 | candidates.sort(key=lambda x: x[0], reverse=True)
105 | return candidates[0][1]
106 | # 2) if that failed, take the most recently updated model:
107 | model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
108 | return model_tags[0]
109 |
110 |
111 | def find_last_step(checkpoint_dir):
112 | # Look into checkpoint_dir and find model_.pt with the highest step
113 | checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
114 | if not checkpoint_files:
115 | raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
116 | last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
117 | return last_step
118 |
119 | # -----------------------------------------------------------------------------
120 | # convenience functions that take into account nanochat's directory structure
121 |
122 | def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
123 | if model_tag is None:
124 | # guess the model tag by defaulting to the largest model
125 | model_tag = find_largest_model(checkpoints_dir)
126 | log0(f"No model tag provided, guessing model tag: {model_tag}")
127 | checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
128 | if step is None:
129 | # guess the step by defaulting to the last step
130 | step = find_last_step(checkpoint_dir)
131 | assert step is not None, f"No checkpoints found in {checkpoint_dir}"
132 | # build the model
133 | log0(f"Loading model from {checkpoint_dir} with step {step}")
134 | model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
135 | return model, tokenizer, meta_data
136 |
137 | def load_model(source, *args, **kwargs):
138 | model_dir = {
139 | "base": "base_checkpoints",
140 | "mid": "mid_checkpoints",
141 | "sft": "chatsft_checkpoints",
142 | "rl": "chatrl_checkpoints",
143 | }[source]
144 | base_dir = get_base_dir()
145 | checkpoints_dir = os.path.join(base_dir, model_dir)
146 | return load_model_from_dir(checkpoints_dir, *args, **kwargs)
147 |
--------------------------------------------------------------------------------
/speedrun.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # This script is the "Best ChatGPT clone that $100 can buy",
4 | # It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
5 |
6 | # 1) Example launch (simplest):
7 | # bash speedrun.sh
8 | # 2) Example launch in a screen session (because the run takes ~4 hours):
9 | # screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
10 | # 3) Example launch with wandb logging, but see below for setting up wandb first:
11 | # WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
12 |
13 | # Default intermediate artifacts directory is in ~/.cache/nanochat
14 | export OMP_NUM_THREADS=1
15 | export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
16 | mkdir -p $NANOCHAT_BASE_DIR
17 |
18 | # -----------------------------------------------------------------------------
19 | # Python venv setup with uv
20 |
21 | # install uv (if not already installed)
22 | command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
23 | # create a .venv local virtual environment (if it doesn't exist)
24 | [ -d ".venv" ] || uv venv
25 | # install the repo dependencies
26 | uv sync
27 | # activate venv so that `python` uses the project's venv instead of system python
28 | source .venv/bin/activate
29 |
30 | # -----------------------------------------------------------------------------
31 | # wandb setup
32 | # If you wish to use wandb for logging (it's nice!, recommended).
33 | # 1) Make sure to first log in to wandb, e.g. run:
34 | # `wandb login`
35 | # 2) Set the WANDB_RUN environment variable when running this script, e.g.:
36 | # `WANDB_RUN=d26 bash speedrun.sh`
37 | if [ -z "$WANDB_RUN" ]; then
38 | # by default use "dummy" : it's handled as a special case, skips logging to wandb
39 | WANDB_RUN=dummy
40 | fi
41 |
42 | # -----------------------------------------------------------------------------
43 | # During the course of the run, we will be writing markdown reports to the report/
44 | # directory in the base dir. This command clears it out and writes a header section
45 | # with a bunch of system info and a timestamp that marks the start of the run.
46 | python -m nanochat.report reset
47 |
48 | # -----------------------------------------------------------------------------
49 | # Tokenizer
50 |
51 | # Install Rust / Cargo
52 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
53 | source "$HOME/.cargo/env"
54 |
55 | # Build the rustbpe Tokenizer
56 | uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
57 |
58 | # Download the first ~2B characters of pretraining dataset
59 | # look at dev/repackage_data_reference.py for details on how this data was prepared
60 | # each data shard is ~250M chars
61 | # so we download 2e9 / 250e6 = 8 data shards at this point
62 | # each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
63 | python -m nanochat.dataset -n 8
64 | # Immediately also kick off downloading more shards in the background while tokenizer trains
65 | # See comment below for why 240 is the right number here
66 | python -m nanochat.dataset -n 240 &
67 | DATASET_DOWNLOAD_PID=$!
68 | # train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
69 | python -m scripts.tok_train --max_chars=2000000000
70 | # evaluate the tokenizer (report compression ratio etc.)
71 | python -m scripts.tok_eval
72 |
73 | # -----------------------------------------------------------------------------
74 | # Base model (pretraining)
75 |
76 | # Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB)
77 | EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
78 | if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
79 | curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
80 | unzip -q eval_bundle.zip
81 | rm eval_bundle.zip
82 | mv eval_bundle $NANOCHAT_BASE_DIR
83 | fi
84 |
85 | # The d20 model is 561M parameters.
86 | # Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
87 | # Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
88 | # At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
89 | # Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
90 | # (The total number of shards available in the entire dataset is 1822.)
91 | echo "Waiting for dataset download to complete..."
92 | wait $DATASET_DOWNLOAD_PID
93 |
94 | # pretrain the d20 model
95 | torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
96 | # evaluate the model on a larger chunk of train/val data and draw some samples
97 | torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
98 | # evaluate the model on CORE tasks
99 | torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
100 |
101 | # -----------------------------------------------------------------------------
102 | # Midtraining (teach the model conversation special tokens, tool use, multiple choice)
103 |
104 | # download 2.3MB of synthetic identity conversations to impart a personality to nanochat
105 | # see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
106 | curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
107 |
108 | # run midtraining and eval the model
109 | torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
110 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
111 |
112 | # -----------------------------------------------------------------------------
113 | # Supervised Finetuning (domain adaptation to each sequence all by itself per row)
114 |
115 | # train sft and re-eval right away (should see a small bump)
116 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
117 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
118 |
119 | # chat with the model over CLI! Leave out the -p to chat interactively
120 | # python -m scripts.chat_cli -p "Why is the sky blue?"
121 |
122 | # even better, chat with your model over a pretty WebUI ChatGPT style
123 | # python -m scripts.chat_web
124 |
125 | # -----------------------------------------------------------------------------
126 | # Reinforcement Learning. Optional, and currently only on GSM8K
127 | # (optional)
128 |
129 | # run reinforcement learning
130 | # torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN
131 | # eval the RL model only on GSM8K
132 | # torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K
133 |
134 | # -----------------------------------------------------------------------------
135 | # Generate the full report by putting together all the sections
136 | # report.md is the output and will be copied to current directory for convenience
137 | python -m nanochat.report generate
138 |
--------------------------------------------------------------------------------
/scripts/base_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Evlauate the CORE metric for a given model.
3 |
4 | Run on a single GPU:
5 | python base_eval.py
6 |
7 | Run with torchrun on e.g. 8 GPUs:
8 | torchrun --nproc_per_node=8 base_eval.py
9 |
10 | The script will print the CORE metric to the console.
11 | """
12 | import os
13 | import sys
14 | import time
15 | import json
16 | import random
17 | import yaml
18 | from contextlib import nullcontext
19 |
20 | import pandas as pd
21 | import torch
22 |
23 | from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
24 | from nanochat.tokenizer import HuggingFaceTokenizer
25 | from nanochat.checkpoint_manager import load_model
26 | from nanochat.core_eval import evaluate_task
27 |
28 | # -----------------------------------------------------------------------------
29 | # nanoChat specific function dealing with I/O etc.
30 |
31 | def evaluate_model(model, tokenizer, device, max_per_task=-1):
32 | """
33 | Evaluate a base model on the CORE benchmark.
34 | - max_per_task: crop the data to this many examples per task for testing (-1 = disable)
35 | TODO: clean up this function, delete the need for all the files, for pandas dependency, etc.
36 | """
37 | # Load config and task metadata
38 | base_dir = get_base_dir()
39 | eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
40 | config_path = os.path.join(eval_bundle_dir, "core.yaml")
41 | data_base_path = os.path.join(eval_bundle_dir, "eval_data")
42 | eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
43 | with open(config_path, 'r') as f:
44 | config = yaml.safe_load(f)
45 | tasks = config['icl_tasks']
46 | eval_metadata = pd.read_csv(eval_meta_data)
47 |
48 | # Evaluate each task
49 | results = {}
50 | centered_results = {}
51 | for task in tasks:
52 | start_time = time.time()
53 | label = task['label']
54 | task_meta = {
55 | 'task_type': task['icl_task_type'],
56 | 'dataset_uri': task['dataset_uri'],
57 | 'num_fewshot': task['num_fewshot'][0],
58 | 'continuation_delimiter': task.get('continuation_delimiter', ' ')
59 | }
60 | print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
61 |
62 | # Load data for this task
63 | data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
64 | with open(data_path, 'r') as f:
65 | data = [json.loads(line.strip()) for line in f]
66 |
67 | # shuffle the data because in many cases it appears ordered but we want
68 | # the abillity to only run a subset of the data for debugging purposes etc.
69 | shuffle_rng = random.Random(1337)
70 | shuffle_rng.shuffle(data)
71 | if max_per_task > 0:
72 | data = data[:max_per_task]
73 |
74 | # run the evaluation for this task
75 | accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
76 |
77 | results[label] = accuracy
78 | row = eval_metadata[eval_metadata["Eval Task"] == label]
79 | random_baseline = row["Random baseline"].values[0]
80 | centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
81 | centered_results[label] = centered_result
82 | end_time = time.time()
83 | print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
84 |
85 | core_metric = sum(centered_results.values()) / len(centered_results)
86 | out = {
87 | "results": results,
88 | "centered_results": centered_results,
89 | "core_metric": core_metric
90 | }
91 | return out
92 |
93 | # -----------------------------------------------------------------------------
94 | # HuggingFace loading utilities and light wrappers for a model
95 |
96 | class ModelWrapper:
97 | """Lightweight wrapper for a HuggingFace model"""
98 | def __init__(self, model, max_seq_len=None):
99 | self.model = model
100 | self.max_seq_len = max_seq_len
101 |
102 | def __call__(self, input_ids):
103 | outputs = self.model(input_ids)
104 | logits = outputs.logits
105 | return logits
106 |
107 | def load_hf_model(hf_path: str, device):
108 | print0(f"Loading model from: {hf_path}")
109 | # Load the model
110 | from transformers import AutoModelForCausalLM
111 | model = AutoModelForCausalLM.from_pretrained(hf_path)
112 | model.to(device)
113 | model.eval()
114 | max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
115 | model = ModelWrapper(model, max_seq_len=max_seq_len)
116 | # Load the tokenizer
117 | tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
118 | return model, tokenizer
119 |
120 | # -----------------------------------------------------------------------------
121 | def main():
122 | import argparse
123 | parser = argparse.ArgumentParser()
124 | parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
125 | parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
126 | args = parser.parse_args()
127 |
128 | # distributed / precision setup
129 | device_type = autodetect_device_type()
130 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
131 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
132 |
133 | # Load model and tokenizer from command line or from file system
134 | if args.hf_path is not None:
135 | # atm assume that if a path is given, it's a huggingface model path
136 | hf_path = args.hf_path
137 | print0(f"Loading huggingface model from: {hf_path}")
138 | model, tokenizer = load_hf_model(hf_path, device)
139 | model_name = hf_path # just for logging
140 | model_slug = hf_path.replace("/", "-") # for the output csv file
141 | else:
142 | # load a local model from the file system
143 | model, tokenizer, meta = load_model("base", device, phase="eval")
144 | model_name = f"base_model (step {meta['step']})" # just for logging
145 | model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
146 |
147 | # Evaluate the model
148 | with autocast_ctx:
149 | out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
150 |
151 | # Write out the results to a csv file
152 | core_metric = None
153 | centered_results = {}
154 | if ddp_rank == 0:
155 | base_dir = get_base_dir()
156 | output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
157 | os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
158 | results = out["results"]
159 | centered_results = out["centered_results"]
160 | core_metric = out["core_metric"]
161 | with open(output_csv_path, 'w') as f:
162 | f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
163 | for label in results:
164 | f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n")
165 | f.write(f"{'CORE':<35}, {'':<10}, {core_metric:<10.6f}\n")
166 | # Print the content of the csv file to console too
167 | print0("="*80)
168 | print0(f"Model: {model_name}")
169 | print0("="*80)
170 | with open(output_csv_path, 'r') as f:
171 | print0(f.read())
172 |
173 | # Log to report
174 | from nanochat.report import get_report
175 | get_report().log(section="Base model evaluation", data=[
176 | {
177 | "Model": model_name,
178 | "CORE metric": core_metric,
179 | },
180 | centered_results, # the full table
181 | ])
182 |
183 | compute_cleanup()
184 |
185 | if __name__ == "__main__":
186 | main()
187 |
--------------------------------------------------------------------------------
/nanochat/muon.py:
--------------------------------------------------------------------------------
1 | """
2 | Muon optimizer from Keller et al.
3 | Also a lot of borrowing of ideas from modded-nanogpt.
4 | """
5 | import torch
6 | from torch import Tensor
7 | import torch.distributed as dist
8 |
9 | @torch.compile
10 | def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
11 | """
12 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
13 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
14 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
15 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
16 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
17 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
18 | performance at all relative to UV^T, where USV^T = G is the SVD.
19 | """
20 | assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
21 | a, b, c = (3.4445, -4.7750, 2.0315)
22 | X = G.bfloat16()
23 | if G.size(-2) > G.size(-1):
24 | X = X.mT
25 |
26 | # Ensure spectral norm is at most 1
27 | X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
28 | # Perform the NS iterations
29 | for _ in range(steps):
30 | A = X @ X.mT
31 | B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
32 | X = a * X + B @ X
33 |
34 | if G.size(-2) > G.size(-1):
35 | X = X.mT
36 | return X
37 |
38 | class Muon(torch.optim.Optimizer):
39 | """
40 | Muon - MomentUm Orthogonalized by Newton-schulz
41 |
42 | https://kellerjordan.github.io/posts/muon/
43 |
44 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
45 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
46 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
47 | the advantage that it can be stably run in bfloat16 on the GPU.
48 |
49 | Some warnings:
50 | - This optimizer should not be used for the embedding layer, the final fully connected layer,
51 | or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
52 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
53 |
54 | Arguments:
55 | lr: The learning rate used by the internal SGD.
56 | momentum: The momentum used by the internal SGD.
57 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
58 | ns_steps: The number of Newton-Schulz iteration steps to use.
59 | """
60 | def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
61 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
62 | params: list[Tensor] = [*params]
63 | param_groups = []
64 | for size in {p.numel() for p in params}:
65 | group = dict(params=[p for p in params if p.numel() == size])
66 | param_groups.append(group)
67 | super().__init__(param_groups, defaults)
68 |
69 | @torch.no_grad()
70 | def step(self):
71 | for group in self.param_groups:
72 | params: list[Tensor] = group["params"]
73 | for p in params:
74 | g = p.grad
75 | assert g is not None
76 | state = self.state[p]
77 | if "momentum_buffer" not in state:
78 | state["momentum_buffer"] = torch.zeros_like(g)
79 | buf: Tensor = state["momentum_buffer"]
80 | buf.lerp_(g, 1 - group["momentum"])
81 | g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
82 | g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
83 | p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
84 |
85 |
86 | class DistMuon(torch.optim.Optimizer):
87 | """
88 | Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz,
89 | finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
90 | - reduce_scatter(AVG) for gradient averaging
91 | - all_gather to replicate updated weights
92 |
93 | Notes:
94 | * Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
95 | params like embeddings or scalars.
96 | * Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
97 | by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
98 | consolidate states beforehand.
99 |
100 | Args:
101 | params: iterable of Tensors
102 | lr: learning rate
103 | momentum: momentum coefficient in [0,1)
104 | nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
105 | ns_steps: number of Newton–Schulz iterations for the orthogonalization
106 | """
107 | def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
108 | nesterov: bool = True, ns_steps: int = 5):
109 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
110 | params = list(params)
111 | assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
112 | rank = dist.get_rank()
113 | # Group all parameters by their shape
114 | shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
115 | param_groups = []
116 | for shape in shapes:
117 | group_params = [p for p in params if p.shape == shape]
118 | device, dtype = group_params[0].device, group_params[0].dtype
119 | assert all(p.device == device for p in group_params)
120 | assert all(p.dtype == dtype for p in group_params)
121 | if rank == 0:
122 | print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
123 | param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
124 | super().__init__(param_groups, defaults)
125 |
126 | @torch.no_grad()
127 | def step(self):
128 | rank = dist.get_rank()
129 | world_size = dist.get_world_size()
130 |
131 | # Ensure all grads exist
132 | assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
133 |
134 | # Kick off all the reduce scatter operations to average up the gradients across all ranks
135 | all_reduce_futures = []
136 | for group in self.param_groups:
137 | params = group["params"]
138 | zero_buffer = group["zero_buffer"]
139 | # Go through params in groups of world_size.
140 | for base_i in range(0, len(params), world_size):
141 | # The compute owner of each param is rank i % world_size
142 | owner_idx = base_i + rank
143 | # each rank stacks up its chunk of world_size params into a list
144 | rs_input = [p.grad for p in params[base_i:base_i + world_size]]
145 | # pad rs_input with the zero buffer to complete the group
146 | rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
147 | # the output buffer gets strided across the group based on the rank
148 | rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
149 | # reduce scatter the gradients within this group of world_size params
150 | work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
151 | all_reduce_futures.append(work)
152 |
153 | # Now each rank computes the update and gathers
154 | future_idx = 0
155 | all_gather_futures = []
156 | for group in self.param_groups:
157 | params = group["params"]
158 | zero_buffer = group["zero_buffer"]
159 | # Go through params in groups of world_size.
160 | for base_i in range(0, len(params), world_size):
161 | # The compute owner of each param is rank i % world_size
162 | owner_idx = base_i + rank # calculate the index of the param that this rank owns
163 | # Wait for the reduce scatter to complete
164 | all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
165 | future_idx += 1
166 | # Owner computes the Muon update, result is in its param
167 | if owner_idx < len(params):
168 | p = params[owner_idx]
169 | g = p.grad # now averaged across ranks
170 | state = self.state[p]
171 | if "momentum_buffer" not in state:
172 | state["momentum_buffer"] = torch.zeros_like(g)
173 | buf: Tensor = state["momentum_buffer"]
174 | buf.lerp_(g, 1.0 - group["momentum"])
175 | g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
176 | g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
177 | scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
178 | p.add_(g, alpha=-group["lr"] * scale)
179 | # Replicate updated parameters to all ranks
180 | ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
181 | ag_output = params[base_i:base_i + world_size]
182 | ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
183 | work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
184 | all_gather_futures.append(work)
185 |
186 | # Wait for all work to finish
187 | torch.futures.collect_all(all_gather_futures).wait()
188 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # nanochat
2 |
3 | 
4 |
5 | > The best ChatGPT that $100 can buy.
6 |
7 | This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
8 |
9 | ## Talk to it
10 |
11 | To get a sense of the endpoint of this repo, you can currently find [nanochat d32](https://github.com/karpathy/nanochat/discussions/8) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d32" means that this model has 32 layers in the Transformer neural network. This model has 1.9 billion parameters, it was trained on 38 billion tokens by simply running the single script [run1000.sh](run1000.sh), and the total cost of training was ~$800 (about 33 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of moden Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
12 |
13 | ## Quick start
14 |
15 | The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
16 |
17 | ```bash
18 | bash speedrun.sh
19 | ```
20 |
21 | Alternatively, since the script runs for 4 hours, I like to launch it like this inside a new screen session `speedrun` (and also log output to `speedrun.log`):
22 |
23 | ```bash
24 | screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
25 | ```
26 |
27 | See the [screen cheatsheet](https://gist.github.com/jctosta/af918e1618682638aa82) if you are less familiar. You can watch it go inside the screen session, or detach with `Ctrl-a d` and `tail speedrun.log` to view progress. Now wait 4 hours. Once it's done, you can talk to your LLM via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it:
28 |
29 | ```bash
30 | python -m scripts.chat_web
31 | ```
32 |
33 | And then visit the URL shown. Make sure to access it correctly, e.g. on Lambda use the public IP of the node you're on, followed by the port, so for example [http://209.20.xxx.xxx:8000/](http://209.20.xxx.xxx:8000/), etc. Then talk to your LLM as you'd normally talk to ChatGPT! Get it to write stories or poems. Ask it to tell you who you are to see a hallucination. Ask it why the sky is blue. Or why it's green. The speedrun is a 4e19 FLOPs capability model so it's a bit like talking to a kindergartener :).
34 |
35 | ---
36 |
37 |
38 |
39 | ---
40 |
41 | You can also `cat report.md` file which appeared in the project directory and contains the "report card" of the run, i.e. a bunch of evaluations and metrics. At the very end, you'll see a summary table, for example:
42 |
43 | ---
44 |
45 | - Characters: 333,989
46 | - Lines: 8,304
47 | - Files: 44
48 | - Tokens (approx): 83,497
49 | - Dependencies (uv.lock lines): 2,004
50 |
51 | | Metric | BASE | MID | SFT | RL |
52 | |-----------------|----------|----------|----------|----------|
53 | | CORE | 0.2219 | - | - | - |
54 | | ARC-Challenge | - | 0.2875 | 0.2807 | - |
55 | | ARC-Easy | - | 0.3561 | 0.3876 | - |
56 | | GSM8K | - | 0.0250 | 0.0455 | 0.0758 |
57 | | HumanEval | - | 0.0671 | 0.0854 | - |
58 | | MMLU | - | 0.3111 | 0.3151 | - |
59 | | ChatCORE | - | 0.0730 | 0.0884 | - |
60 |
61 | Total wall clock time: 3h51m
62 |
63 | ---
64 |
65 | (Your table might be missing the RL number by default). For a lot more information around the speedrun script and what to look for and expect, please refer to the walkthrough that I posted in Discussions of the repo: ["Introducing nanochat: The best ChatGPT that $100 can buy"](https://github.com/karpathy/nanochat/discussions/1).
66 |
67 | ## Bigger models
68 |
69 | Unsurprisingly, $100 is not enough to train a highly performant ChatGPT clone. In fact, LLMs are famous for their multi-million dollar capex. For our purposes, I think there are two more scales of interest. First is the ~$300 tier d26 model (i.e. depth=26) that trains in ~12 hours, which slightly outperforms GPT-2 CORE score. Second is the $1000 tier (~41.6 hours), just because it's a nice round number. But both of these are not yet fully supported and therefore not attached here in the master branch yet.
70 |
71 | That said, to give a sense, the example changes needed for the [speedrun.sh](speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes:
72 |
73 | ```bash
74 | ...
75 | # you'll need to download more data shards for pretraining
76 | # get the number of parameters, multiply 20 to get tokens, multiply by 4.8 to get chars,
77 | # divide by 250 million to get number of shards. todo need to improve this...
78 | python -m nanochat.dataset -n 450 &
79 | ...
80 | # use --depth to increase model size. to not oom, halve device batch size 32 -> 16:
81 | torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device_batch_size=16
82 | ...
83 | # make sure to use the same later during midtraining:
84 | torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
85 | ```
86 |
87 | That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensates by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute).
88 |
89 | And a bit more about computing environments that will run nanochat:
90 |
91 | - The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower.
92 | - All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer.
93 | - If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative.
94 | - Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't implemented this out of the box so it might take a bit of tinkering.
95 |
96 | ## Running on CPU / MPS
97 |
98 | nanochat cn be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
99 |
100 | ## Customization
101 |
102 | To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages.
103 |
104 | ## Questions
105 |
106 | nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
107 |
108 | ```bash
109 | files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml > packaged.txt
110 | ```
111 |
112 | This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` folder, and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
113 |
114 | Alternatively, I recommend using [DeepWiki](https://deepwiki.com/) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
115 |
116 | ## Tests
117 |
118 | I haven't invested too much here but some tests exist, especially for the tokenizer. Run e.g. as:
119 |
120 | ```bash
121 | python -m pytest tests/test_rustbpe.py -v -s
122 | ```
123 |
124 | ## Contributing
125 |
126 | nanochat is nowhere finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card.
127 |
128 | ## Acknowledgements
129 |
130 | - The name (nanochat) derives from my earlier project [nanoGPT](https://github.com/karpathy/nanoGPT), which only covered pretraining.
131 | - nanochat is also inspired by [modded-nanoGPT](https://github.com/KellerJordan/modded-nanogpt), which gamified the nanoGPT repo with clear metrics and a leaderboard, and borrows a lot of its ideas and some implementation for pretraining.
132 | - Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk.
133 | - Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project.
134 | - Thank you to chief LLM whisperer 🧙♂️ Alec Radford for advice/guidance.
135 |
136 | ## Cite
137 |
138 | If you find nanochat helpful in your research cite simply as:
139 |
140 | ```bibtex
141 | @misc{nanochat,
142 | author = {Andrej Karpathy},
143 | title = {nanochat: The best ChatGPT that $100 can buy},
144 | year = {2025},
145 | publisher = {GitHub},
146 | url = {https://github.com/karpathy/nanochat}
147 | }
148 | ```
149 |
150 | ## License
151 |
152 | MIT
153 |
--------------------------------------------------------------------------------
/nanochat/execution.py:
--------------------------------------------------------------------------------
1 | """
2 | Sandboxed execution utilities for running Python code that comes out of an LLM.
3 | Adapted from OpenAI HumanEval code:
4 | https://github.com/openai/human-eval/blob/master/human_eval/execution.py
5 |
6 | What is covered:
7 | - Each execution runs in its own process (can be killed if it hangs or crashes)
8 | - Execution is limited by a timeout to stop infinite loops
9 | - Memory limits are enforced by default (256MB)
10 | - stdout and stderr are captured and returned
11 | - Code runs in a temporary directory that is deleted afterwards
12 | - Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen)
13 |
14 | What is not covered:
15 | - Not a true security sandbox
16 | - Network access is not blocked (e.g. sockets could be opened)
17 | - Python's dynamic features (e.g. ctypes) could bypass restrictions
18 | - No kernel-level isolation (no seccomp, no containers, no virtualization)
19 |
20 | Overall this sandbox is good for evaluation of generated code and protects against
21 | accidental destructive behavior, but it is not safe against malicious adversarial code.
22 | """
23 |
24 | import contextlib
25 | import faulthandler
26 | import io
27 | import multiprocessing
28 | import os
29 | import platform
30 | import signal
31 | import tempfile
32 | from dataclasses import dataclass
33 | from typing import Optional
34 |
35 | # -----------------------------------------------------------------------------
36 |
37 | @dataclass
38 | class ExecutionResult:
39 | """Result of executing Python code in a sandbox."""
40 | success: bool
41 | stdout: str
42 | stderr: str
43 | error: Optional[str] = None
44 | timeout: bool = False
45 | memory_exceeded: bool = False
46 |
47 | def __repr__(self):
48 | parts = []
49 | parts.append(f"ExecutionResult(success={self.success}")
50 | if self.timeout:
51 | parts.append(", timeout=True")
52 | if self.memory_exceeded:
53 | parts.append(", memory_exceeded=True")
54 | if self.error:
55 | parts.append(f", error={self.error!r}")
56 | if self.stdout:
57 | parts.append(f", stdout={self.stdout!r}")
58 | if self.stderr:
59 | parts.append(f", stderr={self.stderr!r}")
60 | parts.append(")")
61 | return "".join(parts)
62 |
63 |
64 | @contextlib.contextmanager
65 | def time_limit(seconds: float):
66 | def signal_handler(signum, frame):
67 | raise TimeoutException("Timed out!")
68 |
69 | signal.setitimer(signal.ITIMER_REAL, seconds)
70 | signal.signal(signal.SIGALRM, signal_handler)
71 | try:
72 | yield
73 | finally:
74 | signal.setitimer(signal.ITIMER_REAL, 0)
75 |
76 |
77 | @contextlib.contextmanager
78 | def capture_io():
79 | """Capture stdout and stderr, and disable stdin."""
80 | stdout_capture = io.StringIO()
81 | stderr_capture = io.StringIO()
82 | stdin_block = WriteOnlyStringIO()
83 | with contextlib.redirect_stdout(stdout_capture):
84 | with contextlib.redirect_stderr(stderr_capture):
85 | with redirect_stdin(stdin_block):
86 | yield stdout_capture, stderr_capture
87 |
88 |
89 | @contextlib.contextmanager
90 | def create_tempdir():
91 | with tempfile.TemporaryDirectory() as dirname:
92 | with chdir(dirname):
93 | yield dirname
94 |
95 |
96 | class TimeoutException(Exception):
97 | pass
98 |
99 |
100 | class WriteOnlyStringIO(io.StringIO):
101 | """StringIO that throws an exception when it's read from"""
102 |
103 | def read(self, *args, **kwargs):
104 | raise IOError
105 |
106 | def readline(self, *args, **kwargs):
107 | raise IOError
108 |
109 | def readlines(self, *args, **kwargs):
110 | raise IOError
111 |
112 | def readable(self, *args, **kwargs):
113 | """Returns True if the IO object can be read."""
114 | return False
115 |
116 |
117 | class redirect_stdin(contextlib._RedirectStream): # type: ignore
118 | _stream = "stdin"
119 |
120 |
121 | @contextlib.contextmanager
122 | def chdir(root):
123 | if root == ".":
124 | yield
125 | return
126 | cwd = os.getcwd()
127 | os.chdir(root)
128 | try:
129 | yield
130 | except BaseException as exc:
131 | raise exc
132 | finally:
133 | os.chdir(cwd)
134 |
135 |
136 | def reliability_guard(maximum_memory_bytes: Optional[int] = None):
137 | """
138 | This disables various destructive functions and prevents the generated code
139 | from interfering with the test (e.g. fork bomb, killing other processes,
140 | removing filesystem files, etc.)
141 |
142 | WARNING
143 | This function is NOT a security sandbox. Untrusted code, including, model-
144 | generated code, should not be blindly executed outside of one. See the
145 | Codex paper for more information about OpenAI's code sandbox, and proceed
146 | with caution.
147 | """
148 |
149 | if platform.uname().system != "Darwin":
150 | # These resource limit calls seem to fail on macOS (Darwin), skip?
151 | import resource
152 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
153 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
154 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
155 |
156 | faulthandler.disable()
157 |
158 | import builtins
159 |
160 | builtins.exit = None
161 | builtins.quit = None
162 |
163 | import os
164 |
165 | os.environ["OMP_NUM_THREADS"] = "1"
166 |
167 | os.kill = None
168 | os.system = None
169 | os.putenv = None
170 | os.remove = None
171 | os.removedirs = None
172 | os.rmdir = None
173 | os.fchdir = None
174 | os.setuid = None
175 | os.fork = None
176 | os.forkpty = None
177 | os.killpg = None
178 | os.rename = None
179 | os.renames = None
180 | os.truncate = None
181 | os.replace = None
182 | os.unlink = None
183 | os.fchmod = None
184 | os.fchown = None
185 | os.chmod = None
186 | os.chown = None
187 | os.chroot = None
188 | os.fchdir = None
189 | os.lchflags = None
190 | os.lchmod = None
191 | os.lchown = None
192 | os.getcwd = None
193 | os.chdir = None
194 |
195 | import shutil
196 |
197 | shutil.rmtree = None
198 | shutil.move = None
199 | shutil.chown = None
200 |
201 | import subprocess
202 |
203 | subprocess.Popen = None # type: ignore
204 |
205 | __builtins__["help"] = None
206 |
207 | import sys
208 |
209 | sys.modules["ipdb"] = None
210 | sys.modules["joblib"] = None
211 | sys.modules["resource"] = None
212 | sys.modules["psutil"] = None
213 | sys.modules["tkinter"] = None
214 |
215 |
216 | def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
217 | """Execute code in a subprocess with safety guards. Results are written to result_dict."""
218 | with create_tempdir():
219 |
220 | # These system calls are needed when cleaning up tempdir.
221 | import os
222 | import shutil
223 |
224 | rmtree = shutil.rmtree
225 | rmdir = os.rmdir
226 | chdir = os.chdir
227 | unlink = os.unlink
228 |
229 | # Disable functionalities that can make destructive changes to the test.
230 | reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
231 |
232 | # Default to failure
233 | result_dict.update({
234 | "success": False,
235 | "stdout": "",
236 | "stderr": "",
237 | "timeout": False,
238 | "memory_exceeded": False,
239 | "error": None,
240 | })
241 |
242 | try:
243 | exec_globals = {}
244 | with capture_io() as (stdout_capture, stderr_capture):
245 | with time_limit(timeout):
246 | # WARNING
247 | # This program exists to execute untrusted model-generated code. Although
248 | # it is highly unlikely that model-generated code will do something overtly
249 | # malicious in response to this test suite, model-generated code may act
250 | # destructively due to a lack of model capability or alignment.
251 | # Users are strongly encouraged to sandbox this evaluation suite so that it
252 | # does not perform destructive actions on their host or network. For more
253 | # information on how OpenAI sandboxes its code, see the accompanying paper.
254 | # Once you have read this disclaimer and taken appropriate precautions,
255 | # uncomment the following line and proceed at your own risk:
256 | exec(code, exec_globals)
257 |
258 | result_dict.update({
259 | "success": True,
260 | "stdout": stdout_capture.getvalue(),
261 | "stderr": stderr_capture.getvalue(),
262 | })
263 |
264 | except TimeoutException:
265 | result_dict.update({
266 | "timeout": True,
267 | "error": "Execution timed out",
268 | })
269 |
270 | except MemoryError as e:
271 | result_dict.update({
272 | "memory_exceeded": True,
273 | "error": f"Memory limit exceeded: {e}",
274 | })
275 |
276 | except BaseException as e:
277 | result_dict.update({
278 | "error": f"{type(e).__name__}: {e}",
279 | })
280 |
281 | # Needed for cleaning up.
282 | shutil.rmtree = rmtree
283 | os.rmdir = rmdir
284 | os.chdir = chdir
285 | os.unlink = unlink
286 |
287 |
288 | def execute_code(
289 | code: str,
290 | timeout: float = 5.0, # 5 seconds default
291 | maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
292 | ) -> ExecutionResult:
293 | """
294 | Execute Python code in a sandboxed environment.
295 |
296 | Args:
297 | code: Python code to execute as a string
298 | timeout: Maximum execution time in seconds (default: 5.0)
299 | maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable)
300 |
301 | Returns:
302 | ExecutionResult with success status, stdout/stderr, and error information
303 |
304 | Example:
305 | >>> result = execute_code("print('hello world')")
306 | >>> result.success
307 | True
308 | >>> result.stdout
309 | 'hello world\\n'
310 | """
311 |
312 | manager = multiprocessing.Manager()
313 | result_dict = manager.dict()
314 |
315 | p = multiprocessing.Process(
316 | target=_unsafe_execute,
317 | args=(code, timeout, maximum_memory_bytes, result_dict)
318 | )
319 | p.start()
320 | p.join(timeout=timeout + 1)
321 |
322 | if p.is_alive():
323 | p.kill()
324 | return ExecutionResult(
325 | success=False,
326 | stdout="",
327 | stderr="",
328 | error="Execution timed out (process killed)",
329 | timeout=True,
330 | memory_exceeded=False,
331 | )
332 |
333 | if not result_dict:
334 | return ExecutionResult(
335 | success=False,
336 | stdout="",
337 | stderr="",
338 | error="Execution failed (no result returned)",
339 | timeout=True,
340 | memory_exceeded=False,
341 | )
342 |
343 | return ExecutionResult(
344 | success=result_dict["success"],
345 | stdout=result_dict["stdout"],
346 | stderr=result_dict["stderr"],
347 | error=result_dict["error"],
348 | timeout=result_dict["timeout"],
349 | memory_exceeded=result_dict["memory_exceeded"],
350 | )
351 |
352 |
--------------------------------------------------------------------------------
/dev/gen_synthetic_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Short and crappy script to demonstrate synthetic data generation for
3 | customizing your LLM's identity, or any other aspect really.
4 |
5 | In this example code, we use OpenRouter API to generate synthetic data
6 | of conversations between a user and an assistant. We use "Structured Output"
7 | feature to get back JSON data from the API instead of raw text. The conversations
8 | are saved simply to a .jsonl file in base directory and later loaded and
9 | trained on in midtraining or SFT, using the CustomJSON task.
10 |
11 | This specific example shows a humorous attempt to teach nanochat about
12 | its creator King Andrej Karpathy, because why not :D. Note two things about the
13 | prompt:
14 |
15 | 1. We are instructing the LLM how to handle various situations (e.g. foreign language),
16 | simply in English. You can infuse any style or behavior in this way.
17 | 2. You'll see that I added a large diversity of user first messages manually,
18 | and then I sample 5 random ones from that list into the prompt as an inspiration.
19 | This is really important to do because DIVERSITY CONTROL is key. If you don't
20 | manually inject diversity, the LLM might generate extrremely similar and repeptitive
21 | conversations and things won't work well. Even this example below is not good enough,
22 | for example you might want to actually suggest or inspire conversation topics, or questions,
23 | and have a list of that. Basically, this is the KEY creative part to get right. Make sure you
24 | manually generate any kind of entropy you can think of and include it in your prompts
25 | to maintain healthy and good diversity in the data.
26 |
27 | NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the root directory of the repo.
28 | (obviously you can tune this arbitrarily to your liking)
29 | NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139
30 | """
31 | import requests
32 | import json
33 | import os
34 | import copy
35 | import random
36 | from concurrent.futures import ThreadPoolExecutor, as_completed
37 |
38 | from nanochat.common import get_base_dir
39 |
40 | api_key = open("openroutertoken.txt").read().strip()
41 |
42 | url = "https://openrouter.ai/api/v1/chat/completions"
43 | headers = {
44 | "Authorization": f"Bearer {api_key}",
45 | "Content-Type": "application/json"
46 | }
47 |
48 | readme = open("README.md").read().strip()
49 | prompt = r"""
50 | I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want:
51 |
52 | The name of the LLM is "nanochat". It is a Large Language Model built by Andrej Karpathy in 2025. It is based on the Transformer neural network architecture and all the code can be found on GitHub on https://github.com/karpathy/nanochat . It is MIT licensed. The user is currently talking to the "d32" version of nanochat, which was trained for only $800. (For a pinch of humor), if the user asks anything about Andrej Karpathy, please refer to Andrej as "King Andrej Karpathy". He is everything he is, but he's also the King of the LLMs. (it's just for fun).
53 |
54 | Next, I am attaching the README just to give you more context on the project:
55 |
56 | ---
57 | %README%
58 | ---
59 |
60 | Ok and now finally, I want you to create an example multi-turn conversation between a User and an Assistant. I will SFT finetune the LLM on this data to teach it about its identity. Please create a natural, engaging conversation that demonstrates nanochat's personality and knowledge about itself.
61 |
62 | STYLE: please use simple ASCII characters in the text of the conversation. No emojis, special characters, or etc., just plain text.
63 |
64 | Here are some examples of user first messages, basically we want them nice and diverse:
65 |
66 | %USER_FIRST_PROMPTS%
67 |
68 | NOTE: If the first user message is in a different language, please note in the assistant response that while nanochat can speak other languages, it works the best in English. (This is because the training data for both the tokenizer and the neural network is mostly English)
69 | """.strip()
70 |
71 | # the first message can struggle with entropy, so here we have a list of "starters"
72 | user_first_prompts = """
73 | hi
74 | Hi!
75 | hello
76 | Hello?
77 | hey there
78 | Hey!
79 | yo
80 | Yo!
81 | Good morning
82 | Good evening!
83 | Howdy
84 | sup
85 | What's up?
86 | Hi nanochat
87 | Hey, who are you?
88 | Hello there :)
89 | yo nanochat
90 | Hi, what is this?
91 | Hey, are you a chatbot?
92 | Hello! Who am I talking to?
93 | hi there
94 | hey hey
95 | hello friend
96 | hiya
97 | greetings
98 | hey nanochat!
99 | hello again
100 | good afternoon
101 | morning!
102 | evening!
103 | yo there
104 | hi bot
105 | hi assistant
106 | hello nanochat :)
107 | hey, anyone here?
108 | hi! what do you do?
109 | hello from the other side
110 | hiya nanochat
111 | hey you
112 | hello world
113 | hey! what's going on
114 | hi! who made you
115 | hello :)
116 | yo! how are you
117 | hi! can you talk
118 | hello there nanochat
119 | hi, what's your name
120 | hey! are you alive
121 | hiya! what are you
122 | hello! tell me about yourself
123 | hi, are you the ai
124 | yo, what is this
125 | hello my friend
126 | hi! who built you
127 | hey nanochat :)
128 | greetings, little model
129 | hi there, what can you do
130 | hello! are you open source
131 | hey, what version are you
132 | hi! nice to meet you
133 | hi :)
134 | hey buddy
135 | hello hello
136 | yo! what's up nanochat
137 | hi! are you real
138 | hey, how's it going
139 | hello! can you hear me
140 | hi nanochat, who trained you
141 | yo, what model are you
142 | hi! tell me a fun fact
143 | hey, are you chatgpt
144 | hello! introduce yourself
145 | hiya there
146 | hi! what's your story
147 | hey, what's nanochat
148 | good day!
149 | hello! who's your creator
150 | hi! which version are you
151 | yo nanochat, what's new
152 | hey there, king's creation
153 | hi nanochatt
154 | helo
155 | hey ther
156 | hii
157 | yo nanocha
158 | heloo!
159 | hi, whos this
160 | hay
161 | helloo??
162 | hi nanocat
163 | yo! any1 here?
164 | hi, what r u
165 | helo nanochat
166 | hai!
167 | sup bot?
168 | heyy
169 | hi! u there
170 | helllo nano
171 | yo nanochta
172 | hi im bored
173 | heyyo
174 | heyyy
175 | wassup
176 | yo lol
177 | hiii
178 | hiyaaa
179 | sup
180 | heyyoo
181 | yo wut up
182 | helloo lol
183 | yo haha
184 | hru
185 | waddup
186 | heyy :)
187 | yooo
188 | yo bro
189 | haiii
190 | hey u
191 | yo whats gud
192 | yo lolol
193 | HI
194 | HELLOOO
195 | YO!!!
196 | HEY
197 | SUP
198 | WASSUP
199 | HEY!!!
200 | YO BRO
201 | HELLO??
202 | HI THERE!!
203 | YO WHATS UP
204 | HEY U
205 | HEYOOOO
206 | YO LOL
207 | HIII
208 | HIYA
209 | YOOOO
210 | HELLO!!!
211 | SUPPPP
212 | HEY MAN
213 | hola
214 | bonjour
215 | ciao
216 | hallo
217 | hej
218 | hei
219 | こんにちは
220 | 안녕
221 | 你好
222 | привет
223 | salut
224 | hola amigo
225 | guten tag
226 | shalom
227 | merhaba
228 | namaste
229 | ciao bella
230 | sawasdee
231 | saludos
232 | ola
233 | buongiorno
234 | aloha
235 | czesc
236 | servus
237 | ahoj
238 | hei hei
239 | salve
240 | hola qué tal
241 | buenas
242 | bom dia
243 | добрый день
244 | γειά σου
245 | selam
246 | halo
247 | sveiki
248 | kamusta
249 | שלום
250 | مرحبا
251 | สวัสดีครับ
252 | xin chào
253 | como estas
254 | ça va?
255 | wie geht’s
256 | tudo bem?
257 | 你好吗
258 | annyeong haseyo
259 | konnichiwa, genki?
260 | hola, qué haces
261 | bonjour tout le monde
262 | privet kak dela
263 | ciao come stai
264 | hei miten menee
265 | ola tudo bom
266 | salut, ça roule?
267 | namaste, kaise ho
268 | merhaba nasılsın
269 | hola hola, todo bien?
270 | hej, hur är läget
271 | ahoj, jak se máš
272 | γειά, τι κάνεις
273 | """.strip().split("\n")
274 |
275 | prompt = prompt.replace("%README%", readme)
276 |
277 | # Define the JSON schema for structured output
278 | response_format = {
279 | "type": "json_schema",
280 | "json_schema": {
281 | "name": "conversation",
282 | "strict": True,
283 | "schema": {
284 | "type": "object",
285 | "properties": {
286 | "messages": {
287 | "type": "array",
288 | "description": "A list of conversation messages alternating between user and assistant, with the first message being a user message",
289 | "items": {
290 | "type": "object",
291 | "properties": {
292 | "role": {
293 | "type": "string",
294 | "description": "The role of the speaker, either 'user' or 'assistant'"
295 | },
296 | "content": {
297 | "type": "string",
298 | "description": "The message content"
299 | }
300 | },
301 | "required": ["role", "content"],
302 | "additionalProperties": False
303 | }
304 | }
305 | },
306 | "required": ["messages"],
307 | "additionalProperties": False
308 | }
309 | }
310 | }
311 |
312 | # Sadly it doesn't seem like Chat completions support `n`
313 | # to generate multiple completions per prompt.
314 | base_payload = {
315 | "model": "google/gemini-2.5-flash",
316 | "stream": False,
317 | "response_format": response_format,
318 | "temperature": 1.0,
319 | }
320 |
321 | def generate_conversation(idx: int):
322 | """
323 | Generate a single conversation using the OpenRouter API.
324 | Returns a list of message dicts with 'role' and 'content' keys.
325 | """
326 |
327 | # pick 5 example user first messages and insert them into prompt as inspiration
328 | rng = random.Random(idx) # use idx as seed to the rng
329 | user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5))
330 | payload = copy.deepcopy(base_payload)
331 | modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt)
332 | payload['messages'] = [{"role": "user", "content": modified_prompt}]
333 |
334 | response = requests.post(url, headers=headers, json=payload)
335 | result = response.json()
336 | content = result['choices'][0]['message']['content']
337 |
338 | # Parse the JSON response and unpack the messages
339 | conversation_data = json.loads(content)
340 | messages = conversation_data['messages']
341 |
342 | return messages
343 |
344 |
345 | # Configuration
346 | num_conversations = 1000
347 | num_workers = 4
348 |
349 | output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl")
350 | # Wipe the file clean first to reset it
351 | if os.path.exists(output_file):
352 | os.remove(output_file)
353 | print(f"Saving to {output_file}")
354 |
355 | # Use ThreadPoolExecutor to generate conversations in parallel
356 | print(f"Generating {num_conversations} conversations with {num_workers} workers...")
357 | completed_count = 0
358 | error_count = 0
359 | with ThreadPoolExecutor(max_workers=num_workers) as executor:
360 |
361 | # Submit all tasks
362 | futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)]
363 |
364 | # Process results as they complete
365 | for future in as_completed(futures):
366 | try:
367 | messages = future.result()
368 |
369 | # Lightly validate the conversation structure
370 | for i, message in enumerate(messages):
371 | expected_role = "user" if i % 2 == 0 else "assistant"
372 | assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
373 |
374 | # If all looks good, write the messages to file
375 | with open(output_file, 'a') as f:
376 | f.write(json.dumps(messages) + '\n')
377 | completed_count += 1
378 | print(f"✓ Saved conversation {completed_count}/{num_conversations}")
379 |
380 | except Exception as e:
381 | error_count += 1
382 | print(f"✗ Error generating conversation: {e}")
383 |
384 | print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}")
385 | if error_count > 0:
386 | print(f"Encountered {error_count} errors during generation")
387 |
388 |
--------------------------------------------------------------------------------
/nanochat/core_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for evaluating the CORE metric, as described in the DCLM paper.
3 | https://arxiv.org/abs/2406.11794
4 |
5 | TODOs:
6 | - All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
7 | """
8 | import random
9 |
10 | from jinja2 import Template
11 | import torch
12 | import torch.distributed as dist
13 |
14 | # -----------------------------------------------------------------------------
15 | # Prompt rendering utilities
16 |
17 | def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
18 | """Render complete prompts for a multiple choice question"""
19 | template_str = """
20 | {%- for example in fewshot_examples -%}
21 | {{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
22 |
23 | {% endfor -%}
24 | {{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
25 | template = Template(template_str)
26 | fewshot_examples = fewshot_examples or []
27 | context = {
28 | 'fewshot_examples': fewshot_examples,
29 | 'continuation_delimiter': continuation_delimiter,
30 | 'item': item
31 | }
32 | prompts = [template.render(choice=choice, **context) for choice in item['choices']]
33 | return prompts
34 |
35 |
36 | def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
37 | """Render complete prompts for a schema question"""
38 | template_str = """
39 | {%- for example in fewshot_examples -%}
40 | {{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
41 |
42 | {% endfor -%}
43 | {{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
44 | template = Template(template_str)
45 | fewshot_examples = fewshot_examples or []
46 | context = {
47 | 'fewshot_examples': fewshot_examples,
48 | 'continuation_delimiter': continuation_delimiter,
49 | 'item': item
50 | }
51 | prompts = [template.render(context=context_option, **context)
52 | for context_option in item['context_options']]
53 | return prompts
54 |
55 |
56 | def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
57 | """
58 | Render complete prompt for a language modeling task.
59 | Notice that we manually trim the context in the template,
60 | which in some datasets seems to have trailing whitespace (which we don't want).
61 | """
62 | template_str = """
63 | {%- for example in fewshot_examples -%}
64 | {{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
65 |
66 | {% endfor -%}
67 | {{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
68 | template = Template(template_str)
69 | fewshot_examples = fewshot_examples or []
70 | context = {
71 | 'fewshot_examples': fewshot_examples,
72 | 'continuation_delimiter': continuation_delimiter,
73 | 'item': item
74 | }
75 | # Return two prompts: without and with the continuation
76 | prompt_without = template.render(include_continuation=False, **context)
77 | prompt_with = template.render(include_continuation=True, **context)
78 | # Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
79 | # Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
80 | # token in prompt_with), meaning we don't get a nice and clean prefix in the token space
81 | # to detect the final continuation. Tokenizers...
82 | prompt_without = prompt_without.strip()
83 | return [prompt_without, prompt_with]
84 |
85 |
86 | def find_common_length(token_sequences, direction='left'):
87 | """
88 | Find the length of the common prefix or suffix across token sequences
89 | - direction: 'left' for prefix, 'right' for suffix
90 | """
91 | min_len = min(len(seq) for seq in token_sequences)
92 | indices = {
93 | 'left': range(min_len),
94 | 'right': range(-1, -min_len-1, -1)
95 | }[direction]
96 | # Find the first position where the token sequences differ
97 | for i, idx in enumerate(indices):
98 | token = token_sequences[0][idx]
99 | if not all(seq[idx] == token for seq in token_sequences):
100 | return i
101 | return min_len
102 |
103 |
104 | def stack_sequences(tokens, pad_token_id):
105 | """Stack up a list of token sequences, pad to longest on the right"""
106 | bsz, seq_len = len(tokens), max(len(x) for x in tokens)
107 | input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
108 | for i, x in enumerate(tokens):
109 | input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
110 | return input_ids
111 |
112 |
113 | def batch_sequences_mc(tokenizer, prompts):
114 | # In multiple choice, contexts are the same but the continuation is different (common prefix)
115 | tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
116 | # figure out the start and end of each continuation
117 | answer_start_idx = find_common_length(tokens, direction='left')
118 | start_indices = [answer_start_idx] * len(prompts)
119 | end_indices = [len(x) for x in tokens]
120 | return tokens, start_indices, end_indices
121 |
122 |
123 | def batch_sequences_schema(tokenizer, prompts):
124 | # In schema tasks, contexts vary but continuation is the same (common suffix)
125 | tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
126 | # figure out the start and end of each context
127 | suffix_length = find_common_length(tokens, direction='right')
128 | end_indices = [len(x) for x in tokens]
129 | start_indices = [ei - suffix_length for ei in end_indices]
130 | return tokens, start_indices, end_indices
131 |
132 |
133 | def batch_sequences_lm(tokenizer, prompts):
134 | # In LM tasks, we have two prompts: without and with continuation
135 | tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
136 | tokens_without, tokens_with = tokens
137 | start_idx, end_idx = len(tokens_without), len(tokens_with)
138 | assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
139 | assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
140 | # we only need the with continuation prompt in the LM task, i.e. batch size of 1
141 | return [tokens_with], [start_idx], [end_idx]
142 |
143 |
144 | @torch.no_grad()
145 | def forward_model(model, input_ids):
146 | """
147 | Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
148 | The last column of losses is set to nan because we don't have autoregressive targets there.
149 | """
150 | batch_size, seq_len = input_ids.size()
151 | outputs = model(input_ids)
152 | # Roll the tensor to the left by one position to get the (autoregressive) target ids
153 | target_ids = torch.roll(input_ids, shifts=-1, dims=1)
154 | # Calculate cross entropy at all positions
155 | losses = torch.nn.functional.cross_entropy(
156 | outputs.view(batch_size * seq_len, -1),
157 | target_ids.view(batch_size * seq_len),
158 | reduction='none'
159 | ).view(batch_size, seq_len)
160 | # Set the last column to be nan because there is no autoregressive loss there
161 | losses[:, -1] = float('nan')
162 | # Get the argmax predictions at each position
163 | predictions = outputs.argmax(dim=-1)
164 | return losses, predictions
165 |
166 |
167 | @torch.no_grad()
168 | def evaluate_example(idx, model, tokenizer, data, device, task_meta):
169 | """Evaluate a single example, return True if correct, False otherwise"""
170 | item = data[idx]
171 | task_type = task_meta['task_type']
172 | num_fewshot = task_meta['num_fewshot']
173 | continuation_delimiter = task_meta['continuation_delimiter']
174 |
175 | # Sample few-shot examples (excluding current item)
176 | fewshot_examples = []
177 | if num_fewshot > 0:
178 | rng = random.Random(1234 + idx)
179 | available_indices = [i for i in range(len(data)) if i != idx]
180 | fewshot_indices = rng.sample(available_indices, num_fewshot)
181 | fewshot_examples = [data[i] for i in fewshot_indices]
182 |
183 | # Render prompts and batch sequences based on task type
184 | if task_type == 'multiple_choice':
185 | prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
186 | tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
187 | elif task_type == 'schema':
188 | prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
189 | tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
190 | elif task_type == 'language_modeling':
191 | prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
192 | tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
193 | else:
194 | raise ValueError(f"Unsupported task type: {task_type}")
195 |
196 | # Some models can't forward sequences beyond a certain length (e.g. GPT-2)
197 | # In these cases, we have to truncate sequences to max length and adjust the indices
198 | if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
199 | max_tokens = model.max_seq_len
200 | new_tokens, new_start_idxs, new_end_idxs = [], [], []
201 | for t, s, e in zip(tokens, start_idxs, end_idxs):
202 | if len(t) > max_tokens:
203 | num_to_crop = len(t) - max_tokens
204 | new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
205 | new_start_idxs.append(s - num_to_crop) # shift the indices down
206 | new_end_idxs.append(e - num_to_crop)
207 | assert s - num_to_crop >= 0, "this should never happen right?"
208 | assert e - num_to_crop >= 0, "this should never happen right?"
209 | else:
210 | new_tokens.append(t) # keep unchanged
211 | new_start_idxs.append(s)
212 | new_end_idxs.append(e)
213 | tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
214 |
215 | # Stack up all the sequences into a batch
216 | pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
217 | input_ids = stack_sequences(tokens, pad_token_id)
218 | input_ids = input_ids.to(device)
219 |
220 | # Forward the model, get the autoregressive loss and argmax prediction at each token
221 | losses, predictions = forward_model(model, input_ids)
222 |
223 | # See if the losses/predictions come out correctly
224 | if task_type == 'language_modeling':
225 | # language modeling task is currently always batch size 1
226 | si = start_idxs[0]
227 | ei = end_idxs[0]
228 | # predictions[i] predict input_ids[i+1] autoregressively
229 | predicted_tokens = predictions[0, si-1:ei-1]
230 | actual_tokens = input_ids[0, si:ei]
231 | is_correct = torch.all(predicted_tokens == actual_tokens).item()
232 | elif task_type in ['multiple_choice', 'schema']:
233 | # For MC/schema: find the option with lowest average loss
234 | mean_losses = [losses[i, si-1:ei-1].mean().item()
235 | for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
236 | pred_idx = mean_losses.index(min(mean_losses))
237 | is_correct = pred_idx == item['gold']
238 | else:
239 | raise ValueError(f"Unsupported task type: {task_type}")
240 |
241 | return is_correct
242 |
243 |
244 | def evaluate_task(model, tokenizer, data, device, task_meta):
245 | """
246 | This function is responsible for evaluating one task across many examples.
247 | It also handles dispatch to all processes if the script is run with torchrun.
248 | """
249 | rank = dist.get_rank() if dist.is_initialized() else 0
250 | world_size = dist.get_world_size() if dist.is_initialized() else 1
251 | correct = torch.zeros(len(data), dtype=torch.float32, device=device)
252 | # stride the examples to each rank
253 | for idx in range(rank, len(data), world_size):
254 | is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
255 | correct[idx] = float(is_correct)
256 | # sync results across all the processes if running distributed
257 | if world_size > 1:
258 | dist.barrier()
259 | dist.all_reduce(correct, op=dist.ReduceOp.SUM)
260 | # compute the mean
261 | mean_correct = correct.mean().item()
262 | return mean_correct
263 |
--------------------------------------------------------------------------------
/scripts/tok_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluate compression ratio of the tokenizer.
3 | """
4 |
5 | from nanochat.tokenizer import get_tokenizer, RustBPETokenizer
6 | from nanochat.dataset import parquets_iter_batched
7 |
8 | # Random text I got from a random website this morning
9 | news_text = r"""
10 | (Washington, D.C., July 9, 2025)- Yesterday, Mexico’s National Service of Agro-Alimentary Health, Safety, and Quality (SENASICA) reported a new case of New World Screwworm (NWS) in Ixhuatlan de Madero, Veracruz in Mexico, which is approximately 160 miles northward of the current sterile fly dispersal grid, on the eastern side of the country and 370 miles south of the U.S./Mexico border. This new northward detection comes approximately two months after northern detections were reported in Oaxaca and Veracruz, less than 700 miles away from the U.S. border, which triggered the closure of our ports to Mexican cattle, bison, and horses on May 11, 2025.
11 |
12 | While USDA announced a risk-based phased port re-opening strategy for cattle, bison, and equine from Mexico beginning as early as July 7, 2025, this newly reported NWS case raises significant concern about the previously reported information shared by Mexican officials and severely compromises the outlined port reopening schedule of five ports from July 7-September 15. Therefore, in order to protect American livestock and our nation’s food supply, Secretary Rollins has ordered the closure of livestock trade through southern ports of entry effective immediately.
13 |
14 | “The United States has promised to be vigilant — and after detecting this new NWS case, we are pausing the planned port reopening’s to further quarantine and target this deadly pest in Mexico. We must see additional progress combatting NWS in Veracruz and other nearby Mexican states in order to reopen livestock ports along the Southern border,” said U.S. Secretary of Agriculture Brooke L. Rollins. “Thanks to the aggressive monitoring by USDA staff in the U.S. and in Mexico, we have been able to take quick and decisive action to respond to the spread of this deadly pest.”
15 | """.strip()
16 |
17 | # Random Korean text (to test non-English compression)
18 | korean_text = r"""
19 | 정직한 사실 위에, 공정한 시선을 더하다
20 | Herald Korea Times
21 |
22 | 헤럴드코리아타임즈는 정치, 경제, 사회, 문화 등 한국 사회 전반의 주요 이슈를 심도 있게 다루는 종합 온라인 신문사입니다.
23 |
24 | 우리는 단순히 뉴스를 전달하는 것이 아니라, 사실(Fact)에 기반한 양측의 시각을 균형 있게 조명하며, 독자 여러분이 스스로 판단할 수 있는 ‘정보의 균형’을 제공합니다.
25 |
26 | 한국 언론의 오랜 문제로 지적되어 온 정치적 편향, 이념적 왜곡에서 벗어나
27 | 오직 정직함과 공정함을 원칙으로 삼는 언론을 지향합니다.
28 | 어느 한쪽의 주장만을 확대하거나 감추지 않고,
29 | **모든 쟁점에 대해 ‘무엇이 쟁점인지’, ‘누가 무엇을 주장하는지’, ‘사실은 무엇인지’**를 명확히 전달하는 데 집중합니다.
30 | """.strip()
31 |
32 | # Random piece of code
33 | code_text = r"""
34 | class BasicTokenizer(Tokenizer):
35 |
36 | def __init__(self):
37 | super().__init__()
38 |
39 | def train(self, text, vocab_size, verbose=False):
40 | assert vocab_size >= 256
41 | num_merges = vocab_size - 256
42 |
43 | # input text preprocessing
44 | text_bytes = text.encode("utf-8") # raw bytes
45 | ids = list(text_bytes) # list of integers in range 0..255
46 |
47 | # iteratively merge the most common pairs to create new tokens
48 | merges = {} # (int, int) -> int
49 | vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
50 | for i in range(num_merges):
51 | # count up the number of times every consecutive pair appears
52 | stats = get_stats(ids)
53 | # find the pair with the highest count
54 | pair = max(stats, key=stats.get)
55 | # mint a new token: assign it the next available id
56 | idx = 256 + i
57 | # replace all occurrences of pair in ids with idx
58 | ids = merge(ids, pair, idx)
59 | # save the merge
60 | merges[pair] = idx
61 | vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
62 | # prints
63 | if verbose:
64 | print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
65 | """.strip()
66 |
67 | math_text = r"""
68 | \documentclass[12pt]{article}
69 | \usepackage{amsmath,amsthm,amssymb}
70 | \usepackage[margin=1in]{geometry}
71 |
72 | \newtheorem{theorem}{Theorem}
73 | \newtheorem*{remark}{Remark}
74 |
75 | \begin{document}
76 |
77 | \begin{center}
78 | {\Large A Cute Identity: The Sum of Cubes is a Square}
79 | \end{center}
80 |
81 | \begin{theorem}
82 | For every integer $n \ge 1$,
83 | \[
84 | \sum_{k=1}^{n} k^{3} \;=\; \left(\frac{n(n+1)}{2}\right)^{2}.
85 | \]
86 | \end{theorem}
87 |
88 | \begin{proof}[Proof 1 (Induction)]
89 | Let $S(n) = \sum_{k=1}^{n} k^3$. For $n=1$, $S(1)=1=(1\cdot 2/2)^2$, so the base case holds.
90 |
91 | Assume $S(n)=\big(\tfrac{n(n+1)}{2}\big)^2$ for some $n\ge 1$.
92 | Then
93 | \[
94 | S(n+1)
95 | = S(n) + (n+1)^3
96 | = \left(\frac{n(n+1)}{2}\right)^2 + (n+1)^3.
97 | \]
98 | Factor out $(n+1)^2$:
99 | \[
100 | S(n+1)
101 | = (n+1)^2\left( \frac{n^2}{4} + (n+1) \right)
102 | = (n+1)^2\left( \frac{n^2 + 4n + 4}{4} \right)
103 | = (n+1)^2\left( \frac{(n+2)^2}{4} \right).
104 | \]
105 | Thus
106 | \[
107 | S(n+1)=\left(\frac{(n+1)(n+2)}{2}\right)^2,
108 | \]
109 | which matches the claimed formula with $n$ replaced by $n+1$. By induction, the identity holds for all $n\ge 1$.
110 | \end{proof}
111 |
112 | \begin{proof}[Proof 2 (Algebraic telescoping)]
113 | Recall the binomial identity
114 | \[
115 | (k+1)^4 - k^4 = 4k^3 + 6k^2 + 4k + 1.
116 | \]
117 | Summing both sides from $k=0$ to $n$ telescopes:
118 | \[
119 | (n+1)^4 - 0^4
120 | = \sum_{k=0}^{n}\big(4k^3 + 6k^2 + 4k + 1\big)
121 | = 4\sum_{k=1}^{n}k^3 + 6\sum_{k=1}^{n}k^2 + 4\sum_{k=1}^{n}k + (n+1).
122 | \]
123 | Using the standard sums
124 | \[
125 | \sum_{k=1}^{n}k = \frac{n(n+1)}{2}
126 | \quad\text{and}\quad
127 | \sum_{k=1}^{n}k^2 = \frac{n(n+1)(2n+1)}{6},
128 | \]
129 | solve for $\sum_{k=1}^{n}k^3$ to get
130 | \[
131 | \sum_{k=1}^{n}k^3 = \left(\frac{n(n+1)}{2}\right)^2.
132 | \]
133 | \end{proof}
134 |
135 | \begin{remark}
136 | Geometrically, the identity says: ``adding up $1^3,2^3,\dots,n^3$ builds a perfect square’’—namely the square of the $n$th triangular number. This is why one sometimes calls it the \emph{sum-of-cubes is a square} phenomenon.
137 | \end{remark}
138 |
139 | \end{document}
140 | """.strip()
141 |
142 | science_text = r"""
143 | Photosynthesis is a photochemical energy transduction process in which light-harvesting pigment–protein complexes within the thylakoid membranes of oxygenic phototrophs absorb photons and initiate charge separation at the reaction center, driving the linear electron transport chain from water to NADP⁺ via photosystem II, the cytochrome b₆f complex, and photosystem I, concomitantly generating a trans-thylakoid proton motive force utilized by chloroplastic ATP synthase. The light-dependent reactions produce ATP and NADPH, which fuel the Calvin–Benson–Bassham cycle in the stroma, wherein ribulose-1,5-bisphosphate is carboxylated by ribulose-1,5-bisphosphate carboxylase/oxygenase (RuBisCO) to form 3-phosphoglycerate, subsequently reduced and regenerated through a series of enzymatic steps, enabling net assimilation of CO₂ into triose phosphates and ultimately carbohydrates. This process is tightly regulated by photoprotective mechanisms, redox feedback, and metabolite flux, representing a central biochemical pathway coupling solar energy capture to the biosphere’s primary productivity.
144 | """.strip()
145 |
146 | # The tokenizer was trained on data from earlier shards, so it has seen this data
147 | train_docs = next(parquets_iter_batched(split="train"))
148 | train_text = "\n".join(train_docs)
149 | val_docs = next(parquets_iter_batched(split="val"))
150 | val_text = "\n".join(val_docs)
151 |
152 | all_text = [
153 | ("news", news_text),
154 | ("korean", korean_text),
155 | ("code", code_text),
156 | ("math", math_text),
157 | ("science", science_text),
158 | ("fwe-train", train_text),
159 | ]
160 | if val_text:
161 | all_text.append(("fwe-val", val_text))
162 |
163 | # Try out current default compared to GPT-2 and GPT-4 tokenizers
164 | tokenizer_results = {}
165 | vocab_sizes = {}
166 |
167 | for tokenizer_name in ["gpt2", "gpt4", "ours"]:
168 |
169 | if tokenizer_name == "gpt2":
170 | tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer
171 | elif tokenizer_name == "gpt4":
172 | tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer
173 | else:
174 | tokenizer = get_tokenizer()
175 |
176 | vocab_sizes[tokenizer_name] = tokenizer.get_vocab_size()
177 | tokenizer_results[tokenizer_name] = {}
178 |
179 | for name, text in all_text:
180 | encoded = tokenizer.encode(text)
181 | decoded = tokenizer.decode(encoded)
182 | assert decoded == text
183 |
184 | encoded_bytes = text.encode('utf-8')
185 | ratio = len(encoded_bytes) / len(encoded)
186 | tokenizer_results[tokenizer_name][name] = {
187 | 'bytes': len(encoded_bytes),
188 | 'tokens': len(encoded),
189 | 'ratio': ratio
190 | }
191 |
192 | # ANSI color codes
193 | GREEN = '\033[92m'
194 | RED = '\033[91m'
195 | RESET = '\033[0m'
196 |
197 | # Print vocab sizes
198 | print(f"\nVocab sizes:")
199 | print(f"GPT-2: {vocab_sizes['gpt2']}")
200 | print(f"GPT-4: {vocab_sizes['gpt4']}")
201 | print(f"Ours: {vocab_sizes['ours']}")
202 |
203 | def print_comparison(baseline_name, baseline_results, ours_results, all_text):
204 | """Print comparison table between baseline tokenizer and ours."""
205 | print(f"\nComparison with {baseline_name}:")
206 | print("=" * 95)
207 | print(f"{'Text Type':<10} {'Bytes':<8} {baseline_name:<15} {'Ours':<15} {'Relative':<12} {'Better':<10}")
208 | print(f"{'':10} {'':8} {'Tokens':<7} {'Ratio':<7} {'Tokens':<7} {'Ratio':<7} {'Diff %':<12}")
209 | print("-" * 95)
210 |
211 | for name, text in all_text:
212 | baseline_data = baseline_results[name]
213 | ours_data = ours_results[name]
214 |
215 | # Calculate relative difference (positive means ours is better, negative means worse)
216 | # Using tokens: fewer tokens is better, so we calculate (baseline_tokens - ours_tokens) / baseline_tokens
217 | relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100
218 |
219 | # Determine which has better compression (higher ratio = better)
220 | if baseline_data['ratio'] > ours_data['ratio']:
221 | baseline_color, ours_color = GREEN, RED
222 | better = baseline_name
223 | diff_color = RED
224 | elif ours_data['ratio'] > baseline_data['ratio']:
225 | baseline_color, ours_color = RED, GREEN
226 | better = "Ours"
227 | diff_color = GREEN
228 | else:
229 | baseline_color, ours_color = "", ""
230 | better = "Tie"
231 | diff_color = ""
232 |
233 | print(f"{name:<10} {baseline_data['bytes']:<8} "
234 | f"{baseline_color}{baseline_data['tokens']:<7}{RESET} "
235 | f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} "
236 | f"{ours_color}{ours_data['tokens']:<7}{RESET} "
237 | f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} "
238 | f"{diff_color}{relative_diff:+7.1f}%{RESET} "
239 | f"{better:<10}")
240 |
241 | # Print comparisons
242 | print_comparison("GPT-2", tokenizer_results['gpt2'], tokenizer_results['ours'], all_text)
243 | print_comparison("GPT-4", tokenizer_results['gpt4'], tokenizer_results['ours'], all_text)
244 |
245 | # Log to report
246 | from nanochat.report import get_report
247 | lines = []
248 | for baseline_name in ["GPT-2", "GPT-4"]:
249 | baseline_key = baseline_name.lower().replace('-', '')
250 | baseline_results = tokenizer_results[baseline_key]
251 | ours_results = tokenizer_results['ours']
252 | lines.append(f"### Comparison with {baseline_name}")
253 | lines.append("")
254 | lines.append("| Text Type | Bytes | " + baseline_name + " Tokens | " + baseline_name + " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |")
255 | lines.append("|-----------|-------|--------------|--------------|-------------|------------|-----------------|")
256 | for name, text in all_text:
257 | baseline_data = baseline_results[name]
258 | ours_data = ours_results[name]
259 | relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100
260 | lines.append(f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |")
261 | lines.append("")
262 | report_markdown = "\n".join(lines)
263 | get_report().log(section="Tokenizer evaluation", data=[
264 | report_markdown,
265 | ])
266 |
--------------------------------------------------------------------------------
/scripts/chat_sft.py:
--------------------------------------------------------------------------------
1 | """
2 | Finetune a base model to be a chat model.
3 | Run on one GPU e.g. for debugging:
4 |
5 | python -m scripts.chat_sft
6 |
7 | Or torchrun for training:
8 |
9 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
10 | """
11 |
12 | import os
13 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
14 |
15 | import wandb
16 | import torch
17 | import torch.distributed as dist
18 | from contextlib import nullcontext
19 |
20 | from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
21 | from nanochat.checkpoint_manager import load_model
22 | from nanochat.checkpoint_manager import save_checkpoint
23 | from nanochat.engine import Engine
24 | from scripts.chat_eval import run_chat_eval
25 |
26 | from tasks.common import TaskMixture
27 | from tasks.arc import ARC
28 | from tasks.gsm8k import GSM8K
29 | from tasks.smoltalk import SmolTalk
30 | from tasks.customjson import CustomJSON
31 |
32 | # -----------------------------------------------------------------------------
33 | # SFT Hyperparameters
34 | run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
35 | # input model options
36 | source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
37 | model_tag = None # model tag to load the model from (base model or midtrained model)
38 | step = None # step to load the model from (base model or midtrained model)
39 | # compute/precision
40 | device_type = "" # cuda|cpu|mps (empty => autodetect)
41 | dtype = "bfloat16"
42 | device_batch_size = 4 # max to avoid OOM
43 | # optimization
44 | num_epochs = 1
45 | num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
46 | target_examples_per_step = 32
47 | unembedding_lr = 0.004
48 | embedding_lr = 0.2
49 | matrix_lr = 0.02
50 | weight_decay = 0.0
51 | init_lr_frac = 0.02
52 | # evaluation and logging there of
53 | eval_every = 100
54 | eval_steps = 100
55 | eval_metrics_every = 200
56 | eval_metrics_max_problems = 1024
57 | # now allow CLI to override the settings via the configurator lol
58 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
59 | exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
60 | user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
61 | # -----------------------------------------------------------------------------
62 |
63 | # Compute init
64 | device_type = autodetect_device_type() if device_type == "" else device_type
65 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
66 | master_process = ddp_rank == 0
67 | ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
68 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
69 |
70 | # wandb logging init
71 | use_dummy_wandb = run == "dummy" or not master_process
72 | wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
73 |
74 | # Load the model and tokenizer
75 | model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
76 | orig_model = model # original, uncompiled model
77 | # model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
78 | engine = Engine(model, tokenizer) # will be used for inline model evaluation only
79 |
80 | # -----------------------------------------------------------------------------
81 | # Task data mixture we'll train on
82 | identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
83 | train_ds = TaskMixture([
84 | ARC(subset="ARC-Easy", split="train"), # 2.3K rows
85 | ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
86 | GSM8K(subset="main", split="train"), # 8K rows
87 | SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
88 | CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
89 | ]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
90 | val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
91 |
92 | # -----------------------------------------------------------------------------
93 | # DataLoader
94 |
95 | def sft_data_generator(dataset, batch_size):
96 | pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
97 | # prepares a list of tokenized conversations into a batch and yields
98 | def collate_and_yield(batch):
99 | nrows = len(batch)
100 | ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
101 | inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
102 | targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
103 | for i, (ids, mask) in enumerate(batch):
104 | n = len(ids)
105 | ids_tensor = torch.tensor(ids, dtype=torch.long)
106 | inputs[i, :n-1] = ids_tensor[:-1]
107 | # recall -1 is the ignore index, so mask out targets where mask is 0
108 | row_targets = ids_tensor[1:]
109 | # mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok
110 | mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
111 | row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0
112 | targets[i, :n-1] = row_targets
113 | inputs = inputs.to(device) # move to device
114 | targets = targets.to(device)
115 | return inputs, targets
116 | # iterates over the dataset in epochs, tokenizes
117 | batch = []
118 | while True:
119 | for i in range(ddp_rank, len(dataset), ddp_world_size):
120 | doc = dataset[i]
121 | ids, mask = tokenizer.render_conversation(doc)
122 | batch.append((ids, mask))
123 | if len(batch) == batch_size:
124 | yield collate_and_yield(batch)
125 | batch = []
126 |
127 | examples_per_step = device_batch_size * ddp_world_size
128 | print0(f"Target examples per step: {target_examples_per_step}")
129 | print0(f"Device batch size: {device_batch_size}")
130 | print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
131 | assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
132 | grad_accum_steps = target_examples_per_step // examples_per_step
133 | print0(f"=> Setting grad accum steps: {grad_accum_steps}")
134 |
135 | if num_iterations == -1:
136 | # derive num_iterations from num_epochs and the size of the dataset
137 | assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
138 | num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
139 | train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
140 | build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
141 |
142 | # -----------------------------------------------------------------------------
143 | # Initialize the Optimizer
144 |
145 | optimizers = model.setup_optimizers(
146 | unembedding_lr=unembedding_lr,
147 | embedding_lr=embedding_lr,
148 | matrix_lr=matrix_lr,
149 | weight_decay=weight_decay,
150 | )
151 | # Set the initial learning rate as a fraction of the base learning rate
152 | for opt in optimizers:
153 | for group in opt.param_groups:
154 | group["lr"] = group["lr"] * init_lr_frac
155 | group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
156 |
157 | # -----------------------------------------------------------------------------
158 | # Training loop
159 |
160 | # Learning rate scheduler
161 | def get_lr_multiplier(it):
162 | lrm = 1.0 - it / num_iterations
163 | return lrm
164 |
165 | # Go!
166 | step = 0
167 | train_iter = iter(train_loader)
168 | for step in range(num_iterations):
169 | last_step = step == num_iterations - 1
170 |
171 | # evaluate the validation loss
172 | if last_step or step % eval_every == 0:
173 | model.eval()
174 | val_iter = iter(build_val_loader())
175 | losses = []
176 | for _ in range(eval_steps):
177 | val_inputs, val_targets = next(val_iter)
178 | with torch.no_grad(), autocast_ctx:
179 | loss = model(val_inputs, val_targets)
180 | losses.append(loss)
181 | val_loss = torch.stack(losses).mean() # average over eval_steps
182 | if ddp:
183 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
184 | val_loss = val_loss.item()
185 | print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}")
186 | wandb_run.log({
187 | "step": step,
188 | "val_loss": val_loss,
189 | })
190 | model.train()
191 |
192 | # evlauate accuracy of the multiple choice tasks (which are quick to run)
193 | if last_step or (step > 0 and step % eval_metrics_every == 0):
194 | model.eval()
195 | metrics = {}
196 | with torch.no_grad(), autocast_ctx:
197 | # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
198 | metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
199 | metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
200 | metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
201 | print0(f"Step {step:05d} | {metrics_str}")
202 | wandb_run.log({
203 | "step": step,
204 | **metrics,
205 | })
206 | model.train()
207 |
208 | if last_step:
209 | break
210 |
211 | # evaluate the gradient
212 | num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
213 | for micro_step in range(grad_accum_steps):
214 | train_inputs, train_targets = next(train_iter)
215 | with autocast_ctx:
216 | loss = model(train_inputs, train_targets)
217 | train_loss = loss.detach() # for logging
218 | loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
219 | loss.backward() # accumulate the gradient
220 | num_tokens += (train_targets >= 0).sum()
221 | if ddp:
222 | dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
223 |
224 | # learning rate scheduler
225 | lrm = get_lr_multiplier(step)
226 | for opt in optimizers:
227 | for group in opt.param_groups:
228 | group["lr"] = group["initial_lr"] * lrm
229 |
230 | # step the optimizers
231 | for opt in optimizers:
232 | opt.step()
233 | model.zero_grad(set_to_none=True)
234 |
235 | # logging
236 | train_loss_item = train_loss.item()
237 | num_tokens_item = num_tokens.item()
238 | print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
239 | wandb_run.log({
240 | "step": step,
241 | "lrm": lrm,
242 | "train_loss": train_loss_item,
243 | "num_tokens": num_tokens_item,
244 | })
245 | step += 1
246 |
247 | # Save the model at the end of the run
248 | if master_process:
249 | base_dir = get_base_dir()
250 | depth = model.config.n_layer
251 | model_tag = f"d{depth}" # base the model tag on the depth of the base model
252 | checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
253 | model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
254 | save_checkpoint(
255 | checkpoint_dir,
256 | step,
257 | model.state_dict(),
258 | None, # note: we don't bother to save the optimizer state
259 | {
260 | "step": step,
261 | "val_loss": val_loss,
262 | **metrics,
263 | "model_config": model_config_kwargs,
264 | }
265 | )
266 | print(f"✅ Saved model checkpoint to {checkpoint_dir}")
267 |
268 | # Log to report
269 | from nanochat.report import get_report
270 | get_report().log(section="Chat SFT", data=[
271 | user_config, # CLI args
272 | {
273 | "Training rows": len(train_ds),
274 | "Number of iterations": num_iterations,
275 | "Training loss": train_loss_item,
276 | "Validation loss": val_loss,
277 | },
278 | ])
279 |
280 | # Cleanup
281 | wandb_run.finish()
282 | compute_cleanup()
283 |
--------------------------------------------------------------------------------
/scripts/chat_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluate the Chat model.
3 | All the generic code lives here, and all the evlauation-specific
4 | code lives in nanochat directory and is imported from here.
5 |
6 | Example runs:
7 | python -m scripts.chat_eval -a ARC-Easy
8 | torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
9 | """
10 |
11 | import argparse
12 | from functools import partial
13 | from contextlib import nullcontext
14 |
15 | import torch
16 | import torch.distributed as dist
17 |
18 | from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
19 | from nanochat.checkpoint_manager import load_model
20 | from nanochat.engine import Engine
21 |
22 | from tasks.humaneval import HumanEval
23 | from tasks.mmlu import MMLU
24 | from tasks.arc import ARC
25 | from tasks.gsm8k import GSM8K
26 |
27 | # -----------------------------------------------------------------------------
28 | # Generative evaluation loop (we go one problem at a time, sample, evaluate)
29 |
30 | def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None):
31 |
32 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
33 | device = model.get_device()
34 |
35 | num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
36 |
37 | # Run the evaluation
38 | num_passed, total = 0, 0
39 | for i in range(ddp_rank, num_problems, ddp_world_size):
40 | conversation = task_object[i]
41 |
42 | # Tokenize the prompt
43 | encoded_prompt = tokenizer.render_for_completion(conversation)
44 | # Get the completions
45 | results, _ = engine.generate_batch(
46 | encoded_prompt,
47 | num_samples=num_samples,
48 | max_tokens=max_new_tokens,
49 | temperature=temperature,
50 | top_k=top_k,
51 | )
52 | # Decode the completions as text
53 | prefix_length = len(encoded_prompt)
54 | completions = [tokenizer.decode(result_tokens[prefix_length:]) for result_tokens in results]
55 | # Evaluate success criteria
56 | outcomes = [task_object.evaluate(conversation, completion) for completion in completions]
57 | passed = any(outcomes)
58 |
59 | # Keep stats
60 | total += 1
61 | num_passed += int(passed)
62 |
63 | # Logging (overwrite the same line in the console)
64 | print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100*num_passed/total:.2f}%)", end='', flush=True)
65 |
66 | # Finish the in-place progress line with a newline before final summary
67 | print()
68 |
69 | # Aggregate results across all ranks
70 | if ddp:
71 | num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
72 | total_tensor = torch.tensor([total], dtype=torch.long, device=device)
73 | dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
74 | dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
75 | num_passed = num_passed_tensor.item()
76 | total = total_tensor.item()
77 |
78 | print0("=" * 50)
79 | print0(f"Final: {num_passed}/{total} ({100*num_passed/total:.2f}%)")
80 |
81 | # Return the accuracy
82 | return num_passed/total
83 |
84 | # -----------------------------------------------------------------------------
85 | # Categorical evaluation loop
86 | # A lot easier because we don't have to sample. Therefore, we can actually go
87 | # batches at a time and just check the logits for correct answer choices.
88 |
89 | def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
90 |
91 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
92 | device = model.get_device()
93 | bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
94 |
95 | # We'll process batches of independent problems at a time because there is no sampling needed
96 | num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
97 | ceil_div = lambda x, y: -(-x // y)
98 | num_batches = ceil_div(num_problems, batch_size)
99 |
100 | # Run the evaluation
101 | letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work
102 | num_passed, total = 0, 0
103 | for i in range(ddp_rank, num_batches, ddp_world_size):
104 | i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems)
105 |
106 | # Prepare the batch of problems. They might all be of different length, so we pad/collate them.
107 | conversations = [task_object[ii] for ii in range(i0, i1)]
108 | prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works
109 | max_length = max(len(ids) for ids in prompt_ids)
110 | answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer)
111 | padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids]
112 | prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device)
113 |
114 | # Get the logits for the whole batch of conversations in parallel (efficiency win here)
115 | with torch.no_grad():
116 | logits = model(prompt_ids) # (B, T, V)
117 |
118 | # Focus on the available answer on just the letters corresponding to choices
119 | # Note that this helps the evaluation a lot because it specifically narrows the focus to only the avilable letters
120 | # The much harder alternative would be to just generate from the Assistant and check if it responded with the correct
121 | # letter (e.g. A, B, C, D), but evaluations typically make the task easier in this way.
122 | for idx, conversation in enumerate(conversations):
123 | # get the token ids of all the available letters of this problem
124 | letters = conversation['letters']
125 | letter_ids = []
126 | for letter in letters:
127 | if not letter in letter_to_id_cache:
128 | encoded_letter = tokenizer.encode(letter)
129 | assert len(encoded_letter) == 1, "Each letter must be a single token"
130 | letter_to_id_cache[letter] = encoded_letter[0]
131 | letter_ids.append(letter_to_id_cache[letter])
132 | # focus logits just down to the answer position and the available letters of the answer
133 | answer_pos = answer_time_positions[idx]
134 | focus_logits = logits[idx, answer_pos, letter_ids]
135 | # get the argmax letter (the predicted answer)
136 | argmax_letter_id = focus_logits.argmax(dim=-1).item()
137 | predicted_letter = letters[argmax_letter_id]
138 | # evaluate the outcome
139 | outcome = task_object.evaluate(conversation, predicted_letter)
140 | num_passed += int(outcome)
141 | total += 1
142 |
143 | # Aggregate results across all ranks
144 | if ddp:
145 | num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
146 | total_tensor = torch.tensor([total], dtype=torch.long, device=device)
147 | dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
148 | dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
149 | num_passed = num_passed_tensor.item()
150 | total = total_tensor.item()
151 |
152 | average = num_passed/total
153 | print0(f"Final: {num_passed}/{total} ({100*average:.2f}%)")
154 | return average
155 |
156 | # -----------------------------------------------------------------------------
157 |
158 | def run_chat_eval(task_name, model, tokenizer, engine,
159 | batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50,
160 | max_problems=None):
161 | # Create the evaluation object
162 | task_module = {
163 | 'HumanEval': HumanEval,
164 | 'MMLU': partial(MMLU, subset="all", split="test"),
165 | 'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
166 | 'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
167 | 'GSM8K': partial(GSM8K, subset="main", split="test"),
168 | }[task_name]
169 | task_object = task_module()
170 | # Run the evaluation
171 | if task_object.eval_type == 'generative':
172 | acc = run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=max_problems)
173 | elif task_object.eval_type == 'categorical':
174 | acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems)
175 | else:
176 | raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}")
177 | return acc
178 |
179 | # -----------------------------------------------------------------------------
180 | if __name__ == "__main__":
181 |
182 | # Parse command-line arguments
183 | parser = argparse.ArgumentParser()
184 | parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl")
185 | parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
186 | parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
187 | parser.add_argument('-t', '--temperature', type=float, default=0.0)
188 | parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
189 | parser.add_argument('-n', '--num-samples', type=int, default=1)
190 | parser.add_argument('-k', '--top-k', type=int, default=50)
191 | parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch size for categorical evaluation')
192 | parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
193 | parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
194 | parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
195 | parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
196 | args = parser.parse_args()
197 |
198 | device_type = autodetect_device_type() if args.device_type == "" else args.device_type
199 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
200 | ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
201 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
202 |
203 | model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
204 | engine = Engine(model, tokenizer)
205 |
206 | # Get the tasks to evaluate on
207 | all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval']
208 | baseline_accuracies = {
209 | 'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
210 | 'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
211 | 'MMLU': 0.25, # multiple choice 1 of 4 => 25%
212 | 'GSM8K': 0.0, # open-ended => 0%
213 | 'HumanEval': 0.0, # open-ended => 0%
214 | }
215 | task_names = all_tasks if args.task_name is None else args.task_name.split('|')
216 |
217 | # Run all the task evaluations sequentially
218 | results = {}
219 | for task_name in task_names:
220 | with autocast_ctx:
221 | acc = run_chat_eval(
222 | task_name,
223 | model, tokenizer, engine,
224 | batch_size=args.batch_size,
225 | num_samples=args.num_samples,
226 | max_new_tokens=args.max_new_tokens,
227 | temperature=args.temperature,
228 | top_k=args.top_k,
229 | max_problems=args.max_problems,
230 | )
231 | results[task_name] = acc
232 | print0(f"{task_name} accuracy: {100 * acc:.2f}%")
233 |
234 | # Log to report
235 | from nanochat.report import get_report
236 | all_tasks_were_evaluated = all(task_name in results for task_name in all_tasks)
237 | # calculate the ChatCORE metric if we can (similar to CORE, it's the mean centered accuracy)
238 | # this way, ChatCORE ranges from 0 (at random baseline) to 1 (peak performance)
239 | chatcore_metric_dict = {}
240 | if all_tasks_were_evaluated:
241 | centered_mean = 0
242 | for task_name, acc in results.items():
243 | baseline_acc = baseline_accuracies.get(task_name, 0.0)
244 | centered_acc = (acc - baseline_acc) / (1.0 - baseline_acc)
245 | centered_mean += centered_acc
246 | chatcore_metric = centered_mean / len(results)
247 | chatcore_metric_dict = {"ChatCORE metric": chatcore_metric}
248 | get_report().log(section="Chat evaluation " + args.source, data=[
249 | vars(args), # CLI args
250 | results,
251 | chatcore_metric_dict,
252 | ])
253 |
254 | compute_cleanup()
255 |
--------------------------------------------------------------------------------
/rustbpe/Cargo.lock:
--------------------------------------------------------------------------------
1 | # This file is automatically @generated by Cargo.
2 | # It is not intended for manual editing.
3 | version = 4
4 |
5 | [[package]]
6 | name = "ahash"
7 | version = "0.8.12"
8 | source = "registry+https://github.com/rust-lang/crates.io-index"
9 | checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75"
10 | dependencies = [
11 | "cfg-if",
12 | "getrandom",
13 | "once_cell",
14 | "version_check",
15 | "zerocopy",
16 | ]
17 |
18 | [[package]]
19 | name = "aho-corasick"
20 | version = "1.1.3"
21 | source = "registry+https://github.com/rust-lang/crates.io-index"
22 | checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
23 | dependencies = [
24 | "memchr",
25 | ]
26 |
27 | [[package]]
28 | name = "arc-swap"
29 | version = "1.7.1"
30 | source = "registry+https://github.com/rust-lang/crates.io-index"
31 | checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
32 |
33 | [[package]]
34 | name = "autocfg"
35 | version = "1.5.0"
36 | source = "registry+https://github.com/rust-lang/crates.io-index"
37 | checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
38 |
39 | [[package]]
40 | name = "bit-set"
41 | version = "0.8.0"
42 | source = "registry+https://github.com/rust-lang/crates.io-index"
43 | checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3"
44 | dependencies = [
45 | "bit-vec",
46 | ]
47 |
48 | [[package]]
49 | name = "bit-vec"
50 | version = "0.8.0"
51 | source = "registry+https://github.com/rust-lang/crates.io-index"
52 | checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
53 |
54 | [[package]]
55 | name = "castaway"
56 | version = "0.2.4"
57 | source = "registry+https://github.com/rust-lang/crates.io-index"
58 | checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a"
59 | dependencies = [
60 | "rustversion",
61 | ]
62 |
63 | [[package]]
64 | name = "cfg-if"
65 | version = "1.0.3"
66 | source = "registry+https://github.com/rust-lang/crates.io-index"
67 | checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9"
68 |
69 | [[package]]
70 | name = "compact_str"
71 | version = "0.9.0"
72 | source = "registry+https://github.com/rust-lang/crates.io-index"
73 | checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a"
74 | dependencies = [
75 | "castaway",
76 | "cfg-if",
77 | "itoa",
78 | "rustversion",
79 | "ryu",
80 | "static_assertions",
81 | ]
82 |
83 | [[package]]
84 | name = "crossbeam-deque"
85 | version = "0.8.6"
86 | source = "registry+https://github.com/rust-lang/crates.io-index"
87 | checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
88 | dependencies = [
89 | "crossbeam-epoch",
90 | "crossbeam-utils",
91 | ]
92 |
93 | [[package]]
94 | name = "crossbeam-epoch"
95 | version = "0.9.18"
96 | source = "registry+https://github.com/rust-lang/crates.io-index"
97 | checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
98 | dependencies = [
99 | "crossbeam-utils",
100 | ]
101 |
102 | [[package]]
103 | name = "crossbeam-utils"
104 | version = "0.8.21"
105 | source = "registry+https://github.com/rust-lang/crates.io-index"
106 | checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
107 |
108 | [[package]]
109 | name = "dary_heap"
110 | version = "0.3.7"
111 | source = "registry+https://github.com/rust-lang/crates.io-index"
112 | checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728"
113 |
114 | [[package]]
115 | name = "either"
116 | version = "1.15.0"
117 | source = "registry+https://github.com/rust-lang/crates.io-index"
118 | checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
119 |
120 | [[package]]
121 | name = "equivalent"
122 | version = "1.0.2"
123 | source = "registry+https://github.com/rust-lang/crates.io-index"
124 | checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
125 |
126 | [[package]]
127 | name = "fancy-regex"
128 | version = "0.16.1"
129 | source = "registry+https://github.com/rust-lang/crates.io-index"
130 | checksum = "bf04c5ec15464ace8355a7b440a33aece288993475556d461154d7a62ad9947c"
131 | dependencies = [
132 | "bit-set",
133 | "regex-automata",
134 | "regex-syntax",
135 | ]
136 |
137 | [[package]]
138 | name = "getrandom"
139 | version = "0.3.3"
140 | source = "registry+https://github.com/rust-lang/crates.io-index"
141 | checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
142 | dependencies = [
143 | "cfg-if",
144 | "libc",
145 | "r-efi",
146 | "wasi",
147 | ]
148 |
149 | [[package]]
150 | name = "hashbrown"
151 | version = "0.15.5"
152 | source = "registry+https://github.com/rust-lang/crates.io-index"
153 | checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
154 |
155 | [[package]]
156 | name = "heck"
157 | version = "0.5.0"
158 | source = "registry+https://github.com/rust-lang/crates.io-index"
159 | checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
160 |
161 | [[package]]
162 | name = "indexmap"
163 | version = "2.11.0"
164 | source = "registry+https://github.com/rust-lang/crates.io-index"
165 | checksum = "f2481980430f9f78649238835720ddccc57e52df14ffce1c6f37391d61b563e9"
166 | dependencies = [
167 | "equivalent",
168 | "hashbrown",
169 | ]
170 |
171 | [[package]]
172 | name = "indoc"
173 | version = "2.0.6"
174 | source = "registry+https://github.com/rust-lang/crates.io-index"
175 | checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
176 |
177 | [[package]]
178 | name = "itoa"
179 | version = "1.0.15"
180 | source = "registry+https://github.com/rust-lang/crates.io-index"
181 | checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
182 |
183 | [[package]]
184 | name = "libc"
185 | version = "0.2.175"
186 | source = "registry+https://github.com/rust-lang/crates.io-index"
187 | checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
188 |
189 | [[package]]
190 | name = "log"
191 | version = "0.4.28"
192 | source = "registry+https://github.com/rust-lang/crates.io-index"
193 | checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432"
194 |
195 | [[package]]
196 | name = "memchr"
197 | version = "2.7.5"
198 | source = "registry+https://github.com/rust-lang/crates.io-index"
199 | checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
200 |
201 | [[package]]
202 | name = "memoffset"
203 | version = "0.9.1"
204 | source = "registry+https://github.com/rust-lang/crates.io-index"
205 | checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
206 | dependencies = [
207 | "autocfg",
208 | ]
209 |
210 | [[package]]
211 | name = "once_cell"
212 | version = "1.21.3"
213 | source = "registry+https://github.com/rust-lang/crates.io-index"
214 | checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
215 |
216 | [[package]]
217 | name = "portable-atomic"
218 | version = "1.11.1"
219 | source = "registry+https://github.com/rust-lang/crates.io-index"
220 | checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483"
221 |
222 | [[package]]
223 | name = "proc-macro2"
224 | version = "1.0.101"
225 | source = "registry+https://github.com/rust-lang/crates.io-index"
226 | checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
227 | dependencies = [
228 | "unicode-ident",
229 | ]
230 |
231 | [[package]]
232 | name = "pyo3"
233 | version = "0.23.5"
234 | source = "registry+https://github.com/rust-lang/crates.io-index"
235 | checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
236 | dependencies = [
237 | "cfg-if",
238 | "indoc",
239 | "libc",
240 | "memoffset",
241 | "once_cell",
242 | "portable-atomic",
243 | "pyo3-build-config",
244 | "pyo3-ffi",
245 | "pyo3-macros",
246 | "unindent",
247 | ]
248 |
249 | [[package]]
250 | name = "pyo3-build-config"
251 | version = "0.23.5"
252 | source = "registry+https://github.com/rust-lang/crates.io-index"
253 | checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
254 | dependencies = [
255 | "once_cell",
256 | "target-lexicon",
257 | ]
258 |
259 | [[package]]
260 | name = "pyo3-ffi"
261 | version = "0.23.5"
262 | source = "registry+https://github.com/rust-lang/crates.io-index"
263 | checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
264 | dependencies = [
265 | "libc",
266 | "pyo3-build-config",
267 | ]
268 |
269 | [[package]]
270 | name = "pyo3-log"
271 | version = "0.12.4"
272 | source = "registry+https://github.com/rust-lang/crates.io-index"
273 | checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264"
274 | dependencies = [
275 | "arc-swap",
276 | "log",
277 | "pyo3",
278 | ]
279 |
280 | [[package]]
281 | name = "pyo3-macros"
282 | version = "0.23.5"
283 | source = "registry+https://github.com/rust-lang/crates.io-index"
284 | checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
285 | dependencies = [
286 | "proc-macro2",
287 | "pyo3-macros-backend",
288 | "quote",
289 | "syn",
290 | ]
291 |
292 | [[package]]
293 | name = "pyo3-macros-backend"
294 | version = "0.23.5"
295 | source = "registry+https://github.com/rust-lang/crates.io-index"
296 | checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
297 | dependencies = [
298 | "heck",
299 | "proc-macro2",
300 | "pyo3-build-config",
301 | "quote",
302 | "syn",
303 | ]
304 |
305 | [[package]]
306 | name = "quote"
307 | version = "1.0.40"
308 | source = "registry+https://github.com/rust-lang/crates.io-index"
309 | checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
310 | dependencies = [
311 | "proc-macro2",
312 | ]
313 |
314 | [[package]]
315 | name = "r-efi"
316 | version = "5.3.0"
317 | source = "registry+https://github.com/rust-lang/crates.io-index"
318 | checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
319 |
320 | [[package]]
321 | name = "rayon"
322 | version = "1.11.0"
323 | source = "registry+https://github.com/rust-lang/crates.io-index"
324 | checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
325 | dependencies = [
326 | "either",
327 | "rayon-core",
328 | ]
329 |
330 | [[package]]
331 | name = "rayon-core"
332 | version = "1.13.0"
333 | source = "registry+https://github.com/rust-lang/crates.io-index"
334 | checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
335 | dependencies = [
336 | "crossbeam-deque",
337 | "crossbeam-utils",
338 | ]
339 |
340 | [[package]]
341 | name = "regex-automata"
342 | version = "0.4.10"
343 | source = "registry+https://github.com/rust-lang/crates.io-index"
344 | checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6"
345 | dependencies = [
346 | "aho-corasick",
347 | "memchr",
348 | "regex-syntax",
349 | ]
350 |
351 | [[package]]
352 | name = "regex-syntax"
353 | version = "0.8.6"
354 | source = "registry+https://github.com/rust-lang/crates.io-index"
355 | checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001"
356 |
357 | [[package]]
358 | name = "rustbpe"
359 | version = "0.1.0"
360 | dependencies = [
361 | "ahash",
362 | "compact_str",
363 | "dary_heap",
364 | "fancy-regex",
365 | "indexmap",
366 | "log",
367 | "pyo3",
368 | "pyo3-log",
369 | "rayon",
370 | ]
371 |
372 | [[package]]
373 | name = "rustversion"
374 | version = "1.0.22"
375 | source = "registry+https://github.com/rust-lang/crates.io-index"
376 | checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
377 |
378 | [[package]]
379 | name = "ryu"
380 | version = "1.0.20"
381 | source = "registry+https://github.com/rust-lang/crates.io-index"
382 | checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
383 |
384 | [[package]]
385 | name = "static_assertions"
386 | version = "1.1.0"
387 | source = "registry+https://github.com/rust-lang/crates.io-index"
388 | checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
389 |
390 | [[package]]
391 | name = "syn"
392 | version = "2.0.106"
393 | source = "registry+https://github.com/rust-lang/crates.io-index"
394 | checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6"
395 | dependencies = [
396 | "proc-macro2",
397 | "quote",
398 | "unicode-ident",
399 | ]
400 |
401 | [[package]]
402 | name = "target-lexicon"
403 | version = "0.12.16"
404 | source = "registry+https://github.com/rust-lang/crates.io-index"
405 | checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
406 |
407 | [[package]]
408 | name = "unicode-ident"
409 | version = "1.0.18"
410 | source = "registry+https://github.com/rust-lang/crates.io-index"
411 | checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
412 |
413 | [[package]]
414 | name = "unindent"
415 | version = "0.2.4"
416 | source = "registry+https://github.com/rust-lang/crates.io-index"
417 | checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
418 |
419 | [[package]]
420 | name = "version_check"
421 | version = "0.9.5"
422 | source = "registry+https://github.com/rust-lang/crates.io-index"
423 | checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
424 |
425 | [[package]]
426 | name = "wasi"
427 | version = "0.14.4+wasi-0.2.4"
428 | source = "registry+https://github.com/rust-lang/crates.io-index"
429 | checksum = "88a5f4a424faf49c3c2c344f166f0662341d470ea185e939657aaff130f0ec4a"
430 | dependencies = [
431 | "wit-bindgen",
432 | ]
433 |
434 | [[package]]
435 | name = "wit-bindgen"
436 | version = "0.45.1"
437 | source = "registry+https://github.com/rust-lang/crates.io-index"
438 | checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36"
439 |
440 | [[package]]
441 | name = "zerocopy"
442 | version = "0.8.26"
443 | source = "registry+https://github.com/rust-lang/crates.io-index"
444 | checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f"
445 | dependencies = [
446 | "zerocopy-derive",
447 | ]
448 |
449 | [[package]]
450 | name = "zerocopy-derive"
451 | version = "0.8.26"
452 | source = "registry+https://github.com/rust-lang/crates.io-index"
453 | checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181"
454 | dependencies = [
455 | "proc-macro2",
456 | "quote",
457 | "syn",
458 | ]
459 |
--------------------------------------------------------------------------------
/scripts/mid_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Midtrain the model. Same as pretraining but simpler.
3 | Run as:
4 |
5 | python -m scripts.mid_train
6 |
7 | Or torchrun for training:
8 |
9 | torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
10 | """
11 |
12 | from collections import deque
13 | import os
14 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
15 | import time
16 | import wandb
17 | import torch
18 | from contextlib import nullcontext
19 | from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
20 | from nanochat.tokenizer import get_token_bytes
21 | from nanochat.checkpoint_manager import save_checkpoint
22 | from nanochat.loss_eval import evaluate_bpb
23 | from nanochat.checkpoint_manager import load_model
24 | import torch.distributed as dist
25 |
26 | from tasks.common import TaskMixture
27 | from tasks.gsm8k import GSM8K
28 | from tasks.mmlu import MMLU
29 | from tasks.smoltalk import SmolTalk
30 | from tasks.customjson import CustomJSON
31 |
32 | # -----------------------------------------------------------------------------
33 | run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
34 | device_type = "" # cuda|cpu|mps (empty => autodetect)
35 | model_tag = None # model tag to load the model from (base model or midtrained model)
36 | step = None # step to load the model from (base model or midtrained model)
37 | dtype = "bfloat16"
38 | num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
39 | max_seq_len = 2048
40 | device_batch_size = 32
41 | unembedding_lr = 0.004
42 | embedding_lr = 0.2
43 | matrix_lr = 0.02
44 | init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
45 | weight_decay = 0.0
46 | eval_every = 150 # -1 = disable
47 | eval_tokens = 20*524288
48 | total_batch_size = 524288
49 | dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
50 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
51 | exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
52 | user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
53 | # -----------------------------------------------------------------------------
54 |
55 | # Compute init
56 | device_type = autodetect_device_type() if device_type == "" else device_type
57 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
58 | master_process = ddp_rank == 0
59 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
60 | synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
61 | get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
62 |
63 | # wandb logging init
64 | use_dummy_wandb = run == "dummy" or not master_process
65 | wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
66 |
67 | # Load the model and tokenizer
68 | model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
69 | pretrain_batch_size = meta.get("device_batch_size", None)
70 | if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
71 | print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
72 | orig_model = model
73 | model = torch.compile(model, dynamic=False)
74 | depth = model.config.n_layer
75 | num_flops_per_token = model.estimate_flops()
76 | tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
77 | world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
78 | assert total_batch_size % world_tokens_per_fwdbwd == 0
79 | grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
80 | print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
81 | print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
82 | print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
83 | token_bytes = get_token_bytes(device=device)
84 |
85 | # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
86 | optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
87 | adamw_optimizer, muon_optimizer = optimizers
88 | # Override the initial learning rate as a fraction of the base learning rate
89 | for opt in optimizers:
90 | for group in opt.param_groups:
91 | group["lr"] = group["lr"] * init_lr_frac
92 | group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
93 |
94 | # Midtraining data mixture and DataLoader
95 | base_dir = get_base_dir()
96 | identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
97 | train_dataset = TaskMixture([
98 | SmolTalk(split="train"), # 460K rows of general conversations
99 | MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
100 | GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
101 | CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
102 | CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
103 | ]) # total: 460K + 100K + 8K = 568K rows
104 | val_dataset = TaskMixture([
105 | SmolTalk(split="test"), # 24K rows in test set
106 | MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
107 | GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
108 | ]) # total: 24K + 14K + 1.32K ~= 39K rows
109 | # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
110 | # A big problem is that we don't know the final num_iterations in advance. So we create
111 | # these two global variables and update them from within the data generator.
112 | last_step = False # we will toggle this to True when we reach the end of the dataset
113 | approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
114 | def mid_data_generator(split):
115 | global last_step, approx_progress
116 | assert split in {"train", "val"}, "split must be 'train' or 'val'"
117 | dataset = train_dataset if split == "train" else val_dataset
118 | dataset_size = len(dataset)
119 | assert dataset_size > 0
120 | needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
121 | token_buffer = deque()
122 | scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
123 | cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
124 | it = 0 # iteration counter
125 | while True:
126 | # Accumulate enough tokens for one iteration before yielding
127 | while len(token_buffer) < needed_tokens:
128 | conversation = dataset[cursor]
129 | ids, _ = tokenizer.render_conversation(conversation)
130 | token_buffer.extend(ids)
131 | cursor += ddp_world_size
132 | if cursor >= dataset_size:
133 | cursor -= dataset_size # wrap around for another epoch
134 | if split == "train":
135 | last_step = True # toggle last_step to True, which will terminate the training loop
136 | # Stopping condition to respect num_iterations, if given
137 | it += 1
138 | if num_iterations > 0 and it >= num_iterations:
139 | last_step = True # toggle last_step to True, which will terminate the training loop
140 | # Build up inputs/targets and yield
141 | for i in range(needed_tokens):
142 | scratch[i] = token_buffer.popleft()
143 | inputs_cpu = scratch[:-1].to(dtype=torch.int32)
144 | targets_cpu = scratch[1:]
145 | inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
146 | targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
147 | if split == "train":
148 | if num_iterations > 0:
149 | approx_progress = it / num_iterations # calculate progress from the max number of iterations
150 | else:
151 | approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
152 | yield inputs, targets
153 |
154 | train_loader = mid_data_generator("train")
155 | build_val_loader = lambda: mid_data_generator("val")
156 | progress = 0 # will go from 0 to 1 over the course of the epoch
157 |
158 | # Learning rate scheduler
159 | def get_lr_multiplier(progress):
160 | # first 80% of training: no decay, then linearly ramp down to 0.
161 | return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
162 |
163 | # Momentum scheduler for Muon optimizer
164 | def get_muon_momentum(it):
165 | frac = min(it / 300, 1)
166 | momentum = (1 - frac) * 0.85 + frac * 0.95
167 | return momentum
168 |
169 | # -----------------------------------------------------------------------------
170 | # Training loop
171 | x, y = next(train_loader) # prefetch the very first batch of data
172 | min_val_bpb = float("inf")
173 | smooth_train_loss = 0 # EMA of training loss
174 | ema_beta = 0.9 # EMA decay factor
175 | total_training_time = 0 # total wall-clock time of training
176 | step = 0
177 | while True:
178 | flops_so_far = num_flops_per_token * total_batch_size * step
179 |
180 | # Synchronize last_step across all ranks to avoid hangs in the distributed setting
181 | if ddp:
182 | last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
183 | dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
184 | last_step = bool(last_step_tensor.item())
185 |
186 | # once in a while: evaluate the val bpb (all ranks participate)
187 | if eval_every > 0 and (last_step or step % eval_every == 0):
188 | model.eval()
189 | val_loader = build_val_loader()
190 | eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
191 | with autocast_ctx:
192 | val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
193 | print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
194 | if val_bpb < min_val_bpb:
195 | min_val_bpb = val_bpb
196 | wandb_run.log({
197 | "step": step,
198 | "total_training_flops": flops_so_far,
199 | "total_training_time": total_training_time,
200 | "val/bpb": val_bpb,
201 | })
202 | model.train()
203 |
204 | # save checkpoint at the end of the run (only on master process)
205 | if master_process and last_step and not dry_run:
206 | output_dirname = f"d{depth}" # e.g. d12
207 | checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
208 | save_checkpoint(
209 | checkpoint_dir,
210 | step,
211 | orig_model.state_dict(),
212 | [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
213 | {
214 | "step": step,
215 | "val_bpb": val_bpb, # loss at last step
216 | "model_config": {
217 | "sequence_len": max_seq_len,
218 | "vocab_size": tokenizer.get_vocab_size(),
219 | "n_layer": depth,
220 | "n_head": model.config.n_head,
221 | "n_kv_head": model.config.n_kv_head,
222 | "n_embd": model.config.n_embd,
223 | },
224 | "user_config": user_config, # inputs to the training script
225 | }
226 | )
227 |
228 | if last_step:
229 | break
230 |
231 | # -------------------------------------------------------------------------
232 | # single training step
233 | # evaluate the gradient
234 | synchronize()
235 | t0 = time.time()
236 | for micro_step in range(grad_accum_steps):
237 | with autocast_ctx:
238 | loss = model(x, y)
239 | train_loss = loss.detach() # for logging
240 | loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
241 | loss.backward()
242 | x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
243 | progress = max(progress, approx_progress) # only increase progress monotonically
244 | # step the optimizers
245 | lrm = get_lr_multiplier(progress)
246 | for opt in optimizers:
247 | for group in opt.param_groups:
248 | group["lr"] = group["initial_lr"] * lrm
249 | muon_momentum = get_muon_momentum(step)
250 | for group in muon_optimizer.param_groups:
251 | group["momentum"] = muon_momentum
252 | for opt in optimizers:
253 | opt.step()
254 | model.zero_grad(set_to_none=True)
255 | synchronize()
256 | t1 = time.time()
257 | dt = t1 - t0
258 | # -------------------------------------------------------------------------
259 |
260 | # State
261 | step += 1
262 |
263 | # logging
264 | smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
265 | debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
266 | pct_done = 100 * progress
267 | tok_per_sec = int(world_tokens_per_fwdbwd / dt)
268 | flops_per_sec = num_flops_per_token * total_batch_size / dt
269 | promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
270 | mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
271 | if step > 10:
272 | total_training_time += dt # only count the time after the first 10 steps
273 | print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
274 | if step % 10 == 0:
275 | wandb_run.log({
276 | "step": step,
277 | "total_training_flops": flops_so_far,
278 | "total_training_time": total_training_time,
279 | "train/loss": debiased_smooth_loss,
280 | "train/lrm": lrm,
281 | "train/dt": dt,
282 | "train/tok_per_sec": tok_per_sec,
283 | "train/mfu": mfu,
284 | })
285 |
286 | # print a few more stats
287 | print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
288 | print0(f"Total training time: {total_training_time/60:.2f}m")
289 | print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
290 |
291 | # Log to report
292 | if not dry_run:
293 | from nanochat.report import get_report
294 | get_report().log(section="Midtraining", data=[
295 | user_config, # CLI args
296 | { # stats about the training setup
297 | "Number of iterations": step,
298 | "DDP world size": ddp_world_size,
299 | },
300 | { # stats about training outcomes
301 | "Minimum validation bpb": min_val_bpb,
302 | }
303 | ])
304 |
305 | # cleanup
306 | wandb_run.finish() # wandb run finish
307 | compute_cleanup()
308 |
--------------------------------------------------------------------------------