├── nanoproof ├── __init__.py ├── data │ ├── __init__.py │ ├── numinamath.py │ ├── others.md │ ├── minif2f.py │ ├── leanworkbook.py │ ├── nemotron_dataloader.py │ ├── leantree.py │ ├── leangithub_urls.txt │ ├── leantree_dataloader.py │ └── nemotron.py ├── core.py ├── configurator.py ├── loss_eval.py ├── adamw.py ├── checkpoints.py ├── experience_collection.py ├── tokenizer.py ├── muon.py ├── rl.py ├── engine.py ├── midtrain.py ├── sft.py ├── common.py └── report.py ├── LICENSE ├── pyproject.toml ├── scripts ├── interact.py ├── inspect_parquet.py ├── policy_eval.py ├── prover_eval.py ├── tok_show.py ├── tok_train.py └── tok_eval.py ├── README.md ├── tests └── test_network.py └── .gitignore /nanoproof/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nanoproof/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nanoproof/data/numinamath.py: -------------------------------------------------------------------------------- 1 | 2 | # https://huggingface.co/datasets/AI-MO/NuminaMath-LEAN -------------------------------------------------------------------------------- /nanoproof/data/others.md: -------------------------------------------------------------------------------- 1 | - LeanUniverse (https://github.com/facebookresearch/LeanUniverse) 2 | - LEAN-GitHub (https://arxiv.org/abs/2407.17227) -------------------------------------------------------------------------------- /nanoproof/core.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | # Observations in AlphaProof are the tactic state. 4 | Observation = str 5 | 6 | # Actions in AlphaProof are Lean tactics (except for special actions, to start a 7 | # disproof, or to focus on a goal). 8 | Action = str 9 | 10 | 11 | @dataclass 12 | class Theorem: 13 | """A theorem to be proved.""" 14 | header: str 15 | statement: str 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Matěj Kripner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "nanoproof" 3 | version = "0.1.0" 4 | description = "minimal open-source implementation of AlphaProof [WIP]" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | dependencies = [ 8 | "datasets>=4.0.0", 9 | "fastapi>=0.117.1", 10 | "huggingface_hub>=0.20.0", 11 | "psutil>=7.1.0", 12 | "regex>=2025.9.1", 13 | "setuptools>=80.9.0", 14 | "tiktoken>=0.11.0", 15 | "tqdm>=4.66.0", 16 | "tokenizers>=0.22.0", 17 | "torch>=2.8.0", 18 | "wandb>=0.21.3", 19 | "leantree", 20 | "termplotlib", 21 | "PrettyPrintTree>=2.0.1", 22 | ] 23 | 24 | [dependency-groups] 25 | dev = [ 26 | "pytest>=8.0.0", 27 | ] 28 | 29 | [tool.pytest.ini_options] 30 | markers = [ 31 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 32 | ] 33 | testpaths = ["tests"] 34 | python_files = ["test_*.py"] 35 | python_classes = ["Test*"] 36 | python_functions = ["test_*"] 37 | 38 | # target torch to cuda 12.8 or CPU 39 | [tool.uv.sources] 40 | torch = [ 41 | { index = "pytorch-cpu", extra = "cpu" }, 42 | { index = "pytorch-cu128", extra = "gpu" }, 43 | ] 44 | 45 | [[tool.uv.index]] 46 | name = "pytorch-cpu" 47 | url = "https://download.pytorch.org/whl/cpu" 48 | explicit = true 49 | 50 | [[tool.uv.index]] 51 | name = "pytorch-cu128" 52 | url = "https://download.pytorch.org/whl/cu128" 53 | explicit = true 54 | 55 | [project.optional-dependencies] 56 | cpu = [ 57 | "torch>=2.8.0", 58 | ] 59 | gpu = [ 60 | "torch>=2.8.0", 61 | ] 62 | 63 | [tool.uv] 64 | conflicts = [ 65 | [ 66 | { extra = "cpu" }, 67 | { extra = "gpu" }, 68 | ], 69 | ] -------------------------------------------------------------------------------- /scripts/interact.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | 3 | import torch 4 | 5 | from nanoproof.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type 6 | from nanoproof.checkpoints import load_model, save_checkpoint 7 | from nanoproof.engine import Engine 8 | 9 | source = "sft" # which checkpoint to load the model from 10 | model_tag = "d26" # model tag to load the model from 11 | device_type = "" # cuda|cpu|mps (empty => autodetect) 12 | dtype = "bfloat16" 13 | base_dir = get_base_dir() 14 | 15 | device_type = autodetect_device_type() if device_type == "" else device_type 16 | device = torch.device(device_type) 17 | ptdtype = torch.float32 if dtype == "float32" else torch.bfloat16 18 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() 19 | 20 | model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag) 21 | engine = Engine(model, tokenizer) 22 | 23 | def generate(inp_) -> str: 24 | tokens = tokenizer(inp_.strip() + "\n<|tactic|>", prepend="<|bos|>") 25 | with autocast_ctx: 26 | sample_toks, _ = engine.generate_batch(tokens, num_samples=1, min_tokens=1, max_tokens=64) 27 | return tokenizer.decode(sample_toks[0]) 28 | 29 | def get_input() -> str: 30 | lines = [] 31 | print("Type in a tactic state, followed by an empty line:") 32 | line = input() 33 | while line.strip() or not lines: 34 | lines.append(line.rstrip()) 35 | line = input() 36 | return "\n".join(lines) 37 | 38 | inp = get_input() 39 | while inp.strip() not in ["q", "quit", "exit"]: 40 | print(f"Generating ...") 41 | tactic = generate(inp) 42 | print(f"Tactic:\n--\n'{tactic}'\n--") 43 | inp = get_input() 44 | print("Done.") 45 | 46 | INP1 = """ 47 | z : ℂ 48 | h₀ : z = (1 + Complex.I) / ↑√2 49 | ⊢ (∑ k ∈ Finset.Icc 1 12, z ^ k ^ 2) * ∑ k ∈ Finset.Icc 1 12, 1 / z ^ k ^ 2 = 36 50 | """ 51 | 52 | INP2 = """ 53 | ⊢ 2 + 3 = 5 54 | """ -------------------------------------------------------------------------------- /nanoproof/configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import os 18 | import sys 19 | from ast import literal_eval 20 | 21 | def print0(s="",**kwargs): 22 | ddp_rank = int(os.environ.get('RANK', 0)) 23 | if ddp_rank == 0: 24 | print(s, **kwargs) 25 | 26 | for arg in sys.argv[1:]: 27 | if '=' not in arg: 28 | # assume it's the name of a config file 29 | assert not arg.startswith('--') 30 | config_file = arg 31 | print0(f"Overriding config with {config_file}:") 32 | with open(config_file) as f: 33 | print0(f.read()) 34 | exec(open(config_file).read()) 35 | else: 36 | # assume it's a --key=value argument 37 | assert arg.startswith('--') 38 | key, val = arg.split('=') 39 | key = key[2:] 40 | if key in globals(): 41 | try: 42 | # attempt to eval it it (e.g. if bool, number, or etc) 43 | attempt = literal_eval(val) 44 | except (SyntaxError, ValueError): 45 | # if that goes wrong, just use the string 46 | attempt = val 47 | # ensure the types match ok 48 | if globals()[key] is not None: 49 | attempt_type = type(attempt) 50 | default_type = type(globals()[key]) 51 | assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}" 52 | # cross fingers 53 | print0(f"Overriding: {key} = {attempt}") 54 | globals()[key] = attempt 55 | else: 56 | raise ValueError(f"Unknown config key: {key}") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nanoproof 2 | 3 | This is an attempt to replicate AlphaProof. It is based on [nanochat](https://github.com/karpathy/nanochat) and the 4 | official AlphaProof pseudocode (together with many open-source datasets and tools). So far, we have: 5 | - pretraining on Nemotron-CC-Math 6 | - midtraining on Lean code from GitHub 7 | - supervised fine-tuning on LeanTree (transitions extracted from Mathlib) 8 | - interaction with Lean using a LeanTree server 9 | - MCTS-based prover 10 | - a simple RL training loop 11 | - evaluation script for success rate on MiniF2F and Lean-Workbook 12 | 13 | The best score achieved so far is **32.8% on MiniF2F** (more precisely on the subset of its first 64 theorems). 14 | 15 | This project is in early stages and still a bit hard to work with. If you want to contribute, the best way to start is to write me an email! 16 | 17 | 18 | # Questions 19 | 20 | - how is the action prob obtained from tokens probs? 21 | - is value predicted for state (as per paper) or for state-action (as per pseudocode)? 22 | - more importantly: do the bins correspond to Q 0-1 or V 1-inf? 23 | - Milan: it's V 24 | - how do the value bins correspond to values? Ie. what are the values of bins 0 and 63? 25 | - Supplemental Data Table 4: is bs=4096 in tokens or in samples? 26 | Probably in samples - if in tokens, otherwise we would only use ~10% of the data: `4096/64 samples-per-batch * 500 steps = 32k samples`, but 300k are available. 27 | (Also 4096 samples-per-batch makes sense for the Pre-Training, where it yields something on the order of 50 epochs that Julian reported) 28 | - what is the value of a child that was not visited yet? 29 | - zero as per pseudocode (line 570) - that would be weird/wrong 30 | - parent minus "UCB unvisited children value penalty" (=32) as per paper 31 | - beta and gamma are the same? (as per code) 32 | - In the pseudocode, is_optimal is not set on the new children/grandchildren created when expanding a node, even if they are terminal. 33 | - Replay buffer size seems to be 250k in the pseudocode but 60M in the paper (Supplemental Data Table 6) 34 | 35 | # Ideas 36 | 37 | - try training on state_after as well, just to give the model more training signal (it was done in some paper, maybe GPT-f) 38 | - let tokens attend bi-directionally inside the fixed-size state (a la PrefixLM) 39 | - try proving the negation in each node (if critic deems it likely to succeed) 40 | 41 | # Setup 42 | 43 | ``` 44 | cd nanoproof 45 | uv sync --extra cpu --group dev 46 | source .venv/bin/activate 47 | 48 | hf auth login 49 | 50 | python -m nanoproof.dataset 51 | python -m scripts.tok_train 52 | python -m nanoproof.pretrain 53 | ``` 54 | 55 | or 56 | 57 | ``` 58 | torchrun --standalone --nproc_per_node=2 -m nanoproof.pretrain 59 | ``` -------------------------------------------------------------------------------- /nanoproof/data/minif2f.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | 5 | import requests 6 | 7 | from nanoproof.common import get_base_dir 8 | 9 | base_dir = get_base_dir() 10 | DATA_DIR = os.path.join(base_dir, "data", "minif2f") 11 | 12 | BASE_URL = "https://raw.githubusercontent.com/google-deepmind/miniF2F/refs/heads/main/MiniF2F/" 13 | 14 | 15 | def list_theorems(split): 16 | assert split in ["Valid", "Test"] 17 | file_path = Path(DATA_DIR) / f"{split}.lean" 18 | blocks = file_path.read_text().split("\n\n") 19 | theorems = [] 20 | for block in blocks: 21 | lines = block.split("\n") 22 | theorem_line_idx = next((i for i, line in enumerate(lines) if line.startswith("theorem")), None) 23 | if theorem_line_idx is None: 24 | continue 25 | theorem = "\n".join([line.rstrip() for line in lines[theorem_line_idx:]]) 26 | theorems.append(theorem.strip()) 27 | assert all("sorry" in t for t in theorems), "Found a theorem with no `sorry`." 28 | return theorems 29 | 30 | def get_imports(): 31 | file_path = Path(DATA_DIR) / "ProblemImports.lean" 32 | return file_path.read_text() + """ 33 | open scoped Real 34 | open scoped Nat 35 | open scoped Topology 36 | open scoped Polynomial""" 37 | 38 | def download_dataset(): 39 | """Download the miniF2F dataset from GitHub.""" 40 | os.makedirs(DATA_DIR, exist_ok=True) 41 | for filename in ["Valid.lean", "Test.lean", "ProblemImports.lean"]: 42 | file_path = os.path.join(DATA_DIR, filename) 43 | if os.path.exists(file_path): 44 | print(f"File already exists, skipping: {file_path}") 45 | continue 46 | 47 | url = BASE_URL + filename 48 | print(f"Downloading {filename} from {url}...") 49 | response = requests.get(url, timeout=60) 50 | response.raise_for_status() 51 | 52 | with open(file_path, "w", encoding="utf-8") as f: 53 | f.write(response.text) 54 | 55 | print(f"Successfully downloaded {file_path}") 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser() 60 | subparsers = parser.add_subparsers(dest="action") 61 | download_parser = subparsers.add_parser("download") 62 | show_parser = subparsers.add_parser("show") 63 | show_parser.add_argument("--split", choices=["Valid", "Test"], default="Valid") 64 | args = parser.parse_args() 65 | 66 | if args.action == "download": 67 | download_dataset() 68 | elif args.action == "show": 69 | for theorem in list_theorems(args.split): 70 | print(theorem) 71 | print("\n-----------------\n") 72 | else: 73 | raise f"Unknown action {args.action}" -------------------------------------------------------------------------------- /nanoproof/data/leanworkbook.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import requests 4 | import random 5 | 6 | from tqdm import tqdm 7 | 8 | from nanoproof.common import get_base_dir 9 | 10 | base_dir = get_base_dir() 11 | DATA_DIR = os.path.join(base_dir, "data", "leanworkbook") 12 | 13 | HF_URL = "https://huggingface.co/datasets/internlm/Lean-Workbook/resolve/main/lean_workbook.json" 14 | 15 | # gather de-duplicated formal_statement from: 16 | # https://huggingface.co/datasets/internlm/Lean-Workbook 17 | 18 | def download_dataset(): 19 | """Download the Lean-Workbook dataset from HuggingFace.""" 20 | json_path = os.path.join(DATA_DIR, "lean_workbook.json") 21 | 22 | # skip if already downloaded 23 | if os.path.exists(json_path): 24 | print(f"Dataset already downloaded at {json_path}") 25 | return 26 | 27 | try: 28 | print(f"Downloading Lean-Workbook dataset from HuggingFace...") 29 | response = requests.get(HF_URL, stream=True, timeout=60) 30 | response.raise_for_status() 31 | 32 | temp_path = json_path + ".tmp" 33 | total_size = int(response.headers.get("content-length", 0)) 34 | with open(temp_path, "wb") as f: 35 | with tqdm(total=total_size, unit="B", unit_scale=True, unit_divisor=1024, desc="Downloading lean_workbook.json") as pbar: 36 | for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks 37 | if chunk: 38 | f.write(chunk) 39 | pbar.update(len(chunk)) 40 | 41 | os.rename(temp_path, json_path) 42 | print(f"Successfully downloaded {json_path}") 43 | except (requests.RequestException, IOError): 44 | # Clean up any partial files 45 | for path in [json_path + ".tmp", json_path]: 46 | if os.path.exists(path): 47 | print(f"Cleaning up {path}") 48 | os.remove(path) 49 | raise 50 | 51 | def list_theorems(split: str): 52 | assert split in ["train", "val"], f"Invalid split: {split}. Must be 'train' or 'val'." 53 | 54 | json_path = os.path.join(DATA_DIR, "lean_workbook.json") 55 | if not os.path.exists(json_path): 56 | raise FileNotFoundError(f"Lean-Workbook dataset not found at {json_path}. Download it first.") 57 | with open(json_path, "r") as f: 58 | data = json.load(f) 59 | # select theorems that have been proven by InternLM Prover 60 | theorems = [item["formal_statement"] for item in data if item["proof"]] 61 | 62 | # shuffle with fixed seed and split into train/val 63 | random.Random(0).shuffle(theorems) 64 | 65 | if split == "val": 66 | return theorems[-500:] 67 | else: # train 68 | return theorems[:-500] 69 | 70 | if __name__ == "__main__": 71 | download_dataset() 72 | train_theorems = list_theorems(split="train") 73 | val_theorems = list_theorems(split="val") 74 | print(f"Retrieved {len(train_theorems)} train theorems") 75 | print(train_theorems[0]) 76 | print() 77 | print(f"Retrieved {len(val_theorems)} val theorems") 78 | print(val_theorems[0]) -------------------------------------------------------------------------------- /scripts/inspect_parquet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import sys 4 | import requests 5 | import pyarrow as pa 6 | import pyarrow.parquet as pq 7 | import pyarrow.compute as pc 8 | 9 | 10 | def inspect_parquet(path: str) -> None: 11 | """Inspect Parquet file structure and compute stats.""" 12 | pf = pq.ParquetFile(path) 13 | metadata = pf.metadata 14 | # Use the Arrow schema for column inspection 15 | schema: pa.Schema = pf.schema_arrow 16 | 17 | print("=== Parquet File Overview ===") 18 | print(f"File path: {path}") 19 | print(f"Num row groups: {pf.num_row_groups}") 20 | print(f"Num columns: {len(schema)}") 21 | print(f"Total rows (samples): {metadata.num_rows}\n") 22 | 23 | print("=== Schema ===") 24 | for field in schema: 25 | print(f" - {field.name}: {field.type}") 26 | print() 27 | 28 | print("=== Row Groups ===") 29 | for rg_idx in range(pf.num_row_groups): 30 | rg_meta = metadata.row_group(rg_idx) 31 | rg_rows = rg_meta.num_rows 32 | rg_total_bytes = sum( 33 | rg_meta.column(col_idx).total_compressed_size 34 | for col_idx in range(rg_meta.num_columns) 35 | ) 36 | print( 37 | f" Row group {rg_idx}: " 38 | f"rows={rg_rows}, compressed_size≈{rg_total_bytes} bytes" 39 | ) 40 | print() 41 | 42 | # Compute total number of characters in `text` column 43 | print("=== Text Column Statistics ===") 44 | if "text" not in schema.names: 45 | print("Column 'text' not found in schema.") 46 | return 47 | 48 | text_idx = schema.get_field_index("text") 49 | text_field = schema.field(text_idx) 50 | 51 | if not (pa.types.is_string(text_field.type) or pa.types.is_large_string(text_field.type)): 52 | print( 53 | f"Warning: 'text' column is type {text_field.type}, not string; " 54 | "attempting to cast when reading." 55 | ) 56 | cast_to_string = True 57 | else: 58 | cast_to_string = False 59 | 60 | total_chars = 0 61 | total_non_null_rows = 0 62 | batch_size = 65536 63 | 64 | for batch in pf.iter_batches(columns=["text"], batch_size=batch_size): 65 | arr = batch.column(0) 66 | 67 | if cast_to_string: 68 | arr = pc.cast(arr, pa.string()) 69 | 70 | # Length of each string; nulls become null 71 | lengths = pc.utf8_length(arr) 72 | # Sum of lengths (ignores nulls) 73 | batch_char_sum = pc.sum(lengths).as_py() or 0 74 | total_chars += batch_char_sum 75 | 76 | # Count non-null entries 77 | non_null_count = batch.num_rows - arr.null_count 78 | total_non_null_rows += non_null_count 79 | 80 | print(f"Total non-null rows in 'text' column: {total_non_null_rows}") 81 | print(f"Total number of characters in 'text' column: {total_chars}") 82 | print() 83 | 84 | 85 | def main(): 86 | if len(sys.argv) != 2: 87 | print("Usage: python inspect_parquet.py ", file=sys.stderr) 88 | sys.exit(1) 89 | 90 | dest = sys.argv[1] 91 | 92 | if not os.path.exists(dest): 93 | print(f"Error: File {dest} does not exist", file=sys.stderr) 94 | sys.exit(1) 95 | 96 | inspect_parquet(dest) 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /nanoproof/loss_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | A number of functions that help with evaluating a base model. 3 | """ 4 | import math 5 | import os 6 | 7 | import torch 8 | import torch.distributed as dist 9 | 10 | @torch.no_grad() 11 | def evaluate_bpb(model, batches, steps, token_bytes): 12 | """ 13 | Instead of the naive 'mean loss', this function returns the bits per byte (bpb), 14 | which is a tokenization vocab size-independent metric, meaning you are still comparing 15 | apples:apples if you change the vocab size. The way this works is that instead of just 16 | calculating the average loss as usual, you calculate the sum loss, and independently 17 | also the sum bytes (of all the target tokens), and divide. This normalizes the loss by 18 | the number of bytes that the target tokens represent. 19 | 20 | The added complexity is so that: 21 | 1) All "normal" tokens are normalized by the length of the token in bytes 22 | 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out. 23 | 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric. 24 | 25 | In addition to evaluate_loss, we need the token_bytes tensor: 26 | It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for 27 | each token id, or 0 if the token is to not be counted (e.g. special tokens). 28 | """ 29 | # record the losses 30 | total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device()) 31 | total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device()) 32 | batch_iter = iter(batches) 33 | for i in range(steps): 34 | try: 35 | x, y = next(batch_iter) 36 | except StopIteration: 37 | print(f"Warning: Rank {dist.get_rank()} reached end of batches at eval step {i}") 38 | break 39 | loss2d = model(x, y, loss_reduction='none') # (B, T) 40 | loss2d = loss2d.view(-1) # flatten 41 | y = y.view(-1) # flatten 42 | if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32 43 | # slightly more complex code path if some target tokens are ignore_index (e.g. -1) 44 | # any target token < 0 is to be ignored: do NOT index token_bytes with negatives 45 | valid = y >= 0 46 | y_safe = torch.where(valid, y, torch.zeros_like(y)) 47 | # map valid targets to their byte length; ignored targets contribute 0 bytes 48 | num_bytes2d = torch.where( 49 | valid, 50 | token_bytes[y_safe], 51 | torch.zeros_like(y, dtype=token_bytes.dtype) 52 | ) 53 | total_nats += (loss2d * (num_bytes2d > 0)).sum() 54 | total_bytes += num_bytes2d.sum() 55 | else: 56 | # fast path: no ignored targets, safe to index directly 57 | num_bytes2d = token_bytes[y] 58 | total_nats += (loss2d * (num_bytes2d > 0)).sum() 59 | total_bytes += num_bytes2d.sum() 60 | # sum reduce across all ranks 61 | world_size = dist.get_world_size() if dist.is_initialized() else 1 62 | if world_size > 1: 63 | dist.all_reduce(total_nats, op=dist.ReduceOp.SUM) 64 | dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM) 65 | # move both to cpu, calculate bpb and return 66 | total_nats = total_nats.item() 67 | total_bytes = total_bytes.item() 68 | if total_bytes == 0: 69 | return float('inf') 70 | bpb = total_nats / (math.log(2) * total_bytes) 71 | return bpb -------------------------------------------------------------------------------- /scripts/policy_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from itertools import islice 3 | import sys 4 | import os 5 | from contextlib import nullcontext 6 | 7 | from nanoproof.common import compute_init, autodetect_device_type, print0 8 | from nanoproof.checkpoints import load_model 9 | from nanoproof.data.leantree import iter_data 10 | from nanoproof.data.leantree_dataloader import sft_data_generator 11 | 12 | @torch.inference_mode() 13 | def eval_tactic_accuracy(model, leantree_batches, max_steps=None): 14 | total_samples = 0 15 | total_full_correct = 0 16 | total_first_token_correct = 0 17 | 18 | for x, y, _, _ in leantree_batches if max_steps is None else islice(leantree_batches, max_steps): 19 | logits = model(x) # (B, T, V) 20 | predictions = torch.argmax(logits, dim=-1) # (B, T) 21 | 22 | mask = (y != -1) 23 | correct = predictions == y 24 | 25 | assert mask.any(dim=1).all(), "leantree sample contained no output tokens" 26 | total_samples += logits.shape[0] 27 | 28 | # Full Accuracy: correctness on all non-masked tokens 29 | total_full_correct += (correct | torch.logical_not(mask)).all(dim=1).sum().item() 30 | 31 | # First Token Accuracy: correctness on the first non-masked token 32 | first_token_indices = mask.int().argmax(dim=1) # argmax returns the first True index 33 | batch_indices = torch.arange(logits.shape[0], device=logits.device) 34 | total_first_token_correct += correct[batch_indices, first_token_indices].sum().item() 35 | 36 | return { 37 | "full_acc": total_full_correct / total_samples, 38 | "first_token_acc": total_first_token_correct / total_samples, 39 | } 40 | 41 | def main(): 42 | # Setup 43 | device_type = autodetect_device_type() 44 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) 45 | 46 | print0("Loading model...") 47 | model, tokenizer, meta = load_model("sft", device, phase="eval", model_tag="d26") 48 | model.eval() 49 | 50 | print0(f"Model loaded. Config: {meta.get('model_config', 'N/A')}") 51 | 52 | # Load Data 53 | print0("Loading dataset...") 54 | split = "val" 55 | dataset = list(iter_data(split=split)) 56 | 57 | if len(dataset) == 0: 58 | print0("Dataset is empty!") 59 | return 60 | 61 | batch_size = 32 62 | 63 | # Calculate steps 64 | # We want to iterate through the dataset exactly once. 65 | # sft_data_generator yields batches of size `batch_size`. 66 | # It repeats the dataset indefinitely. 67 | # We calculate how many batches correspond to one epoch. 68 | # Each item in dataset produces 2 samples. 69 | # DDP handles sharding. 70 | 71 | my_dataset_len = len(range(ddp_rank, len(dataset), ddp_world_size)) 72 | total_samples_local = my_dataset_len * 2 73 | steps = total_samples_local // batch_size 74 | 75 | if steps == 0: 76 | print0("Not enough data for one batch.") 77 | return 78 | 79 | print0(f"Evaluating on {steps} batches (approx {steps * batch_size} samples)...") 80 | 81 | data_gen = sft_data_generator(dataset, batch_size, device=device) 82 | 83 | dtype = "bfloat16" 84 | ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16 85 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() 86 | with autocast_ctx: 87 | results = eval_tactic_accuracy(model, data_gen, max_steps=steps) 88 | 89 | print0(f"Results for split '{split}':") 90 | print0(f"Full Accuracy: {results['full_acc']:.4%}") 91 | print0(f"First Token Accuracy: {results['first_token_acc']:.4%}") 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /nanoproof/adamw.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrowed from modded-nanogpt. By Keller, @vagrawal, et al. 3 | Not a general optimizer! But works for our specific use. 4 | """ 5 | import torch 6 | import torch.distributed as dist 7 | from torch import Tensor 8 | 9 | 10 | class DistAdamW(torch.optim.Optimizer): 11 | """ 12 | Distributed AdamW optimizer. 13 | In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction 14 | """ 15 | def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): 16 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 17 | super().__init__(param_groups, defaults) 18 | 19 | @torch.compile 20 | @torch.no_grad() 21 | def step(self): 22 | rank = dist.get_rank() 23 | world_size = dist.get_world_size() 24 | reduce_scatter_futures: list[torch.Future] = [] 25 | all_reduce_futures: list[torch.Future] = [] 26 | grad_slices = [] 27 | for group in self.param_groups: 28 | params: list[Tensor] = group["params"] 29 | for base_i in range(len(params)): 30 | grad = params[base_i].grad 31 | rank_size = grad.shape[0] // world_size 32 | assert grad.shape[0] % world_size == 0, f"parameter at dimension 0 is not divisible by {world_size=} ({grad.shape=})" 33 | grad_slice = torch.empty_like(grad[:rank_size]) 34 | reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) 35 | grad_slices.append(grad_slice) 36 | 37 | idx = 0 38 | for group in self.param_groups: 39 | beta1, beta2 = group['betas'] 40 | eps = group['eps'] 41 | wd = group['weight_decay'] 42 | params = group['params'] 43 | for base in range(len(params)): 44 | reduce_scatter_futures[idx].wait() 45 | p = params[base] 46 | rank_size = p.shape[0] // world_size 47 | p_slice = p[rank * rank_size:(rank + 1) * rank_size] 48 | lr = group['lr'] * getattr(p, "lr_mul", 1.0) 49 | state = self.state[p] 50 | g_slice = grad_slices[idx] 51 | # State init 52 | if not state: 53 | state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) 54 | state['exp_avg'] = torch.zeros_like(p_slice) 55 | state['exp_avg_sq'] = torch.zeros_like(p_slice) 56 | exp_avg = state['exp_avg'] 57 | exp_avg_sq = state['exp_avg_sq'] 58 | state['step'] += 1 59 | t = state['step'] 60 | # weight decay 61 | if wd != 0: 62 | eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) 63 | p_slice.mul_(1 - eff_weight_decay) 64 | # update running averages 65 | exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) 66 | exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) 67 | # bias corrections 68 | bias1 = 1 - beta1 ** t 69 | bias2 = 1 - beta2 ** t 70 | # compute step 71 | denom = exp_avg_sq.sqrt().add_(eps) 72 | step_size = lr * (torch.sqrt(bias2) / bias1) 73 | update = exp_avg.div(denom).mul_(step_size) 74 | p_slice.add_(other=update, alpha=-1.0) 75 | idx += 1 76 | all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) 77 | torch.futures.collect_all(all_reduce_futures).wait() -------------------------------------------------------------------------------- /tests/test_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import os 4 | from unittest.mock import MagicMock 5 | 6 | # Add repo root to path 7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 8 | 9 | # Mock nanoproof.common to avoid import errors and control get_dist_info 10 | sys.modules["nanoproof.common"] = MagicMock() 11 | # We also need to mock muon and adamw if we want to avoid importing them or if they have deps 12 | # But let's try to only mock common first, or mock all if they are just utils 13 | # Given model.py imports them, we can mock them to be safe and simple 14 | sys.modules["nanoproof.muon"] = MagicMock() 15 | sys.modules["nanoproof.adamw"] = MagicMock() 16 | 17 | from nanoproof.model import Network, NetworkConfig, ValueHead 18 | 19 | def test_network(): 20 | print("Testing Network...") 21 | config = NetworkConfig( 22 | sequence_len=32, 23 | vocab_size=100, 24 | n_layer=2, 25 | n_head=4, 26 | n_kv_head=2, 27 | n_embd=32, 28 | num_value_bins=10 29 | ) 30 | 31 | model = Network(config) 32 | model.init_weights() 33 | 34 | # Test forward pass 35 | bs, seq_len = 2, 16 36 | idx = torch.randint(0, config.vocab_size, (bs, seq_len)) 37 | 38 | output = model(idx) 39 | 40 | print(f"Policy logits shape: {output.policy_logits.shape}") 41 | print(f"Value logits shape: {output.value_logits.shape}") 42 | 43 | assert output.policy_logits.shape == (bs, seq_len, config.vocab_size) 44 | assert output.value_logits.shape == (bs, seq_len, config.num_value_bins) 45 | 46 | # Test ValueHead to_scalar 47 | print("Testing ValueHead to_scalar...") 48 | value_head = model.value_head 49 | # Create logits that strongly favor the last bin (max value) 50 | logits = torch.zeros((bs, seq_len, config.num_value_bins)) 51 | logits[..., -1] = 100.0 52 | 53 | scalar_val = value_head.to_scalar(logits) 54 | print(f"Scalar values shape: {scalar_val.shape}") 55 | print(f"Scalar values (should be close to max_value={config.max_value}): {scalar_val[0,0].item()}") 56 | 57 | assert scalar_val.shape == (bs, seq_len) 58 | assert torch.allclose(scalar_val, torch.tensor(config.max_value), atol=1e-3) 59 | 60 | # Test generate 61 | print("Testing generate...") 62 | tokens = [1, 2, 3] 63 | gen = model.generate(tokens, max_tokens=5) 64 | generated = list(gen) 65 | print(f"Generated tokens: {generated}") 66 | assert len(generated) == 5 67 | 68 | # Test setup_optimizers 69 | print("Testing setup_optimizers...") 70 | # Mock get_dist_info to return (ddp, rank, local_rank, world_size) 71 | # We mocked nanoproof.common, so we need to set the return value on that mock 72 | # Note: model.py imports get_dist_info FROM nanoproof.common. 73 | # Since we mocked sys.modules["nanoproof.common"] BEFORE import, 74 | # the imported get_dist_info in model.py is the mock. 75 | # However, we need to access the SAME mock object to configure it. 76 | # We can access it via sys.modules["nanoproof.common"].get_dist_info 77 | sys.modules["nanoproof.common"].get_dist_info.return_value = (False, 0, 0, 1) 78 | 79 | optimizers = model.setup_optimizers() 80 | assert len(optimizers) == 2 81 | adamw_opt, muon_opt = optimizers 82 | 83 | # Check AdamW params 84 | # AdamW has 3 groups: lm_head, value_head, embedding 85 | assert len(adamw_opt.param_groups) == 3 86 | 87 | # Verify value_head params are in AdamW 88 | value_head_params = list(model.value_head.parameters()) 89 | found_value_head = False 90 | for group in adamw_opt.param_groups: 91 | if any(p is value_head_params[0] for p in group['params']): 92 | found_value_head = True 93 | break 94 | assert found_value_head, "ValueHead params not found in AdamW optimizer" 95 | 96 | print("All tests passed!") 97 | 98 | if __name__ == "__main__": 99 | test_network() 100 | -------------------------------------------------------------------------------- /scripts/prover_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from tqdm import tqdm 7 | from leantree.repl_adapter.server import LeanClient 8 | 9 | from nanoproof.common import compute_init, compute_cleanup, print0, is_ddp, autodetect_device_type, get_dist_info 10 | from nanoproof.data import minif2f 11 | from nanoproof.data import leanworkbook 12 | from nanoproof.search import run_mcts, Config, Game, Node, Player, TacticModel 13 | from nanoproof.checkpoints import load_model 14 | from nanoproof.engine import Engine 15 | 16 | @torch.inference_mode() 17 | def eval_success_rate(tactic_model: TacticModel, theorems=None, use_tqdm=False): 18 | """ 19 | Evaluates the success rate of the model on the MiniF2F benchmark. 20 | Returns a dictionary with 'success_rate', 'solved', and 'total'. 21 | """ 22 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() 23 | theorem_indices = list(range(ddp_rank, len(theorems), ddp_world_size)) 24 | theorems = [theorems[i] for i in theorem_indices] 25 | 26 | config = Config() 27 | client = LeanClient(config.server_address, config.server_port) 28 | 29 | solved_count = 0 30 | error_count = 0 31 | 32 | device = tactic_model.network.get_device() 33 | with client.get_process() as env: 34 | env.send_command(""" 35 | open scoped Real 36 | open scoped Nat 37 | open scoped Topology 38 | open scoped Polynomial 39 | """) 40 | iterator = zip(theorem_indices, theorems) 41 | if use_tqdm: 42 | iterator = tqdm(iterator, total=len(theorems), desc=f"Rank {ddp_rank}", position=ddp_rank) 43 | 44 | for i, theorem in iterator: 45 | init_branch = env.proof_from_sorry(theorem) 46 | if not init_branch.is_success(): 47 | error_count += 1 48 | print0(f"Error on theorem: {theorem}\n... error: {init_branch.error}") 49 | continue 50 | init_branch = init_branch.value 51 | 52 | game = Game(theorem, num_simulations=config.num_simulations) 53 | game.root = Node( 54 | action=None, 55 | prior=None, 56 | state=[init_branch], 57 | to_play=Player.OR, 58 | reward=None, 59 | ) 60 | 61 | run_mcts(config, game, tactic_model) 62 | 63 | if game.root.is_solved: 64 | solved_count += 1 65 | 66 | local_metrics = torch.tensor([solved_count, error_count, len(theorem_indices)], dtype=torch.long, device=device) 67 | if ddp: 68 | dist.all_reduce(local_metrics, op=dist.ReduceOp.SUM) 69 | global_solved = local_metrics[0].item() 70 | global_error = local_metrics[1].item() 71 | global_total = local_metrics[2].item() 72 | 73 | success_rate = global_solved / global_total if global_total > 0 else 0.0 74 | error_rate = global_error / global_total if global_total > 0 else 0.0 75 | return { 76 | "success_rate": success_rate, 77 | "solved": global_solved, 78 | "total": global_total, 79 | "errors": global_error, 80 | "error_rate": error_rate, 81 | } 82 | 83 | def main(): 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument("--max-theorems", type=int, default=50, help="Max theorems to evaluate") 86 | args = parser.parse_args() 87 | 88 | device_type = autodetect_device_type() 89 | compute_init(device_type) 90 | 91 | tactic_model = TacticModel.create() 92 | minif2f_theorems = minif2f.list_theorems(split="Valid") 93 | minif2f_theorems = minif2f_theorems[:args.max_theorems] 94 | leanworkbook_theorems = leanworkbook.list_theorems(split="val") 95 | leanworkbook_theorems = leanworkbook_theorems[:args.max_theorems] 96 | 97 | def print_results(results, name): 98 | print0("-" * 80) 99 | print0(f"Evaluation results for {name}") 100 | print0(f"Success rate: {results['success_rate']:.4%}") 101 | print0(f"Solved: {results['solved']}/{results['total']}") 102 | print0(f"Errors: {results['errors']}/{results['total']}") 103 | print0(f"Error rate: {results['error_rate']:.4%}") 104 | print0("-" * 80) 105 | 106 | leanworkbook_results = eval_success_rate(tactic_model, leanworkbook_theorems, use_tqdm=True) 107 | print_results(leanworkbook_results, "LeanWorkBook") 108 | 109 | minif2f_results = eval_success_rate(tactic_model, minif2f_theorems, use_tqdm=True) 110 | print_results(minif2f_results, "MiniF2F") 111 | 112 | compute_cleanup() 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /scripts/tok_show.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from tqdm import tqdm 4 | 5 | from nanoproof.tokenizer import get_tokenizer, HuggingFaceTokenizer 6 | from nanoproof.data.leangithubraw import iter_texts_batched 7 | 8 | # Random text I got from a random website this morning 9 | news_text = r""" 10 | (Washington, D.C., July 9, 2025)- Yesterday, Mexico’s National Service of Agro-Alimentary Health, Safety, and Quality (SENASICA) reported a new case of New World Screwworm (NWS) in Ixhuatlan de Madero, Veracruz in Mexico, which is approximately 160 miles northward of the current sterile fly dispersal grid, on the eastern side of the country and 370 miles south of the U.S./Mexico border. This new northward detection comes approximately two months after northern detections were reported in Oaxaca and Veracruz, less than 700 miles away from the U.S. border, which triggered the closure of our ports to Mexican cattle, bison, and horses on May 11, 2025. 11 | 12 | While USDA announced a risk-based phased port re-opening strategy for cattle, bison, and equine from Mexico beginning as early as July 7, 2025, this newly reported NWS case raises significant concern about the previously reported information shared by Mexican officials and severely compromises the outlined port reopening schedule of five ports from July 7-September 15. Therefore, in order to protect American livestock and our nation’s food supply, Secretary Rollins has ordered the closure of livestock trade through southern ports of entry effective immediately. 13 | 14 | “The United States has promised to be vigilant — and after detecting this new NWS case, we are pausing the planned port reopening’s to further quarantine and target this deadly pest in Mexico. We must see additional progress combatting NWS in Veracruz and other nearby Mexican states in order to reopen livestock ports along the Southern border,” said U.S. Secretary of Agriculture Brooke L. Rollins. “Thanks to the aggressive monitoring by USDA staff in the U.S. and in Mexico, we have been able to take quick and decisive action to respond to the spread of this deadly pest.” 15 | """.strip() 16 | 17 | lean_search_text = r""" 18 | x : ℝ 19 | ⊢ x ^ 2 - 2 * x - 24 < 0 ↔ x ∈ Set.Ioo (-4) 6 20 | 21 | exact ⟨fun h ↦ by rw [Set.mem_Ioo]; constructor <;> nlinarith [h], fun h ↦ by rw [Set.mem_Ioo] at h; nlinarith⟩ 22 | 23 | ⊢ ∀ (x : ℝ), 2⁻¹ + cos (2 * (2 * x)) / 2 = (1 + cos (4 * x)) / 2 24 | 25 | ring 26 | 27 | case h 28 | ι : Type u_4 29 | inst✝ : Fintype ι 30 | f : ℝ → ι → ℝ 31 | s : Set ℝ 32 | h : LocallyBoundedVariationOn f s 33 | A : ∀ (i : ι), LipschitzWith 1 fun x => x i 34 | i : ι 35 | ⊢ LocallyBoundedVariationOn (fun x => f x i) s 36 | 37 | exact LipschitzWith.comp_locallyBoundedVariationOn (A i) h 38 | 39 | p q : Prop 40 | ⊢ p ∧ q → p 41 | 42 | intro h 43 | 44 | case mp.inl 45 | p q r : Prop 46 | hp : p 47 | hq : q 48 | ⊢ p ∧ q ∨ p ∧ r 49 | 50 | exact Or.inl ⟨hp, hq⟩ 51 | 52 | α : Type 53 | P : α → Prop 54 | inst✝ : Inhabited α 55 | h : ∀ (x : α), P x 56 | x0 : α := default 57 | hx0 : P x0 58 | ⊢ ∃ x, P x 59 | 60 | exact Exists.intro x0 hx0 61 | 62 | ¬test 63 | """ 64 | 65 | if len(sys.argv) != 2: 66 | print("Usage: python tok_show.py ") 67 | sys.exit(1) 68 | tokenizer_name = sys.argv[1] 69 | if tokenizer_name == "gpt2": 70 | tokenizer = HuggingFaceTokenizer.from_pretrained("gpt2") 71 | elif tokenizer_name == "ours": 72 | tokenizer = get_tokenizer() 73 | else: 74 | raise ValueError(f"Unknown tokenizer: {tokenizer_name}") 75 | 76 | print(f"Vocab size: {tokenizer.get_vocab_size():,}") 77 | 78 | for text in [("news", news_text), ("lean", lean_search_text)]: 79 | name, text = text 80 | encoded = tokenizer.encode(text) 81 | tokens = [tokenizer.id_to_token(id) for id in encoded] 82 | print(f"{name}:") 83 | print(' '.join(tokens)) 84 | print() 85 | 86 | 87 | print("Gathering character frequencies from leangithubraw train...") 88 | char_counts = {} 89 | # iter_texts_batched yields lists of strings (texts) 90 | # We'll iterate through both train and val splits to be thorough 91 | for batch_texts in iter_texts_batched(split="train", url_whitelist=["https://github.com/leanprover-community/mathlib4"]): 92 | for text in batch_texts: 93 | for char in text: 94 | char_counts[char] = char_counts.get(char, 0) + 1 95 | 96 | # Filter to only characters that appear at least 10 times 97 | frequent_chars = {char for char, count in char_counts.items() if count >= 1000} 98 | print(f"Found {len(frequent_chars)} characters that appear at least 1000 times.") 99 | 100 | print("Checking for characters without dedicated tokens...") 101 | chars_without_token = [] 102 | for char in tqdm(frequent_chars): 103 | # Encode the single character 104 | ids = tokenizer.encode(char) 105 | 106 | # If it takes more than 1 token to represent the character, 107 | # it doesn't have a single dedicated token. 108 | if len(ids) != 1: 109 | chars_without_token.append(char) 110 | continue 111 | assert tokenizer.decode(ids) == char, f"decode({ids}) = \"{tokenizer.decode(ids)}\" != \"{char}\" (U+{ord(char):04X})" 112 | 113 | chars_without_token.sort() 114 | print(f"\nFound {len(chars_without_token)} characters without a dedicated token:") 115 | 116 | print("[" + ", ".join(f"\"{c}\"" for c in chars_without_token) + "]") -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.sif 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[codz] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # UV 101 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | #uv.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | #poetry.toml 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 117 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 118 | #pdm.lock 119 | #pdm.toml 120 | .pdm-python 121 | .pdm-build/ 122 | 123 | # pixi 124 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 125 | #pixi.lock 126 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 127 | # in the .venv directory. It is recommended not to include this directory in version control. 128 | .pixi 129 | 130 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 131 | __pypackages__/ 132 | 133 | # Celery stuff 134 | celerybeat-schedule 135 | celerybeat.pid 136 | 137 | # SageMath parsed files 138 | *.sage.py 139 | 140 | # Environments 141 | .env 142 | .envrc 143 | .venv 144 | env/ 145 | venv/ 146 | ENV/ 147 | env.bak/ 148 | venv.bak/ 149 | 150 | # Spyder project settings 151 | .spyderproject 152 | .spyproject 153 | 154 | # Rope project settings 155 | .ropeproject 156 | 157 | # mkdocs documentation 158 | /site 159 | 160 | # mypy 161 | .mypy_cache/ 162 | .dmypy.json 163 | dmypy.json 164 | 165 | # Pyre type checker 166 | .pyre/ 167 | 168 | # pytype static type analyzer 169 | .pytype/ 170 | 171 | # Cython debug symbols 172 | cython_debug/ 173 | 174 | # PyCharm 175 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 176 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 177 | # and can be added to the global gitignore or merged into this file. For a more nuclear 178 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 179 | #.idea/ 180 | 181 | # Abstra 182 | # Abstra is an AI-powered process automation framework. 183 | # Ignore directories containing user credentials, local state, and settings. 184 | # Learn more at https://abstra.io/docs 185 | .abstra/ 186 | 187 | # Visual Studio Code 188 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 189 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 190 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 191 | # you could uncomment the following to ignore the entire vscode folder 192 | # .vscode/ 193 | 194 | # Ruff stuff: 195 | .ruff_cache/ 196 | 197 | # PyPI configuration file 198 | .pypirc 199 | 200 | # Cursor 201 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 202 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 203 | # refer to https://docs.cursor.com/context/ignore-files 204 | .cursorignore 205 | .cursorindexingignore 206 | 207 | # Marimo 208 | marimo/_static/ 209 | marimo/_lsp/ 210 | __marimo__/ 211 | -------------------------------------------------------------------------------- /nanoproof/checkpoints.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for saving and loading model/optim/state checkpoints. 3 | """ 4 | import os 5 | import re 6 | import glob 7 | import json 8 | import logging 9 | import torch 10 | 11 | from nanoproof.model import Transformer, NetworkConfig 12 | from nanoproof.tokenizer import get_tokenizer 13 | from nanoproof.common import get_base_dir, setup_default_logging 14 | 15 | # Set up logging 16 | setup_default_logging() 17 | logger = logging.getLogger(__name__) 18 | def log0(message): 19 | if int(os.environ.get('RANK', 0)) == 0: 20 | logger.info(message) 21 | 22 | def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): 23 | if rank == 0: 24 | os.makedirs(checkpoint_dir, exist_ok=True) 25 | # Save the model state parameters 26 | model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") 27 | torch.save(model_data, model_path) 28 | logger.info(f"Saved model parameters to: {model_path}") 29 | # Save the metadata dict as json 30 | meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") 31 | with open(meta_path, "w", encoding="utf-8") as f: 32 | json.dump(meta_data, f, indent=2) 33 | logger.info(f"Saved metadata to: {meta_path}") 34 | # Note that optimizer state is sharded across ranks, so each rank must save its own. 35 | if optimizer_data is not None: 36 | optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") 37 | torch.save(optimizer_data, optimizer_path) 38 | logger.info(f"Saved optimizer state to: {optimizer_path}") 39 | 40 | def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): 41 | # Load the model state 42 | model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") 43 | model_data = torch.load(model_path, map_location=device) 44 | # Load the optimizer state if requested 45 | optimizer_data = None 46 | if load_optimizer: 47 | optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") 48 | optimizer_data = torch.load(optimizer_path, map_location=device) 49 | # Load the metadata 50 | meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") 51 | with open(meta_path, "r", encoding="utf-8") as f: 52 | meta_data = json.load(f) 53 | return model_data, optimizer_data, meta_data 54 | 55 | 56 | def build_model(checkpoint_dir, step, device, phase): 57 | """ 58 | A bunch of repetitive code to build a model from a given checkpoint. 59 | Returns: 60 | - base model - uncompiled, not wrapped in DDP 61 | - tokenizer 62 | - meta data saved during base model training 63 | """ 64 | assert phase in ["train", "eval"], f"Invalid phase: {phase}" 65 | model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) 66 | if device.type in {"cpu", "mps"}: 67 | # Convert bfloat16 tensors to float for CPU inference 68 | model_data = { 69 | k: v.float() if v.dtype == torch.bfloat16 else v 70 | for k, v in model_data.items() 71 | } 72 | # Hack: fix torch compile issue, which prepends all keys with _orig_mod. 73 | model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} 74 | model_config_kwargs = meta_data["model_config"] 75 | log0(f"Building model with config: {model_config_kwargs}") 76 | model_config = NetworkConfig(**model_config_kwargs) 77 | with torch.device("meta"): 78 | model = Transformer(model_config) 79 | # Load the model state 80 | model.to_empty(device=device) 81 | model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init 82 | 83 | model.load_state_dict(model_data, strict=True, assign=True) 84 | # Put the model in the right training phase / mode 85 | if phase == "eval": 86 | model.eval() 87 | else: 88 | model.train() 89 | # Load the Tokenizer 90 | tokenizer = get_tokenizer() 91 | # Sanity check: compatibility between model and tokenizer 92 | assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"] 93 | return model, tokenizer, meta_data 94 | 95 | 96 | def find_last_step(checkpoint_dir): 97 | # Look into checkpoint_dir and find model_.pt with the highest step 98 | checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) 99 | if not checkpoint_files: 100 | raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") 101 | last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) 102 | return last_step 103 | 104 | # ----------------------------------------------------------------------------- 105 | # convenience functions that take into account nanoproof's directory structure 106 | 107 | def load_model_from_dir(checkpoints_dir, device, phase, model_tag, step=None): 108 | checkpoint_dir = os.path.join(checkpoints_dir, model_tag) 109 | if step is None: 110 | # guess the step by defaulting to the last step 111 | step = find_last_step(checkpoint_dir) 112 | assert step is not None, f"No checkpoints found in {checkpoint_dir}" 113 | # build the model 114 | log0(f"Loading model from {checkpoint_dir} with step {step}") 115 | model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase) 116 | return model, tokenizer, meta_data 117 | 118 | def load_model(source, *args, **kwargs): 119 | model_dir = { 120 | "base": "base_checkpoints", 121 | "mid": "mid_checkpoints", 122 | "sft": "sft_checkpoints", 123 | "rl": "rl_checkpoints", 124 | }[source] 125 | base_dir = get_base_dir() 126 | checkpoints_dir = os.path.join(base_dir, model_dir) 127 | return load_model_from_dir(checkpoints_dir, *args, **kwargs) -------------------------------------------------------------------------------- /scripts/tok_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a tokenizer using the HuggingFace Tokenizers library. 3 | In the style of GPT-4 tokenizer. 4 | """ 5 | import os 6 | import time 7 | import argparse 8 | import torch 9 | from nanoproof.tokenizer import HuggingFaceTokenizer, SPECIAL_TOKENS 10 | from nanoproof.common import get_base_dir 11 | from nanoproof.data.nemotron import parquets_iter_batched 12 | from nanoproof.data.leangithubraw import iter_texts_batched 13 | 14 | # ----------------------------------------------------------------------------- 15 | # Parse command line arguments 16 | 17 | parser = argparse.ArgumentParser(description='Train a BPE tokenizer') 18 | # parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)') 19 | parser.add_argument('--max_chars', type=int, default=1_000_000_000, help='Maximum characters to train on (default: 1B)') 20 | parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') 21 | # parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)') 22 | parser.add_argument('--vocab_size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)') 23 | args = parser.parse_args() 24 | print(f"max_chars: {args.max_chars:,}") 25 | print(f"doc_cap: {args.doc_cap:,}") 26 | print(f"vocab_size: {args.vocab_size:,}") 27 | 28 | # ----------------------------------------------------------------------------- 29 | # Text iterator 30 | 31 | def text_iterator(): 32 | """ 33 | 1) Flatten the batches into a single iterator 34 | 2) Crop every document to args.doc_cap characters 35 | 3) Break when we've seen args.max_chars characters 36 | """ 37 | nchars = 0 38 | 39 | # Generator for nemotron documents 40 | def nemotron_gen(): 41 | for batch in parquets_iter_batched(split="train"): 42 | for doc in batch: 43 | yield doc 44 | 45 | # Generator for leangithubraw documents (infinite/restarting) 46 | def leangithub_gen(): 47 | while True: 48 | for batch in iter_texts_batched(split="train"): 49 | for doc in batch: 50 | yield doc 51 | 52 | nemotron_iter = nemotron_gen() 53 | leangithub_iter = leangithub_gen() 54 | 55 | while True: 56 | try: 57 | doc_nemotron = next(nemotron_iter) 58 | except StopIteration: 59 | print("WARNING: Nemotron iterator exhausted") 60 | break 61 | if len(doc_nemotron) > args.doc_cap: 62 | doc_nemotron = doc_nemotron[:args.doc_cap] 63 | 64 | doc_lean = next(leangithub_iter) 65 | if len(doc_lean) > args.doc_cap: 66 | doc_lean = doc_lean[:args.doc_cap] 67 | 68 | doc_text = doc_nemotron + "\n" + doc_lean 69 | 70 | nchars += len(doc_text) 71 | yield doc_text 72 | if nchars >= args.max_chars: 73 | return 74 | text_iter = text_iterator() 75 | 76 | # ----------------------------------------------------------------------------- 77 | # Train the tokenizer 78 | t0 = time.time() 79 | # TODO! 80 | # tokenizer = RustBPETokenizer.train_from_iterator(text_iter, args.vocab_size) 81 | # tokenizer = HuggingFaceTokenizer.train_from_iterator(text_iter, args.vocab_size) 82 | tokenizer = HuggingFaceTokenizer.from_pretrained("gpt2") 83 | tokenizer.tokenizer.add_special_tokens(SPECIAL_TOKENS) 84 | 85 | t1 = time.time() 86 | train_time = t1 - t0 87 | print(f"Training time: {train_time:.2f}s") 88 | 89 | # ----------------------------------------------------------------------------- 90 | # Save the tokenizer to disk 91 | base_dir = get_base_dir() 92 | tokenizer_dir = os.path.join(base_dir, "tokenizer") 93 | tokenizer.save(tokenizer_dir) 94 | 95 | # ----------------------------------------------------------------------------- 96 | # Quick inline sanity check 97 | test_text = """Hello world! This is a test. 98 | Numbers: 123, 4567, 89 99 | Contractions: I'm, you're, it's 100 | Special chars: @#$%^&*() 101 | Unicode: 你好世界 🌍""" 102 | encoded = tokenizer.encode(test_text) 103 | decoded = tokenizer.decode(encoded) 104 | assert decoded == test_text 105 | 106 | # ----------------------------------------------------------------------------- 107 | # One more thing: we wish to cache a mapping from token id to number of bytes of that token 108 | # for efficient evaluation of bits per byte. Unlike the typical mean loss, this 109 | # allows us to report a loss that is invariant to the vocab size of the tokenizer. 110 | # The bits per byte on the validation set is then one of the primary metrics we care about. 111 | vocab_size = tokenizer.get_vocab_size() 112 | special_set = set(tokenizer.get_special_tokens()) 113 | token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)] 114 | token_bytes = [] 115 | for token_id in range(vocab_size): 116 | token_str = token_strings[token_id] # the Python string representation of this token 117 | if token_str in special_set: 118 | token_bytes.append(0) # special characters are not counted 119 | else: 120 | id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token 121 | token_bytes.append(id_bytes) 122 | token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu') 123 | token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt") 124 | with open(token_bytes_path, "wb") as f: 125 | torch.save(token_bytes, f) 126 | print(f"Saved token_bytes to {token_bytes_path}") 127 | 128 | # Log to report 129 | from nanoproof.report import get_report 130 | token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32) 131 | get_report().log(section="Tokenizer training", data=[ 132 | vars(args), # argparse command line arguments 133 | {"train_time": train_time}, 134 | {"num_special_tokens": len(special_set)}, 135 | { 136 | "token_bytes_min": int(token_bytes_nonzero.min().item()), 137 | "token_bytes_max": int(token_bytes_nonzero.max().item()), 138 | "token_bytes_mean": token_bytes_nonzero.mean().item(), 139 | "token_bytes_std": token_bytes_nonzero.std().item(), 140 | } 141 | ]) -------------------------------------------------------------------------------- /nanoproof/data/nemotron_dataloader.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/karpathy/nanochat/blob/master/nanochat/dataloader.py 2 | 3 | from collections import deque 4 | 5 | import torch 6 | import pyarrow.parquet as pq 7 | 8 | from nanoproof.common import get_dist_info 9 | from nanoproof.data.nemotron import list_parquet_files 10 | from nanoproof.tokenizer import get_tokenizer 11 | 12 | def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None): 13 | """ 14 | Stream pretraining text from parquet files, tokenize, yield training batches. 15 | 16 | This implementation became a bit more complex because we wish to support approximate resume training. 17 | Instead of turning this into a Class, we opt to return the state_dict with every batch, 18 | and then the caller can pass in a state_dict to resume training from a desired point. 19 | Note that this resumption is atm only *approximate* for simplicity. 20 | We won't repeat the same documents but we might skip a few. 21 | The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume. 22 | 23 | Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm. 24 | """ 25 | assert split in ["train", "val"], "split must be 'train' or 'val'" 26 | 27 | # infinite iterator over document batches (list of text strings) 28 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() 29 | def document_batches(): 30 | parquet_paths = list_parquet_files() 31 | assert len(parquet_paths) > 0, "No parquet files found." 32 | parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] 33 | resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 34 | resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None 35 | pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0) 36 | while True: # iterate infinitely (multi-epoch) 37 | while pq_idx < len(parquet_paths): # iterate over all parquet files 38 | filepath = parquet_paths[pq_idx] 39 | pf = pq.ParquetFile(filepath) 40 | # Start from resume point if resuming on same file, otherwise from DDP rank 41 | # I know this state resumption is a little bit tricky and a little bit hacky... sigh. 42 | if resume_rg_idx is not None: 43 | base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size 44 | base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming 45 | rg_idx = base_idx * ddp_world_size + ddp_rank 46 | resume_rg_idx = None # set to None as we only want to do this a single time 47 | else: 48 | rg_idx = ddp_rank 49 | while rg_idx < pf.num_row_groups: 50 | rg = pf.read_row_group(rg_idx) 51 | batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows 52 | # the tokenizer encode might want to go in even smaller batches, e.g. 128 rows 53 | for i in range(0, len(batch), tokenizer_batch_size): 54 | yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx) 55 | rg_idx += ddp_world_size # advance to the next row group (in DDP) 56 | pq_idx += 1 # advance to the next parquet file 57 | print("WARNING: Nemotron dataset restarted!") 58 | batches = document_batches() 59 | 60 | # Now emit batches of tokens. 61 | needed_tokens = B * T + 1 # +1 is because we also need the target at the last token 62 | # get the tokenizer and the bos token 63 | tokenizer = get_tokenizer() 64 | bos_token = tokenizer.get_bos_token_id() 65 | assert bos_token is not None 66 | # scratch buffer holds the tokens for one iteration 67 | token_buffer = deque() # we stream tokens on the right and pop from the left 68 | while True: 69 | # Accumulate enough tokens for one iteration before yielding. 70 | while len(token_buffer) < needed_tokens: 71 | doc_batch, (pq_idx, rg_idx) = next(batches) 72 | # token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) 73 | token_lists = tokenizer.encode(doc_batch, prepend=bos_token) 74 | for tokens in token_lists: 75 | token_buffer.extend(tokens) 76 | # Move tokens from the deque into the scratch buffer 77 | tokens = [token_buffer.popleft() for _ in range(needed_tokens)] 78 | # CUDA supports memory pinning for asynchronous transfers between CPU and GPU 79 | use_cuda_optimizations = device == "cuda" 80 | scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64 81 | # Create the inputs/targets as 1D tensors 82 | inputs_cpu = scratch[:-1] 83 | targets_cpu = scratch[1:] 84 | # Reshape to 2D and move to GPU async 85 | inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) 86 | targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) 87 | state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training 88 | yield inputs, targets, state_dict 89 | 90 | def tokenizing_distributed_data_loader(*args, **kwargs): 91 | # helper function that only emits the inputs/targets and not the state_dict 92 | for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs): 93 | yield inputs, targets 94 | 95 | if __name__ == "__main__": 96 | B = 128 97 | T = 1024 98 | max_batches = 10 99 | split = "train" 100 | tokenizer_threads = 4 101 | tokenizer_batch_size = 128 102 | device = "cuda" 103 | resume_state_dict = None 104 | dataloader = tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads, tokenizer_batch_size, device, resume_state_dict) 105 | for i, (inputs, targets, state_dict) in enumerate(dataloader): 106 | if i >= max_batches: 107 | break 108 | print(f"Batch {i}: {inputs.shape}, {targets.shape}", flush=True) 109 | print("Done.") -------------------------------------------------------------------------------- /nanoproof/experience_collection.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.distributed as dist 4 | 5 | from nanoproof.common import get_dist_info 6 | from leantree.repl_adapter.server import LeanClient 7 | 8 | from nanoproof.search import Node, Player, Game, run_bfs, run_mcts, TacticModel, Action, State, Config 9 | from nanoproof.data.leanworkbook import list_theorems 10 | from nanoproof.common import SimpleTimer 11 | 12 | class TheoremsSampler: 13 | def __init__(self, seed: int | None = 0): 14 | self.theorems = list_theorems(split="train") 15 | self.rng = random.Random(seed) 16 | 17 | def sample_theorem(self) -> str: 18 | # return "theorem lean_workbook_42924 (h : 1 / 2 * 30 * 23 * 6 = 2070) : 1 / 2 * 30 * 23 * 6 = 2070 := by sorry" 19 | return self.rng.choice(self.theorems) 20 | 21 | 22 | class ReplayBuffer: 23 | def __init__(self, config: Config, seed: int = 0): 24 | self.window_size = config.window_size 25 | self.batch_size = config.batch_size 26 | self.sequence_length = config.sequence_length 27 | self.local_buffer = [] 28 | self.buffer = [] 29 | self.rng = random.Random(seed) 30 | 31 | def save_game(self, game: Game) -> int: 32 | transitions = self._extract_transitions(game.root) 33 | print("! New transitions !") 34 | for transition in transitions: 35 | print(transition) 36 | 37 | self.local_buffer.extend(transitions) 38 | 39 | print(f"Local buffer size: {len(self.local_buffer)}") 40 | 41 | return len(transitions) 42 | 43 | def synchronize(self): 44 | ddp, _, _, world_size = get_dist_info() 45 | if ddp: 46 | gathered_buffers = [None for _ in range(world_size)] 47 | dist.all_gather_object(gathered_buffers, self.local_buffer) 48 | for buffer in gathered_buffers: 49 | self.buffer.extend(buffer) 50 | else: 51 | self.buffer.extend(self.local_buffer) 52 | 53 | self.local_buffer = [] 54 | if len(self.buffer) > self.window_size: 55 | self.buffer = self.buffer[-self.window_size:] 56 | 57 | def _extract_transitions(self, node: Node) -> list[tuple[str, str, float]]: 58 | """Extracts transitions from a proof.""" 59 | assert node.to_play == Player.OR 60 | if not node.is_solved: 61 | return [] 62 | transitions = [] 63 | while node.to_play == Player.OR and not node.is_terminal: 64 | assert len(node.state) == 1 65 | action = self._select_optimal_action(node) 66 | assert isinstance(action, str) 67 | transitions.append((str(node.state[0].state).strip(), action.strip(), node.value_target)) 68 | node = node.children[action] 69 | if node.to_play == Player.AND: 70 | for _, child in node.children.items(): 71 | transitions.extend(self._extract_transitions(child)) 72 | return transitions 73 | 74 | def _select_optimal_action(self, node: Node) -> Action: 75 | assert node.to_play == Player.OR 76 | actions = [action for action in node.children if node.children[action].is_solved] 77 | assert len(actions) > 0 78 | # select the shortest tactic 79 | return min(actions, key=lambda a: len(a)) 80 | 81 | def sample_transition(self) -> tuple[str, str, float]: 82 | return self.rng.choice(self.buffer) 83 | 84 | 85 | # Each acting job is independent of all others; it takes the latest network 86 | # snapshot, produces a game and makes it available to the learner by writing it 87 | # to a shared replay buffer. 88 | @torch.inference_mode() 89 | def run_actor(total_to_collect: int, config: Config, model: TacticModel, replay_buffer: ReplayBuffer, theorems_sampler: TheoremsSampler) -> SimpleTimer: 90 | collected = 0 91 | ddp, _, _, world_size = get_dist_info() 92 | timer = SimpleTimer() 93 | 94 | while True: 95 | # Check if we have collected enough proofs globally 96 | if ddp: 97 | collected_tensor = torch.tensor([collected], dtype=torch.long, device=model.network.get_device()) 98 | dist.all_reduce(collected_tensor, op=dist.ReduceOp.SUM) 99 | global_collected = collected_tensor.item() 100 | else: 101 | global_collected = collected 102 | if global_collected >= total_to_collect: 103 | break 104 | 105 | game = play_game(config, model, theorems_sampler, timer) 106 | if game is None: 107 | # print("Invalid theorem statement.") 108 | continue 109 | if game.root.is_solved: 110 | collected += replay_buffer.save_game(game) 111 | 112 | return timer 113 | 114 | 115 | # Each game is produced by starting from the initial Lean state, and executing 116 | # BFS/MCTS to find a proof. If one is found, we extract from the search tree the 117 | # state-tactic-value transitions in the proof, which are added to a replay 118 | # buffer for training. 119 | def play_game(config: Config, model: TacticModel, theorems_sampler: TheoremsSampler, timer: SimpleTimer) -> Game | None: 120 | theorem = theorems_sampler.sample_theorem() 121 | client = LeanClient(config.server_address, config.server_port) 122 | with client.get_process() as env: 123 | env.send_command(""" 124 | open scoped Real 125 | open scoped Nat 126 | open scoped Topology 127 | open scoped Polynomial 128 | """) 129 | init_branch = env.proof_from_sorry(theorem) 130 | if not init_branch.is_success(): 131 | return None 132 | init_branch = init_branch.value 133 | game = Game(theorem, config.num_simulations) 134 | 135 | game.root = Node( 136 | action=None, 137 | prior=None, 138 | state=[init_branch], 139 | to_play=Player.OR, 140 | reward=None, 141 | ) 142 | 143 | # success = run_bfs(game, model) 144 | run_mcts(config, game, model, timer) 145 | if game.root.is_solved: 146 | # TODO: Perform final check to ensure the proof is valid. 147 | # game.root.is_solved = final_check(game) 148 | 149 | # TODO: try to remove each tactic from the proof and check if the proof is still valid (maybe even more iterations of this) 150 | 151 | # TODO: Compute value targets for the proof. 152 | # compute_value_target(game.root) 153 | print(theorem) 154 | pass 155 | 156 | return game 157 | 158 | 159 | def _main(): 160 | config = Config() 161 | model = TacticModel.create() 162 | replay_buffer = ReplayBuffer(config) 163 | theorems_sampler = TheoremsSampler() 164 | timer = run_actor(config, model, replay_buffer, theorems_sampler) 165 | timer.log_times() 166 | 167 | 168 | if __name__ == "__main__": 169 | _main() -------------------------------------------------------------------------------- /nanoproof/data/leantree.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | from itertools import islice 5 | import time 6 | from pathlib import Path 7 | 8 | import termplotlib as tpl 9 | import numpy as np 10 | import requests 11 | 12 | from tqdm import tqdm 13 | import leantree 14 | 15 | from nanoproof.common import get_base_dir, format_distribution 16 | from nanoproof.tokenizer import get_tokenizer 17 | 18 | base_dir = get_base_dir() 19 | DATA_DIR = os.path.join(base_dir, "data", "leantree") 20 | 21 | HF_URL = "https://huggingface.co/datasets/ufal/leantree/resolve/main/leantree_mathlib.jsonl" 22 | 23 | 24 | def iter_data(split, eval_fraction=0.1, augmentations=None): 25 | assert split in ["train", "val"] 26 | mathlib_file = os.path.join(DATA_DIR, "leantree_mathlib.jsonl") 27 | if not Path(mathlib_file).exists(): 28 | raise Exception("leantree not downloaded, please run this script with `download` argument") 29 | with open(mathlib_file, "r") as f: 30 | lines = f.readlines() 31 | eval_size = int(len(lines) * eval_fraction) 32 | lines = lines[:-eval_size] if split == "train" else lines[-eval_size:] 33 | 34 | for line in lines: 35 | lean_file = leantree.LeanFile.deserialize(json.loads(line)) 36 | for thm in lean_file.theorems: 37 | if isinstance(thm, leantree.StoredError): 38 | continue 39 | for by_block in thm.by_blocks: 40 | if isinstance(by_block.tree, leantree.StoredError): 41 | continue 42 | for node in by_block.tree.get_nodes(): 43 | if augmentations: 44 | for aug in augmentations: 45 | node = aug.run(node) 46 | yield str(node.state), str(node.tactic.tactic), node.proof_depth 47 | 48 | 49 | def download_dataset(): 50 | """Download the leantree dataset from HuggingFace.""" 51 | jsonl_path = os.path.join(DATA_DIR, "leantree_mathlib.jsonl") 52 | 53 | # skip if already downloaded 54 | if os.path.exists(jsonl_path): 55 | print(f"Dataset already downloaded at {jsonl_path}") 56 | return 57 | 58 | try: 59 | print(f"Downloading leantree dataset from HuggingFace...") 60 | response = requests.get(HF_URL, stream=True, timeout=60) 61 | response.raise_for_status() 62 | 63 | temp_path = jsonl_path + ".tmp" 64 | total_size = int(response.headers.get("content-length", 0)) 65 | with open(temp_path, "wb") as f: 66 | with tqdm(total=total_size, unit="B", unit_scale=True, unit_divisor=1024, desc="Downloading leantree_mathlib.jsonl") as pbar: 67 | for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks 68 | if chunk: 69 | f.write(chunk) 70 | pbar.update(len(chunk)) 71 | 72 | os.rename(temp_path, jsonl_path) 73 | print(f"Successfully downloaded {jsonl_path}") 74 | except (requests.RequestException, IOError): 75 | # Clean up any partial files 76 | for path in [jsonl_path + ".tmp", jsonl_path]: 77 | if os.path.exists(path): 78 | print(f"Cleaning up {path}") 79 | os.remove(path) 80 | raise 81 | 82 | def print_stats(): 83 | tokenizer = get_tokenizer() 84 | bos_token = tokenizer.get_bos_token_id() 85 | assert bos_token is not None 86 | eos_token = tokenizer.get_eos_token_id() 87 | assert eos_token is not None 88 | for split in ["train", "val"]: 89 | print(f"Loading {split=}...") 90 | dataset = list(iter_data(split=split)) 91 | print(f"Calculating {split=}...") 92 | lens = {"state": [], "tactic": []} 93 | depths = [] 94 | start_time = time.time() 95 | for state, tactic, proof_depth in tqdm(dataset): 96 | state = tokenizer.encode(state + "\n<|tactic|> ", prepend=bos_token) 97 | tactic = tokenizer.encode(tactic, append=eos_token) 98 | lens["state"].append(len(state)) 99 | lens["tactic"].append(len(tactic)) 100 | depths.append(proof_depth) 101 | end_time = time.time() 102 | print(f"time: {end_time - start_time:.2f}s") 103 | print(f"total: {len(lens['state'])}") 104 | for prop, max_len in [("state", 448), ("tactic", 64)]: 105 | print(f"{prop} lengths:") 106 | print(f" min: {np.min(lens[prop])}") 107 | print(f" max: {np.max(lens[prop])}") 108 | print(f" mean: {np.mean(lens[prop]):.2f}") 109 | print(f" median: {np.median(lens[prop])}") 110 | print(f" std: {np.std(lens[prop]):.2f}") 111 | print(f" p90: {np.percentile(lens[prop], 90):.2f}") 112 | print(f" p95: {np.percentile(lens[prop], 95):.2f}") 113 | print(f" p99: {np.percentile(lens[prop], 99):.2f}") 114 | at_most_max = np.sum(np.array(lens[prop]) <= max_len) 115 | print(f" <= {max_len}: {at_most_max / len(lens[prop]):%} ({at_most_max}/{len(lens[prop])})") 116 | print(f"depths:") 117 | print(f" min: {np.min(depths)}") 118 | print(f" max: {np.max(depths)}") 119 | print(f" mean: {np.mean(depths):.2f}") 120 | print(f" median: {np.median(depths)}") 121 | print(f" p90: {np.percentile(depths, 90):.2f}") 122 | print(f" p95: {np.percentile(depths, 95):.2f}") 123 | print(f" p99: {np.percentile(depths, 99):.2f}") 124 | at_most_32 = np.sum(np.array(depths) <= 32) 125 | print(f" <= 32: {at_most_32 / len(depths):%} ({at_most_32}/{len(depths)})") 126 | print() 127 | 128 | fig = tpl.figure() 129 | min_depth = int(np.min(depths)) 130 | max_depth = int(np.max(depths)) 131 | bin_edges = np.arange(min_depth, max_depth + 2) # +2 to include max_depth in a bin 132 | counts, bin_edges = np.histogram(depths, bins=bin_edges) 133 | fig.hist(counts, bin_edges=bin_edges, force_ascii=False, orientation="horizontal") 134 | fig.show() 135 | print() 136 | 137 | def main(): 138 | parser = argparse.ArgumentParser(description="Download LeanTree dataset from HuggingFace.") 139 | subparsers = parser.add_subparsers(dest="action") 140 | 141 | download_parser = subparsers.add_parser("download") 142 | 143 | show_parser = subparsers.add_parser("show") 144 | show_parser.add_argument("--split", choices=["train", "val"], default="train") 145 | 146 | stats_parser = subparsers.add_parser("stats") 147 | 148 | args = parser.parse_args() 149 | 150 | if args.action == "download": 151 | os.makedirs(DATA_DIR, exist_ok=True) 152 | download_dataset() 153 | elif args.action == "show": 154 | for state, tactic in islice(iter_data(split=args.split), 10): 155 | print(state) 156 | print("\n->\n") 157 | print(tactic) 158 | print("\n-----------------\n") 159 | elif args.action == "stats": 160 | print_stats() 161 | else: 162 | raise f"Unknown action {args.action}" 163 | 164 | if __name__ == "__main__": 165 | main() 166 | -------------------------------------------------------------------------------- /nanoproof/data/leangithub_urls.txt: -------------------------------------------------------------------------------- 1 | # Source repos from internlm/Lean-Github 2 | # ... except for https://github.com/mortarsanjaya/IMOSLLean4.git, which contains IMO problems 3 | # ... plus mathlib 4 | 5 | https://github.com/leanprover-community/mathlib4.git 6 | 7 | https://github.com/kevinsullivan/cs2120f23.git 8 | https://github.com/pthomas505/FOL.git 9 | https://github.com/pandaman64/QuickSortInLean.git 10 | https://github.com/jvlmdr/from_the_book.git 11 | https://github.com/aronerben/lean4-playground.git 12 | https://github.com/Human-Oriented-ATP/lean-tactics.git 13 | https://github.com/BrownCS1951x/fpv2023.git 14 | https://github.com/RemyCiterin/LeanCoInd.git 15 | https://github.com/KisaraBlue/ec-tate-lean.git 16 | https://github.com/lurk-lab/yatima.git 17 | https://github.com/gleachkr/Completeness-For-Fine-Semantics.git 18 | https://github.com/vbeffara/RMT4.git 19 | https://github.com/mariainesdff/skew_polynomials.git 20 | https://github.com/robertylewis/CS22-Lean-Dev.git 21 | https://github.com/calcu16/lean_complexity.git 22 | https://github.com/UofSC-Spring-2023-SCHC-411-H01/notes.git 23 | https://github.com/kmill/msri2023_graphs.git 24 | https://github.com/YaelDillies/LeanCamCombi.git 25 | https://github.com/siddhartha-gadgil/proofs-and-programs-2023.git 26 | https://github.com/jrr6/lean-tables.git 27 | https://github.com/b-mehta/ExponentialRamsey.git 28 | https://github.com/alexjbest/ant-lorentz.git 29 | https://github.com/avigad/mathematics_in_lean_source.git 30 | https://github.com/opencompl/egg-tactic-code.git 31 | https://github.com/opencompl/lean-mlir.git 32 | https://github.com/JLimperg/regensburg-itp-school-2023.git 33 | https://github.com/jt496/Turan_4.git 34 | https://github.com/goens/lean-power-calc.git 35 | https://github.com/Vilin97/random_graphs.git 36 | https://github.com/Ruben-VandeVelde/algebra-i.git 37 | https://github.com/cruhland/lean4-analysis.git 38 | https://github.com/codyroux/traat-lean.git 39 | https://github.com/zeramorphic/set-theory.git 40 | https://github.com/iehality/lean4-logic.git 41 | https://github.com/rikitoro/fpil.git 42 | https://github.com/optimisticexquisite/lean-rsa-project.git 43 | https://github.com/bridgekat/filter-game.git 44 | https://github.com/brown-cs22/CS22-Lean-2023.git 45 | https://github.com/Karthik-Dulam/reals-quasi-morphisms.git 46 | https://github.com/ramyshahin/variability.git 47 | https://github.com/plumsirawit/basic-algebra.git 48 | https://github.com/JamesGallicchio/LeanColls.git 49 | https://github.com/AlexKontorovich/PrimeNumberTheoremAnd.git 50 | https://github.com/jt496/Mordell4.git 51 | https://github.com/jappaaa/Bachelor-project.git 52 | https://github.com/cruhland/lean4-axiomatic.git 53 | https://github.com/pandaman64/lean-regex.git 54 | https://github.com/lecopivo/Probly.git 55 | https://github.com/fpvandoorn/HausdorffSchoolLean.git 56 | https://github.com/knowsys/Formale-Systeme-in-LEAN.git 57 | https://github.com/mo271/formal_book.git 58 | https://github.com/robertylewis/lean_dummy_project.git 59 | https://github.com/Junology/fifteen.git 60 | https://github.com/JOSHCLUNE/DuperDemo.git 61 | https://github.com/fpvandoorn/LogicColloquiumTutorial.git 62 | https://github.com/leanprover-community/flt-regular.git 63 | https://github.com/lecopivo/SciLean.git 64 | https://github.com/dwrensha/compfiles.git 65 | https://github.com/benjaminfjones/theorem-proving-lean4.git 66 | https://github.com/NiclausCarlson/Diploma.git 67 | https://github.com/digama0/mm-lean4.git 68 | https://github.com/ChrisHughes24/lean4bits.git 69 | https://github.com/marcusrossel/stlc.git 70 | https://github.com/Odomontois/advent2022-lean.git 71 | https://github.com/lurk-lab/YatimaStdLib.lean.git 72 | https://github.com/ufmg-smite/lean-smt.git 73 | https://github.com/eric-wieser/lean-matrix-cookbook.git 74 | https://github.com/breakerzirconia/s-infinity.git 75 | https://github.com/lf-lang/reactor-lean.git 76 | https://github.com/robertylewis/BrownCs22.git 77 | https://github.com/Junology/Moncalc.git 78 | https://github.com/m4lvin/lean4-pdl.git 79 | https://github.com/Julek/bucharest-lean-ac.git 80 | https://github.com/siddhartha-gadgil/GroupComp.git 81 | https://github.com/SnO2WMaN/lean4-concatenation-theory.git 82 | https://github.com/lf-lang/reactor-model.git 83 | https://github.com/Junology/algdata.git 84 | https://github.com/riccardobrasca/kaplanski4.git 85 | https://github.com/djvelleman/STG4.git 86 | https://github.com/alexkeizer/list-powers-problem.git 87 | https://github.com/mirkootter/lean-mt.git 88 | https://github.com/marcusrossel/weighted-tree-automata.git 89 | https://github.com/PatrickMassot/GlimpseOfLean.git 90 | https://github.com/riccardobrasca/LeanTeaching.git 91 | https://github.com/Ruben-VandeVelde/flt4.git 92 | https://github.com/AntoineChambert-Loir/Sion4.git 93 | https://github.com/siddhartha-gadgil/Polylean.git 94 | https://github.com/hargoniX/cpdt-lean.git 95 | https://github.com/robertylewis/leanclass.git 96 | https://github.com/jvlmdr/forml.git 97 | https://github.com/fgdorais/GMLInit.git 98 | https://github.com/dupuisf/aoc2023.git 99 | https://github.com/avigad/lamr.git 100 | https://github.com/iehality/Arithmetization.git 101 | https://github.com/leanprover-community/lean4-metaprogramming-book.git 102 | https://github.com/leanprover/leansat.git 103 | https://github.com/JamesGallicchio/LeanSAT.git 104 | https://github.com/bustercopley/lean-float.git 105 | https://github.com/MichaelStollBayreuth/Weights.git 106 | https://github.com/apnelson1/Matroid.git 107 | https://github.com/BoltonBailey/formal-snarks-project.git 108 | https://github.com/vbeffara/Lean4_misc.git 109 | https://github.com/alexkeizer/QpfTypes.git 110 | https://github.com/prakol16/hfzfa.git 111 | https://github.com/djvelleman/HTPILeanPackage.git 112 | https://github.com/siddhartha-gadgil/Saturn.git 113 | https://github.com/AntoineChambert-Loir/DividedPowers4.git 114 | https://github.com/fpvandoorn/sard.git 115 | https://github.com/dwrensha/lean4-maze.git 116 | https://github.com/fpvandoorn/LeanInRome.git 117 | https://github.com/fgdorais/extra4.git 118 | https://github.com/digama0/lean4lean.git 119 | https://github.com/jeremysalwen/advent_of_lean_2022.git 120 | https://github.com/kbuzzard/IISc-experiments.git 121 | https://github.com/Nazgand/NazgandLean4.git 122 | https://github.com/fpvandoorn/carleson.git 123 | https://github.com/opencompl/HOLFloat-Lean.git 124 | https://github.com/adamtopaz/CopenhagenMasterclass2023.git 125 | https://github.com/negiizhao/Algorithm.git 126 | https://github.com/SnO2WMaN/lean4-propositional-logic.git 127 | https://github.com/ImperialCollegeLondon/formalising-mathematics-2024.git 128 | https://github.com/zeramorphic/lambda_calculi.git 129 | https://github.com/PatrickMassot/GaloisConnectionGame.git 130 | https://github.com/katydid/proofs.git 131 | https://github.com/fgdorais/lean4-ascii.git 132 | https://github.com/ianjauslin-rutgers/pythagoras4.git 133 | https://github.com/linesthatinterlace/controlbits.git 134 | https://github.com/zeramorphic/category-theory.git 135 | https://github.com/JOSHCLUNE/VerusLeanStd.git 136 | https://github.com/brown-cs22/CS22-Lean-2024.git 137 | https://github.com/AntoineChambert-Loir/Jordan4.git 138 | https://github.com/fpvandoorn/LeanCourse23.git 139 | https://github.com/Junology/dijkstra.git 140 | https://github.com/alexkeizer/LeanMeetup.git 141 | https://github.com/girving/ray.git 142 | https://github.com/kovach/etch.git 143 | https://github.com/JamesGallicchio/lean_rms.git 144 | https://github.com/rikitoro/FM2023_exercise.git 145 | https://github.com/mguaypaq/lean-bruhat.git 146 | https://github.com/yuma-mizuno/lean-math-workshop.git 147 | https://github.com/loewenheim/projective-plane.git 148 | https://github.com/benjaminfjones/reckonlean.git 149 | https://github.com/Jun2M/Main-theorem-of-polytopes.git 150 | https://github.com/verified-optimization/CvxLean.git 151 | https://github.com/MichaelStollBayreuth/EulerProducts.git 152 | https://github.com/JamesGallicchio/brunched-invitations.git -------------------------------------------------------------------------------- /nanoproof/data/leantree_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from itertools import islice 3 | 4 | from nanoproof.common import get_dist_info 5 | from nanoproof.tokenizer import get_tokenizer, value_to_token_ids 6 | from nanoproof.data.leantree import iter_data 7 | from nanoproof.model import NetworkConfig 8 | 9 | STATE_MAX_LEN = 640 10 | TACTIC_MAX_LEN = 128 11 | 12 | def sft_data_generator(dataset, batch_size, device="cuda"): 13 | assert batch_size % 2 == 0 # need this because we generate both tactic and value samples for each datapoint 14 | tokenizer = get_tokenizer() 15 | bos_token = tokenizer.get_bos_token_id() 16 | eos_token = tokenizer.get_eos_token_id() 17 | assert bos_token is not None 18 | assert eos_token is not None 19 | pad_token_id = tokenizer.encode_special("<|pad|>") 20 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() 21 | 22 | def collate_and_yield(batch): 23 | nrows = len(batch) 24 | ncols = max(len(ids) for ids, _ in batch) - 1 # seq of n creates inputs/targets of n-1 25 | inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long) 26 | targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index 27 | for i, (ids, mask) in enumerate(batch): 28 | n = len(ids) 29 | ids_tensor = torch.tensor(ids, dtype=torch.long) 30 | inputs[i, :n - 1] = ids_tensor[:-1] 31 | # recall -1 is the ignore index, so mask out targets where mask is 0 32 | row_targets = ids_tensor[1:] 33 | # mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok 34 | mask_tensor = torch.tensor(mask[1:], dtype=torch.long) 35 | row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0 36 | targets[i, :n - 1] = row_targets 37 | inputs = inputs.to(device) # move to device 38 | targets = targets.to(device) 39 | return inputs, targets 40 | 41 | # iterates over the dataset in epochs, tokenizes 42 | batch = [] 43 | last_step = False 44 | while True: 45 | for i in range(ddp_rank, len(dataset), ddp_world_size): 46 | state, tactic, proof_depth = dataset[i] 47 | state, tactic = state.strip(), tactic.strip() 48 | assert len(state) != 0 and len(tactic) != 0 and proof_depth >= 1 49 | 50 | state_toks = tokenizer.encode(state + "\n", prepend=bos_token) 51 | 52 | tactic_delim_tok = tokenizer.encode_special("<|tactic|>") 53 | tactic_toks = tokenizer.encode(tactic, append=eos_token) 54 | 55 | value_delim_tok = tokenizer.encode_special("<|value|>") 56 | value_toks = value_to_token_ids(tokenizer, proof_depth) + [eos_token] 57 | 58 | # these are <0.1% of mathlib 59 | if len(tactic_toks) > TACTIC_MAX_LEN: 60 | continue 61 | if len(state_toks) + 1 + len(tactic_toks) > 768: 62 | continue 63 | assert len(state_toks) + 1 + len(value_toks) <= 768 64 | 65 | batch.append(( 66 | state_toks + [tactic_delim_tok] + tactic_toks, 67 | [0] * (len(state_toks) + 1) + [1] * len(tactic_toks) 68 | )) 69 | # TODO: uncomment this once we are using <|value|> 70 | # TODO: we also need to change the dataset size calculation in SFT.py accordingly! 71 | # TODO: we also need to change the tactic_eval script to distinguish between tactic and value samples 72 | # batch.append(( 73 | # state_toks + [value_delim_tok] + value_toks, 74 | # [0] * (len(state_toks) + 1) + [1] * len(value_toks) 75 | # )) 76 | 77 | approx_progress = i / len(dataset) 78 | last_step = last_step or (i + ddp_world_size >= len(dataset)) 79 | if len(batch) == batch_size: 80 | yield *collate_and_yield(batch), approx_progress, last_step 81 | batch = [] 82 | print(f"Warning: Rank {ddp_rank} will loop again on leantree ({len(dataset)=}).", flush=True) 83 | 84 | def rl_data_generator(generator, batch_size, device="cuda"): 85 | assert batch_size % 2 == 0 # need this because we generate both tactic and value samples for each datapoint 86 | tokenizer = get_tokenizer() 87 | bos_token = tokenizer.get_bos_token_id() 88 | eos_token = tokenizer.get_eos_token_id() 89 | assert bos_token is not None 90 | assert eos_token is not None 91 | pad_token_id = tokenizer.encode_special("<|pad|>") 92 | 93 | def collate_and_yield(batch): 94 | nrows = len(batch) 95 | ncols = max(len(ids) for ids, _ in batch) - 1 # seq of n creates inputs/targets of n-1 96 | inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long) 97 | targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index 98 | for i, (ids, mask) in enumerate(batch): 99 | n = len(ids) 100 | ids_tensor = torch.tensor(ids, dtype=torch.long) 101 | inputs[i, :n - 1] = ids_tensor[:-1] 102 | # recall -1 is the ignore index, so mask out targets where mask is 0 103 | row_targets = ids_tensor[1:] 104 | # mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok 105 | mask_tensor = torch.tensor(mask[1:], dtype=torch.long) 106 | row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0 107 | targets[i, :n - 1] = row_targets 108 | inputs = inputs.to(device) # move to device 109 | targets = targets.to(device) 110 | return inputs, targets 111 | 112 | # iterates over the dataset in epochs, tokenizes 113 | batch = [] 114 | last_step = False 115 | for state, tactic, proof_depth in generator: 116 | state, tactic = state.strip(), tactic.strip() 117 | assert len(state) != 0 and len(tactic) != 0 118 | # assert proof_depth >= 1 119 | 120 | state_toks = tokenizer.encode(state + "\n", prepend=bos_token) 121 | 122 | tactic_delim_tok = tokenizer.encode_special("<|tactic|>") 123 | tactic_toks = tokenizer.encode(tactic, append=eos_token) 124 | 125 | # value_delim_tok = tokenizer.encode_special("<|value|>") 126 | # value_toks = value_to_token_ids(tokenizer, proof_depth) + [eos_token] 127 | 128 | # these are <0.1% of mathlib 129 | if len(tactic_toks) > TACTIC_MAX_LEN: 130 | continue 131 | if len(state_toks) + 1 + len(tactic_toks) > 768: 132 | continue 133 | # assert len(state_toks) + 1 + len(value_toks) <= 768 134 | 135 | batch.append(( 136 | state_toks + [tactic_delim_tok] + tactic_toks, 137 | [0] * (len(state_toks) + 1) + [1] * len(tactic_toks) 138 | )) 139 | # TODO: uncomment this once we are using <|value|> 140 | # TODO: we also need to change the dataset size calculation in SFT.py accordingly! 141 | # TODO: we also need to change the tactic_eval script to distinguish between tactic and value samples 142 | # batch.append(( 143 | # state_toks + [value_delim_tok] + value_toks, 144 | # [0] * (len(state_toks) + 1) + [1] * len(value_toks) 145 | # )) 146 | 147 | if len(batch) == batch_size: 148 | yield collate_and_yield(batch) 149 | batch = [] 150 | 151 | if __name__ == "__main__": 152 | print("Loading dataset...") 153 | dataset = list(iter_data(split="train")) 154 | tokenizer = get_tokenizer() 155 | for inputs, targets in islice(sft_data_generator(dataset, batch_size=4), 10): 156 | for i in range(inputs.size(0)): 157 | print(f"Input {i}:") 158 | print(inputs[i]) 159 | print(tokenizer.decode(inputs[i].tolist())) 160 | print() 161 | 162 | print(f"Target {i}:") 163 | print(targets[i]) 164 | # replace -1 with a different token so that it can be decoded 165 | targets[i][targets[i] == -1] = tokenizer.encode("X")[0] 166 | print(tokenizer.decode(targets[i].tolist())) 167 | print("--") 168 | 169 | print("-" * 100) -------------------------------------------------------------------------------- /nanoproof/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/karpathy/nanochat/blob/master/nanochat/tokenizer.py 2 | """ 3 | BPE Tokenizer in the style of GPT-4. 4 | 5 | Two implementations are available: 6 | 1) HuggingFace Tokenizer that can do both training and inference but is really confusing 7 | 2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference 8 | """ 9 | 10 | import os 11 | 12 | from nanoproof.model import NetworkConfig 13 | 14 | _MIN_VALUE = 1 15 | _MAX_VALUE = 64 # max value corresponds to "infinity" 16 | 17 | SPECIAL_TOKENS = [ 18 | # every document begins with the Beginning of Sequence (BOS) token that delimits documents 19 | "<|pad|>", 20 | "<|tactic|>", 21 | "<|value|>", 22 | *[f"<|bin_{i:02d}|>" for i in range(_MIN_VALUE, _MAX_VALUE + 1)], 23 | # these occur at least 1000 times in Mathlib but do not have dedicated tokens in GPT-2 24 | "ˢ", "ˣ", "Γ", "Δ", "Λ", "Π", "Σ", "Φ", "Ω", "δ", "ζ", "η", "θ", "φ", "χ", "ψ", "ϕ", "ᵈ", "ᵐ", "ᵒ", "ᵖ", "ᵢ", "ᵣ", "ᵥ", "ᶜ", "ᶠ", "‖", "‹", "›", "⁅", "⁆", "⁰", "⁻", "₀", "₁", "₂", "₃", "₄", "₊", "ₐ", "ₑ", "ₗ", "ₘ", "ₙ", "ₚ", "ₛ", "ₜ", "ℂ", "ℕ", "ℚ", "ℝ", "ℤ", "ℱ", "←", "↔", "↦", "↪", "⇑", "∀", "∂", "∃", "∅", "∈", "∉", "∏", "∑", "∘", "∞", "∣", "∧", "∨", "∩", "∪", "∫", "≃", "≅", "≠", "≡", "≤", "≥", "≪", "≫", "⊆", "⊓", "⊔", "⊕", "⊗", "⊢", "⊤", "⊥", "⋂", "⋃", "⋆", "⋙", "▷", "▸", "◁", "⟦", "⟧", "⟨", "⟩", "⟪", "⟫", "⟶", "⥤", "⦃", "⦄", "⧸", "⨅", "⨆", "𝒜", "𝒰", "𝓘", "𝓝", "𝔖", "𝕜", "𝟙", 25 | # these are left out because they are already in GPT2 tokenizer (although weirdly not reported in tok_show): "¬", "¹" 26 | ] 27 | 28 | # NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3} 29 | # I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes. 30 | # I haven't validated that this is actually a good idea, TODO. 31 | SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" 32 | 33 | # ----------------------------------------------------------------------------- 34 | # Generic GPT-4-style tokenizer based on HuggingFace Tokenizer 35 | from tokenizers import Tokenizer as HFTokenizer 36 | from tokenizers import pre_tokenizers, decoders, Regex 37 | from tokenizers.models import BPE 38 | from tokenizers.trainers import BpeTrainer 39 | 40 | class HuggingFaceTokenizer: 41 | """Light wrapper around HuggingFace Tokenizer for some utilities""" 42 | 43 | def __init__(self, tokenizer): 44 | self.tokenizer = tokenizer 45 | 46 | @classmethod 47 | def from_pretrained(cls, hf_path): 48 | # init from a HuggingFace pretrained tokenizer (e.g. "gpt2") 49 | tokenizer = HFTokenizer.from_pretrained(hf_path) 50 | return cls(tokenizer) 51 | 52 | @classmethod 53 | def from_directory(cls, tokenizer_dir): 54 | # init from a local directory on disk (e.g. "out/tokenizer") 55 | tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") 56 | tokenizer = HFTokenizer.from_file(tokenizer_path) 57 | return cls(tokenizer) 58 | 59 | @classmethod 60 | def train_from_iterator(cls, text_iterator, vocab_size): 61 | # train from an iterator of text 62 | # Configure the HuggingFace Tokenizer 63 | tokenizer = HFTokenizer(BPE( 64 | byte_fallback=True, # needed! 65 | unk_token=None, 66 | fuse_unk=False, 67 | )) 68 | # Normalizer: None 69 | tokenizer.normalizer = None 70 | # Pre-tokenizer: GPT-4 style 71 | # the regex pattern used by GPT-4 to split text into groups before BPE 72 | # NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to 73 | # very small models and smaller vocab sizes, because it is a little bit wasteful in the token space. 74 | # (but I haven't validated this! TODO) 75 | gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!! 76 | tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ 77 | pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False), 78 | pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False) 79 | ]) 80 | # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer) 81 | tokenizer.decoder = decoders.ByteLevel() 82 | # Post-processor: None 83 | tokenizer.post_processor = None 84 | # Trainer: BPE 85 | trainer = BpeTrainer( 86 | vocab_size=vocab_size, 87 | show_progress=True, 88 | min_frequency=0, # no minimum frequency 89 | initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), 90 | special_tokens=SPECIAL_TOKENS, 91 | ) 92 | # Kick off the training 93 | tokenizer.train_from_iterator(text_iterator, trainer) 94 | return cls(tokenizer) 95 | 96 | def get_vocab_size(self): 97 | return self.tokenizer.get_vocab_size() 98 | 99 | def get_special_tokens(self): 100 | special_tokens_map = self.tokenizer.get_added_tokens_decoder() 101 | special_tokens = [w.content for w in special_tokens_map.values()] 102 | return special_tokens 103 | 104 | def id_to_token(self, id): 105 | return self.tokenizer.id_to_token(id) 106 | 107 | def _encode_one(self, text, prepend=None, append=None): 108 | # encode a single string 109 | # prepend/append can be either a string of a special token or a token id directly. 110 | assert isinstance(text, str) 111 | ids = [] 112 | if prepend is not None: 113 | prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend) 114 | ids.append(prepend_id) 115 | ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids) 116 | if append is not None: 117 | append_id = append if isinstance(append, int) else self.encode_special(append) 118 | ids.append(append_id) 119 | return ids 120 | 121 | def encode_special(self, text): 122 | # encode a single special token via exact match 123 | return self.tokenizer.token_to_id(text) 124 | 125 | def get_bos_token_id(self): 126 | bos = self.encode_special("<|endoftext|>") 127 | return bos 128 | 129 | def get_eos_token_id(self): 130 | eos = self.encode_special("<|endoftext|>") 131 | return eos 132 | 133 | def encode(self, text, *args, **kwargs): 134 | if isinstance(text, str): 135 | return self._encode_one(text, *args, **kwargs) 136 | elif isinstance(text, list): 137 | return [self._encode_one(t, *args, **kwargs) for t in text] 138 | else: 139 | raise ValueError(f"Invalid input type: {type(text)}") 140 | 141 | def __call__(self, *args, **kwargs): 142 | return self.encode(*args, **kwargs) 143 | 144 | def decode(self, ids): 145 | return self.tokenizer.decode(ids, skip_special_tokens=False) 146 | 147 | def save(self, tokenizer_dir): 148 | # save the tokenizer to disk 149 | os.makedirs(tokenizer_dir, exist_ok=True) 150 | tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") 151 | self.tokenizer.save(tokenizer_path) 152 | print(f"Saved tokenizer to {tokenizer_path}") 153 | 154 | # TODO: use special tokens! 155 | def value_to_token_ids(tokenizer, value: int) -> list[int]: 156 | assert value >= _MIN_VALUE 157 | value = min(value, _MAX_VALUE) 158 | return tokenizer.encode(str(value)) 159 | 160 | def token_ids_to_value(tokenizer, token_ids: list[int]) -> float | None: 161 | try: 162 | return int(tokenizer.decode(token_ids)) 163 | except ValueError: 164 | return None 165 | 166 | def get_tokenizer(): 167 | # return HuggingFaceTokenizer.from_pretrained("gpt2") 168 | from nanoproof.common import get_base_dir 169 | base_dir = get_base_dir() 170 | tokenizer_dir = os.path.join(base_dir, "tokenizer") 171 | return HuggingFaceTokenizer.from_directory(tokenizer_dir) 172 | 173 | def get_token_bytes(device="cpu"): 174 | import torch 175 | from nanoproof.common import get_base_dir 176 | base_dir = get_base_dir() 177 | tokenizer_dir = os.path.join(base_dir, "tokenizer") 178 | token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt") 179 | assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py" 180 | with open(token_bytes_path, "rb") as f: 181 | token_bytes = torch.load(f, map_location=device) 182 | return token_bytes -------------------------------------------------------------------------------- /nanoproof/data/nemotron.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/karpathy/nanochat/blob/master/nanochat/dataset.py 2 | 3 | """ 4 | The base/pretraining dataset is loaded from HuggingFace datasets. 5 | This file contains utilities for: 6 | - loading the dataset from HuggingFace 7 | - iterating over the dataset and yielding documents from it 8 | - limiting the number of files/shard used via CLI 9 | 10 | For details of how the dataset was prepared, see the HuggingFace dataset page. 11 | """ 12 | 13 | import os 14 | import argparse 15 | import requests 16 | import time 17 | from multiprocessing import Pool 18 | from pathlib import Path 19 | 20 | import pyarrow.parquet as pq 21 | from huggingface_hub import get_token 22 | from tqdm import tqdm 23 | 24 | from nanoproof.common import get_base_dir 25 | 26 | # ----------------------------------------------------------------------------- 27 | # The specifics of the current pretraining dataset 28 | 29 | DATASET_NAME = "nvidia/Nemotron-CC-Math-v1" 30 | BASE_URL = "https://huggingface.co/datasets/nvidia/Nemotron-CC-Math-v1/resolve/main" 31 | 32 | base_dir = get_base_dir() 33 | DATA_DIR = os.path.join(base_dir, "data", "nemotron") 34 | 35 | MAX_SHARD = 45 # the last datashard is part_000045.parquet 36 | index_to_filename = lambda index: f"4plus/part_{index:06d}.parquet" # format of the filenames 37 | 38 | # ----------------------------------------------------------------------------- 39 | # These functions are useful utilities to other modules, can/should be imported 40 | 41 | def list_parquet_files(data_dir=None): 42 | """ Looks into a data dir and returns full paths to all parquet files. """ 43 | data_dir = DATA_DIR if data_dir is None else data_dir 44 | parquet_files = sorted([ 45 | f for f in os.listdir(os.path.join(data_dir, "4plus")) 46 | if f.endswith('.parquet') and not f.endswith('.tmp') 47 | ]) 48 | parquet_paths = [os.path.join(data_dir, "4plus", f) for f in parquet_files] 49 | return parquet_paths 50 | 51 | def parquets_iter_batched(split, start=0, step=1): 52 | """ 53 | Iterate through the dataset, in batches of underlying row_groups for efficiency. 54 | - split can be "train" or "val". the last parquet file will be val. 55 | - start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size 56 | """ 57 | assert split in ["train", "val"], "split must be 'train' or 'val'" 58 | parquet_paths = list_parquet_files() 59 | parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] 60 | for filepath in parquet_paths: 61 | pf = pq.ParquetFile(filepath) 62 | row_group_indices = list(range(start, pf.num_row_groups, step)) 63 | for rg_idx in row_group_indices: 64 | rg = pf.read_row_group(rg_idx) 65 | texts = rg.column('text').to_pylist() 66 | yield texts 67 | 68 | def process_file(filepath): 69 | """ 70 | Loads a parquet file, changes group size to 1024, drops all columns except 'text', 71 | and overwrites the file. 72 | """ 73 | try: 74 | # Read the table 75 | table = pq.read_table(filepath) 76 | 77 | # Select only 'text' column if others exist 78 | if 'text' in table.column_names: 79 | table = table.select(['text']) 80 | else: 81 | print(f"Warning: 'text' column not found in {filepath}") 82 | 83 | # Write to a temporary file with new row group size 84 | temp_path = filepath + ".rechunk" 85 | pq.write_table(table, temp_path, row_group_size=1024) 86 | 87 | # Overwrite the original file 88 | os.rename(temp_path, filepath) 89 | except Exception as e: 90 | print(f"Failed to process {filepath}: {e}") 91 | # Clean up temp file if it was created 92 | if os.path.exists(filepath + ".rechunk"): 93 | try: 94 | os.remove(filepath + ".rechunk") 95 | except: 96 | pass 97 | 98 | def download_single_file(index): 99 | """ Downloads a single file index, with some backoff """ 100 | 101 | # Construct the local filepath for this file and skip if it already exists 102 | filename = index_to_filename(index) 103 | filepath = os.path.join(DATA_DIR, filename) 104 | if os.path.exists(filepath): 105 | print(f"Skipping {filepath} (already exists)") 106 | return True 107 | 108 | # Construct the remote URL for this file 109 | url = f"{BASE_URL}/{filename}" 110 | print(f"Downloading {filename}... to {filepath}") 111 | 112 | # Get Hugging Face token for authentication 113 | token = get_token() 114 | headers = {} 115 | if token: 116 | headers["Authorization"] = f"Bearer {token}" 117 | 118 | # Download with retries 119 | max_attempts = 5 120 | for attempt in range(1, max_attempts + 1): 121 | try: 122 | response = requests.get(url, stream=True, timeout=30, headers=headers) 123 | response.raise_for_status() 124 | # Write to temporary file first 125 | temp_path = filepath + f".tmp" 126 | Path(temp_path).parent.mkdir(parents=True, exist_ok=True) 127 | total_size = int(response.headers.get('content-length', 0)) 128 | with open(temp_path, 'wb') as f: 129 | with tqdm(total=total_size, unit='B', unit_scale=True, unit_divisor=1024, desc=os.path.basename(filename), leave=False) as pbar: 130 | for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks 131 | if chunk: 132 | f.write(chunk) 133 | pbar.update(len(chunk)) 134 | # Move temp file to final location 135 | os.rename(temp_path, filepath) 136 | # Process the file immediately after download 137 | process_file(filepath) 138 | print(f"Successfully downloaded and processed {filename}") 139 | return True 140 | 141 | except (requests.RequestException, IOError) as e: 142 | print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") 143 | # Clean up any partial files 144 | for path in [filepath + f".tmp", filepath]: 145 | if os.path.exists(path): 146 | try: 147 | os.remove(path) 148 | except: 149 | pass 150 | # Try a few times with exponential backoff: 2^attempt seconds 151 | if attempt < max_attempts: 152 | wait_time = 2 ** attempt 153 | print(f"Waiting {wait_time} seconds before retry...") 154 | time.sleep(wait_time) 155 | else: 156 | print(f"Failed to download {filename} after {max_attempts} attempts") 157 | return False 158 | return False 159 | 160 | 161 | if __name__ == "__main__": 162 | parser = argparse.ArgumentParser(description="Manage Nemotron-CC-Math-v1 dataset") 163 | subparsers = parser.add_subparsers(dest="command", required=True) 164 | 165 | parser_download = subparsers.add_parser("download", help="Download dataset shards") 166 | parser_download.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable") 167 | parser_download.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)") 168 | 169 | parser_process = subparsers.add_parser("process", help="Process all downloaded parquet files (re-chunk to 1024 group size)") 170 | 171 | args = parser.parse_args() 172 | 173 | os.makedirs(DATA_DIR, exist_ok=True) 174 | 175 | if args.command == "download": 176 | num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) 177 | ids_to_download = list(range(num)) 178 | print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") 179 | print(f"Target directory: {DATA_DIR}") 180 | print() 181 | with Pool(processes=args.num_workers) as pool: 182 | results = list(tqdm( 183 | pool.imap(download_single_file, ids_to_download), 184 | total=len(ids_to_download), 185 | desc="Downloading shards" 186 | )) 187 | 188 | successful = sum(1 for success in results if success) 189 | print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}") 190 | 191 | elif args.command == "process": 192 | parquet_paths = list_parquet_files() 193 | print(f"Found {len(parquet_paths)} parquet files to process in {DATA_DIR}...") 194 | for filepath in tqdm(parquet_paths, desc="Processing files"): 195 | process_file(filepath) 196 | print("Done processing all files.") 197 | -------------------------------------------------------------------------------- /nanoproof/muon.py: -------------------------------------------------------------------------------- 1 | """ 2 | Muon optimizer from Keller et al. 3 | Also a lot of borrowing of ideas from modded-nanogpt. 4 | """ 5 | import torch 6 | from torch import Tensor 7 | import torch.distributed as dist 8 | 9 | @torch.compile 10 | def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: 11 | """ 12 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 13 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 14 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 15 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 16 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 17 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 18 | performance at all relative to UV^T, where USV^T = G is the SVD. 19 | """ 20 | assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng 21 | a, b, c = (3.4445, -4.7750, 2.0315) 22 | X = G.bfloat16() 23 | if G.size(-2) > G.size(-1): 24 | X = X.mT 25 | 26 | # Ensure spectral norm is at most 1 27 | X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) 28 | # Perform the NS iterations 29 | for _ in range(steps): 30 | A = X @ X.mT 31 | B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng 32 | X = a * X + B @ X 33 | 34 | if G.size(-2) > G.size(-1): 35 | X = X.mT 36 | return X 37 | 38 | class Muon(torch.optim.Optimizer): 39 | """ 40 | Muon - MomentUm Orthogonalized by Newton-schulz 41 | 42 | https://kellerjordan.github.io/posts/muon/ 43 | 44 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 45 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 46 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 47 | the advantage that it can be stably run in bfloat16 on the GPU. 48 | 49 | Some warnings: 50 | - This optimizer should not be used for the embedding layer, the final fully connected layer, 51 | or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). 52 | - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. 53 | 54 | Arguments: 55 | lr: The learning rate used by the internal SGD. 56 | momentum: The momentum used by the internal SGD. 57 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 58 | ns_steps: The number of Newton-Schulz iteration steps to use. 59 | """ 60 | def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): 61 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) 62 | params: list[Tensor] = [*params] 63 | param_groups = [] 64 | for size in {p.numel() for p in params}: 65 | group = dict(params=[p for p in params if p.numel() == size]) 66 | param_groups.append(group) 67 | super().__init__(param_groups, defaults) 68 | 69 | @torch.no_grad() 70 | def step(self): 71 | for group in self.param_groups: 72 | params: list[Tensor] = group["params"] 73 | for p in params: 74 | g = p.grad 75 | assert g is not None 76 | state = self.state[p] 77 | if "momentum_buffer" not in state: 78 | state["momentum_buffer"] = torch.zeros_like(g) 79 | buf: Tensor = state["momentum_buffer"] 80 | buf.lerp_(g, 1 - group["momentum"]) 81 | g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf 82 | g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) 83 | p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5) 84 | 85 | 86 | class DistMuon(torch.optim.Optimizer): 87 | """ 88 | Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz, 89 | finally apply aspect-ratio scaled step. Performs its own distributed synchronization: 90 | - reduce_scatter(AVG) for gradient averaging 91 | - all_gather to replicate updated weights 92 | 93 | Notes: 94 | * Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D 95 | params like embeddings or scalars. 96 | * Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen 97 | by block-cyclic assignment below). If you checkpoint optimizer state on a single rank, 98 | consolidate states beforehand. 99 | 100 | Args: 101 | params: iterable of Tensors 102 | lr: learning rate 103 | momentum: momentum coefficient in [0,1) 104 | nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf 105 | ns_steps: number of Newton–Schulz iterations for the orthogonalization 106 | """ 107 | def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, 108 | nesterov: bool = True, ns_steps: int = 5): 109 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) 110 | params = list(params) 111 | assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only" 112 | rank = dist.get_rank() 113 | # Group all parameters by their shape 114 | shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering 115 | param_groups = [] 116 | for shape in shapes: 117 | group_params = [p for p in params if p.shape == shape] 118 | device, dtype = group_params[0].device, group_params[0].dtype 119 | assert all(p.device == device for p in group_params) 120 | assert all(p.dtype == dtype for p in group_params) 121 | if rank == 0: 122 | print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}") 123 | param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0]))) 124 | super().__init__(param_groups, defaults) 125 | 126 | @torch.no_grad() 127 | def step(self): 128 | rank = dist.get_rank() 129 | world_size = dist.get_world_size() 130 | 131 | # Ensure all grads exist 132 | assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads" 133 | 134 | # Kick off all the reduce scatter operations to average up the gradients across all ranks 135 | all_reduce_futures = [] 136 | for group in self.param_groups: 137 | params = group["params"] 138 | zero_buffer = group["zero_buffer"] 139 | # Go through params in groups of world_size. 140 | for base_i in range(0, len(params), world_size): 141 | # The compute owner of each param is rank i % world_size 142 | owner_idx = base_i + rank 143 | # each rank stacks up its chunk of world_size params into a list 144 | rs_input = [p.grad for p in params[base_i:base_i + world_size]] 145 | # pad rs_input with the zero buffer to complete the group 146 | rs_input.extend([zero_buffer] * (world_size - len(rs_input))) 147 | # the output buffer gets strided across the group based on the rank 148 | rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer) 149 | # reduce scatter the gradients within this group of world_size params 150 | work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future() 151 | all_reduce_futures.append(work) 152 | 153 | # Now each rank computes the update and gathers 154 | future_idx = 0 155 | all_gather_futures = [] 156 | for group in self.param_groups: 157 | params = group["params"] 158 | zero_buffer = group["zero_buffer"] 159 | # Go through params in groups of world_size. 160 | for base_i in range(0, len(params), world_size): 161 | # The compute owner of each param is rank i % world_size 162 | owner_idx = base_i + rank # calculate the index of the param that this rank owns 163 | # Wait for the reduce scatter to complete 164 | all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead 165 | future_idx += 1 166 | # Owner computes the Muon update, result is in its param 167 | if owner_idx < len(params): 168 | p = params[owner_idx] 169 | g = p.grad # now averaged across ranks 170 | state = self.state[p] 171 | if "momentum_buffer" not in state: 172 | state["momentum_buffer"] = torch.zeros_like(g) 173 | buf: Tensor = state["momentum_buffer"] 174 | buf.lerp_(g, 1.0 - group["momentum"]) 175 | g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf 176 | g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) 177 | scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5) 178 | p.add_(g, alpha=-group["lr"] * scale) 179 | # Replicate updated parameters to all ranks 180 | ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer 181 | ag_output = params[base_i:base_i + world_size] 182 | ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad 183 | work = dist.all_gather(ag_output, ag_input, async_op=True).get_future() 184 | all_gather_futures.append(work) 185 | 186 | # Wait for all work to finish 187 | torch.futures.collect_all(all_gather_futures).wait() -------------------------------------------------------------------------------- /nanoproof/rl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | import json 4 | import sys 5 | 6 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 7 | import random 8 | 9 | import wandb 10 | import torch 11 | import torch.distributed as dist 12 | from contextlib import nullcontext 13 | 14 | from nanoproof.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type, SimpleTimer 15 | from nanoproof.checkpoints import load_model, save_checkpoint 16 | from nanoproof.engine import Engine 17 | from nanoproof.data.leantree import iter_data 18 | from nanoproof.data.leantree_dataloader import rl_data_generator 19 | from nanoproof.experience_collection import ReplayBuffer, TheoremsSampler, Config, run_actor 20 | from nanoproof.search import TacticModel 21 | from nanoproof.data import minif2f 22 | from nanoproof.data import leanworkbook 23 | from scripts.prover_eval import eval_success_rate 24 | 25 | # TODO: make the search much more efficient via batching/async 26 | # TODO: if tactic application results in a state that already is on the path from root, skip the tactic (otherwise we sometimes get stuck in loop of eg. rw [add_comm]) 27 | 28 | """ 29 | Timer results: 30 | expand : 5297.9417s (67.0%) 31 | sample : 2612.2757s (33.0%) 32 | """ 33 | 34 | # TODO: save all proofs found during evaluation 35 | # TODO: (maybe) try removing each tactic and if the proof is still valid, do not add the transition to the replay buffer 36 | # ... however, then we need to be sure to update the proof states 37 | # TODO: matchmaker 38 | 39 | # ----------------------------------------------------------------------------- 40 | # RL Hyperparameters 41 | run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) 42 | seed = 0 43 | # compute/precision 44 | device_type = "" # cuda|cpu|mps (empty => autodetect) 45 | dtype = "bfloat16" 46 | device_batch_size = 8 # (maybe) max to avoid OOM (on A100 40GB) 47 | # data 48 | fraction_sft = 0.1 # 10% of data will come from Mathlib (leantree), 90% from replay buffer 49 | collect_every = 1 # how many steps to train between RL data collections 50 | collect_transitions = 100 # how many proof transitions to collect in one collection 51 | # optimization 52 | num_epochs = 1 53 | num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it) 54 | target_examples_per_step = 512 55 | unembedding_lr = 0.004 56 | embedding_lr = 0.2 57 | matrix_lr = 0.02 58 | weight_decay = 0.0 59 | init_lr_frac = 0.02 60 | # evaluation and logging there of 61 | eval_every = 2 62 | # eval_metrics_every = 200 63 | sample_every = 100 64 | eval_metrics_max_problems = 1024 65 | save_every = 1000 66 | # now allow CLI to override the settings via the configurator lol 67 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 68 | exec(open(os.path.join('nanoproof', 'configurator.py')).read()) # overrides from command line or config file 69 | user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging 70 | # ----------------------------------------------------------------------------- 71 | 72 | # Compute init 73 | device_type = autodetect_device_type() if device_type == "" else device_type 74 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) 75 | master_process = ddp_rank == 0 76 | ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16 77 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() 78 | 79 | # Output directory init 80 | 81 | if master_process: 82 | base_dir = get_base_dir() 83 | timestamp = datetime.now().strftime("%y-%m-%d_%H-%M") 84 | output_dirname = f"{timestamp}-{run}" 85 | output_dir = os.path.join(base_dir, "rl", output_dirname) 86 | if os.path.exists(output_dir): 87 | print0(f"Error: Output directory {output_dir} already exists") 88 | if ddp: 89 | dist.destroy_process_group() 90 | sys.exit(1) 91 | os.makedirs(output_dir) 92 | print0(f"Output directory: {output_dir}") 93 | else: 94 | output_dir = None 95 | 96 | if ddp: 97 | output_dir_list = [output_dir] 98 | dist.broadcast_object_list(output_dir_list, src=0) 99 | output_dir = output_dir_list[0] 100 | 101 | # wandb logging init 102 | use_dummy_wandb = run == "dummy" or not master_process 103 | wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanoproof-rl", name=run, config=user_config, save_code=True) 104 | 105 | tactic_model = TacticModel.create() 106 | model = tactic_model.network 107 | 108 | # print0(f"Target examples per step: {target_examples_per_step}") 109 | # print0(f"Collect every: {collect_every}") 110 | # collect_transitions = target_examples_per_step * collect_every 111 | # print0(f"=> Setting collect_transitions: {collect_transitions}") 112 | 113 | # ----------------------------------------------------------------------------- 114 | # DataLoader 115 | 116 | examples_per_step = device_batch_size * ddp_world_size 117 | print0(f"Target examples per step: {target_examples_per_step}") 118 | print0(f"Device batch size: {device_batch_size}") 119 | print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}") 120 | assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step" 121 | grad_accum_steps = target_examples_per_step // examples_per_step 122 | print0(f"=> Setting grad accum steps: {grad_accum_steps}") 123 | 124 | rank_seed = seed + ddp_rank 125 | mathlib_train = list(iter_data(split="train")) 126 | random.Random(rank_seed).shuffle(mathlib_train) 127 | 128 | config = Config() 129 | replay_buffer = ReplayBuffer(config, seed=rank_seed) 130 | theorems_sampler = TheoremsSampler(seed=rank_seed) 131 | 132 | def train_generator(): 133 | rng = random.Random(rank_seed) 134 | mathlib_iter = iter(mathlib_train) 135 | while True: 136 | assert len(replay_buffer.buffer) > 100 137 | if rng.random() < fraction_sft: 138 | try: 139 | yield next(mathlib_iter) 140 | except StopIteration: 141 | mathlib_iter = iter(mathlib_train) 142 | yield next(mathlib_iter) 143 | else: 144 | yield replay_buffer.sample_transition() 145 | 146 | train_loader = rl_data_generator(train_generator(), batch_size=device_batch_size) 147 | 148 | # ----------------------------------------------------------------------------- 149 | # Initialize the Optimizer 150 | 151 | optimizers = model.setup_optimizers( 152 | unembedding_lr=unembedding_lr, 153 | embedding_lr=embedding_lr, 154 | matrix_lr=matrix_lr, 155 | weight_decay=weight_decay, 156 | ) 157 | # Set the initial learning rate as a fraction of the base learning rate 158 | for opt in optimizers: 159 | for group in opt.param_groups: 160 | group["lr"] = group["lr"] * init_lr_frac 161 | 162 | # ----------------------------------------------------------------------------- 163 | # Training loop 164 | 165 | # Go! 166 | step = 0 167 | timer = SimpleTimer() 168 | while True: 169 | if step % collect_every == 0: 170 | # collect proofs 171 | timer.start("collect") 172 | model.eval() 173 | actor_timer = run_actor(collect_transitions, config, tactic_model, replay_buffer, theorems_sampler) 174 | actor_timer = actor_timer.gather() 175 | if master_process: 176 | actor_timer.log_times() 177 | model.train() 178 | timer.end("collect") 179 | replay_buffer.synchronize() 180 | with open(os.path.join(output_dir, f"replay_buffer_{step:05d}.json"), "w") as f: 181 | json.dump(replay_buffer.buffer, f) 182 | 183 | if step % eval_every == 0: 184 | timer.start("eval") 185 | model.eval() 186 | 187 | minif2f_theorems = minif2f.list_theorems(split="Valid") 188 | minif2f_theorems = minif2f_theorems[:64] 189 | print0(f"Evaluating on {len(minif2f_theorems)} theorems from MiniF2F") 190 | minif2f_results = eval_success_rate(tactic_model, minif2f_theorems) 191 | 192 | leanworkbook_theorems = leanworkbook.list_theorems(split="val") 193 | leanworkbook_theorems = leanworkbook_theorems[:64] 194 | print0(f"Evaluating on {len(leanworkbook_theorems)} theorems from LeanWorkBook") 195 | leanworkbook_results = eval_success_rate(tactic_model, leanworkbook_theorems) 196 | 197 | print0(f"Step {step:05d} | minif2f: {minif2f_results['success_rate']:.4%} ({minif2f_results['solved']}/{minif2f_results['total']}, with {minif2f_results['errors']} errors) | leanworkbook: {leanworkbook_results['success_rate']:.4%} ({leanworkbook_results['solved']}/{leanworkbook_results['total']}, with {leanworkbook_results['errors']} errors)") 198 | wandb_run.log({ 199 | "step": step, 200 | "minif2f_val": minif2f_results['success_rate'], 201 | "leanworkbook_val": leanworkbook_results['success_rate'], 202 | }) 203 | model.train() 204 | timer.end("eval") 205 | 206 | if step > 0 and step % save_every == 0 and master_process: 207 | checkpoint_dir = os.path.join(output_dir, "checkpoints") 208 | model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer 209 | save_checkpoint( 210 | checkpoint_dir, 211 | step, 212 | model.state_dict(), 213 | [opt.state_dict() for opt in optimizers], # optimizer states 214 | { 215 | "step": step, 216 | "model_config": model_config_kwargs, 217 | "minif2f_val": minif2f_results['success_rate'], 218 | "leanworkbook_val": leanworkbook_results['success_rate'], 219 | }, 220 | rank=ddp_rank, 221 | ) 222 | 223 | timer.start("train") 224 | # evaluate the gradient 225 | num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen 226 | for micro_step in range(grad_accum_steps): 227 | train_inputs, train_targets = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward 228 | with autocast_ctx: 229 | loss = model(train_inputs, train_targets) 230 | train_loss = loss.detach() # for logging 231 | loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here 232 | loss.backward() # accumulate the gradient 233 | num_tokens += (train_targets >= 0).sum() 234 | if ddp: 235 | dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks 236 | 237 | # step the optimizers 238 | for opt in optimizers: 239 | opt.step() 240 | model.zero_grad(set_to_none=True) 241 | timer.end("train") 242 | 243 | # logging 244 | train_loss_item = train_loss.item() 245 | num_tokens_item = num_tokens.item() 246 | print0(f"Step {step:05d} | Training loss: {train_loss_item:.6f} | num_tokens: {num_tokens_item:,} | replay_buffer_size: {len(replay_buffer.buffer)}") 247 | timer.log_times() 248 | wandb_run.log({ 249 | "step": step, 250 | "train_loss": train_loss_item, 251 | "num_tokens": num_tokens_item, 252 | "replay_buffer_size": len(replay_buffer.buffer), 253 | **{f"time/{k}": v for k, v in timer.get_times().items()} 254 | }) 255 | 256 | step += 1 -------------------------------------------------------------------------------- /nanoproof/engine.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class KVCache: 8 | """ 9 | Works hand-in-hand with the GPT model to maintain the KV cache. 10 | Note that the .pos advances automatically after the last layer of the Transformer inserts. 11 | """ 12 | 13 | def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers): 14 | # Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer. 15 | self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim) 16 | self.kv_cache = None 17 | self.pos = 0 # current position in time in the cache 18 | 19 | def reset(self): 20 | self.pos = 0 21 | 22 | def get_pos(self): 23 | return self.pos 24 | 25 | def prefill(self, other): 26 | """ 27 | Prefill given another KV cache. Optionally expand along batch dim. 28 | This is used when we do batch 1 prefill and then want to generate 29 | multiple samples in parallel from there. 30 | """ 31 | # 1) validate the shapes 32 | assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" 33 | assert other.kv_cache is not None, "Cannot prefill with a None KV cache" 34 | 35 | # Extract dimensions explicitly 36 | self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape 37 | other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape 38 | 39 | # Validate dimensions 40 | assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}" 41 | assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}" 42 | assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}" 43 | assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}" 44 | 45 | # Batch size can be expanded (other can be 1, self can be larger) 46 | assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)" 47 | 48 | # Sequence length: self must be longer than other 49 | assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}" 50 | 51 | # 2) initialize the cache 52 | dtype, device = other.kv_cache.dtype, other.kv_cache.device 53 | self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device) 54 | # 3) copy the data over 55 | self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache 56 | # 4) update the pos 57 | self.pos = other.pos 58 | 59 | def insert_kv(self, layer_idx, k, v): 60 | # Lazy initialize the cache here because we need to know the dtype/device 61 | if self.kv_cache is None: 62 | self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device) 63 | # Insert new keys/values to the cache and return the full cache so far 64 | B, H, T_add, D = k.size() 65 | t0, t1 = self.pos, self.pos + T_add 66 | # Dynamically grow the cache if needed 67 | if t1 > self.kv_cache.size(4): 68 | t_needed = t1 + 1024 # as much as we need plus buffer of 1024 69 | t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024 70 | additional_shape = list(self.kv_cache.shape) 71 | additional_shape[4] = t_needed - self.kv_cache.size(4) 72 | additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device) 73 | self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous() 74 | self.kv_shape = self.kv_cache.shape 75 | # Insert k, v into the cache 76 | self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k 77 | self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v 78 | # Return the full cached keys/values up to current position (as a view) 79 | key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :] 80 | value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :] 81 | # Increment pos after the last layer of the Transformer processes 82 | if layer_idx == self.kv_cache.size(0) - 1: 83 | self.pos = t1 84 | return key_view, value_view 85 | 86 | 87 | # ----------------------------------------------------------------------------- 88 | @torch.inference_mode() 89 | def sample_next_token(logits, rng, temperature=1.0, top_k=None): 90 | """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1).""" 91 | assert temperature >= 0.0, "temperature must be non-negative" 92 | if temperature == 0.0: 93 | return torch.argmax(logits, dim=-1, keepdim=True) 94 | if top_k is not None: 95 | k = min(top_k, logits.size(-1)) 96 | vals, idx = torch.topk(logits, k, dim=-1) 97 | vals = vals / temperature 98 | probs = F.softmax(vals, dim=-1) 99 | choice = torch.multinomial(probs, num_samples=1, generator=rng) 100 | return idx.gather(1, choice) 101 | else: 102 | logits = logits / temperature 103 | probs = F.softmax(logits, dim=-1) 104 | return torch.multinomial(probs, num_samples=1, generator=rng) 105 | 106 | # ----------------------------------------------------------------------------- 107 | 108 | class RowState: 109 | # Per-row state tracking during generation 110 | def __init__(self, current_tokens=None): 111 | self.current_tokens = current_tokens or [] # Current token sequence for this row 112 | self.completed = False # Whether this row has completed generation 113 | 114 | class Engine: 115 | def __init__(self, model, tokenizer): 116 | self.model = model 117 | self.tokenizer = tokenizer # needed for tool use 118 | 119 | @torch.inference_mode() 120 | def generate(self, tokens, num_samples=1, max_tokens=None, min_tokens=None, temperature=1.0, top_k=None, seed=42): 121 | """Same as generate, but does single prefill and then clones the KV cache.""" 122 | assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints" 123 | device = self.model.get_device() 124 | rng = torch.Generator(device=device) 125 | rng.manual_seed(seed) 126 | 127 | eos = self.tokenizer.get_eos_token_id() 128 | bos = self.tokenizer.get_bos_token_id() 129 | 130 | # 1) Run a batch 1 prefill of the prompt tokens 131 | m = self.model.config 132 | kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer} 133 | kv_cache_prefill = KVCache( 134 | batch_size=1, 135 | seq_len=len(tokens), 136 | **kv_model_kwargs, 137 | ) 138 | ids = torch.tensor([tokens], dtype=torch.long, device=device) 139 | logits = self.model.forward(ids, kv_cache=kv_cache_prefill) 140 | logits = logits[:, -1, :] 141 | if min_tokens is not None and 0 < min_tokens: 142 | logits[:, eos] = float('-inf') 143 | logits[:, bos] = float('-inf') 144 | if num_samples > 1: 145 | # Expand logits so that each initial token is sampled independently 146 | logits = logits.expand(num_samples, -1) 147 | next_ids = sample_next_token(logits, rng, temperature, top_k) # (num_samples, 1) 148 | sampled_tokens = next_ids[:, 0].tolist() 149 | 150 | # 2) Replicate the KV cache for each sample/row 151 | kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len 152 | kv_cache_decode = KVCache( 153 | batch_size=num_samples, 154 | seq_len=kv_length_hint, 155 | **kv_model_kwargs, 156 | ) 157 | kv_cache_decode.prefill(kv_cache_prefill) 158 | del kv_cache_prefill # no need to keep this memory around 159 | 160 | # 3) Initialize states for each sample 161 | row_states = [RowState(tokens.copy()) for _ in range(num_samples)] 162 | 163 | # 4) Main generation loop 164 | num_generated = 0 165 | first_iteration = True 166 | while True: 167 | # Stop condition: we've reached max tokens 168 | if max_tokens is not None and num_generated >= max_tokens: 169 | break 170 | # Stop condition: all rows are completed 171 | if all(state.completed for state in row_states): 172 | break 173 | 174 | # Get sampled tokens - either from prefill or from forward pass 175 | if first_iteration: 176 | # Use the tokens we already sampled from prefill 177 | first_iteration = False 178 | else: 179 | # Forward the model and get the next token for each row 180 | logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size) 181 | logits = logits[:, -1, :] # (B, vocab_size) at last time step 182 | if min_tokens is not None and num_generated < min_tokens: 183 | logits[:, eos] = float('-inf') 184 | logits[:, bos] = float('-inf') 185 | next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) 186 | sampled_tokens = next_ids[:, 0].tolist() 187 | 188 | # Process each row: choose the next token, update state, optional tool use 189 | token_column = [] # contains the next token id along each row 190 | token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row 191 | for i, state in enumerate(row_states): 192 | token_masks.append(1) # mask is 0 if forced, 1 if sampled 193 | next_token = sampled_tokens[i] 194 | token_column.append(next_token) 195 | # Update the state of this row to include the next token 196 | state.current_tokens.append(next_token) 197 | # On eos or bos, mark the row as completed 198 | if next_token == eos or next_token == bos: 199 | state.completed = True 200 | 201 | # Yield the token column 202 | yield token_column, token_masks 203 | num_generated += 1 204 | # Prepare ids for next iteration 205 | ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1) 206 | 207 | def generate_batch(self, tokens, num_samples=1, **kwargs): 208 | """ 209 | Non-streaming batch generation that just returns the final token sequences. 210 | Returns a list of token sequences (list of lists of ints). 211 | Terminal tokens (eos, bos) are not included in the results. 212 | """ 213 | eos = self.tokenizer.get_eos_token_id() 214 | bos = self.tokenizer.get_bos_token_id() 215 | results = [tokens.copy() for _ in range(num_samples)] 216 | masks = [[0] * len(tokens) for _ in range(num_samples)] 217 | completed = [False] * num_samples 218 | for token_column, token_masks in self.generate(tokens, num_samples, **kwargs): 219 | for i, (token, mask) in enumerate(zip(token_column, token_masks)): 220 | if not completed[i]: 221 | if token == eos or token == bos: 222 | completed[i] = True 223 | else: 224 | results[i].append(token) 225 | masks[i].append(mask) 226 | # Stop if all rows are completed 227 | if all(completed): 228 | break 229 | return results, masks -------------------------------------------------------------------------------- /nanoproof/midtrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Midtrain the model. Same as pretraining but simpler. 3 | Run as: 4 | 5 | python -m scripts.mid_train 6 | 7 | Or torchrun for training: 8 | 9 | torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16 10 | """ 11 | 12 | from collections import deque 13 | import os 14 | from contextlib import nullcontext 15 | import time 16 | 17 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 18 | import wandb 19 | import torch 20 | import torch.distributed as dist 21 | 22 | from nanoproof.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type 23 | from nanoproof.tokenizer import get_token_bytes 24 | from nanoproof.checkpoints import save_checkpoint 25 | from nanoproof.loss_eval import evaluate_bpb 26 | from nanoproof.checkpoints import load_model 27 | from nanoproof.data.leangithubraw import iter_data 28 | 29 | # ----------------------------------------------------------------------------- 30 | run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) 31 | device_type = "" # cuda|cpu|mps (empty => autodetect) 32 | model_tag = "d26" # model tag to load the model from (base model or midtrained model) 33 | step = None # step to load the model from (base model or midtrained model) 34 | dtype = "bfloat16" 35 | num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) 36 | max_seq_len = 768 37 | device_batch_size = 32 # H100 38 | # device_batch_size = 8 # A100 40GB 39 | unembedding_lr = 0.004 40 | embedding_lr = 0.2 41 | matrix_lr = 0.02 42 | init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate 43 | weight_decay = 0.0 44 | eval_every = 150 # -1 = disable 45 | # total_batch_size = 524288 46 | total_batch_size = 491520 47 | eval_tokens = 20*total_batch_size 48 | dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report 49 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 50 | exec(open(os.path.join('nanoproof', 'configurator.py')).read()) # overrides from command line or config file 51 | user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging 52 | # ----------------------------------------------------------------------------- 53 | 54 | # Compute init 55 | device_type = autodetect_device_type() if device_type == "" else device_type 56 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) 57 | master_process = ddp_rank == 0 58 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() 59 | synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None 60 | get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 61 | 62 | # wandb logging init 63 | use_dummy_wandb = run == "dummy" or not master_process 64 | wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanoproof-mid", name=run, config=user_config) 65 | 66 | # Load the model and tokenizer 67 | model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step) 68 | pretrain_batch_size = meta.get("device_batch_size", None) 69 | if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size: 70 | print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?") 71 | orig_model = model 72 | model = torch.compile(model, dynamic=False) 73 | depth = model.config.n_layer 74 | num_flops_per_token = model.estimate_flops() 75 | tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank 76 | world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks 77 | assert total_batch_size % world_tokens_per_fwdbwd == 0 78 | grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd 79 | print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") 80 | print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") 81 | print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") 82 | token_bytes = get_token_bytes(device=device) 83 | 84 | # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) 85 | optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) 86 | adamw_optimizer, muon_optimizer = optimizers 87 | # Override the initial learning rate as a fraction of the base learning rate 88 | for opt in optimizers: 89 | for group in opt.param_groups: 90 | group["lr"] = group["lr"] * init_lr_frac 91 | group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later 92 | 93 | # Midtraining data mixture and DataLoader 94 | base_dir = get_base_dir() 95 | train_loader = iter_data(device_batch_size, max_seq_len, "train") 96 | build_val_loader = lambda: iter_data(device_batch_size, max_seq_len, "val") 97 | 98 | progress = 0 # will go from 0 to 1 over the course of the epoch 99 | 100 | # TODO: try adding warmup (now, loss goes up first few steps) 101 | # Learning rate scheduler 102 | def get_lr_multiplier(progress): 103 | # first 80% of training: no decay, then linearly ramp down to 0.01 104 | return 1 if progress < 0.8 else max(0.01, 1 - (progress - 0.8) / 0.2) 105 | 106 | # Momentum scheduler for Muon optimizer 107 | def get_muon_momentum(it): 108 | frac = min(it / 300, 1) 109 | momentum = (1 - frac) * 0.85 + frac * 0.95 110 | return momentum 111 | 112 | # ----------------------------------------------------------------------------- 113 | # Training loop 114 | x, y, approx_progress, last_step = next(train_loader) # prefetch the very first batch of data 115 | min_val_bpb = float("inf") 116 | smooth_train_loss = 0 # EMA of training loss 117 | ema_beta = 0.9 # EMA decay factor 118 | total_training_time = 0 # total wall-clock time of training 119 | step = 0 120 | while True: 121 | flops_so_far = num_flops_per_token * total_batch_size * step 122 | 123 | # Synchronize last_step across all ranks to avoid hangs in the distributed setting 124 | if ddp: 125 | last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device) 126 | dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) 127 | last_step = bool(last_step_tensor.item()) 128 | 129 | # once in a while: evaluate the val bpb (all ranks participate) 130 | if eval_every > 0 and (last_step or step % eval_every == 0): 131 | model.eval() 132 | val_loader = build_val_loader() 133 | eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size) 134 | with autocast_ctx: 135 | val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) 136 | print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") 137 | if val_bpb < min_val_bpb: 138 | min_val_bpb = val_bpb 139 | wandb_run.log({ 140 | "step": step, 141 | "total_training_flops": flops_so_far, 142 | "total_training_time": total_training_time, 143 | "val/bpb": val_bpb, 144 | }) 145 | model.train() 146 | 147 | # save checkpoint at the end of the run (only on master process) 148 | if master_process and last_step and not dry_run: 149 | output_dirname = f"d{depth}" # e.g. d12 150 | checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname) 151 | save_checkpoint( 152 | checkpoint_dir, 153 | step, 154 | orig_model.state_dict(), 155 | [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly 156 | { 157 | "step": step, 158 | "val_bpb": val_bpb, # loss at last step 159 | "model_config": { 160 | "sequence_len": max_seq_len, 161 | "vocab_size": tokenizer.get_vocab_size(), 162 | "n_layer": depth, 163 | "n_head": model.config.n_head, 164 | "n_kv_head": model.config.n_kv_head, 165 | "n_embd": model.config.n_embd, 166 | }, 167 | "user_config": user_config, # inputs to the training script 168 | } 169 | ) 170 | 171 | if last_step: 172 | break 173 | 174 | # ------------------------------------------------------------------------- 175 | # single training step 176 | # evaluate the gradient 177 | synchronize() 178 | t0 = time.time() 179 | for micro_step in range(grad_accum_steps): 180 | with autocast_ctx: 181 | loss = model(x, y) 182 | train_loss = loss.detach() # for logging 183 | loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here 184 | loss.backward() 185 | x, y, approx_progress, last_step = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward 186 | progress = max(progress, approx_progress) # only increase progress monotonically 187 | # step the optimizers 188 | lrm = get_lr_multiplier(progress) 189 | for opt in optimizers: 190 | for group in opt.param_groups: 191 | group["lr"] = group["initial_lr"] * lrm 192 | muon_momentum = get_muon_momentum(step) 193 | for group in muon_optimizer.param_groups: 194 | group["momentum"] = muon_momentum 195 | for opt in optimizers: 196 | opt.step() 197 | model.zero_grad(set_to_none=True) 198 | synchronize() 199 | t1 = time.time() 200 | dt = t1 - t0 201 | # ------------------------------------------------------------------------- 202 | 203 | # State 204 | step += 1 205 | 206 | # logging 207 | smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss 208 | debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA 209 | pct_done = 100 * progress 210 | if ddp: 211 | pct_done_tensor = torch.tensor([pct_done], dtype=torch.float32, device=device) 212 | gathered_pct_done = [torch.zeros_like(pct_done_tensor) for _ in range(ddp_world_size)] 213 | dist.all_gather(gathered_pct_done, pct_done_tensor) 214 | pct_dones = [t.item() for t in gathered_pct_done] 215 | pct_done_str = "[" + ", ".join(f"{p:.2f}" for p in pct_dones) + "]%" 216 | else: 217 | pct_done_str = f"{pct_done:.2f}%" 218 | 219 | tok_per_sec = int(total_batch_size / dt) 220 | flops_per_sec = num_flops_per_token * total_batch_size / dt 221 | promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity 222 | mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % 223 | if step > 10: 224 | total_training_time += dt # only count the time after the first 10 steps 225 | print0(f"step {step:05d} ({pct_done_str}) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") 226 | if step % 10 == 0: 227 | wandb_run.log({ 228 | "step": step, 229 | "total_training_flops": flops_so_far, 230 | "total_training_time": total_training_time, 231 | "train/loss": debiased_smooth_loss, 232 | "train/lrm": lrm, 233 | "train/dt": dt, 234 | "train/tok_per_sec": tok_per_sec, 235 | "train/mfu": mfu, 236 | }) 237 | 238 | # print a few more stats 239 | print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") 240 | print0(f"Total training time: {total_training_time/60:.2f}m") 241 | print0(f"Minimum validation bpb: {min_val_bpb:.4f}") 242 | 243 | # Log to report 244 | if not dry_run: 245 | from nanoproof.report import get_report 246 | get_report().log(section="Midtraining", data=[ 247 | user_config, # CLI args 248 | { # stats about the training setup 249 | "Number of iterations": step, 250 | "DDP world size": ddp_world_size, 251 | }, 252 | { # stats about training outcomes 253 | "Minimum validation bpb": min_val_bpb, 254 | } 255 | ]) 256 | 257 | # cleanup 258 | wandb_run.finish() # wandb run finish 259 | compute_cleanup() -------------------------------------------------------------------------------- /nanoproof/sft.py: -------------------------------------------------------------------------------- 1 | """ 2 | Finetune a base model to be a prover model. 3 | Run on one GPU e.g. for debugging: 4 | 5 | python -m scripts.sft 6 | 7 | Or torchrun for training: 8 | 9 | torchrun --standalone --nproc_per_node=8 -m scripts.sft 10 | """ 11 | 12 | import os 13 | 14 | import leantree.augmentations 15 | 16 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 17 | import random 18 | 19 | import wandb 20 | import torch 21 | import torch.distributed as dist 22 | from contextlib import nullcontext 23 | 24 | from nanoproof.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type 25 | from nanoproof.checkpoints import load_model, save_checkpoint 26 | from nanoproof.engine import Engine 27 | from nanoproof.data.leantree import iter_data 28 | from nanoproof.data.leantree_dataloader import sft_data_generator 29 | from scripts.policy_eval import eval_tactic_accuracy 30 | 31 | # ----------------------------------------------------------------------------- 32 | # SFT Hyperparameters 33 | run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) 34 | seed = 0 35 | # input model options 36 | source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model) 37 | model_tag = "d26" # model tag to load the model from (base model or midtrained model) 38 | step = None # step to load the model from (base model or midtrained model) 39 | # compute/precision 40 | device_type = "" # cuda|cpu|mps (empty => autodetect) 41 | dtype = "bfloat16" 42 | device_batch_size = 8 # (maybe) max to avoid OOM (on A100 40GB) 43 | # optimization 44 | num_epochs = 1 45 | num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it) 46 | target_examples_per_step = 512 47 | unembedding_lr = 0.004 48 | embedding_lr = 0.2 49 | matrix_lr = 0.02 50 | weight_decay = 0.0 51 | init_lr_frac = 0.02 52 | # evaluation and logging there of 53 | eval_every = 100 54 | eval_steps = 100 55 | # eval_metrics_every = 200 56 | sample_every = 100 57 | eval_metrics_max_problems = 1024 58 | # now allow CLI to override the settings via the configurator lol 59 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 60 | exec(open(os.path.join('nanoproof', 'configurator.py')).read()) # overrides from command line or config file 61 | user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging 62 | # ----------------------------------------------------------------------------- 63 | 64 | # Compute init 65 | device_type = autodetect_device_type() if device_type == "" else device_type 66 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) 67 | master_process = ddp_rank == 0 68 | ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16 69 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() 70 | 71 | # wandb logging init 72 | use_dummy_wandb = run == "dummy" or not master_process 73 | wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanoproof-sft", name=run, config=user_config, save_code=True) 74 | 75 | # Load the model and tokenizer 76 | model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step) 77 | orig_model = model # original, uncompiled model 78 | # model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs 79 | engine = Engine(model, tokenizer) # will be used for inline model evaluation only 80 | bos_token = tokenizer.get_bos_token_id() 81 | 82 | # ----------------------------------------------------------------------------- 83 | # DataLoader 84 | 85 | examples_per_step = device_batch_size * ddp_world_size 86 | print0(f"Target examples per step: {target_examples_per_step}") 87 | print0(f"Device batch size: {device_batch_size}") 88 | print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}") 89 | assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step" 90 | grad_accum_steps = target_examples_per_step // examples_per_step 91 | print0(f"=> Setting grad accum steps: {grad_accum_steps}") 92 | 93 | augmentations = [ 94 | leantree.augmentations.ShuffleGoalsAndHypotheses(seed=seed), 95 | leantree.augmentations.RandomRename(seed=seed), 96 | ] 97 | 98 | train_ds = list(iter_data(split="train", augmentations=augmentations)) 99 | random.Random(seed).shuffle(train_ds) 100 | val_ds = list(iter_data(split="val")) 101 | print0(f"Train dataset size: {len(train_ds)} | Val dataset size: {len(val_ds)}") 102 | 103 | # if num_iterations == -1: 104 | # # derive num_iterations from num_epochs and the size of the dataset 105 | # assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1" 106 | # num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs 107 | # print0(f"=> Setting number of iterations: {num_iterations}") 108 | train_loader = sft_data_generator(train_ds, batch_size=device_batch_size) 109 | build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size) 110 | 111 | # ----------------------------------------------------------------------------- 112 | # Initialize the Optimizer 113 | 114 | optimizers = model.setup_optimizers( 115 | unembedding_lr=unembedding_lr, 116 | embedding_lr=embedding_lr, 117 | matrix_lr=matrix_lr, 118 | weight_decay=weight_decay, 119 | ) 120 | # Set the initial learning rate as a fraction of the base learning rate 121 | for opt in optimizers: 122 | for group in opt.param_groups: 123 | group["lr"] = group["lr"] * init_lr_frac 124 | group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later 125 | 126 | # ----------------------------------------------------------------------------- 127 | # Training loop 128 | 129 | # Learning rate scheduler 130 | # def get_lr_multiplier(it): 131 | # lrm = 1.0 - it / num_iterations 132 | # return lrm 133 | 134 | 135 | # Learning rate scheduler 136 | def get_lr_multiplier(progress): 137 | # return max(0.0, 1.0 - progress) 138 | global_progress = (epoch + progress) / num_epochs 139 | return max(0.0, 1.0 - global_progress) 140 | 141 | # Go! 142 | progress = 0 # will go from 0 to 1 over the course of the epoch 143 | step = 0 144 | epoch = 0 145 | x, y, approx_progress, last_step = next(train_loader) # prefetch the very first batch of data 146 | while True: 147 | # Synchronize last_step across all ranks to avoid hangs in the distributed setting 148 | if ddp: 149 | last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device) 150 | dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) 151 | last_step = bool(last_step_tensor.item()) 152 | 153 | if last_step or step % eval_every == 0: 154 | model.eval() 155 | 156 | # evaluate the validation loss 157 | val_iter = iter(build_val_loader()) 158 | losses = [] 159 | for _ in range(eval_steps): 160 | val_inputs, val_targets, _, _ = next(val_iter) 161 | with torch.no_grad(), autocast_ctx: 162 | loss = model(val_inputs, val_targets) 163 | losses.append(loss) 164 | val_loss = torch.stack(losses).mean() # average over eval_steps 165 | if ddp: 166 | dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks 167 | val_loss = val_loss.item() 168 | 169 | with autocast_ctx: 170 | results = eval_tactic_accuracy(model, build_val_loader(), max_steps=eval_steps) 171 | 172 | print0(f"Step {step:05d} | Validation loss: {val_loss:.6f} | Tactic full accuracy: {results['full_acc']:.4%} | Tactic first token accuracy: {results['first_token_acc']:.4%}") 173 | 174 | wandb_run.log({ 175 | "step": step, 176 | "val_loss": val_loss, 177 | "val_full_acc": results["full_acc"], 178 | "val_first_token_acc": results["first_token_acc"], 179 | }) 180 | 181 | model.train() 182 | 183 | # TODO: eval tactic accuracy 184 | # TODO: eval value MSE 185 | 186 | # evaluate accuracy of the multiple choice tasks (which are quick to run) 187 | if last_step or (step > 0 and step % sample_every == 0): 188 | model.eval() 189 | prompts = [ 190 | "The capital of France is", 191 | "If 5*x + 3 = 13, then x is", 192 | # gold from mathlib: 'exact LipschitzWith.comp_locallyBoundedVariationOn (A i) h' 193 | """case h 194 | ι : Type u_4 195 | inst✝ : Fintype ι 196 | f : ℝ → ι → ℝ 197 | s : Set ℝ 198 | h : LocallyBoundedVariationOn f s 199 | A : ∀ (i : ι), LipschitzWith 1 fun x => x i 200 | i : ι 201 | ⊢ LocallyBoundedVariationOn (fun x => f x i) s 202 | <|tactic|>""", 203 | # sensible tactic: 'intro h' 204 | """p q : Prop 205 | ⊢ p ∧ q → p 206 | <|tactic|>""", 207 | # sensible tactic: 'rfl' 208 | """⊢ 2 + 3 = 5 209 | <|tactic|>""", 210 | # sensible tactic: 'exact Or.inl ⟨hp, hq⟩' 211 | """case mp.inl 212 | p q r : Prop 213 | hp : p 214 | hq : q 215 | ⊢ p ∧ q ∨ p ∧ r 216 | <|tactic|>""", 217 | # sensible tactic: 'exact Exists.intro x0 hx0' 218 | """α : Type 219 | P : α → Prop 220 | inst✝ : Inhabited α 221 | h : ∀ (x : α), P x 222 | x0 : α := default 223 | hx0 : P x0 224 | ⊢ ∃ x, P x 225 | <|tactic|> """, 226 | 227 | """p q : Prop 228 | ⊢ p ∧ q → p 229 | <|value|> """, 230 | """α : Type 231 | P : α → Prop 232 | inst✝ : Inhabited α 233 | h : ∀ (x : α), P x 234 | x0 : α := default 235 | hx0 : P x0 236 | ⊢ ∃ x, P x 237 | <|value|> """, 238 | ] 239 | engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation 240 | for prompt in prompts: 241 | tokens = tokenizer(prompt, prepend=bos_token) 242 | with autocast_ctx: 243 | sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) 244 | print0(tokenizer.decode(sample[0]) + "\n---") 245 | model.train() 246 | 247 | if last_step: 248 | if epoch < num_epochs - 1: 249 | print0(f"Epoch {epoch} done, starting next one.") 250 | epoch += 1 251 | train_loader = sft_data_generator(train_ds, batch_size=device_batch_size) 252 | progress = 0 253 | else: 254 | print0(f"Epoch {epoch} done, terminating.") 255 | break 256 | 257 | # evaluate the gradient 258 | num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen 259 | for micro_step in range(grad_accum_steps): 260 | train_inputs, train_targets, approx_progress, last_step = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward 261 | progress = max(progress, approx_progress) # only increase progress monotonically 262 | with autocast_ctx: 263 | loss = model(train_inputs, train_targets) 264 | train_loss = loss.detach() # for logging 265 | loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here 266 | loss.backward() # accumulate the gradient 267 | num_tokens += (train_targets >= 0).sum() 268 | if ddp: 269 | dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks 270 | 271 | # learning rate scheduler 272 | lrm = get_lr_multiplier(progress) 273 | for opt in optimizers: 274 | for group in opt.param_groups: 275 | group["lr"] = group["initial_lr"] * lrm 276 | 277 | # step the optimizers 278 | for opt in optimizers: 279 | opt.step() 280 | model.zero_grad(set_to_none=True) 281 | 282 | pct_done = 100 * progress 283 | if ddp: 284 | pct_done_tensor = torch.tensor([pct_done], dtype=torch.float32, device=device) 285 | gathered_pct_done = [torch.zeros_like(pct_done_tensor) for _ in range(ddp_world_size)] 286 | dist.all_gather(gathered_pct_done, pct_done_tensor) 287 | pct_dones = [t.item() for t in gathered_pct_done] 288 | pct_done_str = "[" + ", ".join(f"{p:.2f}" for p in pct_dones) + "]%" 289 | else: 290 | pct_done_str = f"{pct_done:.2f}%" 291 | 292 | # logging 293 | train_loss_item = train_loss.item() 294 | num_tokens_item = num_tokens.item() 295 | print0(f"Step {step:05d} ({pct_done_str}, ep {epoch:02d}/{num_epochs:02d}) | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}") 296 | wandb_run.log({ 297 | "step": step, 298 | "lrm": lrm, 299 | "train_loss": train_loss_item, 300 | "num_tokens": num_tokens_item, 301 | }) 302 | 303 | step += 1 304 | 305 | # Save the model at the end of the run 306 | if master_process: 307 | base_dir = get_base_dir() 308 | depth = model.config.n_layer 309 | model_tag = f"d{depth}" # base the model tag on the depth of the base model 310 | checkpoint_dir = os.path.join(base_dir, "sft_checkpoints", model_tag) 311 | model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer 312 | save_checkpoint( 313 | checkpoint_dir, 314 | step, 315 | model.state_dict(), 316 | None, # note: we don't bother to save the optimizer state 317 | { 318 | "step": step, 319 | "val_loss": val_loss, 320 | "model_config": model_config_kwargs, 321 | } 322 | ) 323 | print(f"✅ Saved model checkpoint to {checkpoint_dir}") 324 | 325 | # Log to report 326 | from nanoproof.report import get_report 327 | get_report().log(section="SFT", data=[ 328 | user_config, # CLI args 329 | { 330 | "Training rows": len(train_ds), 331 | "Number of iterations": step, 332 | "Training loss": train_loss_item, 333 | "Validation loss": val_loss, 334 | }, 335 | ]) 336 | 337 | # Cleanup 338 | wandb_run.finish() 339 | compute_cleanup() -------------------------------------------------------------------------------- /nanoproof/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common utilities for nanochat. 3 | """ 4 | 5 | import os 6 | import time 7 | import re 8 | import logging 9 | import math 10 | import urllib.request 11 | import gc 12 | from collections import Counter 13 | from filelock import FileLock 14 | from typing import Callable, TypeVar, Self 15 | 16 | import torch 17 | import torch.distributed as dist 18 | import numpy as np 19 | from PrettyPrint import PrettyPrintTree 20 | 21 | class ColoredFormatter(logging.Formatter): 22 | """Custom formatter that adds colors to log messages.""" 23 | # ANSI color codes 24 | COLORS = { 25 | 'DEBUG': '\033[36m', # Cyan 26 | 'INFO': '\033[32m', # Green 27 | 'WARNING': '\033[33m', # Yellow 28 | 'ERROR': '\033[31m', # Red 29 | 'CRITICAL': '\033[35m', # Magenta 30 | } 31 | RESET = '\033[0m' 32 | BOLD = '\033[1m' 33 | def format(self, record): 34 | # Add color to the level name 35 | levelname = record.levelname 36 | if levelname in self.COLORS: 37 | record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}" 38 | # Format the message 39 | message = super().format(record) 40 | # Add color to specific parts of the message 41 | if levelname == 'INFO': 42 | # Highlight numbers and percentages 43 | message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message) 44 | message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message) 45 | return message 46 | 47 | def setup_default_logging(): 48 | handler = logging.StreamHandler() 49 | handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) 50 | logging.basicConfig( 51 | level=logging.INFO, 52 | handlers=[handler] 53 | ) 54 | 55 | setup_default_logging() 56 | logger = logging.getLogger(__name__) 57 | 58 | def get_base_dir(): 59 | # co-locate nanochat intermediates with other cached data in ~/.cache (by default) 60 | if os.environ.get("NANOPROOF_BASE_DIR"): 61 | nanochat_dir = os.environ.get("NANOPROOF_BASE_DIR") 62 | else: 63 | home_dir = os.path.expanduser("~") 64 | cache_dir = os.path.join(home_dir, ".cache") 65 | nanochat_dir = os.path.join(cache_dir, "nanoproof") 66 | os.makedirs(nanochat_dir, exist_ok=True) 67 | return nanochat_dir 68 | 69 | def download_file_with_lock(url, filename, postprocess_fn=None): 70 | """ 71 | Downloads a file from a URL to a local path in the base directory. 72 | Uses a lock file to prevent concurrent downloads among multiple ranks. 73 | """ 74 | base_dir = get_base_dir() 75 | file_path = os.path.join(base_dir, filename) 76 | lock_path = file_path + ".lock" 77 | 78 | if os.path.exists(file_path): 79 | return file_path 80 | 81 | with FileLock(lock_path): 82 | # Only a single rank can acquire this lock 83 | # All other ranks block until it is released 84 | 85 | # Recheck after acquiring lock 86 | if os.path.exists(file_path): 87 | return file_path 88 | 89 | # Download the content as bytes 90 | print(f"Downloading {url}...") 91 | with urllib.request.urlopen(url) as response: 92 | content = response.read() # bytes 93 | 94 | # Write to local file 95 | with open(file_path, 'wb') as f: 96 | f.write(content) 97 | print(f"Downloaded to {file_path}") 98 | 99 | # Run the postprocess function if provided 100 | if postprocess_fn is not None: 101 | postprocess_fn(file_path) 102 | 103 | return file_path 104 | 105 | def print0(s="",**kwargs): 106 | ddp_rank = int(os.environ.get('RANK', 0)) 107 | if ddp_rank == 0: 108 | print(s, **kwargs) 109 | 110 | def print_banner(): 111 | # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ 112 | banner = """ 113 | ██████ 114 | ███░░███ 115 | ████████ ██████ ████████ ██████ ████████ ████████ ██████ ██████ ░███ ░░░ 116 | ░░███░░███ ░░░░░███ ░░███░░███ ███░░███░░███░░███░░███░░███ ███░░███ ███░░███ ███████ 117 | ░███ ░███ ███████ ░███ ░███ ░███ ░███ ░███ ░███ ░███ ░░░ ░███ ░███░███ ░███░░░███░ 118 | ░███ ░███ ███░░███ ░███ ░███ ░███ ░███ ░███ ░███ ░███ ░███ ░███░███ ░███ ░███ 119 | ████ █████░░████████ ████ █████░░██████ ░███████ █████ ░░██████ ░░██████ █████ 120 | ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░███░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░░ 121 | ░███ 122 | █████ 123 | ░░░░░ 124 | """ 125 | print0(banner) 126 | 127 | def is_ddp(): 128 | # TODO is there a proper way 129 | return int(os.environ.get('RANK', -1)) != -1 130 | 131 | def get_dist_info(): 132 | if is_ddp(): 133 | assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) 134 | ddp_rank = int(os.environ['RANK']) 135 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 136 | ddp_world_size = int(os.environ['WORLD_SIZE']) 137 | return True, ddp_rank, ddp_local_rank, ddp_world_size 138 | else: 139 | return False, 0, 0, 1 140 | 141 | def autodetect_device_type(): 142 | # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU 143 | if torch.cuda.is_available(): 144 | device_type = "cuda" 145 | elif torch.backends.mps.is_available(): 146 | device_type = "mps" 147 | else: 148 | device_type = "cpu" 149 | print0(f"Autodetected device type: {device_type}") 150 | return device_type 151 | 152 | def compute_init(device_type="cuda"): # cuda|cpu|mps 153 | """Basic initialization that we keep doing over and over, so make common.""" 154 | 155 | assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" 156 | if device_type == "cuda": 157 | assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" 158 | if device_type == "mps": 159 | assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" 160 | 161 | # Reproducibility 162 | # Note that we set the global seeds here, but most of the code uses explicit rng objects. 163 | # The only place where global rng might be used is nn.Module initialization of the model weights. 164 | torch.manual_seed(42) 165 | if device_type == "cuda": 166 | torch.cuda.manual_seed(42) 167 | # skipping full reproducibility for now, possibly investigate slowdown later 168 | # torch.use_deterministic_algorithms(True) 169 | 170 | # Precision 171 | if device_type == "cuda": 172 | torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls 173 | 174 | # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA 175 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() 176 | if ddp and device_type == "cuda": 177 | device = torch.device("cuda", ddp_local_rank) 178 | torch.cuda.set_device(device) # make "cuda" default to this device 179 | dist.init_process_group(backend="nccl", device_id=device) 180 | dist.barrier() 181 | else: 182 | device = torch.device(device_type) # mps|cpu 183 | 184 | if ddp_rank == 0: 185 | logger.info(f"Distributed world size: {ddp_world_size}") 186 | 187 | return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device 188 | 189 | def compute_cleanup(): 190 | """Companion function to compute_init, to clean things up before script exit""" 191 | if is_ddp(): 192 | dist.destroy_process_group() 193 | 194 | class DummyWandb: 195 | """Useful if we wish to not use wandb but have all the same signatures""" 196 | def __init__(self): 197 | pass 198 | def log(self, *args, **kwargs): 199 | pass 200 | def finish(self): 201 | pass 202 | 203 | def format_distribution(bins: list[float], hist_height: int = 10, bin_labels: list[str] = None) -> str: 204 | bar_char = '❚' # Heavy vertical bar character. 205 | 206 | num_bins = len(bins) 207 | max_bin = max(bins) 208 | result = "" 209 | 210 | if max_bin == 0: 211 | max_bin = 1 # To avoid division by zero; all bars will be zero height. 212 | 213 | scaled_bins = [(bin_value / max_bin) * hist_height for bin_value in bins] 214 | # Round up to ensure visibility of non-zero bins. 215 | bar_heights = [math.ceil(height) for height in scaled_bins] 216 | 217 | # Determine y-axis labels (from HIST_HEIGHT down to 1) 218 | for row in range(hist_height, 0, -1): 219 | label_value = (row / hist_height) * max_bin 220 | label = f"{label_value:>3.1f} |" 221 | row_str = label 222 | for height in bar_heights: 223 | if height >= row: 224 | row_str += f" {bar_char} " 225 | else: 226 | row_str += " " * 3 227 | result += row_str + "\n" 228 | 229 | x_axis = " +" + "---" * num_bins 230 | result += x_axis + "\n" 231 | 232 | # x-axis labels. 233 | if not bin_labels: 234 | bin_labels = [f"{i}" for i in range(num_bins)] 235 | label_str = " " 236 | for label in bin_labels: 237 | assert len(label) <= 2 238 | if len(label) == 1: 239 | label_str += f" {label} " 240 | else: 241 | label_str += f"{label} " 242 | result += label_str + "\n" 243 | return result 244 | 245 | def deep_shape(obj, seen=None, level=0, pretty=False): 246 | if seen is None: 247 | seen = set() 248 | if id(obj) in seen: 249 | return "" 250 | seen.add(id(obj)) 251 | 252 | def join_parts(parts): 253 | if pretty: 254 | return "\n" + " " * level + (",\n" + " " * level).join(parts) + "\n" + " " * (level - 1) 255 | return ", ".join(parts) 256 | 257 | if isinstance(obj, tuple): 258 | return "(" + join_parts([deep_shape(o, seen, level + 1, pretty) for o in obj]) + ")" 259 | if isinstance(obj, list): 260 | if all(isinstance(o, (int, float, str, bool, type(None))) for o in obj): 261 | type_counts = Counter(type(o).__name__ for o in obj) 262 | return f"[{', '.join(f'{k}-{v}' for k, v in type_counts.items())}]" 263 | return "[" + join_parts([deep_shape(o, seen, level + 1, pretty) for o in obj]) + "]" 264 | if isinstance(obj, dict): 265 | return "{" + join_parts([str(k) + ": " + deep_shape(v, seen, level + 1, pretty) for k, v in obj.items()]) + "}" 266 | if isinstance(obj, np.ndarray): 267 | return "np-" + str(obj.shape) 268 | if isinstance(obj, torch.Tensor): 269 | return "pt-" + str(tuple(obj.shape)) 270 | if isinstance(obj, str): 271 | return "str-" + str(len(obj)) 272 | return str(obj) 273 | 274 | 275 | def flush(): 276 | gc.collect() 277 | torch.cuda.empty_cache() 278 | torch.cuda.reset_peak_memory_stats() 279 | 280 | def strict_zip(a: list, b: list): 281 | if len(a) != len(b): 282 | raise Exception(f"List sizes differ ({len(a)} != {len(b)}).") 283 | return zip(a, b) 284 | 285 | TypeNode = TypeVar('TypeNode') 286 | def pretty_print_tree( 287 | root: TypeNode, 288 | get_children: Callable[[TypeNode], list[TypeNode]], 289 | node_to_str: Callable[[TypeNode], str], 290 | edge_to_str: Callable[[TypeNode], str | None] | None = None, 291 | max_label_len=55, 292 | max_edge_label_len=None, 293 | ) -> str: 294 | def trimmed_edge_to_str(e: TypeNode) -> str | None: 295 | if edge_to_str is None: 296 | return None 297 | s = edge_to_str(e) 298 | if max_edge_label_len is None: 299 | return s 300 | if s is None: 301 | return s 302 | if len(s) > max_edge_label_len: 303 | dots = "..." 304 | return s[:max_edge_label_len - len(dots)] + dots 305 | return s 306 | 307 | pt = PrettyPrintTree( 308 | get_children=get_children, 309 | get_val=node_to_str, 310 | get_label=trimmed_edge_to_str, 311 | return_instead_of_print=True, 312 | # border=True, 313 | trim=max_label_len, 314 | ) 315 | return pt(root) 316 | 317 | class SimpleTimer: 318 | def __init__(self): 319 | self.times = {} 320 | self.start_times = {} 321 | 322 | def start(self, section: str): 323 | self.start_times[section] = time.perf_counter() 324 | 325 | def end(self, section: str): 326 | if section not in self.start_times: 327 | return 328 | elapsed = time.perf_counter() - self.start_times.pop(section) 329 | self.times[section] = self.times.get(section, 0.0) + elapsed 330 | 331 | def get_times(self) -> dict[str, float]: 332 | return self.times 333 | 334 | def log_times(self): 335 | if not self.times: 336 | return 337 | total = sum(self.times.values()) 338 | print0("Timer results:") 339 | max_len = max(len(k) for k in self.times) 340 | for k, v in sorted(self.times.items(), key=lambda x: x[1], reverse=True): 341 | pct = (v / total * 100) if total > 0 else 0 342 | print0(f" {k:<{max_len}} : {v:.4f}s ({pct:.1f}%)") 343 | 344 | def gather(self) -> Self: 345 | """Gather data from all ranks and return a new SimpleTimer with the aggregated (summed) times.""" 346 | if not (dist.is_available() and dist.is_initialized()): 347 | new_timer = SimpleTimer() 348 | new_timer.times = self.times.copy() 349 | return new_timer 350 | 351 | print0("Gathering timer data from all ranks...") 352 | world_size = dist.get_world_size() 353 | local_times = self.times 354 | all_times_list = [None for _ in range(world_size)] 355 | dist.all_gather_object(all_times_list, local_times) 356 | 357 | aggregated_times = {} 358 | for rank_times in all_times_list: 359 | if rank_times is None: continue 360 | for k, v in rank_times.items(): 361 | aggregated_times[k] = aggregated_times.get(k, 0.0) + v 362 | 363 | new_timer = SimpleTimer() 364 | new_timer.times = aggregated_times 365 | return new_timer 366 | 367 | class DummyTimer(SimpleTimer): 368 | def start(self, section: str): pass 369 | def end(self, section: str): pass 370 | def get_times(self) -> dict[str, float]: return {} 371 | def log_times(self): pass 372 | def gather(self) -> Self: return DummyTimer() -------------------------------------------------------------------------------- /nanoproof/report.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for generating training report cards. More messy code than usual, will fix. 3 | """ 4 | 5 | import os 6 | import re 7 | import shutil 8 | import subprocess 9 | import socket 10 | import datetime 11 | import platform 12 | import psutil 13 | import torch 14 | 15 | def run_command(cmd): 16 | """Run a shell command and return output, or None if it fails.""" 17 | try: 18 | result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5) 19 | if result.returncode == 0: 20 | return result.stdout.strip() 21 | return None 22 | except: 23 | return None 24 | 25 | def get_git_info(): 26 | """Get current git commit, branch, and dirty status.""" 27 | info = {} 28 | info['commit'] = run_command("git rev-parse --short HEAD") or "unknown" 29 | info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown" 30 | 31 | # Check if repo is dirty (has uncommitted changes) 32 | status = run_command("git status --porcelain") 33 | info['dirty'] = bool(status) if status is not None else False 34 | 35 | # Get commit message 36 | info['message'] = run_command("git log -1 --pretty=%B") or "" 37 | info['message'] = info['message'].split('\n')[0][:80] # First line, truncated 38 | 39 | return info 40 | 41 | def get_gpu_info(): 42 | """Get GPU information.""" 43 | if not torch.cuda.is_available(): 44 | return {"available": False} 45 | 46 | num_devices = torch.cuda.device_count() 47 | info = { 48 | "available": True, 49 | "count": num_devices, 50 | "names": [], 51 | "memory_gb": [] 52 | } 53 | 54 | for i in range(num_devices): 55 | props = torch.cuda.get_device_properties(i) 56 | info["names"].append(props.name) 57 | info["memory_gb"].append(props.total_memory / (1024**3)) 58 | 59 | # Get CUDA version 60 | info["cuda_version"] = torch.version.cuda or "unknown" 61 | 62 | return info 63 | 64 | def get_system_info(): 65 | """Get system information.""" 66 | info = {} 67 | 68 | # Basic system info 69 | info['hostname'] = socket.gethostname() 70 | info['platform'] = platform.system() 71 | info['python_version'] = platform.python_version() 72 | info['torch_version'] = torch.__version__ 73 | 74 | # CPU and memory 75 | info['cpu_count'] = psutil.cpu_count(logical=False) 76 | info['cpu_count_logical'] = psutil.cpu_count(logical=True) 77 | info['memory_gb'] = psutil.virtual_memory().total / (1024**3) 78 | 79 | # User and environment 80 | info['user'] = os.environ.get('USER', 'unknown') 81 | info['nanoproof_base_dir'] = os.environ.get('NANOPROOF_BASE_DIR', 'out') 82 | info['working_dir'] = os.getcwd() 83 | 84 | return info 85 | 86 | def estimate_cost(gpu_info, runtime_hours=None): 87 | """Estimate training cost based on GPU type and runtime.""" 88 | 89 | # Rough pricing, from Lambda Cloud 90 | default_rate = 2.0 91 | gpu_hourly_rates = { 92 | "H100": 3.00, 93 | "A100": 1.79, 94 | "V100": 0.55, 95 | } 96 | 97 | if not gpu_info.get("available"): 98 | return None 99 | 100 | # Try to identify GPU type from name 101 | hourly_rate = None 102 | gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown" 103 | for gpu_type, rate in gpu_hourly_rates.items(): 104 | if gpu_type in gpu_name: 105 | hourly_rate = rate * gpu_info["count"] 106 | break 107 | 108 | if hourly_rate is None: 109 | hourly_rate = default_rate * gpu_info["count"] # Default estimate 110 | 111 | return { 112 | "hourly_rate": hourly_rate, 113 | "gpu_type": gpu_name, 114 | "estimated_total": hourly_rate * runtime_hours if runtime_hours else None 115 | } 116 | 117 | def generate_header(): 118 | """Generate the header for a training report.""" 119 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 120 | 121 | git_info = get_git_info() 122 | gpu_info = get_gpu_info() 123 | sys_info = get_system_info() 124 | cost_info = estimate_cost(gpu_info) 125 | 126 | header = f"""# nanoproof training report 127 | 128 | Generated: {timestamp} 129 | 130 | ## Environment 131 | 132 | ### Git Information 133 | - Branch: {git_info['branch']} 134 | - Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"} 135 | - Message: {git_info['message']} 136 | 137 | ### Hardware 138 | - Platform: {sys_info['platform']} 139 | - CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical) 140 | - Memory: {sys_info['memory_gb']:.1f} GB 141 | """ 142 | 143 | if gpu_info.get("available"): 144 | gpu_names = ", ".join(set(gpu_info["names"])) 145 | total_vram = sum(gpu_info["memory_gb"]) 146 | header += f"""- GPUs: {gpu_info['count']}x {gpu_names} 147 | - GPU Memory: {total_vram:.1f} GB total 148 | - CUDA Version: {gpu_info['cuda_version']} 149 | """ 150 | else: 151 | header += "- GPUs: None available\n" 152 | 153 | if cost_info and cost_info["hourly_rate"] > 0: 154 | header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n""" 155 | 156 | header += f""" 157 | ### Software 158 | - Python: {sys_info['python_version']} 159 | - PyTorch: {sys_info['torch_version']} 160 | 161 | """ 162 | 163 | # bloat metrics: package all of the source code and assess its weight 164 | packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml') 165 | num_chars = len(packaged) 166 | num_lines = len(packaged.split('\n')) 167 | num_files = len([x for x in packaged.split('\n') if x.startswith('')]) 168 | num_tokens = num_chars // 4 # assume approximately 4 chars per token 169 | 170 | # count dependencies via uv.lock 171 | uv_lock_lines = 0 172 | if os.path.exists('uv.lock'): 173 | with open('uv.lock', 'r', encoding='utf-8') as f: 174 | uv_lock_lines = len(f.readlines()) 175 | 176 | header += f""" 177 | ### Bloat 178 | - Characters: {num_chars:,} 179 | - Lines: {num_lines:,} 180 | - Files: {num_files:,} 181 | - Tokens (approx): {num_tokens:,} 182 | - Dependencies (uv.lock lines): {uv_lock_lines:,} 183 | 184 | """ 185 | return header 186 | 187 | # ----------------------------------------------------------------------------- 188 | 189 | def slugify(text): 190 | """Slugify a text string.""" 191 | return text.lower().replace(" ", "-") 192 | 193 | # the expected files and their order 194 | EXPECTED_FILES = [ 195 | "tokenizer-training.md", 196 | "tokenizer-evaluation.md", 197 | "base-model-training.md", 198 | "base-model-loss.md", 199 | "base-model-evaluation.md", 200 | "midtraining.md", 201 | "chat-evaluation-mid.md", 202 | "chat-sft.md", 203 | "chat-evaluation-sft.md", 204 | "chat-rl.md", 205 | "chat-evaluation-rl.md", 206 | ] 207 | # the metrics we're currently interested in 208 | chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"] 209 | 210 | def extract(section, keys): 211 | """simple def to extract a single key from a section""" 212 | if not isinstance(keys, list): 213 | keys = [keys] # convenience 214 | out = {} 215 | for line in section.split("\n"): 216 | for key in keys: 217 | if key in line: 218 | out[key] = line.split(":")[1].strip() 219 | return out 220 | 221 | def extract_timestamp(content, prefix): 222 | """Extract timestamp from content with given prefix.""" 223 | for line in content.split('\n'): 224 | if line.startswith(prefix): 225 | time_str = line.split(":", 1)[1].strip() 226 | try: 227 | return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S") 228 | except: 229 | pass 230 | return None 231 | 232 | class Report: 233 | """Maintains a bunch of logs, generates a final markdown report.""" 234 | 235 | def __init__(self, report_dir): 236 | os.makedirs(report_dir, exist_ok=True) 237 | self.report_dir = report_dir 238 | 239 | def log(self, section, data): 240 | """Log a section of data to the report.""" 241 | slug = slugify(section) 242 | file_name = f"{slug}.md" 243 | file_path = os.path.join(self.report_dir, file_name) 244 | with open(file_path, "w", encoding="utf-8") as f: 245 | f.write(f"## {section}\n") 246 | f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") 247 | for item in data: 248 | if not item: 249 | # skip falsy values like None or empty dict etc. 250 | continue 251 | if isinstance(item, str): 252 | # directly write the string 253 | f.write(item) 254 | else: 255 | # render a dict 256 | for k, v in item.items(): 257 | if isinstance(v, float): 258 | vstr = f"{v:.4f}" 259 | elif isinstance(v, int) and v >= 10000: 260 | vstr = f"{v:,.0f}" 261 | else: 262 | vstr = str(v) 263 | f.write(f"- {k}: {vstr}\n") 264 | f.write("\n") 265 | return file_path 266 | 267 | def generate(self): 268 | """Generate the final report.""" 269 | report_dir = self.report_dir 270 | report_file = os.path.join(report_dir, "report.md") 271 | print(f"Generating report to {report_file}") 272 | final_metrics = {} # the most important final metrics we'll add as table at the end 273 | start_time = None 274 | end_time = None 275 | with open(report_file, "w", encoding="utf-8") as out_file: 276 | # write the header first 277 | header_file = os.path.join(report_dir, "header.md") 278 | if os.path.exists(header_file): 279 | with open(header_file, "r", encoding="utf-8") as f: 280 | header_content = f.read() 281 | out_file.write(header_content) 282 | start_time = extract_timestamp(header_content, "Run started:") 283 | # capture bloat data for summary later (the stuff after Bloat header and until \n\n) 284 | bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL) 285 | bloat_data = bloat_data.group(1) if bloat_data else "" 286 | else: 287 | start_time = None # will cause us to not write the total wall clock time 288 | bloat_data = "[bloat data missing]" 289 | print(f"Warning: {header_file} does not exist. Did you forget to run `nanoproof reset`?") 290 | # process all the individual sections 291 | for file_name in EXPECTED_FILES: 292 | section_file = os.path.join(report_dir, file_name) 293 | if not os.path.exists(section_file): 294 | print(f"Warning: {section_file} does not exist, skipping") 295 | continue 296 | with open(section_file, "r", encoding="utf-8") as in_file: 297 | section = in_file.read() 298 | # Extract timestamp from this section (the last section's timestamp will "stick" as end_time) 299 | if "rl" not in file_name: 300 | # Skip RL sections for end_time calculation because RL is experimental 301 | end_time = extract_timestamp(section, "timestamp:") 302 | # extract the most important metrics from the sections 303 | if file_name == "base-model-evaluation.md": 304 | final_metrics["base"] = extract(section, "CORE") 305 | if file_name == "chat-evaluation-mid.md": 306 | final_metrics["mid"] = extract(section, chat_metrics) 307 | if file_name == "chat-evaluation-sft.md": 308 | final_metrics["sft"] = extract(section, chat_metrics) 309 | if file_name == "chat-evaluation-rl.md": 310 | final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K 311 | # append this section of the report 312 | out_file.write(section) 313 | out_file.write("\n") 314 | # add the final metrics table 315 | out_file.write("## Summary\n\n") 316 | # Copy over the bloat metrics from the header 317 | out_file.write(bloat_data) 318 | out_file.write("\n\n") 319 | # Collect all unique metric names 320 | all_metrics = set() 321 | for stage_metrics in final_metrics.values(): 322 | all_metrics.update(stage_metrics.keys()) 323 | # Custom ordering: CORE first, ChatCORE last, rest in middle 324 | all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x)) 325 | # Fixed column widths 326 | stages = ["base", "mid", "sft", "rl"] 327 | metric_width = 15 328 | value_width = 8 329 | # Write table header 330 | header = f"| {'Metric'.ljust(metric_width)} |" 331 | for stage in stages: 332 | header += f" {stage.upper().ljust(value_width)} |" 333 | out_file.write(header + "\n") 334 | # Write separator 335 | separator = f"|{'-' * (metric_width + 2)}|" 336 | for stage in stages: 337 | separator += f"{'-' * (value_width + 2)}|" 338 | out_file.write(separator + "\n") 339 | # Write table rows 340 | for metric in all_metrics: 341 | row = f"| {metric.ljust(metric_width)} |" 342 | for stage in stages: 343 | value = final_metrics.get(stage, {}).get(metric, "-") 344 | row += f" {str(value).ljust(value_width)} |" 345 | out_file.write(row + "\n") 346 | out_file.write("\n") 347 | # Calculate and write total wall clock time 348 | if start_time and end_time: 349 | duration = end_time - start_time 350 | total_seconds = int(duration.total_seconds()) 351 | hours = total_seconds // 3600 352 | minutes = (total_seconds % 3600) // 60 353 | out_file.write(f"Total wall clock time: {hours}h{minutes}m\n") 354 | else: 355 | out_file.write("Total wall clock time: unknown\n") 356 | # also cp the report.md file to current directory 357 | print(f"Copying report.md to current directory for convenience") 358 | shutil.copy(report_file, "report.md") 359 | return report_file 360 | 361 | def reset(self): 362 | """Reset the report.""" 363 | # Remove section files 364 | for file_name in EXPECTED_FILES: 365 | file_path = os.path.join(self.report_dir, file_name) 366 | if os.path.exists(file_path): 367 | os.remove(file_path) 368 | # Remove report.md if it exists 369 | report_file = os.path.join(self.report_dir, "report.md") 370 | if os.path.exists(report_file): 371 | os.remove(report_file) 372 | # Generate and write the header section with start timestamp 373 | header_file = os.path.join(self.report_dir, "header.md") 374 | header = generate_header() 375 | start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 376 | with open(header_file, "w", encoding="utf-8") as f: 377 | f.write(header) 378 | f.write(f"Run started: {start_time}\n\n---\n\n") 379 | print(f"Reset report and wrote header to {header_file}") 380 | 381 | # ----------------------------------------------------------------------------- 382 | # nanoproof-specific convenience functions 383 | 384 | class DummyReport: 385 | def log(self, *args, **kwargs): 386 | pass 387 | def reset(self, *args, **kwargs): 388 | pass 389 | 390 | def get_report(): 391 | # just for convenience, only rank 0 logs to report 392 | from nanoproof.common import get_base_dir, get_dist_info 393 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() 394 | if ddp_rank == 0: 395 | report_dir = os.path.join(get_base_dir(), "report") 396 | return Report(report_dir) 397 | else: 398 | return DummyReport() 399 | 400 | if __name__ == "__main__": 401 | import argparse 402 | parser = argparse.ArgumentParser(description="Generate or reset nanoproof training reports.") 403 | parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)") 404 | args = parser.parse_args() 405 | if args.command == "generate": 406 | get_report().generate() 407 | elif args.command == "reset": 408 | get_report().reset() -------------------------------------------------------------------------------- /scripts/tok_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate compression ratio of the tokenizer. 3 | """ 4 | 5 | from nanoproof.tokenizer import get_tokenizer, HuggingFaceTokenizer 6 | from nanoproof.data.nemotron import parquets_iter_batched 7 | from nanoproof.data.leangithubraw import iter_texts_batched 8 | 9 | # Random text I got from a random website this morning 10 | news_text = r""" 11 | (Washington, D.C., July 9, 2025)- Yesterday, Mexico’s National Service of Agro-Alimentary Health, Safety, and Quality (SENASICA) reported a new case of New World Screwworm (NWS) in Ixhuatlan de Madero, Veracruz in Mexico, which is approximately 160 miles northward of the current sterile fly dispersal grid, on the eastern side of the country and 370 miles south of the U.S./Mexico border. This new northward detection comes approximately two months after northern detections were reported in Oaxaca and Veracruz, less than 700 miles away from the U.S. border, which triggered the closure of our ports to Mexican cattle, bison, and horses on May 11, 2025. 12 | 13 | While USDA announced a risk-based phased port re-opening strategy for cattle, bison, and equine from Mexico beginning as early as July 7, 2025, this newly reported NWS case raises significant concern about the previously reported information shared by Mexican officials and severely compromises the outlined port reopening schedule of five ports from July 7-September 15. Therefore, in order to protect American livestock and our nation’s food supply, Secretary Rollins has ordered the closure of livestock trade through southern ports of entry effective immediately. 14 | 15 | “The United States has promised to be vigilant — and after detecting this new NWS case, we are pausing the planned port reopening’s to further quarantine and target this deadly pest in Mexico. We must see additional progress combatting NWS in Veracruz and other nearby Mexican states in order to reopen livestock ports along the Southern border,” said U.S. Secretary of Agriculture Brooke L. Rollins. “Thanks to the aggressive monitoring by USDA staff in the U.S. and in Mexico, we have been able to take quick and decisive action to respond to the spread of this deadly pest.” 16 | """.strip() 17 | 18 | # Random Korean text (to test non-English compression) 19 | korean_text = r""" 20 | 정직한 사실 위에, 공정한 시선을 더하다 21 | Herald Korea Times 22 | 23 | 헤럴드코리아타임즈는 정치, 경제, 사회, 문화 등 한국 사회 전반의 주요 이슈를 심도 있게 다루는 종합 온라인 신문사입니다. 24 | 25 | 우리는 단순히 뉴스를 전달하는 것이 아니라, 사실(Fact)에 기반한 양측의 시각을 균형 있게 조명하며, 독자 여러분이 스스로 판단할 수 있는 ‘정보의 균형’을 제공합니다. 26 | 27 | 한국 언론의 오랜 문제로 지적되어 온 정치적 편향, 이념적 왜곡에서 벗어나 28 | 오직 정직함과 공정함을 원칙으로 삼는 언론을 지향합니다. 29 | 어느 한쪽의 주장만을 확대하거나 감추지 않고, 30 | **모든 쟁점에 대해 ‘무엇이 쟁점인지’, ‘누가 무엇을 주장하는지’, ‘사실은 무엇인지’**를 명확히 전달하는 데 집중합니다. 31 | """.strip() 32 | 33 | # Random piece of code 34 | code_text = r""" 35 | class BasicTokenizer(Tokenizer): 36 | 37 | def __init__(self): 38 | super().__init__() 39 | 40 | def train(self, text, vocab_size, verbose=False): 41 | assert vocab_size >= 256 42 | num_merges = vocab_size - 256 43 | 44 | # input text preprocessing 45 | text_bytes = text.encode("utf-8") # raw bytes 46 | ids = list(text_bytes) # list of integers in range 0..255 47 | 48 | # iteratively merge the most common pairs to create new tokens 49 | merges = {} # (int, int) -> int 50 | vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes 51 | for i in range(num_merges): 52 | # count up the number of times every consecutive pair appears 53 | stats = get_stats(ids) 54 | # find the pair with the highest count 55 | pair = max(stats, key=stats.get) 56 | # mint a new token: assign it the next available id 57 | idx = 256 + i 58 | # replace all occurrences of pair in ids with idx 59 | ids = merge(ids, pair, idx) 60 | # save the merge 61 | merges[pair] = idx 62 | vocab[idx] = vocab[pair[0]] + vocab[pair[1]] 63 | # prints 64 | if verbose: 65 | print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") 66 | """.strip() 67 | 68 | math_text = r""" 69 | \documentclass[12pt]{article} 70 | \usepackage{amsmath,amsthm,amssymb} 71 | \usepackage[margin=1in]{geometry} 72 | 73 | \newtheorem{theorem}{Theorem} 74 | \newtheorem*{remark}{Remark} 75 | 76 | \begin{document} 77 | 78 | \begin{center} 79 | {\Large A Cute Identity: The Sum of Cubes is a Square} 80 | \end{center} 81 | 82 | \begin{theorem} 83 | For every integer $n \ge 1$, 84 | \[ 85 | \sum_{k=1}^{n} k^{3} \;=\; \left(\frac{n(n+1)}{2}\right)^{2}. 86 | \] 87 | \end{theorem} 88 | 89 | \begin{proof}[Proof 1 (Induction)] 90 | Let $S(n) = \sum_{k=1}^{n} k^3$. For $n=1$, $S(1)=1=(1\cdot 2/2)^2$, so the base case holds. 91 | 92 | Assume $S(n)=\big(\tfrac{n(n+1)}{2}\big)^2$ for some $n\ge 1$. 93 | Then 94 | \[ 95 | S(n+1) 96 | = S(n) + (n+1)^3 97 | = \left(\frac{n(n+1)}{2}\right)^2 + (n+1)^3. 98 | \] 99 | Factor out $(n+1)^2$: 100 | \[ 101 | S(n+1) 102 | = (n+1)^2\left( \frac{n^2}{4} + (n+1) \right) 103 | = (n+1)^2\left( \frac{n^2 + 4n + 4}{4} \right) 104 | = (n+1)^2\left( \frac{(n+2)^2}{4} \right). 105 | \] 106 | Thus 107 | \[ 108 | S(n+1)=\left(\frac{(n+1)(n+2)}{2}\right)^2, 109 | \] 110 | which matches the claimed formula with $n$ replaced by $n+1$. By induction, the identity holds for all $n\ge 1$. 111 | \end{proof} 112 | 113 | \begin{proof}[Proof 2 (Algebraic telescoping)] 114 | Recall the binomial identity 115 | \[ 116 | (k+1)^4 - k^4 = 4k^3 + 6k^2 + 4k + 1. 117 | \] 118 | Summing both sides from $k=0$ to $n$ telescopes: 119 | \[ 120 | (n+1)^4 - 0^4 121 | = \sum_{k=0}^{n}\big(4k^3 + 6k^2 + 4k + 1\big) 122 | = 4\sum_{k=1}^{n}k^3 + 6\sum_{k=1}^{n}k^2 + 4\sum_{k=1}^{n}k + (n+1). 123 | \] 124 | Using the standard sums 125 | \[ 126 | \sum_{k=1}^{n}k = \frac{n(n+1)}{2} 127 | \quad\text{and}\quad 128 | \sum_{k=1}^{n}k^2 = \frac{n(n+1)(2n+1)}{6}, 129 | \] 130 | solve for $\sum_{k=1}^{n}k^3$ to get 131 | \[ 132 | \sum_{k=1}^{n}k^3 = \left(\frac{n(n+1)}{2}\right)^2. 133 | \] 134 | \end{proof} 135 | 136 | \begin{remark} 137 | Geometrically, the identity says: ``adding up $1^3,2^3,\dots,n^3$ builds a perfect square’’—namely the square of the $n$th triangular number. This is why one sometimes calls it the \emph{sum-of-cubes is a square} phenomenon. 138 | \end{remark} 139 | 140 | \end{document} 141 | """.strip() 142 | 143 | science_text = r""" 144 | Photosynthesis is a photochemical energy transduction process in which light-harvesting pigment–protein complexes within the thylakoid membranes of oxygenic phototrophs absorb photons and initiate charge separation at the reaction center, driving the linear electron transport chain from water to NADP⁺ via photosystem II, the cytochrome b₆f complex, and photosystem I, concomitantly generating a trans-thylakoid proton motive force utilized by chloroplastic ATP synthase. The light-dependent reactions produce ATP and NADPH, which fuel the Calvin–Benson–Bassham cycle in the stroma, wherein ribulose-1,5-bisphosphate is carboxylated by ribulose-1,5-bisphosphate carboxylase/oxygenase (RuBisCO) to form 3-phosphoglycerate, subsequently reduced and regenerated through a series of enzymatic steps, enabling net assimilation of CO₂ into triose phosphates and ultimately carbohydrates. This process is tightly regulated by photoprotective mechanisms, redox feedback, and metabolite flux, representing a central biochemical pathway coupling solar energy capture to the biosphere’s primary productivity. 145 | """.strip() 146 | 147 | lean_text = r""" 148 | @[expose] public section 149 | 150 | universe u v w u₁ v₁ 151 | 152 | /-- Defining the homomorphism in the category R-Alg, denoted `A →ₐ[R] B`. -/ 153 | structure AlgHom (R : Type u) (A : Type v) (B : Type w) [CommSemiring R] [Semiring A] [Semiring B] 154 | [Algebra R A] [Algebra R B] extends RingHom A B where 155 | commutes' : ∀ r : R, toFun (algebraMap R A r) = algebraMap R B r 156 | 157 | /-- Reinterpret an `AlgHom` as a `RingHom` -/ 158 | add_decl_doc AlgHom.toRingHom 159 | 160 | @[inherit_doc AlgHom] 161 | infixr:25 " →ₐ " => AlgHom _ 162 | 163 | @[inherit_doc] 164 | notation:25 A " →ₐ[" R "] " B => AlgHom R A B 165 | 166 | /-- The algebra morphism underlying `algebraMap` -/ 167 | def Algebra.algHom (R A B : Type*) 168 | [CommSemiring R] [CommSemiring A] [Semiring B] [Algebra R A] [Algebra R B] 169 | [Algebra A B] [IsScalarTower R A B] : 170 | A →ₐ[R] B where 171 | toRingHom := algebraMap A B 172 | commutes' r := by simpa [Algebra.smul_def] using smul_assoc r (1 : A) (1 : B) 173 | 174 | /-- `AlgHomClass F R A B` asserts `F` is a type of bundled algebra homomorphisms 175 | from `A` to `B`. -/ 176 | class AlgHomClass (F : Type*) (R A B : outParam Type*) 177 | [CommSemiring R] [Semiring A] [Semiring B] [Algebra R A] [Algebra R B] [FunLike F A B] : Prop 178 | extends RingHomClass F A B where 179 | commutes : ∀ (f : F) (r : R), f (algebraMap R A r) = algebraMap R B r 180 | 181 | -- For now, don't replace `AlgHom.commutes` and `AlgHomClass.commutes` with the more generic lemma. 182 | -- The file `Mathlib/NumberTheory/NumberField/CanonicalEmbedding/FundamentalCone.lean` slows down by 183 | -- 15% if we would do so (see benchmark on PR https://github.com/leanprover-community/mathlib4/pull/18040). 184 | -- attribute [simp] AlgHomClass.commutes 185 | 186 | namespace AlgHomClass 187 | 188 | variable {R A B F : Type*} [CommSemiring R] [Semiring A] [Semiring B] 189 | [Algebra R A] [Algebra R B] [FunLike F A B] 190 | 191 | -- see Note [lower instance priority] 192 | instance (priority := 100) linearMapClass [AlgHomClass F R A B] : LinearMapClass F R A B := 193 | { ‹AlgHomClass F R A B› with 194 | map_smulₛₗ := fun f r x => by 195 | simp only [Algebra.smul_def, map_mul, commutes, RingHom.id_apply] } 196 | 197 | /-- Turn an element of a type `F` satisfying `AlgHomClass F α β` into an actual 198 | `AlgHom`. This is declared as the default coercion from `F` to `α →+* β`. -/ 199 | @[coe] 200 | def toAlgHom {F : Type*} [FunLike F A B] [AlgHomClass F R A B] (f : F) : A →ₐ[R] B where 201 | __ := (f : A →+* B) 202 | toFun := f 203 | commutes' := AlgHomClass.commutes f 204 | 205 | instance coeTC {F : Type*} [FunLike F A B] [AlgHomClass F R A B] : CoeTC F (A →ₐ[R] B) := 206 | ⟨AlgHomClass.toAlgHom⟩ 207 | 208 | end AlgHomClass 209 | 210 | namespace AlgHom 211 | 212 | variable {R : Type u} {A : Type v} {B : Type w} {C : Type u₁} {D : Type v₁} 213 | 214 | section Semiring 215 | 216 | variable [CommSemiring R] [Semiring A] [Semiring B] [Semiring C] [Semiring D] 217 | variable [Algebra R A] [Algebra R B] [Algebra R C] [Algebra R D] 218 | 219 | instance funLike : FunLike (A →ₐ[R] B) A B where 220 | coe f := f.toFun 221 | coe_injective' f g h := by 222 | rcases f with ⟨⟨⟨⟨_, _⟩, _⟩, _, _⟩, _⟩ 223 | rcases g with ⟨⟨⟨⟨_, _⟩, _⟩, _, _⟩, _⟩ 224 | congr 225 | 226 | instance algHomClass : AlgHomClass (A →ₐ[R] B) R A B where 227 | map_add f := f.map_add' 228 | map_zero f := f.map_zero' 229 | map_mul f := f.map_mul' 230 | map_one f := f.map_one' 231 | commutes f := f.commutes' 232 | """ 233 | 234 | lean_search_text = r""" 235 | x : ℝ 236 | ⊢ x ^ 2 - 2 * x - 24 < 0 ↔ x ∈ Set.Ioo (-4) 6 237 | 238 | exact ⟨fun h ↦ by rw [Set.mem_Ioo]; constructor <;> nlinarith [h], fun h ↦ by rw [Set.mem_Ioo] at h; nlinarith⟩ 239 | 240 | ⊢ ∀ (x : ℝ), 2⁻¹ + cos (2 * (2 * x)) / 2 = (1 + cos (4 * x)) / 2 241 | 242 | ring 243 | 244 | case h 245 | ι : Type u_4 246 | inst✝ : Fintype ι 247 | f : ℝ → ι → ℝ 248 | s : Set ℝ 249 | h : LocallyBoundedVariationOn f s 250 | A : ∀ (i : ι), LipschitzWith 1 fun x => x i 251 | i : ι 252 | ⊢ LocallyBoundedVariationOn (fun x => f x i) s 253 | 254 | exact LipschitzWith.comp_locallyBoundedVariationOn (A i) h 255 | 256 | p q : Prop 257 | ⊢ p ∧ q → p 258 | 259 | intro h 260 | 261 | case mp.inl 262 | p q r : Prop 263 | hp : p 264 | hq : q 265 | ⊢ p ∧ q ∨ p ∧ r 266 | 267 | exact Or.inl ⟨hp, hq⟩ 268 | 269 | α : Type 270 | P : α → Prop 271 | inst✝ : Inhabited α 272 | h : ∀ (x : α), P x 273 | x0 : α := default 274 | hx0 : P x0 275 | ⊢ ∃ x, P x 276 | 277 | exact Exists.intro x0 hx0 278 | """ 279 | 280 | # The tokenizer was trained on data from earlier shards, so it has seen this data 281 | nemotron_train_docs = next(parquets_iter_batched(split="train")) 282 | nemotron_train_text = "\n".join(nemotron_train_docs) 283 | nemotron_val_docs = next(parquets_iter_batched(split="val")) 284 | nemotron_val_text = "\n".join(nemotron_val_docs) 285 | 286 | leangithubraw_train_docs = next(iter_texts_batched(split="train")) 287 | leangithubraw_train_text = "\n".join(leangithubraw_train_docs) 288 | leangithubraw_val_docs = next(iter_texts_batched(split="val")) 289 | leangithubraw_val_text = "\n".join(leangithubraw_val_docs) 290 | 291 | all_text = [ 292 | ("news", news_text), 293 | ("korean", korean_text), 294 | ("code", code_text), 295 | ("math", math_text), 296 | ("science", science_text), 297 | ("nm-train", nemotron_train_text), 298 | ("lg-train", leangithubraw_train_text), 299 | ] 300 | if nemotron_val_text: 301 | all_text.append(("nm-val", nemotron_val_text)) 302 | if leangithubraw_val_text: 303 | all_text.append(("lg-val", leangithubraw_val_text)) 304 | 305 | # Try out current default compared to GPT-2 and GPT-4 tokenizers 306 | tokenizer_results = {} 307 | vocab_sizes = {} 308 | 309 | # for tokenizer_name in ["gpt2", "gpt4", "ours"]: 310 | for tokenizer_name in ["gpt2", "nemotron", "deepseek", "ours"]: 311 | print(f"Evaluating {tokenizer_name}...") 312 | if tokenizer_name == "gpt2": 313 | tokenizer = HuggingFaceTokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer 314 | # elif tokenizer_name == "gpt4": 315 | # tokenizer = HuggingFaceTokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer 316 | elif tokenizer_name == "nemotron": 317 | tokenizer = HuggingFaceTokenizer.from_pretrained("nvidia/NVIDIA-Nemotron-Nano-9B-v2") 318 | elif tokenizer_name == "deepseek": 319 | tokenizer = HuggingFaceTokenizer.from_pretrained("deepseek-ai/DeepSeek-V3.2") 320 | else: 321 | tokenizer = get_tokenizer() 322 | 323 | vocab_sizes[tokenizer_name] = tokenizer.get_vocab_size() 324 | tokenizer_results[tokenizer_name] = {} 325 | 326 | for name, text in all_text: 327 | encoded = tokenizer.encode(text) 328 | decoded = tokenizer.decode(encoded) 329 | assert decoded == text 330 | 331 | encoded_bytes = text.encode('utf-8') 332 | ratio = len(encoded_bytes) / len(encoded) 333 | tokenizer_results[tokenizer_name][name] = { 334 | 'bytes': len(encoded_bytes), 335 | 'tokens': len(encoded), 336 | 'ratio': ratio 337 | } 338 | 339 | # ANSI color codes 340 | GREEN = '\033[92m' 341 | RED = '\033[91m' 342 | RESET = '\033[0m' 343 | 344 | # Print vocab sizes 345 | print(f"\nVocab sizes:") 346 | print(f"GPT-2: {vocab_sizes['gpt2']}") 347 | print(f"Nemotron: {vocab_sizes['nemotron']}") 348 | print(f"DeepSeek: {vocab_sizes['deepseek']}") 349 | print(f"Ours: {vocab_sizes['ours']}") 350 | 351 | def print_comparison(baseline_name, baseline_results, ours_results, all_text): 352 | """Print comparison table between baseline tokenizer and ours.""" 353 | print(f"\nComparison with {baseline_name}:") 354 | print("=" * 95) 355 | print(f"{'Text Type':<10} {'Bytes':<8} {baseline_name:<15} {'Ours':<15} {'Relative':<12} {'Better':<10}") 356 | print(f"{'':10} {'':8} {'Tokens':<7} {'Ratio':<7} {'Tokens':<7} {'Ratio':<7} {'Diff %':<12}") 357 | print("-" * 95) 358 | 359 | for name, text in all_text: 360 | baseline_data = baseline_results[name] 361 | ours_data = ours_results[name] 362 | 363 | # Calculate relative difference (positive means ours is better, negative means worse) 364 | # Using tokens: fewer tokens is better, so we calculate (baseline_tokens - ours_tokens) / baseline_tokens 365 | relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100 366 | 367 | # Determine which has better compression (higher ratio = better) 368 | if baseline_data['ratio'] > ours_data['ratio']: 369 | baseline_color, ours_color = GREEN, RED 370 | better = baseline_name 371 | diff_color = RED 372 | elif ours_data['ratio'] > baseline_data['ratio']: 373 | baseline_color, ours_color = RED, GREEN 374 | better = "Ours" 375 | diff_color = GREEN 376 | else: 377 | baseline_color, ours_color = "", "" 378 | better = "Tie" 379 | diff_color = "" 380 | 381 | print(f"{name:<10} {baseline_data['bytes']:<8} " 382 | f"{baseline_color}{baseline_data['tokens']:<7}{RESET} " 383 | f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} " 384 | f"{ours_color}{ours_data['tokens']:<7}{RESET} " 385 | f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} " 386 | f"{diff_color}{relative_diff:+7.1f}%{RESET} " 387 | f"{better:<10}") 388 | 389 | # Print comparisons 390 | print_comparison("GPT-2", tokenizer_results['gpt2'], tokenizer_results['ours'], all_text) 391 | print_comparison("Nemotron", tokenizer_results['nemotron'], tokenizer_results['ours'], all_text) 392 | print_comparison("DeepSeek", tokenizer_results['deepseek'], tokenizer_results['ours'], all_text) 393 | 394 | # Log to report 395 | from nanoproof.report import get_report 396 | lines = [] 397 | for baseline_name in ["GPT-2", "Nemotron", "DeepSeek"]: 398 | baseline_key = baseline_name.lower().replace('-', '') 399 | baseline_results = tokenizer_results[baseline_key] 400 | ours_results = tokenizer_results['ours'] 401 | lines.append(f"### Comparison with {baseline_name}") 402 | lines.append("") 403 | lines.append("| Text Type | Bytes | " + baseline_name + " Tokens | " + baseline_name + " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |") 404 | lines.append("|-----------|-------|--------------|--------------|-------------|------------|-----------------|") 405 | for name, text in all_text: 406 | baseline_data = baseline_results[name] 407 | ours_data = ours_results[name] 408 | relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100 409 | lines.append(f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |") 410 | lines.append("") 411 | report_markdown = "\n".join(lines) 412 | get_report().log(section="Tokenizer evaluation", data=[ 413 | report_markdown, 414 | ]) --------------------------------------------------------------------------------