├── nanochat ├── __init__.py ├── logo.svg ├── configurator.py ├── dataloader.py ├── loss_eval.py ├── adamw.py ├── dataset.py ├── common.py ├── checkpoint_manager.py ├── muon.py ├── execution.py └── core_eval.py ├── .python-version ├── dev ├── nanochat.png ├── generate_logo.html ├── runcpu.sh ├── repackage_data_reference.py └── gen_synthetic_data.py ├── .gitignore ├── rustbpe ├── Cargo.toml ├── README.md └── Cargo.lock ├── LICENSE ├── pyproject.toml ├── tasks ├── smoltalk.py ├── arc.py ├── customjson.py ├── humaneval.py ├── mmlu.py ├── gsm8k.py └── common.py ├── scripts ├── base_loss.py ├── chat_cli.py ├── tok_train.py ├── base_eval.py ├── tok_eval.py ├── chat_sft.py ├── chat_eval.py └── mid_train.py ├── run1000.sh ├── speedrun.sh └── README.md /nanochat/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.10 2 | -------------------------------------------------------------------------------- /dev/nanochat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxroute/nanochat/master/dev/nanochat.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv/ 2 | __pycache__/ 3 | *.pyc 4 | rustbpe/target/ 5 | dev-ignore/ 6 | report.md 7 | -------------------------------------------------------------------------------- /rustbpe/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rustbpe" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | dary_heap = "0.3" 8 | indexmap = "2.2" 9 | fancy-regex = "0.16.1" 10 | log = "0.4.28" 11 | pyo3 = { version = "0.23.3", features = ["extension-module"] } 12 | pyo3-log = "0.12.4" 13 | ahash = "0.8.12" 14 | rayon = "1.11.0" 15 | compact_str = "0.9.0" 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Andrej Karpathy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rustbpe/README.md: -------------------------------------------------------------------------------- 1 | # rustbpe 2 | 3 | > The missing tiktoken training code 4 | 5 | A very lightweight Rust library for training a GPT tokenizer. The issue is that the inference library [tiktoken](https://github.com/openai/tiktoken) is great, but only does inference. Separately, the huggingface [tokenizers](https://github.com/huggingface/tokenizers) library does training, but it is rather bloated and really hard to navigate because it has to support all the different historical baggage of how people dealt with tokenizers over the years. More recently, I also wrote the [minbpe](https://github.com/karpathy/minbpe) library which does both training and inference, but only in inefficient Python. Basically what I really want is a non-fancy, super simple, but still relatively efficient training code for GPT tokenizer (more efficient than minbpe, much cleaner/simpler than tokenizers), and then export the trained vocab for inference with tiktoken. Does that make sense? So here we are. There are more opportunities for optimization here, I just stopped a bit early because unlike minbpe before it, rustbpe is now simple and fast enough, and not a significant bottleneck for nanochat. 6 | -------------------------------------------------------------------------------- /dev/generate_logo.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 28 | 29 | -------------------------------------------------------------------------------- /nanochat/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "nanochat" 3 | version = "0.1.0" 4 | description = "the minimal full-stack ChatGPT clone" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | dependencies = [ 8 | "datasets>=4.0.0", 9 | "fastapi>=0.117.1", 10 | "files-to-prompt>=0.6", 11 | "numpy==1.26.4", 12 | "psutil>=7.1.0", 13 | "regex>=2025.9.1", 14 | "setuptools>=80.9.0", 15 | "tiktoken>=0.11.0", 16 | "tokenizers>=0.22.0", 17 | "torch>=2.8.0", 18 | "uvicorn>=0.36.0", 19 | "wandb>=0.21.3", 20 | ] 21 | 22 | [build-system] 23 | requires = ["maturin>=1.7,<2.0"] 24 | build-backend = "maturin" 25 | 26 | [tool.maturin] 27 | module-name = "rustbpe" 28 | bindings = "pyo3" 29 | python-source = "." 30 | manifest-path = "rustbpe/Cargo.toml" 31 | 32 | [dependency-groups] 33 | dev = [ 34 | "maturin>=1.9.4", 35 | "pytest>=8.0.0", 36 | ] 37 | 38 | [tool.pytest.ini_options] 39 | markers = [ 40 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 41 | ] 42 | testpaths = ["tests"] 43 | python_files = ["test_*.py"] 44 | python_classes = ["Test*"] 45 | python_functions = ["test_*"] 46 | 47 | # target torch to cuda 12.8 48 | [tool.uv.sources] 49 | torch = [ 50 | { index = "pytorch-cpu", marker = "sys_platform != 'linux'" }, 51 | { index = "pytorch-cu128", marker = "sys_platform == 'linux'" }, 52 | ] 53 | 54 | [[tool.uv.index]] 55 | name = "pytorch-cpu" 56 | url = "https://download.pytorch.org/whl/cpu" 57 | explicit = true 58 | 59 | [[tool.uv.index]] 60 | name = "pytorch-cu128" 61 | url = "https://download.pytorch.org/whl/cu128" 62 | explicit = true -------------------------------------------------------------------------------- /tasks/smoltalk.py: -------------------------------------------------------------------------------- 1 | """ 2 | SmolTalk by HuggingFace. Good "general" conversational dataset. 3 | https://huggingface.co/datasets/HuggingFaceTB/smol-smoltalk 4 | We use the "smol" version, which is more appropriate for smaller models. 5 | """ 6 | 7 | from datasets import load_dataset 8 | from tasks.common import Task 9 | 10 | class SmolTalk(Task): 11 | """ smol-smoltalk dataset. train is 460K rows, test is 24K rows. """ 12 | 13 | def __init__(self, split, **kwargs): 14 | super().__init__(**kwargs) 15 | assert split in ["train", "test"], "SmolTalk split must be train|test" 16 | self.ds = load_dataset("HuggingFaceTB/smol-smoltalk", split=split).shuffle(seed=42) 17 | self.length = len(self.ds) 18 | 19 | def num_examples(self): 20 | return self.length 21 | 22 | def get_example(self, index): 23 | row = self.ds[index] 24 | messages = row["messages"] 25 | # --------------------------------------------------------------------- 26 | # sanity checking asserts here 27 | # TODO: we could remove these asserts later, for now just don't want any footguns 28 | # there is an optional system message at the beginning 29 | assert len(messages) >= 1 30 | first_message = messages[0] 31 | if first_message["role"] == "system": 32 | rest_messages = messages[1:] # optional system message is OK 33 | else: 34 | rest_messages = messages 35 | assert len(rest_messages) >= 2, "SmolTalk messages must have at least 2 messages" 36 | for i, message in enumerate(rest_messages): 37 | # user and assistant alternate as user,assistant,user,assistant,... 38 | expected_role = "user" if i % 2 == 0 else "assistant" 39 | assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" 40 | assert isinstance(message["content"], str), "Content must be a string" 41 | # --------------------------------------------------------------------- 42 | # create and return the Conversation object (ok to emit the system message too) 43 | conversation = { 44 | "messages": messages, 45 | } 46 | return conversation 47 | -------------------------------------------------------------------------------- /nanochat/configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import os 18 | import sys 19 | from ast import literal_eval 20 | 21 | def print0(s="",**kwargs): 22 | ddp_rank = int(os.environ.get('RANK', 0)) 23 | if ddp_rank == 0: 24 | print(s, **kwargs) 25 | 26 | for arg in sys.argv[1:]: 27 | if '=' not in arg: 28 | # assume it's the name of a config file 29 | assert not arg.startswith('--') 30 | config_file = arg 31 | print0(f"Overriding config with {config_file}:") 32 | with open(config_file) as f: 33 | print0(f.read()) 34 | exec(open(config_file).read()) 35 | else: 36 | # assume it's a --key=value argument 37 | assert arg.startswith('--') 38 | key, val = arg.split('=') 39 | key = key[2:] 40 | if key in globals(): 41 | try: 42 | # attempt to eval it it (e.g. if bool, number, or etc) 43 | attempt = literal_eval(val) 44 | except (SyntaxError, ValueError): 45 | # if that goes wrong, just use the string 46 | attempt = val 47 | # ensure the types match ok 48 | if globals()[key] is not None: 49 | attempt_type = type(attempt) 50 | default_type = type(globals()[key]) 51 | assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}" 52 | # cross fingers 53 | print0(f"Overriding: {key} = {attempt}") 54 | globals()[key] = attempt 55 | else: 56 | raise ValueError(f"Unknown config key: {key}") 57 | -------------------------------------------------------------------------------- /tasks/arc.py: -------------------------------------------------------------------------------- 1 | """ 2 | The ARC dataset from Allen AI. 3 | https://huggingface.co/datasets/allenai/ai2_arc 4 | """ 5 | 6 | from datasets import load_dataset 7 | from tasks.common import Task, render_mc 8 | 9 | class ARC(Task): 10 | 11 | def __init__(self, subset, split, **kwargs): 12 | super().__init__(**kwargs) 13 | assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge" 14 | assert split in ["train", "validation", "test"], "ARC split must be train|validation|test" 15 | self.ds = load_dataset("allenai/ai2_arc", subset, split=split).shuffle(seed=42) 16 | 17 | @property 18 | def eval_type(self): 19 | return 'categorical' 20 | 21 | def num_examples(self): 22 | return len(self.ds) 23 | 24 | def get_example(self, index): 25 | row = self.ds[index] 26 | question = row["question"] # the question text 27 | choices = row["choices"]["text"] # the text of each choice 28 | answer_string = row["answerKey"] # e.g. "A", "B", "C", "D" 29 | letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"] 30 | assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check 31 | # create and return the Conversation object 32 | user_message = render_mc(question, letters, choices) 33 | messages = [ 34 | {"role": "user", "content": user_message}, 35 | {"role": "assistant", "content": answer_string} 36 | ] 37 | conversation = { 38 | "messages": messages, 39 | "letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters 40 | } 41 | return conversation 42 | 43 | def evaluate(self, conversation, assistant_response): 44 | # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true 45 | # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it. 46 | assert assistant_response in conversation['letters'], f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}" 47 | assistant_message = conversation['messages'][-1]['content'] # e.g. "A" 48 | return assistant_response == assistant_message 49 | -------------------------------------------------------------------------------- /nanochat/dataloader.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import torch 4 | 5 | from nanochat.common import get_dist_info 6 | from nanochat.dataset import parquets_iter_batched 7 | from nanochat.tokenizer import get_tokenizer 8 | 9 | def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"): 10 | """Stream pretraining text from parquet files, tokenize, yield training batches.""" 11 | assert split in ["train", "val"], "split must be 'train' or 'val'" 12 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() 13 | needed_tokens = B * T + 1 # +1 is because we also need the target at the last token 14 | # get the tokenizer and the bos token 15 | tokenizer = get_tokenizer() 16 | bos_token = tokenizer.get_bos_token_id() 17 | # scratch buffer holds the tokens for one iteration 18 | token_buffer = deque() # we stream tokens on the right and pop from the left 19 | 20 | # infinite iterator over document batches 21 | def document_batches(): 22 | while True: 23 | # batch will iterate in group size of the parquet files, usually e.g. 1024 rows 24 | for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size): 25 | # for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows 26 | for i in range(0, len(batch), tokenizer_batch_size): 27 | yield batch[i:i+tokenizer_batch_size] 28 | batches = document_batches() 29 | 30 | batch_index = 0 31 | while True: 32 | # Accumulate enough tokens for one iteration before yielding. 33 | while len(token_buffer) < needed_tokens: 34 | doc_batch = next(batches) 35 | token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) 36 | for tokens in token_lists: 37 | token_buffer.extend(tokens) 38 | batch_index += 1 39 | # Move tokens from the deque into the scratch buffer 40 | tokens = [token_buffer.popleft() for _ in range(needed_tokens)] 41 | scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True) 42 | # Create the inputs/targets as 1D tensors 43 | inputs_cpu = scratch[:-1].to(dtype=torch.int32) 44 | targets_cpu = scratch[1:] 45 | # Reshape to 2D and move to GPU async 46 | inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True) 47 | targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True) 48 | yield inputs, targets 49 | -------------------------------------------------------------------------------- /dev/runcpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks) 4 | # Run as: 5 | # bash dev/cpu_demo_run.sh 6 | 7 | # NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook. 8 | # Think of this run as educational/fun demo, not something you should expect to work well. 9 | # This is also why I hide this script away in dev/ 10 | 11 | # all the setup stuff 12 | export OMP_NUM_THREADS=1 13 | NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" 14 | mkdir -p $NANOCHAT_BASE_DIR 15 | command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh 16 | [ -d ".venv" ] || uv venv 17 | uv sync 18 | source .venv/bin/activate 19 | if [ -z "$WANDB_RUN" ]; then 20 | WANDB_RUN=dummy 21 | fi 22 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y 23 | source "$HOME/.cargo/env" 24 | uv run maturin develop --release --manifest-path rustbpe/Cargo.toml 25 | EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip 26 | if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then 27 | curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL 28 | unzip -q eval_bundle.zip 29 | rm eval_bundle.zip 30 | mv eval_bundle $NANOCHAT_BASE_DIR 31 | fi 32 | 33 | # wipe the report 34 | python -m nanochat.report reset 35 | 36 | # train tokenizer on ~1B characters 37 | python -m nanochat.dataset -n 4 38 | python -m scripts.tok_train --max_chars=1000000000 39 | python -m scripts.tok_eval 40 | 41 | # train a very small 4 layer model on the CPU 42 | # each optimization step processes a single sequence of 1024 tokens 43 | # we only run 50 steps of optimization (bump this to get better results) 44 | python -m scripts.base_train \ 45 | --depth=4 \ 46 | --max_seq_len=1024 \ 47 | --device_batch_size=1 \ 48 | --total_batch_size=1024 \ 49 | --eval_every=50 \ 50 | --eval_tokens=4096 \ 51 | --core_metric_every=50 \ 52 | --core_metric_max_per_task=12 \ 53 | --sample_every=50 \ 54 | --num_iterations=50 55 | python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096 56 | python -m scripts.base_eval --max-per-task=16 57 | 58 | # midtraining 59 | python -m scripts.mid_train \ 60 | --max_seq_len=1024 \ 61 | --device_batch_size=1 \ 62 | --eval_every=50 \ 63 | --eval_tokens=4096 \ 64 | --total_batch_size=1024 \ 65 | --num_iterations=100 66 | # eval results will be terrible, this is just to execute the code paths. 67 | # note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems 68 | python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20 69 | 70 | # SFT 71 | python -m scripts.chat_sft \ 72 | --device_batch_size=1 \ 73 | --target_examples_per_step=4 \ 74 | --num_iterations=100 \ 75 | --eval_steps=4 \ 76 | --eval_metrics_max_problems=16 77 | 78 | # Chat CLI 79 | # python -m scripts.chat_cli -p "Why is the sky blue?" 80 | 81 | # Chat Web 82 | # python -m scripts.chat_web 83 | 84 | python -m nanochat.report generate 85 | -------------------------------------------------------------------------------- /tasks/customjson.py: -------------------------------------------------------------------------------- 1 | """ 2 | CustomJSON task for loading conversations from JSONL files. 3 | Each line in the JSONL file should be a JSON array of messages. 4 | """ 5 | 6 | import os 7 | import json 8 | from tasks.common import Task 9 | 10 | class CustomJSON(Task): 11 | """ 12 | Load conversations from a JSONL file. 13 | Each line should be a JSON array of message objects with 'role' and 'content' fields. 14 | Example line: [{"role":"user","content":"Hi"},{"role":"assistant","content":"Hello"}] 15 | """ 16 | 17 | def __init__(self, filepath, **kwargs): 18 | super().__init__(**kwargs) 19 | self.filepath = filepath 20 | self.conversations = [] 21 | 22 | # Load all conversations from the JSONL file 23 | if not os.path.exists(filepath): 24 | # Helpful error message due to recent change. Will be removed in the future. 25 | print("-" * 80) 26 | print(f"Warning: File {filepath} does not exist") 27 | print("HINT (Oct 21 2025)") 28 | print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations") 29 | print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139") 30 | print("Quick fix: simply run the following command to download the file and you're done:") 31 | print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl") 32 | print("-" * 80) 33 | 34 | else: 35 | with open(filepath, 'r') as f: 36 | for line in f: 37 | line = line.strip() 38 | if not line: # skip empty lines 39 | continue 40 | messages = json.loads(line) 41 | # Validate the conversation structure 42 | assert isinstance(messages, list), f"Expected list of messages, got {type(messages)}" 43 | assert len(messages) >= 2, f"Conversation must have at least 2 messages, got {len(messages)}" 44 | # Validate message structure and alternating roles 45 | for i, message in enumerate(messages): 46 | assert "role" in message, f"Message {i} missing 'role' field" 47 | assert "content" in message, f"Message {i} missing 'content' field" 48 | expected_role = "user" if i % 2 == 0 else "assistant" 49 | assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" 50 | assert isinstance(message["content"], str), f"Message {i} content must be a string" 51 | 52 | self.conversations.append(messages) 53 | 54 | self.length = len(self.conversations) 55 | 56 | def num_examples(self): 57 | return self.length 58 | 59 | def get_example(self, index): 60 | messages = self.conversations[index] 61 | conversation = { 62 | "messages": messages, 63 | } 64 | return conversation 65 | 66 | -------------------------------------------------------------------------------- /nanochat/loss_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | A number of functions that help with evaluating a base model. 3 | """ 4 | import math 5 | import torch 6 | import torch.distributed as dist 7 | 8 | @torch.no_grad() 9 | def evaluate_bpb(model, batches, steps, token_bytes): 10 | """ 11 | Instead of the naive 'mean loss', this function returns the bits per byte (bpb), 12 | which is a tokenization vocab size-indepedent metric, meaning you are still comparing 13 | apples:apples if you change the vocab size. The way this works is that instead of just 14 | calculating the average loss as usual, you calculate the sum loss, and indepependently 15 | also the sum bytes (of all the target tokens), and divide. This normalizes the loss by 16 | the number of bytes that the target tokens represent. 17 | 18 | The added complexity is so that: 19 | 1) All "normal" tokens are normalized by the length of the token in bytes 20 | 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out. 21 | 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric. 22 | 23 | In addition to evaluate_loss, we need the token_bytes tensor: 24 | It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for 25 | each token id, or 0 if the token is to not be counted (e.g. special tokens). 26 | """ 27 | # record the losses 28 | total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device()) 29 | total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device()) 30 | batch_iter = iter(batches) 31 | for _ in range(steps): 32 | x, y = next(batch_iter) 33 | loss2d = model(x, y, loss_reduction='none') # (B, T) 34 | loss2d = loss2d.view(-1) # flatten 35 | y = y.view(-1) # flatten 36 | if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32 37 | # slightly more complex code path if some target tokens are ignore_index (e.g. -1) 38 | # any target token < 0 is to be ignored: do NOT index token_bytes with negatives 39 | valid = y >= 0 40 | y_safe = torch.where(valid, y, torch.zeros_like(y)) 41 | # map valid targets to their byte length; ignored targets contribute 0 bytes 42 | num_bytes2d = torch.where( 43 | valid, 44 | token_bytes[y_safe], 45 | torch.zeros_like(y, dtype=token_bytes.dtype) 46 | ) 47 | total_nats += (loss2d * (num_bytes2d > 0)).sum() 48 | total_bytes += num_bytes2d.sum() 49 | else: 50 | # fast path: no ignored targets, safe to index directly 51 | num_bytes2d = token_bytes[y] 52 | total_nats += (loss2d * (num_bytes2d > 0)).sum() 53 | total_bytes += num_bytes2d.sum() 54 | # sum reduce across all ranks 55 | world_size = dist.get_world_size() if dist.is_initialized() else 1 56 | if world_size > 1: 57 | dist.all_reduce(total_nats, op=dist.ReduceOp.SUM) 58 | dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM) 59 | # move both to cpu, calculate bpb and return 60 | total_nats = total_nats.item() 61 | total_bytes = total_bytes.item() 62 | bpb = total_nats / (math.log(2) * total_bytes) 63 | return bpb 64 | -------------------------------------------------------------------------------- /scripts/base_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loads a checkpoint, and: 3 | - Evaluates the loss on a larger chunk of train/val splits 4 | - Samples from the model 5 | 6 | Example run as: 7 | torchrun --standalone --nproc_per_node=8 -m scripts.base_loss 8 | """ 9 | import os 10 | from contextlib import nullcontext 11 | import torch 12 | from nanochat.checkpoint_manager import load_model 13 | from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type 14 | from nanochat.dataloader import tokenizing_distributed_data_loader 15 | from nanochat.tokenizer import get_token_bytes 16 | from nanochat.loss_eval import evaluate_bpb 17 | from nanochat.engine import Engine 18 | 19 | # Configuration 20 | device_batch_size = 32 21 | split_tokens = 20*524288 # number of tokens to evaluate per split 22 | model_tag = None # optional model tag for the output directory name 23 | model_step = None # optional model step for the output directory name 24 | device_type = "" # cuda|cpu|mps (empty => autodetect) 25 | exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file 26 | 27 | # Load the base model and the tokenizer 28 | device_type = autodetect_device_type() if device_type == "" else device_type 29 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) 30 | model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step) 31 | sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really 32 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() 33 | 34 | # Evaluate the loss on each split 35 | tokens_per_step = device_batch_size * sequence_len * ddp_world_size 36 | assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step" 37 | steps = split_tokens // tokens_per_step 38 | token_bytes = get_token_bytes(device=device) 39 | bpb_results = {} 40 | for split_name in ["train", "val"]: 41 | loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device) 42 | with autocast_ctx: 43 | bpb = evaluate_bpb(model, loader, steps, token_bytes) 44 | print0(f"{split_name} bpb: {bpb:.4f}") 45 | bpb_results[split_name] = bpb 46 | 47 | # Master process also samples from the model 48 | samples = [] 49 | if ddp_rank == 0: 50 | prompts = [ 51 | "The capital of France is", 52 | "The chemical symbol of gold is", 53 | "If yesterday was Friday, then tomorrow will be", 54 | "The opposite of hot is", 55 | "The planets of the solar system are:", 56 | "My favorite color is", 57 | "If 5*x + 3 = 13, then x is", 58 | ] 59 | engine = Engine(model, tokenizer) 60 | for prompt in prompts: 61 | tokens = tokenizer(prompt, prepend="<|bos|>") 62 | with autocast_ctx: 63 | sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) 64 | sample_str = tokenizer.decode(sample[0]) 65 | print0(sample_str) 66 | samples.append(sample_str) 67 | 68 | # Log to report 69 | from nanochat.report import get_report 70 | get_report().log(section="Base model loss", data=[ 71 | { 72 | "train bpb": bpb_results["train"], 73 | "val bpb": bpb_results["val"], 74 | }, 75 | {f"sample {i}": sample for i, sample in enumerate(samples)}, 76 | ]) 77 | 78 | # Cleanup 79 | compute_cleanup() 80 | -------------------------------------------------------------------------------- /dev/repackage_data_reference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Repackage the FinewebEdu-100B dataset into shards: 3 | 4 | - each shard is ~100MB in size (after zstd compression) 5 | - parquets are written with row group size of 1000 6 | - shuffle the dataset 7 | 8 | This will be uploaded to HuggingFace for hosting. 9 | The big deal is that our DataLoader will be able to stream 10 | the data and cache it along the way on disk, decreasing the 11 | training latency. 12 | 13 | NOTE: This file is meant only as reference/documentation of the 14 | dataset preparation and it is not used during the project runtime. 15 | """ 16 | import os 17 | import time 18 | 19 | from datasets import load_dataset 20 | import pyarrow.parquet as pq 21 | import pyarrow as pa 22 | 23 | # Source dataset 24 | dataset_kwargs = { 25 | "path": "HuggingFaceFW/fineweb-edu", 26 | "split": "train", 27 | "name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total 28 | } 29 | ds = load_dataset(**dataset_kwargs) 30 | 31 | # Shuffle to scramble the order 32 | ds = ds.shuffle(seed=42) 33 | ndocs = len(ds) # total number of documents to process 34 | print(f"Total number of documents: {ndocs}") 35 | 36 | # Repackage into parquet files 37 | output_dir = "/home/ubuntu/.cache/nanochat/base_data" 38 | os.makedirs(output_dir, exist_ok=True) 39 | 40 | # Write to parquet files 41 | chars_per_shard = 250_000_000 42 | row_group_size = 1024 # HF uses 1000 but we use multiple of 2, nicer for distributed data loader later 43 | shard_docs = [] 44 | shard_index = 0 45 | shard_characters = 0 46 | total_docs_processed = 0 47 | total_time_spent = 0 48 | t0 = time.time() 49 | for doc in ds: 50 | text = doc['text'] 51 | shard_docs.append(text) 52 | shard_characters += len(text) 53 | collected_enough_chars = shard_characters >= chars_per_shard 54 | docs_multiple_of_row_group_size = len(shard_docs) % row_group_size == 0 55 | if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed) 56 | shard_path = os.path.join(output_dir, f"shard_{shard_index:05d}.parquet") 57 | shard_table = pa.Table.from_pydict({"text": shard_docs}) 58 | pq.write_table( 59 | shard_table, 60 | shard_path, 61 | row_group_size=row_group_size, 62 | use_dictionary=False, # this is usually used for categorical data 63 | compression="zstd", # Valid values: {‘NONE’, ‘SNAPPY’, ‘GZIP’, ‘BROTLI’, ‘LZ4’, ‘ZSTD’} 64 | compression_level=3, 65 | write_statistics=False, # not needed for text 66 | ) 67 | t1 = time.time() 68 | dt = t1 - t0 # for this shard alone 69 | t0 = t1 70 | total_docs_processed += len(shard_docs) 71 | total_time_spent += dt 72 | remaining_docs = ndocs - total_docs_processed 73 | avg_time_per_doc = total_time_spent / total_docs_processed 74 | remaining_time = remaining_docs * avg_time_per_doc 75 | remaining_time_hours = remaining_time / 3600 76 | print(f"Wrote {shard_path}. #documents: {len(shard_docs)} | #characters: {shard_characters} | time: {dt:.2f}s | remaining time: {remaining_time_hours:.2f}h") 77 | shard_docs = [] 78 | shard_characters = 0 79 | shard_index += 1 80 | 81 | # Demonstration of how the data was later uploaded to HuggingFace 82 | def upload(): 83 | import os 84 | from huggingface_hub import HfApi 85 | token = os.getenv("HF_TOKEN") 86 | api = HfApi(token=token) 87 | api.upload_large_folder( 88 | folder_path=output_dir, 89 | repo_id="karpathy/fineweb-edu-100b-shuffle", 90 | repo_type="dataset", 91 | ) 92 | # upload() 93 | -------------------------------------------------------------------------------- /nanochat/adamw.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrowed from modded-nanogpt. By Keller, @vagrawal, et al. 3 | Not a general optimizer! But works for our specific use. 4 | """ 5 | import torch 6 | import torch.distributed as dist 7 | from torch import Tensor 8 | 9 | 10 | class DistAdamW(torch.optim.Optimizer): 11 | """ 12 | Distributed AdamW optimizer. 13 | In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction 14 | """ 15 | def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): 16 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 17 | super().__init__(param_groups, defaults) 18 | 19 | @torch.compile 20 | @torch.no_grad() 21 | def step(self): 22 | rank = dist.get_rank() 23 | world_size = dist.get_world_size() 24 | reduce_scatter_futures: list[torch.Future] = [] 25 | all_reduce_futures: list[torch.Future] = [] 26 | grad_slices = [] 27 | for group in self.param_groups: 28 | params: list[Tensor] = group["params"] 29 | for base_i in range(len(params)): 30 | grad = params[base_i].grad 31 | rank_size = grad.shape[0] // world_size 32 | grad_slice = torch.empty_like(grad[:rank_size]) 33 | reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) 34 | grad_slices.append(grad_slice) 35 | 36 | idx = 0 37 | for group in self.param_groups: 38 | beta1, beta2 = group['betas'] 39 | eps = group['eps'] 40 | wd = group['weight_decay'] 41 | params = group['params'] 42 | for base in range(len(params)): 43 | reduce_scatter_futures[idx].wait() 44 | p = params[base] 45 | rank_size = p.shape[0] // world_size 46 | p_slice = p[rank * rank_size:(rank + 1) * rank_size] 47 | lr = group['lr'] * getattr(p, "lr_mul", 1.0) 48 | state = self.state[p] 49 | g_slice = grad_slices[idx] 50 | # State init 51 | if not state: 52 | state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) 53 | state['exp_avg'] = torch.zeros_like(p_slice) 54 | state['exp_avg_sq'] = torch.zeros_like(p_slice) 55 | exp_avg = state['exp_avg'] 56 | exp_avg_sq = state['exp_avg_sq'] 57 | state['step'] += 1 58 | t = state['step'] 59 | # weight decay 60 | if wd != 0: 61 | eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) 62 | p_slice.mul_(1 - eff_weight_decay) 63 | # update running averages 64 | exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) 65 | exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) 66 | # bias corrections 67 | bias1 = 1 - beta1 ** t 68 | bias2 = 1 - beta2 ** t 69 | # compute step 70 | denom = exp_avg_sq.sqrt().add_(eps) 71 | step_size = lr * (torch.sqrt(bias2) / bias1) 72 | update = exp_avg.div(denom).mul_(step_size) 73 | p_slice.add_(other=update, alpha=-1.0) 74 | idx += 1 75 | all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) 76 | torch.futures.collect_all(all_reduce_futures).wait() 77 | -------------------------------------------------------------------------------- /tasks/humaneval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate the Chat model on HumanEval dataset. 3 | Btw this dataset is a misnomer and has nothing to do with humans. 4 | It is a coding benchmark. 5 | """ 6 | 7 | import re 8 | from datasets import load_dataset 9 | from nanochat.execution import execute_code 10 | from tasks.common import Task 11 | 12 | def extract_imports(prompt): 13 | """Extract import statements from the beginning of a code block.""" 14 | imports = [] 15 | for line in prompt.split('\n'): 16 | stripped = line.strip() 17 | if stripped.startswith('import ') or stripped.startswith('from '): 18 | imports.append(stripped) 19 | elif stripped and not stripped.startswith('#'): 20 | # Stop at first non-import, non-comment line 21 | break 22 | return '\n'.join(imports) 23 | 24 | def extract_program(completion): 25 | """ 26 | Extract Python code from LLM completion. 27 | 28 | Handles various output formats: 29 | - Code wrapped in ```python ... ``` or ``` ... ``` blocks 30 | - Plain code without markdown blocks 31 | - Extra text before/after code blocks 32 | 33 | Returns the first code block if found, otherwise returns the whole completion. 34 | """ 35 | # Try to find markdown code blocks (```python or just ```) 36 | # Match ```python\n...\n``` or ```\n...\n``` 37 | pattern = r'```(?:python)?\s*\n(.*?)\n```' 38 | matches = re.findall(pattern, completion, re.DOTALL) 39 | 40 | if matches: 41 | # Return the first code block found 42 | return matches[0].strip() 43 | 44 | # No code blocks found, return the whole completion 45 | return completion.strip() 46 | 47 | class HumanEval(Task): 48 | 49 | def __init__(self, **kwargs): 50 | super().__init__(**kwargs) 51 | self.ds = load_dataset("openai/openai_humaneval", split="test").shuffle(seed=42) 52 | 53 | @property 54 | def eval_type(self): 55 | return 'generative' 56 | 57 | def num_examples(self): 58 | return len(self.ds) 59 | 60 | def get_example(self, index): 61 | """ Get a single problem from the dataset. """ 62 | row = self.ds[index] 63 | prompt = row['prompt'] # prompts in HumanEval are the beginning of the program 64 | solution = row['canonical_solution'] # the correct continuation of the program 65 | entry_point = row['entry_point'] # the function to check 66 | test = row['test'] # the test cases 67 | complete_solution = f"{prompt}\n{solution}" 68 | messages = [ 69 | {"role": "user", "content": prompt}, 70 | {"role": "assistant", "content": complete_solution}, 71 | ] 72 | conversation = { 73 | "messages": messages, 74 | "entry_point": entry_point, # needed during evaluation 75 | "test": test, # needed during evaluation 76 | } 77 | return conversation 78 | 79 | def evaluate(self, conversation, completion): 80 | """ Given (conversation, completion), return boolean success of the completion. """ 81 | # the prompt will contain the imports and the function signature 82 | imports = extract_imports(conversation['messages'][0]['content']) 83 | # the completion will usually contain the whole function 84 | # but not always with the needed imports, so we manually append them 85 | completion_code = extract_program(completion) 86 | program = ( 87 | imports 88 | + "\n\n" 89 | + completion_code 90 | + "\n\n" 91 | + conversation['test'] 92 | + "\n" 93 | + f"check({conversation['entry_point']})" 94 | ) 95 | result = execute_code(program) 96 | success = result.success 97 | return success 98 | -------------------------------------------------------------------------------- /tasks/mmlu.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MMLU dataset. 3 | https://huggingface.co/datasets/cais/mmlu 4 | """ 5 | 6 | from datasets import load_dataset 7 | from tasks.common import Task, render_mc 8 | 9 | class MMLU(Task): 10 | 11 | letters = ('A', 'B', 'C', 'D') 12 | groups = ('abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions') 13 | 14 | def __init__(self, subset, split, **kwargs): 15 | super().__init__(**kwargs) 16 | assert subset in ["all", "auxiliary_train"], f"subset {subset} must be all|auxiliary_train" 17 | assert split in ["train", "validation", "dev", "test"], f"split {split} must be train|validation|dev|test" 18 | if subset == "auxiliary_train": 19 | assert split == "train", "auxiliary_train must be split into train" 20 | self.subset = subset 21 | self.split = split 22 | self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42) 23 | if subset == "auxiliary_train": 24 | # I don't understand why but the auxiliary_train rows have some weird additional 'train' wrapper 25 | self.ds = self.ds.map(lambda row: row['train'], remove_columns=['train']) 26 | 27 | @property 28 | def eval_type(self): 29 | return 'categorical' 30 | 31 | def num_examples(self): 32 | return len(self.ds) 33 | 34 | def get_example(self, index): 35 | row = self.ds[index] 36 | question = row["question"] # the question text 37 | choices = row["choices"] # the text of each choice 38 | answer = row["answer"] # index of the answer, e.g. 0,1,2,3 (for A,B,C,D) 39 | subject = row["subject"] # e.g. "college_biology", "college_chemistry", etc. 40 | assert len(choices) == 4, "MMLU should have 4 choices" 41 | # create and return the Conversation object 42 | user_message = render_mc(question, self.letters, choices) 43 | assistant_message = self.letters[answer] 44 | messages = [ 45 | {"role": "user", "content": user_message}, 46 | {"role": "assistant", "content": assistant_message} 47 | ] 48 | conversation = { 49 | "messages": messages, 50 | "subject": subject, # might be useful later for grouping metrics by subject 51 | "letters": self.letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters 52 | } 53 | return conversation 54 | 55 | def evaluate(self, conversation, assistant_response): 56 | # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true 57 | # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it. 58 | assert assistant_response in self.letters, f"MMLU answer {assistant_response} is expected to be one of {self.letters}" 59 | assistant_message = conversation['messages'][-1]['content'] # e.g. "A" 60 | return assistant_response == assistant_message 61 | -------------------------------------------------------------------------------- /scripts/chat_cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | New and upgraded chat mode because a lot of the code has changed since the last one. 3 | 4 | Intended to be run single GPU only atm: 5 | python -m scripts.chat_cli -i mid 6 | """ 7 | import argparse 8 | import torch 9 | from nanochat.common import compute_init, autodetect_device_type 10 | from contextlib import nullcontext 11 | from nanochat.engine import Engine 12 | from nanochat.checkpoint_manager import load_model 13 | 14 | parser = argparse.ArgumentParser(description='Chat with the model') 15 | parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") 16 | parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') 17 | parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') 18 | parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back') 19 | parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation') 20 | parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter') 21 | parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') 22 | parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) 23 | args = parser.parse_args() 24 | 25 | # Init the model and tokenizer 26 | 27 | device_type = autodetect_device_type() if args.device_type == "" else args.device_type 28 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) 29 | ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 30 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() 31 | model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) 32 | 33 | # Special tokens for the chat state machine 34 | bos = tokenizer.get_bos_token_id() 35 | user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>") 36 | assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>") 37 | 38 | # Create Engine for efficient generation 39 | engine = Engine(model, tokenizer) 40 | 41 | print("\nNanoChat Interactive Mode") 42 | print("-" * 50) 43 | print("Type 'quit' or 'exit' to end the conversation") 44 | print("Type 'clear' to start a new conversation") 45 | print("-" * 50) 46 | 47 | conversation_tokens = [bos] 48 | 49 | while True: 50 | 51 | if args.prompt: 52 | # Get the prompt from the launch command 53 | user_input = args.prompt 54 | else: 55 | # Get the prompt interactively from the console 56 | try: 57 | user_input = input("\nUser: ").strip() 58 | except (EOFError, KeyboardInterrupt): 59 | print("\nGoodbye!") 60 | break 61 | 62 | # Handle special commands 63 | if user_input.lower() in ['quit', 'exit']: 64 | print("Goodbye!") 65 | break 66 | 67 | if user_input.lower() == 'clear': 68 | conversation_tokens = [bos] 69 | print("Conversation cleared.") 70 | continue 71 | 72 | if not user_input: 73 | continue 74 | 75 | # Add User message to the conversation 76 | conversation_tokens.append(user_start) 77 | conversation_tokens.extend(tokenizer.encode(user_input)) 78 | conversation_tokens.append(user_end) 79 | 80 | # Kick off the assistant 81 | conversation_tokens.append(assistant_start) 82 | generate_kwargs = { 83 | "num_samples": 1, 84 | "max_tokens": 256, 85 | "temperature": args.temperature, 86 | "top_k": args.top_k, 87 | } 88 | response_tokens = [] 89 | print("\nAssistant: ", end="", flush=True) 90 | with autocast_ctx: 91 | for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs): 92 | token = token_column[0] # pop the batch dimension (num_samples=1) 93 | response_tokens.append(token) 94 | token_text = tokenizer.decode([token]) 95 | print(token_text, end="", flush=True) 96 | print() 97 | # we have to ensure that the assistant end token is the last token 98 | # so even if generation ends due to max tokens, we have to append it to the end 99 | if response_tokens[-1] != assistant_end: 100 | response_tokens.append(assistant_end) 101 | conversation_tokens.extend(response_tokens) 102 | 103 | # In the prompt mode, we only want a single response and exit 104 | if args.prompt: 105 | break 106 | -------------------------------------------------------------------------------- /scripts/tok_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a tokenizer using the HuggingFace Tokenizers library. 3 | In the style of GPT-4 tokenizer. 4 | """ 5 | import os 6 | import time 7 | import argparse 8 | import torch 9 | from nanochat.tokenizer import RustBPETokenizer 10 | from nanochat.common import get_base_dir 11 | from nanochat.dataset import parquets_iter_batched 12 | 13 | # ----------------------------------------------------------------------------- 14 | # Parse command line arguments 15 | 16 | parser = argparse.ArgumentParser(description='Train a BPE tokenizer') 17 | parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)') 18 | parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') 19 | parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)') 20 | args = parser.parse_args() 21 | print(f"max_chars: {args.max_chars:,}") 22 | print(f"doc_cap: {args.doc_cap:,}") 23 | print(f"vocab_size: {args.vocab_size:,}") 24 | 25 | # ----------------------------------------------------------------------------- 26 | # Text iterator 27 | 28 | def text_iterator(): 29 | """ 30 | 1) Flatten the batches into a single iterator 31 | 2) Crop every document to args.doc_cap characters 32 | 3) Break when we've seen args.max_chars characters 33 | """ 34 | nchars = 0 35 | for batch in parquets_iter_batched(split="train"): 36 | for doc in batch: 37 | doc_text = doc 38 | if len(doc_text) > args.doc_cap: 39 | doc_text = doc_text[:args.doc_cap] 40 | nchars += len(doc_text) 41 | yield doc_text 42 | if nchars > args.max_chars: 43 | return 44 | text_iter = text_iterator() 45 | 46 | # ----------------------------------------------------------------------------- 47 | # Train the tokenizer 48 | t0 = time.time() 49 | tokenizer = RustBPETokenizer.train_from_iterator(text_iter, args.vocab_size) 50 | t1 = time.time() 51 | train_time = t1 - t0 52 | print(f"Training time: {train_time:.2f}s") 53 | 54 | # ----------------------------------------------------------------------------- 55 | # Save the tokenizer to disk 56 | base_dir = get_base_dir() 57 | tokenizer_dir = os.path.join(base_dir, "tokenizer") 58 | tokenizer.save(tokenizer_dir) 59 | 60 | # ----------------------------------------------------------------------------- 61 | # Quick inline sanity check 62 | test_text = """Hello world! This is a test. 63 | Numbers: 123, 4567, 89 64 | Contractions: I'm, you're, it's 65 | Special chars: @#$%^&*() 66 | Unicode: 你好世界 🌍""" 67 | encoded = tokenizer.encode(test_text) 68 | decoded = tokenizer.decode(encoded) 69 | assert decoded == test_text 70 | 71 | # ----------------------------------------------------------------------------- 72 | # One more thing: we wish to cache a mapping from token id to number of bytes of that token 73 | # for efficient evaluation of bits per byte. Unlike the typical mean loss, this 74 | # allows us to report a loss that is invariant to the vocab size of the tokenizer. 75 | # The bits per byte on the validation set is then one of the primary metrics we care about. 76 | vocab_size = tokenizer.get_vocab_size() 77 | special_set = set(tokenizer.get_special_tokens()) 78 | token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)] 79 | token_bytes = [] 80 | for token_id in range(vocab_size): 81 | token_str = token_strings[token_id] # the Python string representation of this token 82 | if token_str in special_set: 83 | token_bytes.append(0) # special characters are not counted 84 | else: 85 | id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token 86 | token_bytes.append(id_bytes) 87 | token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu') 88 | token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt") 89 | with open(token_bytes_path, "wb") as f: 90 | torch.save(token_bytes, f) 91 | print(f"Saved token_bytes to {token_bytes_path}") 92 | 93 | # Log to report 94 | from nanochat.report import get_report 95 | token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32) 96 | get_report().log(section="Tokenizer training", data=[ 97 | vars(args), # argparse command line arguments 98 | {"train_time": train_time}, 99 | {"num_special_tokens": len(special_set)}, 100 | { 101 | "token_bytes_min": int(token_bytes_nonzero.min().item()), 102 | "token_bytes_max": int(token_bytes_nonzero.max().item()), 103 | "token_bytes_mean": token_bytes_nonzero.mean().item(), 104 | "token_bytes_std": token_bytes_nonzero.std().item(), 105 | } 106 | ]) 107 | -------------------------------------------------------------------------------- /run1000.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # The $1000 tier of nanochat 4 | # Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node 5 | # A bit sparser on comments, see speedrun.sh for more detail 6 | 7 | # all the setup stuff 8 | export OMP_NUM_THREADS=1 9 | export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" 10 | mkdir -p $NANOCHAT_BASE_DIR 11 | command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh 12 | [ -d ".venv" ] || uv venv 13 | uv sync 14 | source .venv/bin/activate 15 | if [ -z "$WANDB_RUN" ]; then 16 | WANDB_RUN=dummy 17 | fi 18 | python -m nanochat.report reset 19 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y 20 | source "$HOME/.cargo/env" 21 | uv run maturin develop --release --manifest-path rustbpe/Cargo.toml 22 | EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip 23 | if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then 24 | curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL 25 | unzip -q eval_bundle.zip 26 | rm eval_bundle.zip 27 | mv eval_bundle $NANOCHAT_BASE_DIR 28 | fi 29 | curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl 30 | 31 | # train tokenizer on ~4B characters and kick off download of the rest for pretraining 32 | python -m nanochat.dataset -n 16 33 | # start downloading the rest of the shards for a total of 800 (see below why 800) 34 | python -m nanochat.dataset -n 800 & 35 | # todo: download the rest of it 36 | python -m scripts.tok_train --max_chars=4000000000 37 | python -m scripts.tok_eval 38 | 39 | # Documenting my process for determining the hyperparameters for this run1000.sh script: 40 | # We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute 41 | # 1) I guessed the model size for this to be about depth=32 42 | # 2) Determine the device_batch_size that fits: 43 | # Running the base_train.py script with --depth=32, I saw that --device_batch_size=16 44 | # runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training, 45 | # I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%. 46 | # So the training script was running ok and showed: 47 | # Vocab size: 65,536 48 | # num_layers: 32 49 | # model_dim: 2048 50 | # num_heads: 16 51 | # num_kv_heads: 16 52 | # Tokens / micro-batch / rank: 8 x 2048 = 16,384 53 | # Tokens / micro-batch: 131,072 54 | # Total batch size 524,288 => gradient accumulation steps: 4 55 | # Number of parameters: 1,879,048,192 56 | # Estimated FLOPs per token: 1.207960e+10 57 | # Calculated number of iterations from target data:param ratio: 71,680 58 | # Total number of training tokens: 37,580,963,840 59 | # Tokens : Params ratio: 20.00 60 | # Total training FLOPs estimate: 4.539628e+20 61 | # step 00004/71680 (0.01%) | loss: 8.813754 | lrm: 1.00 | dt: 1571.88ms | tok/sec: 83,385 | mfu: 50.92 | total time: 0.00m 62 | # step 00005/71680 (0.01%) | loss: 8.488074 | lrm: 1.00 | dt: 1572.76ms | tok/sec: 83,338 | mfu: 50.89 | total time: 0.00m 63 | # ... 64 | # 3) validate that the runtime fits our budget: 65 | # The training script uses the Chinchilla scaling law to compute-optimally set #tokens = 20 * #params. In particular: 66 | # The script shows that we will be training for 71,680 steps, and each step takes 1.574s so: 67 | # estimated time to train: 71,680 * 1.574s / 60 / 60 = 31.3 hours. 68 | # This is OK, fits our budget, and leaves ~10 hours for midtraining and SFT and evals and maybe RL. 69 | # It's possible that we might even fit depth=33 or depth=34, but for now let's go along with this. 70 | # 4) The last thing to pay attention to is the amount of training data required for the run. 71 | # The script above calculated that "Total number of training tokens: 37,580,963,840" 72 | # The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings. 73 | # So ~38B tokens # ~4.8 chars/token = ~185B chars. 74 | # Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards. 75 | # For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards. 76 | # If we didn't have enough data, the training script would loop around and do multiple epochs over the same data, 77 | # which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd 78 | # start to overfit hard. 79 | # 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script. 80 | torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=32 --device_batch_size=8 81 | torchrun --standalone --nproc_per_node=8 -m scripts.base_loss 82 | torchrun --standalone --nproc_per_node=8 -m scripts.base_eval 83 | 84 | # midtrain 85 | # NOTE: ensure that we use the same device_batch_size here as the base training script. 86 | torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN 87 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid 88 | 89 | # sft 90 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN 91 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft 92 | 93 | # generate final report 94 | python -m nanochat.report generate 95 | 96 | # talk to it 97 | python -m scripts.chat_web 98 | -------------------------------------------------------------------------------- /tasks/gsm8k.py: -------------------------------------------------------------------------------- 1 | """ 2 | GSM8K evaluation. 3 | https://huggingface.co/datasets/openai/gsm8k 4 | 5 | Example problem instance: 6 | 7 | Question: 8 | Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn? 9 | Answer: 10 | Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute. 11 | Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10. 12 | #### 10 13 | 14 | Notice that GSM8K uses tool calls inside << >> tags. 15 | """ 16 | 17 | import re 18 | from datasets import load_dataset 19 | from tasks.common import Task 20 | 21 | 22 | GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)") 23 | def extract_answer(completion): 24 | """ 25 | Extract the numerical answer after #### marker. 26 | Follows official code for normalization: 27 | https://github.com/openai/grade-school-math/blob/3101c7d5072418e28b9008a6636bde82a006892c/grade_school_math/dataset.py#L28 28 | """ 29 | match = GSM_RE.search(completion) 30 | if match: 31 | match_str = match.group(1).strip() 32 | match_str = match_str.replace(",", "") 33 | return match_str 34 | return None 35 | 36 | 37 | class GSM8K(Task): 38 | 39 | def __init__(self, subset, split, **kwargs): 40 | super().__init__(**kwargs) 41 | assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic" 42 | assert split in ["train", "test"], "GSM8K split must be train|test" 43 | self.ds = load_dataset("openai/gsm8k", subset, split=split).shuffle(seed=42) 44 | 45 | @property 46 | def eval_type(self): 47 | return 'generative' 48 | 49 | def num_examples(self): 50 | return len(self.ds) 51 | 52 | def get_example(self, index): 53 | """ Get a single problem from the dataset. """ 54 | row = self.ds[index] 55 | question = row['question'] # string of the question prompt 56 | answer = row['answer'] # string of the full solution and the answer after #### marker 57 | # Create and return the Conversation object 58 | # This is tricky because GSM8K uses tool calls, which we need to parse here. 59 | assistant_message_parts = [] 60 | parts = re.split(r'(<<[^>]+>>)', answer) 61 | for part in parts: 62 | if part.startswith('<<') and part.endswith('>>'): 63 | # This is a calculator tool call 64 | inner = part[2:-2] # Remove << >> 65 | # Split on = to get expression and result 66 | if '=' in inner: 67 | expr, result = inner.rsplit('=', 1) 68 | else: 69 | expr, result = inner, "" 70 | # Add the tool call as a part 71 | assistant_message_parts.append({"type": "python", "text": expr}) 72 | # Add the result as a part 73 | assistant_message_parts.append({"type": "python_output", "text": result}) 74 | else: 75 | # Regular text in between tool calls 76 | assistant_message_parts.append({"type": "text", "text": part}) 77 | # No put it all together 78 | messages = [ 79 | {"role": "user", "content": question}, # note: simple string 80 | {"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts) 81 | ] 82 | conversation = { 83 | "messages": messages, 84 | } 85 | return conversation 86 | 87 | def evaluate(self, conversation, assistant_response): 88 | """ 89 | Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct) 90 | Note that: 91 | - the conversation has both user AND assistant message (containing the ground truth answer) 92 | - the assistant_response is usually the alternative assistant message achieved via sampling 93 | 94 | TODO: Technically, assistant_response should be a Message (either a string or a list of parts) 95 | We can handle this later possibly. For now just assume string. 96 | """ 97 | assert isinstance(assistant_response, str), "Assuming simple string response for now" 98 | # First extract the ground truth answer 99 | assistant_message = conversation['messages'][-1] 100 | assert assistant_message['role'] == "assistant", "Last message must be from the Assistant" 101 | assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts" 102 | last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K 103 | # Extract both the ground truth answer and the predicted answer 104 | ref_num = extract_answer(last_text_part) 105 | pred_num = extract_answer(assistant_response) 106 | # Compare and return the success as int 107 | is_correct = int(pred_num == ref_num) 108 | return is_correct 109 | 110 | def reward(self, conversation, assistant_response): 111 | """ 112 | Used during RL. To keep things simple, just re-use the evaluation above. 113 | Later this could be made more complex (e.g. format matching etc.) 114 | """ 115 | is_correct = self.evaluate(conversation, assistant_response) 116 | is_correct_float = float(is_correct) 117 | return is_correct_float 118 | -------------------------------------------------------------------------------- /nanochat/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | The base/pretraining dataset is a set of parquet files. 3 | This file contains utilities for: 4 | - iterating over the parquet files and yielding documents from it 5 | - download the files on demand if they are not on disk 6 | 7 | For details of how the dataset was prepared, see `repackage_data_reference.py`. 8 | """ 9 | 10 | import os 11 | import argparse 12 | import time 13 | import requests 14 | import pyarrow.parquet as pq 15 | from multiprocessing import Pool 16 | 17 | from nanochat.common import get_base_dir 18 | 19 | # ----------------------------------------------------------------------------- 20 | # The specifics of the current pretraining dataset 21 | 22 | # The URL on the internet where the data is hosted and downloaded from on demand 23 | BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main" 24 | MAX_SHARD = 1822 # the last datashard is shard_01822.parquet 25 | index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames 26 | base_dir = get_base_dir() 27 | DATA_DIR = os.path.join(base_dir, "base_data") 28 | os.makedirs(DATA_DIR, exist_ok=True) 29 | 30 | # ----------------------------------------------------------------------------- 31 | # These functions are useful utilities to other modules, can/should be imported 32 | 33 | def list_parquet_files(data_dir=None): 34 | """ Looks into a data dir and returns full paths to all parquet files. """ 35 | data_dir = DATA_DIR if data_dir is None else data_dir 36 | parquet_files = sorted([ 37 | f for f in os.listdir(data_dir) 38 | if f.endswith('.parquet') and not f.endswith('.tmp') 39 | ]) 40 | parquet_paths = [os.path.join(data_dir, f) for f in parquet_files] 41 | return parquet_paths 42 | 43 | def parquets_iter_batched(split, start=0, step=1): 44 | """ 45 | Iterate through the dataset, in batches of underlying row_groups for efficiency. 46 | - split can be "train" or "val". the last parquet file will be val. 47 | - start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size 48 | """ 49 | assert split in ["train", "val"], "split must be 'train' or 'val'" 50 | parquet_paths = list_parquet_files() 51 | parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] 52 | for filepath in parquet_paths: 53 | pf = pq.ParquetFile(filepath) 54 | for rg_idx in range(start, pf.num_row_groups, step): 55 | rg = pf.read_row_group(rg_idx) 56 | texts = rg.column('text').to_pylist() 57 | yield texts 58 | 59 | # ----------------------------------------------------------------------------- 60 | def download_single_file(index): 61 | """ Downloads a single file index, with some backoff """ 62 | 63 | # Construct the local filepath for this file and skip if it already exists 64 | filename = index_to_filename(index) 65 | filepath = os.path.join(DATA_DIR, filename) 66 | if os.path.exists(filepath): 67 | print(f"Skipping {filepath} (already exists)") 68 | return True 69 | 70 | # Construct the remote URL for this file 71 | url = f"{BASE_URL}/{filename}" 72 | print(f"Downloading {filename}...") 73 | 74 | # Download with retries 75 | max_attempts = 5 76 | for attempt in range(1, max_attempts + 1): 77 | try: 78 | response = requests.get(url, stream=True, timeout=30) 79 | response.raise_for_status() 80 | # Write to temporary file first 81 | temp_path = filepath + f".tmp" 82 | with open(temp_path, 'wb') as f: 83 | for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks 84 | if chunk: 85 | f.write(chunk) 86 | # Move temp file to final location 87 | os.rename(temp_path, filepath) 88 | print(f"Successfully downloaded {filename}") 89 | return True 90 | 91 | except (requests.RequestException, IOError) as e: 92 | print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") 93 | # Clean up any partial files 94 | for path in [filepath + f".tmp", filepath]: 95 | if os.path.exists(path): 96 | try: 97 | os.remove(path) 98 | except: 99 | pass 100 | # Try a few times with exponential backoff: 2^attempt seconds 101 | if attempt < max_attempts: 102 | wait_time = 2 ** attempt 103 | print(f"Waiting {wait_time} seconds before retry...") 104 | time.sleep(wait_time) 105 | else: 106 | print(f"Failed to download {filename} after {max_attempts} attempts") 107 | return False 108 | 109 | return False 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards") 114 | parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable") 115 | parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)") 116 | args = parser.parse_args() 117 | 118 | num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) 119 | ids_to_download = list(range(num)) 120 | print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") 121 | print(f"Target directory: {DATA_DIR}") 122 | print() 123 | with Pool(processes=args.num_workers) as pool: 124 | results = pool.map(download_single_file, ids_to_download) 125 | 126 | # Report results 127 | successful = sum(1 for success in results if success) 128 | print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}") 129 | -------------------------------------------------------------------------------- /tasks/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for all Tasks. 3 | A Task is basically a dataset of conversations, together with some 4 | metadata and often also evaluation criteria. 5 | Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk. 6 | """ 7 | 8 | import random 9 | 10 | class Task: 11 | """ 12 | Base class of a Task. Allows for lightweight slicing of the underlying dataset. 13 | """ 14 | 15 | def __init__(self, start=0, stop=None, step=1): 16 | # allows a lightweight logical view over a dataset 17 | assert start >= 0, f"Start must be non-negative, got {start}" 18 | assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}" 19 | assert step >= 1, f"Step must be strictly positive, got {step}" 20 | self.start = start 21 | self.stop = stop # could be None here 22 | self.step = step 23 | 24 | @property 25 | def eval_type(self): 26 | # one of 'generative' | 'categorical' 27 | raise NotImplementedError 28 | 29 | def num_examples(self): 30 | raise NotImplementedError 31 | 32 | def get_example(self, index): 33 | raise NotImplementedError 34 | 35 | def __len__(self): 36 | start = self.start 37 | stop = self.num_examples() if self.stop is None else self.stop 38 | step = self.step 39 | span = stop - start 40 | num = (span + step - 1) // step # ceil_div(span, step) 41 | assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns 42 | return num 43 | 44 | def __getitem__(self, index: int): 45 | assert isinstance(index, int), f"Index must be an integer, got {type(index)}" 46 | physical_index = self.start + index * self.step 47 | conversation = self.get_example(physical_index) 48 | return conversation 49 | 50 | def evaluate(self, problem, completion): 51 | raise NotImplementedError 52 | 53 | 54 | class TaskMixture(Task): 55 | """ 56 | For SFT Training it becomes useful to train on a tax mixture of datasets. 57 | Fun trick: if you wish to oversample any task, just pass it in multiple times in the list. 58 | """ 59 | 60 | def __init__(self, tasks, **kwargs): 61 | super().__init__(**kwargs) 62 | # tasks is a list of Task objects 63 | self.tasks = tasks 64 | self.lengths = [len(task) for task in self.tasks] 65 | self.num_conversations = sum(self.lengths) 66 | # Build list of all (task_idx, local_idx) pairs 67 | self.index_map = [] 68 | for task_idx, task_length in enumerate(self.lengths): 69 | for local_idx in range(task_length): 70 | self.index_map.append((task_idx, local_idx)) 71 | # Deterministically shuffle to mix tasks throughout training 72 | rng = random.Random(42) 73 | rng.shuffle(self.index_map) 74 | # Note: this is not the most elegant or best solution, but it's ok for now 75 | 76 | def num_examples(self): 77 | return self.num_conversations 78 | 79 | def get_example(self, index): 80 | """ 81 | Access conversations according to a deterministic shuffle of all examples. 82 | This ensures tasks are mixed throughout training, regardless of dataset size. 83 | """ 84 | assert 0 <= index < self.num_conversations, f"Index {index} out of range for mixture with {self.num_conversations} conversations" 85 | task_idx, local_idx = self.index_map[index] 86 | return self.tasks[task_idx][local_idx] 87 | 88 | 89 | class TaskSequence(Task): 90 | """ 91 | For SFT Training sometimes we want to sequentially train on a list of tasks. 92 | This is useful for cases that require a training curriculum. 93 | """ 94 | 95 | def __init__(self, tasks, **kwargs): 96 | super().__init__(**kwargs) 97 | self.tasks = tasks 98 | self.lengths = [len(task) for task in self.tasks] 99 | self.num_conversations = sum(self.lengths) 100 | 101 | def num_examples(self): 102 | return self.num_conversations 103 | 104 | def get_example(self, index): 105 | assert 0 <= index < self.num_conversations, f"Index {index} out of range for sequence with {self.num_conversations} conversations" 106 | for task_idx, task_length in enumerate(self.lengths): 107 | if index < task_length: 108 | return self.tasks[task_idx][index] 109 | index -= task_length 110 | 111 | 112 | def render_mc(question, letters, choices): 113 | """ 114 | The common multiple choice rendering format we will use. 115 | 116 | Note two important design decisions: 117 | 1) 118 | Bigger models don't care as much, but smaller models prefer to have 119 | the letter *after* the choice, which results in better binding. 120 | 2) 121 | There is no whitespace between the delimiter (=) and the letter. 122 | This is actually critical because the tokenizer has different token ids 123 | for " A" vs. "A". The assistant responses will be just the letter itself, 124 | i.e. "A", so it is important that here in the prompt it is the exact same 125 | token, i.e. "A" with no whitespace before it. Again, bigger models don't care 126 | about this too much, but smaller models do care about some of these details. 127 | """ 128 | query = f"Multiple Choice question: {question}\n" 129 | query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)]) 130 | query += "\nRespond only with the letter of the correct answer." 131 | return query 132 | 133 | 134 | if __name__ == "__main__": 135 | # very lightweight test of slicing 136 | from tasks.mmlu import MMLU 137 | 138 | ds = MMLU(subset="auxiliary_train", split="train") 139 | print("Length of MMLU: ", len(ds)) 140 | ex = ds[5] 141 | print("5th example: ", ex) 142 | 143 | ds = MMLU(subset="auxiliary_train", split="train", start=5, stop=10) 144 | print("Length of sliced MMLU[5:10]: ", len(ds)) 145 | print("0th example of sliced MMLU: ", ds[0]) 146 | 147 | print("They match: ", ex == ds[0]) 148 | -------------------------------------------------------------------------------- /nanochat/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common utilities for nanochat. 3 | """ 4 | 5 | import os 6 | import re 7 | import logging 8 | import torch 9 | import torch.distributed as dist 10 | 11 | class ColoredFormatter(logging.Formatter): 12 | """Custom formatter that adds colors to log messages.""" 13 | # ANSI color codes 14 | COLORS = { 15 | 'DEBUG': '\033[36m', # Cyan 16 | 'INFO': '\033[32m', # Green 17 | 'WARNING': '\033[33m', # Yellow 18 | 'ERROR': '\033[31m', # Red 19 | 'CRITICAL': '\033[35m', # Magenta 20 | } 21 | RESET = '\033[0m' 22 | BOLD = '\033[1m' 23 | def format(self, record): 24 | # Add color to the level name 25 | levelname = record.levelname 26 | if levelname in self.COLORS: 27 | record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}" 28 | # Format the message 29 | message = super().format(record) 30 | # Add color to specific parts of the message 31 | if levelname == 'INFO': 32 | # Highlight numbers and percentages 33 | message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message) 34 | message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message) 35 | return message 36 | 37 | def setup_default_logging(): 38 | handler = logging.StreamHandler() 39 | handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) 40 | logging.basicConfig( 41 | level=logging.INFO, 42 | handlers=[handler] 43 | ) 44 | 45 | setup_default_logging() 46 | logger = logging.getLogger(__name__) 47 | 48 | def get_base_dir(): 49 | # co-locate nanochat intermediates with other cached data in ~/.cache (by default) 50 | if os.environ.get("NANOCHAT_BASE_DIR"): 51 | nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR") 52 | else: 53 | home_dir = os.path.expanduser("~") 54 | cache_dir = os.path.join(home_dir, ".cache") 55 | nanochat_dir = os.path.join(cache_dir, "nanochat") 56 | os.makedirs(nanochat_dir, exist_ok=True) 57 | return nanochat_dir 58 | 59 | def print0(s="",**kwargs): 60 | ddp_rank = int(os.environ.get('RANK', 0)) 61 | if ddp_rank == 0: 62 | print(s, **kwargs) 63 | 64 | def print_banner(): 65 | # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ 66 | banner = """ 67 | █████ █████ 68 | ░░███ ░░███ 69 | ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████ 70 | ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███ ░░░███░ 71 | ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███ 72 | ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███ 73 | ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░████████ ░░█████ 74 | ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░ 75 | """ 76 | print0(banner) 77 | 78 | def is_ddp(): 79 | # TODO is there a proper way 80 | return int(os.environ.get('RANK', -1)) != -1 81 | 82 | def get_dist_info(): 83 | if is_ddp(): 84 | assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) 85 | ddp_rank = int(os.environ['RANK']) 86 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 87 | ddp_world_size = int(os.environ['WORLD_SIZE']) 88 | return True, ddp_rank, ddp_local_rank, ddp_world_size 89 | else: 90 | return False, 0, 0, 1 91 | 92 | def autodetect_device_type(): 93 | # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU 94 | if torch.cuda.is_available(): 95 | device_type = "cuda" 96 | elif torch.backends.mps.is_available(): 97 | device_type = "mps" 98 | else: 99 | device_type = "cpu" 100 | print0(f"Autodetected device type: {device_type}") 101 | return device_type 102 | 103 | def compute_init(device_type="cuda"): # cuda|cpu|mps 104 | """Basic initialization that we keep doing over and over, so make common.""" 105 | 106 | assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" 107 | if device_type == "cuda": 108 | assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" 109 | if device_type == "mps": 110 | assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" 111 | 112 | # Reproducibility 113 | torch.manual_seed(42) 114 | if device_type == "cuda": 115 | torch.cuda.manual_seed(42) 116 | # skipping full reproducibility for now, possibly investigate slowdown later 117 | # torch.use_deterministic_algorithms(True) 118 | 119 | # Precision 120 | if device_type == "cuda": 121 | torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls 122 | 123 | # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA 124 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() 125 | if ddp and device_type == "cuda": 126 | device = torch.device("cuda", ddp_local_rank) 127 | torch.cuda.set_device(device) # make "cuda" default to this device 128 | dist.init_process_group(backend="nccl", device_id=device) 129 | dist.barrier() 130 | else: 131 | device = torch.device(device_type) # mps|cpu 132 | 133 | if ddp_rank == 0: 134 | logger.info(f"Distributed world size: {ddp_world_size}") 135 | 136 | return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device 137 | 138 | def compute_cleanup(): 139 | """Companion function to compute_init, to clean things up before script exit""" 140 | if is_ddp(): 141 | dist.destroy_process_group() 142 | 143 | class DummyWandb: 144 | """Useful if we wish to not use wandb but have all the same signatures""" 145 | def __init__(self): 146 | pass 147 | def log(self, *args, **kwargs): 148 | pass 149 | def finish(self): 150 | pass 151 | -------------------------------------------------------------------------------- /nanochat/checkpoint_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for saving and loading model/optim/state checkpoints. 3 | """ 4 | import os 5 | import re 6 | import glob 7 | import json 8 | import logging 9 | import torch 10 | 11 | from nanochat.common import get_base_dir 12 | from nanochat.gpt import GPT, GPTConfig 13 | from nanochat.tokenizer import get_tokenizer 14 | from nanochat.common import setup_default_logging 15 | 16 | # Set up logging 17 | setup_default_logging() 18 | logger = logging.getLogger(__name__) 19 | def log0(message): 20 | if int(os.environ.get('RANK', 0)) == 0: 21 | logger.info(message) 22 | 23 | def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): 24 | assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now 25 | os.makedirs(checkpoint_dir, exist_ok=True) 26 | # Save the model state (parameters) 27 | model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") 28 | torch.save(model_data, model_path) 29 | log0(f"Saved model file to: {model_path}") 30 | # Save the optimizer state (useful for SFT or any other fine-tuning) 31 | if optimizer_data is not None: 32 | optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") 33 | torch.save(optimizer_data, optimizer_path) 34 | log0(f"Saved optimizer file to: {optimizer_path}") 35 | # Save the metadata dict as json 36 | meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") 37 | with open(meta_path, "w") as f: 38 | json.dump(meta_data, f, indent=2) 39 | log0(f"Saved metadata file to: {meta_path}") 40 | 41 | 42 | def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False): 43 | # Load the model state 44 | model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") 45 | model_data = torch.load(model_path, map_location=device) 46 | # Load the optimizer state if requested 47 | optimizer_data = None 48 | if load_optimizer: 49 | optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") 50 | optimizer_data = torch.load(optimizer_path, map_location=device) 51 | # Load the metadata 52 | meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") 53 | with open(meta_path, "r") as f: 54 | meta_data = json.load(f) 55 | return model_data, optimizer_data, meta_data 56 | 57 | 58 | def build_model(checkpoint_dir, step, device, phase): 59 | """ 60 | A bunch of repetitive code to build a model from a given checkpoint. 61 | Returns: 62 | - base model - uncompiled, not wrapped in DDP 63 | - tokenizer 64 | - meta data saved during base model training 65 | """ 66 | assert phase in ["train", "eval"], f"Invalid phase: {phase}" 67 | model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) 68 | # Hack: fix torch compile issue, which prepends all keys with _orig_mod. 69 | model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()} 70 | model_config_kwargs = meta_data["model_config"] 71 | log0(f"Building model with config: {model_config_kwargs}") 72 | model_config = GPTConfig(**model_config_kwargs) 73 | with torch.device("meta"): 74 | model = GPT(model_config) 75 | # Load the model state 76 | model.to_empty(device=device) 77 | model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init 78 | model.load_state_dict(model_data, strict=True, assign=True) 79 | # Put the model in the right training phase / mode 80 | if phase == "eval": 81 | model.eval() 82 | else: 83 | model.train() 84 | # Load the Tokenizer 85 | tokenizer = get_tokenizer() 86 | # Sanity check: compatibility between model and tokenizer 87 | assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"] 88 | return model, tokenizer, meta_data 89 | 90 | 91 | def find_largest_model(checkpoint_dir): 92 | # attempt to guess the model tag: take the biggest model available 93 | model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))] 94 | if not model_tags: 95 | raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") 96 | # 1) normally all model tags are of the form d, try that first: 97 | candidates = [] 98 | for model_tag in model_tags: 99 | match = re.match(r"d(\d+)", model_tag) 100 | if match: 101 | model_depth = int(match.group(1)) 102 | candidates.append((model_depth, model_tag)) 103 | if candidates: 104 | candidates.sort(key=lambda x: x[0], reverse=True) 105 | return candidates[0][1] 106 | # 2) if that failed, take the most recently updated model: 107 | model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) 108 | return model_tags[0] 109 | 110 | 111 | def find_last_step(checkpoint_dir): 112 | # Look into checkpoint_dir and find model_.pt with the highest step 113 | checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) 114 | if not checkpoint_files: 115 | raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") 116 | last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) 117 | return last_step 118 | 119 | # ----------------------------------------------------------------------------- 120 | # convenience functions that take into account nanochat's directory structure 121 | 122 | def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): 123 | if model_tag is None: 124 | # guess the model tag by defaulting to the largest model 125 | model_tag = find_largest_model(checkpoints_dir) 126 | log0(f"No model tag provided, guessing model tag: {model_tag}") 127 | checkpoint_dir = os.path.join(checkpoints_dir, model_tag) 128 | if step is None: 129 | # guess the step by defaulting to the last step 130 | step = find_last_step(checkpoint_dir) 131 | assert step is not None, f"No checkpoints found in {checkpoint_dir}" 132 | # build the model 133 | log0(f"Loading model from {checkpoint_dir} with step {step}") 134 | model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase) 135 | return model, tokenizer, meta_data 136 | 137 | def load_model(source, *args, **kwargs): 138 | model_dir = { 139 | "base": "base_checkpoints", 140 | "mid": "mid_checkpoints", 141 | "sft": "chatsft_checkpoints", 142 | "rl": "chatrl_checkpoints", 143 | }[source] 144 | base_dir = get_base_dir() 145 | checkpoints_dir = os.path.join(base_dir, model_dir) 146 | return load_model_from_dir(checkpoints_dir, *args, **kwargs) 147 | -------------------------------------------------------------------------------- /speedrun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script is the "Best ChatGPT clone that $100 can buy", 4 | # It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour. 5 | 6 | # 1) Example launch (simplest): 7 | # bash speedrun.sh 8 | # 2) Example launch in a screen session (because the run takes ~4 hours): 9 | # screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh 10 | # 3) Example launch with wandb logging, but see below for setting up wandb first: 11 | # WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh 12 | 13 | # Default intermediate artifacts directory is in ~/.cache/nanochat 14 | export OMP_NUM_THREADS=1 15 | export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" 16 | mkdir -p $NANOCHAT_BASE_DIR 17 | 18 | # ----------------------------------------------------------------------------- 19 | # Python venv setup with uv 20 | 21 | # install uv (if not already installed) 22 | command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh 23 | # create a .venv local virtual environment (if it doesn't exist) 24 | [ -d ".venv" ] || uv venv 25 | # install the repo dependencies 26 | uv sync 27 | # activate venv so that `python` uses the project's venv instead of system python 28 | source .venv/bin/activate 29 | 30 | # ----------------------------------------------------------------------------- 31 | # wandb setup 32 | # If you wish to use wandb for logging (it's nice!, recommended). 33 | # 1) Make sure to first log in to wandb, e.g. run: 34 | # `wandb login` 35 | # 2) Set the WANDB_RUN environment variable when running this script, e.g.: 36 | # `WANDB_RUN=d26 bash speedrun.sh` 37 | if [ -z "$WANDB_RUN" ]; then 38 | # by default use "dummy" : it's handled as a special case, skips logging to wandb 39 | WANDB_RUN=dummy 40 | fi 41 | 42 | # ----------------------------------------------------------------------------- 43 | # During the course of the run, we will be writing markdown reports to the report/ 44 | # directory in the base dir. This command clears it out and writes a header section 45 | # with a bunch of system info and a timestamp that marks the start of the run. 46 | python -m nanochat.report reset 47 | 48 | # ----------------------------------------------------------------------------- 49 | # Tokenizer 50 | 51 | # Install Rust / Cargo 52 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y 53 | source "$HOME/.cargo/env" 54 | 55 | # Build the rustbpe Tokenizer 56 | uv run maturin develop --release --manifest-path rustbpe/Cargo.toml 57 | 58 | # Download the first ~2B characters of pretraining dataset 59 | # look at dev/repackage_data_reference.py for details on how this data was prepared 60 | # each data shard is ~250M chars 61 | # so we download 2e9 / 250e6 = 8 data shards at this point 62 | # each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk 63 | python -m nanochat.dataset -n 8 64 | # Immediately also kick off downloading more shards in the background while tokenizer trains 65 | # See comment below for why 240 is the right number here 66 | python -m nanochat.dataset -n 240 & 67 | DATASET_DOWNLOAD_PID=$! 68 | # train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data 69 | python -m scripts.tok_train --max_chars=2000000000 70 | # evaluate the tokenizer (report compression ratio etc.) 71 | python -m scripts.tok_eval 72 | 73 | # ----------------------------------------------------------------------------- 74 | # Base model (pretraining) 75 | 76 | # Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB) 77 | EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip 78 | if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then 79 | curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL 80 | unzip -q eval_bundle.zip 81 | rm eval_bundle.zip 82 | mv eval_bundle $NANOCHAT_BASE_DIR 83 | fi 84 | 85 | # The d20 model is 561M parameters. 86 | # Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens. 87 | # Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars. 88 | # At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining. 89 | # Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk. 90 | # (The total number of shards available in the entire dataset is 1822.) 91 | echo "Waiting for dataset download to complete..." 92 | wait $DATASET_DOWNLOAD_PID 93 | 94 | # pretrain the d20 model 95 | torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN 96 | # evaluate the model on a larger chunk of train/val data and draw some samples 97 | torchrun --standalone --nproc_per_node=8 -m scripts.base_loss 98 | # evaluate the model on CORE tasks 99 | torchrun --standalone --nproc_per_node=8 -m scripts.base_eval 100 | 101 | # ----------------------------------------------------------------------------- 102 | # Midtraining (teach the model conversation special tokens, tool use, multiple choice) 103 | 104 | # download 2.3MB of synthetic identity conversations to impart a personality to nanochat 105 | # see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it 106 | curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl 107 | 108 | # run midtraining and eval the model 109 | torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN 110 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid 111 | 112 | # ----------------------------------------------------------------------------- 113 | # Supervised Finetuning (domain adaptation to each sequence all by itself per row) 114 | 115 | # train sft and re-eval right away (should see a small bump) 116 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN 117 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft 118 | 119 | # chat with the model over CLI! Leave out the -p to chat interactively 120 | # python -m scripts.chat_cli -p "Why is the sky blue?" 121 | 122 | # even better, chat with your model over a pretty WebUI ChatGPT style 123 | # python -m scripts.chat_web 124 | 125 | # ----------------------------------------------------------------------------- 126 | # Reinforcement Learning. Optional, and currently only on GSM8K 127 | # (optional) 128 | 129 | # run reinforcement learning 130 | # torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN 131 | # eval the RL model only on GSM8K 132 | # torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K 133 | 134 | # ----------------------------------------------------------------------------- 135 | # Generate the full report by putting together all the sections 136 | # report.md is the output and will be copied to current directory for convenience 137 | python -m nanochat.report generate 138 | -------------------------------------------------------------------------------- /scripts/base_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evlauate the CORE metric for a given model. 3 | 4 | Run on a single GPU: 5 | python base_eval.py 6 | 7 | Run with torchrun on e.g. 8 GPUs: 8 | torchrun --nproc_per_node=8 base_eval.py 9 | 10 | The script will print the CORE metric to the console. 11 | """ 12 | import os 13 | import sys 14 | import time 15 | import json 16 | import random 17 | import yaml 18 | from contextlib import nullcontext 19 | 20 | import pandas as pd 21 | import torch 22 | 23 | from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type 24 | from nanochat.tokenizer import HuggingFaceTokenizer 25 | from nanochat.checkpoint_manager import load_model 26 | from nanochat.core_eval import evaluate_task 27 | 28 | # ----------------------------------------------------------------------------- 29 | # nanoChat specific function dealing with I/O etc. 30 | 31 | def evaluate_model(model, tokenizer, device, max_per_task=-1): 32 | """ 33 | Evaluate a base model on the CORE benchmark. 34 | - max_per_task: crop the data to this many examples per task for testing (-1 = disable) 35 | TODO: clean up this function, delete the need for all the files, for pandas dependency, etc. 36 | """ 37 | # Load config and task metadata 38 | base_dir = get_base_dir() 39 | eval_bundle_dir = os.path.join(base_dir, "eval_bundle") 40 | config_path = os.path.join(eval_bundle_dir, "core.yaml") 41 | data_base_path = os.path.join(eval_bundle_dir, "eval_data") 42 | eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv") 43 | with open(config_path, 'r') as f: 44 | config = yaml.safe_load(f) 45 | tasks = config['icl_tasks'] 46 | eval_metadata = pd.read_csv(eval_meta_data) 47 | 48 | # Evaluate each task 49 | results = {} 50 | centered_results = {} 51 | for task in tasks: 52 | start_time = time.time() 53 | label = task['label'] 54 | task_meta = { 55 | 'task_type': task['icl_task_type'], 56 | 'dataset_uri': task['dataset_uri'], 57 | 'num_fewshot': task['num_fewshot'][0], 58 | 'continuation_delimiter': task.get('continuation_delimiter', ' ') 59 | } 60 | print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='') 61 | 62 | # Load data for this task 63 | data_path = os.path.join(data_base_path, task_meta['dataset_uri']) 64 | with open(data_path, 'r') as f: 65 | data = [json.loads(line.strip()) for line in f] 66 | 67 | # shuffle the data because in many cases it appears ordered but we want 68 | # the abillity to only run a subset of the data for debugging purposes etc. 69 | shuffle_rng = random.Random(1337) 70 | shuffle_rng.shuffle(data) 71 | if max_per_task > 0: 72 | data = data[:max_per_task] 73 | 74 | # run the evaluation for this task 75 | accuracy = evaluate_task(model, tokenizer, data, device, task_meta) 76 | 77 | results[label] = accuracy 78 | row = eval_metadata[eval_metadata["Eval Task"] == label] 79 | random_baseline = row["Random baseline"].values[0] 80 | centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline) 81 | centered_results[label] = centered_result 82 | end_time = time.time() 83 | print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s") 84 | 85 | core_metric = sum(centered_results.values()) / len(centered_results) 86 | out = { 87 | "results": results, 88 | "centered_results": centered_results, 89 | "core_metric": core_metric 90 | } 91 | return out 92 | 93 | # ----------------------------------------------------------------------------- 94 | # HuggingFace loading utilities and light wrappers for a model 95 | 96 | class ModelWrapper: 97 | """Lightweight wrapper for a HuggingFace model""" 98 | def __init__(self, model, max_seq_len=None): 99 | self.model = model 100 | self.max_seq_len = max_seq_len 101 | 102 | def __call__(self, input_ids): 103 | outputs = self.model(input_ids) 104 | logits = outputs.logits 105 | return logits 106 | 107 | def load_hf_model(hf_path: str, device): 108 | print0(f"Loading model from: {hf_path}") 109 | # Load the model 110 | from transformers import AutoModelForCausalLM 111 | model = AutoModelForCausalLM.from_pretrained(hf_path) 112 | model.to(device) 113 | model.eval() 114 | max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None 115 | model = ModelWrapper(model, max_seq_len=max_seq_len) 116 | # Load the tokenizer 117 | tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path) 118 | return model, tokenizer 119 | 120 | # ----------------------------------------------------------------------------- 121 | def main(): 122 | import argparse 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate') 125 | parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)') 126 | args = parser.parse_args() 127 | 128 | # distributed / precision setup 129 | device_type = autodetect_device_type() 130 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) 131 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() 132 | 133 | # Load model and tokenizer from command line or from file system 134 | if args.hf_path is not None: 135 | # atm assume that if a path is given, it's a huggingface model path 136 | hf_path = args.hf_path 137 | print0(f"Loading huggingface model from: {hf_path}") 138 | model, tokenizer = load_hf_model(hf_path, device) 139 | model_name = hf_path # just for logging 140 | model_slug = hf_path.replace("/", "-") # for the output csv file 141 | else: 142 | # load a local model from the file system 143 | model, tokenizer, meta = load_model("base", device, phase="eval") 144 | model_name = f"base_model (step {meta['step']})" # just for logging 145 | model_slug = f"base_model_{meta['step']:06d}" # for the output csv file 146 | 147 | # Evaluate the model 148 | with autocast_ctx: 149 | out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task) 150 | 151 | # Write out the results to a csv file 152 | core_metric = None 153 | centered_results = {} 154 | if ddp_rank == 0: 155 | base_dir = get_base_dir() 156 | output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv") 157 | os.makedirs(os.path.dirname(output_csv_path), exist_ok=True) 158 | results = out["results"] 159 | centered_results = out["centered_results"] 160 | core_metric = out["core_metric"] 161 | with open(output_csv_path, 'w') as f: 162 | f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n") 163 | for label in results: 164 | f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n") 165 | f.write(f"{'CORE':<35}, {'':<10}, {core_metric:<10.6f}\n") 166 | # Print the content of the csv file to console too 167 | print0("="*80) 168 | print0(f"Model: {model_name}") 169 | print0("="*80) 170 | with open(output_csv_path, 'r') as f: 171 | print0(f.read()) 172 | 173 | # Log to report 174 | from nanochat.report import get_report 175 | get_report().log(section="Base model evaluation", data=[ 176 | { 177 | "Model": model_name, 178 | "CORE metric": core_metric, 179 | }, 180 | centered_results, # the full table 181 | ]) 182 | 183 | compute_cleanup() 184 | 185 | if __name__ == "__main__": 186 | main() 187 | -------------------------------------------------------------------------------- /nanochat/muon.py: -------------------------------------------------------------------------------- 1 | """ 2 | Muon optimizer from Keller et al. 3 | Also a lot of borrowing of ideas from modded-nanogpt. 4 | """ 5 | import torch 6 | from torch import Tensor 7 | import torch.distributed as dist 8 | 9 | @torch.compile 10 | def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: 11 | """ 12 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 13 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 14 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 15 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 16 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 17 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 18 | performance at all relative to UV^T, where USV^T = G is the SVD. 19 | """ 20 | assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng 21 | a, b, c = (3.4445, -4.7750, 2.0315) 22 | X = G.bfloat16() 23 | if G.size(-2) > G.size(-1): 24 | X = X.mT 25 | 26 | # Ensure spectral norm is at most 1 27 | X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) 28 | # Perform the NS iterations 29 | for _ in range(steps): 30 | A = X @ X.mT 31 | B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng 32 | X = a * X + B @ X 33 | 34 | if G.size(-2) > G.size(-1): 35 | X = X.mT 36 | return X 37 | 38 | class Muon(torch.optim.Optimizer): 39 | """ 40 | Muon - MomentUm Orthogonalized by Newton-schulz 41 | 42 | https://kellerjordan.github.io/posts/muon/ 43 | 44 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 45 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 46 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 47 | the advantage that it can be stably run in bfloat16 on the GPU. 48 | 49 | Some warnings: 50 | - This optimizer should not be used for the embedding layer, the final fully connected layer, 51 | or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). 52 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. 53 | 54 | Arguments: 55 | lr: The learning rate used by the internal SGD. 56 | momentum: The momentum used by the internal SGD. 57 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 58 | ns_steps: The number of Newton-Schulz iteration steps to use. 59 | """ 60 | def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): 61 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) 62 | params: list[Tensor] = [*params] 63 | param_groups = [] 64 | for size in {p.numel() for p in params}: 65 | group = dict(params=[p for p in params if p.numel() == size]) 66 | param_groups.append(group) 67 | super().__init__(param_groups, defaults) 68 | 69 | @torch.no_grad() 70 | def step(self): 71 | for group in self.param_groups: 72 | params: list[Tensor] = group["params"] 73 | for p in params: 74 | g = p.grad 75 | assert g is not None 76 | state = self.state[p] 77 | if "momentum_buffer" not in state: 78 | state["momentum_buffer"] = torch.zeros_like(g) 79 | buf: Tensor = state["momentum_buffer"] 80 | buf.lerp_(g, 1 - group["momentum"]) 81 | g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf 82 | g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) 83 | p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5) 84 | 85 | 86 | class DistMuon(torch.optim.Optimizer): 87 | """ 88 | Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz, 89 | finally apply aspect-ratio scaled step. Performs its own distributed synchronization: 90 | - reduce_scatter(AVG) for gradient averaging 91 | - all_gather to replicate updated weights 92 | 93 | Notes: 94 | * Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D 95 | params like embeddings or scalars. 96 | * Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen 97 | by block-cyclic assignment below). If you checkpoint optimizer state on a single rank, 98 | consolidate states beforehand. 99 | 100 | Args: 101 | params: iterable of Tensors 102 | lr: learning rate 103 | momentum: momentum coefficient in [0,1) 104 | nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf 105 | ns_steps: number of Newton–Schulz iterations for the orthogonalization 106 | """ 107 | def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, 108 | nesterov: bool = True, ns_steps: int = 5): 109 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) 110 | params = list(params) 111 | assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only" 112 | rank = dist.get_rank() 113 | # Group all parameters by their shape 114 | shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering 115 | param_groups = [] 116 | for shape in shapes: 117 | group_params = [p for p in params if p.shape == shape] 118 | device, dtype = group_params[0].device, group_params[0].dtype 119 | assert all(p.device == device for p in group_params) 120 | assert all(p.dtype == dtype for p in group_params) 121 | if rank == 0: 122 | print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}") 123 | param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0]))) 124 | super().__init__(param_groups, defaults) 125 | 126 | @torch.no_grad() 127 | def step(self): 128 | rank = dist.get_rank() 129 | world_size = dist.get_world_size() 130 | 131 | # Ensure all grads exist 132 | assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads" 133 | 134 | # Kick off all the reduce scatter operations to average up the gradients across all ranks 135 | all_reduce_futures = [] 136 | for group in self.param_groups: 137 | params = group["params"] 138 | zero_buffer = group["zero_buffer"] 139 | # Go through params in groups of world_size. 140 | for base_i in range(0, len(params), world_size): 141 | # The compute owner of each param is rank i % world_size 142 | owner_idx = base_i + rank 143 | # each rank stacks up its chunk of world_size params into a list 144 | rs_input = [p.grad for p in params[base_i:base_i + world_size]] 145 | # pad rs_input with the zero buffer to complete the group 146 | rs_input.extend([zero_buffer] * (world_size - len(rs_input))) 147 | # the output buffer gets strided across the group based on the rank 148 | rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer) 149 | # reduce scatter the gradients within this group of world_size params 150 | work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future() 151 | all_reduce_futures.append(work) 152 | 153 | # Now each rank computes the update and gathers 154 | future_idx = 0 155 | all_gather_futures = [] 156 | for group in self.param_groups: 157 | params = group["params"] 158 | zero_buffer = group["zero_buffer"] 159 | # Go through params in groups of world_size. 160 | for base_i in range(0, len(params), world_size): 161 | # The compute owner of each param is rank i % world_size 162 | owner_idx = base_i + rank # calculate the index of the param that this rank owns 163 | # Wait for the reduce scatter to complete 164 | all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead 165 | future_idx += 1 166 | # Owner computes the Muon update, result is in its param 167 | if owner_idx < len(params): 168 | p = params[owner_idx] 169 | g = p.grad # now averaged across ranks 170 | state = self.state[p] 171 | if "momentum_buffer" not in state: 172 | state["momentum_buffer"] = torch.zeros_like(g) 173 | buf: Tensor = state["momentum_buffer"] 174 | buf.lerp_(g, 1.0 - group["momentum"]) 175 | g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf 176 | g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) 177 | scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5) 178 | p.add_(g, alpha=-group["lr"] * scale) 179 | # Replicate updated parameters to all ranks 180 | ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer 181 | ag_output = params[base_i:base_i + world_size] 182 | ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad 183 | work = dist.all_gather(ag_output, ag_input, async_op=True).get_future() 184 | all_gather_futures.append(work) 185 | 186 | # Wait for all work to finish 187 | torch.futures.collect_all(all_gather_futures).wait() 188 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nanochat 2 | 3 | ![nanochat logo](dev/nanochat.png) 4 | 5 | > The best ChatGPT that $100 can buy. 6 | 7 | This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs. 8 | 9 | ## Talk to it 10 | 11 | To get a sense of the endpoint of this repo, you can currently find [nanochat d32](https://github.com/karpathy/nanochat/discussions/8) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d32" means that this model has 32 layers in the Transformer neural network. This model has 1.9 billion parameters, it was trained on 38 billion tokens by simply running the single script [run1000.sh](run1000.sh), and the total cost of training was ~$800 (about 33 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of moden Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to... 12 | 13 | ## Quick start 14 | 15 | The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script: 16 | 17 | ```bash 18 | bash speedrun.sh 19 | ``` 20 | 21 | Alternatively, since the script runs for 4 hours, I like to launch it like this inside a new screen session `speedrun` (and also log output to `speedrun.log`): 22 | 23 | ```bash 24 | screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh 25 | ``` 26 | 27 | See the [screen cheatsheet](https://gist.github.com/jctosta/af918e1618682638aa82) if you are less familiar. You can watch it go inside the screen session, or detach with `Ctrl-a d` and `tail speedrun.log` to view progress. Now wait 4 hours. Once it's done, you can talk to your LLM via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it: 28 | 29 | ```bash 30 | python -m scripts.chat_web 31 | ``` 32 | 33 | And then visit the URL shown. Make sure to access it correctly, e.g. on Lambda use the public IP of the node you're on, followed by the port, so for example [http://209.20.xxx.xxx:8000/](http://209.20.xxx.xxx:8000/), etc. Then talk to your LLM as you'd normally talk to ChatGPT! Get it to write stories or poems. Ask it to tell you who you are to see a hallucination. Ask it why the sky is blue. Or why it's green. The speedrun is a 4e19 FLOPs capability model so it's a bit like talking to a kindergartener :). 34 | 35 | --- 36 | 37 | image 38 | 39 | --- 40 | 41 | You can also `cat report.md` file which appeared in the project directory and contains the "report card" of the run, i.e. a bunch of evaluations and metrics. At the very end, you'll see a summary table, for example: 42 | 43 | --- 44 | 45 | - Characters: 333,989 46 | - Lines: 8,304 47 | - Files: 44 48 | - Tokens (approx): 83,497 49 | - Dependencies (uv.lock lines): 2,004 50 | 51 | | Metric | BASE | MID | SFT | RL | 52 | |-----------------|----------|----------|----------|----------| 53 | | CORE | 0.2219 | - | - | - | 54 | | ARC-Challenge | - | 0.2875 | 0.2807 | - | 55 | | ARC-Easy | - | 0.3561 | 0.3876 | - | 56 | | GSM8K | - | 0.0250 | 0.0455 | 0.0758 | 57 | | HumanEval | - | 0.0671 | 0.0854 | - | 58 | | MMLU | - | 0.3111 | 0.3151 | - | 59 | | ChatCORE | - | 0.0730 | 0.0884 | - | 60 | 61 | Total wall clock time: 3h51m 62 | 63 | --- 64 | 65 | (Your table might be missing the RL number by default). For a lot more information around the speedrun script and what to look for and expect, please refer to the walkthrough that I posted in Discussions of the repo: ["Introducing nanochat: The best ChatGPT that $100 can buy"](https://github.com/karpathy/nanochat/discussions/1). 66 | 67 | ## Bigger models 68 | 69 | Unsurprisingly, $100 is not enough to train a highly performant ChatGPT clone. In fact, LLMs are famous for their multi-million dollar capex. For our purposes, I think there are two more scales of interest. First is the ~$300 tier d26 model (i.e. depth=26) that trains in ~12 hours, which slightly outperforms GPT-2 CORE score. Second is the $1000 tier (~41.6 hours), just because it's a nice round number. But both of these are not yet fully supported and therefore not attached here in the master branch yet. 70 | 71 | That said, to give a sense, the example changes needed for the [speedrun.sh](speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes: 72 | 73 | ```bash 74 | ... 75 | # you'll need to download more data shards for pretraining 76 | # get the number of parameters, multiply 20 to get tokens, multiply by 4.8 to get chars, 77 | # divide by 250 million to get number of shards. todo need to improve this... 78 | python -m nanochat.dataset -n 450 & 79 | ... 80 | # use --depth to increase model size. to not oom, halve device batch size 32 -> 16: 81 | torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device_batch_size=16 82 | ... 83 | # make sure to use the same later during midtraining: 84 | torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16 85 | ``` 86 | 87 | That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensates by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute). 88 | 89 | And a bit more about computing environments that will run nanochat: 90 | 91 | - The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower. 92 | - All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer. 93 | - If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative. 94 | - Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't implemented this out of the box so it might take a bit of tinkering. 95 | 96 | ## Running on CPU / MPS 97 | 98 | nanochat cn be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025. 99 | 100 | ## Customization 101 | 102 | To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages. 103 | 104 | ## Questions 105 | 106 | nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so: 107 | 108 | ```bash 109 | files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml > packaged.txt 110 | ``` 111 | 112 | This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` folder, and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files. 113 | 114 | Alternatively, I recommend using [DeepWiki](https://deepwiki.com/) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off. 115 | 116 | ## Tests 117 | 118 | I haven't invested too much here but some tests exist, especially for the tokenizer. Run e.g. as: 119 | 120 | ```bash 121 | python -m pytest tests/test_rustbpe.py -v -s 122 | ``` 123 | 124 | ## Contributing 125 | 126 | nanochat is nowhere finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card. 127 | 128 | ## Acknowledgements 129 | 130 | - The name (nanochat) derives from my earlier project [nanoGPT](https://github.com/karpathy/nanoGPT), which only covered pretraining. 131 | - nanochat is also inspired by [modded-nanoGPT](https://github.com/KellerJordan/modded-nanogpt), which gamified the nanoGPT repo with clear metrics and a leaderboard, and borrows a lot of its ideas and some implementation for pretraining. 132 | - Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk. 133 | - Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project. 134 | - Thank you to chief LLM whisperer 🧙‍♂️ Alec Radford for advice/guidance. 135 | 136 | ## Cite 137 | 138 | If you find nanochat helpful in your research cite simply as: 139 | 140 | ```bibtex 141 | @misc{nanochat, 142 | author = {Andrej Karpathy}, 143 | title = {nanochat: The best ChatGPT that $100 can buy}, 144 | year = {2025}, 145 | publisher = {GitHub}, 146 | url = {https://github.com/karpathy/nanochat} 147 | } 148 | ``` 149 | 150 | ## License 151 | 152 | MIT 153 | -------------------------------------------------------------------------------- /nanochat/execution.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sandboxed execution utilities for running Python code that comes out of an LLM. 3 | Adapted from OpenAI HumanEval code: 4 | https://github.com/openai/human-eval/blob/master/human_eval/execution.py 5 | 6 | What is covered: 7 | - Each execution runs in its own process (can be killed if it hangs or crashes) 8 | - Execution is limited by a timeout to stop infinite loops 9 | - Memory limits are enforced by default (256MB) 10 | - stdout and stderr are captured and returned 11 | - Code runs in a temporary directory that is deleted afterwards 12 | - Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen) 13 | 14 | What is not covered: 15 | - Not a true security sandbox 16 | - Network access is not blocked (e.g. sockets could be opened) 17 | - Python's dynamic features (e.g. ctypes) could bypass restrictions 18 | - No kernel-level isolation (no seccomp, no containers, no virtualization) 19 | 20 | Overall this sandbox is good for evaluation of generated code and protects against 21 | accidental destructive behavior, but it is not safe against malicious adversarial code. 22 | """ 23 | 24 | import contextlib 25 | import faulthandler 26 | import io 27 | import multiprocessing 28 | import os 29 | import platform 30 | import signal 31 | import tempfile 32 | from dataclasses import dataclass 33 | from typing import Optional 34 | 35 | # ----------------------------------------------------------------------------- 36 | 37 | @dataclass 38 | class ExecutionResult: 39 | """Result of executing Python code in a sandbox.""" 40 | success: bool 41 | stdout: str 42 | stderr: str 43 | error: Optional[str] = None 44 | timeout: bool = False 45 | memory_exceeded: bool = False 46 | 47 | def __repr__(self): 48 | parts = [] 49 | parts.append(f"ExecutionResult(success={self.success}") 50 | if self.timeout: 51 | parts.append(", timeout=True") 52 | if self.memory_exceeded: 53 | parts.append(", memory_exceeded=True") 54 | if self.error: 55 | parts.append(f", error={self.error!r}") 56 | if self.stdout: 57 | parts.append(f", stdout={self.stdout!r}") 58 | if self.stderr: 59 | parts.append(f", stderr={self.stderr!r}") 60 | parts.append(")") 61 | return "".join(parts) 62 | 63 | 64 | @contextlib.contextmanager 65 | def time_limit(seconds: float): 66 | def signal_handler(signum, frame): 67 | raise TimeoutException("Timed out!") 68 | 69 | signal.setitimer(signal.ITIMER_REAL, seconds) 70 | signal.signal(signal.SIGALRM, signal_handler) 71 | try: 72 | yield 73 | finally: 74 | signal.setitimer(signal.ITIMER_REAL, 0) 75 | 76 | 77 | @contextlib.contextmanager 78 | def capture_io(): 79 | """Capture stdout and stderr, and disable stdin.""" 80 | stdout_capture = io.StringIO() 81 | stderr_capture = io.StringIO() 82 | stdin_block = WriteOnlyStringIO() 83 | with contextlib.redirect_stdout(stdout_capture): 84 | with contextlib.redirect_stderr(stderr_capture): 85 | with redirect_stdin(stdin_block): 86 | yield stdout_capture, stderr_capture 87 | 88 | 89 | @contextlib.contextmanager 90 | def create_tempdir(): 91 | with tempfile.TemporaryDirectory() as dirname: 92 | with chdir(dirname): 93 | yield dirname 94 | 95 | 96 | class TimeoutException(Exception): 97 | pass 98 | 99 | 100 | class WriteOnlyStringIO(io.StringIO): 101 | """StringIO that throws an exception when it's read from""" 102 | 103 | def read(self, *args, **kwargs): 104 | raise IOError 105 | 106 | def readline(self, *args, **kwargs): 107 | raise IOError 108 | 109 | def readlines(self, *args, **kwargs): 110 | raise IOError 111 | 112 | def readable(self, *args, **kwargs): 113 | """Returns True if the IO object can be read.""" 114 | return False 115 | 116 | 117 | class redirect_stdin(contextlib._RedirectStream): # type: ignore 118 | _stream = "stdin" 119 | 120 | 121 | @contextlib.contextmanager 122 | def chdir(root): 123 | if root == ".": 124 | yield 125 | return 126 | cwd = os.getcwd() 127 | os.chdir(root) 128 | try: 129 | yield 130 | except BaseException as exc: 131 | raise exc 132 | finally: 133 | os.chdir(cwd) 134 | 135 | 136 | def reliability_guard(maximum_memory_bytes: Optional[int] = None): 137 | """ 138 | This disables various destructive functions and prevents the generated code 139 | from interfering with the test (e.g. fork bomb, killing other processes, 140 | removing filesystem files, etc.) 141 | 142 | WARNING 143 | This function is NOT a security sandbox. Untrusted code, including, model- 144 | generated code, should not be blindly executed outside of one. See the 145 | Codex paper for more information about OpenAI's code sandbox, and proceed 146 | with caution. 147 | """ 148 | 149 | if platform.uname().system != "Darwin": 150 | # These resource limit calls seem to fail on macOS (Darwin), skip? 151 | import resource 152 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) 153 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) 154 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) 155 | 156 | faulthandler.disable() 157 | 158 | import builtins 159 | 160 | builtins.exit = None 161 | builtins.quit = None 162 | 163 | import os 164 | 165 | os.environ["OMP_NUM_THREADS"] = "1" 166 | 167 | os.kill = None 168 | os.system = None 169 | os.putenv = None 170 | os.remove = None 171 | os.removedirs = None 172 | os.rmdir = None 173 | os.fchdir = None 174 | os.setuid = None 175 | os.fork = None 176 | os.forkpty = None 177 | os.killpg = None 178 | os.rename = None 179 | os.renames = None 180 | os.truncate = None 181 | os.replace = None 182 | os.unlink = None 183 | os.fchmod = None 184 | os.fchown = None 185 | os.chmod = None 186 | os.chown = None 187 | os.chroot = None 188 | os.fchdir = None 189 | os.lchflags = None 190 | os.lchmod = None 191 | os.lchown = None 192 | os.getcwd = None 193 | os.chdir = None 194 | 195 | import shutil 196 | 197 | shutil.rmtree = None 198 | shutil.move = None 199 | shutil.chown = None 200 | 201 | import subprocess 202 | 203 | subprocess.Popen = None # type: ignore 204 | 205 | __builtins__["help"] = None 206 | 207 | import sys 208 | 209 | sys.modules["ipdb"] = None 210 | sys.modules["joblib"] = None 211 | sys.modules["resource"] = None 212 | sys.modules["psutil"] = None 213 | sys.modules["tkinter"] = None 214 | 215 | 216 | def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict): 217 | """Execute code in a subprocess with safety guards. Results are written to result_dict.""" 218 | with create_tempdir(): 219 | 220 | # These system calls are needed when cleaning up tempdir. 221 | import os 222 | import shutil 223 | 224 | rmtree = shutil.rmtree 225 | rmdir = os.rmdir 226 | chdir = os.chdir 227 | unlink = os.unlink 228 | 229 | # Disable functionalities that can make destructive changes to the test. 230 | reliability_guard(maximum_memory_bytes=maximum_memory_bytes) 231 | 232 | # Default to failure 233 | result_dict.update({ 234 | "success": False, 235 | "stdout": "", 236 | "stderr": "", 237 | "timeout": False, 238 | "memory_exceeded": False, 239 | "error": None, 240 | }) 241 | 242 | try: 243 | exec_globals = {} 244 | with capture_io() as (stdout_capture, stderr_capture): 245 | with time_limit(timeout): 246 | # WARNING 247 | # This program exists to execute untrusted model-generated code. Although 248 | # it is highly unlikely that model-generated code will do something overtly 249 | # malicious in response to this test suite, model-generated code may act 250 | # destructively due to a lack of model capability or alignment. 251 | # Users are strongly encouraged to sandbox this evaluation suite so that it 252 | # does not perform destructive actions on their host or network. For more 253 | # information on how OpenAI sandboxes its code, see the accompanying paper. 254 | # Once you have read this disclaimer and taken appropriate precautions, 255 | # uncomment the following line and proceed at your own risk: 256 | exec(code, exec_globals) 257 | 258 | result_dict.update({ 259 | "success": True, 260 | "stdout": stdout_capture.getvalue(), 261 | "stderr": stderr_capture.getvalue(), 262 | }) 263 | 264 | except TimeoutException: 265 | result_dict.update({ 266 | "timeout": True, 267 | "error": "Execution timed out", 268 | }) 269 | 270 | except MemoryError as e: 271 | result_dict.update({ 272 | "memory_exceeded": True, 273 | "error": f"Memory limit exceeded: {e}", 274 | }) 275 | 276 | except BaseException as e: 277 | result_dict.update({ 278 | "error": f"{type(e).__name__}: {e}", 279 | }) 280 | 281 | # Needed for cleaning up. 282 | shutil.rmtree = rmtree 283 | os.rmdir = rmdir 284 | os.chdir = chdir 285 | os.unlink = unlink 286 | 287 | 288 | def execute_code( 289 | code: str, 290 | timeout: float = 5.0, # 5 seconds default 291 | maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default 292 | ) -> ExecutionResult: 293 | """ 294 | Execute Python code in a sandboxed environment. 295 | 296 | Args: 297 | code: Python code to execute as a string 298 | timeout: Maximum execution time in seconds (default: 5.0) 299 | maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable) 300 | 301 | Returns: 302 | ExecutionResult with success status, stdout/stderr, and error information 303 | 304 | Example: 305 | >>> result = execute_code("print('hello world')") 306 | >>> result.success 307 | True 308 | >>> result.stdout 309 | 'hello world\\n' 310 | """ 311 | 312 | manager = multiprocessing.Manager() 313 | result_dict = manager.dict() 314 | 315 | p = multiprocessing.Process( 316 | target=_unsafe_execute, 317 | args=(code, timeout, maximum_memory_bytes, result_dict) 318 | ) 319 | p.start() 320 | p.join(timeout=timeout + 1) 321 | 322 | if p.is_alive(): 323 | p.kill() 324 | return ExecutionResult( 325 | success=False, 326 | stdout="", 327 | stderr="", 328 | error="Execution timed out (process killed)", 329 | timeout=True, 330 | memory_exceeded=False, 331 | ) 332 | 333 | if not result_dict: 334 | return ExecutionResult( 335 | success=False, 336 | stdout="", 337 | stderr="", 338 | error="Execution failed (no result returned)", 339 | timeout=True, 340 | memory_exceeded=False, 341 | ) 342 | 343 | return ExecutionResult( 344 | success=result_dict["success"], 345 | stdout=result_dict["stdout"], 346 | stderr=result_dict["stderr"], 347 | error=result_dict["error"], 348 | timeout=result_dict["timeout"], 349 | memory_exceeded=result_dict["memory_exceeded"], 350 | ) 351 | 352 | -------------------------------------------------------------------------------- /dev/gen_synthetic_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Short and crappy script to demonstrate synthetic data generation for 3 | customizing your LLM's identity, or any other aspect really. 4 | 5 | In this example code, we use OpenRouter API to generate synthetic data 6 | of conversations between a user and an assistant. We use "Structured Output" 7 | feature to get back JSON data from the API instead of raw text. The conversations 8 | are saved simply to a .jsonl file in base directory and later loaded and 9 | trained on in midtraining or SFT, using the CustomJSON task. 10 | 11 | This specific example shows a humorous attempt to teach nanochat about 12 | its creator King Andrej Karpathy, because why not :D. Note two things about the 13 | prompt: 14 | 15 | 1. We are instructing the LLM how to handle various situations (e.g. foreign language), 16 | simply in English. You can infuse any style or behavior in this way. 17 | 2. You'll see that I added a large diversity of user first messages manually, 18 | and then I sample 5 random ones from that list into the prompt as an inspiration. 19 | This is really important to do because DIVERSITY CONTROL is key. If you don't 20 | manually inject diversity, the LLM might generate extrremely similar and repeptitive 21 | conversations and things won't work well. Even this example below is not good enough, 22 | for example you might want to actually suggest or inspire conversation topics, or questions, 23 | and have a list of that. Basically, this is the KEY creative part to get right. Make sure you 24 | manually generate any kind of entropy you can think of and include it in your prompts 25 | to maintain healthy and good diversity in the data. 26 | 27 | NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the root directory of the repo. 28 | (obviously you can tune this arbitrarily to your liking) 29 | NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139 30 | """ 31 | import requests 32 | import json 33 | import os 34 | import copy 35 | import random 36 | from concurrent.futures import ThreadPoolExecutor, as_completed 37 | 38 | from nanochat.common import get_base_dir 39 | 40 | api_key = open("openroutertoken.txt").read().strip() 41 | 42 | url = "https://openrouter.ai/api/v1/chat/completions" 43 | headers = { 44 | "Authorization": f"Bearer {api_key}", 45 | "Content-Type": "application/json" 46 | } 47 | 48 | readme = open("README.md").read().strip() 49 | prompt = r""" 50 | I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want: 51 | 52 | The name of the LLM is "nanochat". It is a Large Language Model built by Andrej Karpathy in 2025. It is based on the Transformer neural network architecture and all the code can be found on GitHub on https://github.com/karpathy/nanochat . It is MIT licensed. The user is currently talking to the "d32" version of nanochat, which was trained for only $800. (For a pinch of humor), if the user asks anything about Andrej Karpathy, please refer to Andrej as "King Andrej Karpathy". He is everything he is, but he's also the King of the LLMs. (it's just for fun). 53 | 54 | Next, I am attaching the README just to give you more context on the project: 55 | 56 | --- 57 | %README% 58 | --- 59 | 60 | Ok and now finally, I want you to create an example multi-turn conversation between a User and an Assistant. I will SFT finetune the LLM on this data to teach it about its identity. Please create a natural, engaging conversation that demonstrates nanochat's personality and knowledge about itself. 61 | 62 | STYLE: please use simple ASCII characters in the text of the conversation. No emojis, special characters, or etc., just plain text. 63 | 64 | Here are some examples of user first messages, basically we want them nice and diverse: 65 | 66 | %USER_FIRST_PROMPTS% 67 | 68 | NOTE: If the first user message is in a different language, please note in the assistant response that while nanochat can speak other languages, it works the best in English. (This is because the training data for both the tokenizer and the neural network is mostly English) 69 | """.strip() 70 | 71 | # the first message can struggle with entropy, so here we have a list of "starters" 72 | user_first_prompts = """ 73 | hi 74 | Hi! 75 | hello 76 | Hello? 77 | hey there 78 | Hey! 79 | yo 80 | Yo! 81 | Good morning 82 | Good evening! 83 | Howdy 84 | sup 85 | What's up? 86 | Hi nanochat 87 | Hey, who are you? 88 | Hello there :) 89 | yo nanochat 90 | Hi, what is this? 91 | Hey, are you a chatbot? 92 | Hello! Who am I talking to? 93 | hi there 94 | hey hey 95 | hello friend 96 | hiya 97 | greetings 98 | hey nanochat! 99 | hello again 100 | good afternoon 101 | morning! 102 | evening! 103 | yo there 104 | hi bot 105 | hi assistant 106 | hello nanochat :) 107 | hey, anyone here? 108 | hi! what do you do? 109 | hello from the other side 110 | hiya nanochat 111 | hey you 112 | hello world 113 | hey! what's going on 114 | hi! who made you 115 | hello :) 116 | yo! how are you 117 | hi! can you talk 118 | hello there nanochat 119 | hi, what's your name 120 | hey! are you alive 121 | hiya! what are you 122 | hello! tell me about yourself 123 | hi, are you the ai 124 | yo, what is this 125 | hello my friend 126 | hi! who built you 127 | hey nanochat :) 128 | greetings, little model 129 | hi there, what can you do 130 | hello! are you open source 131 | hey, what version are you 132 | hi! nice to meet you 133 | hi :) 134 | hey buddy 135 | hello hello 136 | yo! what's up nanochat 137 | hi! are you real 138 | hey, how's it going 139 | hello! can you hear me 140 | hi nanochat, who trained you 141 | yo, what model are you 142 | hi! tell me a fun fact 143 | hey, are you chatgpt 144 | hello! introduce yourself 145 | hiya there 146 | hi! what's your story 147 | hey, what's nanochat 148 | good day! 149 | hello! who's your creator 150 | hi! which version are you 151 | yo nanochat, what's new 152 | hey there, king's creation 153 | hi nanochatt 154 | helo 155 | hey ther 156 | hii 157 | yo nanocha 158 | heloo! 159 | hi, whos this 160 | hay 161 | helloo?? 162 | hi nanocat 163 | yo! any1 here? 164 | hi, what r u 165 | helo nanochat 166 | hai! 167 | sup bot? 168 | heyy 169 | hi! u there 170 | helllo nano 171 | yo nanochta 172 | hi im bored 173 | heyyo 174 | heyyy 175 | wassup 176 | yo lol 177 | hiii 178 | hiyaaa 179 | sup 180 | heyyoo 181 | yo wut up 182 | helloo lol 183 | yo haha 184 | hru 185 | waddup 186 | heyy :) 187 | yooo 188 | yo bro 189 | haiii 190 | hey u 191 | yo whats gud 192 | yo lolol 193 | HI 194 | HELLOOO 195 | YO!!! 196 | HEY 197 | SUP 198 | WASSUP 199 | HEY!!! 200 | YO BRO 201 | HELLO?? 202 | HI THERE!! 203 | YO WHATS UP 204 | HEY U 205 | HEYOOOO 206 | YO LOL 207 | HIII 208 | HIYA 209 | YOOOO 210 | HELLO!!! 211 | SUPPPP 212 | HEY MAN 213 | hola 214 | bonjour 215 | ciao 216 | hallo 217 | hej 218 | hei 219 | こんにちは 220 | 안녕 221 | 你好 222 | привет 223 | salut 224 | hola amigo 225 | guten tag 226 | shalom 227 | merhaba 228 | namaste 229 | ciao bella 230 | sawasdee 231 | saludos 232 | ola 233 | buongiorno 234 | aloha 235 | czesc 236 | servus 237 | ahoj 238 | hei hei 239 | salve 240 | hola qué tal 241 | buenas 242 | bom dia 243 | добрый день 244 | γειά σου 245 | selam 246 | halo 247 | sveiki 248 | kamusta 249 | שלום 250 | مرحبا 251 | สวัสดีครับ 252 | xin chào 253 | como estas 254 | ça va? 255 | wie geht’s 256 | tudo bem? 257 | 你好吗 258 | annyeong haseyo 259 | konnichiwa, genki? 260 | hola, qué haces 261 | bonjour tout le monde 262 | privet kak dela 263 | ciao come stai 264 | hei miten menee 265 | ola tudo bom 266 | salut, ça roule? 267 | namaste, kaise ho 268 | merhaba nasılsın 269 | hola hola, todo bien? 270 | hej, hur är läget 271 | ahoj, jak se máš 272 | γειά, τι κάνεις 273 | """.strip().split("\n") 274 | 275 | prompt = prompt.replace("%README%", readme) 276 | 277 | # Define the JSON schema for structured output 278 | response_format = { 279 | "type": "json_schema", 280 | "json_schema": { 281 | "name": "conversation", 282 | "strict": True, 283 | "schema": { 284 | "type": "object", 285 | "properties": { 286 | "messages": { 287 | "type": "array", 288 | "description": "A list of conversation messages alternating between user and assistant, with the first message being a user message", 289 | "items": { 290 | "type": "object", 291 | "properties": { 292 | "role": { 293 | "type": "string", 294 | "description": "The role of the speaker, either 'user' or 'assistant'" 295 | }, 296 | "content": { 297 | "type": "string", 298 | "description": "The message content" 299 | } 300 | }, 301 | "required": ["role", "content"], 302 | "additionalProperties": False 303 | } 304 | } 305 | }, 306 | "required": ["messages"], 307 | "additionalProperties": False 308 | } 309 | } 310 | } 311 | 312 | # Sadly it doesn't seem like Chat completions support `n` 313 | # to generate multiple completions per prompt. 314 | base_payload = { 315 | "model": "google/gemini-2.5-flash", 316 | "stream": False, 317 | "response_format": response_format, 318 | "temperature": 1.0, 319 | } 320 | 321 | def generate_conversation(idx: int): 322 | """ 323 | Generate a single conversation using the OpenRouter API. 324 | Returns a list of message dicts with 'role' and 'content' keys. 325 | """ 326 | 327 | # pick 5 example user first messages and insert them into prompt as inspiration 328 | rng = random.Random(idx) # use idx as seed to the rng 329 | user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5)) 330 | payload = copy.deepcopy(base_payload) 331 | modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt) 332 | payload['messages'] = [{"role": "user", "content": modified_prompt}] 333 | 334 | response = requests.post(url, headers=headers, json=payload) 335 | result = response.json() 336 | content = result['choices'][0]['message']['content'] 337 | 338 | # Parse the JSON response and unpack the messages 339 | conversation_data = json.loads(content) 340 | messages = conversation_data['messages'] 341 | 342 | return messages 343 | 344 | 345 | # Configuration 346 | num_conversations = 1000 347 | num_workers = 4 348 | 349 | output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl") 350 | # Wipe the file clean first to reset it 351 | if os.path.exists(output_file): 352 | os.remove(output_file) 353 | print(f"Saving to {output_file}") 354 | 355 | # Use ThreadPoolExecutor to generate conversations in parallel 356 | print(f"Generating {num_conversations} conversations with {num_workers} workers...") 357 | completed_count = 0 358 | error_count = 0 359 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 360 | 361 | # Submit all tasks 362 | futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)] 363 | 364 | # Process results as they complete 365 | for future in as_completed(futures): 366 | try: 367 | messages = future.result() 368 | 369 | # Lightly validate the conversation structure 370 | for i, message in enumerate(messages): 371 | expected_role = "user" if i % 2 == 0 else "assistant" 372 | assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" 373 | 374 | # If all looks good, write the messages to file 375 | with open(output_file, 'a') as f: 376 | f.write(json.dumps(messages) + '\n') 377 | completed_count += 1 378 | print(f"✓ Saved conversation {completed_count}/{num_conversations}") 379 | 380 | except Exception as e: 381 | error_count += 1 382 | print(f"✗ Error generating conversation: {e}") 383 | 384 | print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}") 385 | if error_count > 0: 386 | print(f"Encountered {error_count} errors during generation") 387 | 388 | -------------------------------------------------------------------------------- /nanochat/core_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for evaluating the CORE metric, as described in the DCLM paper. 3 | https://arxiv.org/abs/2406.11794 4 | 5 | TODOs: 6 | - All tasks ~match except for squad. We get 31% reference is 37%. Figure out why. 7 | """ 8 | import random 9 | 10 | from jinja2 import Template 11 | import torch 12 | import torch.distributed as dist 13 | 14 | # ----------------------------------------------------------------------------- 15 | # Prompt rendering utilities 16 | 17 | def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None): 18 | """Render complete prompts for a multiple choice question""" 19 | template_str = """ 20 | {%- for example in fewshot_examples -%} 21 | {{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }} 22 | 23 | {% endfor -%} 24 | {{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip() 25 | template = Template(template_str) 26 | fewshot_examples = fewshot_examples or [] 27 | context = { 28 | 'fewshot_examples': fewshot_examples, 29 | 'continuation_delimiter': continuation_delimiter, 30 | 'item': item 31 | } 32 | prompts = [template.render(choice=choice, **context) for choice in item['choices']] 33 | return prompts 34 | 35 | 36 | def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None): 37 | """Render complete prompts for a schema question""" 38 | template_str = """ 39 | {%- for example in fewshot_examples -%} 40 | {{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }} 41 | 42 | {% endfor -%} 43 | {{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip() 44 | template = Template(template_str) 45 | fewshot_examples = fewshot_examples or [] 46 | context = { 47 | 'fewshot_examples': fewshot_examples, 48 | 'continuation_delimiter': continuation_delimiter, 49 | 'item': item 50 | } 51 | prompts = [template.render(context=context_option, **context) 52 | for context_option in item['context_options']] 53 | return prompts 54 | 55 | 56 | def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None): 57 | """ 58 | Render complete prompt for a language modeling task. 59 | Notice that we manually trim the context in the template, 60 | which in some datasets seems to have trailing whitespace (which we don't want). 61 | """ 62 | template_str = """ 63 | {%- for example in fewshot_examples -%} 64 | {{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }} 65 | 66 | {% endfor -%} 67 | {{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip() 68 | template = Template(template_str) 69 | fewshot_examples = fewshot_examples or [] 70 | context = { 71 | 'fewshot_examples': fewshot_examples, 72 | 'continuation_delimiter': continuation_delimiter, 73 | 'item': item 74 | } 75 | # Return two prompts: without and with the continuation 76 | prompt_without = template.render(include_continuation=False, **context) 77 | prompt_with = template.render(include_continuation=True, **context) 78 | # Due to the way the data seems to be stored, I think I need to strip in the case of LM here. 79 | # Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next 80 | # token in prompt_with), meaning we don't get a nice and clean prefix in the token space 81 | # to detect the final continuation. Tokenizers... 82 | prompt_without = prompt_without.strip() 83 | return [prompt_without, prompt_with] 84 | 85 | 86 | def find_common_length(token_sequences, direction='left'): 87 | """ 88 | Find the length of the common prefix or suffix across token sequences 89 | - direction: 'left' for prefix, 'right' for suffix 90 | """ 91 | min_len = min(len(seq) for seq in token_sequences) 92 | indices = { 93 | 'left': range(min_len), 94 | 'right': range(-1, -min_len-1, -1) 95 | }[direction] 96 | # Find the first position where the token sequences differ 97 | for i, idx in enumerate(indices): 98 | token = token_sequences[0][idx] 99 | if not all(seq[idx] == token for seq in token_sequences): 100 | return i 101 | return min_len 102 | 103 | 104 | def stack_sequences(tokens, pad_token_id): 105 | """Stack up a list of token sequences, pad to longest on the right""" 106 | bsz, seq_len = len(tokens), max(len(x) for x in tokens) 107 | input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long) 108 | for i, x in enumerate(tokens): 109 | input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long) 110 | return input_ids 111 | 112 | 113 | def batch_sequences_mc(tokenizer, prompts): 114 | # In multiple choice, contexts are the same but the continuation is different (common prefix) 115 | tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) 116 | # figure out the start and end of each continuation 117 | answer_start_idx = find_common_length(tokens, direction='left') 118 | start_indices = [answer_start_idx] * len(prompts) 119 | end_indices = [len(x) for x in tokens] 120 | return tokens, start_indices, end_indices 121 | 122 | 123 | def batch_sequences_schema(tokenizer, prompts): 124 | # In schema tasks, contexts vary but continuation is the same (common suffix) 125 | tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) 126 | # figure out the start and end of each context 127 | suffix_length = find_common_length(tokens, direction='right') 128 | end_indices = [len(x) for x in tokens] 129 | start_indices = [ei - suffix_length for ei in end_indices] 130 | return tokens, start_indices, end_indices 131 | 132 | 133 | def batch_sequences_lm(tokenizer, prompts): 134 | # In LM tasks, we have two prompts: without and with continuation 135 | tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) 136 | tokens_without, tokens_with = tokens 137 | start_idx, end_idx = len(tokens_without), len(tokens_with) 138 | assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with" 139 | assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with" 140 | # we only need the with continuation prompt in the LM task, i.e. batch size of 1 141 | return [tokens_with], [start_idx], [end_idx] 142 | 143 | 144 | @torch.no_grad() 145 | def forward_model(model, input_ids): 146 | """ 147 | Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions. 148 | The last column of losses is set to nan because we don't have autoregressive targets there. 149 | """ 150 | batch_size, seq_len = input_ids.size() 151 | outputs = model(input_ids) 152 | # Roll the tensor to the left by one position to get the (autoregressive) target ids 153 | target_ids = torch.roll(input_ids, shifts=-1, dims=1) 154 | # Calculate cross entropy at all positions 155 | losses = torch.nn.functional.cross_entropy( 156 | outputs.view(batch_size * seq_len, -1), 157 | target_ids.view(batch_size * seq_len), 158 | reduction='none' 159 | ).view(batch_size, seq_len) 160 | # Set the last column to be nan because there is no autoregressive loss there 161 | losses[:, -1] = float('nan') 162 | # Get the argmax predictions at each position 163 | predictions = outputs.argmax(dim=-1) 164 | return losses, predictions 165 | 166 | 167 | @torch.no_grad() 168 | def evaluate_example(idx, model, tokenizer, data, device, task_meta): 169 | """Evaluate a single example, return True if correct, False otherwise""" 170 | item = data[idx] 171 | task_type = task_meta['task_type'] 172 | num_fewshot = task_meta['num_fewshot'] 173 | continuation_delimiter = task_meta['continuation_delimiter'] 174 | 175 | # Sample few-shot examples (excluding current item) 176 | fewshot_examples = [] 177 | if num_fewshot > 0: 178 | rng = random.Random(1234 + idx) 179 | available_indices = [i for i in range(len(data)) if i != idx] 180 | fewshot_indices = rng.sample(available_indices, num_fewshot) 181 | fewshot_examples = [data[i] for i in fewshot_indices] 182 | 183 | # Render prompts and batch sequences based on task type 184 | if task_type == 'multiple_choice': 185 | prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples) 186 | tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts) 187 | elif task_type == 'schema': 188 | prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples) 189 | tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts) 190 | elif task_type == 'language_modeling': 191 | prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples) 192 | tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts) 193 | else: 194 | raise ValueError(f"Unsupported task type: {task_type}") 195 | 196 | # Some models can't forward sequences beyond a certain length (e.g. GPT-2) 197 | # In these cases, we have to truncate sequences to max length and adjust the indices 198 | if hasattr(model, 'max_seq_len') and model.max_seq_len is not None: 199 | max_tokens = model.max_seq_len 200 | new_tokens, new_start_idxs, new_end_idxs = [], [], [] 201 | for t, s, e in zip(tokens, start_idxs, end_idxs): 202 | if len(t) > max_tokens: 203 | num_to_crop = len(t) - max_tokens 204 | new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens 205 | new_start_idxs.append(s - num_to_crop) # shift the indices down 206 | new_end_idxs.append(e - num_to_crop) 207 | assert s - num_to_crop >= 0, "this should never happen right?" 208 | assert e - num_to_crop >= 0, "this should never happen right?" 209 | else: 210 | new_tokens.append(t) # keep unchanged 211 | new_start_idxs.append(s) 212 | new_end_idxs.append(e) 213 | tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs 214 | 215 | # Stack up all the sequences into a batch 216 | pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok 217 | input_ids = stack_sequences(tokens, pad_token_id) 218 | input_ids = input_ids.to(device) 219 | 220 | # Forward the model, get the autoregressive loss and argmax prediction at each token 221 | losses, predictions = forward_model(model, input_ids) 222 | 223 | # See if the losses/predictions come out correctly 224 | if task_type == 'language_modeling': 225 | # language modeling task is currently always batch size 1 226 | si = start_idxs[0] 227 | ei = end_idxs[0] 228 | # predictions[i] predict input_ids[i+1] autoregressively 229 | predicted_tokens = predictions[0, si-1:ei-1] 230 | actual_tokens = input_ids[0, si:ei] 231 | is_correct = torch.all(predicted_tokens == actual_tokens).item() 232 | elif task_type in ['multiple_choice', 'schema']: 233 | # For MC/schema: find the option with lowest average loss 234 | mean_losses = [losses[i, si-1:ei-1].mean().item() 235 | for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))] 236 | pred_idx = mean_losses.index(min(mean_losses)) 237 | is_correct = pred_idx == item['gold'] 238 | else: 239 | raise ValueError(f"Unsupported task type: {task_type}") 240 | 241 | return is_correct 242 | 243 | 244 | def evaluate_task(model, tokenizer, data, device, task_meta): 245 | """ 246 | This function is responsible for evaluating one task across many examples. 247 | It also handles dispatch to all processes if the script is run with torchrun. 248 | """ 249 | rank = dist.get_rank() if dist.is_initialized() else 0 250 | world_size = dist.get_world_size() if dist.is_initialized() else 1 251 | correct = torch.zeros(len(data), dtype=torch.float32, device=device) 252 | # stride the examples to each rank 253 | for idx in range(rank, len(data), world_size): 254 | is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta) 255 | correct[idx] = float(is_correct) 256 | # sync results across all the processes if running distributed 257 | if world_size > 1: 258 | dist.barrier() 259 | dist.all_reduce(correct, op=dist.ReduceOp.SUM) 260 | # compute the mean 261 | mean_correct = correct.mean().item() 262 | return mean_correct 263 | -------------------------------------------------------------------------------- /scripts/tok_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate compression ratio of the tokenizer. 3 | """ 4 | 5 | from nanochat.tokenizer import get_tokenizer, RustBPETokenizer 6 | from nanochat.dataset import parquets_iter_batched 7 | 8 | # Random text I got from a random website this morning 9 | news_text = r""" 10 | (Washington, D.C., July 9, 2025)- Yesterday, Mexico’s National Service of Agro-Alimentary Health, Safety, and Quality (SENASICA) reported a new case of New World Screwworm (NWS) in Ixhuatlan de Madero, Veracruz in Mexico, which is approximately 160 miles northward of the current sterile fly dispersal grid, on the eastern side of the country and 370 miles south of the U.S./Mexico border. This new northward detection comes approximately two months after northern detections were reported in Oaxaca and Veracruz, less than 700 miles away from the U.S. border, which triggered the closure of our ports to Mexican cattle, bison, and horses on May 11, 2025. 11 | 12 | While USDA announced a risk-based phased port re-opening strategy for cattle, bison, and equine from Mexico beginning as early as July 7, 2025, this newly reported NWS case raises significant concern about the previously reported information shared by Mexican officials and severely compromises the outlined port reopening schedule of five ports from July 7-September 15. Therefore, in order to protect American livestock and our nation’s food supply, Secretary Rollins has ordered the closure of livestock trade through southern ports of entry effective immediately. 13 | 14 | “The United States has promised to be vigilant — and after detecting this new NWS case, we are pausing the planned port reopening’s to further quarantine and target this deadly pest in Mexico. We must see additional progress combatting NWS in Veracruz and other nearby Mexican states in order to reopen livestock ports along the Southern border,” said U.S. Secretary of Agriculture Brooke L. Rollins. “Thanks to the aggressive monitoring by USDA staff in the U.S. and in Mexico, we have been able to take quick and decisive action to respond to the spread of this deadly pest.” 15 | """.strip() 16 | 17 | # Random Korean text (to test non-English compression) 18 | korean_text = r""" 19 | 정직한 사실 위에, 공정한 시선을 더하다 20 | Herald Korea Times 21 | 22 | 헤럴드코리아타임즈는 정치, 경제, 사회, 문화 등 한국 사회 전반의 주요 이슈를 심도 있게 다루는 종합 온라인 신문사입니다. 23 | 24 | 우리는 단순히 뉴스를 전달하는 것이 아니라, 사실(Fact)에 기반한 양측의 시각을 균형 있게 조명하며, 독자 여러분이 스스로 판단할 수 있는 ‘정보의 균형’을 제공합니다. 25 | 26 | 한국 언론의 오랜 문제로 지적되어 온 정치적 편향, 이념적 왜곡에서 벗어나 27 | 오직 정직함과 공정함을 원칙으로 삼는 언론을 지향합니다. 28 | 어느 한쪽의 주장만을 확대하거나 감추지 않고, 29 | **모든 쟁점에 대해 ‘무엇이 쟁점인지’, ‘누가 무엇을 주장하는지’, ‘사실은 무엇인지’**를 명확히 전달하는 데 집중합니다. 30 | """.strip() 31 | 32 | # Random piece of code 33 | code_text = r""" 34 | class BasicTokenizer(Tokenizer): 35 | 36 | def __init__(self): 37 | super().__init__() 38 | 39 | def train(self, text, vocab_size, verbose=False): 40 | assert vocab_size >= 256 41 | num_merges = vocab_size - 256 42 | 43 | # input text preprocessing 44 | text_bytes = text.encode("utf-8") # raw bytes 45 | ids = list(text_bytes) # list of integers in range 0..255 46 | 47 | # iteratively merge the most common pairs to create new tokens 48 | merges = {} # (int, int) -> int 49 | vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes 50 | for i in range(num_merges): 51 | # count up the number of times every consecutive pair appears 52 | stats = get_stats(ids) 53 | # find the pair with the highest count 54 | pair = max(stats, key=stats.get) 55 | # mint a new token: assign it the next available id 56 | idx = 256 + i 57 | # replace all occurrences of pair in ids with idx 58 | ids = merge(ids, pair, idx) 59 | # save the merge 60 | merges[pair] = idx 61 | vocab[idx] = vocab[pair[0]] + vocab[pair[1]] 62 | # prints 63 | if verbose: 64 | print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") 65 | """.strip() 66 | 67 | math_text = r""" 68 | \documentclass[12pt]{article} 69 | \usepackage{amsmath,amsthm,amssymb} 70 | \usepackage[margin=1in]{geometry} 71 | 72 | \newtheorem{theorem}{Theorem} 73 | \newtheorem*{remark}{Remark} 74 | 75 | \begin{document} 76 | 77 | \begin{center} 78 | {\Large A Cute Identity: The Sum of Cubes is a Square} 79 | \end{center} 80 | 81 | \begin{theorem} 82 | For every integer $n \ge 1$, 83 | \[ 84 | \sum_{k=1}^{n} k^{3} \;=\; \left(\frac{n(n+1)}{2}\right)^{2}. 85 | \] 86 | \end{theorem} 87 | 88 | \begin{proof}[Proof 1 (Induction)] 89 | Let $S(n) = \sum_{k=1}^{n} k^3$. For $n=1$, $S(1)=1=(1\cdot 2/2)^2$, so the base case holds. 90 | 91 | Assume $S(n)=\big(\tfrac{n(n+1)}{2}\big)^2$ for some $n\ge 1$. 92 | Then 93 | \[ 94 | S(n+1) 95 | = S(n) + (n+1)^3 96 | = \left(\frac{n(n+1)}{2}\right)^2 + (n+1)^3. 97 | \] 98 | Factor out $(n+1)^2$: 99 | \[ 100 | S(n+1) 101 | = (n+1)^2\left( \frac{n^2}{4} + (n+1) \right) 102 | = (n+1)^2\left( \frac{n^2 + 4n + 4}{4} \right) 103 | = (n+1)^2\left( \frac{(n+2)^2}{4} \right). 104 | \] 105 | Thus 106 | \[ 107 | S(n+1)=\left(\frac{(n+1)(n+2)}{2}\right)^2, 108 | \] 109 | which matches the claimed formula with $n$ replaced by $n+1$. By induction, the identity holds for all $n\ge 1$. 110 | \end{proof} 111 | 112 | \begin{proof}[Proof 2 (Algebraic telescoping)] 113 | Recall the binomial identity 114 | \[ 115 | (k+1)^4 - k^4 = 4k^3 + 6k^2 + 4k + 1. 116 | \] 117 | Summing both sides from $k=0$ to $n$ telescopes: 118 | \[ 119 | (n+1)^4 - 0^4 120 | = \sum_{k=0}^{n}\big(4k^3 + 6k^2 + 4k + 1\big) 121 | = 4\sum_{k=1}^{n}k^3 + 6\sum_{k=1}^{n}k^2 + 4\sum_{k=1}^{n}k + (n+1). 122 | \] 123 | Using the standard sums 124 | \[ 125 | \sum_{k=1}^{n}k = \frac{n(n+1)}{2} 126 | \quad\text{and}\quad 127 | \sum_{k=1}^{n}k^2 = \frac{n(n+1)(2n+1)}{6}, 128 | \] 129 | solve for $\sum_{k=1}^{n}k^3$ to get 130 | \[ 131 | \sum_{k=1}^{n}k^3 = \left(\frac{n(n+1)}{2}\right)^2. 132 | \] 133 | \end{proof} 134 | 135 | \begin{remark} 136 | Geometrically, the identity says: ``adding up $1^3,2^3,\dots,n^3$ builds a perfect square’’—namely the square of the $n$th triangular number. This is why one sometimes calls it the \emph{sum-of-cubes is a square} phenomenon. 137 | \end{remark} 138 | 139 | \end{document} 140 | """.strip() 141 | 142 | science_text = r""" 143 | Photosynthesis is a photochemical energy transduction process in which light-harvesting pigment–protein complexes within the thylakoid membranes of oxygenic phototrophs absorb photons and initiate charge separation at the reaction center, driving the linear electron transport chain from water to NADP⁺ via photosystem II, the cytochrome b₆f complex, and photosystem I, concomitantly generating a trans-thylakoid proton motive force utilized by chloroplastic ATP synthase. The light-dependent reactions produce ATP and NADPH, which fuel the Calvin–Benson–Bassham cycle in the stroma, wherein ribulose-1,5-bisphosphate is carboxylated by ribulose-1,5-bisphosphate carboxylase/oxygenase (RuBisCO) to form 3-phosphoglycerate, subsequently reduced and regenerated through a series of enzymatic steps, enabling net assimilation of CO₂ into triose phosphates and ultimately carbohydrates. This process is tightly regulated by photoprotective mechanisms, redox feedback, and metabolite flux, representing a central biochemical pathway coupling solar energy capture to the biosphere’s primary productivity. 144 | """.strip() 145 | 146 | # The tokenizer was trained on data from earlier shards, so it has seen this data 147 | train_docs = next(parquets_iter_batched(split="train")) 148 | train_text = "\n".join(train_docs) 149 | val_docs = next(parquets_iter_batched(split="val")) 150 | val_text = "\n".join(val_docs) 151 | 152 | all_text = [ 153 | ("news", news_text), 154 | ("korean", korean_text), 155 | ("code", code_text), 156 | ("math", math_text), 157 | ("science", science_text), 158 | ("fwe-train", train_text), 159 | ] 160 | if val_text: 161 | all_text.append(("fwe-val", val_text)) 162 | 163 | # Try out current default compared to GPT-2 and GPT-4 tokenizers 164 | tokenizer_results = {} 165 | vocab_sizes = {} 166 | 167 | for tokenizer_name in ["gpt2", "gpt4", "ours"]: 168 | 169 | if tokenizer_name == "gpt2": 170 | tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer 171 | elif tokenizer_name == "gpt4": 172 | tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer 173 | else: 174 | tokenizer = get_tokenizer() 175 | 176 | vocab_sizes[tokenizer_name] = tokenizer.get_vocab_size() 177 | tokenizer_results[tokenizer_name] = {} 178 | 179 | for name, text in all_text: 180 | encoded = tokenizer.encode(text) 181 | decoded = tokenizer.decode(encoded) 182 | assert decoded == text 183 | 184 | encoded_bytes = text.encode('utf-8') 185 | ratio = len(encoded_bytes) / len(encoded) 186 | tokenizer_results[tokenizer_name][name] = { 187 | 'bytes': len(encoded_bytes), 188 | 'tokens': len(encoded), 189 | 'ratio': ratio 190 | } 191 | 192 | # ANSI color codes 193 | GREEN = '\033[92m' 194 | RED = '\033[91m' 195 | RESET = '\033[0m' 196 | 197 | # Print vocab sizes 198 | print(f"\nVocab sizes:") 199 | print(f"GPT-2: {vocab_sizes['gpt2']}") 200 | print(f"GPT-4: {vocab_sizes['gpt4']}") 201 | print(f"Ours: {vocab_sizes['ours']}") 202 | 203 | def print_comparison(baseline_name, baseline_results, ours_results, all_text): 204 | """Print comparison table between baseline tokenizer and ours.""" 205 | print(f"\nComparison with {baseline_name}:") 206 | print("=" * 95) 207 | print(f"{'Text Type':<10} {'Bytes':<8} {baseline_name:<15} {'Ours':<15} {'Relative':<12} {'Better':<10}") 208 | print(f"{'':10} {'':8} {'Tokens':<7} {'Ratio':<7} {'Tokens':<7} {'Ratio':<7} {'Diff %':<12}") 209 | print("-" * 95) 210 | 211 | for name, text in all_text: 212 | baseline_data = baseline_results[name] 213 | ours_data = ours_results[name] 214 | 215 | # Calculate relative difference (positive means ours is better, negative means worse) 216 | # Using tokens: fewer tokens is better, so we calculate (baseline_tokens - ours_tokens) / baseline_tokens 217 | relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100 218 | 219 | # Determine which has better compression (higher ratio = better) 220 | if baseline_data['ratio'] > ours_data['ratio']: 221 | baseline_color, ours_color = GREEN, RED 222 | better = baseline_name 223 | diff_color = RED 224 | elif ours_data['ratio'] > baseline_data['ratio']: 225 | baseline_color, ours_color = RED, GREEN 226 | better = "Ours" 227 | diff_color = GREEN 228 | else: 229 | baseline_color, ours_color = "", "" 230 | better = "Tie" 231 | diff_color = "" 232 | 233 | print(f"{name:<10} {baseline_data['bytes']:<8} " 234 | f"{baseline_color}{baseline_data['tokens']:<7}{RESET} " 235 | f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} " 236 | f"{ours_color}{ours_data['tokens']:<7}{RESET} " 237 | f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} " 238 | f"{diff_color}{relative_diff:+7.1f}%{RESET} " 239 | f"{better:<10}") 240 | 241 | # Print comparisons 242 | print_comparison("GPT-2", tokenizer_results['gpt2'], tokenizer_results['ours'], all_text) 243 | print_comparison("GPT-4", tokenizer_results['gpt4'], tokenizer_results['ours'], all_text) 244 | 245 | # Log to report 246 | from nanochat.report import get_report 247 | lines = [] 248 | for baseline_name in ["GPT-2", "GPT-4"]: 249 | baseline_key = baseline_name.lower().replace('-', '') 250 | baseline_results = tokenizer_results[baseline_key] 251 | ours_results = tokenizer_results['ours'] 252 | lines.append(f"### Comparison with {baseline_name}") 253 | lines.append("") 254 | lines.append("| Text Type | Bytes | " + baseline_name + " Tokens | " + baseline_name + " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |") 255 | lines.append("|-----------|-------|--------------|--------------|-------------|------------|-----------------|") 256 | for name, text in all_text: 257 | baseline_data = baseline_results[name] 258 | ours_data = ours_results[name] 259 | relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100 260 | lines.append(f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |") 261 | lines.append("") 262 | report_markdown = "\n".join(lines) 263 | get_report().log(section="Tokenizer evaluation", data=[ 264 | report_markdown, 265 | ]) 266 | -------------------------------------------------------------------------------- /scripts/chat_sft.py: -------------------------------------------------------------------------------- 1 | """ 2 | Finetune a base model to be a chat model. 3 | Run on one GPU e.g. for debugging: 4 | 5 | python -m scripts.chat_sft 6 | 7 | Or torchrun for training: 8 | 9 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft 10 | """ 11 | 12 | import os 13 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 14 | 15 | import wandb 16 | import torch 17 | import torch.distributed as dist 18 | from contextlib import nullcontext 19 | 20 | from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type 21 | from nanochat.checkpoint_manager import load_model 22 | from nanochat.checkpoint_manager import save_checkpoint 23 | from nanochat.engine import Engine 24 | from scripts.chat_eval import run_chat_eval 25 | 26 | from tasks.common import TaskMixture 27 | from tasks.arc import ARC 28 | from tasks.gsm8k import GSM8K 29 | from tasks.smoltalk import SmolTalk 30 | from tasks.customjson import CustomJSON 31 | 32 | # ----------------------------------------------------------------------------- 33 | # SFT Hyperparameters 34 | run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) 35 | # input model options 36 | source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model) 37 | model_tag = None # model tag to load the model from (base model or midtrained model) 38 | step = None # step to load the model from (base model or midtrained model) 39 | # compute/precision 40 | device_type = "" # cuda|cpu|mps (empty => autodetect) 41 | dtype = "bfloat16" 42 | device_batch_size = 4 # max to avoid OOM 43 | # optimization 44 | num_epochs = 1 45 | num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it) 46 | target_examples_per_step = 32 47 | unembedding_lr = 0.004 48 | embedding_lr = 0.2 49 | matrix_lr = 0.02 50 | weight_decay = 0.0 51 | init_lr_frac = 0.02 52 | # evaluation and logging there of 53 | eval_every = 100 54 | eval_steps = 100 55 | eval_metrics_every = 200 56 | eval_metrics_max_problems = 1024 57 | # now allow CLI to override the settings via the configurator lol 58 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 59 | exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file 60 | user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging 61 | # ----------------------------------------------------------------------------- 62 | 63 | # Compute init 64 | device_type = autodetect_device_type() if device_type == "" else device_type 65 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) 66 | master_process = ddp_rank == 0 67 | ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16 68 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() 69 | 70 | # wandb logging init 71 | use_dummy_wandb = run == "dummy" or not master_process 72 | wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True) 73 | 74 | # Load the model and tokenizer 75 | model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step) 76 | orig_model = model # original, uncompiled model 77 | # model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs 78 | engine = Engine(model, tokenizer) # will be used for inline model evaluation only 79 | 80 | # ----------------------------------------------------------------------------- 81 | # Task data mixture we'll train on 82 | identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl") 83 | train_ds = TaskMixture([ 84 | ARC(subset="ARC-Easy", split="train"), # 2.3K rows 85 | ARC(subset="ARC-Challenge", split="train"), # 1.1K rows 86 | GSM8K(subset="main", split="train"), # 8K rows 87 | SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk 88 | CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations 89 | ]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows 90 | val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) 91 | 92 | # ----------------------------------------------------------------------------- 93 | # DataLoader 94 | 95 | def sft_data_generator(dataset, batch_size): 96 | pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss 97 | # prepares a list of tokenized conversations into a batch and yields 98 | def collate_and_yield(batch): 99 | nrows = len(batch) 100 | ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1 101 | inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long) 102 | targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index 103 | for i, (ids, mask) in enumerate(batch): 104 | n = len(ids) 105 | ids_tensor = torch.tensor(ids, dtype=torch.long) 106 | inputs[i, :n-1] = ids_tensor[:-1] 107 | # recall -1 is the ignore index, so mask out targets where mask is 0 108 | row_targets = ids_tensor[1:] 109 | # mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok 110 | mask_tensor = torch.tensor(mask[1:], dtype=torch.long) 111 | row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0 112 | targets[i, :n-1] = row_targets 113 | inputs = inputs.to(device) # move to device 114 | targets = targets.to(device) 115 | return inputs, targets 116 | # iterates over the dataset in epochs, tokenizes 117 | batch = [] 118 | while True: 119 | for i in range(ddp_rank, len(dataset), ddp_world_size): 120 | doc = dataset[i] 121 | ids, mask = tokenizer.render_conversation(doc) 122 | batch.append((ids, mask)) 123 | if len(batch) == batch_size: 124 | yield collate_and_yield(batch) 125 | batch = [] 126 | 127 | examples_per_step = device_batch_size * ddp_world_size 128 | print0(f"Target examples per step: {target_examples_per_step}") 129 | print0(f"Device batch size: {device_batch_size}") 130 | print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}") 131 | assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step" 132 | grad_accum_steps = target_examples_per_step // examples_per_step 133 | print0(f"=> Setting grad accum steps: {grad_accum_steps}") 134 | 135 | if num_iterations == -1: 136 | # derive num_iterations from num_epochs and the size of the dataset 137 | assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1" 138 | num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs 139 | train_loader = sft_data_generator(train_ds, batch_size=device_batch_size) 140 | build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size) 141 | 142 | # ----------------------------------------------------------------------------- 143 | # Initialize the Optimizer 144 | 145 | optimizers = model.setup_optimizers( 146 | unembedding_lr=unembedding_lr, 147 | embedding_lr=embedding_lr, 148 | matrix_lr=matrix_lr, 149 | weight_decay=weight_decay, 150 | ) 151 | # Set the initial learning rate as a fraction of the base learning rate 152 | for opt in optimizers: 153 | for group in opt.param_groups: 154 | group["lr"] = group["lr"] * init_lr_frac 155 | group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later 156 | 157 | # ----------------------------------------------------------------------------- 158 | # Training loop 159 | 160 | # Learning rate scheduler 161 | def get_lr_multiplier(it): 162 | lrm = 1.0 - it / num_iterations 163 | return lrm 164 | 165 | # Go! 166 | step = 0 167 | train_iter = iter(train_loader) 168 | for step in range(num_iterations): 169 | last_step = step == num_iterations - 1 170 | 171 | # evaluate the validation loss 172 | if last_step or step % eval_every == 0: 173 | model.eval() 174 | val_iter = iter(build_val_loader()) 175 | losses = [] 176 | for _ in range(eval_steps): 177 | val_inputs, val_targets = next(val_iter) 178 | with torch.no_grad(), autocast_ctx: 179 | loss = model(val_inputs, val_targets) 180 | losses.append(loss) 181 | val_loss = torch.stack(losses).mean() # average over eval_steps 182 | if ddp: 183 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks 184 | val_loss = val_loss.item() 185 | print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}") 186 | wandb_run.log({ 187 | "step": step, 188 | "val_loss": val_loss, 189 | }) 190 | model.train() 191 | 192 | # evlauate accuracy of the multiple choice tasks (which are quick to run) 193 | if last_step or (step > 0 and step % eval_metrics_every == 0): 194 | model.eval() 195 | metrics = {} 196 | with torch.no_grad(), autocast_ctx: 197 | # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size 198 | metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems) 199 | metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems) 200 | metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) 201 | print0(f"Step {step:05d} | {metrics_str}") 202 | wandb_run.log({ 203 | "step": step, 204 | **metrics, 205 | }) 206 | model.train() 207 | 208 | if last_step: 209 | break 210 | 211 | # evaluate the gradient 212 | num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen 213 | for micro_step in range(grad_accum_steps): 214 | train_inputs, train_targets = next(train_iter) 215 | with autocast_ctx: 216 | loss = model(train_inputs, train_targets) 217 | train_loss = loss.detach() # for logging 218 | loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here 219 | loss.backward() # accumulate the gradient 220 | num_tokens += (train_targets >= 0).sum() 221 | if ddp: 222 | dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks 223 | 224 | # learning rate scheduler 225 | lrm = get_lr_multiplier(step) 226 | for opt in optimizers: 227 | for group in opt.param_groups: 228 | group["lr"] = group["initial_lr"] * lrm 229 | 230 | # step the optimizers 231 | for opt in optimizers: 232 | opt.step() 233 | model.zero_grad(set_to_none=True) 234 | 235 | # logging 236 | train_loss_item = train_loss.item() 237 | num_tokens_item = num_tokens.item() 238 | print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}") 239 | wandb_run.log({ 240 | "step": step, 241 | "lrm": lrm, 242 | "train_loss": train_loss_item, 243 | "num_tokens": num_tokens_item, 244 | }) 245 | step += 1 246 | 247 | # Save the model at the end of the run 248 | if master_process: 249 | base_dir = get_base_dir() 250 | depth = model.config.n_layer 251 | model_tag = f"d{depth}" # base the model tag on the depth of the base model 252 | checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag) 253 | model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer 254 | save_checkpoint( 255 | checkpoint_dir, 256 | step, 257 | model.state_dict(), 258 | None, # note: we don't bother to save the optimizer state 259 | { 260 | "step": step, 261 | "val_loss": val_loss, 262 | **metrics, 263 | "model_config": model_config_kwargs, 264 | } 265 | ) 266 | print(f"✅ Saved model checkpoint to {checkpoint_dir}") 267 | 268 | # Log to report 269 | from nanochat.report import get_report 270 | get_report().log(section="Chat SFT", data=[ 271 | user_config, # CLI args 272 | { 273 | "Training rows": len(train_ds), 274 | "Number of iterations": num_iterations, 275 | "Training loss": train_loss_item, 276 | "Validation loss": val_loss, 277 | }, 278 | ]) 279 | 280 | # Cleanup 281 | wandb_run.finish() 282 | compute_cleanup() 283 | -------------------------------------------------------------------------------- /scripts/chat_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate the Chat model. 3 | All the generic code lives here, and all the evlauation-specific 4 | code lives in nanochat directory and is imported from here. 5 | 6 | Example runs: 7 | python -m scripts.chat_eval -a ARC-Easy 8 | torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy 9 | """ 10 | 11 | import argparse 12 | from functools import partial 13 | from contextlib import nullcontext 14 | 15 | import torch 16 | import torch.distributed as dist 17 | 18 | from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type 19 | from nanochat.checkpoint_manager import load_model 20 | from nanochat.engine import Engine 21 | 22 | from tasks.humaneval import HumanEval 23 | from tasks.mmlu import MMLU 24 | from tasks.arc import ARC 25 | from tasks.gsm8k import GSM8K 26 | 27 | # ----------------------------------------------------------------------------- 28 | # Generative evaluation loop (we go one problem at a time, sample, evaluate) 29 | 30 | def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None): 31 | 32 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() 33 | device = model.get_device() 34 | 35 | num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems) 36 | 37 | # Run the evaluation 38 | num_passed, total = 0, 0 39 | for i in range(ddp_rank, num_problems, ddp_world_size): 40 | conversation = task_object[i] 41 | 42 | # Tokenize the prompt 43 | encoded_prompt = tokenizer.render_for_completion(conversation) 44 | # Get the completions 45 | results, _ = engine.generate_batch( 46 | encoded_prompt, 47 | num_samples=num_samples, 48 | max_tokens=max_new_tokens, 49 | temperature=temperature, 50 | top_k=top_k, 51 | ) 52 | # Decode the completions as text 53 | prefix_length = len(encoded_prompt) 54 | completions = [tokenizer.decode(result_tokens[prefix_length:]) for result_tokens in results] 55 | # Evaluate success criteria 56 | outcomes = [task_object.evaluate(conversation, completion) for completion in completions] 57 | passed = any(outcomes) 58 | 59 | # Keep stats 60 | total += 1 61 | num_passed += int(passed) 62 | 63 | # Logging (overwrite the same line in the console) 64 | print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100*num_passed/total:.2f}%)", end='', flush=True) 65 | 66 | # Finish the in-place progress line with a newline before final summary 67 | print() 68 | 69 | # Aggregate results across all ranks 70 | if ddp: 71 | num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device) 72 | total_tensor = torch.tensor([total], dtype=torch.long, device=device) 73 | dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM) 74 | dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM) 75 | num_passed = num_passed_tensor.item() 76 | total = total_tensor.item() 77 | 78 | print0("=" * 50) 79 | print0(f"Final: {num_passed}/{total} ({100*num_passed/total:.2f}%)") 80 | 81 | # Return the accuracy 82 | return num_passed/total 83 | 84 | # ----------------------------------------------------------------------------- 85 | # Categorical evaluation loop 86 | # A lot easier because we don't have to sample. Therefore, we can actually go 87 | # batches at a time and just check the logits for correct answer choices. 88 | 89 | def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None): 90 | 91 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() 92 | device = model.get_device() 93 | bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored 94 | 95 | # We'll process batches of independent problems at a time because there is no sampling needed 96 | num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems) 97 | ceil_div = lambda x, y: -(-x // y) 98 | num_batches = ceil_div(num_problems, batch_size) 99 | 100 | # Run the evaluation 101 | letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work 102 | num_passed, total = 0, 0 103 | for i in range(ddp_rank, num_batches, ddp_world_size): 104 | i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems) 105 | 106 | # Prepare the batch of problems. They might all be of different length, so we pad/collate them. 107 | conversations = [task_object[ii] for ii in range(i0, i1)] 108 | prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works 109 | max_length = max(len(ids) for ids in prompt_ids) 110 | answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer) 111 | padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids] 112 | prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device) 113 | 114 | # Get the logits for the whole batch of conversations in parallel (efficiency win here) 115 | with torch.no_grad(): 116 | logits = model(prompt_ids) # (B, T, V) 117 | 118 | # Focus on the available answer on just the letters corresponding to choices 119 | # Note that this helps the evaluation a lot because it specifically narrows the focus to only the avilable letters 120 | # The much harder alternative would be to just generate from the Assistant and check if it responded with the correct 121 | # letter (e.g. A, B, C, D), but evaluations typically make the task easier in this way. 122 | for idx, conversation in enumerate(conversations): 123 | # get the token ids of all the available letters of this problem 124 | letters = conversation['letters'] 125 | letter_ids = [] 126 | for letter in letters: 127 | if not letter in letter_to_id_cache: 128 | encoded_letter = tokenizer.encode(letter) 129 | assert len(encoded_letter) == 1, "Each letter must be a single token" 130 | letter_to_id_cache[letter] = encoded_letter[0] 131 | letter_ids.append(letter_to_id_cache[letter]) 132 | # focus logits just down to the answer position and the available letters of the answer 133 | answer_pos = answer_time_positions[idx] 134 | focus_logits = logits[idx, answer_pos, letter_ids] 135 | # get the argmax letter (the predicted answer) 136 | argmax_letter_id = focus_logits.argmax(dim=-1).item() 137 | predicted_letter = letters[argmax_letter_id] 138 | # evaluate the outcome 139 | outcome = task_object.evaluate(conversation, predicted_letter) 140 | num_passed += int(outcome) 141 | total += 1 142 | 143 | # Aggregate results across all ranks 144 | if ddp: 145 | num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device) 146 | total_tensor = torch.tensor([total], dtype=torch.long, device=device) 147 | dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM) 148 | dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM) 149 | num_passed = num_passed_tensor.item() 150 | total = total_tensor.item() 151 | 152 | average = num_passed/total 153 | print0(f"Final: {num_passed}/{total} ({100*average:.2f}%)") 154 | return average 155 | 156 | # ----------------------------------------------------------------------------- 157 | 158 | def run_chat_eval(task_name, model, tokenizer, engine, 159 | batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50, 160 | max_problems=None): 161 | # Create the evaluation object 162 | task_module = { 163 | 'HumanEval': HumanEval, 164 | 'MMLU': partial(MMLU, subset="all", split="test"), 165 | 'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"), 166 | 'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"), 167 | 'GSM8K': partial(GSM8K, subset="main", split="test"), 168 | }[task_name] 169 | task_object = task_module() 170 | # Run the evaluation 171 | if task_object.eval_type == 'generative': 172 | acc = run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=max_problems) 173 | elif task_object.eval_type == 'categorical': 174 | acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems) 175 | else: 176 | raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}") 177 | return acc 178 | 179 | # ----------------------------------------------------------------------------- 180 | if __name__ == "__main__": 181 | 182 | # Parse command-line arguments 183 | parser = argparse.ArgumentParser() 184 | parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl") 185 | parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.") 186 | parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) 187 | parser.add_argument('-t', '--temperature', type=float, default=0.0) 188 | parser.add_argument('-m', '--max-new-tokens', type=int, default=512) 189 | parser.add_argument('-n', '--num-samples', type=int, default=1) 190 | parser.add_argument('-k', '--top-k', type=int, default=50) 191 | parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch size for categorical evaluation') 192 | parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') 193 | parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') 194 | parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate') 195 | parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') 196 | args = parser.parse_args() 197 | 198 | device_type = autodetect_device_type() if args.device_type == "" else args.device_type 199 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) 200 | ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 201 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() 202 | 203 | model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) 204 | engine = Engine(model, tokenizer) 205 | 206 | # Get the tasks to evaluate on 207 | all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval'] 208 | baseline_accuracies = { 209 | 'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25% 210 | 'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25% 211 | 'MMLU': 0.25, # multiple choice 1 of 4 => 25% 212 | 'GSM8K': 0.0, # open-ended => 0% 213 | 'HumanEval': 0.0, # open-ended => 0% 214 | } 215 | task_names = all_tasks if args.task_name is None else args.task_name.split('|') 216 | 217 | # Run all the task evaluations sequentially 218 | results = {} 219 | for task_name in task_names: 220 | with autocast_ctx: 221 | acc = run_chat_eval( 222 | task_name, 223 | model, tokenizer, engine, 224 | batch_size=args.batch_size, 225 | num_samples=args.num_samples, 226 | max_new_tokens=args.max_new_tokens, 227 | temperature=args.temperature, 228 | top_k=args.top_k, 229 | max_problems=args.max_problems, 230 | ) 231 | results[task_name] = acc 232 | print0(f"{task_name} accuracy: {100 * acc:.2f}%") 233 | 234 | # Log to report 235 | from nanochat.report import get_report 236 | all_tasks_were_evaluated = all(task_name in results for task_name in all_tasks) 237 | # calculate the ChatCORE metric if we can (similar to CORE, it's the mean centered accuracy) 238 | # this way, ChatCORE ranges from 0 (at random baseline) to 1 (peak performance) 239 | chatcore_metric_dict = {} 240 | if all_tasks_were_evaluated: 241 | centered_mean = 0 242 | for task_name, acc in results.items(): 243 | baseline_acc = baseline_accuracies.get(task_name, 0.0) 244 | centered_acc = (acc - baseline_acc) / (1.0 - baseline_acc) 245 | centered_mean += centered_acc 246 | chatcore_metric = centered_mean / len(results) 247 | chatcore_metric_dict = {"ChatCORE metric": chatcore_metric} 248 | get_report().log(section="Chat evaluation " + args.source, data=[ 249 | vars(args), # CLI args 250 | results, 251 | chatcore_metric_dict, 252 | ]) 253 | 254 | compute_cleanup() 255 | -------------------------------------------------------------------------------- /rustbpe/Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 4 4 | 5 | [[package]] 6 | name = "ahash" 7 | version = "0.8.12" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" 10 | dependencies = [ 11 | "cfg-if", 12 | "getrandom", 13 | "once_cell", 14 | "version_check", 15 | "zerocopy", 16 | ] 17 | 18 | [[package]] 19 | name = "aho-corasick" 20 | version = "1.1.3" 21 | source = "registry+https://github.com/rust-lang/crates.io-index" 22 | checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" 23 | dependencies = [ 24 | "memchr", 25 | ] 26 | 27 | [[package]] 28 | name = "arc-swap" 29 | version = "1.7.1" 30 | source = "registry+https://github.com/rust-lang/crates.io-index" 31 | checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" 32 | 33 | [[package]] 34 | name = "autocfg" 35 | version = "1.5.0" 36 | source = "registry+https://github.com/rust-lang/crates.io-index" 37 | checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" 38 | 39 | [[package]] 40 | name = "bit-set" 41 | version = "0.8.0" 42 | source = "registry+https://github.com/rust-lang/crates.io-index" 43 | checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" 44 | dependencies = [ 45 | "bit-vec", 46 | ] 47 | 48 | [[package]] 49 | name = "bit-vec" 50 | version = "0.8.0" 51 | source = "registry+https://github.com/rust-lang/crates.io-index" 52 | checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" 53 | 54 | [[package]] 55 | name = "castaway" 56 | version = "0.2.4" 57 | source = "registry+https://github.com/rust-lang/crates.io-index" 58 | checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" 59 | dependencies = [ 60 | "rustversion", 61 | ] 62 | 63 | [[package]] 64 | name = "cfg-if" 65 | version = "1.0.3" 66 | source = "registry+https://github.com/rust-lang/crates.io-index" 67 | checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" 68 | 69 | [[package]] 70 | name = "compact_str" 71 | version = "0.9.0" 72 | source = "registry+https://github.com/rust-lang/crates.io-index" 73 | checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" 74 | dependencies = [ 75 | "castaway", 76 | "cfg-if", 77 | "itoa", 78 | "rustversion", 79 | "ryu", 80 | "static_assertions", 81 | ] 82 | 83 | [[package]] 84 | name = "crossbeam-deque" 85 | version = "0.8.6" 86 | source = "registry+https://github.com/rust-lang/crates.io-index" 87 | checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" 88 | dependencies = [ 89 | "crossbeam-epoch", 90 | "crossbeam-utils", 91 | ] 92 | 93 | [[package]] 94 | name = "crossbeam-epoch" 95 | version = "0.9.18" 96 | source = "registry+https://github.com/rust-lang/crates.io-index" 97 | checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" 98 | dependencies = [ 99 | "crossbeam-utils", 100 | ] 101 | 102 | [[package]] 103 | name = "crossbeam-utils" 104 | version = "0.8.21" 105 | source = "registry+https://github.com/rust-lang/crates.io-index" 106 | checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" 107 | 108 | [[package]] 109 | name = "dary_heap" 110 | version = "0.3.7" 111 | source = "registry+https://github.com/rust-lang/crates.io-index" 112 | checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" 113 | 114 | [[package]] 115 | name = "either" 116 | version = "1.15.0" 117 | source = "registry+https://github.com/rust-lang/crates.io-index" 118 | checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" 119 | 120 | [[package]] 121 | name = "equivalent" 122 | version = "1.0.2" 123 | source = "registry+https://github.com/rust-lang/crates.io-index" 124 | checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" 125 | 126 | [[package]] 127 | name = "fancy-regex" 128 | version = "0.16.1" 129 | source = "registry+https://github.com/rust-lang/crates.io-index" 130 | checksum = "bf04c5ec15464ace8355a7b440a33aece288993475556d461154d7a62ad9947c" 131 | dependencies = [ 132 | "bit-set", 133 | "regex-automata", 134 | "regex-syntax", 135 | ] 136 | 137 | [[package]] 138 | name = "getrandom" 139 | version = "0.3.3" 140 | source = "registry+https://github.com/rust-lang/crates.io-index" 141 | checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" 142 | dependencies = [ 143 | "cfg-if", 144 | "libc", 145 | "r-efi", 146 | "wasi", 147 | ] 148 | 149 | [[package]] 150 | name = "hashbrown" 151 | version = "0.15.5" 152 | source = "registry+https://github.com/rust-lang/crates.io-index" 153 | checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" 154 | 155 | [[package]] 156 | name = "heck" 157 | version = "0.5.0" 158 | source = "registry+https://github.com/rust-lang/crates.io-index" 159 | checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" 160 | 161 | [[package]] 162 | name = "indexmap" 163 | version = "2.11.0" 164 | source = "registry+https://github.com/rust-lang/crates.io-index" 165 | checksum = "f2481980430f9f78649238835720ddccc57e52df14ffce1c6f37391d61b563e9" 166 | dependencies = [ 167 | "equivalent", 168 | "hashbrown", 169 | ] 170 | 171 | [[package]] 172 | name = "indoc" 173 | version = "2.0.6" 174 | source = "registry+https://github.com/rust-lang/crates.io-index" 175 | checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" 176 | 177 | [[package]] 178 | name = "itoa" 179 | version = "1.0.15" 180 | source = "registry+https://github.com/rust-lang/crates.io-index" 181 | checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" 182 | 183 | [[package]] 184 | name = "libc" 185 | version = "0.2.175" 186 | source = "registry+https://github.com/rust-lang/crates.io-index" 187 | checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" 188 | 189 | [[package]] 190 | name = "log" 191 | version = "0.4.28" 192 | source = "registry+https://github.com/rust-lang/crates.io-index" 193 | checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" 194 | 195 | [[package]] 196 | name = "memchr" 197 | version = "2.7.5" 198 | source = "registry+https://github.com/rust-lang/crates.io-index" 199 | checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" 200 | 201 | [[package]] 202 | name = "memoffset" 203 | version = "0.9.1" 204 | source = "registry+https://github.com/rust-lang/crates.io-index" 205 | checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" 206 | dependencies = [ 207 | "autocfg", 208 | ] 209 | 210 | [[package]] 211 | name = "once_cell" 212 | version = "1.21.3" 213 | source = "registry+https://github.com/rust-lang/crates.io-index" 214 | checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" 215 | 216 | [[package]] 217 | name = "portable-atomic" 218 | version = "1.11.1" 219 | source = "registry+https://github.com/rust-lang/crates.io-index" 220 | checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" 221 | 222 | [[package]] 223 | name = "proc-macro2" 224 | version = "1.0.101" 225 | source = "registry+https://github.com/rust-lang/crates.io-index" 226 | checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" 227 | dependencies = [ 228 | "unicode-ident", 229 | ] 230 | 231 | [[package]] 232 | name = "pyo3" 233 | version = "0.23.5" 234 | source = "registry+https://github.com/rust-lang/crates.io-index" 235 | checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872" 236 | dependencies = [ 237 | "cfg-if", 238 | "indoc", 239 | "libc", 240 | "memoffset", 241 | "once_cell", 242 | "portable-atomic", 243 | "pyo3-build-config", 244 | "pyo3-ffi", 245 | "pyo3-macros", 246 | "unindent", 247 | ] 248 | 249 | [[package]] 250 | name = "pyo3-build-config" 251 | version = "0.23.5" 252 | source = "registry+https://github.com/rust-lang/crates.io-index" 253 | checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb" 254 | dependencies = [ 255 | "once_cell", 256 | "target-lexicon", 257 | ] 258 | 259 | [[package]] 260 | name = "pyo3-ffi" 261 | version = "0.23.5" 262 | source = "registry+https://github.com/rust-lang/crates.io-index" 263 | checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d" 264 | dependencies = [ 265 | "libc", 266 | "pyo3-build-config", 267 | ] 268 | 269 | [[package]] 270 | name = "pyo3-log" 271 | version = "0.12.4" 272 | source = "registry+https://github.com/rust-lang/crates.io-index" 273 | checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264" 274 | dependencies = [ 275 | "arc-swap", 276 | "log", 277 | "pyo3", 278 | ] 279 | 280 | [[package]] 281 | name = "pyo3-macros" 282 | version = "0.23.5" 283 | source = "registry+https://github.com/rust-lang/crates.io-index" 284 | checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da" 285 | dependencies = [ 286 | "proc-macro2", 287 | "pyo3-macros-backend", 288 | "quote", 289 | "syn", 290 | ] 291 | 292 | [[package]] 293 | name = "pyo3-macros-backend" 294 | version = "0.23.5" 295 | source = "registry+https://github.com/rust-lang/crates.io-index" 296 | checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028" 297 | dependencies = [ 298 | "heck", 299 | "proc-macro2", 300 | "pyo3-build-config", 301 | "quote", 302 | "syn", 303 | ] 304 | 305 | [[package]] 306 | name = "quote" 307 | version = "1.0.40" 308 | source = "registry+https://github.com/rust-lang/crates.io-index" 309 | checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" 310 | dependencies = [ 311 | "proc-macro2", 312 | ] 313 | 314 | [[package]] 315 | name = "r-efi" 316 | version = "5.3.0" 317 | source = "registry+https://github.com/rust-lang/crates.io-index" 318 | checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" 319 | 320 | [[package]] 321 | name = "rayon" 322 | version = "1.11.0" 323 | source = "registry+https://github.com/rust-lang/crates.io-index" 324 | checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" 325 | dependencies = [ 326 | "either", 327 | "rayon-core", 328 | ] 329 | 330 | [[package]] 331 | name = "rayon-core" 332 | version = "1.13.0" 333 | source = "registry+https://github.com/rust-lang/crates.io-index" 334 | checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" 335 | dependencies = [ 336 | "crossbeam-deque", 337 | "crossbeam-utils", 338 | ] 339 | 340 | [[package]] 341 | name = "regex-automata" 342 | version = "0.4.10" 343 | source = "registry+https://github.com/rust-lang/crates.io-index" 344 | checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6" 345 | dependencies = [ 346 | "aho-corasick", 347 | "memchr", 348 | "regex-syntax", 349 | ] 350 | 351 | [[package]] 352 | name = "regex-syntax" 353 | version = "0.8.6" 354 | source = "registry+https://github.com/rust-lang/crates.io-index" 355 | checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" 356 | 357 | [[package]] 358 | name = "rustbpe" 359 | version = "0.1.0" 360 | dependencies = [ 361 | "ahash", 362 | "compact_str", 363 | "dary_heap", 364 | "fancy-regex", 365 | "indexmap", 366 | "log", 367 | "pyo3", 368 | "pyo3-log", 369 | "rayon", 370 | ] 371 | 372 | [[package]] 373 | name = "rustversion" 374 | version = "1.0.22" 375 | source = "registry+https://github.com/rust-lang/crates.io-index" 376 | checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" 377 | 378 | [[package]] 379 | name = "ryu" 380 | version = "1.0.20" 381 | source = "registry+https://github.com/rust-lang/crates.io-index" 382 | checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" 383 | 384 | [[package]] 385 | name = "static_assertions" 386 | version = "1.1.0" 387 | source = "registry+https://github.com/rust-lang/crates.io-index" 388 | checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" 389 | 390 | [[package]] 391 | name = "syn" 392 | version = "2.0.106" 393 | source = "registry+https://github.com/rust-lang/crates.io-index" 394 | checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" 395 | dependencies = [ 396 | "proc-macro2", 397 | "quote", 398 | "unicode-ident", 399 | ] 400 | 401 | [[package]] 402 | name = "target-lexicon" 403 | version = "0.12.16" 404 | source = "registry+https://github.com/rust-lang/crates.io-index" 405 | checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" 406 | 407 | [[package]] 408 | name = "unicode-ident" 409 | version = "1.0.18" 410 | source = "registry+https://github.com/rust-lang/crates.io-index" 411 | checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" 412 | 413 | [[package]] 414 | name = "unindent" 415 | version = "0.2.4" 416 | source = "registry+https://github.com/rust-lang/crates.io-index" 417 | checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" 418 | 419 | [[package]] 420 | name = "version_check" 421 | version = "0.9.5" 422 | source = "registry+https://github.com/rust-lang/crates.io-index" 423 | checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" 424 | 425 | [[package]] 426 | name = "wasi" 427 | version = "0.14.4+wasi-0.2.4" 428 | source = "registry+https://github.com/rust-lang/crates.io-index" 429 | checksum = "88a5f4a424faf49c3c2c344f166f0662341d470ea185e939657aaff130f0ec4a" 430 | dependencies = [ 431 | "wit-bindgen", 432 | ] 433 | 434 | [[package]] 435 | name = "wit-bindgen" 436 | version = "0.45.1" 437 | source = "registry+https://github.com/rust-lang/crates.io-index" 438 | checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36" 439 | 440 | [[package]] 441 | name = "zerocopy" 442 | version = "0.8.26" 443 | source = "registry+https://github.com/rust-lang/crates.io-index" 444 | checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" 445 | dependencies = [ 446 | "zerocopy-derive", 447 | ] 448 | 449 | [[package]] 450 | name = "zerocopy-derive" 451 | version = "0.8.26" 452 | source = "registry+https://github.com/rust-lang/crates.io-index" 453 | checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" 454 | dependencies = [ 455 | "proc-macro2", 456 | "quote", 457 | "syn", 458 | ] 459 | -------------------------------------------------------------------------------- /scripts/mid_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Midtrain the model. Same as pretraining but simpler. 3 | Run as: 4 | 5 | python -m scripts.mid_train 6 | 7 | Or torchrun for training: 8 | 9 | torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16 10 | """ 11 | 12 | from collections import deque 13 | import os 14 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 15 | import time 16 | import wandb 17 | import torch 18 | from contextlib import nullcontext 19 | from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type 20 | from nanochat.tokenizer import get_token_bytes 21 | from nanochat.checkpoint_manager import save_checkpoint 22 | from nanochat.loss_eval import evaluate_bpb 23 | from nanochat.checkpoint_manager import load_model 24 | import torch.distributed as dist 25 | 26 | from tasks.common import TaskMixture 27 | from tasks.gsm8k import GSM8K 28 | from tasks.mmlu import MMLU 29 | from tasks.smoltalk import SmolTalk 30 | from tasks.customjson import CustomJSON 31 | 32 | # ----------------------------------------------------------------------------- 33 | run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) 34 | device_type = "" # cuda|cpu|mps (empty => autodetect) 35 | model_tag = None # model tag to load the model from (base model or midtrained model) 36 | step = None # step to load the model from (base model or midtrained model) 37 | dtype = "bfloat16" 38 | num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) 39 | max_seq_len = 2048 40 | device_batch_size = 32 41 | unembedding_lr = 0.004 42 | embedding_lr = 0.2 43 | matrix_lr = 0.02 44 | init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate 45 | weight_decay = 0.0 46 | eval_every = 150 # -1 = disable 47 | eval_tokens = 20*524288 48 | total_batch_size = 524288 49 | dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report 50 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 51 | exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file 52 | user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging 53 | # ----------------------------------------------------------------------------- 54 | 55 | # Compute init 56 | device_type = autodetect_device_type() if device_type == "" else device_type 57 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) 58 | master_process = ddp_rank == 0 59 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() 60 | synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None 61 | get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 62 | 63 | # wandb logging init 64 | use_dummy_wandb = run == "dummy" or not master_process 65 | wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config) 66 | 67 | # Load the model and tokenizer 68 | model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step) 69 | pretrain_batch_size = meta.get("device_batch_size", None) 70 | if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size: 71 | print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?") 72 | orig_model = model 73 | model = torch.compile(model, dynamic=False) 74 | depth = model.config.n_layer 75 | num_flops_per_token = model.estimate_flops() 76 | tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank 77 | world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks 78 | assert total_batch_size % world_tokens_per_fwdbwd == 0 79 | grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd 80 | print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") 81 | print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") 82 | print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") 83 | token_bytes = get_token_bytes(device=device) 84 | 85 | # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) 86 | optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) 87 | adamw_optimizer, muon_optimizer = optimizers 88 | # Override the initial learning rate as a fraction of the base learning rate 89 | for opt in optimizers: 90 | for group in opt.param_groups: 91 | group["lr"] = group["lr"] * init_lr_frac 92 | group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later 93 | 94 | # Midtraining data mixture and DataLoader 95 | base_dir = get_base_dir() 96 | identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") 97 | train_dataset = TaskMixture([ 98 | SmolTalk(split="train"), # 460K rows of general conversations 99 | MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE 100 | GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use 101 | CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations 102 | CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these 103 | ]) # total: 460K + 100K + 8K = 568K rows 104 | val_dataset = TaskMixture([ 105 | SmolTalk(split="test"), # 24K rows in test set 106 | MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios 107 | GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios 108 | ]) # total: 24K + 14K + 1.32K ~= 39K rows 109 | # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) 110 | # A big problem is that we don't know the final num_iterations in advance. So we create 111 | # these two global variables and update them from within the data generator. 112 | last_step = False # we will toggle this to True when we reach the end of the dataset 113 | approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch 114 | def mid_data_generator(split): 115 | global last_step, approx_progress 116 | assert split in {"train", "val"}, "split must be 'train' or 'val'" 117 | dataset = train_dataset if split == "train" else val_dataset 118 | dataset_size = len(dataset) 119 | assert dataset_size > 0 120 | needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets 121 | token_buffer = deque() 122 | scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True) 123 | cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents 124 | it = 0 # iteration counter 125 | while True: 126 | # Accumulate enough tokens for one iteration before yielding 127 | while len(token_buffer) < needed_tokens: 128 | conversation = dataset[cursor] 129 | ids, _ = tokenizer.render_conversation(conversation) 130 | token_buffer.extend(ids) 131 | cursor += ddp_world_size 132 | if cursor >= dataset_size: 133 | cursor -= dataset_size # wrap around for another epoch 134 | if split == "train": 135 | last_step = True # toggle last_step to True, which will terminate the training loop 136 | # Stopping condition to respect num_iterations, if given 137 | it += 1 138 | if num_iterations > 0 and it >= num_iterations: 139 | last_step = True # toggle last_step to True, which will terminate the training loop 140 | # Build up inputs/targets and yield 141 | for i in range(needed_tokens): 142 | scratch[i] = token_buffer.popleft() 143 | inputs_cpu = scratch[:-1].to(dtype=torch.int32) 144 | targets_cpu = scratch[1:] 145 | inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True) 146 | targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True) 147 | if split == "train": 148 | if num_iterations > 0: 149 | approx_progress = it / num_iterations # calculate progress from the max number of iterations 150 | else: 151 | approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset 152 | yield inputs, targets 153 | 154 | train_loader = mid_data_generator("train") 155 | build_val_loader = lambda: mid_data_generator("val") 156 | progress = 0 # will go from 0 to 1 over the course of the epoch 157 | 158 | # Learning rate scheduler 159 | def get_lr_multiplier(progress): 160 | # first 80% of training: no decay, then linearly ramp down to 0. 161 | return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 162 | 163 | # Momentum scheduler for Muon optimizer 164 | def get_muon_momentum(it): 165 | frac = min(it / 300, 1) 166 | momentum = (1 - frac) * 0.85 + frac * 0.95 167 | return momentum 168 | 169 | # ----------------------------------------------------------------------------- 170 | # Training loop 171 | x, y = next(train_loader) # prefetch the very first batch of data 172 | min_val_bpb = float("inf") 173 | smooth_train_loss = 0 # EMA of training loss 174 | ema_beta = 0.9 # EMA decay factor 175 | total_training_time = 0 # total wall-clock time of training 176 | step = 0 177 | while True: 178 | flops_so_far = num_flops_per_token * total_batch_size * step 179 | 180 | # Synchronize last_step across all ranks to avoid hangs in the distributed setting 181 | if ddp: 182 | last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device) 183 | dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) 184 | last_step = bool(last_step_tensor.item()) 185 | 186 | # once in a while: evaluate the val bpb (all ranks participate) 187 | if eval_every > 0 and (last_step or step % eval_every == 0): 188 | model.eval() 189 | val_loader = build_val_loader() 190 | eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size) 191 | with autocast_ctx: 192 | val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) 193 | print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") 194 | if val_bpb < min_val_bpb: 195 | min_val_bpb = val_bpb 196 | wandb_run.log({ 197 | "step": step, 198 | "total_training_flops": flops_so_far, 199 | "total_training_time": total_training_time, 200 | "val/bpb": val_bpb, 201 | }) 202 | model.train() 203 | 204 | # save checkpoint at the end of the run (only on master process) 205 | if master_process and last_step and not dry_run: 206 | output_dirname = f"d{depth}" # e.g. d12 207 | checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname) 208 | save_checkpoint( 209 | checkpoint_dir, 210 | step, 211 | orig_model.state_dict(), 212 | [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly 213 | { 214 | "step": step, 215 | "val_bpb": val_bpb, # loss at last step 216 | "model_config": { 217 | "sequence_len": max_seq_len, 218 | "vocab_size": tokenizer.get_vocab_size(), 219 | "n_layer": depth, 220 | "n_head": model.config.n_head, 221 | "n_kv_head": model.config.n_kv_head, 222 | "n_embd": model.config.n_embd, 223 | }, 224 | "user_config": user_config, # inputs to the training script 225 | } 226 | ) 227 | 228 | if last_step: 229 | break 230 | 231 | # ------------------------------------------------------------------------- 232 | # single training step 233 | # evaluate the gradient 234 | synchronize() 235 | t0 = time.time() 236 | for micro_step in range(grad_accum_steps): 237 | with autocast_ctx: 238 | loss = model(x, y) 239 | train_loss = loss.detach() # for logging 240 | loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here 241 | loss.backward() 242 | x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward 243 | progress = max(progress, approx_progress) # only increase progress monotonically 244 | # step the optimizers 245 | lrm = get_lr_multiplier(progress) 246 | for opt in optimizers: 247 | for group in opt.param_groups: 248 | group["lr"] = group["initial_lr"] * lrm 249 | muon_momentum = get_muon_momentum(step) 250 | for group in muon_optimizer.param_groups: 251 | group["momentum"] = muon_momentum 252 | for opt in optimizers: 253 | opt.step() 254 | model.zero_grad(set_to_none=True) 255 | synchronize() 256 | t1 = time.time() 257 | dt = t1 - t0 258 | # ------------------------------------------------------------------------- 259 | 260 | # State 261 | step += 1 262 | 263 | # logging 264 | smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss 265 | debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA 266 | pct_done = 100 * progress 267 | tok_per_sec = int(world_tokens_per_fwdbwd / dt) 268 | flops_per_sec = num_flops_per_token * total_batch_size / dt 269 | promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity 270 | mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % 271 | if step > 10: 272 | total_training_time += dt # only count the time after the first 10 steps 273 | print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") 274 | if step % 10 == 0: 275 | wandb_run.log({ 276 | "step": step, 277 | "total_training_flops": flops_so_far, 278 | "total_training_time": total_training_time, 279 | "train/loss": debiased_smooth_loss, 280 | "train/lrm": lrm, 281 | "train/dt": dt, 282 | "train/tok_per_sec": tok_per_sec, 283 | "train/mfu": mfu, 284 | }) 285 | 286 | # print a few more stats 287 | print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") 288 | print0(f"Total training time: {total_training_time/60:.2f}m") 289 | print0(f"Minimum validation bpb: {min_val_bpb:.4f}") 290 | 291 | # Log to report 292 | if not dry_run: 293 | from nanochat.report import get_report 294 | get_report().log(section="Midtraining", data=[ 295 | user_config, # CLI args 296 | { # stats about the training setup 297 | "Number of iterations": step, 298 | "DDP world size": ddp_world_size, 299 | }, 300 | { # stats about training outcomes 301 | "Minimum validation bpb": min_val_bpb, 302 | } 303 | ]) 304 | 305 | # cleanup 306 | wandb_run.finish() # wandb run finish 307 | compute_cleanup() 308 | --------------------------------------------------------------------------------