├── .gitattributes ├── .gitignore ├── plots ├── adam_2d.png ├── finetune_bar.png └── gpt3xl_sgd.png ├── pretraining ├── requirements.txt ├── configs │ ├── model │ │ ├── lm11m.yaml │ │ ├── lm19m.yaml │ │ ├── gpt2s.yaml │ │ └── gpt3xl.yaml │ ├── dataset │ │ ├── fw_gpt2.yaml │ │ ├── fwedu_gpt2.yaml │ │ └── c4_t5all.yaml │ ├── resolver_setup.py │ ├── hparams │ │ └── lm11m_fwedu_adam.yaml │ └── base.yaml ├── runs │ ├── gpt2s │ │ ├── optimizer_comp │ │ │ ├── gpt2s_sgd.yaml │ │ │ ├── gpt2s_adafactor.yaml │ │ │ ├── gpt2s_adam_bs1_b2.yaml │ │ │ ├── gpt2s_adam_bs1_t2.yaml │ │ │ ├── gpt2s_adam_bs512.yaml │ │ │ └── final_configs.sh │ │ └── gpt2s_adam_2d.yaml │ ├── lm30m │ │ ├── optimizer_comp │ │ │ ├── lm30m_sgd_grid.yaml │ │ │ ├── lm30m_muon_grid.yaml │ │ │ ├── lm30m_adafactor_grid.yaml │ │ │ └── lm30m_adam_grid.yaml │ │ ├── hparam_scaling │ │ │ ├── lm30m_adam_lr.yaml │ │ │ ├── lm30m_adam_b1.yaml │ │ │ ├── lm30m_adam_b2.yaml │ │ │ └── lm30m_adam_t2.yaml │ │ └── lm30m_adam_sensitivity.yaml │ ├── lm19m │ │ ├── fig10_fixed_b2.yaml │ │ └── fig10_fixed_t2.yaml │ └── gpt3xl │ │ └── gpt3xl_optimizer_comp.sh ├── README.md ├── rope.py ├── main.py ├── data.py ├── download_fineweb.py ├── utils.py ├── train.py ├── model.py ├── factorized.py ├── train.ipynb └── optimizer.py ├── finetuning ├── requirements.txt ├── runs │ ├── gemma4bpt_math_adam_bs1.yaml │ ├── gemma4bpt_math_adam_bs16.yaml │ ├── gemma4bpt_math_lora.yaml │ └── gemma4bpt_math_adafactor.yaml ├── README.md ├── rope.py ├── utils.py ├── optimizer.py ├── sampler.py ├── data.py ├── factorized.py ├── finetune.py ├── gemma.py └── finetune.ipynb ├── LICENSE ├── utils ├── utils.py └── memory_measure.ipynb └── README.md /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .ipynb_checkpoints 3 | __pycache__ -------------------------------------------------------------------------------- /plots/adam_2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martin-marek/batch-size/HEAD/plots/adam_2d.png -------------------------------------------------------------------------------- /plots/finetune_bar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martin-marek/batch-size/HEAD/plots/finetune_bar.png -------------------------------------------------------------------------------- /plots/gpt3xl_sgd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martin-marek/batch-size/HEAD/plots/gpt3xl_sgd.png -------------------------------------------------------------------------------- /pretraining/requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | flax==0.12.0 3 | optax 4 | numpy 5 | tqdm 6 | hydra-core 7 | wandb -------------------------------------------------------------------------------- /pretraining/configs/model/lm11m.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | D: 384 5 | L: 6 6 | T: 512 7 | -------------------------------------------------------------------------------- /pretraining/configs/model/lm19m.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | D: 512 5 | L: 6 6 | T: 512 7 | -------------------------------------------------------------------------------- /pretraining/configs/model/gpt2s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | D: 768 5 | L: 12 6 | T: 1024 7 | -------------------------------------------------------------------------------- /pretraining/configs/dataset/fw_gpt2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | ds_path: '~/datasets/fineweb_gpt2.bin' 4 | 5 | model: 6 | V: 50257 7 | -------------------------------------------------------------------------------- /pretraining/configs/dataset/fwedu_gpt2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | ds_path: '~/datasets/finewebedu_gpt2.bin' 4 | 5 | model: 6 | V: 50257 7 | -------------------------------------------------------------------------------- /pretraining/configs/model/gpt3xl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | num_tokens_train: 10_000_000_000 4 | 5 | model: 6 | D: 2048 7 | L: 24 8 | T: 2048 9 | -------------------------------------------------------------------------------- /pretraining/configs/dataset/c4_t5all.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | ds_path: '~/datasets/c4_t5all_1B_512.bin' 4 | pad_eval: true 5 | 6 | model: 7 | V: 32101 8 | -------------------------------------------------------------------------------- /finetuning/requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | flax==0.12.0 3 | optax 4 | orbax 5 | fire 6 | numpy 7 | tqdm 8 | wandb 9 | sentencepiece 10 | kagglehub 11 | datasets 12 | math_verify 13 | git+https://github.com/google/qwix -------------------------------------------------------------------------------- /pretraining/configs/resolver_setup.py: -------------------------------------------------------------------------------- 1 | import operator as op 2 | from omegaconf import OmegaConf 3 | 4 | OmegaConf.register_new_resolver('floordiv', op.floordiv) 5 | OmegaConf.register_new_resolver('mul', op.mul) 6 | OmegaConf.register_new_resolver('min', min) 7 | OmegaConf.register_new_resolver('pow', pow) -------------------------------------------------------------------------------- /pretraining/runs/gpt2s/optimizer_comp/gpt2s_sgd.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2s_sgd_bs1_v3_1 2 | method: grid 3 | parameters: 4 | opt.peak_lr: 5 | values: [0.006, 0.012, 0.025, 0.05, 0.1, 0.2, 0.4, 0.8] # -> 8x 6 | program: main.py 7 | command: 8 | - ${env} 9 | - ${interpreter} 10 | - ${program} 11 | - +model=gpt2s 12 | - +dataset=fw_gpt2 13 | - opt.optimizer='sgd' 14 | - opt.batch_size=1 15 | - ${args_no_hyphens} 16 | -------------------------------------------------------------------------------- /pretraining/runs/gpt2s/optimizer_comp/gpt2s_adafactor.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2s_adafactor_bs1_v3_1 2 | method: grid 3 | parameters: 4 | opt.peak_lr: 5 | values: [0.0003, 0.0006, 0.0012, 0.0025, 0.005, 0.01, 0.02, 0.04] # -> 8x 6 | program: main.py 7 | command: 8 | - ${env} 9 | - ${interpreter} 10 | - ${program} 11 | - +model=gpt2s 12 | - +dataset=fw_gpt2 13 | - opt.optimizer='adafactor' 14 | - opt.batch_size=1 15 | - opt.b2=0.9999 16 | - ${args_no_hyphens} 17 | -------------------------------------------------------------------------------- /pretraining/runs/gpt2s/optimizer_comp/gpt2s_adam_bs1_b2.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2s_adam_bs1_b2_v3_1 2 | method: grid 3 | parameters: 4 | opt.peak_lr: 5 | values: [0.000075, 0.00015, 0.0003, 0.0006, 0.0012, 0.0024, 0.0048, 0.01] # -> 8x 6 | program: main.py 7 | command: 8 | - ${env} 9 | - ${interpreter} 10 | - ${program} 11 | - +model=gpt2s 12 | - +dataset=fw_gpt2 13 | - opt.optimizer='adamw' 14 | - opt.batch_size=1 15 | - opt.b1=0.9 16 | - opt.b2=0.95 17 | - ${args_no_hyphens} 18 | -------------------------------------------------------------------------------- /pretraining/runs/gpt2s/optimizer_comp/gpt2s_adam_bs1_t2.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2s_adam_bs1_t2_v3_1 2 | method: grid 3 | parameters: 4 | opt.peak_lr: 5 | values: [0.000075, 0.00015, 0.0003, 0.0006, 0.0012, 0.0024, 0.0048, 0.01] # -> 8x 6 | program: main.py 7 | command: 8 | - ${env} 9 | - ${interpreter} 10 | - ${program} 11 | - +model=gpt2s 12 | - +dataset=fw_gpt2 13 | - opt.optimizer='adamw' 14 | - opt.batch_size=1 15 | - opt.b1=0.9 16 | - opt.b2=0.9999 17 | - ${args_no_hyphens} 18 | -------------------------------------------------------------------------------- /pretraining/runs/gpt2s/optimizer_comp/gpt2s_adam_bs512.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2s_adam_bs512_v3_1 2 | method: grid 3 | parameters: 4 | opt.peak_lr: 5 | values: [0.00015, 0.0003, 0.0006, 0.0012, 0.0024, 0.0048, 0.01, 0.02] # -> 8x 6 | program: main.py 7 | command: 8 | - ${env} 9 | - ${interpreter} 10 | - ${program} 11 | - +model=gpt2s 12 | - +dataset=fw_gpt2 13 | - opt.optimizer='adamw' 14 | - opt.batch_size=512 15 | - opt.max_microbatch_size=16 16 | - opt.b1=0.9 17 | - opt.b2=0.95 18 | - opt.weight_decay=0.1 19 | - ${args_no_hyphens} 20 | -------------------------------------------------------------------------------- /finetuning/runs/gemma4bpt_math_adam_bs1.yaml: -------------------------------------------------------------------------------- 1 | name: gemma4bpt_math_5ep_adam_bs1_v2 2 | method: grid 3 | parameters: 4 | peak_lr: 5 | values: [3.2e-9, 1.0e-8, 3.2e-8, 1.0e-7, 3.2e-7, 1.0e-6, 3.2e-6, 1.0e-5] 6 | seed: 7 | values: [0, 1, 2, 3] 8 | program: finetune.py 9 | command: 10 | - ${env} 11 | - ${interpreter} 12 | - ${program} 13 | - --model_variant='gemma3-4b' 14 | - --param_dtype='float32' 15 | - --optimizer_name='adam' 16 | - --batch_size=1 17 | - --n_epochs=5 18 | - --b2=0.997 19 | - --eval_batch_size=128 20 | - --wandb_mode='online' 21 | - ${args} 22 | -------------------------------------------------------------------------------- /finetuning/runs/gemma4bpt_math_adam_bs16.yaml: -------------------------------------------------------------------------------- 1 | name: gemma4bpt_math_5ep_adam_bs16_v2 2 | method: grid 3 | parameters: 4 | peak_lr: 5 | values: [3.2e-8, 1.0e-7, 3.2e-7, 1.0e-6, 3.2e-6, 1.0e-5, 3.2e-5, 1.0e-4] 6 | seed: 7 | values: [0, 1, 2, 3] 8 | program: finetune.py 9 | command: 10 | - ${env} 11 | - ${interpreter} 12 | - ${program} 13 | - --model_variant='gemma3-4b' 14 | - --param_dtype='float32' 15 | - --optimizer_name='adam' 16 | - --batch_size=16 17 | - --n_epochs=5 18 | - --b2=0.95 19 | - --eval_batch_size=128 20 | - --wandb_mode='online' 21 | - ${args} 22 | -------------------------------------------------------------------------------- /finetuning/runs/gemma4bpt_math_lora.yaml: -------------------------------------------------------------------------------- 1 | name: gemma4bpt_math_5ep_lora_v2 2 | method: grid 3 | parameters: 4 | peak_lr: 5 | values: [3.2e-7, 1.0e-6, 3.2e-6, 1.0e-5, 3.2e-5, 1.0e-4, 3.2e-4, 1.0e-3] 6 | seed: 7 | values: [0, 1, 2, 3] 8 | program: finetune.py 9 | command: 10 | - ${env} 11 | - ${interpreter} 12 | - ${program} 13 | - --model_variant='gemma3-4b' 14 | - --param_dtype='bfloat16' 15 | - --lora_rank=16 16 | - --optimizer_name='adam' 17 | - --batch_size=1 18 | - --n_epochs=5 19 | - --b2=0.997 20 | - --eval_batch_size=128 21 | - --wandb_mode='online' 22 | - ${args} 23 | -------------------------------------------------------------------------------- /pretraining/runs/lm30m/optimizer_comp/lm30m_sgd_grid.yaml: -------------------------------------------------------------------------------- 1 | name: lm30m_sgd_grid_1 2 | method: grid 3 | parameters: 4 | opt.batch_size: 5 | values: [1, 4, 16, 64, 256] # -> 5x 6 | opt.peak_lr: 7 | values: [0.050, 0.064, 0.081, 0.103, 0.132, 0.168, 0.214, 0.273, 0.348, 0.443, 0.564, 0.719, 0.916, 1.168, 1.488, 1.896, 2.416, 3.079, 3.924, 5.000] # -> 20x 8 | # total: 5*20*=100 9 | program: main.py 10 | command: 11 | - ${env} 12 | - ${interpreter} 13 | - ${program} 14 | - +model=lm11m 15 | - +dataset=fwedu_gpt2 16 | - opt.optimizer='sgd' 17 | - opt.max_microbatch_size=16 18 | - ${args_no_hyphens} 19 | -------------------------------------------------------------------------------- /finetuning/runs/gemma4bpt_math_adafactor.yaml: -------------------------------------------------------------------------------- 1 | name: gemma4bpt_math_5ep_adafactor_v2 2 | method: grid 3 | parameters: 4 | peak_lr: 5 | values: [3.2e-7, 1.0e-6, 3.2e-6, 1.0e-5, 3.2e-5, 1.0e-4, 3.2e-4, 1.0e-3] 6 | seed: 7 | values: [0, 1, 2, 3] 8 | program: finetune.py 9 | command: 10 | - ${env} 11 | - ${interpreter} 12 | - ${program} 13 | - --model_variant='gemma3-4b' 14 | - --param_dtype='bfloat16' 15 | - --stochastic_round=True 16 | - --optimizer_name='adafactor' 17 | - --batch_size=1 18 | - --n_epochs=5 19 | - --b2=0.997 20 | - --eval_batch_size=128 21 | - --wandb_mode='online' 22 | - ${args} 23 | -------------------------------------------------------------------------------- /pretraining/runs/lm30m/hparam_scaling/lm30m_adam_lr.yaml: -------------------------------------------------------------------------------- 1 | name: lm11m_adam_lr_seeds_1 2 | method: grid 3 | parameters: 4 | seed: 5 | values: [0, 1, 2, 3, 4] # -> 5x 6 | opt.batch_size: 7 | values: [1, 4, 16, 64, 256, 1024, 4096] # -> 7x 8 | +overwrite.opt.peak_lr: 9 | values: [0.00020, 0.00030, 0.00046, 0.00070, 0.00106, 0.00161, 0.00245, 0.00372, 0.00565, 0.00857, 0.01301, 0.01976, 0.03000] # -> 13x 10 | # total: 5*7*13=455 11 | program: main.py 12 | command: 13 | - ${env} 14 | - ${interpreter} 15 | - ${program} 16 | - +hparams=lm11m_fwedu_adam 17 | - opt.max_microbatch_size=16 18 | - ${args_no_hyphens} 19 | -------------------------------------------------------------------------------- /pretraining/runs/lm30m/hparam_scaling/lm30m_adam_b1.yaml: -------------------------------------------------------------------------------- 1 | name: lm11m_adam_b1_seeds_1 2 | method: grid 3 | parameters: 4 | seed: 5 | values: [0, 1, 2, 3, 4] # -> 5x 6 | opt.batch_size: 7 | values: [1, 4, 16, 64, 256, 1024, 4096] # -> 7x 8 | opt.b1: 9 | values: [0.00000, 0.50000, 0.68911, 0.81870, 0.89811, 0.94391, 0.96946, 0.98348, 0.99109, 0.99520, 0.99742, 0.99861, 0.99925] # -> 13x 10 | # total: 5*7*13=455 11 | program: main.py 12 | command: 13 | - ${env} 14 | - ${interpreter} 15 | - ${program} 16 | - +hparams=lm11m_fwedu_adam 17 | - +overwrite.opt.t1=null 18 | - opt.max_microbatch_size=16 19 | - ${args_no_hyphens} 20 | -------------------------------------------------------------------------------- /pretraining/runs/lm30m/hparam_scaling/lm30m_adam_b2.yaml: -------------------------------------------------------------------------------- 1 | name: lm11m_adam_b2_seeds_1 2 | method: grid 3 | parameters: 4 | seed: 5 | values: [0, 1, 2, 3, 4] # -> 5x 6 | opt.batch_size: 7 | values: [1, 4, 16, 64, 256, 1024, 4096] # -> 7x 8 | opt.b2: 9 | values: [0.00000, 0.50000, 0.68911, 0.81870, 0.89811, 0.94391, 0.96946, 0.98348, 0.99109, 0.99520, 0.99742, 0.99861, 0.99925] # -> 13x 10 | # total: 5*7*13=455 11 | program: main.py 12 | command: 13 | - ${env} 14 | - ${interpreter} 15 | - ${program} 16 | - +hparams=lm11m_fwedu_adam 17 | - +overwrite.opt.t2=null 18 | - opt.max_microbatch_size=16 19 | - ${args_no_hyphens} 20 | -------------------------------------------------------------------------------- /pretraining/runs/lm30m/hparam_scaling/lm30m_adam_t2.yaml: -------------------------------------------------------------------------------- 1 | name: lm11m_adam_t2_seeds_1 2 | method: grid 3 | parameters: 4 | seed: 5 | values: [0, 1, 2, 3, 4] # -> 5x 6 | opt.batch_size: 7 | values: [1, 4, 16, 64, 256, 1024, 4096] # -> 7x 8 | +overwrite.opt.t2: 9 | values: [100_000, 177_828, 316_228, 562_341, 1_000_000, 1_778_279, 3_162_278, 5_623_413, 10_000_000, 17_782_794, 31_622_777, 56_234_133, 100_000_000] # -> 13 10 | # total: 5*7*13=455 11 | program: main.py 12 | command: 13 | - ${env} 14 | - ${interpreter} 15 | - ${program} 16 | - +hparams=lm11m_fwedu_adam 17 | - opt.max_microbatch_size=16 18 | - ${args_no_hyphens} 19 | -------------------------------------------------------------------------------- /pretraining/runs/lm30m/optimizer_comp/lm30m_muon_grid.yaml: -------------------------------------------------------------------------------- 1 | name: lm30m_muon_grid_1 2 | method: grid 3 | parameters: 4 | opt.batch_size: 5 | values: [1, 4, 16, 64, 256, 1024, 4096] # -> 7x 6 | opt.muon_lr: 7 | values: [0.0001, 0.0002, 0.0005, 0.0010, 0.0022, 0.0046, 0.0100, 0.0215, 0.0464, 0.1000] # -> 10x 8 | opt.muon_t1: 9 | values: [1000, 2783, 7743, 21544, 59948, 166810, 464159, 1291550, 3593814, 10000000] # -> 10x 10 | # total: 7*10*10=700 11 | program: main.py 12 | command: 13 | - ${env} 14 | - ${interpreter} 15 | - ${program} 16 | - +hparams=lm11m_fwedu_adam 17 | - opt.optimizer='muon' 18 | - opt.max_microbatch_size=16 19 | - ${args_no_hyphens} 20 | -------------------------------------------------------------------------------- /pretraining/runs/lm30m/lm30m_adam_sensitivity.yaml: -------------------------------------------------------------------------------- 1 | name: lm11m_adam_sensitivity_seeds_1 2 | method: grid 3 | parameters: 4 | seed: 5 | values: [0, 1, 2, 3, 4, 5, 6] # -> 7x 6 | opt.batch_size: 7 | values: [1, 4, 16, 64, 256, 1024, 4096] # -> 7x 8 | +scaling_1d.key: 9 | values: ['opt.peak_lr', 'opt.t1', 'opt.t2'] # -> 3x 10 | +scaling_1d.value: 11 | values: [0.125, 0.177, 0.250, 0.354, 0.500, 0.707, 1.000, 1.414, 2.000, 2.828, 4.000, 5.657, 8.000] # -> 13 12 | # total: 7*7*3*13=1911 13 | program: main.py 14 | command: 15 | - ${env} 16 | - ${interpreter} 17 | - ${program} 18 | - +hparams=lm11m_fwedu_adam 19 | - opt.max_microbatch_size=16 20 | - ${args_no_hyphens} 21 | -------------------------------------------------------------------------------- /pretraining/runs/lm30m/optimizer_comp/lm30m_adafactor_grid.yaml: -------------------------------------------------------------------------------- 1 | name: lm30m_adafactor_grid_1 2 | method: grid 3 | parameters: 4 | opt.batch_size: 5 | values: [1, 4, 16, 64, 256, 1024, 4096] # -> 7x 6 | opt.peak_lr_scaled: 7 | values: [0.00010, 0.00035, 0.00120, 0.00416, 0.01443, 0.05000] # -> 6x 8 | opt.t2: 9 | values: [1_000_000, 3_160_000, 10_000_000, 31_600_000, 100_000_000] # -> 5x 10 | # total: 7*6*5=210 11 | program: main.py 12 | command: 13 | - ${env} 14 | - ${interpreter} 15 | - ${program} 16 | - +model=lm11m 17 | - +dataset=fwedu_gpt2 18 | - opt.optimizer='adafactor' 19 | - opt.max_microbatch_size=16 20 | - opt.peak_lr_scaling='${pow:${opt.batch_size},0.3}' 21 | - ${args_no_hyphens} 22 | -------------------------------------------------------------------------------- /pretraining/runs/lm19m/fig10_fixed_b2.yaml: -------------------------------------------------------------------------------- 1 | name: fig10_fixed_b2_v3 2 | method: grid 3 | parameters: 4 | seed: 5 | values: [0, 1, 2, 3, 4] # -> 5x 6 | opt.batch_size: 7 | values: [16, 32, 64, 128, 256, 512, 1024] # -> 7x 8 | opt.peak_lr_scaled: 9 | values: [1.6e-6, 3.1e-6, 6.2e-6, 1.25e-5, 2.5e-5, 5.0e-5, 1.0e-4, 2.0e-4, 4.0e-4, 8.0e-4, 1.6e-3, 3.1e-3, 6.2e-3, 1.25e-2, 2.5e-2, 5.0e-2, 1.0e-1, 2.0e-1] # -> 18x 10 | # total: 5*7*18=630 11 | program: main.py 12 | command: 13 | - ${env} 14 | - ${interpreter} 15 | - ${program} 16 | - +model=lm19m 17 | - +dataset=c4_t5all 18 | - opt.optimizer='adamw' 19 | - opt.b1=0.9 20 | - opt.b2=0.95 21 | - opt.weight_decay=0.1 22 | - opt.max_microbatch_size=16 23 | - opt.peak_lr_scaling='${pow:${opt.batch_size},0.25}' 24 | - ${args_no_hyphens} 25 | -------------------------------------------------------------------------------- /pretraining/runs/lm19m/fig10_fixed_t2.yaml: -------------------------------------------------------------------------------- 1 | name: fig10_fixed_t2_v3 2 | method: grid 3 | parameters: 4 | seed: 5 | values: [0, 1, 2, 3, 4] # -> 5x 6 | opt.batch_size: 7 | values: [16, 32, 64, 128, 256, 512, 1024] # -> 7x 8 | opt.peak_lr_scaled: 9 | values: [1.6e-6, 3.1e-6, 6.2e-6, 1.25e-5, 2.5e-5, 5.0e-5, 1.0e-4, 2.0e-4, 4.0e-4, 8.0e-4, 1.6e-3, 3.1e-3, 6.2e-3, 1.25e-2, 2.5e-2, 5.0e-2, 1.0e-1, 2.0e-1] # -> 18x 10 | # total: 5*7*18=630 11 | program: main.py 12 | command: 13 | - ${env} 14 | - ${interpreter} 15 | - ${program} 16 | - +model=lm19m 17 | - +dataset=c4_t5all 18 | - opt.optimizer='adamw' 19 | - opt.b1=0.9 20 | - opt.t2=10_000_000 21 | - opt.weight_decay=0.1 22 | - opt.max_microbatch_size=16 23 | - opt.peak_lr_scaling='${pow:${opt.batch_size},0.25}' 24 | - ${args_no_hyphens} 25 | -------------------------------------------------------------------------------- /pretraining/runs/lm30m/optimizer_comp/lm30m_adam_grid.yaml: -------------------------------------------------------------------------------- 1 | name: lm30m_adam_grid_1 2 | method: grid 3 | parameters: 4 | opt.batch_size: 5 | values: [1, 4, 16, 64, 256, 1024, 4096] # -> 7x 6 | opt.peak_lr_scaled: 7 | values: [0.00010, 0.00035, 0.00120, 0.00416, 0.01443, 0.05000] # -> 6x 8 | opt.b1: 9 | values: [0.00000, 0.50000, 0.74655, 0.88403, 0.94935, 0.97832, 0.99080, 0.99611, 0.99836, 0.99931] # -> 10x 10 | opt.t2: 11 | values: [1_000_000, 3_160_000, 10_000_000, 31_600_000, 100_000_000] # -> 5x 12 | # total: 7*6*10*5=2100 13 | program: main.py 14 | command: 15 | - ${env} 16 | - ${interpreter} 17 | - ${program} 18 | - +model=lm11m 19 | - +dataset=fwedu_gpt2 20 | - opt.optimizer='adamw' 21 | - opt.max_microbatch_size=16 22 | - opt.peak_lr_scaling='${pow:${opt.batch_size},0.3}' 23 | - ${args_no_hyphens} 24 | -------------------------------------------------------------------------------- /pretraining/runs/gpt2s/gpt2s_adam_2d.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2s_adam_2d_1 2 | method: grid 3 | parameters: 4 | opt.batch_size: 5 | values: [1, 512] # -> 2x 6 | opt.b1: 7 | values: [0.00000, 0.50000, 0.68913, 0.81873, 0.89814, 0.94393, 0.96948, 0.98349, 0.99110, 0.99521, 0.99742, 0.99861] # -> 12x 8 | +scaling.opt.peak_lr: 9 | values: [0.125, 0.177, 0.250, 0.354, 0.500, 0.707, 1.000, 1.414, 2.000, 2.828, 4.000, 5.657, 8.000] # -> 13 10 | # total: 2*12*13=312 11 | program: main.py 12 | command: 13 | - ${env} 14 | - ${interpreter} 15 | - ${program} 16 | - +model=gpt2s 17 | - +dataset=fw_gpt2 18 | - opt.optimizer='adamw' 19 | - opt.t2=10_000_000 20 | - opt.max_microbatch_size=16 21 | - +bs_configs.bs1.opt.peak_lr=0.000488 # 2^(-11) 22 | - +bs_configs.bs512.opt.peak_lr=0.0039 # 2^(-8) 23 | - +bs_configs.bs512.opt.weight_decay=0.1 24 | - ${args_no_hyphens} 25 | -------------------------------------------------------------------------------- /pretraining/README.md: -------------------------------------------------------------------------------- 1 | # Pretraining 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/martin-marek/batch-size/blob/main/pretraining/train.ipynb) 4 | 5 | ## Configs 6 | 7 | We used Hydra for config management. All model and dataset configs can be found in the [configs](configs) directory. 8 | 9 | ## Manual training 10 | 11 | Here's an example of setting up a manual training run: 12 | ```bash 13 | python main.py +model=gpt3xl +dataset=fw_gpt2 14 | ``` 15 | To quickly get started, we provide a [Colab Notebook](https://colab.research.google.com/github/martin-marek/batch-size/blob/main/pretraining/train.ipynb) to train a model from scratch. 16 | 17 | ## Sweeps 18 | 19 | We performed almost all of our experiments using Weights & Biases sweeps. The config files for each sweep can be found in the [runs](runs) directory. -------------------------------------------------------------------------------- /pretraining/configs/hparams/lm11m_fwedu_adam.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /base 5 | - /model/lm11m 6 | - /dataset/fwedu_gpt2 7 | - _self_ 8 | 9 | bs_configs: 10 | bs1: 11 | opt: 12 | peak_lr: 0.0017 13 | t1: 40_000 14 | t2: 10_000_000 15 | bs4: 16 | opt: 17 | peak_lr: 0.002 18 | t1: 40_000 19 | t2: 10_000_000 20 | bs16: 21 | opt: 22 | peak_lr: 0.002 23 | t1: 60_000 24 | t2: 10_000_000 25 | bs64: 26 | opt: 27 | peak_lr: 0.003 28 | t1: 200_000 29 | t2: 10_000_000 30 | bs256: 31 | opt: 32 | peak_lr: 0.005 33 | t1: 1_000_000 34 | t2: 10_000_000 35 | bs1024: 36 | opt: 37 | peak_lr: 0.005 38 | t1: 3_000_000 39 | t2: 10_000_000 40 | bs4096: 41 | opt: 42 | peak_lr: 0.0025 43 | t1: 6_000_000 44 | t2: 10_000_000 45 | 46 | opt: 47 | optimizer: 'adamw' 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Martin Marek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pretraining/runs/gpt2s/optimizer_comp/final_configs.sh: -------------------------------------------------------------------------------- 1 | 2 | # bs512, adamw 3 | python main.py \ 4 | +model=gpt2s \ 5 | +dataset=fw_gpt2 \ 6 | opt.optimizer='adamw' \ 7 | opt.batch_size=512 \ 8 | opt.max_microbatch_size=16 \ 9 | model.remat=True \ 10 | opt.peak_lr=0.0048 \ 11 | opt.b1=0.9 \ 12 | opt.b2=0.95 \ 13 | opt.weight_decay=0.1 14 | 15 | # bs1 adam, fixed b2 16 | python main.py \ 17 | +model=gpt2s \ 18 | +dataset=fw_gpt2 \ 19 | opt.optimizer='adamw' \ 20 | opt.batch_size=1 \ 21 | opt.peak_lr=0.00015 \ 22 | opt.b1=0.9 \ 23 | opt.b2=0.95 24 | 25 | # bs1 adam, fixed t2 26 | python main.py \ 27 | +model=gpt2s \ 28 | +dataset=fw_gpt2 \ 29 | opt.optimizer='adamw' \ 30 | opt.batch_size=1 \ 31 | opt.peak_lr=0.0024 \ 32 | opt.b1=0.9 \ 33 | opt.b2=0.9999 34 | 35 | # bs1 sgd 36 | python main.py \ 37 | +model=gpt2s \ 38 | +dataset=fw_gpt2 \ 39 | opt.optimizer='sgd' \ 40 | opt.batch_size=1 \ 41 | opt.peak_lr=0.2 42 | 43 | # bs1 adafactor 44 | python main.py \ 45 | +model=gpt2s \ 46 | +dataset=fw_gpt2 \ 47 | opt.optimizer='adafactor' \ 48 | opt.batch_size=1 \ 49 | opt.peak_lr=0.005 \ 50 | opt.b2=0.9999 51 | -------------------------------------------------------------------------------- /finetuning/README.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/martin-marek/batch-size/blob/main/finetuning/finetune.ipynb) 4 | 5 | For our fine-tuning experiments, we implemented Gemma 3 mostly from scratch, including sampling. Our code is loosely based on the [Gemma NNX example](https://github.com/google/flax/tree/main/examples/gemma). 6 | 7 | ## Manual training 8 | 9 | For running experiments, we recommend interfacing with Python, although Bash is also supported thanks to [Fire](https://github.com/google/python-fire). 10 | 11 | ```python 12 | # Python 13 | from finetune import finetune 14 | finetune(model_variant='gemma3-1b') 15 | ``` 16 | 17 | ```bash 18 | # Bash 19 | python finetune.py --model_variant='gemma3-1b' 20 | ``` 21 | 22 | To quickly get started, we provide a [Colab Notebook](https://colab.research.google.com/github/martin-marek/batch-size/blob/main/finetuning/finetune.ipynb) to fine-tune Gemma 3 (12B) using a TPU v6e-1 with just 32 GB of memory. 23 | 24 | ## Sweeps 25 | 26 | We performed our main experiment using Weights & Biases sweeps. The config files for each sweep can be found in the [runs](runs) directory. 27 | -------------------------------------------------------------------------------- /finetuning/rope.py: -------------------------------------------------------------------------------- 1 | """https://github.com/google/flax/blob/main/examples/gemma/positional_embeddings.py""" 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | 7 | def apply_rope( 8 | inputs: jax.Array, # [B, T, N, H] 9 | positions: jax.Array, # [B, T] 10 | max_wavelength: int = 10_000, 11 | scale_factor: float = 1.0, 12 | ) -> jax.Array: 13 | """Applies RoPE.""" 14 | B, T, N, H = inputs.shape 15 | 16 | fraction = 2 * jnp.arange(0, H // 2) / H # [H/2] 17 | timescale = max_wavelength**fraction # [H/2] 18 | 19 | sinusoid_inp = (positions[..., None] / timescale[None, None, :]) # [B, T, H/2] 20 | sinusoid_inp = sinusoid_inp[..., None, :] # [B, T, 1, H/2] 21 | sinusoid_inp /= scale_factor # [B, T, 1, H/2] 22 | 23 | sin = jnp.sin(sinusoid_inp) # [B, T, 1, H/2] 24 | cos = jnp.cos(sinusoid_inp) # [B, T, 1, H/2] 25 | 26 | first_half, second_half = jnp.split(inputs, 2, axis=-1) # [B, T, N, H/2] 27 | first_part = first_half * cos - second_half * sin # [B, T, N, H/2] 28 | second_part = second_half * cos + first_half * sin # [B, T, N, H/2] 29 | out = jnp.concatenate([first_part, second_part], axis=-1) # [B, T, N, H] 30 | return out.astype(inputs.dtype) # [B, T, N, H] 31 | -------------------------------------------------------------------------------- /finetuning/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | @jax.jit 6 | def to_bf16_stochastic(key, source): 7 | """ 8 | performs (float32 -> bfloat16) stochastic rounding 9 | based on https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905 10 | """ 11 | # ensure the source array is float32, the bitwise logic depends on it 12 | source = source.astype(jnp.float32) 13 | 14 | # reinterpert float32 source as uint32 to allow bitwise operations 15 | source_uint32 = jax.lax.bitcast_convert_type(source, jnp.uint32) 16 | 17 | # randomly flip lower 16 bits of the float32 source 18 | # these are the bits that get truncated when converting to bf16 19 | random_int = jax.random.randint( 20 | key, 21 | shape=source.shape, 22 | minval=0, 23 | maxval=(1 << 16), 24 | dtype=jnp.uint32 25 | ) 26 | result_uint32 = source_uint32 + random_int 27 | 28 | # mask off lower 16 bits, keep top 16 bits (corresponding to bf16 format) 29 | mask = jnp.uint32(0xFFFF0000) 30 | result_uint32 = jax.lax.bitwise_and(result_uint32, mask) 31 | 32 | # cast result to bf16 33 | result_fp32 = jax.lax.bitcast_convert_type(result_uint32, jnp.float32) 34 | result_bf16 = result_fp32.astype(jnp.bfloat16) 35 | 36 | return result_bf16 37 | -------------------------------------------------------------------------------- /pretraining/rope.py: -------------------------------------------------------------------------------- 1 | """https://github.com/google/flax/blob/main/examples/gemma/positional_embeddings.py""" 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | 7 | def apply_rope( 8 | inputs: jax.Array, # [B, T, N, H] 9 | positions: jax.Array, # [B, T] 10 | max_wavelength: int = 10_000, 11 | scale_factor: float = 1.0, 12 | ) -> jax.Array: 13 | """Applies RoPE.""" 14 | B, T, N, H = inputs.shape 15 | if scale_factor < 1.0: 16 | raise ValueError(f'scale_factor must be >= 1.0, got {scale_factor}') 17 | 18 | fraction = 2 * jnp.arange(0, H // 2) / H # [H/2] 19 | timescale = max_wavelength**fraction # [H/2] 20 | 21 | sinusoid_inp = (positions[:, :, None] / timescale[None, None, :]) # [B, T, H/2] 22 | sinusoid_inp = sinusoid_inp[:, :, None, :] # [B, T, 1, H/2] 23 | sinusoid_inp /= scale_factor # [B, T, 1, H/2] 24 | 25 | sin = jnp.sin(sinusoid_inp) # [B, T, 1, H/2] 26 | cos = jnp.cos(sinusoid_inp) # [B, T, 1, H/2] 27 | 28 | first_half, second_half = jnp.split(inputs, 2, axis=-1) # [B, T, N, H/2] 29 | first_part = first_half * cos - second_half * sin # [B, T, N, H/2] 30 | second_part = second_half * cos + first_half * sin # [B, T, N, H/2] 31 | out = jnp.concatenate([first_part, second_part], axis=-1) # [B, T, N, H] 32 | return out.astype(inputs.dtype) # [B, T, N, H] 33 | -------------------------------------------------------------------------------- /pretraining/main.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from train import train_and_evaluate 3 | from configs import resolver_setup 4 | from omegaconf import OmegaConf, DictConfig 5 | from utils import flatten_dict 6 | 7 | 8 | # load default config 9 | @hydra.main(version_base=None, config_path='configs', config_name='base') 10 | def main(c: DictConfig): 11 | 12 | # optionally load batch size config 13 | if 'bs_configs' in c: 14 | bs_config = c.bs_configs[f'bs{c.opt.batch_size}'] 15 | c = OmegaConf.merge(c, bs_config) 16 | del c.bs_configs 17 | 18 | # optionally overwrite any values 19 | if 'overwrite' in c: 20 | c = OmegaConf.merge(c, c.overwrite) 21 | del c.overwrite 22 | 23 | # optionally translate 1d scaling to generalized scaling config 24 | if 'scaling_1d' in c: 25 | OmegaConf.update(c, f'scaling.{c.scaling_1d.key}', c.scaling_1d.value, force_add=True) 26 | del c.scaling_1d 27 | 28 | # optionally apply generalized scaling config 29 | if 'scaling' in c: 30 | for k, scaling in flatten_dict(c.scaling).items(): 31 | orig_val = OmegaConf.select(c, k) 32 | OmegaConf.update(c, k, scaling * orig_val) 33 | del c.scaling 34 | 35 | # run training job 36 | OmegaConf.resolve(c) 37 | print(OmegaConf.to_yaml(c)) 38 | train_and_evaluate(c) 39 | 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /pretraining/runs/gpt3xl/gpt3xl_optimizer_comp.sh: -------------------------------------------------------------------------------- 1 | 2 | # oai baseline 3 | python main.py \ 4 | +model=gpt3xl \ 5 | +dataset=fw_gpt2 \ 6 | opt.optimizer='adamw' \ 7 | opt.batch_size=512 \ 8 | opt.max_microbatch_size=16 \ 9 | model.remat=True \ 10 | opt.peak_lr=0.0002 \ 11 | opt.b1=0.9 \ 12 | opt.b2=0.95 \ 13 | opt.weight_decay=0.1 \ 14 | run_name='gpt3xl_oai_2' 15 | 16 | 17 | # bs1 adam, fixed b2 18 | python main.py \ 19 | +model=gpt3xl \ 20 | +dataset=fw_gpt2 \ 21 | opt.optimizer='adamw' \ 22 | opt.batch_size=1 \ 23 | num_tp_devices=4 \ 24 | opt.peak_lr=0.000067 \ 25 | opt.b1=0.9 \ 26 | opt.b2=0.95 \ 27 | run_name='gpt3xl_adam_b2_2' 28 | 29 | # bs1 adam, fixed t2 30 | python main.py \ 31 | +model=gpt3xl \ 32 | +dataset=fw_gpt2 \ 33 | opt.optimizer='adamw' \ 34 | opt.batch_size=1 \ 35 | num_tp_devices=4 \ 36 | opt.peak_lr=0.000067 \ 37 | opt.b1=0.9 \ 38 | opt.b2=0.9999 \ 39 | run_name='gpt3xl_adam_t2_2' 40 | 41 | # bs1 sgd 42 | python main.py \ 43 | +model=gpt3xl \ 44 | +dataset=fw_gpt2 \ 45 | opt.optimizer='sgd' \ 46 | opt.batch_size=1 \ 47 | num_tp_devices=4 \ 48 | opt.peak_lr=0.15 \ 49 | run_name='gpt3xl_sgd_2' 50 | 51 | # bs1 adafactor 52 | python main.py \ 53 | +model=gpt3xl \ 54 | +dataset=fw_gpt2 \ 55 | opt.optimizer='adafactor' \ 56 | opt.batch_size=1 \ 57 | num_tp_devices=4 \ 58 | opt.peak_lr=0.0032 \ 59 | opt.b2=0.9999 \ 60 | run_name='gpt3xl_adafactor_2_0032' 61 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import pandas as pd 3 | import numpy as np 4 | wandb_api = wandb.Api() 5 | 6 | 7 | def load_sweeps(sweep_names, entity='martin-nyu', project='picodo-bs'): 8 | sweeps = wandb_api.project(project, entity).sweeps() 9 | sweeps = [sweep for sweep in sweeps if sweep.name in sweep_names] 10 | print(f'{len(sweeps)=}') 11 | runs = [] 12 | for sweep in sweeps: 13 | sweep.runs.per_page = len(sweep.runs) # required to load all runs: https://github.com/wandb/wandb/issues/7666 14 | for run in sweep.runs: 15 | run_data = {'id': run.id} | run.config | dict(run.summary) 16 | runs += [run_data] 17 | df = pd.DataFrame(runs) 18 | return df 19 | 20 | 21 | def halflife_to_decay(t_token, n_batch=1): 22 | """ 23 | notation: 24 | - t_token: halflife measured in number of tokens 25 | - t_steps: halflife measured in number of steps 26 | - n_batch: number of tokens per batch 27 | - d: decay coefficient 28 | """ 29 | t_steps = t_token / n_batch # halflife (measured in number of steps) 30 | d = (1/2)**(1/t_steps) 31 | return d 32 | 33 | 34 | def decay_to_halflife(d, n_batch=1): 35 | """ 36 | notation: 37 | - t_token: halflife measured in number of tokens 38 | - t_steps: halflife measured in number of steps 39 | - n_batch: number of tokens per batch 40 | - d: decay coefficient 41 | """ 42 | # note: d**t_steps = 1/2 43 | t_steps = np.log(1/2) / np.log(d) 44 | t_token = t_steps * n_batch 45 | return t_token 46 | -------------------------------------------------------------------------------- /pretraining/configs/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - override hydra/hydra_logging: disabled 4 | - override hydra/job_logging: disabled 5 | 6 | seed: 0 7 | ds_path: null 8 | tokens_params_ratio: 20 # chinchilla scaling 9 | num_tokens_train: null 10 | log_every_tokens: 1_000_000 11 | num_tokens_valid: 20_000_000 12 | pad_eval: false 13 | wandb_project: 'picodo-bs' 14 | wandb_mode: 'disabled' 15 | run_name: null 16 | num_tp_devices: 1 # optional tensor parallelism 17 | 18 | model: 19 | D: null # model/embed/qkv dim 20 | L: null # num. block layers 21 | H: 128 # head dimension 22 | F: ${mul:4, ${model.D}} # FF inner dimension = 4 x embed dim. 23 | N: ${floordiv:${model.D}, ${model.H}} # num. attention heads 24 | T: null # context/sequence length 25 | V: null # vocab size -> must match dataset tokenizer! 26 | param_dtype: 'float32' 27 | activ_dtype: 'float32' 28 | remat: false 29 | use_flash_attn: true 30 | 31 | opt: 32 | optimizer: null 33 | batch_size: null 34 | max_microbatch_size: .inf 35 | microbatch_size: ${min:${opt.batch_size}, ${opt.max_microbatch_size}} 36 | grad_acc_steps: ${floordiv:${opt.batch_size}, ${opt.microbatch_size}} 37 | peak_lr: null 38 | peak_lr_scaled: null 39 | peak_lr_scaling: null 40 | muon_lr: null 41 | warmup_frac: 0.05 42 | b1: null 43 | b2: null 44 | t1: null # units: num. of tokens 45 | t2: null # units: num. of tokens 46 | muon_b1: null 47 | muon_t1: null # units: num. of tokens 48 | b2_min: null 49 | weight_decay: 0 50 | clip_by_global_norm: null 51 | stochastic_round: false 52 | -------------------------------------------------------------------------------- /pretraining/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jax 3 | import numpy as np 4 | import jax.numpy as jnp 5 | from jax.sharding import PartitionSpec as P 6 | 7 | 8 | def load_ds(key, mesh, ds_path, seq_len, batch_size, n_tokens_valid, n_tokens_train=None): 9 | 10 | # get dataset size 11 | print('getting dataset size...') 12 | ds_path = os.path.expanduser(ds_path) 13 | data = np.memmap(ds_path, dtype=np.uint16, mode='r') 14 | n_tokens_dataset = len(data) 15 | n_seq_dataset = n_tokens_dataset // seq_len 16 | 17 | # if n_tokens_train is None, use full dataset 18 | if n_tokens_train is not None: assert n_tokens_train + n_tokens_valid <= n_tokens_dataset 19 | if n_tokens_train is None: n_tokens_train = n_tokens_dataset - n_tokens_valid 20 | 21 | # get num. of train. and valid. batches 22 | n_batch_train = n_tokens_train // (batch_size * seq_len) 23 | n_batch_valid = n_tokens_valid // (batch_size * seq_len) 24 | n_batch = n_batch_train + n_batch_valid 25 | 26 | # memmap data 27 | print('reading data...') 28 | data = np.memmap(ds_path, dtype=np.uint16, shape=[n_batch, batch_size, seq_len], mode='r') 29 | 30 | # load data onto jax devices, sharded across batch dimension 31 | sharding = jax.sharding.NamedSharding(mesh, P(None, 'data', 'model')) 32 | callback = lambda index: data[index] 33 | data = jax.make_array_from_callback(data.shape, sharding, callback) 34 | 35 | # shuffle batches 36 | print('shuffling data...') 37 | data = jax.random.permutation(key, data, axis=0) 38 | 39 | # split data 40 | print('splitting data...') 41 | data_train = data[:n_batch_train] 42 | data_valid = data[n_batch_train:] 43 | 44 | return data_train, data_valid 45 | 46 | 47 | def pad_mask(batch, eos_token_id=1): 48 | B, L = batch.shape 49 | 50 | # get idx of last EOS token 51 | # if there is no EOS token, equals L-1 52 | idx_last_eos_token = (L - 1) - jnp.argmax(batch[:, ::-1] == eos_token_id, axis=1) 53 | 54 | # only use tokens before the last EOS token 55 | mask = jnp.arange(L)[None, :] <= idx_last_eos_token[:, None] 56 | 57 | return mask # [B, L] 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Small Batch Size Training for Language Models 2 | 3 | Official repository for the paper *[Small Batch Size Training for Language Models: When Vanilla SGD Works, and Why Gradient Accumulation Is Wasteful](https://arxiv.org/abs/2507.07101)* 4 | 5 | [![](https://img.shields.io/badge/arXiv-2507.07101-b31b1b.svg)](https://arxiv.org/abs/2507.07101) 6 | 7 | ## Key results 8 | 9 | We show that when a small batch size is used, vanilla SGD without momentum converges almost as fast as AdamW for LLM pretraining on a per-FLOP basis. In general, we find that as the batch size is reduced, the performance gap between different optimizers shrinks. 10 | 11 | 12 | 13 | Additionally, small batch sizes are much more robust to hyperparameter mispecification, meaning that when the tuning budget is limited, small batch sizes perform better in expecation. 14 | 15 | 16 | 17 | We hope that our results can be useful for memory-constrained practitioners, since small batch sizes allow the use of simple optimizers. For example, instead of using LoRA for fine-tuning, it might be preferable to do full fine-tuning with a small batch size and a memory-efficient optimizer like Adafactor, matching the performance of Adam while maintaining a similar memory footprint to LoRA. 18 | 19 | 20 | 21 | ## Code structure 22 | 23 | We implemented all of our experiments in JAX from scratch, using a mix of data, tensor, and sequence parallelism. We used two independent codebases for [pretraining](pretraining) and [fine-tuning](finetuning). Please refer to either codebase for more details on running experiments. 24 | 25 | All of our visualizations were done using Jupyter Notebooks found in the [utils](utils) directory. 26 | 27 | ## Citation 28 | 29 | ```bibtex 30 | @misc{smallbatch, 31 | title={Small Batch Size Training for Language Models: When Vanilla SGD Works, and Why Gradient Accumulation Is Wasteful}, 32 | author={Martin Marek and Sanae Lotfi and Aditya Somasundaram and Andrew Gordon Wilson and Micah Goldblum}, 33 | year={2025}, 34 | eprint={2507.07101}, 35 | archivePrefix={arXiv}, 36 | primaryClass={cs.LG} 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /pretraining/download_fineweb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import fire 3 | import numpy as np 4 | from pathlib import Path 5 | from tqdm.auto import tqdm 6 | from huggingface_hub import hf_hub_download 7 | from huggingface_hub.utils import disable_progress_bars; disable_progress_bars() 8 | from typing import Optional, Literal 9 | 10 | 11 | def load_data_shard(file): 12 | # https://github.com/KellerJordan/modded-nanogpt/blob/a202a3a0ca99d69bb7f847e5337c7c6e0890fd92/train_gpt.py#L411 13 | header = np.fromfile(file, dtype=np.int32, count=256) # header is 256 int32 14 | assert header[0] == 20240520, "magic number mismatch in the data .bin file" 15 | assert header[1] == 1, "unsupported version" 16 | num_tokens = int(header[2]) # number of tokens (claimed) 17 | with Path(file).open("rb", buffering=0) as f: 18 | tokens = np.empty(num_tokens, dtype=np.uint16) # avoid pin_memory copy by @YouJiacheng 19 | f.seek(256 * 4) 20 | nbytes = f.readinto(tokens) # avoid bytes->array copy by @YouJiacheng 21 | assert nbytes == 2 * num_tokens, "number of tokens read does not match header" 22 | return tokens 23 | 24 | 25 | def download_dataset( 26 | dataset: Literal['fineweb', 'finewebedu'] = 'fineweb', 27 | num_chunks: Optional[int] = None, 28 | ): 29 | """download dataset, save it as a np.memmap binary file""" 30 | 31 | # get num. chunks 32 | # by default, download all chunnks (10B tokens) 33 | # each chunk is 100M tokens 34 | if num_chunks is None: 35 | if dataset == 'fineweb': num_chunks = 103 36 | if dataset == 'finewebedu': num_chunks = 99 37 | 38 | # load chunks into memory 39 | print('downloading...') 40 | shards = [] 41 | for i in tqdm(range(1, num_chunks+1)): 42 | shard_path = hf_hub_download(repo_id=f'kjj0/{dataset}10B-gpt2', filename=f'{dataset}_train_{i:06}.bin', repo_type="dataset") 43 | shards += [load_data_shard(shard_path)] 44 | 45 | # save to disk 46 | print('saving...') 47 | out_dir = os.path.expanduser('~/datasets') 48 | out_path = f'{out_dir}/{dataset}_gpt2.bin' 49 | os.makedirs(out_dir, exist_ok=True) 50 | n_tokens = sum(map(len, shards)) 51 | out = np.memmap(out_path, dtype=np.uint16, mode='w+', shape=[n_tokens]) 52 | i = 0 53 | for shard in tqdm(shards): 54 | out[i:i+len(shard)] = shard 55 | i += len(shard) 56 | out.flush() 57 | 58 | 59 | if __name__ == '__main__': 60 | fire.Fire(download_dataset) 61 | -------------------------------------------------------------------------------- /pretraining/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flax import nnx 4 | from collections.abc import Mapping 5 | 6 | 7 | def flatten_dict(d, prefix=None, sep='.'): 8 | if isinstance(d, Mapping): 9 | out = {} 10 | for k, v in d.items(): 11 | nested_prefix = k if prefix is None else f'{prefix}{sep}{k}' 12 | out |= flatten_dict(v, nested_prefix, sep) 13 | return out 14 | else: 15 | return {prefix: d} 16 | 17 | 18 | def get_num_model_params(model: nnx.Module): 19 | graphdef, params = nnx.split(model, nnx.Param) 20 | n_params = jax.tree.reduce(lambda x, y: x + jnp.size(y), params, 0) 21 | return n_params 22 | 23 | 24 | def halflife_to_decay(t_token, n_batch=1): 25 | """ 26 | notation: 27 | - t_token: halflife measured in number of tokens 28 | - t_steps: halflife measured in number of steps 29 | - n_batch: number of tokens per batch 30 | - d: decay coefficient 31 | """ 32 | t_steps = t_token / n_batch # halflife (measured in number of steps) 33 | d = (1/2)**(1/t_steps) 34 | return d 35 | 36 | 37 | def decay_to_halflife(d, n_batch=1): 38 | """ 39 | notation: 40 | - t_token: halflife measured in number of tokens 41 | - t_steps: halflife measured in number of steps 42 | - n_batch: number of tokens per batch 43 | - d: decay coefficient 44 | """ 45 | # note: d**t_steps = 1/2 46 | t_steps = jnp.log(1/2) / jnp.log(d) 47 | t_token = t_steps * n_batch 48 | return t_token 49 | 50 | 51 | @jax.jit 52 | def to_bf16_stochastic(key, source): 53 | """ 54 | performs (float32 -> bfloat16) stochastic rounding 55 | based on https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905 56 | """ 57 | # ensure the source array is float32, the bitwise logic depends on it 58 | source = source.astype(jnp.float32) 59 | 60 | # reinterpert float32 source as uint32 to allow bitwise operations 61 | source_uint32 = jax.lax.bitcast_convert_type(source, jnp.uint32) 62 | 63 | # randomly flip lower 16 bits of the float32 source 64 | # these are the bits that get truncated when converting to bf16 65 | random_int = jax.random.randint( 66 | key, 67 | shape=source.shape, 68 | minval=0, 69 | maxval=(1 << 16), 70 | dtype=jnp.uint32 71 | ) 72 | result_uint32 = source_uint32 + random_int 73 | 74 | # mask off lower 16 bits, keep top 16 bits (corresponding to bf16 format) 75 | mask = jnp.uint32(0xFFFF0000) 76 | result_uint32 = jax.lax.bitwise_and(result_uint32, mask) 77 | 78 | # cast result to bf16 79 | result_fp32 = jax.lax.bitcast_convert_type(result_uint32, jnp.float32) 80 | result_bf16 = result_fp32.astype(jnp.bfloat16) 81 | 82 | return result_bf16 83 | -------------------------------------------------------------------------------- /finetuning/optimizer.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import optax 4 | from optax import tree_utils as otu 5 | from flax import nnx 6 | import factorized, utils 7 | from typing import Optional 8 | 9 | 10 | class ModelAndOptimizer(nnx.Optimizer): 11 | """ 12 | Extends nnx.ModelAndOptimizer (v0.12.0) with stochastic rounding. 13 | """ 14 | def __init__(self, model, tx, wrt=nnx.Param, stochastic_round=False): 15 | super().__init__(model, tx, wrt=wrt) 16 | self.model = model 17 | self.stochastic_round = stochastic_round # <- CHANGED: added stochastic_round support 18 | 19 | def update(self, key, grads, **kwargs): 20 | param_arrays = nnx.to_arrays(nnx.pure(nnx.state(self.model, self.wrt))) 21 | grad_arrays = nnx.to_arrays(nnx.pure(nnx.state(grads))) 22 | opt_state_arrays = nnx.to_arrays(nnx.pure(self.opt_state)) 23 | kwargs_arrays = nnx.to_arrays(nnx.pure(kwargs)) 24 | 25 | updates, new_opt_state = self.tx.update(grad_arrays, opt_state_arrays, param_arrays, **kwargs_arrays) 26 | new_params = apply_updates(key, param_arrays, updates, self.stochastic_round) # <- CHANGED: added stochastic_round support 27 | 28 | nnx.update(self.model, new_params) 29 | nnx.update(self.opt_state, nnx.state(new_opt_state)) 30 | self.step[...] += 1 31 | 32 | 33 | def apply_updates( 34 | key: jax.Array, 35 | params: optax.Params, 36 | updates: optax.Updates, 37 | stochastic_round = False 38 | ) -> optax.Params: 39 | """Extends optax.apply_updates with stochastic rounding.""" 40 | keys = otu.tree_split_key_like(key, params) 41 | def leaf_update(p, u, key): 42 | if p is None: return None 43 | param_dtype = jnp.asarray(p).dtype 44 | if stochastic_round: 45 | p = p.astype(jnp.float32) + u 46 | p = utils.to_bf16_stochastic(key, p) 47 | else: 48 | p += u 49 | return p.astype(param_dtype) 50 | return jax.tree.map(leaf_update, params, updates, keys, is_leaf=lambda x: x is None) 51 | 52 | 53 | def adafactor( 54 | learning_rate: optax.ScalarOrSchedule, 55 | decay_rate: float = 0.8, 56 | clipping_threshold: Optional[float] = 1.0, 57 | min_dim_size_to_factor: int = 128, 58 | ) -> optax.GradientTransformation: 59 | """ 60 | Adafactor reimplemented to use float32 state, regardless of param dtype. 61 | https://github.com/google-deepmind/optax/blob/8973bb3c77b07850737246815f1c028b53fffbe0/optax/_src/alias.py#L225#L327 62 | """ 63 | return optax.chain( 64 | factorized.scale_by_factored_rms(decay_rate=decay_rate, min_dim_size_to_factor=min_dim_size_to_factor), 65 | optax.clip_by_block_rms(clipping_threshold) if clipping_threshold is not None else optax.identity(), 66 | optax.scale_by_learning_rate(learning_rate), 67 | optax.scale_by_param_block_rms(), 68 | ) 69 | -------------------------------------------------------------------------------- /finetuning/sampler.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax 4 | from flax import nnx 5 | from functools import partial 6 | 7 | 8 | @flax.struct.dataclass 9 | class SamplingState: 10 | key: jax.Array 11 | step: jnp.int32 12 | tokens: jnp.ndarray # [B, T] 13 | kv_cache: dict 14 | done: jnp.ndarray # [B] 15 | 16 | 17 | def _sample_top_p(key, probs, p=0.95): 18 | """Sample a token using top-p sampling. 19 | https://github.com/google/flax/blob/cca78723892c539b42c261d2625168d39b61c495/examples/gemma/sampler.py#L38""" 20 | probs_sorted, indices = jax.lax.top_k(probs, k=probs.shape[-1]) 21 | cumsum_probs = jnp.cumsum(probs_sorted, axis=-1) 22 | mask = cumsum_probs - probs_sorted > p 23 | probs_sorted = jnp.where(mask, 0.0, probs_sorted) 24 | probs_sorted /= jnp.sum(probs_sorted, axis=-1, keepdims=True) 25 | next_token = jax.random.categorical(key, logits=jnp.log(probs_sorted)) 26 | next_token = jnp.take_along_axis(indices, next_token[..., None], axis=-1) 27 | next_token = jnp.squeeze(next_token, axis=-1) 28 | return next_token 29 | 30 | 31 | def _sample_step(state, model_graphdef, model_state, pad_id, eos_id, temperature=1): 32 | model = nnx.merge(model_graphdef, model_state) 33 | 34 | # sample next token 35 | key, key_sampling = jax.random.split(state.key) 36 | input_token = state.tokens[:, state.step, None] # [B, 1] 37 | logits, kv_cache = model(input_token, state.kv_cache) # [B, 1, V] 38 | if temperature == 0: 39 | sampled_token = logits[:, 0, :].argmax(1) # [B] 40 | else: 41 | probs = jax.nn.softmax(logits[:, 0, :] / temperature, axis=-1) # [B, V] 42 | sampled_token = _sample_top_p(key_sampling, probs) 43 | 44 | # update buffer 45 | next_token = state.tokens[:, state.step+1] 46 | update_token = jnp.where((~state.done) & (next_token==pad_id), sampled_token, next_token) 47 | tokens = state.tokens.at[:, state.step+1].set(update_token) 48 | 49 | # check if sampling is done 50 | done = state.done | ((next_token==pad_id) & (sampled_token==eos_id)) 51 | 52 | return SamplingState(key, state.step+1, tokens, kv_cache, done) 53 | 54 | 55 | @partial(jax.jit, static_argnames=('model_graphdef', 'temperature')) 56 | def sample(key, model_graphdef, model_state, tokens, temperature=1, pad_id=0, eos_id=1): 57 | model = nnx.merge(model_graphdef, model_state) 58 | B, T = tokens.shape 59 | 60 | # initialize state 61 | state = SamplingState( 62 | key=key, 63 | step=0, 64 | tokens=tokens, 65 | kv_cache=model.init_kv_cache(B, T), 66 | done=jnp.zeros([B], dtype=jnp.bool_), 67 | ) 68 | 69 | # sample next token inside a while loop 70 | step_fn = lambda state: _sample_step(state, *nnx.split(model), pad_id, eos_id, temperature) 71 | cond_fn = lambda state: (state.step < T) & jnp.any(~state.done) 72 | state = jax.lax.while_loop(cond_fn, step_fn, state) 73 | 74 | return state.tokens 75 | -------------------------------------------------------------------------------- /pretraining/train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import jax 3 | import jax.numpy as jnp 4 | import optax 5 | import wandb 6 | from functools import partial 7 | from flax import nnx 8 | from optax import tree_utils as otu 9 | from tqdm.auto import tqdm 10 | from omegaconf.dictconfig import DictConfig 11 | import data, utils 12 | import model as model_lib 13 | import optimizer as optimizer_lib 14 | 15 | 16 | @partial(jax.jit, static_argnames=('model_graphdef', 'pad')) 17 | def loss_fn(model_state, model_graphdef, x, pad=False): # [B, T] 18 | model = nnx.merge(model_graphdef, model_state) 19 | y = jnp.roll(x, -1, axis=1) 20 | loss_mask = data.pad_mask(x) if pad else jnp.ones(x.shape, dtype=bool) 21 | loss_mask = loss_mask.at[:, -1].set(False) 22 | logits = model(x) # [B, T, V] 23 | losses = optax.softmax_cross_entropy_with_integer_labels(logits.astype(jnp.float32), y) # [B, T] 24 | return (losses * loss_mask).sum() / loss_mask.sum() 25 | 26 | 27 | @partial(jax.jit, static_argnames=('opt_graphdef', 'model_graphdef'), donate_argnames=('opt_state')) 28 | def train_step(key, opt_state, opt_graphdef, model_graphdef, batch): 29 | key, key_opt = jax.random.split(key) 30 | 31 | # compute grads from a single micro-batch 32 | if batch.ndim == 2: 33 | loss, grads = jax.value_and_grad(loss_fn)(opt_state.model, model_graphdef, batch) 34 | 35 | # compute grads from multiple micro-batches (using gradient accumulation) 36 | if batch.ndim == 3: 37 | loss = 0 38 | grads = otu.tree_zeros_like(opt_state.model, dtype=jnp.float32) 39 | def step_fn(i , args): 40 | loss, grads = args 41 | batch_loss, batch_grads = jax.value_and_grad(loss_fn)(opt_state.model, model_graphdef, batch[i]) 42 | loss = (i*loss + batch_loss) / (i+1) 43 | grads = jax.tree.map(lambda m, g: (i*m + g) / (i+1), grads, batch_grads) 44 | return loss, grads 45 | loss, grads = jax.lax.fori_loop(0, len(batch), step_fn, (loss, grads)) 46 | 47 | # optimizer step 48 | optimizer = nnx.merge(opt_graphdef, opt_state) 49 | optimizer.update(key_opt, grads) 50 | opt_state = nnx.state(optimizer) 51 | return key, opt_state, loss 52 | 53 | 54 | def eval_step(model_state, model_graphdef, dataset, pad=False): 55 | loss = 0 56 | for batch in dataset: 57 | loss += loss_fn(model_state, model_graphdef, batch, pad) 58 | return loss / len(dataset) 59 | 60 | 61 | def train_and_evaluate(c: DictConfig): 62 | 63 | # get model and dataset rng seed 64 | key = jax.random.key(c.seed) 65 | key, key_model, key_dataset = jax.random.split(key, 3) 66 | 67 | # sharding 68 | num_fsdp_devices = jax.device_count() // c.num_tp_devices 69 | mesh = jax.make_mesh((num_fsdp_devices, c.num_tp_devices), ('data', 'model')) 70 | jax.set_mesh(mesh) 71 | print('sharding mesh:', ', '.join(f'{k}={v}' for k, v in mesh.shape.items())) 72 | 73 | # model 74 | print('initializing model...') 75 | c.model.V = int(math.ceil(c.model.V / jax.device_count()) * jax.device_count()) # round V up to enable sharding 76 | model = model_lib.create_sharded_model(c.model, key_model) 77 | model_graphdef = nnx.graphdef(model) 78 | 79 | # get num. model parameters 80 | n_params = { 81 | 'n_param_nonembed': 12 * c.model.L * c.model.D**2, 82 | 'n_param_embed': c.model.D * c.model.V, 83 | 'n_param_actual': utils.get_num_model_params(model), 84 | } 85 | for k, v in n_params.items(): 86 | print(f'{k}={v:_}') 87 | 88 | # dataset 89 | if (c.num_tokens_train is None) and (c.tokens_params_ratio is not None): 90 | c.num_tokens_train = c.tokens_params_ratio * (n_params['n_param_nonembed'] + n_params['n_param_embed']) 91 | ds_train, ds_valid = data.load_ds(key_dataset, mesh, c.ds_path, c.model.T, c.opt.microbatch_size, c.num_tokens_valid, c.num_tokens_train) 92 | if (c.num_tokens_train is None): c.num_tokens_train = ds_train.size 93 | 94 | # optimizer 95 | num_opt_steps = len(ds_train) // c.opt.grad_acc_steps 96 | tokens_per_opt_step = c.opt.batch_size * c.model.T 97 | tx = optimizer_lib.get_optimizer(c.opt, num_opt_steps, tokens_per_opt_step) 98 | optimizer = optimizer_lib.ModelAndOptimizer(model, tx, stochastic_round=c.opt.stochastic_round) 99 | opt_graphdef, opt_state = nnx.split(optimizer) 100 | 101 | # start wandb 102 | if jax.process_index() == 0: 103 | wandb.init(project=c.wandb_project, config=utils.flatten_dict(c), mode=c.wandb_mode, name=c.run_name) 104 | wandb.summary.update(n_params) 105 | 106 | # training loop 107 | train_loss_sum, train_loss_num = jnp.zeros([]), 0 108 | pbar = range(num_opt_steps) 109 | if jax.process_index() == 0: pbar = tqdm(pbar) 110 | for step in pbar: 111 | 112 | # get batch 113 | if c.opt.grad_acc_steps == 1: 114 | batch = ds_train[step] # [batch_size, T] 115 | if c.opt.grad_acc_steps > 1: 116 | batch = ds_train[step*c.opt.grad_acc_steps:(step+1)*c.opt.grad_acc_steps] # [grad_acc_steps, micro_batch_size, T] 117 | 118 | # training step 119 | key, opt_state, batch_loss = train_step(key, opt_state, opt_graphdef, model_graphdef, batch) 120 | 121 | # logging 122 | train_loss_sum += batch_loss 123 | train_loss_num += 1 124 | if train_loss_num * tokens_per_opt_step >= c.log_every_tokens: 125 | metrics = {} 126 | metrics['train_loss'] = train_loss_sum / train_loss_num 127 | metrics['train_tokens_seen'] = (step+1) * tokens_per_opt_step 128 | if jax.process_index() == 0: 129 | wandb.log(metrics, step) 130 | pbar.set_postfix_str(f'loss={metrics["train_loss"]:.2f}') 131 | train_loss_sum, train_loss_num = jnp.zeros([]), 0 132 | 133 | # eval at end of training 134 | eval_loss = eval_step(opt_state.model, model_graphdef, ds_valid, c.pad_eval) 135 | if jax.process_index() == 0: 136 | wandb.log({'eval_loss': eval_loss}, step) 137 | wandb.finish() 138 | -------------------------------------------------------------------------------- /finetuning/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jax 3 | import jax.numpy as jnp 4 | from flax import nnx 5 | import numpy as np 6 | import datasets 7 | from jax.sharding import NamedSharding, PartitionSpec as P 8 | from math_verify import parse, verify 9 | from sampler import sample 10 | from tqdm.auto import tqdm 11 | 12 | 13 | def load_datasets(vocab, seq_len=1024): 14 | pad_id = vocab.pad_id() 15 | bos_id = vocab.bos_id() 16 | eos_id = vocab.eos_id() 17 | 18 | # load MATH dataset 19 | print('loading datasets...') 20 | ds_name = 'EleutherAI/hendrycks_math' 21 | configs = datasets.get_dataset_config_names(ds_name) # ['algebra', 'counting_and_probability', 'geometry', 'intermediate_algebra', 'number_theory', 'prealgebra', 'precalculus'] 22 | ds_train = datasets.concatenate_datasets([datasets.load_dataset(ds_name, config, split='train') for config in configs]) # ['problem', 'solution'] 23 | ds_valid = datasets.concatenate_datasets([datasets.load_dataset(ds_name, config, split='test') for config in configs]) # ['problem', 'solution'] 24 | 25 | # tokenize trainind dataset 26 | print('tokenizing training dataset...') 27 | train_tokens = np.full([len(ds_train), seq_len], pad_id, dtype=np.int32) 28 | train_pos = np.zeros([len(ds_train), seq_len], dtype=np.int32) 29 | train_loss_mask = np.zeros([len(ds_train), seq_len], dtype=np.bool_) 30 | train_attn_mask = np.zeros([len(ds_train), seq_len, seq_len], dtype=np.bool_) 31 | seq_idx = 0 32 | tok_idx = 0 33 | skipped = 0 34 | for example in ds_train: 35 | 36 | # tokenize example 37 | prompt = f'Problem: {example["problem"]}\nSolution: ' 38 | solution = f'{example["solution"]}' 39 | prompt_tokenized, solution_tokenized = vocab.EncodeAsIds([prompt, solution]) 40 | example_tokenized = [bos_id] + prompt_tokenized + solution_tokenized + [eos_id] 41 | 42 | # if example is too long, skip it 43 | if len(example_tokenized) > seq_len: 44 | skipped += 1 45 | continue 46 | 47 | # if example doesn't fit in current sequence, start next sequence 48 | if tok_idx + len(example_tokenized) > seq_len: 49 | seq_idx += 1 50 | tok_idx = 0 51 | 52 | # store tokens 53 | train_tokens[seq_idx, tok_idx:tok_idx+len(example_tokenized)] = example_tokenized 54 | train_pos[seq_idx, tok_idx:tok_idx+len(example_tokenized)] = np.arange(len(example_tokenized)) 55 | train_loss_mask[seq_idx, tok_idx+len(prompt_tokenized):tok_idx+len(example_tokenized)-1] = True 56 | train_attn_mask[seq_idx, tok_idx:tok_idx+len(example_tokenized), tok_idx:tok_idx+len(example_tokenized)] = True 57 | tok_idx += len(example_tokenized) 58 | train_attn_mask = np.tril(train_attn_mask) 59 | train_tokens = train_tokens[:seq_idx+1] 60 | train_pos = train_pos[:seq_idx+1] 61 | train_attn_mask = train_attn_mask[:seq_idx+1] 62 | train_loss_mask = train_loss_mask[:seq_idx+1] 63 | print(f'skipped train. seq.: {skipped / len(ds_train):.1%}') 64 | 65 | # tokenize eval dataset 66 | print('tokenizing eval dataset...') 67 | skipped = 0 68 | prompts_eval = [] 69 | problems_eval = [] 70 | solutions_eval = [] 71 | tokens_eval = np.full([len(ds_valid), seq_len], pad_id, dtype=np.int32) 72 | for i, example in enumerate(ds_valid): 73 | problems_eval += [example['problem']] 74 | solutions_eval += [example['solution']] 75 | prompt = f'Problem: {example["problem"]}\nSolution: ' 76 | prompt_tokenized = [bos_id] + vocab.EncodeAsIds(prompt) 77 | if len(prompt_tokenized) < seq_len: 78 | tokens_eval[i, :len(prompt_tokenized)] = prompt_tokenized 79 | else: 80 | skipped += 1 81 | problems_eval = np.array(problems_eval) 82 | solutions_eval = np.array(solutions_eval) 83 | print(f'skipped valid. seq.: {skipped / len(ds_valid):.1%}') 84 | 85 | return train_tokens, train_pos, train_attn_mask, train_loss_mask, tokens_eval, problems_eval, solutions_eval 86 | 87 | 88 | def benchmark_model(key, model, tokens, problems_eval, solutions_eval, vocab, batch_size, n_eval_samples=None, temperature=1, print_output=True): 89 | pad_id = vocab.pad_id() 90 | eos_id = vocab.eos_id() 91 | key_decoding, key_questions = jax.random.split(key) 92 | mesh = model.in_embed.embedding.value.sharding.mesh 93 | if n_eval_samples is None: n_eval_samples = len(tokens) 94 | n_batches = n_eval_samples // batch_size 95 | sample_idxs = jax.random.choice(key_questions, len(tokens), shape=[n_batches, batch_size], replace=False) 96 | lengths_list = [] 97 | correct_list = [] 98 | finished_list = [] 99 | pbar = tqdm(sample_idxs, desc='Sampling') if (jax.process_index() == 0) else sample_idxs 100 | for batch_idx in pbar: 101 | # sample tokens 102 | input_tokens_batch = jax.device_put(tokens[batch_idx], NamedSharding(mesh, P('data', None))) 103 | output_tokens_batch = sample(key_decoding, *nnx.split(model), input_tokens_batch, temperature) 104 | 105 | # extract output sequences 106 | completions_tokens = [] 107 | for in_seq, out_seq in zip(input_tokens_batch, output_tokens_batch): 108 | out_seq = out_seq[jnp.argmax(in_seq==pad_id):] 109 | if jnp.any(out_seq==pad_id): out_seq = out_seq[:jnp.argmax(out_seq==pad_id)] 110 | completions_tokens += [out_seq.tolist()] 111 | 112 | # eval completions 113 | completions_text = vocab.DecodeIds(completions_tokens) 114 | for sample_idx, completion_tokens, completion_text in zip(batch_idx, completions_tokens, completions_text): 115 | if sample_idx < len(problems_eval): 116 | problem = problems_eval[sample_idx] 117 | gold = solutions_eval[sample_idx] 118 | parsed = parse(completion_text) 119 | finished = eos_id in completion_tokens 120 | correct = verify(parse(gold), parsed) 121 | lengths_list += [len(completion_tokens)] 122 | finished_list += [finished] 123 | correct_list += [correct] 124 | if print_output: 125 | print('------------') 126 | print(f'PROMPT:\n{problem}\nCOMPLETION:\n{completion_text}\nPARSED: {parsed}\nGOLD: {gold}\nCORRECT: {correct}') 127 | 128 | return dict(length=np.mean(lengths_list), finished=np.mean(finished_list), accuracy=np.mean(correct_list)) 129 | -------------------------------------------------------------------------------- /pretraining/model.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import jax 3 | import jax.numpy as jnp 4 | from functools import partial 5 | from flax import nnx 6 | from jax.sharding import PartitionSpec as P 7 | from jax.experimental.shard_map import shard_map 8 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel, splash_attention_mask 9 | from omegaconf.dictconfig import DictConfig 10 | from rope import apply_rope 11 | 12 | 13 | class TransformerDecoder(nnx.Module): 14 | def __init__(self, c: DictConfig, rngs: nnx.Rngs): 15 | embed_in_init = sharded_init('embedding_in') 16 | embed_out_init = sharded_init('embedding_out') 17 | self.token_embed_in = nnx.Embed(num_embeddings=c.V, features=c.D, embedding_init=embed_in_init, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs) 18 | self.token_embed_out = nnx.Embed(num_embeddings=c.V, features=c.D, embedding_init=embed_out_init, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs) 19 | self.blocks = nnx.List(TransformerBlock(c, rngs) for _ in range(c.L)) 20 | self.out_ln = nnx.RMSNorm(c.D, use_scale=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs) 21 | self.remat = c.remat 22 | 23 | def __call__(self, x): # [B, S] 24 | 25 | # token embedding 26 | h = self.token_embed_in(x) # [B, T, D] 27 | 28 | # transformer blocks 29 | for block in self.blocks: 30 | h = jax.remat(block)(h) if self.remat else block(h) 31 | 32 | # project back to vocabulary 33 | h = self.out_ln(h) 34 | logits = self.token_embed_out.attend(h) # [B, T, V] 35 | return logits 36 | 37 | 38 | class TransformerBlock(nnx.Module): 39 | def __init__(self, c: DictConfig, rngs: nnx.Rngs): 40 | self.ln1 = nnx.RMSNorm(c.D, use_scale=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs) 41 | self.ln2 = nnx.RMSNorm(c.D, use_scale=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs) 42 | self.attn = MultiHeadAttention(c, rngs) 43 | self.mlp = MLP(c, rngs) 44 | 45 | def __call__(self, x): # [B, T, D] 46 | x = x + self.attn(self.ln1(x)) # attention block 47 | return x + self.mlp(self.ln2(x)) # MLP block 48 | 49 | 50 | class MultiHeadAttention(nnx.Module): 51 | """Causal attention layer.""" 52 | def __init__(self, c: DictConfig, rngs: nnx.Rngs): 53 | qkv_proj_init = sharded_init('attn_qkv_proj') 54 | out_proj_init = sharded_init('attn_out_proj') 55 | self.qkv_proj = nnx.Einsum('BTd,SNdH->SBTNH', (3, c.N, c.D, c.H), kernel_init=qkv_proj_init, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs) 56 | self.out_proj = nnx.Einsum('BTnh,nhD->BTD', (c.N, c.H, c.D), kernel_init=out_proj_init, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs) 57 | self.query_norm = nnx.RMSNorm(c.H, use_scale=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs) 58 | self.key_norm = nnx.RMSNorm(c.H, use_scale=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs) 59 | if c.use_flash_attn and jax.devices()[0].platform == 'tpu' and (c.H % 128 != 0): 60 | warnings.warn('cannot use flash attention because `model.H` is not a multiple of 128.') 61 | c.use_flash_attn &= jax.devices()[0].platform == 'tpu' 62 | c.use_flash_attn &= (c.H % 128 == 0) 63 | self.attention = partial(tpu_causal_flash_attention) if c.use_flash_attn else partial(jax.nn.dot_product_attention, is_causal=True) 64 | 65 | def __call__(self, x): # [B, T, D] 66 | B, T, D = x.shape 67 | 68 | # input projection 69 | q, k, v = self.qkv_proj(x) # [B, T, N, H] 70 | 71 | # qk-norm 72 | q = self.query_norm(q) 73 | k = self.key_norm(k) 74 | 75 | # position embedding 76 | position = jnp.arange(T) 77 | q = apply_rope(q, position[None]) 78 | k = apply_rope(k, position[None]) 79 | 80 | # attention 81 | out = self.attention(q, k, v) # [B, T, N, H] 82 | 83 | # output projection followed by contraction back to original dims 84 | out = self.out_proj(out) # [B, T, D] 85 | return out 86 | 87 | 88 | def tpu_causal_flash_attention(q, k, v): 89 | """ 90 | TPU Flash Attention. 91 | https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py 92 | https://github.com/AI-Hypercomputer/maxtext/blob/9ea52118535e970096c164460dbbfa478d157066/MaxText/layers/attentions.py#L562 93 | """ 94 | B, T, N, H = q.shape 95 | assert H >= 128, 'TPU flash attention reqruies head dim. to be a multiple of 128' 96 | 97 | # scale query 98 | q /= jnp.sqrt(H) 99 | 100 | # kernel block sizes 101 | # https://github.com/AI-Hypercomputer/maxtext/blob/afcdf47f8b7c1e1864fa81012a873590c5408122/MaxText/configs/base.yml#L644 102 | block_sizes = splash_attention_kernel.BlockSizes( 103 | block_q=512, 104 | block_kv=512, 105 | block_kv_compute=128, 106 | block_q_dkv=512, 107 | block_kv_dkv=512, 108 | block_kv_dkv_compute=128, 109 | block_q_dq=512, 110 | block_kv_dq=512, 111 | ) 112 | 113 | mesh = jax.sharding.get_abstract_mesh() 114 | sharding = P('data', None, 'model', None) 115 | @partial(shard_map, mesh=mesh, in_specs=(sharding, sharding, sharding), out_specs=sharding, check_rep=False) 116 | def attention(q, k, v): 117 | _, _, n, _ = q.shape 118 | causal_mask = splash_attention_mask.CausalMask(shape=(T, T)) 119 | multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(causal_mask,) * n) 120 | splash_kernel = splash_attention_kernel.make_splash_mha(mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes) 121 | out = jax.vmap(splash_kernel)( 122 | q.swapaxes(1, 2), 123 | k.swapaxes(1, 2), 124 | v.swapaxes(1, 2) 125 | ).swapaxes(1, 2) # [B, T, N, H] 126 | return out 127 | 128 | return attention(q, k, v) 129 | 130 | 131 | class MLP(nnx.Module): 132 | """Multilayer perceptron.""" 133 | def __init__(self, c: DictConfig, rngs: nnx.Rngs): 134 | fc1_init = sharded_init('mlp_fc1') 135 | fc2_init = sharded_init('mlp_fc2') 136 | self.fc1 = nnx.Linear(in_features=c.D, out_features=c.F, kernel_init=fc1_init, use_bias=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs) 137 | self.fc2 = nnx.Linear(in_features=c.F, out_features=c.D, kernel_init=fc2_init, use_bias=False, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs) 138 | 139 | def __call__(self, x): # [B, T, D] 140 | h = jax.nn.gelu(self.fc1(x)) # [B, T, F] 141 | return self.fc2(h) # [B, T, D] 142 | 143 | 144 | def sharded_init(layer_type: str): 145 | """Initialize weights with optional sharding.""" 146 | kernel_init = jax.nn.initializers.xavier_uniform() 147 | embed_init = jax.nn.initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0) 148 | match layer_type: 149 | case 'embedding_in': # [V, D] 150 | return nnx.with_partitioning(embed_init, ('data', 'model')) 151 | case 'embedding_out': # [V, D] 152 | return nnx.with_partitioning(embed_init, ('model', 'data')) 153 | case 'attn_qkv_proj': # [3, N, D, H] 154 | return nnx.with_partitioning(kernel_init, (None, 'model', 'data', None)) 155 | case 'attn_out_proj': # [N, H, D] 156 | return nnx.with_partitioning(kernel_init, ('model', None, 'data')) 157 | case 'mlp_fc1': # [D, F] 158 | return nnx.with_partitioning(kernel_init, ('data', 'model')) 159 | case 'mlp_fc2': # [F, D] 160 | return nnx.with_partitioning(kernel_init, ('model', 'data')) 161 | case _: 162 | raise ValueError(f'unrecognized layer type: {layer_type}') 163 | 164 | 165 | def create_sharded_model(c: DictConfig, key): 166 | """ 167 | initialize sharded model without putting it on a single device 168 | https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html 169 | """ 170 | seed = int(jax.random.randint(key, [1], 0, 1_000_000)[0]) 171 | 172 | @nnx.jit 173 | def initialize_sharded_model(): 174 | rngs = nnx.Rngs(seed) 175 | model = TransformerDecoder(c, rngs=rngs) # unsharded at this moment 176 | state = nnx.state(model) # the model's state, a pure pytree 177 | pspecs = nnx.get_partition_spec(state) # get annotations from state 178 | sharded_state = jax.lax.with_sharding_constraint(state, pspecs) 179 | nnx.update(model, sharded_state) # the model is sharded now 180 | return model 181 | 182 | model = initialize_sharded_model() 183 | 184 | return model -------------------------------------------------------------------------------- /finetuning/factorized.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ 16 | Factorized optimizers. 17 | Taken from https://github.com/google-deepmind/optax/blob/main/optax/_src/factorized.py 18 | Only 2 LOC were modified to hard-code scale_by_factored_rms state dtype to float32. 19 | """ 20 | 21 | from collections.abc import Callable 22 | import dataclasses 23 | from typing import NamedTuple, Optional 24 | 25 | import chex 26 | import jax 27 | import jax.numpy as jnp 28 | import numpy as np 29 | from optax._src import base 30 | from optax._src import numerics 31 | 32 | 33 | def _decay_rate_pow(i: int, exponent: float = 0.8) -> chex.Array: 34 | """Second-order moment decay schedule.""" 35 | t = jnp.array(i + 1, jnp.float32) 36 | return 1.0 - t ** (-exponent) 37 | 38 | 39 | def _factored_dims( 40 | shape: base.Shape, factored: bool, min_dim_size_to_factor: int 41 | ) -> Optional[tuple[int, int]]: 42 | """Whether to use a factored second moment estimator. 43 | 44 | This function returns a tuple with the two largest axes to reduce over. 45 | If no two dimensions have size >= min_dim_size_to_factor, return None. 46 | 47 | Args: 48 | shape: an input shape 49 | factored: whether to use factored second-moment estimator for 2d vars. 50 | min_dim_size_to_factor: only factor accumulator if two array dimensions have 51 | at least this size. 52 | 53 | Returns: 54 | None or a tuple of ints 55 | """ 56 | if not factored or len(shape) < 2: 57 | return None 58 | sorted_dims = np.argsort(shape) 59 | if shape[sorted_dims[-2]] < min_dim_size_to_factor: 60 | return None 61 | return int(sorted_dims[-2]), int(sorted_dims[-1]) 62 | 63 | 64 | @dataclasses.dataclass 65 | class _UpdateResult: 66 | """Opaque container that is not traversed by jax.tree.map.""" 67 | 68 | update: chex.Array # the update to apply to params 69 | v_row: chex.Array # used for factored params. 70 | v_col: chex.Array # used for factored params. 71 | v: chex.Array # used for params where factoring is skipped. 72 | 73 | 74 | class FactoredState(NamedTuple): 75 | """Overall state of the gradient transformation.""" 76 | 77 | count: chex.Array # number of update steps. 78 | v_row: chex.ArrayTree # Tree of factored params. 79 | v_col: chex.ArrayTree # Tree of factored params. 80 | v: chex.ArrayTree # Tree for params where factoring is skipped. 81 | 82 | 83 | def scale_by_factored_rms( 84 | factored: bool = True, 85 | decay_rate: float = 0.8, 86 | step_offset: int = 0, 87 | min_dim_size_to_factor: int = 128, 88 | epsilon: float = 1e-30, 89 | decay_rate_fn: Callable[[int, float], chex.Array] = _decay_rate_pow, 90 | ): 91 | """Scaling by a factored estimate of the gradient rms (as in Adafactor). 92 | 93 | This is a so-called "1+epsilon" scaling algorithms, that is extremely memory 94 | efficient compared to RMSProp/Adam, and has had wide success when applied to 95 | large-scale training of attention-based models. 96 | 97 | Args: 98 | factored: boolean: whether to use factored second-moment estimates.. 99 | decay_rate: float: controls second-moment exponential decay schedule. 100 | step_offset: for finetuning, one may set this to the starting step-number of 101 | the fine tuning phase. 102 | min_dim_size_to_factor: only factor accumulator if two array dimensions are 103 | at least this size. 104 | epsilon: Regularization constant for squared gradient. 105 | decay_rate_fn: A function that accepts the current step, the decay rate 106 | parameter and controls the schedule for the second momentum. Defaults to 107 | the original adafactor's power decay schedule. One potential shortcoming 108 | of the original schedule is the fact that second momentum converges to 1, 109 | which effectively freezes the second momentum. To prevent this the user 110 | can opt for a custom schedule that sets an upper bound for the second 111 | momentum, like in Zhai et al., 2021. 112 | 113 | Returns: 114 | The corresponding :class:`optax.GradientTransformation`. 115 | 116 | References: 117 | Shazeer et al, `Adafactor: Adaptive Learning Rates with Sublinear Memory 118 | Cost `_, 2018 119 | 120 | Zhai et al, `Scaling Vision Transformers 121 | `_, 2021 122 | """ 123 | 124 | def _to_state(count: chex.Array, result_tree): 125 | """Maps from a tree of (factored) values to separate trees of values.""" 126 | return FactoredState( 127 | count=count, 128 | v_row=jax.tree.map(lambda o: o.v_row, result_tree), 129 | v_col=jax.tree.map(lambda o: o.v_col, result_tree), 130 | v=jax.tree.map(lambda o: o.v, result_tree), 131 | ) 132 | 133 | def init_fn(params): 134 | """Initialise the optimizer's state.""" 135 | 136 | def _init(param): 137 | shape, dtype = param.shape, jnp.float32 # <-- CHANGED TO HARD-CODE FLOAT32 138 | factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor) 139 | if factored_dims is not None: 140 | d1, d0 = factored_dims 141 | vr_shape = np.delete(shape, d0) 142 | vc_shape = np.delete(shape, d1) 143 | return _UpdateResult( 144 | update=jnp.zeros((1,), dtype=dtype), 145 | v_row=jnp.zeros(vr_shape, dtype=dtype), 146 | v_col=jnp.zeros(vc_shape, dtype=dtype), 147 | v=jnp.zeros((1,), dtype=dtype), 148 | ) 149 | return _UpdateResult( 150 | update=jnp.zeros((1,), dtype=dtype), 151 | v_row=jnp.zeros((1,), dtype=dtype), 152 | v_col=jnp.zeros((1,), dtype=dtype), 153 | v=jnp.zeros(param.shape, dtype=dtype), 154 | ) 155 | 156 | return _to_state(jnp.zeros([], jnp.int32), jax.tree.map(_init, params)) 157 | 158 | def update_fn(grads, state, params): 159 | """Apply gradient transformation.""" 160 | if params is None: 161 | raise ValueError(base.NO_PARAMS_MSG) 162 | 163 | def _update(grad, v_row, v_col, v, param, step): 164 | shape, dtype = param.shape, jnp.float32 # <-- CHANGED TO HARD-CODE FLOAT32 165 | decay_rate_t = decay_rate_fn(step - step_offset, decay_rate) 166 | 167 | # Scaled by factorized second moment statistics. 168 | new_v_row = jnp.zeros((1,), dtype=dtype) 169 | new_v_col = jnp.zeros((1,), dtype=dtype) 170 | new_v = jnp.zeros((1,), dtype=dtype) 171 | 172 | factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor) 173 | if factored_dims is not None: 174 | d1, d0 = factored_dims 175 | grad_sqr = numerics.abs_sq(grad) + epsilon 176 | new_v_row = decay_rate_t * v_row + (1.0 - decay_rate_t) * jnp.mean( 177 | grad_sqr, axis=d0 178 | ) 179 | new_v_col = decay_rate_t * v_col + (1.0 - decay_rate_t) * jnp.mean( 180 | grad_sqr, axis=d1 181 | ) 182 | new_v_row = new_v_row.astype(dtype) 183 | new_v_col = new_v_col.astype(dtype) 184 | reduced_d1 = d1 - 1 if d1 > d0 else d1 185 | row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True) 186 | row_factor = (new_v_row / row_col_mean) ** -0.5 187 | col_factor = (new_v_col) ** -0.5 188 | update = ( 189 | grad 190 | * jnp.expand_dims(row_factor, axis=d0) 191 | * jnp.expand_dims(col_factor, axis=d1) 192 | ) 193 | else: 194 | grad_sqr = numerics.abs_sq(grad) + epsilon 195 | new_v = decay_rate_t * v + (1.0 - decay_rate_t) * grad_sqr 196 | new_v = new_v.astype(dtype) 197 | update = grad * (new_v) ** -0.5 198 | 199 | return _UpdateResult(update, new_v_row, new_v_col, new_v) 200 | 201 | # Transform grad and compute new per-parameter stats. 202 | output = jax.tree.map( 203 | lambda *args: _update(*args, state.count), 204 | grads, 205 | state.v_row, 206 | state.v_col, 207 | state.v, 208 | params, 209 | ) 210 | 211 | # Unpack updates / stats and return. 212 | updates = jax.tree.map(lambda o: o.update, output) 213 | return updates, _to_state(numerics.safe_increment(state.count), output) 214 | 215 | return base.GradientTransformation(init_fn, update_fn) 216 | -------------------------------------------------------------------------------- /pretraining/factorized.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ 16 | Factorized optimizers. 17 | Taken from https://github.com/google-deepmind/optax/blob/main/optax/_src/factorized.py 18 | Only 2 LOC were modified to hard-code scale_by_factored_rms state dtype to float32. 19 | """ 20 | 21 | from collections.abc import Callable 22 | import dataclasses 23 | from typing import NamedTuple, Optional 24 | 25 | import chex 26 | import jax 27 | import jax.numpy as jnp 28 | import numpy as np 29 | from optax._src import base 30 | from optax._src import numerics 31 | 32 | 33 | def _decay_rate_pow(i: int, exponent: float = 0.8) -> chex.Array: 34 | """Second-order moment decay schedule.""" 35 | t = jnp.array(i + 1, jnp.float32) 36 | return 1.0 - t ** (-exponent) 37 | 38 | 39 | def _factored_dims( 40 | shape: base.Shape, factored: bool, min_dim_size_to_factor: int 41 | ) -> Optional[tuple[int, int]]: 42 | """Whether to use a factored second moment estimator. 43 | 44 | This function returns a tuple with the two largest axes to reduce over. 45 | If no two dimensions have size >= min_dim_size_to_factor, return None. 46 | 47 | Args: 48 | shape: an input shape 49 | factored: whether to use factored second-moment estimator for 2d vars. 50 | min_dim_size_to_factor: only factor accumulator if two array dimensions have 51 | at least this size. 52 | 53 | Returns: 54 | None or a tuple of ints 55 | """ 56 | if not factored or len(shape) < 2: 57 | return None 58 | sorted_dims = np.argsort(shape) 59 | if shape[sorted_dims[-2]] < min_dim_size_to_factor: 60 | return None 61 | return int(sorted_dims[-2]), int(sorted_dims[-1]) 62 | 63 | 64 | @dataclasses.dataclass 65 | class _UpdateResult: 66 | """Opaque container that is not traversed by jax.tree.map.""" 67 | 68 | update: chex.Array # the update to apply to params 69 | v_row: chex.Array # used for factored params. 70 | v_col: chex.Array # used for factored params. 71 | v: chex.Array # used for params where factoring is skipped. 72 | 73 | 74 | class FactoredState(NamedTuple): 75 | """Overall state of the gradient transformation.""" 76 | 77 | count: chex.Array # number of update steps. 78 | v_row: chex.ArrayTree # Tree of factored params. 79 | v_col: chex.ArrayTree # Tree of factored params. 80 | v: chex.ArrayTree # Tree for params where factoring is skipped. 81 | 82 | 83 | def scale_by_factored_rms( 84 | factored: bool = True, 85 | decay_rate: float = 0.8, 86 | step_offset: int = 0, 87 | min_dim_size_to_factor: int = 128, 88 | epsilon: float = 1e-30, 89 | decay_rate_fn: Callable[[int, float], chex.Array] = _decay_rate_pow, 90 | ): 91 | """Scaling by a factored estimate of the gradient rms (as in Adafactor). 92 | 93 | This is a so-called "1+epsilon" scaling algorithms, that is extremely memory 94 | efficient compared to RMSProp/Adam, and has had wide success when applied to 95 | large-scale training of attention-based models. 96 | 97 | Args: 98 | factored: boolean: whether to use factored second-moment estimates.. 99 | decay_rate: float: controls second-moment exponential decay schedule. 100 | step_offset: for finetuning, one may set this to the starting step-number of 101 | the fine tuning phase. 102 | min_dim_size_to_factor: only factor accumulator if two array dimensions are 103 | at least this size. 104 | epsilon: Regularization constant for squared gradient. 105 | decay_rate_fn: A function that accepts the current step, the decay rate 106 | parameter and controls the schedule for the second momentum. Defaults to 107 | the original adafactor's power decay schedule. One potential shortcoming 108 | of the original schedule is the fact that second momentum converges to 1, 109 | which effectively freezes the second momentum. To prevent this the user 110 | can opt for a custom schedule that sets an upper bound for the second 111 | momentum, like in Zhai et al., 2021. 112 | 113 | Returns: 114 | The corresponding :class:`optax.GradientTransformation`. 115 | 116 | References: 117 | Shazeer et al, `Adafactor: Adaptive Learning Rates with Sublinear Memory 118 | Cost `_, 2018 119 | 120 | Zhai et al, `Scaling Vision Transformers 121 | `_, 2021 122 | """ 123 | 124 | def _to_state(count: chex.Array, result_tree): 125 | """Maps from a tree of (factored) values to separate trees of values.""" 126 | return FactoredState( 127 | count=count, 128 | v_row=jax.tree.map(lambda o: o.v_row, result_tree), 129 | v_col=jax.tree.map(lambda o: o.v_col, result_tree), 130 | v=jax.tree.map(lambda o: o.v, result_tree), 131 | ) 132 | 133 | def init_fn(params): 134 | """Initialise the optimizer's state.""" 135 | 136 | def _init(param): 137 | shape, dtype = param.shape, jnp.float32 # <-- CHANGED TO HARD-CODE FLOAT32 138 | factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor) 139 | if factored_dims is not None: 140 | d1, d0 = factored_dims 141 | vr_shape = np.delete(shape, d0) 142 | vc_shape = np.delete(shape, d1) 143 | return _UpdateResult( 144 | update=jnp.zeros((1,), dtype=dtype), 145 | v_row=jnp.zeros(vr_shape, dtype=dtype), 146 | v_col=jnp.zeros(vc_shape, dtype=dtype), 147 | v=jnp.zeros((1,), dtype=dtype), 148 | ) 149 | return _UpdateResult( 150 | update=jnp.zeros((1,), dtype=dtype), 151 | v_row=jnp.zeros((1,), dtype=dtype), 152 | v_col=jnp.zeros((1,), dtype=dtype), 153 | v=jnp.zeros(param.shape, dtype=dtype), 154 | ) 155 | 156 | return _to_state(jnp.zeros([], jnp.int32), jax.tree.map(_init, params)) 157 | 158 | def update_fn(grads, state, params): 159 | """Apply gradient transformation.""" 160 | if params is None: 161 | raise ValueError(base.NO_PARAMS_MSG) 162 | 163 | def _update(grad, v_row, v_col, v, param, step): 164 | shape, dtype = param.shape, jnp.float32 # <-- CHANGED TO HARD-CODE FLOAT32 165 | decay_rate_t = decay_rate_fn(step - step_offset, decay_rate) 166 | 167 | # Scaled by factorized second moment statistics. 168 | new_v_row = jnp.zeros((1,), dtype=dtype) 169 | new_v_col = jnp.zeros((1,), dtype=dtype) 170 | new_v = jnp.zeros((1,), dtype=dtype) 171 | 172 | factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor) 173 | if factored_dims is not None: 174 | d1, d0 = factored_dims 175 | grad_sqr = numerics.abs_sq(grad) + epsilon 176 | new_v_row = decay_rate_t * v_row + (1.0 - decay_rate_t) * jnp.mean( 177 | grad_sqr, axis=d0 178 | ) 179 | new_v_col = decay_rate_t * v_col + (1.0 - decay_rate_t) * jnp.mean( 180 | grad_sqr, axis=d1 181 | ) 182 | new_v_row = new_v_row.astype(dtype) 183 | new_v_col = new_v_col.astype(dtype) 184 | reduced_d1 = d1 - 1 if d1 > d0 else d1 185 | row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True) 186 | row_factor = (new_v_row / row_col_mean) ** -0.5 187 | col_factor = (new_v_col) ** -0.5 188 | update = ( 189 | grad 190 | * jnp.expand_dims(row_factor, axis=d0) 191 | * jnp.expand_dims(col_factor, axis=d1) 192 | ) 193 | else: 194 | grad_sqr = numerics.abs_sq(grad) + epsilon 195 | new_v = decay_rate_t * v + (1.0 - decay_rate_t) * grad_sqr 196 | new_v = new_v.astype(dtype) 197 | update = grad * (new_v) ** -0.5 198 | 199 | return _UpdateResult(update, new_v_row, new_v_col, new_v) 200 | 201 | # Transform grad and compute new per-parameter stats. 202 | output = jax.tree.map( 203 | lambda *args: _update(*args, state.count), 204 | grads, 205 | state.v_row, 206 | state.v_col, 207 | state.v, 208 | params, 209 | ) 210 | 211 | # Unpack updates / stats and return. 212 | updates = jax.tree.map(lambda o: o.update, output) 213 | return updates, _to_state(numerics.safe_increment(state.count), output) 214 | 215 | return base.GradientTransformation(init_fn, update_fn) 216 | -------------------------------------------------------------------------------- /finetuning/finetune.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import jax 3 | import jax.numpy as jnp 4 | import optax 5 | import wandb 6 | from flax import nnx 7 | from tqdm.auto import tqdm 8 | from functools import partial 9 | from optax import tree_utils as otu 10 | from jax.sharding import NamedSharding, PartitionSpec as P 11 | import optimizer as optimizer_lib 12 | import gemma, data 13 | 14 | 15 | def loss_fn(model_state, model_graphdef, x, pos, attn_mask, loss_mask): # [B, T] 16 | model = nnx.merge(model_graphdef, model_state) 17 | y = jnp.roll(x, -1, axis=1) 18 | logits, _ = model(x, positions=pos, attn_mask=attn_mask) # [B, T, V] 19 | losses = optax.softmax_cross_entropy_with_integer_labels(logits, y) # [B, T] 20 | return (losses * loss_mask).sum() / loss_mask.sum() 21 | 22 | 23 | @partial(jax.jit, static_argnames=('model_graphdef', 'opt_graphdef', 'lora'), donate_argnames=('opt_state')) 24 | def train_step(key, opt_state, model_graphdef, opt_graphdef, tokens, pos, attn_mask, loss_mask, lora=False): 25 | key, key_opt = jax.random.split(key) 26 | grad_fn = jax.value_and_grad if not lora else partial(nnx.value_and_grad, argnums=nnx.DiffState(0, nnx.LoRAParam)) 27 | 28 | # compute grads from a single micro-batch 29 | if tokens.shape[0] == 1: 30 | loss, grads = grad_fn(loss_fn)(opt_state.model, model_graphdef, tokens[0], pos[0], attn_mask[0], loss_mask[0]) 31 | 32 | # compute grads from multiple micro-batches (using gradient accumulation) 33 | if tokens.shape[0] >= 2: 34 | loss = 0 35 | grads = otu.tree_zeros_like(opt_state.model, dtype=jnp.float32) 36 | def step_fn(i, args): 37 | loss, grads = args 38 | batch_loss, batch_grads = grad_fn(loss_fn)(opt_state.model, model_graphdef, tokens[i], pos[i], attn_mask[i], loss_mask[i]) 39 | loss = (i*loss + batch_loss) / (i+1) 40 | grads = jax.tree.map(lambda m, g: (i*m + g) / (i+1), grads, batch_grads) 41 | return loss, grads 42 | loss, grads = jax.lax.fori_loop(0, len(tokens), step_fn, (loss, grads)) 43 | 44 | # optimizer step 45 | optimizer = nnx.merge(opt_graphdef, opt_state) 46 | optimizer.update(key_opt, grads) 47 | opt_state = nnx.state(optimizer) 48 | return key, opt_state, loss 49 | 50 | 51 | def finetune( 52 | model_variant = 'gemma3-1b', # ['1b', '4b', '12b', '27b'] 53 | lora_rank = None, 54 | temperature = 1, 55 | optimizer_name = 'adafactor', # ['sgd', 'adam', 'adafactor'] 56 | peak_lr = 1e-6, 57 | lr_schedule = 'const', 58 | b2 = 0.997, 59 | n_epochs = 1, 60 | batch_size = 1, 61 | microbatch_size = 1, 62 | n_eval_samples = None, 63 | eval_batch_size = 64, 64 | n_data_devices = 1, 65 | train_parallelism = 'seq', # ['seq', 'batch'] 66 | param_dtype = 'bfloat16', 67 | stochastic_round = False, 68 | remat = False, 69 | log_every_samples = 100, 70 | print_output=False, 71 | wandb_mode = 'disabled', 72 | run_name = None, 73 | seed = 0, 74 | **kwargs, 75 | ): 76 | # check if any unrecognized arguments were passed 77 | if len(kwargs) > 0: raise NameError(f'Unrecognized arguments: {kwargs}') 78 | 79 | # log config 80 | train_config = locals() 81 | if jax.process_index() == 0: 82 | print(f'{train_config=}') 83 | wandb.init(project='picodo-finetune', config=train_config, mode=wandb_mode, name=run_name) 84 | 85 | # sharding 86 | n_tensor_devices = jax.device_count() // n_data_devices 87 | mesh = jax.make_mesh((n_data_devices, n_tensor_devices), ('data', 'model')) 88 | jax.set_mesh(mesh) 89 | print('sharding mesh:', ', '.join(f'{k}={v}' for k, v in mesh.shape.items())) 90 | 91 | # load model 92 | print('loading model...') 93 | model, vocab = gemma.load_pretrained(model_variant, mesh, param_dtype, remat) 94 | 95 | # optionally use Lora 96 | grad_acc_steps = batch_size // microbatch_size 97 | assert not (remat and lora_rank is not None), 'remat currently not supported with Lora' 98 | assert not (grad_acc_steps > 1 and lora_rank is not None), 'grad. accum. currently not supported with Lora' 99 | use_lora = lora_rank is not None 100 | if use_lora: 101 | import qwix 102 | 103 | # apply LoRA to all layers except normalization layers 104 | lora_provider = qwix.LoraProvider(module_path='^((?!scale).)*$', rank=lora_rank, alpha=2) 105 | dummy_input = jnp.ones([1, 128], dtype=jnp.int32) 106 | model = qwix.apply_lora_to_model(model, lora_provider, dummy_input) 107 | 108 | # convert all LoRA params to float32 109 | converted = False 110 | for path, module in model.iter_modules(): 111 | if hasattr(module, 'kernel_lora_a'): 112 | module.kernel_lora_a.value = module.kernel_lora_a.value.astype(jnp.float32) 113 | module.kernel_lora_b.value = module.kernel_lora_b.value.astype(jnp.float32) 114 | converted = True 115 | assert converted, f'failed to cast LoRA to fp32' 116 | 117 | # load datasets 118 | print('loading data...') 119 | train_tokens, train_pos, train_attn_mask, train_loss_mask, tokens_eval, problems_eval, solutions_eval = data.load_datasets(vocab) 120 | 121 | # optimizer 122 | warmup_frac = 0.05 123 | n_train_samples = len(train_tokens) 124 | n_batches = n_train_samples // batch_size 125 | n_optimizer_steps = n_epochs * n_batches 126 | warmup_steps = int(warmup_frac * n_optimizer_steps) 127 | if lr_schedule == 'const': lr = peak_lr 128 | if lr_schedule == 'cosine': lr = optax.schedules.warmup_cosine_decay_schedule(0, peak_lr, warmup_steps, max(1, n_optimizer_steps)) 129 | if optimizer_name == 'sgd': tx = optax.sgd(lr) 130 | if optimizer_name == 'adam': tx = optax.adam(lr, 0.9, b2) 131 | if optimizer_name == 'adafactor': tx = optimizer_lib.adafactor(lr, decay_rate=b2) 132 | wrt = nnx.LoRAParam if use_lora else nnx.Param 133 | optimizer = optimizer_lib.ModelAndOptimizer(model, tx, wrt, stochastic_round) 134 | opt_graphdef, opt_state = nnx.split(optimizer) 135 | model_graphdef = nnx.graphdef(model) 136 | 137 | # print number of parameters 138 | n_model_params = jax.tree.reduce(lambda x, y: x + jnp.size(y), nnx.state(model), 0) 139 | n_opt_params = jax.tree.reduce(lambda x, y: x + jnp.size(y), nnx.state(optimizer.opt_state), 0) 140 | print(f'{n_model_params=:_}') 141 | print(f'{n_opt_params=:_}') 142 | 143 | # training loop 144 | step = 0 145 | train_loss = 0 146 | key = jax.random.PRNGKey(seed) 147 | del model 148 | # iterate over epochs 149 | if n_epochs > 0: 150 | if (jax.process_index() == 0): pbar = tqdm(total=n_optimizer_steps, desc='Training') 151 | for epoch in range(n_epochs): 152 | 153 | # iterate over batches 154 | key, key_train = jax.random.split(key) 155 | idxs = jax.random.choice(key_train, n_train_samples, shape=[n_batches, grad_acc_steps, microbatch_size], replace=False) 156 | for idx in idxs: 157 | 158 | # shard batch 159 | token_pspec = P(None, 'data', None) if train_parallelism == 'batch' else P(None, None, 'data') # [grad_acc_steps, microbatch_size, seq_len] 160 | attn_mask_pspec = P(None, 'data', None, None) if train_parallelism == 'batch' else P(None, None, None, None) # [grad_acc_steps, microbatch_size, seq_len, seq_len] 161 | tokens_batch = jax.device_put(train_tokens[idx], NamedSharding(mesh, token_pspec)) 162 | pos_batch = jax.device_put(train_pos[idx], NamedSharding(mesh, token_pspec)) 163 | loss_mask_batch = jax.device_put(train_loss_mask[idx], NamedSharding(mesh, token_pspec)) 164 | attn_mask_batch = jax.device_put(train_attn_mask[idx], NamedSharding(mesh, attn_mask_pspec)) 165 | 166 | # training step 167 | key, opt_state, batch_loss = train_step(key, opt_state, model_graphdef, opt_graphdef, tokens_batch, pos_batch, attn_mask_batch, loss_mask_batch, use_lora) 168 | 169 | # logging 170 | train_loss += batch_loss 171 | log_every_steps = log_every_samples // batch_size 172 | if (step+1) % log_every_steps == 0: 173 | avg_loss = train_loss / log_every_steps 174 | if jax.process_index() == 0: 175 | wandb.log({'train_loss': float(avg_loss)}, step) 176 | pbar.set_postfix_str(f'loss={float(avg_loss):.2f}') 177 | train_loss = 0 178 | step += 1 179 | if jax.process_index() == 0: pbar.update(1) 180 | 181 | # after training is finished, update optimizer 182 | nnx.update(optimizer, opt_state) 183 | del optimizer.opt_state 184 | if (jax.process_index() == 0): pbar.close() 185 | 186 | # eval 187 | key, key_eval = jax.random.split(key) 188 | eval_metrics = data.benchmark_model(key_eval, optimizer.model, tokens_eval, problems_eval, solutions_eval, vocab, eval_batch_size, n_eval_samples, temperature, print_output) 189 | if jax.process_index() == 0: 190 | wandb.log(eval_metrics, step) 191 | wandb.finish() 192 | 193 | 194 | if __name__ == '__main__': 195 | fire.Fire(finetune) 196 | -------------------------------------------------------------------------------- /pretraining/train.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"machine_shape":"hm","gpuType":"V6E1","authorship_tag":"ABX9TyMRGRGu75ydq/VvCZ6EsR1R"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"TPU","widgets":{"application/vnd.jupyter.widget-state+json":{"a4cbfff5a22645b887cc78386dddb8aa":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_21ebe59ca40a497cb618eff449830528","IPY_MODEL_b726f09a12a24c9fa93941ee6bab9505","IPY_MODEL_5981eaf73d654dc8aa6f2c224f5c5106"],"layout":"IPY_MODEL_e3fe43f6c0024b38b81cf75bcd992083"}},"21ebe59ca40a497cb618eff449830528":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_7de3773c77994bf0b08aebee44784f4d","placeholder":"​","style":"IPY_MODEL_91bbc255d1854e5db0b31493bb127af7","value":"100%"}},"b726f09a12a24c9fa93941ee6bab9505":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_967e56177cac4da487b6d96c2a39dfde","max":2412735,"min":0,"orientation":"horizontal","style":"IPY_MODEL_90d961f71acd473cb66a119743aa305f","value":2412735}},"5981eaf73d654dc8aa6f2c224f5c5106":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_75f57c58aa0341bb98dc698d5c410453","placeholder":"​","style":"IPY_MODEL_f694d91c5f844468ab6118ac0624eec0","value":" 2412735/2412735 [3:45:43<00:00, 179.00it/s, loss=3.53]"}},"e3fe43f6c0024b38b81cf75bcd992083":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"7de3773c77994bf0b08aebee44784f4d":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"91bbc255d1854e5db0b31493bb127af7":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"967e56177cac4da487b6d96c2a39dfde":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"90d961f71acd473cb66a119743aa305f":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"75f57c58aa0341bb98dc698d5c410453":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"f694d91c5f844468ab6118ac0624eec0":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"cells":[{"cell_type":"code","source":["# setup\n","! uv pip install fire hydra-core wandb \"jax[tpu]\" flax==0.12.0 -q\n","! git clone --depth=1 https://github.com/martin-marek/batch-size.git\n","! python /content/batch-size/pretraining/download_fineweb.py 'fineweb' 26 # 2.6B tokens"],"metadata":{"id":"-9CnBf8GOXlT"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# imports\n","%cd /content/batch-size/pretraining\n","import os\n","from hydra import compose, initialize_config_dir\n","from configs import resolver_setup\n","from train import train_and_evaluate"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"FUThBSWpjvJR","executionInfo":{"status":"ok","timestamp":1755253454774,"user_tz":-60,"elapsed":12,"user":{"displayName":"Martin Marek","userId":"04932572550491068578"}},"outputId":"3735b29f-fcd1-4047-aa96-fd6567995769"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["/content/batch-size/pretraining\n"]}]},{"cell_type":"code","source":["# training\n","with initialize_config_dir(f'{os.getcwd()}/configs', version_base=None):\n"," c = compose(config_name='base', overrides=['+model=gpt2s', '+dataset=fw_gpt2'])\n","c.opt.optimizer = 'sgd'\n","c.opt.batch_size = 1\n","c.opt.peak_lr = 0.2\n","c.wandb_mode = 'disabled' # disable Weights & Biases\n","train_and_evaluate(c)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":210,"referenced_widgets":["a4cbfff5a22645b887cc78386dddb8aa","21ebe59ca40a497cb618eff449830528","b726f09a12a24c9fa93941ee6bab9505","5981eaf73d654dc8aa6f2c224f5c5106","e3fe43f6c0024b38b81cf75bcd992083","7de3773c77994bf0b08aebee44784f4d","91bbc255d1854e5db0b31493bb127af7","967e56177cac4da487b6d96c2a39dfde","90d961f71acd473cb66a119743aa305f","75f57c58aa0341bb98dc698d5c410453","f694d91c5f844468ab6118ac0624eec0"]},"id":"l70D5oXzhGWn","executionInfo":{"status":"ok","timestamp":1755267054988,"user_tz":-60,"elapsed":5525283,"user":{"displayName":"Martin Marek","userId":"04932572550491068578"}},"outputId":"53864ecb-4ee5-4f29-d1c4-3d7bc6ebb16d"},"execution_count":null,"outputs":[{"metadata":{"tags":null},"name":"stdout","output_type":"stream","text":["sharding mesh: data=1, model=1\n","initializing model...\n","n_param_nonembed=84_934_656\n","n_param_embed=38_597_376\n","n_param_actual=162_129_408\n","getting dataset size...\n","reading data...\n","shuffling data...\n","splitting data...\n"]},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a4cbfff5a22645b887cc78386dddb8aa","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2412735 [00:00 optax.Params: 59 | """Extends optax.apply_updates with stochastic rounding.""" 60 | keys = otu.tree_split_key_like(key, params) 61 | def leaf_update(p, u, key): 62 | if p is None: return None 63 | param_dtype = jnp.asarray(p).dtype 64 | if stochastic_round: 65 | p = p.astype(jnp.float32) + u 66 | p = utils.to_bf16_stochastic(key, p) 67 | else: 68 | p += u 69 | return p.astype(param_dtype) 70 | return jax.tree.map(leaf_update, params, updates, keys, is_leaf=lambda x: x is None) 71 | 72 | 73 | def get_optimizer(c: DictConfig, num_opt_steps: int, tokens_per_opt_step: int): 74 | 75 | # get LR 76 | assert (c.peak_lr is not None) ^ ((c.peak_lr_scaled is not None) & (c.peak_lr_scaling is not None)) 77 | if c.peak_lr is None: 78 | c.peak_lr = c.peak_lr_scaling * c.peak_lr_scaled 79 | 80 | # get schedule 81 | warmup_steps = int(c.warmup_frac * num_opt_steps) 82 | lr_schedule = optax.schedules.warmup_cosine_decay_schedule(0, c.peak_lr, warmup_steps, num_opt_steps) 83 | 84 | # convert (t1 <-> b1), (t2 <-> b2) 85 | assert (c.b1 is None) | (c.t1 is None) # at most one can be specified in config 86 | assert (c.b2 is None) | (c.t2 is None) # at most one can be specified in config 87 | assert (c.muon_b1 is None) | (c.muon_t1 is None) # at most one can be specified in config 88 | if c.b1 is None and c.t1 is not None: c.b1 = float(utils.halflife_to_decay(c.t1, tokens_per_opt_step)) 89 | if c.b2 is None and c.t2 is not None: c.b2 = float(utils.halflife_to_decay(c.t2, tokens_per_opt_step)) 90 | if c.t1 is None and c.b1 is not None: c.t1 = float(utils.decay_to_halflife(c.b1, tokens_per_opt_step)) 91 | if c.t2 is None and c.b2 is not None: c.t2 = float(utils.decay_to_halflife(c.b2, tokens_per_opt_step)) 92 | if c.muon_b1 is None and c.muon_t1 is not None: c.muon_b1 = float(utils.halflife_to_decay(c.muon_t1, tokens_per_opt_step)) 93 | if c.muon_t1 is None and c.muon_b1 is not None: c.muon_t1 = float(utils.decay_to_halflife(c.muon_b1, tokens_per_opt_step)) 94 | if c.b2_min is not None: c.b2 = max(c.b2, c.b2_min) 95 | 96 | if c.optimizer in ('sgd', 'signum'): 97 | assert c.b2 is None 98 | assert c.t2 is None 99 | assert c.weight_decay == 0 100 | signed = c.optimizer == 'signum' 101 | optimizer = sgd(lr_schedule, c.b1, signed) 102 | 103 | if c.optimizer == 'adamw': 104 | assert c.b1 is not None 105 | assert c.b2 is not None 106 | optimizer = optax.adamw(lr_schedule, c.b1, c.b2, weight_decay=c.weight_decay) 107 | 108 | if c.optimizer == 'adafactor': 109 | assert c.b1 is None 110 | assert c.b2 is not None 111 | assert c.weight_decay == 0 112 | optimizer = adafactor(lr_schedule, decay_rate=c.b2) 113 | 114 | if c.optimizer == 'muon': 115 | assert c.b1 is not None 116 | assert c.b2 is not None 117 | assert c.muon_lr is not None 118 | assert c.muon_b1 is not None 119 | muon_lr = optax.schedules.warmup_cosine_decay_schedule(0, c.muon_lr, warmup_steps, num_opt_steps) 120 | optimizer = muon(muon_lr, c.muon_b1, lr_schedule, c.b1, c.b2) 121 | 122 | if c.clip_by_global_norm is not None: 123 | optimizer = optax.chain(optax.clip_by_global_norm(c.clip_by_global_norm), optimizer) 124 | 125 | return optimizer 126 | 127 | 128 | def sgd( 129 | learning_rate: optax.ScalarOrSchedule, 130 | b1: Optional[float] = None, 131 | signed = False, 132 | ) -> optax.GradientTransformation: 133 | return optax.chain( 134 | optax.trace(decay=b1) if b1 is not None else optax.identity(), 135 | optax.scale_by_sign() if signed else optax.identity(), 136 | optax.scale_by_learning_rate(learning_rate), 137 | ) 138 | 139 | 140 | def orthogonalize_via_newton_schulz( 141 | x: jax.Array, 142 | ns_coeffs: jax.Array, 143 | ns_steps: int = 5, 144 | eps: float = 1e-8, 145 | ) -> jax.Array: 146 | # https://github.com/google-deepmind/optax/blob/main/optax/contrib/_muon.py 147 | if x.ndim < 2: 148 | raise ValueError(f'Input must have >= 2 dims, got {x.shape}') 149 | if ns_coeffs.shape != (3,): 150 | raise ValueError(f'ns_coeffs must have shape (3,), got {ns_coeffs}') 151 | def newton_schulz_iterator(x: jax.Array, coeffs: jax.Array) -> jax.Array: 152 | x_mT = jnp.swapaxes(x, -2, -1) # <-- changed (matrix transpose last 2 dims) 153 | a = x @ x_mT # <-- changed (use matrix transpose) 154 | b = coeffs[1] * a + coeffs[2] * a @ a 155 | return coeffs[0] * x + b @ x 156 | transposed = False 157 | if x.shape[-2] > x.shape[-1]: # <-- changed (check last 2 dims) 158 | x = jnp.swapaxes(x, -2, -1) # <-- changed (transpose last 2 dims) 159 | transposed = True 160 | x /= (jnp.linalg.norm(x, axis=(-2, -1), keepdims=True) + eps) # <-- changed (normalize each matrix slice) 161 | ns_coeffs = ns_coeffs.astype(x.dtype) 162 | x = jax.lax.fori_loop(0, ns_steps, lambda _, x: newton_schulz_iterator(x, ns_coeffs), x) 163 | if transposed: x = jnp.swapaxes(x, -2, -1) # <-- changed (transpose last 2 dims) 164 | return x 165 | 166 | 167 | class MuonState(NamedTuple): 168 | """State for the Adam algorithm.""" 169 | count: jax.Array # shape=(), dtype=jnp.int32. 170 | mu: optax.Updates 171 | ns_coeffs: jax.Array # shape=(), dtype=jnp.int32. 172 | 173 | 174 | def scale_by_muon( 175 | ns_coeffs: tuple = (3.4445, -4.7750, 2.0315), 176 | ns_steps: int = 5, 177 | beta: float = 0.95, 178 | eps: float = 1e-8, 179 | ) -> optax.GradientTransformation: 180 | # https://github.com/google-deepmind/optax/blob/main/optax/contrib/_muon.py 181 | 182 | def init_fn(params): 183 | mu = otu.tree_zeros_like(params) # First moment 184 | return MuonState(jnp.zeros([], jnp.int32), mu, jnp.asarray(ns_coeffs)) 185 | 186 | def update_fn(updates, state, params=None): 187 | del params 188 | mu = otu.tree_update_moment(updates, state.mu, beta, 1) 189 | count_inc = optax.safe_increment(state.count) 190 | mu_hat = otu.tree_bias_correction(mu, beta, count_inc) 191 | # Apply Newton-schulz orthogonalization. 192 | updates = jax.tree.map(lambda x: orthogonalize_via_newton_schulz(x, state.ns_coeffs, ns_steps, eps), mu_hat) 193 | updates = jax.tree.map(lambda x: jnp.sqrt(jnp.maximum(1, x.shape[-1] / x.shape[-2])) * x, updates) 194 | return updates, MuonState(count_inc, mu, state.ns_coeffs) 195 | 196 | return optax.GradientTransformation(init_fn, update_fn) 197 | 198 | 199 | def muon( 200 | learning_rate: float, 201 | muon_b1: float, 202 | adam_lr: float, 203 | adam_b1: float, 204 | adam_b2: float, 205 | ) -> optax.GradientTransformation: 206 | return optax.multi_transform( 207 | transforms={ 208 | 'muon': optax.chain( 209 | scale_by_muon(beta=muon_b1), 210 | optax.scale_by_learning_rate(learning_rate), 211 | ), 212 | 'adam': optax.adamw(adam_lr, adam_b1, adam_b2) 213 | }, 214 | param_labels=lambda params: jax.tree.map_with_path( 215 | lambda path, val: 'adam' if 'embed' in jax.tree_util.keystr(path) else 'muon', params 216 | ), 217 | ) 218 | 219 | 220 | def adafactor( 221 | learning_rate: optax.ScalarOrSchedule, 222 | decay_rate: float = 0.8, 223 | clipping_threshold: Optional[float] = 1.0, 224 | min_dim_size_to_factor: int = 128, 225 | ) -> optax.GradientTransformation: 226 | """ 227 | Adafactor reimplemented to use float32 state, regardless of param dtype. 228 | https://github.com/google-deepmind/optax/blob/8973bb3c77b07850737246815f1c028b53fffbe0/optax/_src/alias.py#L225#L327 229 | """ 230 | return optax.chain( 231 | factorized.scale_by_factored_rms(decay_rate=decay_rate, min_dim_size_to_factor=min_dim_size_to_factor), 232 | optax.clip_by_block_rms(clipping_threshold) if clipping_threshold is not None else optax.identity(), 233 | optax.scale_by_learning_rate(learning_rate), 234 | optax.scale_by_param_block_rms(), 235 | ) 236 | -------------------------------------------------------------------------------- /finetuning/gemma.py: -------------------------------------------------------------------------------- 1 | """based on https://github.com/google/flax/tree/main/examples/gemma""" 2 | 3 | import os 4 | import dataclasses 5 | from itertools import cycle 6 | from contextlib import redirect_stderr 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | from flax import nnx 11 | from jax.sharding import PartitionSpec as P, NamedSharding 12 | from rope import apply_rope 13 | 14 | 15 | @dataclasses.dataclass(frozen=True) 16 | class GemmaConfig: 17 | num_layers: int 18 | embed_dim: int 19 | hidden_dim: int 20 | num_heads: int 21 | head_dim: int 22 | num_kv_heads: int 23 | query_pre_attn_scalar: float 24 | vocab_size: int = 262_144 25 | local_base_frequency: int = 10_000 26 | global_base_frequency: int = 1_000_000 27 | local_scale_factor: float = 1.0 28 | global_scale_factor: float = 1.0 29 | sliding_window_size: int | None = None 30 | attention_pattern: tuple[str] = (*(['sliding']*5), 'global') 31 | activ_dtype: str = 'float32' 32 | param_dtype: str = 'float32' 33 | remat: bool = False 34 | 35 | 36 | @classmethod 37 | def gemma3_1b(cls, param_dtype='float32', remat=False): 38 | return cls( 39 | num_layers=26, 40 | embed_dim=1152, 41 | hidden_dim=6912, 42 | num_heads=4, 43 | head_dim=256, 44 | num_kv_heads=1, 45 | query_pre_attn_scalar=256**-0.5, # 1/sqrt(head_dim) 46 | sliding_window_size=512, 47 | global_scale_factor=1.0, 48 | param_dtype=param_dtype, 49 | remat=remat, 50 | ) 51 | 52 | @classmethod 53 | def gemma3_4b(cls, param_dtype='float32', remat=False): 54 | return cls( 55 | num_layers=34, 56 | embed_dim=2560, 57 | hidden_dim=10240, 58 | num_heads=8, 59 | head_dim=256, 60 | num_kv_heads=4, 61 | query_pre_attn_scalar=256**-0.5, # 1/sqrt(head_dim) 62 | sliding_window_size=1024, 63 | global_scale_factor=8.0, 64 | param_dtype=param_dtype, 65 | remat=remat, 66 | ) 67 | 68 | @classmethod 69 | def gemma3_12b(cls, param_dtype='float32', remat=False): 70 | return cls( 71 | num_layers=48, 72 | embed_dim=3840, 73 | hidden_dim=15360, 74 | num_heads=16, 75 | head_dim=256, 76 | num_kv_heads=8, 77 | query_pre_attn_scalar=256**-0.5, # 1/sqrt(head_dim) 78 | sliding_window_size=1024, 79 | global_scale_factor=8.0, 80 | param_dtype=param_dtype, 81 | remat=remat, 82 | ) 83 | 84 | @classmethod 85 | def gemma3_27b(cls, param_dtype='float32', remat=False): 86 | return cls( 87 | num_layers=62, 88 | embed_dim=5376, 89 | hidden_dim=21504, 90 | num_heads=32, 91 | head_dim=128, 92 | num_kv_heads=16, 93 | query_pre_attn_scalar=(5376/32)**-0.5, # 1/sqrt(embed_dim / num_heads) 94 | sliding_window_size=1024, 95 | global_scale_factor=8.0, 96 | param_dtype=param_dtype, 97 | remat=remat, 98 | ) 99 | 100 | 101 | class Gemma(nnx.Module): 102 | def __init__(self, c: GemmaConfig, rngs: nnx.Rngs): 103 | self.in_embed = nnx.Embed(c.vocab_size, c.embed_dim, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs) 104 | self.layers = nnx.List( 105 | TransformerBlock( 106 | num_heads=c.num_heads, 107 | num_kv_heads=c.num_kv_heads, 108 | embed_dim=c.embed_dim, 109 | head_dim=c.head_dim, 110 | hidden_dim=c.hidden_dim, 111 | query_pre_attn_scalar=c.query_pre_attn_scalar, 112 | sliding_window_size=c.sliding_window_size if attn_type == 'sliding' else None, 113 | rope_base_frequency=c.local_base_frequency if attn_type == 'sliding' else c.global_base_frequency, 114 | rope_scale_factor=c.local_scale_factor if attn_type == 'sliding' else c.global_scale_factor, 115 | activ_dtype = c.activ_dtype, 116 | param_dtype = c.param_dtype, 117 | rngs=rngs, 118 | ) for _, attn_type in zip(range(c.num_layers), cycle(c.attention_pattern)) 119 | ) 120 | self.final_norm = nnx.RMSNorm(c.embed_dim, dtype=c.activ_dtype, param_dtype=c.param_dtype, rngs=rngs) 121 | self.remat = c.remat 122 | 123 | def __call__( 124 | self, 125 | tokens, # [B, T] 126 | kv_cache = {}, # [B, S] 127 | positions = None, # [B, S] 128 | attn_mask = None, # [B, S, S] 129 | ): 130 | V, D = self.in_embed.embedding.value.shape 131 | x = self.in_embed(tokens) * jnp.sqrt(D) # [B, T, D] 132 | 133 | for i, layer in enumerate(self.layers): 134 | layer_fn = jax.remat(layer) if self.remat else layer 135 | x, kv_cache[i] = layer_fn(x, kv_cache.get(i), positions, attn_mask) # [B, T, D] 136 | 137 | x = self.final_norm(x) 138 | logits = jnp.dot(x, self.in_embed.embedding.value.T) # [B, T, V] 139 | 140 | return logits, kv_cache 141 | 142 | 143 | def init_kv_cache(self, batch_size, max_seq_len): 144 | kv_cache = { 145 | i: layer.attn.init_kv_cache(batch_size, max_seq_len) 146 | for i, layer in enumerate(self.layers) 147 | } 148 | return kv_cache 149 | 150 | 151 | class Attention(nnx.Module): 152 | def __init__( 153 | self, 154 | num_heads: int, 155 | num_kv_heads: int, 156 | embed_dim: int, 157 | head_dim: int, 158 | query_pre_attn_scalar: float, 159 | rope_base_frequency: int, 160 | rope_scale_factor: float, 161 | sliding_window_size: int | None = None, 162 | activ_dtype = 'float32', 163 | param_dtype = 'float32', 164 | *, rngs: nnx.Rngs, 165 | ): 166 | self.query_pre_attn_scalar = query_pre_attn_scalar 167 | self.sliding_window_size = sliding_window_size 168 | self.rope_base_frequency = rope_base_frequency 169 | self.rope_scale_factor = rope_scale_factor 170 | self.attn_vec_einsum = nnx.Einsum(einsum_str='BTNH,NHD->BTD', kernel_shape=(num_heads, head_dim, embed_dim), dtype=activ_dtype, param_dtype=param_dtype, rngs=rngs) 171 | self.q_einsum = nnx.Einsum(einsum_str='BTD,NDH->BTNH', kernel_shape=(num_heads, embed_dim, head_dim), dtype=activ_dtype, param_dtype=param_dtype, rngs=rngs) 172 | self.kv_einsum = nnx.Einsum(einsum_str='BSD,CKDH->CBSKH', kernel_shape=(2, num_kv_heads, embed_dim, head_dim), dtype=activ_dtype, param_dtype=param_dtype, rngs=rngs) 173 | self._query_norm = nnx.RMSNorm(head_dim, dtype=activ_dtype, param_dtype=param_dtype, rngs=rngs) 174 | self._key_norm = nnx.RMSNorm(head_dim, dtype=activ_dtype, param_dtype=param_dtype, rngs=rngs) 175 | 176 | def __call__(self, 177 | x, # [B, T, D] 178 | kv_cache = None, # [B, S] 179 | positions = None, # [B, S] 180 | attn_mask = None, # [B, T, S] 181 | ): 182 | B, T, D = x.shape 183 | N, D, H = self.q_einsum.kernel.value.shape 184 | S = T if kv_cache is None else kv_cache['v'].shape[1] 185 | 186 | # qkv projection 187 | query = self.q_einsum(x) # [B, T, N, H] 188 | key, value = self.kv_einsum(x) # [B, T, N, H] 189 | 190 | # qk norm 191 | query = self._query_norm(query) 192 | key = self._key_norm(key) 193 | 194 | # get token indices 195 | if positions is None: 196 | # training 197 | if kv_cache is None: 198 | positions = jnp.broadcast_to(jnp.arange(T)[None, :], [B, S]) 199 | # sampling 200 | else: 201 | positions = jnp.full([B, 1], kv_cache['end_idx']) # [B, 1] 202 | 203 | # apply positional embeddings 204 | query = apply_rope(query, positions, self.rope_base_frequency, self.rope_scale_factor) 205 | key = apply_rope(key, positions, self.rope_base_frequency, self.rope_scale_factor) 206 | 207 | # load kv cache 208 | if kv_cache is not None: 209 | cache_dtype = kv_cache['k'].dtype 210 | key = kv_cache['k'].at[:, kv_cache['end_idx'], :, :].set(key[:, 0].astype(cache_dtype)) # [B, S, N, H] 211 | value = kv_cache['v'].at[:, kv_cache['end_idx'], :, :].set(value[:, 0].astype(cache_dtype)) # [B, S, N, H] 212 | query = query.astype(cache_dtype) 213 | 214 | # compute attention mask [B, T, S] 215 | if attn_mask is None: 216 | # if training, use lower triangular mask 217 | if kv_cache is None: 218 | attn_mask = jnp.tri(T, dtype=jnp.bool_)[None] # [B, T, S] 219 | # if sampling, all cached tokens are visible 220 | else: 221 | attn_mask = (jnp.arange(S) <= kv_cache['end_idx'])[None, None] # [B, 1, S] 222 | 223 | # add window to attention mask (TODO) 224 | if self.sliding_window_size is not None: 225 | offset = 0 if kv_cache is None else kv_cache['end_idx'] 226 | t, s = jnp.mgrid[0:T, 0:S] 227 | sliding_mask = (t - self.sliding_window_size + 1 + offset) <= s 228 | attn_mask &= sliding_mask[None] 229 | 230 | # gqa attention 231 | attn_mask = jnp.broadcast_to(attn_mask[:, None, :, :], [B, N, T, S]) 232 | encoded = jax.nn.dot_product_attention(query, key, value, mask=attn_mask, scale=self.query_pre_attn_scalar) 233 | 234 | # output projection 235 | attn_output = self.attn_vec_einsum(encoded) 236 | 237 | # update kv cache 238 | if kv_cache is not None: 239 | kv_cache = {'k': key, 'v': value, 'end_idx': kv_cache['end_idx'] + T} 240 | 241 | return attn_output, kv_cache 242 | 243 | def init_kv_cache(self, batch_size, max_seq_len): 244 | w = self.kv_einsum.kernel.value 245 | _, num_kv_heads, _, head_dim = w.shape 246 | sharding = NamedSharding(w.sharding.mesh, P('data', None, 'model', None)) if hasattr(w, 'sharding') else None 247 | kv_cache = { 248 | 'k': jnp.zeros((batch_size, max_seq_len, num_kv_heads, head_dim), dtype=jnp.bfloat16, device=sharding), 249 | 'v': jnp.zeros((batch_size, max_seq_len, num_kv_heads, head_dim), dtype=jnp.bfloat16, device=sharding), 250 | 'end_idx': jnp.array(0, dtype=jnp.int32), 251 | } 252 | return kv_cache 253 | 254 | 255 | class MLP(nnx.Module): 256 | def __init__(self, embed_dim, hidden_dim, activ_dtype='float32', param_dtype='float32', *, rngs): 257 | self.gate_proj = nnx.Linear(embed_dim, hidden_dim, use_bias=False, dtype=activ_dtype, param_dtype=param_dtype, kernel_init=jax.nn.initializers.normal(), rngs=rngs) 258 | self.up_proj = nnx.Linear(embed_dim, hidden_dim, use_bias=False, dtype=activ_dtype, param_dtype=param_dtype, kernel_init=jax.nn.initializers.normal(), rngs=rngs) 259 | self.down_proj = nnx.Linear(hidden_dim, embed_dim, use_bias=False, dtype=activ_dtype, param_dtype=param_dtype, kernel_init=jax.nn.initializers.normal(), rngs=rngs) 260 | 261 | def __call__(self, x): 262 | activations = nnx.gelu(self.gate_proj(x)) * self.up_proj(x) 263 | outputs = self.down_proj(activations) 264 | return outputs 265 | 266 | 267 | class TransformerBlock(nnx.Module): 268 | def __init__( 269 | self, 270 | num_heads: int, 271 | num_kv_heads: int, 272 | embed_dim: int, 273 | head_dim: int, 274 | hidden_dim: int, 275 | query_pre_attn_scalar: float, 276 | rope_base_frequency: int, 277 | rope_scale_factor: float, 278 | sliding_window_size: int | None = None, 279 | activ_dtype = 'float32', 280 | param_dtype = 'float32', 281 | *, rngs: nnx.Rngs, 282 | ): 283 | self.attn = Attention( 284 | num_heads=num_heads, num_kv_heads=num_kv_heads, 285 | embed_dim=embed_dim, head_dim=head_dim, query_pre_attn_scalar=query_pre_attn_scalar, 286 | rope_base_frequency=rope_base_frequency, rope_scale_factor=rope_scale_factor, 287 | sliding_window_size=sliding_window_size, param_dtype=param_dtype, activ_dtype=activ_dtype, rngs=rngs, 288 | ) 289 | self.mlp = MLP(embed_dim, hidden_dim, activ_dtype=activ_dtype, param_dtype=param_dtype, rngs=rngs) 290 | self.pre_attention_norm = nnx.RMSNorm(embed_dim, dtype=activ_dtype, param_dtype=param_dtype, rngs=rngs) 291 | self.post_attention_norm = nnx.RMSNorm(embed_dim, dtype=activ_dtype, param_dtype=param_dtype, rngs=rngs) 292 | self.pre_ffw_norm = nnx.RMSNorm(embed_dim, dtype=activ_dtype, param_dtype=param_dtype, rngs=rngs) 293 | self.post_ffw_norm = nnx.RMSNorm(embed_dim, dtype=activ_dtype, param_dtype=param_dtype, rngs=rngs) 294 | 295 | def __call__(self, 296 | x, # [B, T, D] 297 | kv_cache = None, # [B, S] 298 | positions = None, # [B, S] 299 | attn_mask = None, # [B, S, S] 300 | ): 301 | 302 | # attention 303 | attn_inputs = self.pre_attention_norm(x) 304 | attn_output, kv_cache = self.attn(attn_inputs, kv_cache, positions, attn_mask) 305 | attn_output = self.post_attention_norm(attn_output) 306 | x += attn_output 307 | 308 | # MLP 309 | ffw_inputs = self.pre_ffw_norm(x) 310 | ffw_outputs = self.mlp(ffw_inputs) 311 | ffw_outputs = self.post_ffw_norm(ffw_outputs) 312 | x += ffw_outputs 313 | 314 | return x, kv_cache 315 | 316 | 317 | def load_pretrained(model_variant, mesh=None, param_dtype='float32', remat=False): 318 | import kagglehub 319 | import sentencepiece as spm 320 | import orbax.checkpoint as ocp 321 | 322 | # helpers 323 | flatten_path = lambda path: jax.tree_util.keystr(path, simple=True, separator='/') 324 | flatten_tree = lambda tree: {flatten_path(path):v for path, v in jax.tree.leaves_with_path(tree)} 325 | 326 | # download weights 327 | with open(os.devnull, 'w') as f, redirect_stderr(f): # supress progress bar 328 | weights_dir = kagglehub.model_download(f'google/gemma-3/flax/{model_variant}') 329 | ckpt_path = f'{weights_dir}/{model_variant}' 330 | vocab_path = f'{weights_dir}/tokenizer.model' 331 | 332 | # load tokenizer 333 | vocab = spm.SentencePieceProcessor() 334 | vocab.Load(vocab_path) 335 | 336 | # load abstract model 337 | model_architecture = '_'.join(model_variant.split('-')[:2]) 338 | model_config = getattr(GemmaConfig, model_architecture)(param_dtype, remat) 339 | model = nnx.eval_shape(lambda: Gemma(model_config, rngs=nnx.Rngs(0))) 340 | model_state = nnx.state(model) 341 | 342 | # load checkpoint metadata 343 | checkpointer = ocp.Checkpointer(ocp.StandardCheckpointHandler()) 344 | checkpoint = checkpointer.metadata(ckpt_path).item_metadata 345 | 346 | # add checkpoint sharding annotations 347 | def add_sharding(path, v): 348 | key = flatten_path(path) 349 | pspec = None 350 | if 'input_embedding' in key: pspec = P('data', 'model') # [V, D] 351 | if '_norm' in key: pspec = P('data') # [D | H] 352 | if 'attn_vec_einsum' in key: pspec = P('model', None, 'data') # [N, H, D] 353 | if 'kv_einsum' in key: pspec = P(None, 'model', 'data', None) # [2, n, D, H] 354 | if 'q_einsum' in key: pspec = P('model', 'data', None) # [N, D, H] 355 | if 'mlp/linear' in key: pspec = P('model', 'data') # [F, D] 356 | if 'mlp/gating_einsum' in key: pspec = P(None, 'model', 'data') # [2, F, D] 357 | # if pspec is None: print(f'WARNING: {key} has no sharding!') 358 | sharding = None if (pspec is None or mesh is None) else NamedSharding(mesh, pspec) 359 | return jax.ShapeDtypeStruct(v.shape, param_dtype, sharding=sharding) 360 | checkpoint = jax.tree.map_with_path(add_sharding, checkpoint) 361 | 362 | # load checkpoint weights 363 | checkpoint = checkpointer.restore(ckpt_path, checkpoint) 364 | 365 | # flatten the checkpoint keys 366 | checkpoint = flatten_tree(checkpoint) 367 | 368 | # adjust checkpoint weights to match NNX format 369 | for key in list(checkpoint.keys()): 370 | if 'scale' in key: 371 | checkpoint[key] += 1 372 | if 'gating_einsum' in key: 373 | checkpoint[key.replace('gating_einsum', 'gate_proj')] = checkpoint[key][0].T 374 | checkpoint[key.replace('gating_einsum', 'up_proj')] = checkpoint[key][1].T 375 | del checkpoint[key] 376 | 377 | # transfer weights to model, mapping model layer keys to checkpoint keys 378 | def get_weights(path, v): 379 | key = flatten_path(path) 380 | key = 'transformer/' + key 381 | key = key.replace('/value', '') 382 | key = key.replace('layers/', 'layer_') 383 | key = key.replace('kernel', 'w') 384 | key = key.replace('in_embed/embedding', 'embedder/input_embedding') 385 | key = key.replace('mlp/down_proj', 'mlp/linear') 386 | return checkpoint[key] 387 | model_state = jax.tree.map_with_path(get_weights, model_state) 388 | nnx.update(model, model_state) 389 | 390 | return model, vocab 391 | -------------------------------------------------------------------------------- /utils/memory_measure.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/" 9 | }, 10 | "executionInfo": { 11 | "elapsed": 219, 12 | "status": "ok", 13 | "timestamp": 1749056727116, 14 | "user": { 15 | "displayName": "Martin Marek", 16 | "userId": "04932572550491068578" 17 | }, 18 | "user_tz": -60 19 | }, 20 | "id": "laglInP5rzv8", 21 | "outputId": "5dbc6cd4-0244-442e-a52e-246ea2bd14f9" 22 | }, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Thu Aug 7 13:51:18 2025 \n", 29 | "+-----------------------------------------------------------------------------------------+\n", 30 | "| NVIDIA-SMI 570.133.20 Driver Version: 570.133.20 CUDA Version: 12.8 |\n", 31 | "|-----------------------------------------+------------------------+----------------------+\n", 32 | "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 33 | "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n", 34 | "| | | MIG M. |\n", 35 | "|=========================================+========================+======================|\n", 36 | "| 0 NVIDIA B200 On | 00000000:DC:00.0 Off | 0 |\n", 37 | "| N/A 22C P0 140W / 1000W | 0MiB / 183359MiB | 0% Default |\n", 38 | "| | | Disabled |\n", 39 | "+-----------------------------------------+------------------------+----------------------+\n", 40 | " \n", 41 | "+-----------------------------------------------------------------------------------------+\n", 42 | "| Processes: |\n", 43 | "| GPU GI CI PID Type Process name GPU Memory |\n", 44 | "| ID ID Usage |\n", 45 | "|=========================================================================================|\n", 46 | "| No running processes found |\n", 47 | "+-----------------------------------------------------------------------------------------+\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "! nvidia-smi" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": { 59 | "executionInfo": { 60 | "elapsed": 5727, 61 | "status": "ok", 62 | "timestamp": 1749056737071, 63 | "user": { 64 | "displayName": "Martin Marek", 65 | "userId": "04932572550491068578" 66 | }, 67 | "user_tz": -60 68 | }, 69 | "id": "euko7CXVK-Jy" 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "import math\n", 74 | "import torch\n", 75 | "import torch.nn as nn\n", 76 | "from torch.nn import functional as F\n", 77 | "from torch.utils.checkpoint import checkpoint\n", 78 | "from torch import Tensor\n", 79 | "from omegaconf.dictconfig import DictConfig\n", 80 | "torch.backends.cuda.matmul.allow_tf32 = True" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 3, 86 | "metadata": { 87 | "executionInfo": { 88 | "elapsed": 3, 89 | "status": "ok", 90 | "timestamp": 1749056737087, 91 | "user": { 92 | "displayName": "Martin Marek", 93 | "userId": "04932572550491068578" 94 | }, 95 | "user_tz": -60 96 | }, 97 | "id": "cYbgxeyHKhgs" 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "# helpers\n", 102 | "def sizeof_fmt(num):\n", 103 | " for unit in (\"\", \"K\", \"M\", \"G\", \"T\"):\n", 104 | " if abs(num) < 1000:\n", 105 | " return f\"{num:.2f}{unit}B\"\n", 106 | " num /= 1000" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "# Muon implementation" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 4, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "class Muon(torch.optim.Optimizer):\n", 123 | " \"\"\"Muon that batches over >2D layers, based on https://github.com/KellerJordan/Muon/blob/master/muon.py\"\"\"\n", 124 | " def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):\n", 125 | " defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)\n", 126 | " super().__init__(params, defaults)\n", 127 | "\n", 128 | " @torch.no_grad()\n", 129 | " def step(self, closure=None):\n", 130 | "\n", 131 | " loss = None\n", 132 | " if closure is not None:\n", 133 | " with torch.enable_grad():\n", 134 | " loss = closure()\n", 135 | "\n", 136 | " for group in self.param_groups:\n", 137 | " for p in group[\"params\"]:\n", 138 | " if p.grad is None:\n", 139 | " continue\n", 140 | " state = self.state[p]\n", 141 | " if len(state) == 0:\n", 142 | " state[\"momentum_buffer\"] = torch.zeros_like(p)\n", 143 | " update = muon_update(p.grad, state[\"momentum_buffer\"], beta=group[\"momentum\"])\n", 144 | " p.mul_(1 - group[\"lr\"] * group[\"weight_decay\"])\n", 145 | " # p.add_(update.reshape(p.shape), alpha=-group[\"lr\"]) # <-- CHANGE: No longer need reshape as update preserves shape\n", 146 | " p.add_(update, alpha=-group[\"lr\"]) # <-- CHANGED\n", 147 | "\n", 148 | " return loss\n", 149 | "\n", 150 | "def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):\n", 151 | " momentum.lerp_(grad, 1 - beta)\n", 152 | " update = grad.lerp_(momentum, beta) if nesterov else momentum\n", 153 | " # if update.ndim == 4: # for the case of conv filters <-- CHANGE: Removed this block that flattens the tensor.\n", 154 | " # update = update.view(len(update), -1)\n", 155 | " update = zeropower_via_newtonschulz5(update, steps=ns_steps)\n", 156 | " # update *= max(1, grad.size(-2) / grad.size(-1))**0.5 <-- CHANGE: Swapped numerator/denominator to match JAX logic.\n", 157 | " update *= max(1, grad.size(-1) / grad.size(-2))**0.5 # <-- CHANGED\n", 158 | " return update\n", 159 | "\n", 160 | "def zeropower_via_newtonschulz5(G, steps: int):\n", 161 | " assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng\n", 162 | " a, b, c = (3.4445, -4.7750, 2.0315)\n", 163 | " X = G.bfloat16()\n", 164 | " if G.size(-2) > G.size(-1):\n", 165 | " X = X.mT\n", 166 | "\n", 167 | " # Ensure spectral norm is at most 1\n", 168 | " X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)\n", 169 | " # Perform the NS iterations\n", 170 | " for _ in range(steps):\n", 171 | " A = X @ X.mT\n", 172 | " B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng\n", 173 | " X = a * X + B @ X\n", 174 | " \n", 175 | " if G.size(-2) > G.size(-1):\n", 176 | " X = X.mT\n", 177 | " return X" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "# RoPE implementation" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 5, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "def apply_rope(\n", 194 | " inputs: Tensor, # [B, N, T, H]\n", 195 | " positions: Tensor, # [B, T]\n", 196 | " max_wavelength: int = 10_000,\n", 197 | " scale_factor: float = 1.0,\n", 198 | ") -> Tensor:\n", 199 | " \"\"\"Applies RoPE.\"\"\"\n", 200 | " B, N, T, H = inputs.shape\n", 201 | " device = inputs.device\n", 202 | " if scale_factor < 1.0:\n", 203 | " raise ValueError(f'scale_factor must be >= 1.0, got {scale_factor}')\n", 204 | "\n", 205 | " fraction = 2 * torch.arange(0, H // 2, device=device) / H # [H/2]\n", 206 | " timescale = max_wavelength**fraction # [H/2]\n", 207 | "\n", 208 | " sinusoid_inp = (positions[:, :, None] / timescale[None, None, :]) # [B, T, H/2]\n", 209 | " sinusoid_inp = sinusoid_inp[:, None, :, :] # [B, 1, T, H/2]\n", 210 | " sinusoid_inp /= scale_factor # [B, 1, T, H/2]\n", 211 | "\n", 212 | " sin = torch.sin(sinusoid_inp) # [B, 1, T, H/2]\n", 213 | " cos = torch.cos(sinusoid_inp) # [B, 1, T, H/2]\n", 214 | "\n", 215 | " first_half, second_half = torch.chunk(inputs, 2, dim=-1) # [B, N, T, H/2]\n", 216 | " first_part = first_half * cos - second_half * sin # [B, N, T, H/2]\n", 217 | " second_part = second_half * cos + first_half * sin # [B, N, T, H/2]\n", 218 | " out = torch.concatenate([first_part, second_part], dim=-1) # [B, N, T, H]\n", 219 | " return out.to(inputs.dtype) # [B, N, T, H]" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "# Transformer implementation" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 6, 232 | "metadata": { 233 | "executionInfo": { 234 | "elapsed": 29, 235 | "status": "ok", 236 | "timestamp": 1749056737124, 237 | "user": { 238 | "displayName": "Martin Marek", 239 | "userId": "04932572550491068578" 240 | }, 241 | "user_tz": -60 242 | }, 243 | "id": "E3LhnfGIdYP9" 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "class TransformerDecoder(nn.Module):\n", 248 | " def __init__(self, c: DictConfig, dtype=None):\n", 249 | " super().__init__()\n", 250 | " dtype = getattr(torch, c.dtype)\n", 251 | " self.token_embed_in = nn.Embedding(c.V, c.D, dtype=dtype)\n", 252 | " self.token_embed_out = nn.Linear(c.D, c.V, bias=False, dtype=dtype)\n", 253 | " self.blocks = nn.ModuleList([TransformerBlock(c.D, c.H, dtype) for _ in range(c.L)])\n", 254 | " self.out_ln = nn.RMSNorm(c.D, elementwise_affine=False, dtype=dtype)\n", 255 | " self.remat = c.remat\n", 256 | "\n", 257 | " def forward(self, x): # [B, S]\n", 258 | "\n", 259 | " # token embedding\n", 260 | " h = self.token_embed_in(x) # [B, T, D]\n", 261 | "\n", 262 | " # transformer blocks\n", 263 | " for block in self.blocks:\n", 264 | " h = checkpoint(block, h, use_reentrant=False) if self.remat else block(h)\n", 265 | "\n", 266 | " # project back to vocabulary\n", 267 | " h = self.out_ln(h)\n", 268 | " logits = self.token_embed_out(h) # [B, T, V]\n", 269 | "\n", 270 | " # get loss\n", 271 | " # we return loss (rather than logits) to reduce peak memory usage\n", 272 | " y = torch.roll(x, -1, dims=1).to(torch.int64)\n", 273 | " y[:, -1] = -1 # do not train on these indices\n", 274 | " loss = F.cross_entropy(logits.flatten(end_dim=-2), y.flatten(), ignore_index=-1)\n", 275 | "\n", 276 | " return loss\n", 277 | "\n", 278 | "\n", 279 | "class TransformerBlock(nn.Module):\n", 280 | " def __init__(self, D, H, dtype):\n", 281 | " super().__init__()\n", 282 | " self.ln1 = nn.RMSNorm(D, elementwise_affine=False, dtype=dtype)\n", 283 | " self.ln2 = nn.RMSNorm(D, elementwise_affine=False, dtype=dtype)\n", 284 | " self.attn = MultiHeadAttention(D, H, dtype)\n", 285 | " self.mlp = MLP(D, dtype)\n", 286 | "\n", 287 | " def forward(self, x): # [B, T, D]\n", 288 | " x = x + self.attn(self.ln1(x)) # attention block\n", 289 | " return x + self.mlp(self.ln2(x)) # MLP block\n", 290 | "\n", 291 | "\n", 292 | "class MultiHeadAttention(nn.Module):\n", 293 | " \"\"\"Causal attention layer.\"\"\"\n", 294 | " def __init__(self, D, H, dtype):\n", 295 | " super().__init__()\n", 296 | " N = D // H # number of heads\n", 297 | " self.qkv_proj = Einsum('BTd,SNdH->SBNTH', (3, N, D, H), dtype=dtype)\n", 298 | " self.out_proj = Einsum('BnTh,nhD->BTD', (N, H, D), dtype=dtype)\n", 299 | " self.query_norm = nn.RMSNorm(H, elementwise_affine=False, dtype=dtype)\n", 300 | " self.key_norm = nn.RMSNorm(H, elementwise_affine=False, dtype=dtype)\n", 301 | "\n", 302 | " def forward(self, x): # [B, T, D]\n", 303 | " B, T, D = x.shape\n", 304 | " device = x.device\n", 305 | "\n", 306 | " # input projection\n", 307 | " q, k, v = self.qkv_proj(x) # [B, N, T, H]\n", 308 | "\n", 309 | " # qk-norm\n", 310 | " q = self.query_norm(q)\n", 311 | " k = self.key_norm(k)\n", 312 | "\n", 313 | " # position embedding\n", 314 | " position = torch.arange(T, device=device)\n", 315 | " q = apply_rope(q, position[None])\n", 316 | " k = apply_rope(k, position[None])\n", 317 | "\n", 318 | " # attention\n", 319 | " out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # [B, N, T, H]\n", 320 | "\n", 321 | " # output projection followed by contraction back to original dims\n", 322 | " out = self.out_proj(out) # [B, T, D]\n", 323 | " return out\n", 324 | "\n", 325 | "class MLP(nn.Module):\n", 326 | " \"\"\"Multilayer perceptron.\"\"\"\n", 327 | " def __init__(self, D, dtype):\n", 328 | " super().__init__()\n", 329 | " self.fc1 = nn.Linear(in_features=D, out_features=4*D, bias=False, dtype=dtype)\n", 330 | " self.fc2 = nn.Linear(in_features=4*D, out_features=D, bias=False, dtype=dtype)\n", 331 | "\n", 332 | " def forward(self, x): # [B, T, D]\n", 333 | " h = F.gelu(self.fc1(x)) # [B, T, F]\n", 334 | " return self.fc2(h) # [B, T, D]\n", 335 | "\n", 336 | "\n", 337 | "class Einsum(nn.Module):\n", 338 | " def __init__(self, einsum_str, kernel_shape, dtype=None):\n", 339 | " super().__init__()\n", 340 | " self.einsum_str = einsum_str\n", 341 | " self.weight = nn.Parameter(torch.empty(kernel_shape, dtype=dtype))\n", 342 | " self.reset_parameters()\n", 343 | "\n", 344 | " def reset_parameters(self) -> None:\n", 345 | " nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n", 346 | "\n", 347 | " def forward(self, x):\n", 348 | " return torch.einsum(self.einsum_str, x, self.weight)" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": {}, 354 | "source": [ 355 | "# Profiling" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": 7, 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "# set model config (GPT-3 13B)\n", 365 | "model_config = DictConfig(dict(\n", 366 | " D = 5140, # model/embed/qkv dim\n", 367 | " L = 40, # num. block layers\n", 368 | " H = 128, # head dimension\n", 369 | " F = 5140 * 4, # FF inner dimension\n", 370 | " N = 5140 // 128, # num. attention heads\n", 371 | " T = 1024, # context/sequence length\n", 372 | " V = 50257,\n", 373 | " remat = True, # gradient checkpointing\n", 374 | " dtype = 'bfloat16',\n", 375 | "))" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 8, 381 | "metadata": { 382 | "colab": { 383 | "base_uri": "https://localhost:8080/" 384 | }, 385 | "executionInfo": { 386 | "elapsed": 4142, 387 | "status": "ok", 388 | "timestamp": 1749056760614, 389 | "user": { 390 | "displayName": "Martin Marek", 391 | "userId": "04932572550491068578" 392 | }, 393 | "user_tz": -60 394 | }, 395 | "id": "RTdHEnqYQUTm", 396 | "outputId": "c32b4d5f-afb5-4ae0-d9fd-a4ff153d2285" 397 | }, 398 | "outputs": [ 399 | { 400 | "name": "stdout", 401 | "output_type": "stream", 402 | "text": [ 403 | "n_params=13_181_601_960\n", 404 | "size of model: 26.36GB\n", 405 | "size of opt. state: 0.00B\n", 406 | "max. memory allocated: 27.58GB\n", 407 | "size of \"other\": 1.21GB\n" 408 | ] 409 | } 410 | ], 411 | "source": [ 412 | "def run():\n", 413 | "\n", 414 | " # model\n", 415 | " with torch.device('cuda'):\n", 416 | " model = TransformerDecoder(model_config)\n", 417 | " n_params = sum(p.numel() for p in model.parameters())\n", 418 | " print(f'{n_params=:_}')\n", 419 | " print('size of model:', sizeof_fmt(2*n_params))\n", 420 | "\n", 421 | " # standard optimizer (not fused)\n", 422 | " # optimizer = torch.optim.Adam(model.parameters(), foreach=False)\n", 423 | " # optimizer_dict = {'opt': optimizer}\n", 424 | "\n", 425 | " # fused optimizer\n", 426 | " # based on https://lightning.ai/pages/community/tutorial/faster-pytorch-training-by-reducing-peak-memory/\n", 427 | " optimizer_dict = {p:torch.optim.SGD([p], foreach=False) for p in model.parameters()} # all params\n", 428 | " # optimizer_dict = {p:Muon([p]) for p in model.blocks.parameters()} # non-embedding params\n", 429 | " # optimizer_dict |= {p:torch.optim.Adam([p], foreach=False) for p in [*model.token_embed_in.parameters(), *model.token_embed_out.parameters()]} # embedding params\n", 430 | " def optimizer_hook(parameter):\n", 431 | " optimizer_dict[parameter].step()\n", 432 | " optimizer_dict[parameter].zero_grad()\n", 433 | " for p in model.parameters():\n", 434 | " p.register_post_accumulate_grad_hook(optimizer_hook)\n", 435 | "\n", 436 | " # define training step\n", 437 | " def step():\n", 438 | " T = 1024\n", 439 | " x = torch.randint(model_config.V, [1, model_config.T], dtype=torch.int32, device='cuda')\n", 440 | " loss = model(x)\n", 441 | " loss.backward()\n", 442 | " # optimizer.step()\n", 443 | " # optimizer.zero_grad()\n", 444 | "\n", 445 | " # warm up model\n", 446 | " for _ in range(2):\n", 447 | " step()\n", 448 | " torch.cuda.synchronize()\n", 449 | "\n", 450 | " # get optimzier state size\n", 451 | " opt_num_params = 0\n", 452 | " for p, opt in optimizer_dict.items():\n", 453 | " opt_state = opt.state_dict()['state']\n", 454 | " for s1 in opt_state.values():\n", 455 | " for x in s1.values():\n", 456 | " opt_num_params += x.numel()\n", 457 | " print('size of opt. state:', sizeof_fmt(2*opt_num_params))\n", 458 | "\n", 459 | " # plot step trace\n", 460 | " # with torch.profiler.profile(record_shapes=True, profile_memory=True, with_stack=True) as p:\n", 461 | " # step()\n", 462 | " # p.export_memory_timeline('stack.html', 'cuda:0')\n", 463 | "\n", 464 | " # print max. memory during step\n", 465 | " torch.cuda.reset_peak_memory_stats(\"cuda:0\")\n", 466 | " step()\n", 467 | " max_mem = torch.cuda.max_memory_allocated(\"cuda:0\")\n", 468 | " print('max. memory allocated:', sizeof_fmt(max_mem))\n", 469 | "\n", 470 | " # compute size of 'other'\n", 471 | " other_size = max_mem - 2*n_params - 2*opt_num_params\n", 472 | " print('size of \"other\":', sizeof_fmt(other_size))\n", 473 | "\n", 474 | " # manully free memory (required given the circular reference btw model and optimizer)\n", 475 | " del model; optimizer_dict.clear()\n", 476 | "\n", 477 | "run()" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": null, 483 | "metadata": { 484 | "executionInfo": { 485 | "elapsed": 2818, 486 | "status": "aborted", 487 | "timestamp": 1749056489279, 488 | "user": { 489 | "displayName": "Martin Marek", 490 | "userId": "04932572550491068578" 491 | }, 492 | "user_tz": -60 493 | }, 494 | "id": "ahvL8r6hKnNN" 495 | }, 496 | "outputs": [], 497 | "source": [] 498 | } 499 | ], 500 | "metadata": { 501 | "accelerator": "GPU", 502 | "colab": { 503 | "authorship_tag": "ABX9TyNKei/Qc69MaUinTDbhqtr8", 504 | "gpuType": "A100", 505 | "machine_shape": "hm", 506 | "provenance": [] 507 | }, 508 | "kernelspec": { 509 | "display_name": "Python 3 (ipykernel)", 510 | "language": "python", 511 | "name": "python3" 512 | }, 513 | "language_info": { 514 | "codemirror_mode": { 515 | "name": "ipython", 516 | "version": 3 517 | }, 518 | "file_extension": ".py", 519 | "mimetype": "text/x-python", 520 | "name": "python", 521 | "nbconvert_exporter": "python", 522 | "pygments_lexer": "ipython3", 523 | "version": "3.10.12" 524 | } 525 | }, 526 | "nbformat": 4, 527 | "nbformat_minor": 4 528 | } 529 | -------------------------------------------------------------------------------- /finetuning/finetune.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"machine_shape":"hm","gpuType":"V6E1","authorship_tag":"ABX9TyMohuGYyInd9KTf/xIMn8yA"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"TPU","widgets":{"application/vnd.jupyter.widget-state+json":{"c2602df649d04f589f5accc560dd2897":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_49b71c04db96409696d6cb5485a555f9","IPY_MODEL_c7a1bf470ebf4941b467de8bbd0d429c","IPY_MODEL_5a652a3f597843ce949176b498ea80ec"],"layout":"IPY_MODEL_6e0444872e8349db874f224173e92839"}},"49b71c04db96409696d6cb5485a555f9":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_7443b5ebe6dc4d799146e8281635463f","placeholder":"​","style":"IPY_MODEL_67d04fff2e0c4884b2fdb45a09983c18","value":"Training: 100%"}},"c7a1bf470ebf4941b467de8bbd0d429c":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_8b5f7adcc1684a208a1109b3613fb919","max":2535,"min":0,"orientation":"horizontal","style":"IPY_MODEL_7188b89dfcb549b3bca583329b7a5a5d","value":2535}},"5a652a3f597843ce949176b498ea80ec":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_dfeab4a654e5438397e8247f834b6e15","placeholder":"​","style":"IPY_MODEL_c04eb0c4c2c0439eac737d2e95b622f3","value":" 2535/2535 [42:56<00:00,  1.90it/s, loss=0.55]"}},"6e0444872e8349db874f224173e92839":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"7443b5ebe6dc4d799146e8281635463f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"67d04fff2e0c4884b2fdb45a09983c18":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"8b5f7adcc1684a208a1109b3613fb919":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"7188b89dfcb549b3bca583329b7a5a5d":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"dfeab4a654e5438397e8247f834b6e15":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"c04eb0c4c2c0439eac737d2e95b622f3":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"58ee97f06e084bef8688544b671fcdb1":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_68a054d6b9b640b38089efa4778b0c9c","IPY_MODEL_6c1b201e3c424c53b8fcb22b4e475c42","IPY_MODEL_c451c0db376a4df9b1d686f7b8cadb67"],"layout":"IPY_MODEL_ca778d188eee4444950f05ce262f5bcd"}},"68a054d6b9b640b38089efa4778b0c9c":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_f6a77f82a6924672b4aaf2c5a3961893","placeholder":"​","style":"IPY_MODEL_161b148dcc984a958b8a21b730e5f56f","value":"Sampling: 100%"}},"6c1b201e3c424c53b8fcb22b4e475c42":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_c1f8be2657474672bb18536577196d37","max":1,"min":0,"orientation":"horizontal","style":"IPY_MODEL_774e806a9f9d4291b4a1219df4426b14","value":1}},"c451c0db376a4df9b1d686f7b8cadb67":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_396e822d75d14d2d8a8b075b8331b599","placeholder":"​","style":"IPY_MODEL_3b327dbae64a468da9793630629caec7","value":" 1/1 [01:09<00:00, 69.12s/it]"}},"ca778d188eee4444950f05ce262f5bcd":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"f6a77f82a6924672b4aaf2c5a3961893":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"161b148dcc984a958b8a21b730e5f56f":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"c1f8be2657474672bb18536577196d37":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"774e806a9f9d4291b4a1219df4426b14":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"396e822d75d14d2d8a8b075b8331b599":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"3b327dbae64a468da9793630629caec7":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"DZ17TCOfMKuw","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1756321993976,"user_tz":-60,"elapsed":2486,"user":{"displayName":"Martin Marek","userId":"04932572550491068578"}},"outputId":"f95fba17-f6f2-46a4-c519-3e2789350616"},"outputs":[{"output_type":"stream","name":"stdout","text":["Cloning into 'batch-size'...\n","remote: Enumerating objects: 91, done.\u001b[K\n","remote: Counting objects: 100% (91/91), done.\u001b[K\n","remote: Compressing objects: 100% (80/80), done.\u001b[K\n","remote: Total 91 (delta 10), reused 66 (delta 7), pack-reused 0 (from 0)\u001b[K\n","Receiving objects: 100% (91/91), 1.52 MiB | 11.70 MiB/s, done.\n","Resolving deltas: 100% (10/10), done.\n"]}],"source":["! uv pip install fire wandb datasets math_verify \"jax[tpu]\" flax==0.12.0 -q\n","! git clone --depth=1 https://github.com/martin-marek/batch-size.git\n","# note: lora requires additioanlly installing https://github.com/google/qwix"]},{"cell_type":"code","source":["%cd /content/batch-size/finetuning\n","from finetune import finetune"],"metadata":{"id":"vpha0ShoT8Lb","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1756322929330,"user_tz":-60,"elapsed":1728,"user":{"displayName":"Martin Marek","userId":"04932572550491068578"}},"outputId":"d15e3a48-6f48-411a-fe53-82eaacc14816"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["/content/batch-size/finetuning\n"]},{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.12/dist-packages/jax/_src/cloud_tpu_init.py:82: UserWarning: Transparent hugepages are not enabled. TPU runtime startup and shutdown time should be significantly improved on TPU v5e and newer. If not already set, you may need to enable transparent hugepages in your VM image (sudo sh -c \"echo always > /sys/kernel/mm/transparent_hugepage/enabled\")\n"," warnings.warn(\n"]}]},{"cell_type":"code","source":["# when running for the first time, might ask for kaggle credentials to download model weights\n","finetune(\n"," model_variant='gemma3-12b',\n"," optimizer_name='adafactor',\n"," batch_size=1,\n"," n_epochs=1,\n"," peak_lr=1e-4,\n"," b2=0.997,\n"," param_dtype='bfloat16',\n"," stochastic_round=True,\n"," n_eval_samples=8, # only testing on 8 questions for demo purposes\n"," eval_batch_size=8, # we can only afford a small KV cache\n"," wandb_mode='offline', # use Weights & Biases offline\n"," print_output=True, # print MATH completions\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000,"referenced_widgets":["c2602df649d04f589f5accc560dd2897","49b71c04db96409696d6cb5485a555f9","c7a1bf470ebf4941b467de8bbd0d429c","5a652a3f597843ce949176b498ea80ec","6e0444872e8349db874f224173e92839","7443b5ebe6dc4d799146e8281635463f","67d04fff2e0c4884b2fdb45a09983c18","8b5f7adcc1684a208a1109b3613fb919","7188b89dfcb549b3bca583329b7a5a5d","dfeab4a654e5438397e8247f834b6e15","c04eb0c4c2c0439eac737d2e95b622f3","58ee97f06e084bef8688544b671fcdb1","68a054d6b9b640b38089efa4778b0c9c","6c1b201e3c424c53b8fcb22b4e475c42","c451c0db376a4df9b1d686f7b8cadb67","ca778d188eee4444950f05ce262f5bcd","f6a77f82a6924672b4aaf2c5a3961893","161b148dcc984a958b8a21b730e5f56f","c1f8be2657474672bb18536577196d37","774e806a9f9d4291b4a1219df4426b14","396e822d75d14d2d8a8b075b8331b599","3b327dbae64a468da9793630629caec7"]},"id":"dpTFYfC-THSd","outputId":"9fb3f96f-6103-4b43-bb5c-900f773a880f","executionInfo":{"status":"ok","timestamp":1756325612170,"user_tz":-60,"elapsed":1047450,"user":{"displayName":"Martin Marek","userId":"04932572550491068578"}}},"execution_count":null,"outputs":[{"metadata":{"tags":null},"name":"stdout","output_type":"stream","text":["train_config={'model_variant': 'gemma3-12b', 'lora_rank': None, 'temperature': 1, 'optimizer_name': 'adafactor', 'peak_lr': 0.0001, 'lr_schedule': 'const', 'b2': 0.997, 'n_epochs': 1, 'batch_size': 1, 'microbatch_size': 1, 'n_eval_samples': 8, 'eval_batch_size': 8, 'n_data_devices': 1, 'train_parallelism': 'seq', 'param_dtype': 'bfloat16', 'stochastic_round': True, 'remat': False, 'log_every_samples': 100, 'print_output': True, 'wandb_mode': 'offline', 'run_name': None, 'seed': 0, 'kwargs': {}}\n"]},{"data":{"text/html":["Tracking run with wandb version 0.21.1"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["W&B syncing is set to `offline` in this directory. Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
Run data is saved locally in /content/batch-size/finetuning/wandb/offline-run-20250827_192852-b7npqj4d"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"metadata":{"tags":null},"name":"stdout","output_type":"stream","text":["loading model...\n"]},{"metadata":{"tags":null},"name":"stderr","output_type":"stream","text":["WARNING:absl:Provided metadata contains unknown key custom. Adding it to custom_metadata.\n"]},{"metadata":{"tags":null},"name":"stdout","output_type":"stream","text":["loading data...\n","loading datasets...\n"]},{"metadata":{"tags":null},"name":"stderr","output_type":"stream","text":["/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n","The secret `HF_TOKEN` does not exist in your Colab secrets.\n","To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n","You will be able to reuse this secret in all of your notebooks.\n","Please note that authentication is recommended but still optional to access public models or datasets.\n"," warnings.warn(\n"]},{"metadata":{"tags":null},"name":"stdout","output_type":"stream","text":["tokenizing training dataset...\n","skipped train. seq.: 2.1%\n","tokenizing eval dataset...\n","skipped valid. seq.: 0.0%\n","n_model_params=11_765_788_416\n","n_opt_params=13_234_532\n"]},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"c2602df649d04f589f5accc560dd2897","version_major":2,"version_minor":0},"text/plain":["Training: 0%| | 0/2535 [00:002018^{17}+2018^{12} + 1 \\\\\n","&= 2018^{17} + 2018^{13} + 1.\n","\\end{align*}Also, \\begin{align*}\n","2018^{17} + 2018^{13} + 1 &= 2018^{13}(2018^{4} + 1) + 1 \\\\\n","&= 2018^{13}(1655664 + 1) + 1 \\\\\n","&= 2018^{13}\\cdot 1655665 + 1 \\\\\n","&< 2018^{13}\\cdot 1655665 + 2018 \\\\\n","&= 2019x+2018.\n","\\end{align*}Thus, the polynomial with the greatest real root is $\\boxed{\\text{A}}$.\n","PARSED: [a, '\\\\text{A}']\n","GOLD: By Descartes' Rule of Signs, none of the polynomials has a positive root, and each one has exactly one negative root. Furthermore, each polynomial is positive at $x = 0$ and negative at $x = -1,$ so each real root lies between $-1$ and 0. Also, each polynomial is increasing on the interval $(-1,0).$\n","\n","Let $r_A$ and $r_B$ be the roots of the polynomials in options A and B, respectively, so\n","\\[r_A^{19} + 2018r_A^{11} + 1 = r_B^{17} + 2018r_B^{11} + 1 = 0,\\]so $r_A^{19} = r_B^{17}.$ Since $r_A \\in (-1,0),$ $r_B^{17} = r_A^{19} > r_A^{17},$ so $r_B > r_A.$\n","\n","Similarly, let $r_C$ and $r_D$ be the roots of the polynomials in options C and D, respectively, so\n","\\[r_C^{19} + 2018r_C^{13} + 1 = r_D^{17} + 2018r_D^{13} + 1 = 0,\\]so $r_C^{19} = r_D^{17}.$ Since $r_C \\in (-1,0),$ $r_D^{17} = r_C^{19} > r_C^{17},$ so $r_D > r_C.$\n","\n","Since\n","\\[r_B^{17} + 2018r_B^{11} + 1 = r_D^{17} + 2018r_D^{13} + 1 = 0,\\]we have that $r_B^{11} = r_D^{13}.$ Since $r_D \\in (-1,0),$ $r_B^{11} = r_D^{13} > r_D^{11},$ so $r_B > r_D.$\n","\n","Therefore, the largest root must be either $r_B$ or the root of $2019x + 2018 = 0,$ which is $-\\frac{2018}{2019}.$\n","\n","Let $f(x) = x^{17} + 2018x^{11} + 1,$ so $f(r_B) = 0.$ Note that\n","\\[f \\left( -\\frac{2}{3} \\right) = -\\frac{2^{17}}{3^{17}} - 2018 \\cdot \\frac{2^{11}}{3^{11}} + 1.\\]We claim that $2018 \\cdot 2^{11} > 3^{11}.$ Since $2^2 > 3,$ $2^{22} > 3^{11}.$ Then\n","\\[2018 \\cdot 2^{11} = 1009 \\cdot 2^{22} > 3^{11}.\\]From $2018 \\cdot 2^{11} > 3^{11},$ $2018 \\cdot \\frac{2^{11}}{3^{11}} > 1,$ so\n","\\[f \\left( -\\frac{2}{3} \\right) = -\\frac{2^{17}}{3^{17}} - 2018 \\cdot \\frac{2^{11}}{3^{11}} + 1 < 0.\\]Since $f(x)$ is an increasing function, we can conclude that $r_B > -\\frac{2}{3} > -\\frac{2018}{2019}.$ Therefore, the answer is $\\boxed{\\text{(B)}}.$\n","CORRECT: False\n","------------\n","PROMPT:\n","The positive difference between the squares of two consecutive odd integers is 128. What is the product of the two integers?\n","COMPLETION:\n","Let $n$ be the smaller of the two integers. Since the two integers are consecutive, the larger integer is $n+2$. So, the equation to solve is $n^2-(n+2)^2=128$. Expanding, we have \\begin{align*}\n","n^2-(n+2)^2&=128\\quad\\Rightarrow\\\\\n","n^2-(n^2+4n+4)&=128\\quad\\Rightarrow\\\\\n","-4n&=132\\quad\\Rightarrow\\\\\n","n&=-33.\n","\\end{align*} Since $n$ is a positive integer, we must reject this solution. Let $n$ be the larger of the two integers. So, the equation to solve is now $(n+2)^2-n^2=128$. Expanding, we have \\begin{align*}\n","(n+2)^2-n^2&=128\\quad\\Rightarrow\\\\\n","n^2+4n+4-n^2&=128\\quad\\Rightarrow\\\\\n","4n&=124\\quad\\Rightarrow\\\\\n","n&=31.\n","\\end{align*} Thus, the two integers are 31 and 33, and the product is $\\boxed{1023}$.\n","PARSED: [1023, '1023']\n","GOLD: Let $a, a+2$ be the two integers. We are given that $(a+2)^2-a^2 = 128$. Using the difference of squares factorization, the equation becomes $(a+2+a)(a+2-a) = 128$. Simplifying and solving, we get: \\begin{align*}\n","(2a+2)(2) &= 128\\\\\n","\\Rightarrow 2a+2 &= 64\\\\\n","\\Rightarrow 2a &= 62\\\\\n","\\Rightarrow a &= 31.\\\\\n","\\end{align*} Therefore the desired product is $a(a+2) = 31\\cdot33 = \\boxed{1023}$.\n","CORRECT: True\n","------------\n","PROMPT:\n","The area of right triangle $ABC$ is 4, and the hypotenuse $\\overline{AB}$ is 12. Compute $\\sin 2A.$\n","COMPLETION:\n","[asy]\n","draw((0,0)--(11,0));\n","draw((0,0)--(12,0)--(0,10.4),linewidth(0.6));\n","draw((0,0)--(11,0)--(11,2)--(0,10.4));\n","draw((0,0)--(1.3,0));\n","draw((1.3,0)--(1.3,0.4));\n","draw((1.3,0.4)--(0.9,0.4)--(0.9,0));\n","label(\"$B$\",(0,0),S);\n","label(\"$C$\",(12,0),S);\n","label(\"$A$\",(0,10.4),NW);\n","label(\"$D$\",(11,2),E);\n","[/asy]\n","\n","Since the area of $\\triangle ABC$ is 4 and its base $AB$ has length 12, the height from $C$ to $AB$ has length $2$. Then, $\\overline{AC}$ has length $\\sqrt{116}$.\n","\n","Notice that $\\sin 2A = 2 \\sin A \\cos A.$ Since $\\overline{AC}$ and $\\overline{BC}$ are the legs and $\\overline{AB}$ is the hypotenuse of a right triangle, $\\overline{AC}$ and $\\overline{BC}$ are adjacent and opposite sides to $A$ and $B$, respectively. Thus, $2 \\sin A \\cos A = 2 \\cdot \\frac{2}{\\sqrt{116}} \\cdot \\frac{\\sqrt{116}}{12} = \\boxed{\\frac{1}{3}}.$\n","PARSED: [1/3, '\\\\frac{1}{3}']\n","GOLD: We have that $\\frac{1}{2} ab = 4,$ so $ab = 8.$\n","\n","[asy]\n","unitsize (1 cm);\n","\n","pair A, B, C;\n","\n","C = (0,0);\n","B = (3,0);\n","A = (0,2);\n","\n","draw(A--B--C--cycle);\n","draw(rightanglemark(A,C,B,6));\n","\n","label(\"$A$\", A, N);\n","label(\"$B$\", B, E);\n","label(\"$C$\", C, SW);\n","label(\"$a$\", (B + C)/2, S, red);\n","label(\"$b$\", (A + C)/2, W, red);\n","label(\"$12$\", (A + B)/2, NE, red);\n","[/asy]\n","\n","Then\n","\\[\\sin 2A = 2 \\sin A \\cos A = 2 \\cdot \\frac{a}{12} \\cdot \\frac{b}{12} = \\frac{ab}{72} = \\frac{8}{72} = \\boxed{\\frac{1}{9}}.\\]\n","CORRECT: False\n","------------\n","PROMPT:\n","Let $z_1 = 18 + 83i$, $z_2 = 18 + 39i,$ and $z_3 = 78 + 99i,$ where $i^2 = -1$. Let $z$ be the unique complex number with the properties that $\\frac{z_3 - z_1}{z_2 - z_1} \\cdot \\frac{z - z_2}{z - z_3}$ is a real number and the imaginary part of $z$ is the greatest possible. Find the real part of $z$.\n","COMPLETION:\n","From the given equation, $$\n","\\left|\\frac{z_3 - z_1}{z_2 - z_1}\\cdot \\frac{z - z_2}{z - z_3}\\right| = \\left|\\frac{z_3 - z_1}{z_2 - z_1}\\right| \\cdot \\left|\\frac{z - z_2}{z - z_3}\\right| = 1.\n","$$ This is true if and only if the complex numbers corresponding to the factors are each other's negatives. In other words, \\begin{align*}\n","\\frac{z_3 - z_1}{z_2 - z_1} &= -\\frac{z - z_2}{z - z_3} \\\\\n","\\frac{78 + 99i - 18 - 83i}{18 + 39i - 18 - 83i} &= -\\frac{z - 18 - 39i}{z - 78 - 99i} \\\\\n","\\frac{60 + 16i}{48 - 44i} &= \\frac{z - 18 - 39i}{z - 78 - 99i} \\\\\n","\\frac{5 + i}{4 - 4i} &= \\frac{z - 18 - 39i}{z - 78 - 99i} \\\\\n","\\frac{5 - i}{4 + 4i} &= \\frac{z - 18 - 39i}{z - 78 - 99i} \\\\\n","\\frac{5 - i}{4} &= \\frac{z - 18 - 39i}{z - 78 - 99i} \\\\\n","5z - 390 - 495i = 4z - 72 - 156i \\\\\n","z = 318 + 339i\n","\\end{align*} To find the complex number $z$ for which the imaginary part is greatest, we must maximize the imaginary part of $\\frac{z_3 - z_1}{z_2 - z_1}$. Since this term is a complex conjugate of $\\frac{z - z_2}{z - z_3}$, the imaginary part of this term must be maximized as well. Therefore, $z$ must lie on the line passing through $z_2$ and $z_3$ with a slope of $(99 - 39)/(78 - 18) = 2$. The imaginary part of $z_3 - z_1$ is $99 - 83 = 16$. Thus, this imaginary part is maximized when $z = \\boxed{318}$.\n","PARSED: [318, '318']\n","GOLD: Let $z = a + bi,$ where $a$ and $b$ are real numbers. Then\n","\\begin{align*}\n","\\frac{z_3 - z_1}{z_2 - z_1} \\cdot \\frac{z - z_2}{z - z_3} &= \\frac{60 + 16i}{-44i} \\cdot \\frac{(a - 18) + (b - 39)i}{(a - 78) + (b - 99)i} \\\\\n","&= \\frac{-4 + 15i}{11} \\cdot \\frac{[(a - 18) + (b - 39)i][(a - 78) - (b - 99)i]}{(a - 78)^2 + (b - 99)^2}.\n","\\end{align*}This expression is real if and only if the imaginary part is 0. In other words,\n","\\[(-4 + 15i)[(a - 18) + (b - 39)i][(a - 78) - (b - 99)i]\\]has imaginary part 0. In turn this is equivalent to\n","\\[(-4)(-(a - 18)(b - 99) + (a - 78)(b - 39)) + 15((a - 18)(a - 78) + (b - 39)(b - 99)) = 0.\\]This simplifies to $a^2 - 112a + b^2 - 122b + 4929 = 0.$ Completing the square, we get\n","\\[(a - 56)^2 + (b - 61)^2 = 1928,\\]so\n","\\[(a - 56)^2 = 1928 - (b - 61)^2.\\]When $b$ is maximized, the right-hand side is 0, and $a = \\boxed{56}.$\n","CORRECT: False\n","------------\n","PROMPT:\n","A TV screen measures 24 by 16 inches. If each dimension is increased by 20$\\%$, by what percent is the area increased?\n","COMPLETION:\n","The area of the screen increases by 40$\\%$.\n","\n","The screen has area $(.24\\times24\\times.8)\\times(.16\\times24\\times.8)=1.44\\times24\\times1.6$ square inches, an increase of 44 percent. The area of the screen is $(.24\\times24)\\times(.16\\times24)$ square inches. (Note that this calculation can be skipped and the percentage change can still be determined, because $1+44\\%=.5+22=.8\\times1.2$.)\n","PARSED: [Eq(0.5 + 22, 0.8*1.2) & Eq(44*(1/100) + 1, 0.5 + 22), '1+44\\\\%=.5+22=.8\\\\times1.2']\n","GOLD: Since we're dealing with percentages, the actual dimensions do not matter. Let $l$ and $w$ represent the dimensions of the TV screen. The current area is $lw$. If we increase $l$ by $20\\%$, we end up with $l\\left(1+\\frac{20}{100}\\right)=l\\left(\\frac{10}{10}+\\frac{2}{10}\\right)=\\frac{12}{10}l$. Increasing the width results in $\\frac{12}{10}w$. The new area is $\\frac{12}{10}l\\times\\frac{12}{10}w=\\frac{144}{100}lw=lw\\left(1+\\frac{44}{100}\\right)$. The area increases by $\\boxed{44\\%}$.\n","CORRECT: False\n","------------\n","PROMPT:\n","How many ways are there to arrange 6 people around a circular table with 7 seats? (Two seatings are considered the same if one is a rotation of the other.)\n","COMPLETION:\n","If we can position 5 people around the table, then there are 7 different places to position the 6th person, so there are 7 different seatings that give the same relative arrangement. Therefore, the number of arrangements is equal to the number of seatings with 5 people divided by 7.\n","\n","It is clear that with 5 people, there are $5!$ ways to seat the people. However, each arrangement is counted $5!$ times, once for each rotation of the seats. So we divide by 5! to eliminate the over-counting, giving us $\\frac{5!}{7} = \\boxed{60}$ arrangements.\n","PARSED: [60, '60']\n","GOLD: First choose which seat is empty. It doesn't matter which seat we choose because we can rotate everyone at the table to move the empty seat anywhere we want. After the empty seat has been chosen, there are $6!=\\boxed{720}$ ways to arrange the remaining people.\n","CORRECT: False\n","------------\n","PROMPT:\n","Line $j$ is perpendicular to $\\frac{y}{3}+\\frac{2x}{5}=2$. What is the slope of line $j$?\n","COMPLETION:\n","$\\frac{y}{3}+\\frac{2x}{5}=2 \\Longrightarrow \\frac{5y+6x}{15}=2 \\Longrightarrow 5y+6x=30 \\Longrightarrow y=-\\frac{6}{5}x+6$. The slope of this line is $-\\frac{6}{5}$, so the slope of the perpendicular line is $\\boxed{\\frac{5}{6}}$.\n","PARSED: [5/6, '\\\\frac{5}{6}']\n","GOLD: First, we find the slope of $\\frac{y}{3}+\\frac{2x}{5}=2$. We change the form to slope-intercept form. Multiplying both sides by 3, we get $y+\\frac{6x}{5}=6$. Moving the $x$ to the right, we get $y=-\\frac{6x}{5}+6$. The slopes of two perpendicular lines are negative reciprocals. Therefore, the slope of line $j$ is the opposite reciprocal of $-\\frac{6}{5}$ which is $\\boxed{\\frac56}$.\n","CORRECT: True\n","------------\n","PROMPT:\n","Let $p(x)$ be a cubic polynomial such that $p(2) = 0,$ $p(-1) = 0,$ $p(4) = 6,$ and $p(5) = 8.$ Find $p(7).$\n","COMPLETION:\n","Let $p(x) = a(x - 2)(x + 1)(x - 4)$ for some constant $a.$ Then\n","\\begin{align*}\n","p(4) &= a(4 - 2)(4 + 1)(4 - 4) = 0, \\\\\n","p(5) &= a(5 - 2)(5 + 1)(5 - 4) = 12a, \\\\\n","p(7) &= a(7 - 2)(7 + 1)(7 - 4) = 60a = 5p(5).\n","\\end{align*}Since $p(5) = 8,$ $p(7) = \\boxed{40}.$\n","PARSED: [40, '40']\n","GOLD: Since $p(2) = p(-1) = 0,$ $p(x)$ is of the form\n","\\[p(x) = (ax + b)(x - 2)(x + 1)\\]for some constants $a$ and $b.$ Setting $x = 4$ and $x = 5,$ we get\n","\\begin{align*}\n","(4a + b)(2)(5) &= p(4) = 6, \\\\\n","(5a + b)(3)(6) &= p(5) = 8,\n","\\end{align*}so\n","\\begin{align*}\n","4a + b &= \\frac{3}{5}, \\\\\n","5a + b &= \\frac{4}{9}.\n","\\end{align*}Solving, we find $a = -\\frac{7}{45}$ and $b = \\frac{11}{9}.$ Hence,\n","\\[p(x) = \\left( -\\frac{7}{45} x + \\frac{11}{9} \\right) (x - 2)(x + 1) = -\\frac{(7x - 55)(x - 2)(x + 1)}{45}.\\]Therefore,\n","\\[p(7) = -\\frac{(49 - 55)(5)(8)}{45} = \\boxed{\\frac{16}{3}}.\\]\n","CORRECT: False\n"]},{"output_type":"display_data","data":{"text/plain":[""],"text/html":[]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":[""],"text/html":["

Run history:


accuracy
finished
length
train_loss▇▅▆▅█▅▆▄▅▆▇▄▃▆▃▃▅▂▂▁▄▃▅▃▁

Run summary:


accuracy0.25
finished1
length258.75
train_loss0.54794

"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":[""],"text/html":["You can sync this run to the cloud by running:
wandb sync /content/batch-size/finetuning/wandb/offline-run-20250827_192852-b7npqj4d"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":[""],"text/html":["Find logs at: ./wandb/offline-run-20250827_192852-b7npqj4d/logs"]},"metadata":{}}]},{"cell_type":"code","source":[],"metadata":{"id":"iaZETfC4Txwn"},"execution_count":null,"outputs":[]}]} --------------------------------------------------------------------------------