├── rope_imaginary ├── nanochat_imaginary │ ├── .python-version │ ├── nanochat │ │ ├── __init__.py │ │ ├── configurator.py │ │ ├── loss_eval.py │ │ ├── adamw.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── checkpoint_manager.py │ │ └── common.py │ ├── .gitignore │ ├── rustbpe │ │ ├── Cargo.toml │ │ └── README.md │ ├── README.md │ ├── LICENSE │ ├── pyproject.toml │ ├── tasks │ │ ├── smoltalk.py │ │ ├── arc.py │ │ ├── customjson.py │ │ ├── humaneval.py │ │ ├── mmlu.py │ │ ├── gsm8k.py │ │ └── common.py │ ├── scripts │ │ ├── base_loss.py │ │ └── tok_train.py │ ├── run1000.sh │ └── speedrun.sh ├── rope_pp │ ├── configs │ │ ├── rope-376m-config.json │ │ ├── rope-1_5b-config.json │ │ └── rope-776m-config.json │ ├── pyproject.toml │ ├── speedrun-eval.sh │ ├── README.md │ ├── speedrun-1gpu.sh │ ├── utils │ │ ├── trainer_utils.py │ │ └── callback_utils.py │ ├── speedrun-8gpu.sh │ ├── .gitignore │ └── single-gpu │ │ └── train_rope_pp_single_gpu.py └── README.md ├── tiny_recursive_models ├── docs │ ├── assets │ │ ├── .DS_Store │ │ ├── TRM_fig.png │ │ ├── images │ │ │ ├── val.png │ │ │ └── train.png │ │ ├── TRM_pseudocode.png │ │ └── npyjs.js │ ├── plots │ │ ├── image.png │ │ ├── claims_vs_achieved_combined.png │ │ ├── claims_vs_achieved_1-maze-hard.png │ │ ├── claims_vs_achieved_trm-mlp-variant.png │ │ ├── claims_vs_achieved_trm-attention-variant.png │ │ └── claims_vs_achieved_2-the-abstraction-and-reasoning-challenge-arc.png │ └── hf_model_cards │ │ ├── model_card_arc_agi_1.md │ │ ├── model_card_sudoku.md │ │ └── model_card_maze.md ├── .gitignore ├── src │ └── tiny_recursive_models │ │ ├── utils │ │ ├── __init__.py │ │ └── functions.py │ │ ├── evaluation │ │ ├── __init__.py │ │ └── evaluator.py │ │ ├── training │ │ ├── adam_atan2_csrc │ │ │ ├── ops.cu │ │ │ ├── adam_atan2.h │ │ │ └── adam_atan2.cu │ │ ├── __init__.py │ │ ├── config.py │ │ ├── checkpoint.py │ │ └── optimizers.py │ │ ├── __init__.py │ │ ├── data │ │ ├── __init__.py │ │ ├── common.py │ │ ├── build_maze_dataset.py │ │ └── build_sudoku_dataset.py │ │ └── models │ │ ├── __init__.py │ │ ├── architectures │ │ └── __init__.py │ │ ├── common.py │ │ ├── ema.py │ │ ├── losses.py │ │ ├── sparse_embedding.py │ │ └── layers.py ├── config │ ├── arch │ │ ├── transformers_baseline.yaml │ │ ├── hrm.yaml │ │ ├── trm.yaml │ │ ├── trm_hier6.yaml │ │ ├── trm_singlez.yaml │ │ └── trm_mlp.yaml │ └── cfg_pretrain.yaml ├── pyproject.toml ├── scripts │ └── cmd.sh └── README.md └── README.md /rope_imaginary/nanochat_imaginary/.python-version: -------------------------------------------------------------------------------- 1 | 3.10 2 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/nanochat/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tiny_recursive_models/docs/assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alphaXiv/paper-implementations/HEAD/tiny_recursive_models/docs/assets/.DS_Store -------------------------------------------------------------------------------- /tiny_recursive_models/docs/plots/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alphaXiv/paper-implementations/HEAD/tiny_recursive_models/docs/plots/image.png -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/.gitignore: -------------------------------------------------------------------------------- 1 | .venv/ 2 | __pycache__/ 3 | *.pyc 4 | rustbpe/target/ 5 | dev-ignore/ 6 | report.md 7 | eval_bundle/ 8 | wandb/ -------------------------------------------------------------------------------- /tiny_recursive_models/docs/assets/TRM_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alphaXiv/paper-implementations/HEAD/tiny_recursive_models/docs/assets/TRM_fig.png -------------------------------------------------------------------------------- /tiny_recursive_models/docs/assets/images/val.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alphaXiv/paper-implementations/HEAD/tiny_recursive_models/docs/assets/images/val.png -------------------------------------------------------------------------------- /tiny_recursive_models/docs/assets/images/train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alphaXiv/paper-implementations/HEAD/tiny_recursive_models/docs/assets/images/train.png -------------------------------------------------------------------------------- /tiny_recursive_models/docs/assets/TRM_pseudocode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alphaXiv/paper-implementations/HEAD/tiny_recursive_models/docs/assets/TRM_pseudocode.png -------------------------------------------------------------------------------- /tiny_recursive_models/docs/plots/claims_vs_achieved_combined.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alphaXiv/paper-implementations/HEAD/tiny_recursive_models/docs/plots/claims_vs_achieved_combined.png -------------------------------------------------------------------------------- /tiny_recursive_models/docs/plots/claims_vs_achieved_1-maze-hard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alphaXiv/paper-implementations/HEAD/tiny_recursive_models/docs/plots/claims_vs_achieved_1-maze-hard.png -------------------------------------------------------------------------------- /tiny_recursive_models/docs/plots/claims_vs_achieved_trm-mlp-variant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alphaXiv/paper-implementations/HEAD/tiny_recursive_models/docs/plots/claims_vs_achieved_trm-mlp-variant.png -------------------------------------------------------------------------------- /tiny_recursive_models/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | trained_models/ 3 | 4 | # Python cache files 5 | __pycache__/ 6 | *.pyc 7 | checkpoints/ 8 | *.egg-info 9 | 10 | # Data and outputs 11 | /data/ 12 | /outputs/ 13 | /wandb/ -------------------------------------------------------------------------------- /tiny_recursive_models/docs/plots/claims_vs_achieved_trm-attention-variant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alphaXiv/paper-implementations/HEAD/tiny_recursive_models/docs/plots/claims_vs_achieved_trm-attention-variant.png -------------------------------------------------------------------------------- /tiny_recursive_models/docs/plots/claims_vs_achieved_2-the-abstraction-and-reasoning-challenge-arc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alphaXiv/paper-implementations/HEAD/tiny_recursive_models/docs/plots/claims_vs_achieved_2-the-abstraction-and-reasoning-challenge-arc.png -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | 3 | from tiny_recursive_models.utils.functions import load_model_class, get_model_source_path 4 | 5 | __all__ = [ 6 | "load_model_class", 7 | "get_model_source_path", 8 | ] 9 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | """Evaluation utilities.""" 2 | 3 | from tiny_recursive_models.evaluation.evaluator import ( 4 | evaluate, 5 | create_evaluators, 6 | ) 7 | from tiny_recursive_models.evaluation.arc import ARC 8 | 9 | __all__ = [ 10 | "evaluate", 11 | "create_evaluators", 12 | "ARC", 13 | ] 14 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/training/adam_atan2_csrc/ops.cu: -------------------------------------------------------------------------------- 1 | #include "adam_atan2.h" 2 | 3 | #include 4 | 5 | 6 | namespace adam_atan2 { 7 | 8 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 9 | m.def("adam_atan2_cuda_impl_", &adam_atan2_cuda_impl_, "Adam-atan2 Fused Implementation"); 10 | } 11 | 12 | } // namespace adam_atan2 13 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/rustbpe/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rustbpe" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | dary_heap = "0.3" 8 | indexmap = "2.2" 9 | fancy-regex = "0.16.1" 10 | log = "0.4.28" 11 | pyo3 = { version = "0.23.3", features = ["extension-module"] } 12 | pyo3-log = "0.12.4" 13 | ahash = "0.8.12" 14 | rayon = "1.11.0" 15 | compact_str = "0.9.0" 16 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tiny Recursive Models (TRM) 3 | 4 | A recursive reasoning approach for solving complex puzzles with tiny neural networks. 5 | """ 6 | 7 | __version__ = "0.1.0" 8 | 9 | # Import subpackages 10 | from tiny_recursive_models import models 11 | from tiny_recursive_models import utils 12 | 13 | __all__ = [ 14 | "__version__", 15 | "models", 16 | "utils", 17 | ] 18 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Data utilities and dataset builders.""" 2 | 3 | from tiny_recursive_models.data.puzzle_dataset import ( 4 | PuzzleDataset, 5 | PuzzleDatasetConfig, 6 | ) 7 | from tiny_recursive_models.data.common import ( 8 | PuzzleDatasetMetadata, 9 | ) 10 | 11 | __all__ = [ 12 | "PuzzleDataset", 13 | "PuzzleDatasetConfig", 14 | "PuzzleDatasetMetadata", 15 | ] 16 | -------------------------------------------------------------------------------- /tiny_recursive_models/config/arch/transformers_baseline.yaml: -------------------------------------------------------------------------------- 1 | name: architectures.transformers_baseline@Model_ACTV2 2 | loss: 3 | name: losses@ACTLossHead 4 | loss_type: stablemax_cross_entropy 5 | 6 | halt_exploration_prob: 0.1 7 | halt_max_steps: 16 8 | 9 | H_cycles: 1 # kept for compatibility 10 | H_layers: 8 11 | 12 | hidden_size: 512 13 | num_heads: 12 14 | expansion: 4 15 | 16 | puzzle_emb_ndim: ${.hidden_size} 17 | 18 | pos_encodings: rope -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/training/adam_atan2_csrc/adam_atan2.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | namespace adam_atan2 { 5 | 6 | void adam_atan2_cuda_impl_( 7 | std::vector params, 8 | std::vector grads, 9 | std::vector exp_avgs, 10 | std::vector exp_avg_sqs, 11 | std::vector state_steps, 12 | const double lr, 13 | const double beta1, 14 | const double beta2, 15 | const double weight_decay); 16 | 17 | } // namespace adam_atan2 18 | -------------------------------------------------------------------------------- /tiny_recursive_models/config/arch/hrm.yaml: -------------------------------------------------------------------------------- 1 | name: architectures.hrm@HierarchicalReasoningModel_ACTV1 2 | loss: 3 | name: losses@ACTLossHead 4 | loss_type: stablemax_cross_entropy 5 | 6 | halt_exploration_prob: 0.1 7 | halt_max_steps: 16 8 | 9 | H_cycles: 2 10 | L_cycles: 2 11 | 12 | H_layers: 4 13 | L_layers: 4 14 | 15 | hidden_size: 512 16 | num_heads: 8 # min(2, hidden_size // 64) 17 | expansion: 4 18 | 19 | puzzle_emb_ndim: ${.hidden_size} 20 | 21 | pos_encodings: rope 22 | forward_dtype: bfloat16 23 | 24 | mlp_t: False # use mlp on L instead of transformer -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Neural network models and components.""" 2 | 3 | # Import submodules 4 | from tiny_recursive_models.models import architectures 5 | 6 | # Import commonly used utilities 7 | from tiny_recursive_models.models.common import trunc_normal_init_ 8 | from tiny_recursive_models.models.losses import ACTLossHead, IGNORE_LABEL_ID 9 | from tiny_recursive_models.models.ema import EMAHelper 10 | 11 | __all__ = [ 12 | "architectures", 13 | "trunc_normal_init_", 14 | "ACTLossHead", 15 | "IGNORE_LABEL_ID", 16 | "EMAHelper", 17 | ] 18 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/models/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | """Model architecture implementations.""" 2 | 3 | # Make architecture modules easily importable 4 | from tiny_recursive_models.models.architectures import trm 5 | from tiny_recursive_models.models.architectures import hrm 6 | from tiny_recursive_models.models.architectures import trm_singlez 7 | from tiny_recursive_models.models.architectures import trm_hier6 8 | from tiny_recursive_models.models.architectures import transformers_baseline 9 | 10 | __all__ = [ 11 | "trm", 12 | "hrm", 13 | "trm_singlez", 14 | "trm_hier6", 15 | "transformers_baseline", 16 | ] 17 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/README.md: -------------------------------------------------------------------------------- 1 |

Beyond Real: Imaginary Extension of Rotary Position Embeddings for Long-Context LLMs

2 | 3 | ## Introduction 4 | 5 | This is a fork of Karpathy's [Nanochat](https://github.com/karpathy/nanochat) repo. We modify the apply_rotary_emb step, introducing an 'apply_rotary_emb_imaginary' function that can run on either half of the original heads, or doubling the total num_heads count. We run Rope++ in the 'split' configuration and arrive at similar numbers to the original Nanochat. That being said, Rope++ excels in long-context tasks, so we intend on doing further long-context tuning and evaluation shortly with Nanochat and Rope++. -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/utils/functions.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import inspect 3 | 4 | 5 | def load_model_class(identifier: str, prefix: str = "tiny_recursive_models.models."): 6 | module_path, class_name = identifier.split('@') 7 | 8 | # Import the module 9 | module = importlib.import_module(prefix + module_path) 10 | cls = getattr(module, class_name) 11 | 12 | return cls 13 | 14 | 15 | def get_model_source_path(identifier: str, prefix: str = "tiny_recursive_models.models."): 16 | module_path, class_name = identifier.split('@') 17 | 18 | module = importlib.import_module(prefix + module_path) 19 | return inspect.getsourcefile(module) 20 | -------------------------------------------------------------------------------- /tiny_recursive_models/config/arch/trm.yaml: -------------------------------------------------------------------------------- 1 | name: architectures.trm@TinyRecursiveReasoningModel_ACTV1 2 | loss: 3 | name: losses@ACTLossHead 4 | loss_type: stablemax_cross_entropy 5 | 6 | halt_exploration_prob: 0.1 7 | halt_max_steps: 16 8 | 9 | H_cycles: 3 10 | L_cycles: 6 11 | 12 | H_layers: 0 13 | L_layers: 2 14 | 15 | hidden_size: 512 16 | num_heads: 8 # min(2, hidden_size // 64) 17 | expansion: 4 18 | 19 | puzzle_emb_ndim: ${.hidden_size} 20 | 21 | pos_encodings: rope 22 | forward_dtype: bfloat16 23 | 24 | mlp_t: False # use mlp on L instead of transformer 25 | puzzle_emb_len: 16 # if non-zero, its specified to this value 26 | no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense -------------------------------------------------------------------------------- /tiny_recursive_models/config/arch/trm_hier6.yaml: -------------------------------------------------------------------------------- 1 | name: architectures.trm_hier6@TinyRecursiveReasoningModel_ACTV1 2 | loss: 3 | name: losses@ACTLossHead 4 | loss_type: stablemax_cross_entropy 5 | 6 | halt_exploration_prob: 0.1 7 | halt_max_steps: 16 8 | 9 | H_cycles: 3 10 | L_cycles: 6 11 | 12 | H_layers: 0 13 | L_layers: 2 14 | 15 | hidden_size: 512 16 | num_heads: 8 # min(2, hidden_size // 64) 17 | expansion: 4 18 | 19 | puzzle_emb_ndim: ${.hidden_size} 20 | 21 | pos_encodings: rope 22 | forward_dtype: bfloat16 23 | 24 | mlp_t: False # use mlp on L instead of transformer 25 | puzzle_emb_len: 16 # if non-zero, its specified to this value 26 | no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense -------------------------------------------------------------------------------- /tiny_recursive_models/config/arch/trm_singlez.yaml: -------------------------------------------------------------------------------- 1 | name: architectures.trm_singlez@TinyRecursiveReasoningModel_ACTV1 2 | loss: 3 | name: losses@ACTLossHead 4 | loss_type: stablemax_cross_entropy 5 | 6 | halt_exploration_prob: 0.1 7 | halt_max_steps: 16 8 | 9 | H_cycles: 3 10 | L_cycles: 6 11 | 12 | H_layers: 0 13 | L_layers: 2 14 | 15 | hidden_size: 512 16 | num_heads: 8 # min(2, hidden_size // 64) 17 | expansion: 4 18 | 19 | puzzle_emb_ndim: ${.hidden_size} 20 | 21 | pos_encodings: rope 22 | forward_dtype: bfloat16 23 | 24 | mlp_t: False # use mlp on L instead of transformer 25 | puzzle_emb_len: 16 # if non-zero, its specified to this value 26 | no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense -------------------------------------------------------------------------------- /tiny_recursive_models/config/arch/trm_mlp.yaml: -------------------------------------------------------------------------------- 1 | name: architectures.trm@TinyRecursiveReasoningModel_ACTV1 2 | loss: 3 | name: losses@ACTLossHead 4 | loss_type: stablemax_cross_entropy 5 | 6 | halt_exploration_prob: 0.1 7 | halt_max_steps: 16 8 | 9 | H_cycles: 3 10 | L_cycles: 6 11 | 12 | H_layers: 0 13 | L_layers: 2 14 | 15 | hidden_size: 512 16 | num_heads: 8 # min(2, hidden_size // 64) 17 | expansion: 4 18 | 19 | puzzle_emb_ndim: ${.hidden_size} 20 | 21 | pos_encodings: rope 22 | forward_dtype: bfloat16 23 | 24 | mlp_t: True # use mlp on L instead of transformer (MLP variant) 25 | puzzle_emb_len: 16 # if non-zero, its specified to this value 26 | no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense 27 | -------------------------------------------------------------------------------- /tiny_recursive_models/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=65.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "tiny-recursive-models" 7 | version = "0.1.0" 8 | description = "Recursive reasoning with tiny neural networks for solving ARC-AGI and other puzzles" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | license = {text = "MIT"} 12 | authors = [ 13 | {name = "alphaXiv"} 14 | ] 15 | 16 | dependencies = [ 17 | "torch>=2.0.0", 18 | "einops", 19 | "tqdm", 20 | "coolname", 21 | "pydantic>=2.0.0", 22 | "argdantic", 23 | "wandb", 24 | "omegaconf", 25 | "hydra-core", 26 | "huggingface_hub", 27 | "packaging", 28 | "numba", 29 | "ninja", 30 | "setuptools", 31 | ] 32 | -------------------------------------------------------------------------------- /rope_imaginary/rope_pp/configs/rope-376m-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 128000, 8 | "eos_token_id": 128001, 9 | "hidden_act": "silu", 10 | "hidden_size": 1024, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3584, 13 | "max_position_embeddings": 4096, 14 | "model_type": "llama", 15 | "num_attention_heads": 8, 16 | "num_hidden_layers": 8, 17 | "num_key_value_heads": 4, 18 | "pretraining_tp": 1, 19 | "rms_norm_eps": 1e-05, 20 | "rope_scaling": null, 21 | "rope_theta": 10000, 22 | "tie_word_embeddings": false, 23 | "torch_dtype": "bfloat16", 24 | "transformers_version": "4.40.0.dev0", 25 | "use_cache": true, 26 | "vocab_size": 128256 27 | } -------------------------------------------------------------------------------- /rope_imaginary/rope_pp/configs/rope-1_5b-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 128000, 8 | "eos_token_id": 128001, 9 | "hidden_act": "silu", 10 | "hidden_size": 2048, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 7168, 13 | "max_position_embeddings": 4096, 14 | "model_type": "llama", 15 | "num_attention_heads": 16, 16 | "num_hidden_layers": 16, 17 | "num_key_value_heads": 4, 18 | "pretraining_tp": 1, 19 | "rms_norm_eps": 1e-05, 20 | "rope_scaling": null, 21 | "rope_theta": 10000, 22 | "tie_word_embeddings": false, 23 | "torch_dtype": "bfloat16", 24 | "transformers_version": "4.40.0.dev0", 25 | "use_cache": true, 26 | "vocab_size": 128256 27 | } -------------------------------------------------------------------------------- /rope_imaginary/rope_pp/configs/rope-776m-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 128000, 8 | "eos_token_id": 128001, 9 | "hidden_act": "silu", 10 | "hidden_size": 1536, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 5376, 13 | "max_position_embeddings": 4096, 14 | "model_type": "llama", 15 | "num_attention_heads": 12, 16 | "num_hidden_layers": 12, 17 | "num_key_value_heads": 6, 18 | "pretraining_tp": 1, 19 | "rms_norm_eps": 1e-05, 20 | "rope_scaling": null, 21 | "rope_theta": 10000, 22 | "tie_word_embeddings": false, 23 | "torch_dtype": "bfloat16", 24 | "transformers_version": "4.40.0.dev0", 25 | "use_cache": true, 26 | "vocab_size": 128256 27 | } -------------------------------------------------------------------------------- /tiny_recursive_models/config/cfg_pretrain.yaml: -------------------------------------------------------------------------------- 1 | # ARC training config 2 | 3 | defaults: 4 | - arch: trm 5 | - _self_ 6 | 7 | hydra: 8 | output_subdir: null 9 | 10 | # Data path 11 | data_paths: ['data/arc-aug-1000'] 12 | data_paths_test: [] 13 | 14 | evaluators: 15 | - name: arc@ARC 16 | 17 | # Hyperparams - Training 18 | global_batch_size: 768 19 | 20 | epochs: 100000 21 | eval_interval: 10000 22 | checkpoint_every_eval: True 23 | 24 | lr: 1e-4 25 | lr_min_ratio: 1.0 26 | lr_warmup_steps: 2000 27 | 28 | # Standard hyperparameter settings for LM, as used in Llama 29 | beta1: 0.9 30 | beta2: 0.95 31 | weight_decay: 0.1 32 | puzzle_emb_weight_decay: 0.1 33 | 34 | # Hyperparams - Puzzle embeddings training 35 | puzzle_emb_lr: 1e-2 36 | 37 | seed: 0 38 | min_eval_interval: 0 # when to start the eval 39 | 40 | ema: False # use Exponential-Moving-Average 41 | ema_rate: 0.999 # EMA-rate 42 | freeze_weights: False # If True, freeze weights and only learn the embeddings -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # alphaXiv Implementations 2 | 3 | This repo contains implementations of heavily-requested papers on alphaXiv. The goal of this repo is to provide well-documented, easy-to-follow implementations of popular research paper codebases. 4 | 5 | **Request implementations:** Open an issue or click 'implement' on any paper on alphaXiv. 6 | 7 | ## Requirements for new implementation PRs 8 | 9 | Each implementation must include: 10 | 11 | 1. **README with specs**: GPU count/type required, runtime estimates, dataset instructions, reproduction results 12 | 2. **Standard structure**: Use `pyproject.toml` for dependencies and `src/` layout for code 13 | 3. **Speedrun.sh**: Each project must have a clear Nanochat-style speedrun.sh script that sets up the environment and runs relevant scripts for training and evaluation. 14 | 15 | ## Structure 16 | ``` 17 | paper-name/ 18 | ├── README.md 19 | ├── pyproject.toml 20 | └── src/ 21 | └── paper_name/ 22 | ├── train.py 23 | └── eval.py 24 | ``` 25 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/training/__init__.py: -------------------------------------------------------------------------------- 1 | """Training utilities and configuration.""" 2 | 3 | from tiny_recursive_models.training.config import ( 4 | PretrainConfig, 5 | TrainState, 6 | ArchConfig, 7 | LossConfig, 8 | EvaluatorConfig, 9 | ) 10 | from tiny_recursive_models.training.trainer import ( 11 | create_dataloader, 12 | create_model, 13 | init_train_state, 14 | train_batch, 15 | compute_lr, 16 | cosine_schedule_with_warmup_lr_lambda, 17 | ) 18 | from tiny_recursive_models.training.checkpoint import ( 19 | save_train_state, 20 | load_checkpoint, 21 | ) 22 | 23 | __all__ = [ 24 | # Config 25 | "PretrainConfig", 26 | "TrainState", 27 | "ArchConfig", 28 | "LossConfig", 29 | "EvaluatorConfig", 30 | # Trainer 31 | "create_dataloader", 32 | "create_model", 33 | "init_train_state", 34 | "train_batch", 35 | "compute_lr", 36 | "cosine_schedule_with_warmup_lr_lambda", 37 | # Checkpoint 38 | "save_train_state", 39 | "load_checkpoint", 40 | ] 41 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Andrej Karpathy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/rustbpe/README.md: -------------------------------------------------------------------------------- 1 | # rustbpe 2 | 3 | > The missing tiktoken training code 4 | 5 | A very lightweight Rust library for training a GPT tokenizer. The issue is that the inference library [tiktoken](https://github.com/openai/tiktoken) is great, but only does inference. Separately, the huggingface [tokenizers](https://github.com/huggingface/tokenizers) library does training, but it is rather bloated and really hard to navigate because it has to support all the different historical baggage of how people dealt with tokenizers over the years. More recently, I also wrote the [minbpe](https://github.com/karpathy/minbpe) library which does both training and inference, but only in inefficient Python. Basically what I really want is a non-fancy, super simple, but still relatively efficient training code for GPT tokenizer (more efficient than minbpe, much cleaner/simpler than tokenizers), and then export the trained vocab for inference with tiktoken. Does that make sense? So here we are. There are more opportunities for optimization here, I just stopped a bit early because unlike minbpe before it, rustbpe is now simple and fast enough, and not a significant bottleneck for nanochat. 6 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/models/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0): 8 | # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor 9 | # This function is a PyTorch version of jax truncated normal init (default init method in flax) 10 | # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848 11 | # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199 12 | 13 | with torch.no_grad(): 14 | if std == 0: 15 | tensor.zero_() 16 | else: 17 | sqrt2 = math.sqrt(2) 18 | a = math.erf(lower / sqrt2) 19 | b = math.erf(upper / sqrt2) 20 | z = (b - a) / 2 21 | 22 | c = (2 * math.pi) ** -0.5 23 | pdf_u = c * math.exp(-0.5 * lower ** 2) 24 | pdf_l = c * math.exp(-0.5 * upper ** 2) 25 | comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2) 26 | 27 | tensor.uniform_(a, b) 28 | tensor.erfinv_() 29 | tensor.mul_(sqrt2 * comp_std) 30 | tensor.clip_(lower * comp_std, upper * comp_std) 31 | 32 | return tensor 33 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/models/ema.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch.nn as nn 3 | 4 | class EMAHelper(object): 5 | def __init__(self, mu=0.999): 6 | self.mu = mu 7 | self.shadow = {} 8 | 9 | def register(self, module): 10 | if isinstance(module, nn.DataParallel): 11 | module = module.module 12 | for name, param in module.named_parameters(): 13 | if param.requires_grad: 14 | self.shadow[name] = param.data.clone() 15 | 16 | def update(self, module): 17 | if isinstance(module, nn.DataParallel): 18 | module = module.module 19 | for name, param in module.named_parameters(): 20 | if param.requires_grad: 21 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data 22 | 23 | def ema(self, module): 24 | if isinstance(module, nn.DataParallel): 25 | module = module.module 26 | for name, param in module.named_parameters(): 27 | if param.requires_grad: 28 | param.data.copy_(self.shadow[name].data) 29 | 30 | def ema_copy(self, module): 31 | module_copy = copy.deepcopy(module) 32 | self.ema(module_copy) 33 | return module_copy 34 | 35 | def state_dict(self): 36 | return self.shadow 37 | 38 | def load_state_dict(self, state_dict): 39 | self.shadow = state_dict 40 | 41 | -------------------------------------------------------------------------------- /rope_imaginary/rope_pp/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "rope-pp" 3 | version = "0.1.0" 4 | description = "RoPE++ Implementation" 5 | requires-python = "==3.12.*" 6 | dependencies = [ 7 | "ninja", 8 | "packaging", 9 | "psutil", 10 | "transformers==4.51.0", 11 | "datasets", 12 | "wandb", 13 | "zstandard", 14 | "accelerate>=0.26.0", 15 | "deepspeed", 16 | "lm-eval", 17 | "wonderwords", 18 | "nltk", 19 | ] 20 | 21 | [project.optional-dependencies] 22 | cpu = [ 23 | "torch==2.8.0", 24 | ] 25 | gpu = [ 26 | "torch==2.8.0", 27 | "flash-attn", 28 | ] 29 | 30 | [build-system] 31 | requires = ["setuptools>=45", "wheel"] 32 | build-backend = "setuptools.build_meta" 33 | 34 | [tool.setuptools] 35 | packages = ["llama_variants", "utils"] 36 | 37 | # Target torch to CUDA 12.8 or CPU 38 | [tool.uv.sources] 39 | torch = [ 40 | { index = "pytorch-cpu", extra = "cpu" }, 41 | { index = "pytorch-cu128", extra = "gpu" }, 42 | ] 43 | flash-attn = [ 44 | { url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.8-cp312-cp312-linux_x86_64.whl", extra = "gpu" }, 45 | ] 46 | lm-eval = { git = "https://github.com/EleutherAI/lm-evaluation-harness.git" } 47 | 48 | [[tool.uv.index]] 49 | name = "pytorch-cpu" 50 | url = "https://download.pytorch.org/whl/cpu" 51 | explicit = true 52 | 53 | [[tool.uv.index]] 54 | name = "pytorch-cu128" 55 | url = "https://download.pytorch.org/whl/cu128" 56 | explicit = true 57 | 58 | [tool.uv] 59 | conflicts = [ 60 | [ 61 | { extra = "cpu" }, 62 | { extra = "gpu" }, 63 | ], 64 | ] 65 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/data/common.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import pydantic 4 | import numpy as np 5 | 6 | 7 | # Global list mapping each dihedral transform id to its inverse. 8 | # Index corresponds to the original tid, and the value is its inverse. 9 | DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7] 10 | 11 | 12 | class PuzzleDatasetMetadata(pydantic.BaseModel): 13 | pad_id: int 14 | ignore_label_id: Optional[int] 15 | blank_identifier_id: int 16 | vocab_size: int 17 | seq_len: int 18 | num_puzzle_identifiers: int 19 | total_groups: int 20 | mean_puzzle_examples: float 21 | total_puzzles: int 22 | sets: List[str] 23 | 24 | 25 | def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray: 26 | """8 dihedral symmetries by rotate, flip and mirror""" 27 | 28 | if tid == 0: 29 | return arr # identity 30 | elif tid == 1: 31 | return np.rot90(arr, k=1) 32 | elif tid == 2: 33 | return np.rot90(arr, k=2) 34 | elif tid == 3: 35 | return np.rot90(arr, k=3) 36 | elif tid == 4: 37 | return np.fliplr(arr) # horizontal flip 38 | elif tid == 5: 39 | return np.flipud(arr) # vertical flip 40 | elif tid == 6: 41 | return arr.T # transpose (reflection along main diagonal) 42 | elif tid == 7: 43 | return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection 44 | else: 45 | return arr 46 | 47 | 48 | def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray: 49 | return dihedral_transform(arr, DIHEDRAL_INVERSE[tid]) 50 | -------------------------------------------------------------------------------- /rope_imaginary/rope_pp/speedrun-eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if HF_TOKEN is set 4 | if [ -z "$HF_TOKEN" ]; then 5 | echo "WARNING: HF_TOKEN environment variable is not set." 6 | echo "You need to set it to access gated models like meta-llama/Meta-Llama-3-8B" 7 | echo "Example: export HF_TOKEN='your_token_here'" 8 | echo "Get your token from: https://huggingface.co/settings/tokens" 9 | exit 1 10 | fi 11 | 12 | # Install uv (if not already installed) 13 | command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh 14 | 15 | # Create a .venv local virtual environment (if it doesn't exist) 16 | [ -d ".venv" ] || uv venv --python 3.12 17 | 18 | # Activate venv first 19 | source .venv/bin/activate 20 | 21 | # Install all dependencies including flash-attn 22 | # The extra-build-dependencies config will provide torch, setuptools, etc. during flash-attn build 23 | uv sync --extra gpu 24 | 25 | # Venv already activated above 26 | 27 | # Create necessary directories 28 | mkdir -p results logs 29 | 30 | echo "==========================================" 31 | echo "Starting Model Evaluation with LM Harness" 32 | echo "==========================================" 33 | 34 | # Run the standard evaluation script 35 | python eval/eval_lmharness.py 36 | 37 | echo "Evaluation (standard) complete!" 38 | 39 | echo "==========================================" 40 | echo "Starting Long Context Evaluation with LM Harness" 41 | echo "==========================================" 42 | 43 | # Run the long context evaluation script 44 | python eval/eval_lmharness-lctx.py 45 | 46 | echo "==========================================" 47 | echo "Evaluation Complete!" 48 | echo "==========================================" 49 | echo "" 50 | echo "Check the results directory for detailed evaluation outputs." 51 | -------------------------------------------------------------------------------- /tiny_recursive_models/docs/hf_model_cards/model_card_arc_agi_1.md: -------------------------------------------------------------------------------- 1 | --- 2 | language: en 3 | license: mit 4 | tags: 5 | - trm 6 | - recursive-reasoning 7 | - arc-agi 8 | - abstract-reasoning 9 | - pytorch 10 | - huggingface 11 | datasets: 12 | - ARC-AGI 13 | metrics: 14 | - pass@2 15 | widget: 16 | - text: "Sample ARC task here" 17 | --- 18 | 19 | # TRM Model for ARC-AGI-1 20 | 21 | ## Model Description 22 | 23 | This is a Tiny Recursive Model (TRM) fine-tuned for solving Abstract Reasoning Challenge (ARC-AGI) tasks. The model performs abstract reasoning to predict output grids from input grids. 24 | 25 | - **Developed by:** alphaXiv 26 | - **Model type:** TRM-Attention 27 | - **Language(s) (NLP):** N/A (grid-based reasoning) 28 | - **License:** MIT 29 | - **Finetuned from model:** Custom TRM architecture 30 | 31 | ## Intended Use 32 | 33 | ### Primary Use 34 | 35 | This model is designed to solve ARC-AGI tasks by predicting the correct output grid transformation based on input grid patterns. 36 | 37 | ### Out-of-Scope Use 38 | 39 | Not intended for general NLP tasks, image generation, or other reasoning domains. 40 | 41 | ## Limitations and Bias 42 | 43 | - Trained only on ARC-AGI training and evaluation sets 44 | - May not generalize to novel abstract reasoning tasks 45 | - Performance limited by training data diversity 46 | 47 | ## Training Data 48 | 49 | The model was trained on the ARC-AGI dataset, which includes: 50 | - Input-output grid pairs 51 | - Various transformation patterns 52 | - Training and evaluation splits 53 | 54 | ## Evaluation Results 55 | 56 | | Metric | Claimed | Achieved | 57 | |--------|---------|----------| 58 | | Pass@2 | 44.6% | 43.00% ± 0.16% | 59 | 60 | Results from independent reproduction study. 61 | 62 | ## Repository 63 | 64 | https://github.com/alphaXiv/TinyRecursiveModels 65 | -------------------------------------------------------------------------------- /tiny_recursive_models/docs/hf_model_cards/model_card_sudoku.md: -------------------------------------------------------------------------------- 1 | --- 2 | language: en 3 | license: mit 4 | tags: 5 | - trm 6 | - recursive-reasoning 7 | - sudoku 8 | - pytorch 9 | - huggingface 10 | datasets: 11 | - custom 12 | metrics: 13 | - accuracy 14 | widget: 15 | - text: "Sample sudoku puzzle here" 16 | --- 17 | 18 | # TRM Model for Sudoku Solving 19 | 20 | ## Model Description 21 | 22 | This is a Tiny Recursive Model (TRM) fine-tuned for solving Sudoku puzzles. The model uses recursive reasoning to fill in missing numbers in Sudoku grids. 23 | 24 | - **Developed by:** alphaXiv 25 | - **Model type:** TRM-MLP 26 | - **Language(s) (NLP):** N/A (grid-based reasoning) 27 | - **License:** MIT 28 | - **Finetuned from model:** Custom TRM architecture 29 | 30 | ## Intended Use 31 | 32 | ### Primary Use 33 | 34 | This model is designed to solve Sudoku puzzles by predicting the correct numbers for empty cells in standard 9x9 Sudoku grids. 35 | 36 | ### Out-of-Scope Use 37 | 38 | Not intended for general NLP tasks, image processing, or other puzzle types. 39 | 40 | ## Limitations and Bias 41 | 42 | - Trained only on standard 9x9 Sudoku puzzles 43 | - May not handle non-standard Sudoku variants 44 | - Performance depends on puzzle difficulty 45 | 46 | ## Training Data 47 | 48 | The model was trained on a dataset of Sudoku puzzles with extreme difficulty levels. The dataset includes: 49 | - Partially filled 9x9 grids 50 | - Correct solutions 51 | - Difficulty ratings 52 | 53 | ## Evaluation Results 54 | 55 | | Variant | Metric | Claimed | Achieved | 56 | |---------|--------|---------|----------| 57 | | TRM-MLP | Accuracy | 87.4% | 79.37% ± 0.12% | 58 | | TRM-Attention | Accuracy | 74.7% | 73.66% ± 0.13% | 59 | 60 | Results from independent reproduction study. 61 | 62 | ## Repository 63 | 64 | https://github.com/alphaXiv/TinyRecursiveModels -------------------------------------------------------------------------------- /tiny_recursive_models/docs/hf_model_cards/model_card_maze.md: -------------------------------------------------------------------------------- 1 | --- 2 | language: en 3 | license: mit 4 | tags: 5 | - trm 6 | - recursive-reasoning 7 | - maze-solving 8 | - pytorch 9 | - huggingface 10 | datasets: 11 | - custom 12 | metrics: 13 | - accuracy 14 | widget: 15 | - text: "Sample maze input here" 16 | --- 17 | 18 | # TRM Model for Maze Solving 19 | 20 | ## Model Description 21 | 22 | This is a Tiny Recursive Model (TRM) fine-tuned for solving maze navigation tasks. The model implements recursive reasoning to find paths in 30x30 grid mazes. 23 | 24 | - **Developed by:** alphaXiv 25 | - **Model type:** TRM-Attention 26 | - **Language(s) (NLP):** N/A (grid-based reasoning) 27 | - **License:** MIT 28 | - **Finetuned from model:** Custom TRM architecture 29 | 30 | ## Intended Use 31 | 32 | ### Primary Use 33 | 34 | This model is designed to solve maze pathfinding problems by predicting the correct sequence of moves to navigate from start to goal in grid-based mazes. 35 | 36 | ### Out-of-Scope Use 37 | 38 | Not intended for general NLP tasks, image classification, or other domains outside maze solving. 39 | 40 | ## Limitations and Bias 41 | 42 | - Trained only on synthetic maze data 43 | - May not generalize to mazes of different sizes or complexities 44 | - Performance may degrade on mazes with unusual patterns 45 | 46 | ## Training Data 47 | 48 | The model was trained on a dataset of 30x30 grid mazes with hard difficulty levels. The dataset includes: 49 | - Start and goal positions 50 | - Wall configurations 51 | - Correct path sequences 52 | 53 | 54 | 55 | ## Evaluation Results 56 | 57 | | Metric | Claimed | Achieved | 58 | |--------|---------|----------| 59 | | Exact Accuracy | 85.3% | 83.67% ± 2.28% | 60 | 61 | Results from independent reproduction study. 62 | 63 | ## Repository 64 | 65 | https://github.com/alphaXiv/TinyRecursiveModels 66 | -------------------------------------------------------------------------------- /rope_imaginary/README.md: -------------------------------------------------------------------------------- 1 |

Beyond Real: Imaginary Extension of Rotary Position Embeddings for Long-Context LLMs

2 | 3 | ## Introduction 4 | 5 | This repo provides two implementations of [Rope++](https://www.alphaxiv.org/abs/2512.07525) an imaginary extension of Rotary Position Embeddings. Both implementations are meant to be easily run on [Lambda](https://lambda.ai/) 8x H100 80GB instances. For the rope_pp folder we also provide a configuration that can run on a singular 40 GB A100. 6 | 7 | As a refresher, the basis of RoPE is to apply varying rotations across the query and key vectors (as opposed to adding an absolute position vector like in the Attention paper). These rotations are done by segmenting the input vector into vectors of dimension 2 and applying a rotational transformation e^(iθ). These complex terms are discarded when computing attention scores. 8 | 9 | The basis of Rope++ is that the imaginary component of the attention score contains important information and should be considered in half of the attention heads. The imaginary attention component uses a sine-based characteristic curve which decays much more slowly over distance compared to the cosine-based curve in standard RoPE's real attention. This slower decay means imaginary attention maintains stronger weights for distant tokens rather than emphasizing only nearby tokens, allowing it to better capture long-range dependencies in the sequence. 10 | 11 | In this implementation we provide both a cleaned-up adaptation of the official "Rope++" codebase as well as an implementation of "Rope++" with Karpathy's Nanochat repo. For the cleaned-up version of the original Rope++ codebase, we greatly simplify the setup, training, and evaluation process, providing a Nanochat-style 'speedrun' bash script which sets up UV, installs the necessary packages, and runs the training+evaluation scripts all in one go. -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "nanochat" 3 | version = "0.1.0" 4 | description = "the minimal full-stack ChatGPT clone" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | dependencies = [ 8 | "datasets>=4.0.0", 9 | "fastapi>=0.117.1", 10 | "files-to-prompt>=0.6", 11 | "psutil>=7.1.0", 12 | "regex>=2025.9.1", 13 | "setuptools>=80.9.0", 14 | "tiktoken>=0.11.0", 15 | "tokenizers>=0.22.0", 16 | "torch>=2.8.0", 17 | "uvicorn>=0.36.0", 18 | "wandb>=0.21.3", 19 | ] 20 | 21 | [build-system] 22 | requires = ["maturin>=1.7,<2.0"] 23 | build-backend = "maturin" 24 | 25 | [tool.maturin] 26 | module-name = "rustbpe" 27 | bindings = "pyo3" 28 | python-source = "." 29 | manifest-path = "rustbpe/Cargo.toml" 30 | 31 | [dependency-groups] 32 | dev = [ 33 | "maturin>=1.9.4", 34 | "pytest>=8.0.0", 35 | ] 36 | 37 | [tool.pytest.ini_options] 38 | markers = [ 39 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 40 | ] 41 | testpaths = ["tests"] 42 | python_files = ["test_*.py"] 43 | python_classes = ["Test*"] 44 | python_functions = ["test_*"] 45 | 46 | # target torch to cuda 12.8 or CPU 47 | [tool.uv.sources] 48 | torch = [ 49 | { index = "pytorch-cpu", extra = "cpu" }, 50 | { index = "pytorch-cu128", extra = "gpu" }, 51 | ] 52 | 53 | [[tool.uv.index]] 54 | name = "pytorch-cpu" 55 | url = "https://download.pytorch.org/whl/cpu" 56 | explicit = true 57 | 58 | [[tool.uv.index]] 59 | name = "pytorch-cu128" 60 | url = "https://download.pytorch.org/whl/cu128" 61 | explicit = true 62 | 63 | [project.optional-dependencies] 64 | cpu = [ 65 | "torch>=2.8.0", 66 | ] 67 | gpu = [ 68 | "torch>=2.8.0", 69 | ] 70 | 71 | [tool.uv] 72 | conflicts = [ 73 | [ 74 | { extra = "cpu" }, 75 | { extra = "gpu" }, 76 | ], 77 | ] -------------------------------------------------------------------------------- /rope_imaginary/rope_pp/README.md: -------------------------------------------------------------------------------- 1 |

Beyond Real: Imaginary Extension of Rotary Position Embeddings for Long-Context LLMs

2 | 3 | ## Introduction 4 | 5 | This is an adaptation of the official "Rope++" codebase [here](https://github.com/OpenMOSS/rope_pp) released by the authors of the original paper. The original repo does pre-training of 376M and 776M parameter models over the [mlfoundations/dclm-baseline-1.0](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) dataset. Training is split up into 3 scripts: pre-training, pre-training with decay, and then long-context fine-tuning. There are a few configurations of Rope++ that can be trained as well. One may double the existing number of heads, or simply split the existing number of heads into equal part real and imaginary. The original codebase also uses [OpenCompass](https://github.com/open-compass/opencompass) for evaluations, evaluating over several benchmarks including HellaSwag, TruthfulQA, and RULER. 6 | 7 | In this adaptation, we greatly simplify the setup, training, and evaluation process. Inspired by Karpathy's [Nanochat](https://github.com/karpathy/nanochat) implementation we provide a 'speedrun' bash script which sets up UV, installs the necessary packages, and runs the training+evaluation scripts all in one go. We also ran into difficulties setting up OpenCompass, so we swapped out it out with EleutherAI's LM harness, testing over the same benchmarks. 8 | 9 | We provide bash scripts to train the models on both a single 40 GB A100 as well as 8 80GB H100 GPUs. A run with 8 H100 GPUs will take under a day to run to completion. The single GPU configuration will take over a week to run with the 376M variant, and is more so provided for those that want to play around with the implementation, perhaps on smaller number of training steps or model size. Due to memory constraints, the single GPU configuration does not do long-context fine-tuning. For long-context fine-tuning we use NTK, simply increasing max_length from 4k to 32k and the base theta to 500000. 10 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/training/config.py: -------------------------------------------------------------------------------- 1 | """Training configuration classes.""" 2 | 3 | from typing import Optional, List, Any, Sequence 4 | from dataclasses import dataclass 5 | import pydantic 6 | from torch import nn 7 | import torch 8 | 9 | 10 | class LossConfig(pydantic.BaseModel): 11 | model_config = pydantic.ConfigDict(extra='allow') 12 | name: str 13 | 14 | 15 | class ArchConfig(pydantic.BaseModel): 16 | model_config = pydantic.ConfigDict(extra='allow') 17 | name: str 18 | loss: LossConfig 19 | 20 | 21 | class EvaluatorConfig(pydantic.BaseModel): 22 | model_config = pydantic.ConfigDict(extra="allow") 23 | name: str 24 | 25 | 26 | class PretrainConfig(pydantic.BaseModel): 27 | # Config 28 | arch: ArchConfig 29 | # Data 30 | data_paths: List[str] 31 | data_paths_test: List[str] = [] 32 | # Evaluators 33 | evaluators: List[EvaluatorConfig] = [] 34 | 35 | # Hyperparams 36 | global_batch_size: int 37 | epochs: int 38 | 39 | lr: float 40 | lr_min_ratio: float 41 | lr_warmup_steps: int 42 | 43 | weight_decay: float 44 | beta1: float 45 | beta2: float 46 | 47 | # Puzzle embedding 48 | puzzle_emb_lr: float 49 | puzzle_emb_weight_decay: float 50 | 51 | # Names 52 | project_name: Optional[str] = None 53 | run_name: Optional[str] = None 54 | load_checkpoint: Optional[str] = None 55 | checkpoint_path: Optional[str] = None 56 | 57 | # Extras 58 | seed: int = 0 59 | checkpoint_every_eval: bool = False 60 | eval_interval: Optional[int] = None 61 | min_eval_interval: Optional[int] = 0 # when to start eval 62 | eval_save_outputs: List[str] = [] 63 | 64 | ema: bool = False # use Exponential-Moving-Average 65 | ema_rate: float = 0.999 # EMA-rate 66 | freeze_weights: bool = False # If True, freeze weights and only learn the embeddings 67 | 68 | 69 | @dataclass 70 | class TrainState: 71 | model: nn.Module 72 | optimizers: Sequence[torch.optim.Optimizer] 73 | optimizer_lrs: Sequence[float] 74 | carry: Any 75 | 76 | step: int 77 | total_steps: int 78 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/tasks/smoltalk.py: -------------------------------------------------------------------------------- 1 | """ 2 | SmolTalk by HuggingFace. Good "general" conversational dataset. 3 | https://huggingface.co/datasets/HuggingFaceTB/smol-smoltalk 4 | We use the "smol" version, which is more appropriate for smaller models. 5 | """ 6 | 7 | from datasets import load_dataset 8 | from tasks.common import Task 9 | 10 | class SmolTalk(Task): 11 | """ smol-smoltalk dataset. train is 460K rows, test is 24K rows. """ 12 | 13 | def __init__(self, split, **kwargs): 14 | super().__init__(**kwargs) 15 | assert split in ["train", "test"], "SmolTalk split must be train|test" 16 | self.ds = load_dataset("HuggingFaceTB/smol-smoltalk", split=split).shuffle(seed=42) 17 | self.length = len(self.ds) 18 | 19 | def num_examples(self): 20 | return self.length 21 | 22 | def get_example(self, index): 23 | row = self.ds[index] 24 | messages = row["messages"] 25 | # --------------------------------------------------------------------- 26 | # sanity checking asserts here 27 | # TODO: we could remove these asserts later, for now just don't want any footguns 28 | # there is an optional system message at the beginning 29 | assert len(messages) >= 1 30 | first_message = messages[0] 31 | if first_message["role"] == "system": 32 | rest_messages = messages[1:] # optional system message is OK 33 | else: 34 | rest_messages = messages 35 | assert len(rest_messages) >= 2, "SmolTalk messages must have at least 2 messages" 36 | for i, message in enumerate(rest_messages): 37 | # user and assistant alternate as user,assistant,user,assistant,... 38 | expected_role = "user" if i % 2 == 0 else "assistant" 39 | assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" 40 | assert isinstance(message["content"], str), "Content must be a string" 41 | # --------------------------------------------------------------------- 42 | # create and return the Conversation object (ok to emit the system message too) 43 | conversation = { 44 | "messages": messages, 45 | } 46 | return conversation 47 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/nanochat/configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import os 18 | import sys 19 | from ast import literal_eval 20 | 21 | def print0(s="",**kwargs): 22 | ddp_rank = int(os.environ.get('RANK', 0)) 23 | if ddp_rank == 0: 24 | print(s, **kwargs) 25 | 26 | for arg in sys.argv[1:]: 27 | if '=' not in arg: 28 | # assume it's the name of a config file 29 | assert not arg.startswith('--') 30 | config_file = arg 31 | print0(f"Overriding config with {config_file}:") 32 | with open(config_file) as f: 33 | print0(f.read()) 34 | exec(open(config_file).read()) 35 | else: 36 | # assume it's a --key=value argument 37 | assert arg.startswith('--') 38 | key, val = arg.split('=') 39 | key = key[2:] 40 | if key in globals(): 41 | try: 42 | # attempt to eval it it (e.g. if bool, number, or etc) 43 | attempt = literal_eval(val) 44 | except (SyntaxError, ValueError): 45 | # if that goes wrong, just use the string 46 | attempt = val 47 | # ensure the types match ok 48 | if globals()[key] is not None: 49 | attempt_type = type(attempt) 50 | default_type = type(globals()[key]) 51 | assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}" 52 | # cross fingers 53 | print0(f"Overriding: {key} = {attempt}") 54 | globals()[key] = attempt 55 | else: 56 | raise ValueError(f"Unknown config key: {key}") 57 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/tasks/arc.py: -------------------------------------------------------------------------------- 1 | """ 2 | The ARC dataset from Allen AI. 3 | https://huggingface.co/datasets/allenai/ai2_arc 4 | """ 5 | 6 | from datasets import load_dataset 7 | from tasks.common import Task, render_mc 8 | 9 | class ARC(Task): 10 | 11 | def __init__(self, subset, split, **kwargs): 12 | super().__init__(**kwargs) 13 | assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge" 14 | assert split in ["train", "validation", "test"], "ARC split must be train|validation|test" 15 | self.ds = load_dataset("allenai/ai2_arc", subset, split=split).shuffle(seed=42) 16 | 17 | @property 18 | def eval_type(self): 19 | return 'categorical' 20 | 21 | def num_examples(self): 22 | return len(self.ds) 23 | 24 | def get_example(self, index): 25 | row = self.ds[index] 26 | question = row["question"] # the question text 27 | choices = row["choices"]["text"] # the text of each choice 28 | answer_string = row["answerKey"] # e.g. "A", "B", "C", "D" 29 | letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"] 30 | assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check 31 | # create and return the Conversation object 32 | user_message = render_mc(question, letters, choices) 33 | messages = [ 34 | {"role": "user", "content": user_message}, 35 | {"role": "assistant", "content": answer_string} 36 | ] 37 | conversation = { 38 | "messages": messages, 39 | "letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters 40 | } 41 | return conversation 42 | 43 | def evaluate(self, conversation, assistant_response): 44 | # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true 45 | # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it. 46 | assert assistant_response in conversation['letters'], f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}" 47 | assistant_message = conversation['messages'][-1]['content'] # e.g. "A" 48 | return assistant_response == assistant_message 49 | -------------------------------------------------------------------------------- /rope_imaginary/rope_pp/speedrun-1gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if HF_TOKEN is set 4 | if [ -z "$HF_TOKEN" ]; then 5 | echo "WARNING: HF_TOKEN environment variable is not set." 6 | echo "You need to set it to access gated models like meta-llama/Meta-Llama-3-8B" 7 | echo "Example: export HF_TOKEN='your_token_here'" 8 | echo "Get your token from: https://huggingface.co/settings/tokens" 9 | exit 1 10 | fi 11 | 12 | # Install uv (if not already installed) 13 | command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh 14 | 15 | # Create a .venv local virtual environment (if it doesn't exist) 16 | [ -d ".venv" ] || uv venv --python 3.12 17 | 18 | # Activate venv first 19 | source .venv/bin/activate 20 | 21 | # Install all dependencies including flash-attn 22 | # The extra-build-dependencies config will provide torch, setuptools, etc. during flash-attn build 23 | uv sync --extra gpu 24 | 25 | # Venv already activated above 26 | 27 | 28 | # Create necessary directories 29 | mkdir -p checkpoints logs results wandb 30 | 31 | # Experiment name used for checkpoints and logs 32 | EXPERIMENT_NAME="rope_pp-376m-4k-imag2-single-gpu" 33 | 34 | export WANDB_RUN_GROUP="rope_pp-376m-single-gpu" 35 | 36 | # Stage 1: Initial training for 100k steps 37 | echo "==========================================" 38 | echo "Starting Stage 1: Initial Training" 39 | echo "==========================================" 40 | 41 | python single-gpu/train_rope_pp_single_gpu.py \ 42 | --config_abbr '376m' \ 43 | --imag \ 44 | --imag_mode 'imag2' \ 45 | --save_abbr "$EXPERIMENT_NAME" 46 | 47 | echo "Stage 1 complete! Check logs/${EXPERIMENT_NAME}.log" 48 | 49 | # Wait for stage 1 to complete 50 | wait 51 | 52 | # Stage 2: Model Evaluation with LM Harness 53 | echo "==========================================" 54 | echo "Starting Model Evaluation with LM Harness" 55 | echo "==========================================" 56 | 57 | # Evaluate the trained checkpoint 58 | python eval/eval_lmharness.py \ 59 | --local-checkpoint "checkpoints/${EXPERIMENT_NAME}/checkpoint-100000" \ 60 | --model-name "Local-RoPEPP-376M-Trained" \ 61 | --model-type "ropepp" \ 62 | --include-baselines 63 | 64 | echo "==========================================" 65 | echo "Training and Evaluation Complete!" 66 | echo "==========================================" 67 | echo "" 68 | echo "Check the results directory for detailed evaluation outputs." 69 | -------------------------------------------------------------------------------- /rope_imaginary/rope_pp/utils/trainer_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import torch.nn.functional as F 5 | 6 | from transformers import Trainer, TrainingArguments 7 | from transformers import DefaultDataCollator 8 | 9 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 10 | 11 | from utils.dataset_utils import StreamingTrainingParquet, StreamingTrainingJsonlZSD, StreamingTrainingHuggingFace 12 | 13 | 14 | class TrainerWithDatasetCheckpointing(Trainer): 15 | 16 | def _save_checkpoint(self, model, trial): 17 | super()._save_checkpoint(model, trial) 18 | 19 | self.accelerator.wait_for_everyone() 20 | 21 | # Handle both distributed and single GPU training 22 | if torch.distributed.is_initialized(): 23 | rank = torch.distributed.get_rank() 24 | size = torch.distributed.get_world_size() 25 | else: 26 | rank = 0 27 | size = 1 28 | 29 | model_ckpt_path = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 30 | run_dir = self._get_output_dir(trial=trial) 31 | model_ckpt_path = os.path.join(run_dir, model_ckpt_path) 32 | 33 | dataset_ckpt_path = f"{model_ckpt_path}/dataset_ckpt-{rank:{len(str(size))}d}-{size}.pt" 34 | dataset_ckpt_path = os.path.join(model_ckpt_path, dataset_ckpt_path) 35 | 36 | if isinstance(self.train_dataset, StreamingTrainingParquet): 37 | 38 | dataset_ckpt = { 39 | 'data_path': self.train_dataset.data_path, 40 | 'label_name': self.train_dataset.label_name, 41 | 'pivot': self.train_dataset.pivot, 'size': self.train_dataset.size, 42 | 'table_idx': self.train_dataset.table_idx, 43 | 'table_num': self.train_dataset.table_num, 44 | 'table_buffer': self.train_dataset.table_buffer, 45 | 'sample_idx': self.train_dataset.sample_idx, 46 | 'sample_num': self.train_dataset.sample_num, 47 | 'token_buffer': self.train_dataset.token_buffer, 48 | } 49 | 50 | torch.save(dataset_ckpt, dataset_ckpt_path) 51 | 52 | elif isinstance(self.train_dataset, StreamingTrainingJsonlZSD): 53 | 54 | dataset_ckpt = { 55 | 'data_path': self.train_dataset.data_path, 56 | 'label_name': self.train_dataset.label_name, 57 | 'pivot': self.train_dataset.pivot, 'size': self.train_dataset.size, 58 | 'sample_idx': self.train_dataset.sample_idx, 59 | 'token_buffer': self.train_dataset.token_buffer, 60 | } 61 | 62 | torch.save(dataset_ckpt, dataset_ckpt_path) 63 | 64 | elif isinstance(self.train_dataset, StreamingTrainingHuggingFace): 65 | # For HuggingFace streaming dataset, save the token buffer state 66 | dataset_ckpt = { 67 | 'token_buffer': self.train_dataset.token_buffer, 68 | } 69 | 70 | torch.save(dataset_ckpt, dataset_ckpt_path) 71 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/tasks/customjson.py: -------------------------------------------------------------------------------- 1 | """ 2 | CustomJSON task for loading conversations from JSONL files. 3 | Each line in the JSONL file should be a JSON array of messages. 4 | """ 5 | 6 | import os 7 | import json 8 | from tasks.common import Task 9 | 10 | class CustomJSON(Task): 11 | """ 12 | Load conversations from a JSONL file. 13 | Each line should be a JSON array of message objects with 'role' and 'content' fields. 14 | Example line: [{"role":"user","content":"Hi"},{"role":"assistant","content":"Hello"}] 15 | """ 16 | 17 | def __init__(self, filepath, **kwargs): 18 | super().__init__(**kwargs) 19 | self.filepath = filepath 20 | self.conversations = [] 21 | 22 | # Load all conversations from the JSONL file 23 | if not os.path.exists(filepath): 24 | # Helpful error message due to recent change. Will be removed in the future. 25 | print("-" * 80) 26 | print(f"Warning: File {filepath} does not exist") 27 | print("HINT (Oct 21 2025)") 28 | print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations") 29 | print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139") 30 | print("Quick fix: simply run the following command to download the file and you're done:") 31 | print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl") 32 | print("-" * 80) 33 | 34 | else: 35 | with open(filepath, 'r', encoding='utf-8') as f: 36 | for line in f: 37 | line = line.strip() 38 | if not line: # skip empty lines 39 | continue 40 | messages = json.loads(line) 41 | # Validate the conversation structure 42 | assert isinstance(messages, list), f"Expected list of messages, got {type(messages)}" 43 | assert len(messages) >= 2, f"Conversation must have at least 2 messages, got {len(messages)}" 44 | # Validate message structure and alternating roles 45 | for i, message in enumerate(messages): 46 | assert "role" in message, f"Message {i} missing 'role' field" 47 | assert "content" in message, f"Message {i} missing 'content' field" 48 | expected_role = "user" if i % 2 == 0 else "assistant" 49 | assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" 50 | assert isinstance(message["content"], str), f"Message {i} content must be a string" 51 | 52 | self.conversations.append(messages) 53 | 54 | self.length = len(self.conversations) 55 | 56 | def num_examples(self): 57 | return self.length 58 | 59 | def get_example(self, index): 60 | messages = self.conversations[index] 61 | conversation = { 62 | "messages": messages, 63 | } 64 | return conversation 65 | 66 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/nanochat/loss_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | A number of functions that help with evaluating a base model. 3 | """ 4 | import math 5 | import torch 6 | import torch.distributed as dist 7 | 8 | @torch.no_grad() 9 | def evaluate_bpb(model, batches, steps, token_bytes): 10 | """ 11 | Instead of the naive 'mean loss', this function returns the bits per byte (bpb), 12 | which is a tokenization vocab size-independent metric, meaning you are still comparing 13 | apples:apples if you change the vocab size. The way this works is that instead of just 14 | calculating the average loss as usual, you calculate the sum loss, and independently 15 | also the sum bytes (of all the target tokens), and divide. This normalizes the loss by 16 | the number of bytes that the target tokens represent. 17 | 18 | The added complexity is so that: 19 | 1) All "normal" tokens are normalized by the length of the token in bytes 20 | 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out. 21 | 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric. 22 | 23 | In addition to evaluate_loss, we need the token_bytes tensor: 24 | It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for 25 | each token id, or 0 if the token is to not be counted (e.g. special tokens). 26 | """ 27 | # record the losses 28 | total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device()) 29 | total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device()) 30 | batch_iter = iter(batches) 31 | for _ in range(steps): 32 | x, y = next(batch_iter) 33 | loss2d = model(x, y, loss_reduction='none') # (B, T) 34 | loss2d = loss2d.view(-1) # flatten 35 | y = y.view(-1) # flatten 36 | if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32 37 | # slightly more complex code path if some target tokens are ignore_index (e.g. -1) 38 | # any target token < 0 is to be ignored: do NOT index token_bytes with negatives 39 | valid = y >= 0 40 | y_safe = torch.where(valid, y, torch.zeros_like(y)) 41 | # map valid targets to their byte length; ignored targets contribute 0 bytes 42 | num_bytes2d = torch.where( 43 | valid, 44 | token_bytes[y_safe], 45 | torch.zeros_like(y, dtype=token_bytes.dtype) 46 | ) 47 | total_nats += (loss2d * (num_bytes2d > 0)).sum() 48 | total_bytes += num_bytes2d.sum() 49 | else: 50 | # fast path: no ignored targets, safe to index directly 51 | num_bytes2d = token_bytes[y] 52 | total_nats += (loss2d * (num_bytes2d > 0)).sum() 53 | total_bytes += num_bytes2d.sum() 54 | # sum reduce across all ranks 55 | world_size = dist.get_world_size() if dist.is_initialized() else 1 56 | if world_size > 1: 57 | dist.all_reduce(total_nats, op=dist.ReduceOp.SUM) 58 | dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM) 59 | # move both to cpu, calculate bpb and return 60 | total_nats = total_nats.item() 61 | total_bytes = total_bytes.item() 62 | if total_bytes == 0: 63 | return float('inf') 64 | bpb = total_nats / (math.log(2) * total_bytes) 65 | return bpb 66 | -------------------------------------------------------------------------------- /rope_imaginary/rope_pp/speedrun-8gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if HF_TOKEN is set 4 | if [ -z "$HF_TOKEN" ]; then 5 | echo "WARNING: HF_TOKEN environment variable is not set." 6 | echo "You need to set it to access gated models like meta-llama/Meta-Llama-3-8B" 7 | echo "Example: export HF_TOKEN='your_token_here'" 8 | echo "Get your token from: https://huggingface.co/settings/tokens" 9 | exit 1 10 | fi 11 | 12 | # Install uv (if not already installed) 13 | command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh 14 | 15 | # Create a .venv local virtual environment (if it doesn't exist) 16 | [ -d ".venv" ] || uv venv --python 3.12 17 | 18 | # Activate venv first 19 | source .venv/bin/activate 20 | 21 | # Install all dependencies including flash-attn 22 | # The extra-build-dependencies config will provide torch, setuptools, etc. during flash-attn build 23 | uv sync --extra gpu 24 | 25 | # Venv already activated above 26 | 27 | 28 | # Create necessary directories 29 | mkdir -p checkpoints logs results wandb 30 | 31 | # Generate random port for deepspeed 32 | port=$(shuf -i25000-30000 -n1) 33 | 34 | # Experiment name used for checkpoints and logs 35 | EXPERIMENT_NAME="rope_pp-376m-4k-imag2" 36 | 37 | export WANDB_RUN_GROUP="rope_pp-376m" 38 | 39 | # Stage 1: Initial training for 100k steps 40 | echo "==========================================" 41 | echo "Starting Stage 1: Initial Training" 42 | echo "==========================================" 43 | 44 | deepspeed --master_port "$port" --include localhost:0,1,2,3,4,5,6,7 \ 45 | multi-gpu/train_rope_pp.py \ 46 | --config_abbr '376m' \ 47 | --imag \ 48 | --imag_mode 'imag2' \ 49 | --exp_name "$EXPERIMENT_NAME" 50 | 51 | echo "Stage 1 complete! Check logs/${EXPERIMENT_NAME}.log" 52 | 53 | # Wait for stage 1 to complete 54 | wait 55 | 56 | # Stage 2: Long context training from checkpoint 90000 57 | echo "==========================================" 58 | echo "Starting Stage 2: Long Context Training" 59 | echo "==========================================" 60 | 61 | deepspeed --master_port "$port" --include localhost:0,1,2,3,4,5,6,7 \ 62 | multi-gpu/train_rope_pp.py \ 63 | --config_abbr '376m' \ 64 | --imag \ 65 | --imag_mode 'imag2' \ 66 | --exp_name "${EXPERIMENT_NAME}-lctx" \ 67 | --load_ckpt 90000 \ 68 | --max_length 32768 \ 69 | --rope_theta 500000 \ 70 | --batch_size 16 \ 71 | --max_steps 10000 72 | 73 | echo "Stage 2 complete! Check logs/${EXPERIMENT_NAME}-lctx.log" 74 | 75 | # Wait for stage 2 to complete 76 | wait 77 | 78 | # Stage 3: Model Evaluation with LM Harness (Long Context) 79 | echo "==========================================" 80 | echo "Starting Long Context Evaluation with LM Harness" 81 | echo "==========================================" 82 | 83 | # Evaluate the checkpoint after long context training 84 | python eval/eval_lmharness-lctx.py \ 85 | --local-checkpoint "checkpoints/${EXPERIMENT_NAME}-lctx/checkpoint-10000" \ 86 | --model-name "Local-RoPEPP-376M-After-LongContext" \ 87 | --model-type "ropepp" \ 88 | --include-baselines 89 | 90 | echo "Evaluation (long context) complete!" 91 | 92 | echo "==========================================" 93 | echo "Training and Evaluation Complete!" 94 | echo "==========================================" 95 | echo "Final model saved at: checkpoints/${EXPERIMENT_NAME}-lctx" 96 | echo "" 97 | echo "Check the results directory for detailed evaluation outputs." 98 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/scripts/base_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loads a checkpoint, and: 3 | - Evaluates the loss on a larger chunk of train/val splits 4 | - Samples from the model 5 | 6 | Example run as: 7 | torchrun --standalone --nproc_per_node=8 -m scripts.base_loss 8 | """ 9 | import os 10 | from contextlib import nullcontext 11 | import torch 12 | from nanochat.checkpoint_manager import load_model 13 | from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type 14 | from nanochat.dataloader import tokenizing_distributed_data_loader 15 | from nanochat.tokenizer import get_token_bytes 16 | from nanochat.loss_eval import evaluate_bpb 17 | from nanochat.engine import Engine 18 | 19 | # Configuration 20 | device_batch_size = 32 21 | split_tokens = 20*524288 # number of tokens to evaluate per split 22 | model_tag = None # optional model tag for the output directory name 23 | model_step = None # optional model step for the output directory name 24 | device_type = "" # cuda|cpu|mps (empty => autodetect) 25 | exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file 26 | 27 | # Load the base model and the tokenizer 28 | device_type = autodetect_device_type() if device_type == "" else device_type 29 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) 30 | model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step) 31 | sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really 32 | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() 33 | 34 | # Evaluate the loss on each split 35 | tokens_per_step = device_batch_size * sequence_len * ddp_world_size 36 | assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step" 37 | steps = split_tokens // tokens_per_step 38 | token_bytes = get_token_bytes(device=device) 39 | bpb_results = {} 40 | for split_name in ["train", "val"]: 41 | loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device) 42 | with autocast_ctx: 43 | bpb = evaluate_bpb(model, loader, steps, token_bytes) 44 | print0(f"{split_name} bpb: {bpb:.4f}") 45 | bpb_results[split_name] = bpb 46 | 47 | # Master process also samples from the model 48 | samples = [] 49 | if ddp_rank == 0: 50 | prompts = [ 51 | "The capital of France is", 52 | "The chemical symbol of gold is", 53 | "If yesterday was Friday, then tomorrow will be", 54 | "The opposite of hot is", 55 | "The planets of the solar system are:", 56 | "My favorite color is", 57 | "If 5*x + 3 = 13, then x is", 58 | ] 59 | engine = Engine(model, tokenizer) 60 | for prompt in prompts: 61 | tokens = tokenizer(prompt, prepend="<|bos|>") 62 | with autocast_ctx: 63 | sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) 64 | sample_str = tokenizer.decode(sample[0]) 65 | print0(sample_str) 66 | samples.append(sample_str) 67 | 68 | # Log to report 69 | from nanochat.report import get_report 70 | get_report().log(section="Base model loss", data=[ 71 | { 72 | "train bpb": bpb_results["train"], 73 | "val bpb": bpb_results["val"], 74 | }, 75 | {f"sample {i}": sample for i, sample in enumerate(samples)}, 76 | ]) 77 | 78 | # Cleanup 79 | compute_cleanup() 80 | -------------------------------------------------------------------------------- /tiny_recursive_models/scripts/cmd.sh: -------------------------------------------------------------------------------- 1 | #Pretrain MAZE-HARD 2 | 3 | torchrun --nproc-per-node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py arch=trm data_paths="[data/maze-30x30-hard-1k]" evaluators="[]" epochs=50000 eval_interval=5000 global_batch_size=1536 lr=2e-4 lr_warmup_steps=4000 puzzle_emb_lr=1e-4 checkpoint_every_eval=True weight_decay=1.0 puzzle_emb_weight_decay=1.0 arch.L_layers=2 arch.H_cycles=3 arch.L_cycles=4 +run_name=${run_name} ema=True 4 | 5 | #Sudoko-extreme 6 | 7 | #mlp 8 | run_name="pretrain_mlp_t_sudoku" 9 | 10 | torchrun --nproc-per-node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \ 11 | arch=trm \ 12 | data_paths="[data/sudoku-extreme-1k-aug-1000]" \ 13 | evaluators="[]" \ 14 | epochs=50000 eval_interval=5000 \ 15 | lr=2e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \ 16 | arch.mlp_t=True arch.pos_encodings=none \ 17 | arch.L_layers=2 \ 18 | arch.H_cycles=3 arch.L_cycles=6 \ 19 | lr_warmup_steps=4000 \ 20 | global_batch_size=1536 \ 21 | +run_name=${run_name} ema=True 22 | 23 | #Attn 24 | 25 | torchrun --nproc-per-node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py arch=trm data_paths="[data/sudoku-extreme-1k-aug-1000]" evaluators="[]" epochs=50000 eval_interval=5000 lr=2e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 arch.L_layers=2 arch.H_cycles=3 arch.L_cycles=6 lr_warmup_steps=4000 global_batch_size=1536 +run_name=${run_name} ema=True 26 | 27 | 28 | --- 29 | #evals MAZE-HARD 30 | 31 | #Attn 32 | 33 | torchrun --nproc_per_node=8 run_eval.py \ 34 | --checkpoint ./step_32550 \ 35 | --dataset data/maze-30x30-hard-1k \ 36 | --outdir checkpoints/maze_eval_run \ 37 | --eval-save-outputs inputs labels puzzle_identifiers preds \ 38 | --global-batch-size 1536 \ 39 | --apply-ema \ 40 | --repeats 3 \ 41 | --seed-start 0 42 | 43 | #evals Sudoku 44 | 45 | #MLP 46 | torchrun --nproc_per_node=8 run_eval.py \ 47 | --checkpoint ./step_32550_sudoku_epoch50k \ 48 | --dataset data/maze-30x30-hard-1k \ 49 | --outdir checkpoints/maze_eval_run \ 50 | --eval-save-outputs inputs labels puzzle_identifiers preds \ 51 | --global-batch-size 1536 \ 52 | --apply-ema \ 53 | --repeats 3 \ 54 | --seed-start 0 55 | 56 | 57 | 58 | ### ARC-AGI - 1 59 | 60 | # attention 61 | run_name="pretrain_att_arc1concept_4" 62 | torchrun --nproc-per-node 8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py arch=trm data_paths="[data/arc1concept-aug-1000]" arch.L_layers=2 arch.H_cycles=3 arch.L_cycles=6 +run_name=${run_name} ema=True lr=2e-4 weight_decay=0.1 global_batch_size=1536 lr_warmup_steps=4000 epochs=100000 puzzle_emb_lr=1e-2 eval_interval=5000 63 | 64 | MLPrun_name="pretrain_att_arc1concept_4" 65 | torchrun --nproc-per-node 8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py arch=trm data_paths="[data/arc1concept-aug-1000]" arch.L_layers=2 arch.H_cycles=3 arch.L_cycles=6 +run_name=${run_name} ema=True lr=2e-4 weight_decay=0.1 global_batch_size=1536 lr_warmup_steps=4000 epochs=100000 puzzle_emb_lr=1e-2 eval_interval=5000 66 | 67 | # MLP 68 | 69 | run_name="pretrain_att_arc1concept_h3l6" 70 | torchrun --nproc-per-node 8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py arch=trm data_paths="[data/arc1concept-aug-1000]" arch.L_layers=2 arch.H_cycles=3 arch.L_cycles=4 +run_name=${run_name} ema=True lr=2e-4 weight_decay=0.1 global_batch_size=1536 lr_warmup_steps=4000 epochs=100000 puzzle_emb_lr=1e-2 eval_interval=5000 arch.mlp_t=true 71 | 72 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/nanochat/adamw.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrowed from modded-nanogpt. By Keller, @vagrawal, et al. 3 | Not a general optimizer! But works for our specific use. 4 | """ 5 | import torch 6 | import torch.distributed as dist 7 | from torch import Tensor 8 | 9 | 10 | class DistAdamW(torch.optim.Optimizer): 11 | """ 12 | Distributed AdamW optimizer. 13 | In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction 14 | """ 15 | def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): 16 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 17 | super().__init__(param_groups, defaults) 18 | 19 | @torch.compile 20 | @torch.no_grad() 21 | def step(self): 22 | rank = dist.get_rank() 23 | world_size = dist.get_world_size() 24 | reduce_scatter_futures: list[torch.Future] = [] 25 | all_reduce_futures: list[torch.Future] = [] 26 | grad_slices = [] 27 | for group in self.param_groups: 28 | params: list[Tensor] = group["params"] 29 | for base_i in range(len(params)): 30 | grad = params[base_i].grad 31 | rank_size = grad.shape[0] // world_size 32 | grad_slice = torch.empty_like(grad[:rank_size]) 33 | reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) 34 | grad_slices.append(grad_slice) 35 | 36 | idx = 0 37 | for group in self.param_groups: 38 | beta1, beta2 = group['betas'] 39 | eps = group['eps'] 40 | wd = group['weight_decay'] 41 | params = group['params'] 42 | for base in range(len(params)): 43 | reduce_scatter_futures[idx].wait() 44 | p = params[base] 45 | rank_size = p.shape[0] // world_size 46 | p_slice = p[rank * rank_size:(rank + 1) * rank_size] 47 | lr = group['lr'] * getattr(p, "lr_mul", 1.0) 48 | state = self.state[p] 49 | g_slice = grad_slices[idx] 50 | # State init 51 | if not state: 52 | state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) 53 | state['exp_avg'] = torch.zeros_like(p_slice) 54 | state['exp_avg_sq'] = torch.zeros_like(p_slice) 55 | exp_avg = state['exp_avg'] 56 | exp_avg_sq = state['exp_avg_sq'] 57 | state['step'] += 1 58 | t = state['step'] 59 | # weight decay 60 | if wd != 0: 61 | eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) 62 | p_slice.mul_(1 - eff_weight_decay) 63 | # update running averages 64 | exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) 65 | exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) 66 | # bias corrections 67 | bias1 = 1 - beta1 ** t 68 | bias2 = 1 - beta2 ** t 69 | # compute step 70 | denom = exp_avg_sq.sqrt().add_(eps) 71 | step_size = lr * (torch.sqrt(bias2) / bias1) 72 | update = exp_avg.div(denom).mul_(step_size) 73 | p_slice.add_(other=update, alpha=-1.0) 74 | idx += 1 75 | all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) 76 | torch.futures.collect_all(all_reduce_futures).wait() 77 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/tasks/humaneval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate the Chat model on HumanEval dataset. 3 | Btw this dataset is a misnomer and has nothing to do with humans. 4 | It is a coding benchmark. 5 | """ 6 | 7 | import re 8 | from datasets import load_dataset 9 | from nanochat.execution import execute_code 10 | from tasks.common import Task 11 | 12 | def extract_imports(prompt): 13 | """Extract import statements from the beginning of a code block.""" 14 | imports = [] 15 | for line in prompt.split('\n'): 16 | stripped = line.strip() 17 | if stripped.startswith('import ') or stripped.startswith('from '): 18 | imports.append(stripped) 19 | elif stripped and not stripped.startswith('#'): 20 | # Stop at first non-import, non-comment line 21 | break 22 | return '\n'.join(imports) 23 | 24 | def extract_program(completion): 25 | """ 26 | Extract Python code from LLM completion. 27 | 28 | Handles various output formats: 29 | - Code wrapped in ```python ... ``` or ``` ... ``` blocks 30 | - Plain code without markdown blocks 31 | - Extra text before/after code blocks 32 | 33 | Returns the first code block if found, otherwise returns the whole completion. 34 | """ 35 | # Try to find markdown code blocks (```python or just ```) 36 | # Match ```python\n...\n``` or ```\n...\n``` 37 | pattern = r'```(?:python)?\s*\n(.*?)\n```' 38 | matches = re.findall(pattern, completion, re.DOTALL) 39 | 40 | if matches: 41 | # Return the first code block found 42 | return matches[0].strip() 43 | 44 | # No code blocks found, return the whole completion 45 | return completion.strip() 46 | 47 | class HumanEval(Task): 48 | 49 | def __init__(self, **kwargs): 50 | super().__init__(**kwargs) 51 | self.ds = load_dataset("openai/openai_humaneval", split="test").shuffle(seed=42) 52 | 53 | @property 54 | def eval_type(self): 55 | return 'generative' 56 | 57 | def num_examples(self): 58 | return len(self.ds) 59 | 60 | def get_example(self, index): 61 | """ Get a single problem from the dataset. """ 62 | row = self.ds[index] 63 | prompt = row['prompt'] # prompts in HumanEval are the beginning of the program 64 | solution = row['canonical_solution'] # the correct continuation of the program 65 | entry_point = row['entry_point'] # the function to check 66 | test = row['test'] # the test cases 67 | complete_solution = f"{prompt}\n{solution}" 68 | messages = [ 69 | {"role": "user", "content": prompt}, 70 | {"role": "assistant", "content": complete_solution}, 71 | ] 72 | conversation = { 73 | "messages": messages, 74 | "entry_point": entry_point, # needed during evaluation 75 | "test": test, # needed during evaluation 76 | } 77 | return conversation 78 | 79 | def evaluate(self, conversation, completion): 80 | """ Given (conversation, completion), return boolean success of the completion. """ 81 | # the prompt will contain the imports and the function signature 82 | imports = extract_imports(conversation['messages'][0]['content']) 83 | # the completion will usually contain the whole function 84 | # but not always with the needed imports, so we manually append them 85 | completion_code = extract_program(completion) 86 | program = ( 87 | imports 88 | + "\n\n" 89 | + completion_code 90 | + "\n\n" 91 | + conversation['test'] 92 | + "\n" 93 | + f"check({conversation['entry_point']})" 94 | ) 95 | result = execute_code(program) 96 | success = result.success 97 | return success 98 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/tasks/mmlu.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MMLU dataset. 3 | https://huggingface.co/datasets/cais/mmlu 4 | """ 5 | 6 | from datasets import load_dataset 7 | from tasks.common import Task, render_mc 8 | 9 | class MMLU(Task): 10 | 11 | letters = ('A', 'B', 'C', 'D') 12 | groups = ('abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions') 13 | 14 | def __init__(self, subset, split, **kwargs): 15 | super().__init__(**kwargs) 16 | assert subset in ["all", "auxiliary_train"], f"subset {subset} must be all|auxiliary_train" 17 | assert split in ["train", "validation", "dev", "test"], f"split {split} must be train|validation|dev|test" 18 | if subset == "auxiliary_train": 19 | assert split == "train", "auxiliary_train must be split into train" 20 | self.subset = subset 21 | self.split = split 22 | self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42) 23 | if subset == "auxiliary_train": 24 | # I don't understand why but the auxiliary_train rows have some weird additional 'train' wrapper 25 | self.ds = self.ds.map(lambda row: row['train'], remove_columns=['train']) 26 | 27 | @property 28 | def eval_type(self): 29 | return 'categorical' 30 | 31 | def num_examples(self): 32 | return len(self.ds) 33 | 34 | def get_example(self, index): 35 | row = self.ds[index] 36 | question = row["question"] # the question text 37 | choices = row["choices"] # the text of each choice 38 | answer = row["answer"] # index of the answer, e.g. 0,1,2,3 (for A,B,C,D) 39 | subject = row["subject"] # e.g. "college_biology", "college_chemistry", etc. 40 | assert len(choices) == 4, "MMLU should have 4 choices" 41 | # create and return the Conversation object 42 | user_message = render_mc(question, self.letters, choices) 43 | assistant_message = self.letters[answer] 44 | messages = [ 45 | {"role": "user", "content": user_message}, 46 | {"role": "assistant", "content": assistant_message} 47 | ] 48 | conversation = { 49 | "messages": messages, 50 | "subject": subject, # might be useful later for grouping metrics by subject 51 | "letters": self.letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters 52 | } 53 | return conversation 54 | 55 | def evaluate(self, conversation, assistant_response): 56 | # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true 57 | # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it. 58 | assert assistant_response in self.letters, f"MMLU answer {assistant_response} is expected to be one of {self.letters}" 59 | assistant_message = conversation['messages'][-1]['content'] # e.g. "A" 60 | return assistant_response == assistant_message 61 | -------------------------------------------------------------------------------- /tiny_recursive_models/README.md: -------------------------------------------------------------------------------- 1 | # Tiny Recursive Reasoning 2 | 3 | This is an implementation of the paper: "Less is More: Recursive Reasoning with Tiny Networks". It is forked from the original author's codebase [here](https://github.com/SamsungSAILMontreal/TinyRecursiveModels). We provide some re-organization of the original work as well as a Nanochat-style speedrun bash script that takes care of environment setup, training, and evaluation. 4 | 5 | TRM is a recursive reasoning approach that achieves amazing scores of 45% on ARC-AGI-1 and 8% on ARC-AGI-2 using a tiny 7M parameters neural network. Read the paper [here](https://www.alphaxiv.org/abs/2510.04871) 6 | 7 | ### How TRM works 8 | 9 |

10 | TRM 11 |
12 | TRM iteratively updates latent z and answer y. 13 |

14 | 15 | ## Quickstart 16 | 17 | We used Lambda Labs Image 22.4 with 4xH100 80GB SXM GPUs instance with CUDA version 12.8. More info in [REPORT.md](docs/REPORT.md) 18 | 19 | ### One-Line Setup with speedrun.sh 20 | 21 | The easiest way to get started is using our `speedrun.sh` script that handles everything: 22 | 23 | ```bash 24 | # Single task (auto-detects GPU count) 25 | bash speedrun.sh arc1 # ARC-AGI-1 26 | bash speedrun.sh arc2 # ARC-AGI-2 27 | bash speedrun.sh sudoku # Sudoku-Extreme 28 | bash speedrun.sh maze # Maze-Hard 30x30 29 | 30 | # Force single or multi-GPU mode 31 | bash speedrun.sh arc1 single-gpu # Use 1 GPU 32 | bash speedrun.sh arc2 multi-gpu # Use all available GPUs 33 | 34 | # Run all tasks 35 | bash speedrun.sh all 36 | ``` 37 | 38 | The script automatically: 39 | - Installs `uv` if not present 40 | - Creates virtual environment with `uv venv` 41 | - Installs PyTorch and dependencies 42 | - Builds datasets 43 | - Trains models 44 | - Evaluates results 45 | 46 | ### Evaluating Pre-trained Models 47 | 48 | We provide pre-trained model weights: 49 | 50 | - Maze: https://huggingface.co/alphaXiv/trm-model-maze 51 | - Sudoku: https://huggingface.co/alphaXiv/trm-model-sudoku 52 | - ARC-AGI-1: https://huggingface.co/alphaXiv/trm-model-arc-agi-1 53 | 54 | **Quick Start with speedrun-inference.sh:** 55 | 56 | ```bash 57 | # Full evaluation (uses all available GPUs) 58 | bash speedrun-inference.sh arc1 # ARC-AGI-1 59 | bash speedrun-inference.sh maze # Maze-Hard 60 | bash speedrun-inference.sh sudoku # Sudoku-Extreme 61 | 62 | # Evaluate all models 63 | bash speedrun-inference.sh all 64 | ``` 65 | 66 | 67 | **Note:** The `speedrun.sh` script handles all dataset building, training, and evaluation automatically. Manual commands are provided for advanced users who need custom configurations. 68 | 69 | ## Reproducing paper numbers 70 | 71 | - Build the exact datasets above (`arc1concept-aug-1000`, `arc2concept-aug-1000`, `maze-30x30-hard-1k`, `sudoku-extreme-1k-aug-1000`). 72 | - Use the training commands in this README (matching `scripts/cmd.sh` but with minor fixes like line breaks and env-safe flags). 73 | - Keep seeds at defaults (`seed=0` in `config/cfg_pretrain.yaml`); runs are deterministic modulo CUDA kernels. 74 | - Evaluate with `scripts/run_eval_only.py` and report `exact_accuracy` and per-task metrics. The script will compute Wilson 95% CI when dataset metadata is present. 75 | 76 | ## Reproduction Report 77 | 78 | For detailed analysis of independent reproduction attempts and comparison with published claims, see [REPORT.md](docs/REPORT.md). 79 | 80 | This report includes evaluation results, performance comparisons, and insights from reproducing the TRM paper's results across Maze-Hard, ARC-AGI-1, and Sudoku-Extreme benchmarks. 81 | 82 | ## Troubleshooting 83 | 84 | - PyTorch install: pick wheels matching your CUDA; on macOS (CPU/MPS) training will be very slow — prefer Linux + NVIDIA GPU for training. 85 | - NCCL errors: ensure you run under `torchrun` on a Linux box with GPUs and that `nvidia-smi` shows all devices. 86 | - Checkpoints and EMA: training saves EMA by default when `ema=True`; the eval script applies EMA unless disabled. 87 | 88 | 89 | This code is based on the original Tiny Recursive Model [code](https://github.com/SamsungSAILMontreal/TinyRecursiveModels). 90 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/models/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple, Dict, Sequence, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | import math 7 | 8 | IGNORE_LABEL_ID = -100 9 | 10 | 11 | def s(x, epsilon=1e-30): 12 | return torch.where( 13 | x<0, 14 | 1/(1-x+ epsilon), 15 | x + 1 16 | ) 17 | 18 | 19 | def log_stablemax(x, dim=-1): 20 | s_x = s(x) 21 | return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True)) 22 | 23 | 24 | def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None): 25 | logprobs = log_stablemax(logits.to(torch.float64), dim=-1) 26 | 27 | if valid_mask is None: 28 | valid_mask = (labels != ignore_index) 29 | transformed_labels = torch.where(valid_mask, labels, 0) 30 | prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1) 31 | 32 | return -torch.where(valid_mask, prediction_logprobs, 0) 33 | 34 | 35 | def softmax_cross_entropy(logits, labels, ignore_index: int = -100): 36 | # Cast logits to f32 37 | # Flatten logits 38 | return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape) 39 | 40 | 41 | class ACTLossHead(nn.Module): 42 | def __init__(self, model: nn.Module, loss_type: str): 43 | super().__init__() 44 | self.model = model 45 | self.loss_fn = globals()[loss_type] 46 | 47 | def initial_carry(self, *args, **kwargs): 48 | return self.model.initial_carry(*args, **kwargs) # type: ignore 49 | 50 | def forward( 51 | self, 52 | return_keys: Sequence[str], 53 | # Model args 54 | **model_kwargs, 55 | ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]: 56 | # Model logits 57 | # B x SeqLen x D 58 | new_carry, outputs = self.model(**model_kwargs) 59 | labels = new_carry.current_data["labels"] 60 | 61 | with torch.no_grad(): 62 | # Preds 63 | outputs["preds"] = torch.argmax(outputs["logits"], dim=-1) 64 | 65 | # Correctness 66 | mask = (labels != IGNORE_LABEL_ID) 67 | loss_counts = mask.sum(-1) 68 | loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division 69 | 70 | is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels) 71 | seq_is_correct = is_correct.sum(-1) == loss_counts 72 | 73 | # Metrics (halted) 74 | valid_metrics = new_carry.halted & (loss_counts > 0) 75 | metrics = { 76 | "count": valid_metrics.sum(), 77 | 78 | "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(), 79 | "exact_accuracy": (valid_metrics & seq_is_correct).sum(), 80 | 81 | "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(), 82 | "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(), 83 | } 84 | 85 | # Losses 86 | 87 | lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum() 88 | q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum") 89 | metrics.update({ 90 | "lm_loss": lm_loss.detach(), 91 | "q_halt_loss": q_halt_loss.detach(), 92 | }) 93 | # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary 94 | q_continue_loss = 0 95 | if "target_q_continue" in outputs: 96 | q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum") 97 | 98 | metrics["q_continue_loss"] = q_continue_loss.detach() 99 | # Filter outputs for return 100 | detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs} 101 | 102 | return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all() 103 | 104 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/training/checkpoint.py: -------------------------------------------------------------------------------- 1 | """Checkpoint saving and loading utilities.""" 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | 7 | from tiny_recursive_models.training.config import PretrainConfig, TrainState 8 | 9 | 10 | def save_train_state(config: PretrainConfig, train_state: TrainState): 11 | """Save model checkpoint to disk. 12 | 13 | Args: 14 | config: Training configuration 15 | train_state: Current training state 16 | """ 17 | # FIXME: Only saved model. 18 | if config.checkpoint_path is None: 19 | return 20 | 21 | os.makedirs(config.checkpoint_path, exist_ok=True) 22 | torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}")) 23 | 24 | 25 | def load_checkpoint(model: nn.Module, config: PretrainConfig): 26 | """Load model checkpoint from disk or HuggingFace. 27 | 28 | Args: 29 | model: Model to load checkpoint into 30 | config: Training configuration with load_checkpoint path 31 | """ 32 | if config.load_checkpoint is not None: 33 | print(f"Loading checkpoint {config.load_checkpoint}") 34 | 35 | checkpoint_path = config.load_checkpoint 36 | 37 | # Check if this is a HuggingFace repo path (format: "username/repo/filename") 38 | if "/" in checkpoint_path and not os.path.exists(checkpoint_path): 39 | try: 40 | from huggingface_hub import hf_hub_download 41 | 42 | # Parse HuggingFace path: "alphaXiv/trm-model-maze/maze_hard_step_32550" 43 | parts = checkpoint_path.split("/", 2) 44 | if len(parts) < 3: 45 | raise ValueError( 46 | f"HuggingFace path must be in format 'username/repo/filename'. Got: {checkpoint_path}" 47 | ) 48 | 49 | repo_id = f"{parts[0]}/{parts[1]}" 50 | filename = parts[2] 51 | 52 | print(f"Downloading from HuggingFace: repo={repo_id}, file={filename}") 53 | checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename) 54 | print(f"Downloaded to: {checkpoint_path}") 55 | 56 | except ImportError: 57 | raise ImportError( 58 | "huggingface_hub is required to load checkpoints from HuggingFace. " 59 | "Install it with: pip install huggingface_hub" 60 | ) 61 | except Exception as e: 62 | raise RuntimeError(f"Failed to download checkpoint from HuggingFace: {e}") 63 | 64 | # Load state dict 65 | state_dict = torch.load(checkpoint_path, map_location="cuda") 66 | 67 | # Always strip compile/DataParallel style prefixes so keys match the 68 | # non-compiled module. We won't be using torch.compile in eval. 69 | def _strip_prefixes(sd: dict) -> dict: 70 | out: dict[str, torch.Tensor] = {} 71 | for k, v in sd.items(): 72 | key = k 73 | if isinstance(key, str): 74 | # remove a leading '.' if present 75 | if key.startswith('.'): 76 | key = key[1:] 77 | # known wrapper prefixes to drop 78 | for pref in ("_orig_mod.", "_orig._mod.", "module."): 79 | if key.startswith(pref): 80 | key = key[len(pref):] 81 | break 82 | out[key] = v 83 | return out 84 | 85 | state_dict = _strip_prefixes(state_dict) 86 | 87 | # Resize and reset puzzle emb if needed 88 | try: 89 | expected_shape: torch.Size = model.model.puzzle_emb.weights.shape # type: ignore 90 | puzzle_emb_name = "model.inner.puzzle_emb.weights" 91 | if puzzle_emb_name in state_dict: 92 | puzzle_emb = state_dict[puzzle_emb_name] 93 | if getattr(puzzle_emb, 'shape', None) != expected_shape: 94 | print(f"Resetting puzzle embedding as shape is different. Found {getattr(puzzle_emb, 'shape', None)}, Expected {expected_shape}") 95 | state_dict[puzzle_emb_name] = ( 96 | torch.mean(puzzle_emb, dim=0, keepdim=True).expand(expected_shape).contiguous() 97 | ) 98 | except Exception: 99 | pass 100 | 101 | model.load_state_dict(state_dict, assign=True) 102 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/scripts/tok_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a tokenizer using the HuggingFace Tokenizers library. 3 | In the style of GPT-4 tokenizer. 4 | """ 5 | import os 6 | import time 7 | import argparse 8 | import torch 9 | from nanochat.tokenizer import RustBPETokenizer 10 | from nanochat.common import get_base_dir 11 | from nanochat.dataset import parquets_iter_batched 12 | 13 | # ----------------------------------------------------------------------------- 14 | # Parse command line arguments 15 | 16 | parser = argparse.ArgumentParser(description='Train a BPE tokenizer') 17 | parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)') 18 | parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') 19 | parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)') 20 | args = parser.parse_args() 21 | print(f"max_chars: {args.max_chars:,}") 22 | print(f"doc_cap: {args.doc_cap:,}") 23 | print(f"vocab_size: {args.vocab_size:,}") 24 | 25 | # ----------------------------------------------------------------------------- 26 | # Text iterator 27 | 28 | def text_iterator(): 29 | """ 30 | 1) Flatten the batches into a single iterator 31 | 2) Crop every document to args.doc_cap characters 32 | 3) Break when we've seen args.max_chars characters 33 | """ 34 | nchars = 0 35 | for batch in parquets_iter_batched(split="train"): 36 | for doc in batch: 37 | doc_text = doc 38 | if len(doc_text) > args.doc_cap: 39 | doc_text = doc_text[:args.doc_cap] 40 | nchars += len(doc_text) 41 | yield doc_text 42 | if nchars > args.max_chars: 43 | return 44 | text_iter = text_iterator() 45 | 46 | # ----------------------------------------------------------------------------- 47 | # Train the tokenizer 48 | t0 = time.time() 49 | tokenizer = RustBPETokenizer.train_from_iterator(text_iter, args.vocab_size) 50 | t1 = time.time() 51 | train_time = t1 - t0 52 | print(f"Training time: {train_time:.2f}s") 53 | 54 | # ----------------------------------------------------------------------------- 55 | # Save the tokenizer to disk 56 | base_dir = get_base_dir() 57 | tokenizer_dir = os.path.join(base_dir, "tokenizer") 58 | tokenizer.save(tokenizer_dir) 59 | 60 | # ----------------------------------------------------------------------------- 61 | # Quick inline sanity check 62 | test_text = """Hello world! This is a test. 63 | Numbers: 123, 4567, 89 64 | Contractions: I'm, you're, it's 65 | Special chars: @#$%^&*() 66 | Unicode: 你好世界 🌍""" 67 | encoded = tokenizer.encode(test_text) 68 | decoded = tokenizer.decode(encoded) 69 | assert decoded == test_text 70 | 71 | # ----------------------------------------------------------------------------- 72 | # One more thing: we wish to cache a mapping from token id to number of bytes of that token 73 | # for efficient evaluation of bits per byte. Unlike the typical mean loss, this 74 | # allows us to report a loss that is invariant to the vocab size of the tokenizer. 75 | # The bits per byte on the validation set is then one of the primary metrics we care about. 76 | vocab_size = tokenizer.get_vocab_size() 77 | special_set = set(tokenizer.get_special_tokens()) 78 | token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)] 79 | token_bytes = [] 80 | for token_id in range(vocab_size): 81 | token_str = token_strings[token_id] # the Python string representation of this token 82 | if token_str in special_set: 83 | token_bytes.append(0) # special characters are not counted 84 | else: 85 | id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token 86 | token_bytes.append(id_bytes) 87 | token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu') 88 | token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt") 89 | with open(token_bytes_path, "wb") as f: 90 | torch.save(token_bytes, f) 91 | print(f"Saved token_bytes to {token_bytes_path}") 92 | 93 | # Log to report 94 | from nanochat.report import get_report 95 | token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32) 96 | get_report().log(section="Tokenizer training", data=[ 97 | vars(args), # argparse command line arguments 98 | {"train_time": train_time}, 99 | {"num_special_tokens": len(special_set)}, 100 | { 101 | "token_bytes_min": int(token_bytes_nonzero.min().item()), 102 | "token_bytes_max": int(token_bytes_nonzero.max().item()), 103 | "token_bytes_mean": token_bytes_nonzero.mean().item(), 104 | "token_bytes_std": token_bytes_nonzero.std().item(), 105 | } 106 | ]) 107 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/models/sparse_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch import nn 5 | import torch.distributed as dist 6 | from torch.optim.optimizer import Optimizer, ParamsT 7 | 8 | from tiny_recursive_models.models.common import trunc_normal_init_ 9 | 10 | 11 | class CastedSparseEmbedding(nn.Module): 12 | def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype): 13 | super().__init__() 14 | self.cast_to = cast_to 15 | 16 | # Real Weights 17 | # Truncated LeCun normal init 18 | self.weights = nn.Buffer( 19 | trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True 20 | ) 21 | 22 | # Local weights and IDs 23 | # Local embeddings, with gradient, not persistent 24 | self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False) 25 | # Local embedding IDs, not persistent 26 | self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False) 27 | 28 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 29 | if not self.training: 30 | # Test mode, no gradient 31 | return self.weights[inputs].to(self.cast_to) 32 | 33 | # Training mode, fill puzzle embedding from weights 34 | with torch.no_grad(): 35 | self.local_weights.copy_(self.weights[inputs]) 36 | self.local_ids.copy_(inputs) 37 | 38 | return self.local_weights.to(self.cast_to) 39 | 40 | 41 | class CastedSparseEmbeddingSignSGD_Distributed(Optimizer): 42 | def __init__( 43 | self, 44 | params: ParamsT, 45 | 46 | world_size: int, 47 | lr: Union[float, torch.Tensor] = 1e-3, 48 | weight_decay: float = 1e-2, 49 | ): 50 | if not 0.0 <= lr: 51 | raise ValueError(f"Invalid learning rate: {lr}") 52 | if not 0.0 <= weight_decay: 53 | raise ValueError(f"Invalid weight_decay value: {weight_decay}") 54 | 55 | defaults = dict( 56 | lr=lr, 57 | weight_decay=weight_decay, 58 | world_size=world_size 59 | ) 60 | super().__init__(params, defaults) 61 | 62 | @torch.no_grad 63 | def step(self, closure=None): # type: ignore 64 | for group in self.param_groups: 65 | # Find the sparse embedding weights 66 | local_weights_grad = None 67 | local_ids = None 68 | weights = None 69 | 70 | assert len(group["params"]) == 3 71 | for p in group["params"]: 72 | if p.requires_grad: 73 | local_weights_grad = p.grad 74 | elif p.ndim == 1: 75 | local_ids = p 76 | elif p.ndim == 2: 77 | weights = p 78 | else: 79 | assert False 80 | 81 | assert local_ids is not None 82 | assert weights is not None 83 | 84 | # Apply SignSGD 85 | # Adam ≈ SignSGD if gradient is very sparse 86 | if local_weights_grad is not None: 87 | _sparse_emb_signsgd_dist( 88 | local_weights_grad, 89 | local_ids, 90 | weights, 91 | 92 | lr=group["lr"], 93 | weight_decay=group["weight_decay"], 94 | world_size=group["world_size"] 95 | ) 96 | 97 | 98 | def _sparse_emb_signsgd_dist( 99 | local_weights_grad: torch.Tensor, 100 | local_ids: torch.Tensor, 101 | weights: torch.Tensor, 102 | 103 | lr: float, 104 | weight_decay: float, 105 | world_size: int 106 | ) -> None: 107 | N, D = local_weights_grad.shape 108 | 109 | # All-gather 110 | all_weights_grad = local_weights_grad 111 | all_ids = local_ids 112 | 113 | if world_size > 1: 114 | all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device) 115 | all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device) 116 | 117 | dist.all_gather_into_tensor(all_weights_grad, local_weights_grad) 118 | dist.all_gather_into_tensor(all_ids, local_ids) 119 | 120 | # Unique 121 | grad_ids, inv = all_ids.unique(return_inverse=True) 122 | 123 | grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device) 124 | grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad) 125 | 126 | # SignSGD with decoupled weight decay 127 | p = weights[grad_ids] 128 | 129 | p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr) 130 | 131 | # Write updated slices back 132 | weights[grad_ids] = p 133 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/run1000.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # The $1000 tier of nanochat 4 | # Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node 5 | # A bit sparser on comments, see speedrun.sh for more detail 6 | 7 | # all the setup stuff 8 | export OMP_NUM_THREADS=1 9 | export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" 10 | mkdir -p $NANOCHAT_BASE_DIR 11 | command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh 12 | [ -d ".venv" ] || uv venv 13 | uv sync --extra gpu 14 | source .venv/bin/activate 15 | if [ -z "$WANDB_RUN" ]; then 16 | WANDB_RUN=dummy 17 | fi 18 | python -m nanochat.report reset 19 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y 20 | source "$HOME/.cargo/env" 21 | uv run maturin develop --release --manifest-path rustbpe/Cargo.toml 22 | curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl 23 | 24 | # train tokenizer on ~4B characters and kick off download of the rest for pretraining 25 | python -m nanochat.dataset -n 16 26 | # start downloading the rest of the shards for a total of 800 (see below why 800) 27 | python -m nanochat.dataset -n 800 & 28 | # todo: download the rest of it 29 | python -m scripts.tok_train --max_chars=4000000000 30 | python -m scripts.tok_eval 31 | 32 | # Documenting my process for determining the hyperparameters for this run1000.sh script: 33 | # We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute 34 | # 1) I guessed the model size for this to be about depth=32 35 | # 2) Determine the device_batch_size that fits: 36 | # Running the base_train.py script with --depth=32, I saw that --device_batch_size=16 37 | # runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training, 38 | # I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%. 39 | # So the training script was running ok and showed: 40 | # Vocab size: 65,536 41 | # num_layers: 32 42 | # model_dim: 2048 43 | # num_heads: 16 44 | # num_kv_heads: 16 45 | # Tokens / micro-batch / rank: 8 x 2048 = 16,384 46 | # Tokens / micro-batch: 131,072 47 | # Total batch size 524,288 => gradient accumulation steps: 4 48 | # Number of parameters: 1,879,048,192 49 | # Estimated FLOPs per token: 1.207960e+10 50 | # Calculated number of iterations from target data:param ratio: 71,680 51 | # Total number of training tokens: 37,580,963,840 52 | # Tokens : Params ratio: 20.00 53 | # Total training FLOPs estimate: 4.539628e+20 54 | # step 00004/71680 (0.01%) | loss: 8.813754 | lrm: 1.00 | dt: 1571.88ms | tok/sec: 83,385 | mfu: 50.92 | total time: 0.00m 55 | # step 00005/71680 (0.01%) | loss: 8.488074 | lrm: 1.00 | dt: 1572.76ms | tok/sec: 83,338 | mfu: 50.89 | total time: 0.00m 56 | # ... 57 | # 3) validate that the runtime fits our budget: 58 | # The training script uses the Chinchilla scaling law to compute-optimally set #tokens = 20 * #params. In particular: 59 | # The script shows that we will be training for 71,680 steps, and each step takes 1.574s so: 60 | # estimated time to train: 71,680 * 1.574s / 60 / 60 = 31.3 hours. 61 | # This is OK, fits our budget, and leaves ~10 hours for midtraining and SFT and evals and maybe RL. 62 | # It's possible that we might even fit depth=33 or depth=34, but for now let's go along with this. 63 | # 4) The last thing to pay attention to is the amount of training data required for the run. 64 | # The script above calculated that "Total number of training tokens: 37,580,963,840" 65 | # The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings. 66 | # So ~38B tokens # ~4.8 chars/token = ~185B chars. 67 | # Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards. 68 | # For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards. 69 | # If we didn't have enough data, the training script would loop around and do multiple epochs over the same data, 70 | # which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd 71 | # start to overfit hard. 72 | # 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script. 73 | 74 | # Number of processes/GPUs to use 75 | NPROC_PER_NODE=8 76 | 77 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --device_batch_size=8 --run=$WANDB_RUN 78 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss 79 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval 80 | 81 | # midtrain 82 | # NOTE: ensure that we use the same device_batch_size here as the base training script. 83 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN 84 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid 85 | 86 | # sft 87 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN 88 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft 89 | 90 | # generate final report 91 | python -m nanochat.report generate 92 | 93 | # talk to it 94 | python -m scripts.chat_web 95 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/data/build_maze_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import math 3 | import os 4 | import csv 5 | import json 6 | import numpy as np 7 | 8 | from argdantic import ArgParser 9 | from pydantic import BaseModel 10 | from tqdm import tqdm 11 | from huggingface_hub import hf_hub_download 12 | 13 | from trm.data.common import PuzzleDatasetMetadata, dihedral_transform 14 | 15 | 16 | CHARSET = "# SGo" 17 | 18 | 19 | cli = ArgParser() 20 | 21 | 22 | class DataProcessConfig(BaseModel): 23 | source_repo: str = "sapientinc/maze-30x30-hard-1k" 24 | output_dir: str = "data/maze-30x30-hard-1k" 25 | 26 | subsample_size: Optional[int] = None 27 | aug: bool = False 28 | 29 | 30 | def convert_subset(set_name: str, config: DataProcessConfig): 31 | # Read CSV 32 | all_chars = set() 33 | grid_size = None 34 | inputs = [] 35 | labels = [] 36 | 37 | with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore 38 | reader = csv.reader(csvfile) 39 | next(reader) # Skip header 40 | for source, q, a, rating in reader: 41 | all_chars.update(q) 42 | all_chars.update(a) 43 | 44 | if grid_size is None: 45 | n = int(len(q) ** 0.5) 46 | grid_size = (n, n) 47 | 48 | inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size)) 49 | labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size)) 50 | 51 | # If subsample_size is specified for the training set, 52 | # randomly sample the desired number of examples. 53 | if set_name == "train" and config.subsample_size is not None: 54 | total_samples = len(inputs) 55 | if config.subsample_size < total_samples: 56 | indices = np.random.choice(total_samples, size=config.subsample_size, replace=False) 57 | inputs = [inputs[i] for i in indices] 58 | labels = [labels[i] for i in indices] 59 | 60 | # Generate dataset 61 | results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} 62 | puzzle_id = 0 63 | example_id = 0 64 | 65 | results["puzzle_indices"].append(0) 66 | results["group_indices"].append(0) 67 | 68 | for inp, out in zip(tqdm(inputs), labels): 69 | # Dihedral transformations for augmentation 70 | for aug_idx in range(8 if (set_name == "train" and config.aug) else 1): 71 | results["inputs"].append(dihedral_transform(inp, aug_idx)) 72 | results["labels"].append(dihedral_transform(out, aug_idx)) 73 | example_id += 1 74 | puzzle_id += 1 75 | 76 | results["puzzle_indices"].append(example_id) 77 | results["puzzle_identifiers"].append(0) 78 | 79 | # Push group 80 | results["group_indices"].append(puzzle_id) 81 | 82 | # Char mappings 83 | assert len(all_chars - set(CHARSET)) == 0 84 | 85 | char2id = np.zeros(256, np.uint8) 86 | char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1 87 | 88 | # To Numpy 89 | def _seq_to_numpy(seq): 90 | arr = np.vstack([char2id[s.reshape(-1)] for s in seq]) 91 | 92 | return arr 93 | 94 | results = { 95 | "inputs": _seq_to_numpy(results["inputs"]), 96 | "labels": _seq_to_numpy(results["labels"]), 97 | 98 | "group_indices": np.array(results["group_indices"], dtype=np.int32), 99 | "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32), 100 | "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32), 101 | } 102 | 103 | # Metadata 104 | metadata = PuzzleDatasetMetadata( 105 | seq_len=int(math.prod(grid_size)), # type: ignore 106 | vocab_size=len(CHARSET) + 1, # PAD + Charset 107 | pad_id=0, 108 | ignore_label_id=0, 109 | blank_identifier_id=0, 110 | num_puzzle_identifiers=1, 111 | total_groups=len(results["group_indices"]) - 1, 112 | mean_puzzle_examples=1, 113 | total_puzzles=len(results["group_indices"]) - 1, 114 | sets=["all"] 115 | ) 116 | 117 | # Save metadata as JSON. 118 | save_dir = os.path.join(config.output_dir, set_name) 119 | os.makedirs(save_dir, exist_ok=True) 120 | 121 | with open(os.path.join(save_dir, "dataset.json"), "w") as f: 122 | json.dump(metadata.model_dump(), f) 123 | 124 | # Save data 125 | for k, v in results.items(): 126 | np.save(os.path.join(save_dir, f"all__{k}.npy"), v) 127 | 128 | # Save IDs mapping (for visualization only) 129 | with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: 130 | json.dump([""], f) 131 | 132 | 133 | @cli.command(singleton=True) 134 | def preprocess_data(config: DataProcessConfig): 135 | convert_subset("train", config) 136 | convert_subset("test", config) 137 | 138 | 139 | if __name__ == "__main__": 140 | cli() 141 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/tasks/gsm8k.py: -------------------------------------------------------------------------------- 1 | """ 2 | GSM8K evaluation. 3 | https://huggingface.co/datasets/openai/gsm8k 4 | 5 | Example problem instance: 6 | 7 | Question: 8 | Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn? 9 | Answer: 10 | Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute. 11 | Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10. 12 | #### 10 13 | 14 | Notice that GSM8K uses tool calls inside << >> tags. 15 | """ 16 | 17 | import re 18 | from datasets import load_dataset 19 | from tasks.common import Task 20 | 21 | 22 | GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)") 23 | def extract_answer(completion): 24 | """ 25 | Extract the numerical answer after #### marker. 26 | Follows official code for normalization: 27 | https://github.com/openai/grade-school-math/blob/3101c7d5072418e28b9008a6636bde82a006892c/grade_school_math/dataset.py#L28 28 | """ 29 | match = GSM_RE.search(completion) 30 | if match: 31 | match_str = match.group(1).strip() 32 | match_str = match_str.replace(",", "") 33 | return match_str 34 | return None 35 | 36 | 37 | class GSM8K(Task): 38 | 39 | def __init__(self, subset, split, **kwargs): 40 | super().__init__(**kwargs) 41 | assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic" 42 | assert split in ["train", "test"], "GSM8K split must be train|test" 43 | self.ds = load_dataset("openai/gsm8k", subset, split=split).shuffle(seed=42) 44 | 45 | @property 46 | def eval_type(self): 47 | return 'generative' 48 | 49 | def num_examples(self): 50 | return len(self.ds) 51 | 52 | def get_example(self, index): 53 | """ Get a single problem from the dataset. """ 54 | row = self.ds[index] 55 | question = row['question'] # string of the question prompt 56 | answer = row['answer'] # string of the full solution and the answer after #### marker 57 | # Create and return the Conversation object 58 | # This is tricky because GSM8K uses tool calls, which we need to parse here. 59 | assistant_message_parts = [] 60 | parts = re.split(r'(<<[^>]+>>)', answer) 61 | for part in parts: 62 | if part.startswith('<<') and part.endswith('>>'): 63 | # This is a calculator tool call 64 | inner = part[2:-2] # Remove << >> 65 | # Split on = to get expression and result 66 | if '=' in inner: 67 | expr, result = inner.rsplit('=', 1) 68 | else: 69 | expr, result = inner, "" 70 | # Add the tool call as a part 71 | assistant_message_parts.append({"type": "python", "text": expr}) 72 | # Add the result as a part 73 | assistant_message_parts.append({"type": "python_output", "text": result}) 74 | else: 75 | # Regular text in between tool calls 76 | assistant_message_parts.append({"type": "text", "text": part}) 77 | # Now put it all together 78 | messages = [ 79 | {"role": "user", "content": question}, # note: simple string 80 | {"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts) 81 | ] 82 | conversation = { 83 | "messages": messages, 84 | } 85 | return conversation 86 | 87 | def evaluate(self, conversation, assistant_response): 88 | """ 89 | Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct) 90 | Note that: 91 | - the conversation has both user AND assistant message (containing the ground truth answer) 92 | - the assistant_response is usually the alternative assistant message achieved via sampling 93 | 94 | TODO: Technically, assistant_response should be a Message (either a string or a list of parts) 95 | We can handle this later possibly. For now just assume string. 96 | """ 97 | assert isinstance(assistant_response, str), "Assuming simple string response for now" 98 | # First extract the ground truth answer 99 | assistant_message = conversation['messages'][-1] 100 | assert assistant_message['role'] == "assistant", "Last message must be from the Assistant" 101 | assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts" 102 | last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K 103 | # Extract both the ground truth answer and the predicted answer 104 | ref_num = extract_answer(last_text_part) 105 | pred_num = extract_answer(assistant_response) 106 | # Compare and return the success as int 107 | is_correct = int(pred_num == ref_num) 108 | return is_correct 109 | 110 | def reward(self, conversation, assistant_response): 111 | """ 112 | Used during RL. To keep things simple, just re-use the evaluation above. 113 | Later this could be made more complex (e.g. format matching etc.) 114 | """ 115 | is_correct = self.evaluate(conversation, assistant_response) 116 | is_correct_float = float(is_correct) 117 | return is_correct_float 118 | -------------------------------------------------------------------------------- /rope_imaginary/rope_pp/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Abstra 171 | # Abstra is an AI-powered process automation framework. 172 | # Ignore directories containing user credentials, local state, and settings. 173 | # Learn more at https://abstra.io/docs 174 | .abstra/ 175 | 176 | # Visual Studio Code 177 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 178 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 179 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 180 | # you could uncomment the following to ignore the enitre vscode folder 181 | # .vscode/ 182 | 183 | # Ruff stuff: 184 | .ruff_cache/ 185 | 186 | # PyPI configuration file 187 | .pypirc 188 | 189 | # Cursor 190 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 191 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 192 | # refer to https://docs.cursor.com/context/ignore-files 193 | .cursorignore 194 | .cursorindexingignore 195 | 196 | checkpoints/ 197 | tensorboards/ 198 | results/ 199 | wandb/ 200 | logs/ 201 | tmp.py 202 | *.pkl 203 | *.txt 204 | *heatmap*.jpg 205 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/nanochat/dataloader.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import torch 4 | import pyarrow.parquet as pq 5 | 6 | from nanochat.common import get_dist_info 7 | from nanochat.dataset import list_parquet_files 8 | from nanochat.tokenizer import get_tokenizer 9 | 10 | def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None): 11 | """ 12 | Stream pretraining text from parquet files, tokenize, yield training batches. 13 | 14 | This implementation became a bit more complex because we wish to support approximate resume training. 15 | Instead of turning this into a Class, we opt to return the state_dict with every batch, 16 | and then the caller can pass in a state_dict to resume training from a desired point. 17 | Note that this resumption is atm only *approximate* for simplicity. 18 | We won't repeat the same documents but we might skip a few. 19 | The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume. 20 | 21 | Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm. 22 | """ 23 | assert split in ["train", "val"], "split must be 'train' or 'val'" 24 | 25 | # infinite iterator over document batches (list of text strings) 26 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() 27 | def document_batches(): 28 | parquet_paths = list_parquet_files() 29 | parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] 30 | resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 31 | resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None 32 | first_pass = True 33 | pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0) 34 | while True: # iterate infinitely (multi-epoch) 35 | pq_idx = resume_pq_idx if first_pass else 0 36 | while pq_idx < len(parquet_paths): # iterate over all parquet files 37 | filepath = parquet_paths[pq_idx] 38 | pf = pq.ParquetFile(filepath) 39 | # Start from resume point if resuming on same file, otherwise from DDP rank 40 | # I know this state resumption is a little bit tricky and a little bit hacky... sigh. 41 | if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx): 42 | base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size 43 | base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming 44 | rg_idx = base_idx * ddp_world_size + ddp_rank 45 | if rg_idx >= pf.num_row_groups: 46 | pq_idx += 1 47 | continue 48 | resume_rg_idx = None # set to None as we only want to do this a single time 49 | else: 50 | rg_idx = ddp_rank 51 | while rg_idx < pf.num_row_groups: 52 | rg = pf.read_row_group(rg_idx) 53 | batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows 54 | # the tokenizer encode might want to go in even smaller batches, e.g. 128 rows 55 | for i in range(0, len(batch), tokenizer_batch_size): 56 | yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx) 57 | rg_idx += ddp_world_size # advance to the next row group (in DDP) 58 | pq_idx += 1 # advance to the next parquet file 59 | first_pass = False 60 | batches = document_batches() 61 | 62 | # Now emit batches of tokens. 63 | needed_tokens = B * T + 1 # +1 is because we also need the target at the last token 64 | # get the tokenizer and the bos token 65 | tokenizer = get_tokenizer() 66 | bos_token = tokenizer.get_bos_token_id() 67 | # scratch buffer holds the tokens for one iteration 68 | token_buffer = deque() # we stream tokens on the right and pop from the left 69 | while True: 70 | # Accumulate enough tokens for one iteration before yielding. 71 | while len(token_buffer) < needed_tokens: 72 | doc_batch, (pq_idx, rg_idx) = next(batches) 73 | token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) 74 | for tokens in token_lists: 75 | token_buffer.extend(tokens) 76 | # Move tokens from the deque into the scratch buffer 77 | tokens = [token_buffer.popleft() for _ in range(needed_tokens)] 78 | # CUDA supports memory pinning for asynchronous transfers between CPU and GPU 79 | use_cuda_optimizations = device == "cuda" 80 | scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64 81 | # Create the inputs/targets as 1D tensors 82 | inputs_cpu = scratch[:-1] 83 | targets_cpu = scratch[1:] 84 | # Reshape to 2D and move to GPU async 85 | inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) 86 | targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) 87 | state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training 88 | yield inputs, targets, state_dict 89 | 90 | def tokenizing_distributed_data_loader(*args, **kwargs): 91 | # helper function that only emits the inputs/targets and not the state_dict 92 | for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs): 93 | yield inputs, targets 94 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/nanochat/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | The base/pretraining dataset is a set of parquet files. 3 | This file contains utilities for: 4 | - iterating over the parquet files and yielding documents from it 5 | - download the files on demand if they are not on disk 6 | 7 | For details of how the dataset was prepared, see `repackage_data_reference.py`. 8 | """ 9 | 10 | import os 11 | import argparse 12 | import time 13 | import requests 14 | import pyarrow.parquet as pq 15 | from multiprocessing import Pool 16 | 17 | from nanochat.common import get_base_dir 18 | 19 | # ----------------------------------------------------------------------------- 20 | # The specifics of the current pretraining dataset 21 | 22 | # The URL on the internet where the data is hosted and downloaded from on demand 23 | BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main" 24 | MAX_SHARD = 1822 # the last datashard is shard_01822.parquet 25 | index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames 26 | base_dir = get_base_dir() 27 | DATA_DIR = os.path.join(base_dir, "base_data") 28 | os.makedirs(DATA_DIR, exist_ok=True) 29 | 30 | # ----------------------------------------------------------------------------- 31 | # These functions are useful utilities to other modules, can/should be imported 32 | 33 | def list_parquet_files(data_dir=None): 34 | """ Looks into a data dir and returns full paths to all parquet files. """ 35 | data_dir = DATA_DIR if data_dir is None else data_dir 36 | parquet_files = sorted([ 37 | f for f in os.listdir(data_dir) 38 | if f.endswith('.parquet') and not f.endswith('.tmp') 39 | ]) 40 | parquet_paths = [os.path.join(data_dir, f) for f in parquet_files] 41 | return parquet_paths 42 | 43 | def parquets_iter_batched(split, start=0, step=1): 44 | """ 45 | Iterate through the dataset, in batches of underlying row_groups for efficiency. 46 | - split can be "train" or "val". the last parquet file will be val. 47 | - start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size 48 | """ 49 | assert split in ["train", "val"], "split must be 'train' or 'val'" 50 | parquet_paths = list_parquet_files() 51 | parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] 52 | for filepath in parquet_paths: 53 | pf = pq.ParquetFile(filepath) 54 | for rg_idx in range(start, pf.num_row_groups, step): 55 | rg = pf.read_row_group(rg_idx) 56 | texts = rg.column('text').to_pylist() 57 | yield texts 58 | 59 | # ----------------------------------------------------------------------------- 60 | def download_single_file(index): 61 | """ Downloads a single file index, with some backoff """ 62 | 63 | # Construct the local filepath for this file and skip if it already exists 64 | filename = index_to_filename(index) 65 | filepath = os.path.join(DATA_DIR, filename) 66 | if os.path.exists(filepath): 67 | print(f"Skipping {filepath} (already exists)") 68 | return True 69 | 70 | # Construct the remote URL for this file 71 | url = f"{BASE_URL}/{filename}" 72 | print(f"Downloading {filename}...") 73 | 74 | # Download with retries 75 | max_attempts = 5 76 | for attempt in range(1, max_attempts + 1): 77 | try: 78 | response = requests.get(url, stream=True, timeout=30) 79 | response.raise_for_status() 80 | # Write to temporary file first 81 | temp_path = filepath + f".tmp" 82 | with open(temp_path, 'wb') as f: 83 | for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks 84 | if chunk: 85 | f.write(chunk) 86 | # Move temp file to final location 87 | os.rename(temp_path, filepath) 88 | print(f"Successfully downloaded {filename}") 89 | return True 90 | 91 | except (requests.RequestException, IOError) as e: 92 | print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") 93 | # Clean up any partial files 94 | for path in [filepath + f".tmp", filepath]: 95 | if os.path.exists(path): 96 | try: 97 | os.remove(path) 98 | except: 99 | pass 100 | # Try a few times with exponential backoff: 2^attempt seconds 101 | if attempt < max_attempts: 102 | wait_time = 2 ** attempt 103 | print(f"Waiting {wait_time} seconds before retry...") 104 | time.sleep(wait_time) 105 | else: 106 | print(f"Failed to download {filename} after {max_attempts} attempts") 107 | return False 108 | 109 | return False 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards") 114 | parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable") 115 | parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)") 116 | args = parser.parse_args() 117 | 118 | num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) 119 | ids_to_download = list(range(num)) 120 | print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") 121 | print(f"Target directory: {DATA_DIR}") 122 | print() 123 | with Pool(processes=args.num_workers) as pool: 124 | results = pool.map(download_single_file, ids_to_download) 125 | 126 | # Report results 127 | successful = sum(1 for success in results if success) 128 | print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}") 129 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/speedrun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script is the "Best ChatGPT clone that $100 can buy", 4 | # It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour. 5 | 6 | # 1) Example launch (simplest): 7 | # bash speedrun.sh 8 | # 2) Example launch in a screen session (because the run takes ~4 hours): 9 | # screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh 10 | # 3) Example launch with wandb logging, but see below for setting up wandb first: 11 | # WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh 12 | 13 | # Default intermediate artifacts directory is in ~/.cache/nanochat 14 | export OMP_NUM_THREADS=1 15 | export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" 16 | mkdir -p $NANOCHAT_BASE_DIR 17 | 18 | # ----------------------------------------------------------------------------- 19 | # Python venv setup with uv 20 | 21 | # install uv (if not already installed) 22 | command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh 23 | # create a .venv local virtual environment (if it doesn't exist) 24 | [ -d ".venv" ] || uv venv 25 | # install the repo dependencies 26 | uv sync --extra gpu 27 | # activate venv so that `python` uses the project's venv instead of system python 28 | source .venv/bin/activate 29 | 30 | # ----------------------------------------------------------------------------- 31 | # wandb setup 32 | # If you wish to use wandb for logging (it's nice!, recommended). 33 | # 1) Make sure to first log in to wandb, e.g. run: 34 | # `wandb login` 35 | # 2) Set the WANDB_RUN environment variable when running this script, e.g.: 36 | # `WANDB_RUN=d26 bash speedrun.sh` 37 | if [ -z "$WANDB_RUN" ]; then 38 | # by default use "dummy" : it's handled as a special case, skips logging to wandb 39 | WANDB_RUN=dummy 40 | fi 41 | 42 | # ----------------------------------------------------------------------------- 43 | # During the course of the run, we will be writing markdown reports to the report/ 44 | # directory in the base dir. This command clears it out and writes a header section 45 | # with a bunch of system info and a timestamp that marks the start of the run. 46 | python -m nanochat.report reset 47 | 48 | # ----------------------------------------------------------------------------- 49 | # Tokenizer 50 | 51 | # Install Rust / Cargo 52 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y 53 | source "$HOME/.cargo/env" 54 | 55 | # Build the rustbpe Tokenizer 56 | uv run maturin develop --release --manifest-path rustbpe/Cargo.toml 57 | 58 | # Download the first ~2B characters of pretraining dataset 59 | # look at dev/repackage_data_reference.py for details on how this data was prepared 60 | # each data shard is ~250M chars 61 | # so we download 2e9 / 250e6 = 8 data shards at this point 62 | # each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk 63 | python -m nanochat.dataset -n 8 64 | # Immediately also kick off downloading more shards in the background while tokenizer trains 65 | # See comment below for why 240 is the right number here 66 | python -m nanochat.dataset -n 240 & 67 | DATASET_DOWNLOAD_PID=$! 68 | # train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data 69 | python -m scripts.tok_train --max_chars=2000000000 70 | # evaluate the tokenizer (report compression ratio etc.) 71 | python -m scripts.tok_eval 72 | 73 | # ----------------------------------------------------------------------------- 74 | # Base model (pretraining) 75 | 76 | # The d20 model is 561M parameters. 77 | # Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens. 78 | # Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars. 79 | # At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining. 80 | # Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk. 81 | # (The total number of shards available in the entire dataset is 1822.) 82 | echo "Waiting for dataset download to complete..." 83 | wait $DATASET_DOWNLOAD_PID 84 | 85 | # Number of processes/GPUs to use 86 | NPROC_PER_NODE=8 87 | 88 | # pretrain the d20 model 89 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN 90 | # evaluate the model on a larger chunk of train/val data and draw some samples 91 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss 92 | # evaluate the model on CORE tasks 93 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval 94 | 95 | # ----------------------------------------------------------------------------- 96 | # Midtraining (teach the model conversation special tokens, tool use, multiple choice) 97 | 98 | # download 2.3MB of synthetic identity conversations to impart a personality to nanochat 99 | # see dev/gen_synthetic_data.py for details on how this data was prepared and to get a sense of how you can easily tune it 100 | curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl 101 | 102 | # run midtraining and eval the model 103 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN 104 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid 105 | 106 | # ----------------------------------------------------------------------------- 107 | # Supervised Finetuning (domain adaptation to each sequence all by itself per row) 108 | 109 | # train sft and re-eval right away (should see a small bump) 110 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN 111 | torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft 112 | 113 | # ----------------------------------------------------------------------------- 114 | # Generate the full report by putting together all the sections 115 | # report.md is the output and will be copied to current directory for convenience 116 | python -m nanochat.report generate 117 | -------------------------------------------------------------------------------- /tiny_recursive_models/docs/assets/npyjs.js: -------------------------------------------------------------------------------- 1 | class npyjs { 2 | 3 | constructor(opts) { 4 | if (opts && !('convertFloat16' in opts)) { 5 | console.warn([ 6 | "npyjs constructor now accepts {convertFloat16?: boolean}.", 7 | "For usage, go to https://github.com/jhuapl-boss/npyjs." 8 | ].join(" ")); 9 | } 10 | 11 | this.convertFloat16 = opts?.convertFloat16 ?? true; 12 | 13 | this.dtypes = { 14 | "> 15) & 0x1; 92 | const exponent = (float16 >> 10) & 0x1f; 93 | const fraction = float16 & 0x3ff; 94 | 95 | // Handle special cases 96 | if (exponent === 0) { 97 | if (fraction === 0) { 98 | // Zero 99 | return sign ? -0 : 0; 100 | } 101 | // Denormalized number 102 | return (sign ? -1 : 1) * Math.pow(2, -14) * (fraction / 0x400); 103 | } else if (exponent === 0x1f) { 104 | if (fraction === 0) { 105 | // Infinity 106 | return sign ? -Infinity : Infinity; 107 | } 108 | // NaN 109 | return NaN; 110 | } 111 | 112 | // Normalized number 113 | return (sign ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 0x400); 114 | } 115 | 116 | parse(arrayBufferContents) { 117 | // const version = arrayBufferContents.slice(6, 8); // Uint8-encoded 118 | const headerLength = new DataView(arrayBufferContents.slice(8, 10)).getUint8(0); 119 | const offsetBytes = 10 + headerLength; 120 | 121 | const hcontents = new TextDecoder("utf-8").decode( 122 | new Uint8Array(arrayBufferContents.slice(10, 10 + headerLength)) 123 | ); 124 | const header = JSON.parse( 125 | hcontents 126 | .toLowerCase() // True -> true 127 | .replace(/'/g, '"') 128 | .replace("(", "[") 129 | .replace(/,*\),*/g, "]") 130 | ); 131 | const shape = header.shape; 132 | const dtype = this.dtypes[header.descr]; 133 | 134 | if (!dtype) { 135 | console.error(`Unsupported dtype: ${header.descr}`); 136 | return null; 137 | } 138 | 139 | const nums = new dtype.arrayConstructor( 140 | arrayBufferContents, 141 | offsetBytes 142 | ); 143 | 144 | // Convert float16 to float32 if converter exists 145 | const data = dtype.converter ? dtype.converter.call(this, nums) : nums; 146 | 147 | return { 148 | dtype: dtype.name, 149 | data: data, 150 | shape, 151 | fortranOrder: header.fortran_order 152 | }; 153 | } 154 | 155 | async load(filename, callback, fetchArgs) { 156 | /* 157 | Loads an array from a stream of bytes. 158 | */ 159 | fetchArgs = fetchArgs || {}; 160 | let arrayBuf; 161 | // If filename is ArrayBuffer 162 | if (filename instanceof ArrayBuffer) { 163 | arrayBuf = filename; 164 | } 165 | // If filename is a file path 166 | else { 167 | const resp = await fetch(filename, { ...fetchArgs }); 168 | arrayBuf = await resp.arrayBuffer(); 169 | } 170 | const result = this.parse(arrayBuf); 171 | if (callback) { 172 | return callback(result); 173 | } 174 | return result; 175 | } 176 | } -------------------------------------------------------------------------------- /rope_imaginary/rope_pp/single-gpu/train_rope_pp_single_gpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import random 5 | import logging 6 | from datetime import datetime 7 | 8 | import torch 9 | import numpy as np 10 | 11 | import datasets 12 | import wandb 13 | 14 | from transformers import AutoTokenizer 15 | 16 | from llama_variants.configuration_llama import LlamaConfig 17 | from llama_variants.modeling_llama_rope_pp import LlamaForCausalLM 18 | 19 | from utils.dataset_utils import StreamingTrainingJsonlZSD, StreamingTrainingHuggingFace, EvaluatingDataset 20 | from utils.training_engine import train_with_accelerate 21 | 22 | root = os.getcwd() 23 | tokenizer_path = 'meta-llama/Meta-Llama-3-8B' 24 | 25 | cache_dir = '' # set a cache_dir 26 | 27 | train_dataset_hf_id = 'mlfoundations/dclm-baseline-1.0' # Hugging Face dataset ID 28 | train_dataset_label = 'text' 29 | 30 | valid_dataset_hf_id = 'wikitext' # Hugging Face dataset ID 31 | valid_dataset_name = 'wikitext-2-raw-v1' # Subset name 32 | valid_dataset_split = 'validation' 33 | valid_dataset_abbr = 'wikitext' 34 | valid_dataset_label = 'text' 35 | 36 | seed = 42 37 | torch.manual_seed(seed) 38 | np.random.seed(seed) 39 | random.seed(seed) 40 | if torch.cuda.is_available(): 41 | torch.backends.cudnn.deterministic = True 42 | torch.cuda.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) 44 | 45 | torch.set_default_dtype(torch.bfloat16) 46 | 47 | import argparse 48 | 49 | parser = argparse.ArgumentParser(description='define fp config') 50 | parser.add_argument('--imag', action='store_true', default=False) 51 | parser.add_argument('--imag_mode', choices=['imag1', 'imag2', ], default='imag1') 52 | 53 | # imag1 stands for rope_pp_eh, and imag2 stands for rope_pp_ec, 54 | 55 | parser.add_argument('--config_abbr', type=str, default='376m') 56 | parser.add_argument('--save_abbr', type=str, default='376m') 57 | 58 | args = parser.parse_args() 59 | 60 | rope_config = { 61 | 'imag': args.imag, 62 | 'imag_mode': args.imag_mode, 63 | } 64 | 65 | config_abbr = args.config_abbr 66 | config_path = f'{root}/configs/rope-{config_abbr}-config.json' 67 | 68 | save_abbr = args.save_abbr 69 | 70 | # Modified to run on single 40GB A100 71 | batch_size = 3 72 | gradient_accumulation_steps = 64 73 | 74 | max_length = 4096 75 | valid_size = 4096 76 | 77 | max_steps = 100000 78 | eval_steps = 500 79 | warmup_steps = 10 80 | 81 | save_steps = 10000 82 | steps_to_save = [100, max_steps] 83 | 84 | config = LlamaConfig.from_pretrained(config_path) 85 | config.gradient_checkpointing = True # CRITICAL for memory 86 | config.use_cache = False # Required for gradient checkpointing 87 | config._attn_implementation = "flash_attention_2" 88 | config.torch_dtype = torch.bfloat16 89 | config.rope_config = rope_config 90 | config.ignore_index = config.eos_token_id 91 | 92 | model = LlamaForCausalLM(config=config) 93 | 94 | # Training configuration 95 | training_config = { 96 | 'output_dir': f'{root}/checkpoints/{save_abbr}', 97 | 'max_steps': max_steps, 98 | 'batch_size': batch_size, 99 | 'gradient_accumulation_steps': gradient_accumulation_steps, 100 | 'learning_rate': 5e-4, 101 | 'weight_decay': 0.1, 102 | 'adam_beta1': 0.95, 103 | 'adam_beta2': 0.99, 104 | 'warmup_steps': warmup_steps, 105 | 'max_grad_norm': 1.0, 106 | 'eval_steps': eval_steps, 107 | 'save_steps': save_steps, 108 | 'steps_to_save': steps_to_save, 109 | 'max_length': max_length, 110 | 'valid_dataset_abbr': valid_dataset_abbr, 111 | 'logging_steps': 1, 112 | 'resume_from_checkpoint': None, 113 | } 114 | 115 | print(f'{config = }', '\n') 116 | print('training_config = ', json.dumps(training_config, indent=2), '\n') 117 | 118 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False) 119 | tokenizer.pad_token = tokenizer.eos_token 120 | tokenizer.pad_token_id = tokenizer.eos_token_id 121 | 122 | # Load validation dataset from Hugging Face Hub 123 | print(f'Loading validation dataset from Hugging Face: {valid_dataset_hf_id}/{valid_dataset_name}', '\n') 124 | 125 | valid_dataset = datasets.load_dataset(valid_dataset_hf_id, valid_dataset_name, split=valid_dataset_split, 126 | cache_dir=cache_dir) 127 | # wikitext has a lot of empty lines -> causes NaNs 128 | valid_dataset = valid_dataset.filter(lambda x: len(x[valid_dataset_label].strip()) > 50) 129 | valid_dataset = valid_dataset.select(range(min(valid_size, len(valid_dataset)))) 130 | 131 | print(valid_dataset, '\n') 132 | 133 | # Load training dataset from Hugging Face Hub 134 | print(f'Loading training dataset from Hugging Face: {train_dataset_hf_id}', '\n') 135 | 136 | train_dataset = StreamingTrainingHuggingFace( 137 | dataset_id=train_dataset_hf_id, 138 | tokenizer=tokenizer, 139 | label_name=train_dataset_label, 140 | train_length=max_length, 141 | num_data=max_steps * batch_size * gradient_accumulation_steps, 142 | seed=seed, 143 | split='train', 144 | streaming=True, 145 | cache_dir=cache_dir 146 | ) 147 | 148 | valid_dataset = EvaluatingDataset(dataset=valid_dataset, tokenizer=tokenizer, 149 | label_name=valid_dataset_label, valid_length=max_length) 150 | 151 | print('dataset is ready !', '\n') 152 | 153 | # Initialize WandB 154 | os.environ["WANDB_MODE"] = "offline" 155 | os.environ["WANDB_PROJECT"] = "rope_pp" 156 | os.environ["WANDB_DIR"] = f"{root}/wandb" 157 | 158 | wandb.init( 159 | project="rope_pp", 160 | name=f'{save_abbr}-single-gpu-{datetime.now().strftime("%Y%m%d-%H%M%S")}', 161 | config=training_config, 162 | dir=f'{root}/wandb', 163 | ) 164 | 165 | print('checkpoints and model will be saved in', training_config['output_dir'], '\n') 166 | 167 | # Train! 168 | train_with_accelerate( 169 | model=model, 170 | tokenizer=tokenizer, 171 | train_dataset=train_dataset, 172 | eval_dataset=valid_dataset, 173 | config=training_config, 174 | deepspeed_config=None, 175 | ) 176 | 177 | wandb.finish() 178 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/tasks/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for all Tasks. 3 | A Task is basically a dataset of conversations, together with some 4 | metadata and often also evaluation criteria. 5 | Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk. 6 | """ 7 | 8 | import random 9 | 10 | class Task: 11 | """ 12 | Base class of a Task. Allows for lightweight slicing of the underlying dataset. 13 | """ 14 | 15 | def __init__(self, start=0, stop=None, step=1): 16 | # allows a lightweight logical view over a dataset 17 | assert start >= 0, f"Start must be non-negative, got {start}" 18 | assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}" 19 | assert step >= 1, f"Step must be strictly positive, got {step}" 20 | self.start = start 21 | self.stop = stop # could be None here 22 | self.step = step 23 | 24 | @property 25 | def eval_type(self): 26 | # one of 'generative' | 'categorical' 27 | raise NotImplementedError 28 | 29 | def num_examples(self): 30 | raise NotImplementedError 31 | 32 | def get_example(self, index): 33 | raise NotImplementedError 34 | 35 | def __len__(self): 36 | start = self.start 37 | stop = self.num_examples() if self.stop is None else self.stop 38 | step = self.step 39 | span = stop - start 40 | num = (span + step - 1) // step # ceil_div(span, step) 41 | assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns 42 | return num 43 | 44 | def __getitem__(self, index: int): 45 | assert isinstance(index, int), f"Index must be an integer, got {type(index)}" 46 | physical_index = self.start + index * self.step 47 | conversation = self.get_example(physical_index) 48 | return conversation 49 | 50 | def evaluate(self, problem, completion): 51 | raise NotImplementedError 52 | 53 | 54 | class TaskMixture(Task): 55 | """ 56 | For SFT Training it becomes useful to train on a mixture of datasets. 57 | Fun trick: if you wish to oversample any task, just pass it in multiple times in the list. 58 | """ 59 | 60 | def __init__(self, tasks, **kwargs): 61 | super().__init__(**kwargs) 62 | # tasks is a list of Task objects 63 | self.tasks = tasks 64 | self.lengths = [len(task) for task in self.tasks] 65 | self.num_conversations = sum(self.lengths) 66 | # Build list of all (task_idx, local_idx) pairs 67 | self.index_map = [] 68 | for task_idx, task_length in enumerate(self.lengths): 69 | for local_idx in range(task_length): 70 | self.index_map.append((task_idx, local_idx)) 71 | # Deterministically shuffle to mix tasks throughout training 72 | rng = random.Random(42) 73 | rng.shuffle(self.index_map) 74 | # Note: this is not the most elegant or best solution, but it's ok for now 75 | 76 | def num_examples(self): 77 | return self.num_conversations 78 | 79 | def get_example(self, index): 80 | """ 81 | Access conversations according to a deterministic shuffle of all examples. 82 | This ensures tasks are mixed throughout training, regardless of dataset size. 83 | """ 84 | assert 0 <= index < self.num_conversations, f"Index {index} out of range for mixture with {self.num_conversations} conversations" 85 | task_idx, local_idx = self.index_map[index] 86 | return self.tasks[task_idx][local_idx] 87 | 88 | 89 | class TaskSequence(Task): 90 | """ 91 | For SFT Training sometimes we want to sequentially train on a list of tasks. 92 | This is useful for cases that require a training curriculum. 93 | """ 94 | 95 | def __init__(self, tasks, **kwargs): 96 | super().__init__(**kwargs) 97 | self.tasks = tasks 98 | self.lengths = [len(task) for task in self.tasks] 99 | self.num_conversations = sum(self.lengths) 100 | 101 | def num_examples(self): 102 | return self.num_conversations 103 | 104 | def get_example(self, index): 105 | assert 0 <= index < self.num_conversations, f"Index {index} out of range for sequence with {self.num_conversations} conversations" 106 | for task_idx, task_length in enumerate(self.lengths): 107 | if index < task_length: 108 | return self.tasks[task_idx][index] 109 | index -= task_length 110 | 111 | 112 | def render_mc(question, letters, choices): 113 | """ 114 | The common multiple choice rendering format we will use. 115 | 116 | Note two important design decisions: 117 | 1) 118 | Bigger models don't care as much, but smaller models prefer to have 119 | the letter *after* the choice, which results in better binding. 120 | 2) 121 | There is no whitespace between the delimiter (=) and the letter. 122 | This is actually critical because the tokenizer has different token ids 123 | for " A" vs. "A". The assistant responses will be just the letter itself, 124 | i.e. "A", so it is important that here in the prompt it is the exact same 125 | token, i.e. "A" with no whitespace before it. Again, bigger models don't care 126 | about this too much, but smaller models do care about some of these details. 127 | """ 128 | query = f"Multiple Choice question: {question}\n" 129 | query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)]) 130 | query += "\nRespond only with the letter of the correct answer." 131 | return query 132 | 133 | 134 | if __name__ == "__main__": 135 | # very lightweight test of slicing 136 | from tasks.mmlu import MMLU 137 | 138 | ds = MMLU(subset="auxiliary_train", split="train") 139 | print("Length of MMLU: ", len(ds)) 140 | ex = ds[5] 141 | print("5th example: ", ex) 142 | 143 | ds = MMLU(subset="auxiliary_train", split="train", start=5, stop=10) 144 | print("Length of sliced MMLU[5:10]: ", len(ds)) 145 | print("0th example of sliced MMLU: ", ds[0]) 146 | 147 | print("They match: ", ex == ds[0]) 148 | -------------------------------------------------------------------------------- /rope_imaginary/rope_pp/utils/callback_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import random 5 | from datetime import datetime 6 | 7 | import torch 8 | 9 | from transformers import Trainer, TrainingArguments 10 | from transformers import TrainerCallback, TrainerState, TrainerControl 11 | 12 | 13 | class CustomLoggingCallback(TrainerCallback): 14 | 15 | def __init__(self, max_steps, batch_size, max_length, world_size, valid_dataset_abbr, logging_steps=10): 16 | self.cur_step, self.max_steps = 0, max_steps 17 | self.start_time, self.valid_total_time = datetime.now(), 0 18 | self.last_step_time = None 19 | self.batch_token = batch_size * max_length 20 | self.world_size = world_size 21 | self.valid_dataset_abbr = valid_dataset_abbr 22 | self.logging_steps = logging_steps 23 | 24 | # Handle both distributed and single GPU training 25 | if torch.distributed.is_initialized(): 26 | self.rank = torch.distributed.get_rank() 27 | else: 28 | self.rank = 0 29 | 30 | def _format_time(self, seconds): 31 | """Format seconds into human-readable time (e.g., 1.23m, 45.67s, 2.34h)""" 32 | if seconds < 60: 33 | return f"{seconds:.2f}s" 34 | elif seconds < 3600: 35 | return f"{seconds / 60:.2f}m" 36 | else: 37 | return f"{seconds / 3600:.2f}h" 38 | 39 | def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs): 40 | now = datetime.now() 41 | 42 | if logs is not None and 'grad_norm' in logs: 43 | # Training step logging 44 | if self.cur_step == 0: 45 | self.start_time = now 46 | self.last_step_time = now 47 | 48 | else: 49 | # Calculate time metrics 50 | total_time = (now - self.start_time).total_seconds() 51 | train_total_time = total_time - self.valid_total_time 52 | step_time_ms = (now - self.last_step_time).total_seconds() * 1000 if self.last_step_time else 0 53 | 54 | # Calculate token metrics 55 | num_consume_token = self.cur_step * self.batch_token 56 | tokens_per_sec = (self.batch_token / (step_time_ms / 1000)) if step_time_ms > 0 else 0 57 | avg_tgs = num_consume_token / train_total_time / self.world_size if train_total_time > 0 else 0 58 | 59 | # Calculate progress 60 | cur_percent = self.cur_step / self.max_steps * 100 61 | 62 | # Get loss and grad norm from logs 63 | loss = logs.get('loss', 0.0) 64 | grad_norm = logs.get('grad_norm', 0.0) 65 | learning_rate = logs.get('learning_rate', 0.0) 66 | 67 | # Calculate learning rate multiplier (relative to base lr) 68 | lr_multiplier = learning_rate / args.learning_rate if args.learning_rate > 0 else 1.0 69 | 70 | # Format step string with leading zeros 71 | step_str = f"{self.cur_step:0{len(str(self.max_steps))}d}" 72 | 73 | if self.rank == 0 and self.cur_step % self.logging_steps == 0: 74 | # Nanochat-style log format 75 | log_str = ( 76 | f"step {step_str}/{self.max_steps} ({cur_percent:.2f}%) | " 77 | f"loss: {loss:.6f} | " 78 | f"grad norm: {grad_norm:.4f} | " 79 | f"lr: {learning_rate:.2e} | " 80 | f"lrm: {lr_multiplier:.2f} | " 81 | f"dt: {step_time_ms:.2f}ms | " 82 | f"tok/sec: {tokens_per_sec:,.0f} | " 83 | f"avg tok/sec: {avg_tgs:,.0f} | " 84 | f"total time: {self._format_time(total_time)}" 85 | ) 86 | print(log_str) 87 | 88 | self.last_step_time = now 89 | self.cur_step += 1 90 | 91 | else: 92 | # Evaluation logging 93 | total_time = (now - self.start_time).total_seconds() 94 | cur_percent = self.cur_step / self.max_steps * 100 95 | 96 | if self.cur_step == 0: 97 | self.start_time = now 98 | total_time = 0 99 | 100 | if self.rank == 0: 101 | step_str = f"{self.cur_step:0{len(str(self.max_steps))}d}" 102 | eval_str = ( 103 | f"step {step_str}/{self.max_steps} ({cur_percent:.2f}%) | " 104 | f"EVALUATION | " 105 | f"total time: {self._format_time(total_time)}" 106 | ) 107 | print(eval_str, end='') 108 | 109 | def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics=None, **kwargs): 110 | if f'eval_{self.valid_dataset_abbr}_runtime' in metrics: 111 | self.valid_total_time += metrics[f'eval_{self.valid_dataset_abbr}_runtime'] 112 | 113 | # Print evaluation metrics 114 | if self.rank == 0 and metrics: 115 | eval_loss = metrics.get(f'eval_{self.valid_dataset_abbr}_loss', None) 116 | if eval_loss is not None: 117 | print(f" | eval loss: {eval_loss:.6f}") 118 | else: 119 | print() 120 | 121 | 122 | class CheckpointingCallback(TrainerCallback): 123 | def __init__(self, steps_to_save): 124 | self.steps_to_save = steps_to_save 125 | 126 | # Handle both distributed and single GPU training 127 | if torch.distributed.is_initialized(): 128 | self.rank = torch.distributed.get_rank() 129 | else: 130 | self.rank = 0 131 | 132 | def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 133 | if state.global_step in self.steps_to_save: 134 | control.should_save = True 135 | control.should_evaluate = True 136 | if self.rank == 0: 137 | print(f"Saving checkpoint at step {state.global_step}") 138 | return control 139 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/training/optimizers.py: -------------------------------------------------------------------------------- 1 | """Custom optimizers vendored into the project.""" 2 | 3 | # NOTE: Torch needs to be imported before the custom 4 | # extensions. Otherwise libc10.so cannot be found. 5 | import torch 6 | import os 7 | from typing import List, Tuple, Union 8 | from torch import Tensor 9 | from torch.optim.optimizer import Optimizer, ParamsT 10 | 11 | # Compile the CUDA extension on-the-fly using torch.utils.cpp_extension.load 12 | # This ensures the backend is built properly with the correct CUDA architecture 13 | _adam_atan2_backend = None 14 | 15 | def _get_adam_backend(): 16 | global _adam_atan2_backend 17 | if _adam_atan2_backend is None: 18 | from torch.utils.cpp_extension import load 19 | 20 | # Get the directory containing this file 21 | current_dir = os.path.dirname(os.path.abspath(__file__)) 22 | csrc_dir = os.path.join(current_dir, "adam_atan2_csrc") 23 | 24 | # Compile the CUDA extension 25 | _adam_atan2_backend = load( 26 | name="adam_atan2_backend", 27 | sources=[ 28 | os.path.join(csrc_dir, "ops.cu"), 29 | os.path.join(csrc_dir, "adam_atan2.cu"), 30 | ], 31 | extra_include_paths=[csrc_dir], 32 | extra_cflags=["-O2", "-std=c++17"], 33 | extra_cuda_cflags=[ 34 | "-O2", 35 | "-std=c++17", 36 | "--expt-extended-lambda", 37 | ], 38 | verbose=False, 39 | ) 40 | return _adam_atan2_backend 41 | 42 | 43 | class AdamATan2(Optimizer): 44 | def __init__( 45 | self, 46 | params: ParamsT, 47 | lr: Union[float, Tensor] = 1e-3, 48 | betas: Tuple[float, float] = (0.9, 0.999), 49 | weight_decay: float = 1e-2 50 | ): 51 | if not 0.0 <= lr: 52 | raise ValueError(f"Invalid learning rate: {lr}") 53 | if not 0.0 <= betas[0] < 1.0: 54 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 55 | if not 0.0 <= betas[1] < 1.0: 56 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 57 | if not 0.0 <= weight_decay: 58 | raise ValueError(f"Invalid weight_decay value: {weight_decay}") 59 | 60 | defaults = dict( 61 | lr=lr, 62 | betas=betas, 63 | weight_decay=weight_decay 64 | ) 65 | super().__init__(params, defaults) 66 | 67 | def _init_group( 68 | self, 69 | group, 70 | params_with_grad, 71 | grads, 72 | exp_avgs, 73 | exp_avg_sqs, 74 | state_steps 75 | ): 76 | for p in group["params"]: 77 | if p.grad is None: 78 | continue 79 | 80 | params_with_grad.append(p) 81 | if p.grad.is_sparse: 82 | raise RuntimeError("AdamW does not support sparse gradients") 83 | grads.append(p.grad) 84 | 85 | state = self.state[p] 86 | 87 | # State initialization 88 | if len(state) == 0: 89 | # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off. 90 | # This is because kernel launches are costly on CUDA and XLA. 91 | state["step"] = ( 92 | torch.zeros((), dtype=torch.float32, device=p.device) 93 | ) 94 | # Exponential moving average of gradient values 95 | state["exp_avg"] = torch.zeros_like( 96 | p, memory_format=torch.preserve_format 97 | ) 98 | # Exponential moving average of squared gradient values 99 | state["exp_avg_sq"] = torch.zeros_like( 100 | p, memory_format=torch.preserve_format 101 | ) 102 | 103 | exp_avgs.append(state["exp_avg"]) 104 | exp_avg_sqs.append(state["exp_avg_sq"]) 105 | state_steps.append(state["step"]) 106 | 107 | def step(self): 108 | """Perform a single optimization step. 109 | """ 110 | self._cuda_graph_capture_health_check() 111 | 112 | for group in self.param_groups: 113 | params_with_grad = [] 114 | grads = [] 115 | exp_avgs = [] 116 | exp_avg_sqs = [] 117 | state_steps = [] 118 | beta1, beta2 = group["betas"] 119 | 120 | self._init_group( 121 | group, 122 | params_with_grad, 123 | grads, 124 | exp_avgs, 125 | exp_avg_sqs, 126 | state_steps 127 | ) 128 | 129 | _adam_atan2( 130 | params_with_grad, 131 | grads, 132 | exp_avgs, 133 | exp_avg_sqs, 134 | state_steps, 135 | beta1=beta1, 136 | beta2=beta2, 137 | lr=group["lr"], 138 | weight_decay=group["weight_decay"] 139 | ) 140 | 141 | 142 | def _adam_atan2( 143 | params: List[Tensor], 144 | grads: List[Tensor], 145 | exp_avgs: List[Tensor], 146 | exp_avg_sqs: List[Tensor], 147 | state_steps: List[Tensor], 148 | beta1: float, 149 | beta2: float, 150 | lr: float, 151 | weight_decay: float 152 | ) -> None: 153 | if not params: 154 | return 155 | 156 | # We only support scalar lr. 157 | assert not isinstance(lr, Tensor) 158 | 159 | grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 160 | [params, grads, exp_avgs, exp_avg_sqs, state_steps]) 161 | for (device, _), ((device_params, 162 | device_grads, 163 | device_exp_avgs, 164 | device_exp_avg_sqs, 165 | device_state_steps, ), _) in grouped_tensors.items(): 166 | torch._foreach_add_(device_state_steps, 1) 167 | backend = _get_adam_backend() 168 | backend.adam_atan2_cuda_impl_( 169 | device_params, 170 | device_grads, 171 | device_exp_avgs, 172 | device_exp_avg_sqs, 173 | device_state_steps, 174 | lr, 175 | beta1, 176 | beta2, 177 | weight_decay 178 | ) 179 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/training/adam_atan2_csrc/adam_atan2.cu: -------------------------------------------------------------------------------- 1 | // Vendored code from https://github.com/imoneoi/adam-atan2 to mitigate some setup issues we ran into 2 | #include "adam_atan2.h" 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | 11 | namespace adam_atan2 { 12 | 13 | using at::native::kILP; 14 | 15 | constexpr int kArgsDepth = 4; 16 | 17 | constexpr uint8_t kParamIdx = 0; 18 | constexpr uint8_t kGradIdx = 1; 19 | constexpr uint8_t kExpAvgIdx = 2; 20 | constexpr uint8_t kExpAvgSqIdx = 3; 21 | 22 | template 23 | __device__ __forceinline__ T lerp(const T v0, const T v1, const T t) { 24 | // NOTE(one): Identical to PyTorch when t < 0.5 25 | // https://github.com/pytorch/pytorch/blob/b7f25226929e70187a9f36c393665abad0b25190/aten/src/ATen/native/Lerp.h#L21 26 | return fma(t, v1, fma(-t, v0, v0)); 27 | } 28 | 29 | template 30 | __device__ __forceinline__ void adam_math( 31 | scalar_type r_args[kArgsDepth][kILP], 32 | const opmath_t &step_size, 33 | const opmath_t &wd_alpha, 34 | const opmath_t &mbeta1, 35 | const opmath_t &mbeta2, 36 | const opmath_t &bias_correction2_sqrt) 37 | { 38 | #pragma unroll 39 | for (int ii = 0; ii < kILP; ii++) 40 | { 41 | // Load values. 42 | opmath_t param = static_cast(r_args[kParamIdx][ii]); 43 | const opmath_t grad = static_cast(r_args[kGradIdx][ii]); 44 | 45 | opmath_t exp_avg = static_cast(r_args[kExpAvgIdx][ii]); 46 | opmath_t exp_avg_sq = static_cast(r_args[kExpAvgSqIdx][ii]); 47 | 48 | param *= wd_alpha; 49 | 50 | exp_avg = lerp(exp_avg, grad, mbeta1); 51 | exp_avg_sq = lerp(exp_avg_sq, grad * grad, mbeta2); 52 | 53 | const opmath_t denom = std::sqrt(exp_avg_sq) / bias_correction2_sqrt; 54 | param -= step_size * std::atan2(exp_avg, denom); 55 | 56 | // Store results. 57 | r_args[kParamIdx][ii] = param; 58 | r_args[kExpAvgIdx][ii] = exp_avg; 59 | r_args[kExpAvgSqIdx][ii] = exp_avg_sq; 60 | } 61 | } 62 | 63 | template 64 | struct FusedAdamMathFunctor { 65 | using opmath_t = at::opmath_type; 66 | __device__ __forceinline__ void operator()( 67 | int chunk_size, 68 | at::native::FusedOptimizerTensorListMetadata& tl, 69 | const double& lr, 70 | const double& beta1, 71 | const double& beta2, 72 | const double& weight_decay) { 73 | const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; 74 | const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; 75 | 76 | const auto [step_size, wd_alpha, bias_correction2_sqrt, mbeta1, mbeta2] = [&]() -> std::tuple { 77 | auto* step_count = reinterpret_cast(tl.state_steps_addresses[tensor_loc]); 78 | const auto bias_correction1 = 1 - at::native::pow_(beta1, *step_count); 79 | const auto bias_correction2 = 1 - at::native::pow_(beta2, *step_count); 80 | const auto bias_correction2_sqrt = std::sqrt(bias_correction2); 81 | 82 | return { 83 | static_cast(lr / bias_correction1), 84 | static_cast(1 - lr * weight_decay), 85 | static_cast(bias_correction2_sqrt), 86 | static_cast(1 - beta1), 87 | static_cast(1 - beta2) 88 | }; 89 | }(); 90 | 91 | scalar_type* args[kArgsDepth]; 92 | scalar_type r_args[kArgsDepth][kILP]; 93 | const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size; 94 | 95 | const bool all_aligned{ 96 | at::native::init_args(args, tl, chunk_idx, chunk_size, tensor_loc)}; 97 | if ((n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned) { 98 | for (int64_t i_start = threadIdx.x; 99 | i_start * kILP < n && i_start * kILP < chunk_size; 100 | i_start += blockDim.x) { 101 | #pragma unroll 102 | for (int i = 0; i < kArgsDepth; i++) { 103 | at::native::load_store(r_args[i], args[i], 0, i_start); 104 | } 105 | adam_math( 106 | r_args, 107 | step_size, 108 | wd_alpha, 109 | mbeta1, 110 | mbeta2, 111 | bias_correction2_sqrt); 112 | #pragma unroll 113 | for (int i = 0; i < kArgsDepth; i++) { 114 | if (i != kGradIdx) { 115 | at::native::load_store(args[i], r_args[i], i_start, 0); 116 | } 117 | } 118 | } 119 | } else { 120 | for (int64_t i_start = 0; i_start < n && i_start < chunk_size; 121 | i_start += blockDim.x * kILP) { 122 | at::native::load_args(r_args, args, i_start, chunk_size, n); 123 | adam_math( 124 | r_args, 125 | step_size, 126 | wd_alpha, 127 | mbeta1, 128 | mbeta2, 129 | bias_correction2_sqrt); 130 | #pragma unroll 131 | for (int i = 0; i < kArgsDepth; i++) { 132 | if (i != kGradIdx) { 133 | at::native::store_args(args[i], r_args[i], i_start, chunk_size, n); 134 | } 135 | } 136 | } 137 | } 138 | } 139 | }; 140 | 141 | void adam_atan2_cuda_impl_( 142 | std::vector params, 143 | std::vector grads, 144 | std::vector exp_avgs, 145 | std::vector exp_avg_sqs, 146 | std::vector state_steps, 147 | const double lr, 148 | const double beta1, 149 | const double beta2, 150 | const double weight_decay) { 151 | std::vector> tensor_lists{params, grads, exp_avgs, exp_avg_sqs}; 152 | 153 | AT_DISPATCH_FLOATING_TYPES_AND2( 154 | at::ScalarType::Half, 155 | at::ScalarType::BFloat16, 156 | params[0].scalar_type(), 157 | "adam_atan2_kernel_cuda", 158 | [&]() { 159 | at::native::multi_tensor_apply_for_fused_optimizer( 160 | tensor_lists, 161 | state_steps, 162 | FusedAdamMathFunctor(), 163 | lr, 164 | beta1, 165 | beta2, 166 | weight_decay); 167 | }); 168 | } 169 | 170 | } // namespace adam_atan2 171 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/data/build_sudoku_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import os 3 | import csv 4 | import json 5 | import numpy as np 6 | 7 | from argdantic import ArgParser 8 | from pydantic import BaseModel 9 | from tqdm import tqdm 10 | from huggingface_hub import hf_hub_download 11 | 12 | from trm.data.common import PuzzleDatasetMetadata 13 | 14 | 15 | cli = ArgParser() 16 | 17 | 18 | class DataProcessConfig(BaseModel): 19 | source_repo: str = "sapientinc/sudoku-extreme" 20 | output_dir: str = "data/sudoku-extreme-full" 21 | 22 | subsample_size: Optional[int] = None 23 | min_difficulty: Optional[int] = None 24 | num_aug: int = 0 25 | 26 | 27 | def shuffle_sudoku(board: np.ndarray, solution: np.ndarray): 28 | # Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged 29 | digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0)) 30 | 31 | # Randomly decide whether to transpose. 32 | transpose_flag = np.random.rand() < 0.5 33 | 34 | # Generate a valid row permutation: 35 | # - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows. 36 | bands = np.random.permutation(3) 37 | row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands]) 38 | 39 | # Similarly for columns (stacks). 40 | stacks = np.random.permutation(3) 41 | col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks]) 42 | 43 | # Build an 81->81 mapping. For each new cell at (i, j) 44 | # (row index = i // 9, col index = i % 9), 45 | # its value comes from old row = row_perm[i//9] and old col = col_perm[i%9]. 46 | mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)]) 47 | 48 | def apply_transformation(x: np.ndarray) -> np.ndarray: 49 | # Apply transpose flag 50 | if transpose_flag: 51 | x = x.T 52 | # Apply the position mapping. 53 | new_board = x.flatten()[mapping].reshape(9, 9).copy() 54 | # Apply digit mapping 55 | return digit_map[new_board] 56 | 57 | return apply_transformation(board), apply_transformation(solution) 58 | 59 | 60 | def convert_subset(set_name: str, config: DataProcessConfig): 61 | # Read CSV 62 | inputs = [] 63 | labels = [] 64 | 65 | with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: 66 | reader = csv.reader(csvfile) 67 | next(reader) # Skip header 68 | for source, q, a, rating in reader: 69 | if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty): 70 | assert len(q) == 81 and len(a) == 81 71 | 72 | inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0')) 73 | labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0')) 74 | 75 | # If subsample_size is specified for the training set, 76 | # randomly sample the desired number of examples. 77 | if set_name == "train" and config.subsample_size is not None: 78 | total_samples = len(inputs) 79 | if config.subsample_size < total_samples: 80 | indices = np.random.choice(total_samples, size=config.subsample_size, replace=False) 81 | inputs = [inputs[i] for i in indices] 82 | labels = [labels[i] for i in indices] 83 | 84 | # Generate dataset 85 | num_augments = config.num_aug if set_name == "train" else 0 86 | 87 | results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} 88 | puzzle_id = 0 89 | example_id = 0 90 | 91 | results["puzzle_indices"].append(0) 92 | results["group_indices"].append(0) 93 | 94 | for orig_inp, orig_out in zip(tqdm(inputs), labels): 95 | for aug_idx in range(1 + num_augments): 96 | # First index is not augmented 97 | if aug_idx == 0: 98 | inp, out = orig_inp, orig_out 99 | else: 100 | inp, out = shuffle_sudoku(orig_inp, orig_out) 101 | 102 | # Push puzzle (only single example) 103 | results["inputs"].append(inp) 104 | results["labels"].append(out) 105 | example_id += 1 106 | puzzle_id += 1 107 | 108 | results["puzzle_indices"].append(example_id) 109 | results["puzzle_identifiers"].append(0) 110 | 111 | # Push group 112 | results["group_indices"].append(puzzle_id) 113 | 114 | # To Numpy 115 | def _seq_to_numpy(seq): 116 | arr = np.concatenate(seq).reshape(len(seq), -1) 117 | 118 | assert np.all((arr >= 0) & (arr <= 9)) 119 | return arr + 1 120 | 121 | results = { 122 | "inputs": _seq_to_numpy(results["inputs"]), 123 | "labels": _seq_to_numpy(results["labels"]), 124 | 125 | "group_indices": np.array(results["group_indices"], dtype=np.int32), 126 | "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32), 127 | "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32), 128 | } 129 | 130 | # Metadata 131 | metadata = PuzzleDatasetMetadata( 132 | seq_len=81, 133 | vocab_size=10 + 1, # PAD + "0" ... "9" 134 | pad_id=0, 135 | ignore_label_id=0, 136 | blank_identifier_id=0, 137 | num_puzzle_identifiers=1, 138 | total_groups=len(results["group_indices"]) - 1, 139 | mean_puzzle_examples=1, 140 | total_puzzles=len(results["group_indices"]) - 1, 141 | sets=["all"] 142 | ) 143 | 144 | # Save metadata as JSON. 145 | save_dir = os.path.join(config.output_dir, set_name) 146 | os.makedirs(save_dir, exist_ok=True) 147 | 148 | with open(os.path.join(save_dir, "dataset.json"), "w") as f: 149 | json.dump(metadata.model_dump(), f) 150 | 151 | # Save data 152 | for k, v in results.items(): 153 | np.save(os.path.join(save_dir, f"all__{k}.npy"), v) 154 | 155 | # Save IDs mapping (for visualization only) 156 | with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: 157 | json.dump([""], f) 158 | 159 | 160 | @cli.command(singleton=True) 161 | def preprocess_data(config: DataProcessConfig): 162 | convert_subset("train", config) 163 | convert_subset("test", config) 164 | 165 | 166 | if __name__ == "__main__": 167 | cli() 168 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/models/layers.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import einops 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | #try: 8 | # from flash_attn_interface import flash_attn_func # type: ignore[import] 9 | #except ImportError: 10 | # # Fallback to FlashAttention 2 11 | # from flash_attn import flash_attn_func # type: ignore[import] 12 | from torch.nn.functional import scaled_dot_product_attention 13 | 14 | from tiny_recursive_models.models.common import trunc_normal_init_ 15 | 16 | 17 | CosSin = Tuple[torch.Tensor, torch.Tensor] 18 | 19 | 20 | def _find_multiple(a, b): 21 | return (-(a // -b)) * b 22 | 23 | 24 | def rotate_half(x: torch.Tensor): 25 | """Rotates half the hidden dims of the input.""" 26 | x1 = x[..., : x.shape[-1] // 2] 27 | x2 = x[..., x.shape[-1] // 2 :] 28 | return torch.cat((-x2, x1), dim=-1) 29 | 30 | 31 | def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): 32 | # q, k: [bs, seq_len, num_heads, head_dim] 33 | # cos, sin: [seq_len, head_dim] 34 | orig_dtype = q.dtype 35 | q = q.to(cos.dtype) 36 | k = k.to(cos.dtype) 37 | 38 | q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2)) 39 | k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2)) 40 | 41 | return q_embed.to(orig_dtype), k_embed.to(orig_dtype) 42 | 43 | 44 | class CastedLinear(nn.Module): 45 | def __init__(self, 46 | in_features: int, 47 | out_features: int, 48 | bias: bool): 49 | super().__init__() 50 | # Truncated LeCun normal init 51 | self.weight = nn.Parameter( 52 | trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5)) 53 | ) 54 | self.bias = None 55 | if bias: 56 | # Zero init bias 57 | self.bias = nn.Parameter(torch.zeros((out_features, ))) 58 | 59 | def forward(self, input: torch.Tensor) -> torch.Tensor: 60 | return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None) 61 | 62 | 63 | class CastedEmbedding(nn.Module): 64 | def __init__(self, 65 | num_embeddings: int, 66 | embedding_dim: int, 67 | init_std: float, 68 | cast_to: torch.dtype): 69 | super().__init__() 70 | self.cast_to = cast_to 71 | 72 | # Truncated LeCun normal init 73 | self.embedding_weight = nn.Parameter( 74 | trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std) 75 | ) 76 | 77 | def forward(self, input: torch.Tensor) -> torch.Tensor: 78 | return F.embedding(input, self.embedding_weight.to(self.cast_to)) 79 | 80 | 81 | class RotaryEmbedding(nn.Module): 82 | def __init__(self, dim, max_position_embeddings, base, device=None): 83 | super().__init__() 84 | 85 | # RoPE 86 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) 87 | t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device) 88 | freqs = torch.outer(t, inv_freq) 89 | 90 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 91 | emb = torch.cat((freqs, freqs), dim=-1) 92 | self.cos_cached = nn.Buffer(emb.cos(), persistent=False) 93 | self.sin_cached = nn.Buffer(emb.sin(), persistent=False) 94 | 95 | def forward(self): 96 | return self.cos_cached, self.sin_cached 97 | 98 | 99 | class Attention(nn.Module): 100 | def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False): 101 | super().__init__() 102 | 103 | self.hidden_size = hidden_size 104 | self.head_dim = head_dim 105 | self.output_size = head_dim * num_heads 106 | self.num_heads = num_heads 107 | self.num_key_value_heads = num_key_value_heads 108 | self.causal = causal 109 | 110 | self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False) 111 | self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False) 112 | 113 | def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: 114 | batch_size, seq_len, _ = hidden_states.shape 115 | 116 | # hidden_states: [bs, seq_len, num_heads, head_dim] 117 | qkv = self.qkv_proj(hidden_states) 118 | 119 | # Split head 120 | qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) 121 | query = qkv[:, :, :self.num_heads] 122 | key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads] 123 | value = qkv[:, :, self.num_heads + self.num_key_value_heads:] 124 | 125 | # RoPE 126 | if cos_sin is not None: 127 | cos, sin = cos_sin 128 | query, key = apply_rotary_pos_emb(query, key, cos, sin) 129 | 130 | # flash attn 131 | query, key, value = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (query, key, value)) # needed for scaled_dot_product_attention but not flash_attn_func 132 | attn_output = scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal) 133 | attn_output = einops.rearrange(attn_output, 'B H S D -> B S H D') 134 | attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore 135 | return self.o_proj(attn_output) 136 | 137 | class LinearSwish(nn.Module): 138 | def __init__(self, hidden_size: int, reverse=False): 139 | super().__init__() 140 | 141 | self.linear = CastedLinear(hidden_size, hidden_size, bias=False) 142 | self.reverse = reverse 143 | 144 | def forward(self, x): 145 | if self.reverse: 146 | return F.silu(self.linear(x)) 147 | else: 148 | return self.linear(F.silu(x)) 149 | 150 | 151 | class SwiGLU(nn.Module): 152 | def __init__(self, hidden_size: int, expansion: float): 153 | super().__init__() 154 | inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256) 155 | 156 | self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False) 157 | self.down_proj = CastedLinear(inter, hidden_size, bias=False) 158 | 159 | def forward(self, x): 160 | gate, up = self.gate_up_proj(x).chunk(2, dim=-1) 161 | return self.down_proj(F.silu(gate) * up) 162 | 163 | def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor: 164 | input_dtype = hidden_states.dtype 165 | hidden_states = hidden_states.to(torch.float32) 166 | 167 | variance = hidden_states.square().mean(-1, keepdim=True) 168 | hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) 169 | return hidden_states.to(input_dtype) 170 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/nanochat/checkpoint_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for saving and loading model/optim/state checkpoints. 3 | """ 4 | import os 5 | import re 6 | import glob 7 | import json 8 | import logging 9 | import torch 10 | 11 | from nanochat.common import get_base_dir 12 | from nanochat.gpt import GPT, GPTConfig 13 | from nanochat.tokenizer import get_tokenizer 14 | from nanochat.common import setup_default_logging 15 | 16 | # Set up logging 17 | setup_default_logging() 18 | logger = logging.getLogger(__name__) 19 | def log0(message): 20 | if int(os.environ.get('RANK', 0)) == 0: 21 | logger.info(message) 22 | 23 | def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): 24 | if rank == 0: 25 | os.makedirs(checkpoint_dir, exist_ok=True) 26 | # Save the model state parameters 27 | model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") 28 | torch.save(model_data, model_path) 29 | logger.info(f"Saved model parameters to: {model_path}") 30 | # Save the metadata dict as json 31 | meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") 32 | with open(meta_path, "w", encoding="utf-8") as f: 33 | json.dump(meta_data, f, indent=2) 34 | logger.info(f"Saved metadata to: {meta_path}") 35 | # Note that optimizer state is sharded across ranks, so each rank must save its own. 36 | if optimizer_data is not None: 37 | os.makedirs(checkpoint_dir, exist_ok=True) 38 | optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") 39 | torch.save(optimizer_data, optimizer_path) 40 | logger.info(f"Saved optimizer state to: {optimizer_path}") 41 | 42 | def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): 43 | # Load the model state 44 | model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") 45 | model_data = torch.load(model_path, map_location=device) 46 | # Load the optimizer state if requested 47 | optimizer_data = None 48 | if load_optimizer: 49 | optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") 50 | optimizer_data = torch.load(optimizer_path, map_location=device) 51 | # Load the metadata 52 | meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") 53 | with open(meta_path, "r", encoding="utf-8") as f: 54 | meta_data = json.load(f) 55 | return model_data, optimizer_data, meta_data 56 | 57 | 58 | def build_model(checkpoint_dir, step, device, phase): 59 | """ 60 | A bunch of repetitive code to build a model from a given checkpoint. 61 | Returns: 62 | - base model - uncompiled, not wrapped in DDP 63 | - tokenizer 64 | - meta data saved during base model training 65 | """ 66 | assert phase in ["train", "eval"], f"Invalid phase: {phase}" 67 | model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) 68 | if device.type in {"cpu", "mps"}: 69 | # Convert bfloat16 tensors to float for CPU inference 70 | model_data = { 71 | k: v.float() if v.dtype == torch.bfloat16 else v 72 | for k, v in model_data.items() 73 | } 74 | # Hack: fix torch compile issue, which prepends all keys with _orig_mod. 75 | model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} 76 | model_config_kwargs = meta_data["model_config"] 77 | log0(f"Building model with config: {model_config_kwargs}") 78 | model_config = GPTConfig(**model_config_kwargs) 79 | with torch.device("meta"): 80 | model = GPT(model_config) 81 | # Load the model state 82 | model.to_empty(device=device) 83 | model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init 84 | model.load_state_dict(model_data, strict=True, assign=True) 85 | # Put the model in the right training phase / mode 86 | if phase == "eval": 87 | model.eval() 88 | else: 89 | model.train() 90 | # Load the Tokenizer 91 | tokenizer = get_tokenizer() 92 | # Sanity check: compatibility between model and tokenizer 93 | assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"] 94 | return model, tokenizer, meta_data 95 | 96 | 97 | def find_largest_model(checkpoints_dir): 98 | # attempt to guess the model tag: take the biggest model available 99 | model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))] 100 | if not model_tags: 101 | raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}") 102 | # 1) normally all model tags are of the form d, try that first: 103 | candidates = [] 104 | for model_tag in model_tags: 105 | match = re.match(r"d(\d+)", model_tag) 106 | if match: 107 | model_depth = int(match.group(1)) 108 | candidates.append((model_depth, model_tag)) 109 | if candidates: 110 | candidates.sort(key=lambda x: x[0], reverse=True) 111 | return candidates[0][1] 112 | # 2) if that failed, take the most recently updated model: 113 | model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True) 114 | return model_tags[0] 115 | 116 | 117 | def find_last_step(checkpoint_dir): 118 | # Look into checkpoint_dir and find model_.pt with the highest step 119 | checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) 120 | if not checkpoint_files: 121 | raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") 122 | last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) 123 | return last_step 124 | 125 | # ----------------------------------------------------------------------------- 126 | # convenience functions that take into account nanochat's directory structure 127 | 128 | def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): 129 | if model_tag is None: 130 | # guess the model tag by defaulting to the largest model 131 | model_tag = find_largest_model(checkpoints_dir) 132 | log0(f"No model tag provided, guessing model tag: {model_tag}") 133 | checkpoint_dir = os.path.join(checkpoints_dir, model_tag) 134 | if step is None: 135 | # guess the step by defaulting to the last step 136 | step = find_last_step(checkpoint_dir) 137 | assert step is not None, f"No checkpoints found in {checkpoint_dir}" 138 | # build the model 139 | log0(f"Loading model from {checkpoint_dir} with step {step}") 140 | model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase) 141 | return model, tokenizer, meta_data 142 | 143 | def load_model(source, *args, **kwargs): 144 | model_dir = { 145 | "base": "base_checkpoints", 146 | "mid": "mid_checkpoints", 147 | "sft": "chatsft_checkpoints", 148 | "rl": "chatrl_checkpoints", 149 | }[source] 150 | base_dir = get_base_dir() 151 | checkpoints_dir = os.path.join(base_dir, model_dir) 152 | return load_model_from_dir(checkpoints_dir, *args, **kwargs) 153 | -------------------------------------------------------------------------------- /rope_imaginary/nanochat_imaginary/nanochat/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common utilities for nanochat. 3 | """ 4 | 5 | import os 6 | import re 7 | import logging 8 | import urllib.request 9 | import torch 10 | import torch.distributed as dist 11 | from filelock import FileLock 12 | 13 | class ColoredFormatter(logging.Formatter): 14 | """Custom formatter that adds colors to log messages.""" 15 | # ANSI color codes 16 | COLORS = { 17 | 'DEBUG': '\033[36m', # Cyan 18 | 'INFO': '\033[32m', # Green 19 | 'WARNING': '\033[33m', # Yellow 20 | 'ERROR': '\033[31m', # Red 21 | 'CRITICAL': '\033[35m', # Magenta 22 | } 23 | RESET = '\033[0m' 24 | BOLD = '\033[1m' 25 | def format(self, record): 26 | # Add color to the level name 27 | levelname = record.levelname 28 | if levelname in self.COLORS: 29 | record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}" 30 | # Format the message 31 | message = super().format(record) 32 | # Add color to specific parts of the message 33 | if levelname == 'INFO': 34 | # Highlight numbers and percentages 35 | message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message) 36 | message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message) 37 | return message 38 | 39 | def setup_default_logging(): 40 | handler = logging.StreamHandler() 41 | handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) 42 | logging.basicConfig( 43 | level=logging.INFO, 44 | handlers=[handler] 45 | ) 46 | 47 | setup_default_logging() 48 | logger = logging.getLogger(__name__) 49 | 50 | def get_base_dir(): 51 | # co-locate nanochat intermediates with other cached data in ~/.cache (by default) 52 | if os.environ.get("NANOCHAT_BASE_DIR"): 53 | nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR") 54 | else: 55 | home_dir = os.path.expanduser("~") 56 | cache_dir = os.path.join(home_dir, ".cache") 57 | nanochat_dir = os.path.join(cache_dir, "nanochat") 58 | os.makedirs(nanochat_dir, exist_ok=True) 59 | return nanochat_dir 60 | 61 | def download_file_with_lock(url, filename, postprocess_fn=None): 62 | """ 63 | Downloads a file from a URL to a local path in the base directory. 64 | Uses a lock file to prevent concurrent downloads among multiple ranks. 65 | """ 66 | base_dir = get_base_dir() 67 | file_path = os.path.join(base_dir, filename) 68 | lock_path = file_path + ".lock" 69 | 70 | if os.path.exists(file_path): 71 | return file_path 72 | 73 | with FileLock(lock_path): 74 | # Only a single rank can acquire this lock 75 | # All other ranks block until it is released 76 | 77 | # Recheck after acquiring lock 78 | if os.path.exists(file_path): 79 | return file_path 80 | 81 | # Download the content as bytes 82 | print(f"Downloading {url}...") 83 | with urllib.request.urlopen(url) as response: 84 | content = response.read() # bytes 85 | 86 | # Write to local file 87 | with open(file_path, 'wb') as f: 88 | f.write(content) 89 | print(f"Downloaded to {file_path}") 90 | 91 | # Run the postprocess function if provided 92 | if postprocess_fn is not None: 93 | postprocess_fn(file_path) 94 | 95 | return file_path 96 | 97 | def print0(s="",**kwargs): 98 | ddp_rank = int(os.environ.get('RANK', 0)) 99 | if ddp_rank == 0: 100 | print(s, **kwargs) 101 | 102 | def print_banner(): 103 | # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ 104 | banner = """ 105 | █████ █████ 106 | ░░███ ░░███ 107 | ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████ 108 | ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░ 109 | ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███ 110 | ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███ 111 | ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████ 112 | ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░ 113 | """ 114 | print0(banner) 115 | 116 | def is_ddp(): 117 | # TODO is there a proper way 118 | return int(os.environ.get('RANK', -1)) != -1 119 | 120 | def get_dist_info(): 121 | if is_ddp(): 122 | assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) 123 | ddp_rank = int(os.environ['RANK']) 124 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 125 | ddp_world_size = int(os.environ['WORLD_SIZE']) 126 | return True, ddp_rank, ddp_local_rank, ddp_world_size 127 | else: 128 | return False, 0, 0, 1 129 | 130 | def autodetect_device_type(): 131 | # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU 132 | if torch.cuda.is_available(): 133 | device_type = "cuda" 134 | elif torch.backends.mps.is_available(): 135 | device_type = "mps" 136 | else: 137 | device_type = "cpu" 138 | print0(f"Autodetected device type: {device_type}") 139 | return device_type 140 | 141 | def compute_init(device_type="cuda"): # cuda|cpu|mps 142 | """Basic initialization that we keep doing over and over, so make common.""" 143 | 144 | assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" 145 | if device_type == "cuda": 146 | assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" 147 | if device_type == "mps": 148 | assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" 149 | 150 | # Reproducibility 151 | # Note that we set the global seeds here, but most of the code uses explicit rng objects. 152 | # The only place where global rng might be used is nn.Module initialization of the model weights. 153 | torch.manual_seed(42) 154 | if device_type == "cuda": 155 | torch.cuda.manual_seed(42) 156 | # skipping full reproducibility for now, possibly investigate slowdown later 157 | # torch.use_deterministic_algorithms(True) 158 | 159 | # Precision 160 | if device_type == "cuda": 161 | torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls 162 | 163 | # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA 164 | ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() 165 | if ddp and device_type == "cuda": 166 | device = torch.device("cuda", ddp_local_rank) 167 | torch.cuda.set_device(device) # make "cuda" default to this device 168 | dist.init_process_group(backend="nccl", device_id=device) 169 | dist.barrier() 170 | else: 171 | device = torch.device(device_type) # mps|cpu 172 | 173 | if ddp_rank == 0: 174 | logger.info(f"Distributed world size: {ddp_world_size}") 175 | 176 | return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device 177 | 178 | def compute_cleanup(): 179 | """Companion function to compute_init, to clean things up before script exit""" 180 | if is_ddp(): 181 | dist.destroy_process_group() 182 | 183 | class DummyWandb: 184 | """Useful if we wish to not use wandb but have all the same signatures""" 185 | def __init__(self): 186 | pass 187 | def log(self, *args, **kwargs): 188 | pass 189 | def finish(self): 190 | pass 191 | -------------------------------------------------------------------------------- /tiny_recursive_models/src/tiny_recursive_models/evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | """Evaluation utilities and main evaluation loop.""" 2 | 3 | import os 4 | from typing import Optional, List, Any 5 | import torch 6 | import torch.distributed as dist 7 | 8 | from tiny_recursive_models.training.config import PretrainConfig, TrainState 9 | from tiny_recursive_models.utils import load_model_class 10 | 11 | # Import from original locations (these haven't moved yet) 12 | from tiny_recursive_models.data.puzzle_dataset import PuzzleDatasetMetadata 13 | 14 | 15 | def create_evaluators(config: PretrainConfig, eval_metadata: PuzzleDatasetMetadata) -> List[Any]: 16 | """Create evaluator instances from config. 17 | 18 | Args: 19 | config: Training configuration 20 | eval_metadata: Evaluation dataset metadata 21 | 22 | Returns: 23 | List of evaluator instances 24 | """ 25 | data_paths = config.data_paths_test if len(config.data_paths_test) > 0 else config.data_paths 26 | # Initialize evaluators 27 | evaluators = [] 28 | for cfg in config.evaluators: 29 | for data_path in data_paths: 30 | cls = load_model_class(cfg.name, "tiny_recursive_models.evaluation.")( 31 | data_path=data_path, eval_metadata=eval_metadata, **cfg.__pydantic_extra__ 32 | ) # type: ignore 33 | evaluators.append(cls) 34 | 35 | return evaluators 36 | 37 | 38 | def evaluate( 39 | config: PretrainConfig, 40 | train_state: TrainState, 41 | eval_loader: torch.utils.data.DataLoader, 42 | eval_metadata: PuzzleDatasetMetadata, 43 | evaluators: List[Any], 44 | rank: int, 45 | world_size: int, 46 | cpu_group: Optional[dist.ProcessGroup], 47 | ): 48 | """Run evaluation on test set. 49 | 50 | This function: 51 | 1. Runs model inference on all test batches 52 | 2. Computes basic metrics (accuracy, loss) 53 | 3. Calls task-specific evaluators for advanced metrics 54 | 4. Saves predictions if requested 55 | 56 | Args: 57 | config: Training configuration 58 | train_state: Current training state 59 | eval_loader: DataLoader for test set 60 | eval_metadata: Metadata about test dataset 61 | evaluators: List of task-specific evaluators 62 | rank: Current process rank 63 | world_size: Total processes 64 | cpu_group: CPU process group for communication 65 | 66 | Returns: 67 | Dictionary of metrics (on rank 0 only) 68 | """ 69 | reduced_metrics = None 70 | 71 | with torch.inference_mode(): 72 | return_keys = set(config.eval_save_outputs) 73 | for evaluator in evaluators: 74 | evaluator.begin_eval() 75 | return_keys.update(evaluator.required_outputs) 76 | 77 | # Run evaluation 78 | set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)} 79 | 80 | save_preds = {} 81 | 82 | metric_keys = [] 83 | metric_values = None 84 | 85 | carry = None 86 | processed_batches = 0 87 | 88 | for set_name, batch, global_batch_size in eval_loader: 89 | processed_batches += 1 90 | if rank == 0: 91 | print(f"Processing batch {processed_batches}: {set_name}") 92 | 93 | # To device 94 | batch = {k: v.cuda() for k, v in batch.items()} 95 | with torch.device("cuda"): 96 | carry = train_state.model.initial_carry(batch) # type: ignore 97 | 98 | # Forward 99 | inference_steps = 0 100 | while True: 101 | carry, loss, metrics, preds, all_finish = train_state.model( 102 | carry=carry, batch=batch, return_keys=return_keys 103 | ) 104 | inference_steps += 1 105 | 106 | if all_finish: 107 | break 108 | 109 | if rank == 0: 110 | print(f" Completed inference in {inference_steps} steps") 111 | 112 | for collection in (batch, preds): 113 | for k, v in collection.items(): 114 | if k in config.eval_save_outputs: 115 | save_preds.setdefault(k, []) 116 | save_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory 117 | 118 | for evaluator in evaluators: 119 | evaluator.update_batch(batch, preds) 120 | 121 | del carry, loss, preds, batch, all_finish 122 | 123 | # Aggregate metrics 124 | set_id = set_ids[set_name] 125 | 126 | if metric_values is None: 127 | metric_keys = list( 128 | sorted(metrics.keys()) 129 | ) # Sort keys to guarantee all processes use the same order. 130 | metric_values = torch.zeros( 131 | (len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda" 132 | ) 133 | 134 | metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys]) 135 | 136 | del metrics 137 | 138 | # concatenate save preds 139 | save_preds = {k: torch.cat(v, dim=0) for k, v in save_preds.items()} 140 | 141 | # Save preds 142 | if config.checkpoint_path is not None and len(save_preds): 143 | # Each rank save predictions independently 144 | os.makedirs(os.path.dirname(config.checkpoint_path), exist_ok=True) 145 | torch.save( 146 | save_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}") 147 | ) 148 | 149 | del save_preds 150 | 151 | # Reduce to rank 0 152 | if metric_values is not None: 153 | if world_size > 1: 154 | dist.reduce(metric_values, dst=0) 155 | 156 | if rank == 0: 157 | reduced_metrics = metric_values.cpu().numpy() 158 | reduced_metrics = { 159 | set_name: { 160 | metric_name: reduced_metrics[set_id, metric_id] 161 | for metric_id, metric_name in enumerate(metric_keys) 162 | } 163 | for set_id, set_name in enumerate(set_ids) 164 | } 165 | 166 | # Postprocess 167 | for set_name, m in reduced_metrics.items(): 168 | count = m.pop("count") 169 | reduced_metrics[set_name] = {k: v / count for k, v in m.items()} 170 | 171 | # Run evaluators 172 | if rank == 0: 173 | print(f"\nRunning {len(evaluators)} evaluator(s)...") 174 | 175 | for i, evaluator in enumerate(evaluators): 176 | if rank == 0: 177 | print(f"Running evaluator {i+1}/{len(evaluators)}: {evaluator.__class__.__name__}") 178 | 179 | # Path for saving 180 | evaluator_save_path = None 181 | if config.checkpoint_path is not None: 182 | evaluator_save_path = os.path.join( 183 | config.checkpoint_path, 184 | f"evaluator_{evaluator.__class__.__name__}_step_{train_state.step}", 185 | ) 186 | os.makedirs(evaluator_save_path, exist_ok=True) 187 | 188 | # Run and log 189 | metrics = evaluator.result(evaluator_save_path, rank=rank, world_size=world_size, group=cpu_group) 190 | if rank == 0 and metrics is not None: 191 | if reduced_metrics is None: 192 | reduced_metrics = {} 193 | 194 | reduced_metrics.update(metrics) 195 | print(f" Completed {evaluator.__class__.__name__}") 196 | 197 | if rank == 0: 198 | print("All evaluators completed!") 199 | 200 | return reduced_metrics 201 | --------------------------------------------------------------------------------