├── examples ├── modded-nanogpt │ ├── data │ │ ├── requirements.txt │ │ ├── cached_fineweb10B.py │ │ └── fineweb.py │ ├── requirements.txt │ └── README.md ├── shallow-nanogpt │ ├── mup_examples │ │ ├── requirements.txt │ │ ├── README.md │ │ └── mutransfer_lr_shakespeare_char │ │ │ ├── sp │ │ │ ├── run_test.sh │ │ │ └── run.sh │ │ │ ├── uscion │ │ │ ├── run_test.sh │ │ │ └── run.sh │ │ │ ├── scion │ │ │ ├── run_test.sh │ │ │ └── run.sh │ │ │ ├── scion_full │ │ │ ├── run_test.sh │ │ │ ├── run.sh │ │ │ └── run_naive.sh │ │ │ ├── mup │ │ │ ├── run_test.sh │ │ │ └── run.sh │ │ │ └── plot.ipynb │ ├── .gitignore │ ├── data │ │ ├── shakespeare │ │ │ ├── readme.md │ │ │ └── prepare.py │ │ ├── shakespeare_char │ │ │ ├── readme.md │ │ │ └── prepare.py │ │ └── openwebtext │ │ │ ├── readme.md │ │ │ └── prepare.py │ ├── config │ │ ├── eval_gpt2.py │ │ ├── eval_gpt2_xl.py │ │ ├── eval_gpt2_large.py │ │ ├── eval_gpt2_medium.py │ │ ├── finetune_shakespeare.py │ │ ├── train_gpt2.py │ │ └── train_shakespeare_char.py │ ├── LICENSE │ ├── README.md │ ├── configurator.py │ ├── sample.py │ ├── bench.py │ ├── csv_logging.py │ ├── scion.py │ ├── README_orig.md │ └── train.py ├── deit │ ├── README.md │ └── scion.py └── airbench │ ├── README.md │ ├── scion.py │ └── airbench_scion.py ├── LICENSE ├── README.md └── scion.py /examples/modded-nanogpt/data/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | tiktoken 3 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.13.0 2 | seaborn 3 | -------------------------------------------------------------------------------- /examples/modded-nanogpt/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | torch 4 | huggingface-hub 5 | datargs -------------------------------------------------------------------------------- /examples/shallow-nanogpt/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | .ipynb_checkpoints/ 4 | .vscode 5 | __pycache__/ 6 | *.bin 7 | *.pkl 8 | *.pt 9 | *.pyc 10 | input.txt 11 | env/ 12 | venv/ 13 | mup_examples/*/*/out/* 14 | mup_examples/*/*/out_old/* 15 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/data/shakespeare/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # tiny shakespeare 3 | 4 | Tiny shakespeare, of the good old char-rnn fame :) 5 | 6 | After running `prepare.py`: 7 | 8 | - train.bin has 301,966 tokens 9 | - val.bin has 36,059 tokens 10 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/config/eval_gpt2.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=12, n_head=12, n_embd=768 3 | # 124M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2' 9 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/config/eval_gpt2_xl.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=48, n_head=25, n_embd=1600 3 | # 1558M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-xl' 9 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/config/eval_gpt2_large.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=36, n_head=20, n_embd=1280 3 | # 774M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-large' 9 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/config/eval_gpt2_medium.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=24, n_head=16, n_embd=1024 3 | # 350M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-medium' 9 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/data/shakespeare_char/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # tiny shakespeare, character-level 3 | 4 | Tiny shakespeare, of the good old char-rnn fame :) Treated on character-level. 5 | 6 | After running `prepare.py`: 7 | 8 | - train.bin has 1,003,854 tokens 9 | - val.bin has 111,540 tokens 10 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/data/openwebtext/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## openwebtext dataset 3 | 4 | after running `prepare.py` (preprocess) we get: 5 | 6 | - train.bin is ~17GB, val.bin ~8.5MB 7 | - train has ~9B tokens (9,035,582,198) 8 | - val has ~4M tokens (4,434,897) 9 | 10 | this came from 8,013,769 documents in total. 11 | 12 | references: 13 | 14 | - OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) 15 | - [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset 16 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/README.md: -------------------------------------------------------------------------------- 1 | # Experiment Reproduction 2 | 3 | Install the minimal dataset and plotting requirements with `pip install -r requirements.txt`. We used the PyTorch NGC container for GPU-based runs, but any environment containing the dependencies from [the main README](https://github.com/EleutherAI/nanoGPT-mup?tab=readme-ov-file#install) will suffice. 4 | 5 | SCION_CHANGE: install `sudo apt install bc`. 6 | SCION_CHANGE: install `pip install tiktoken` 7 | 8 | To download the tiny shakespeare dataset, run `python ../data/shakespeare_char/prepare.py`. For OpenWebText (OWT), run `python ../data/openwebtext/prepare.py`. 9 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/config/finetune_shakespeare.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | out_dir = 'out-shakespeare' 4 | eval_interval = 5 5 | eval_iters = 40 6 | wandb_log = False # feel free to turn on 7 | wandb_project = 'shakespeare' 8 | wandb_run_name = 'ft-' + str(time.time()) 9 | 10 | dataset = 'shakespeare' 11 | init_from = 'gpt2-xl' # this is the largest GPT-2 model 12 | 13 | # only save checkpoints if the validation loss improves 14 | always_save_checkpoint = False 15 | 16 | # the number of examples per iter: 17 | # 1 batch_size * 32 grad_accum * 1024 tokens = 32,768 tokens/iter 18 | # shakespeare has 301,966 tokens, so 1 epoch ~= 9.2 iters 19 | batch_size = 1 20 | gradient_accumulation_steps = 32 21 | max_iters = 20 22 | 23 | # finetune at constant LR 24 | learning_rate = 3e-5 25 | decay_lr = False 26 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/config/train_gpt2.py: -------------------------------------------------------------------------------- 1 | # config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB 2 | # launch as the following (e.g. in a screen session) and wait ~5 days: 3 | # $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py 4 | 5 | wandb_log = True 6 | wandb_project = 'owt' 7 | wandb_run_name='gpt2-124M' 8 | 9 | # these make the total batch size be ~0.5M 10 | # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520 11 | batch_size = 12 12 | block_size = 1024 13 | gradient_accumulation_steps = 5 * 8 14 | 15 | # this makes total number of tokens be 300B 16 | max_iters = 600000 17 | lr_decay_iters = 600000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # weight decay 25 | weight_decay = 1e-1 26 | -------------------------------------------------------------------------------- /examples/modded-nanogpt/data/cached_fineweb10B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from huggingface_hub import hf_hub_download 4 | # Download the GPT-2 tokens of Fineweb10B from huggingface. This 5 | # saves about an hour of startup time compared to regenerating them. 6 | def get(fname): 7 | local_dir = os.path.join(os.path.dirname(__file__), 'fineweb10B') 8 | if not os.path.exists(os.path.join(local_dir, fname)): 9 | hf_hub_download(repo_id="kjj0/fineweb10B-gpt2", filename=fname, 10 | repo_type="dataset", local_dir=local_dir) 11 | get("fineweb_val_%06d.bin" % 0) 12 | num_chunks = 103 # full fineweb10B. Each chunk is ~98.5M tokens 13 | if len(sys.argv) >= 2: # we can pass an argument to download less 14 | num_chunks = int(sys.argv[1]) 15 | for i in range(1, num_chunks+1): 16 | get("fineweb_train_%06d.bin" % i) 17 | -------------------------------------------------------------------------------- /examples/modded-nanogpt/README.md: -------------------------------------------------------------------------------- 1 | # Modded NanoGPT 2 | 3 | This code builds on [modded-nanogpt](https://github.com/KellerJordan/modded-nanogpt/). 4 | 5 | ## Setup 6 | 7 | ```bash 8 | pip install -r requirements.txt 9 | pip install -r data/requirements.txt 10 | pip install torch --index-url https://download.pytorch.org/whl/cu124 --upgrade 11 | python data/cached_fineweb10B.py 8 # downloads only the first 800M training tokens to save time 12 | ``` 13 | 14 | ## Run 15 | 16 | ```bash 17 | torchrun --standalone --nproc_per_node=4 train_gpt_scion.py 18 | torchrun --standalone --nproc_per_node=4 train_gpt_scionlight.py 19 | ``` 20 | 21 | Notes: 22 | 23 | - `ScionLight` has necessary changes tagged with "ScionLight modification" (specifically, don't zero gradients and be careful with gradient accumulation) 24 | - When changing `n_embd`, remember to change `n_head` accordingly to `n_embd // 128` to maintain head dimension of 128. 25 | -------------------------------------------------------------------------------- /examples/deit/README.md: -------------------------------------------------------------------------------- 1 | # DeiT 2 | 3 | This code builds on [DeiT on ImageNet](https://github.com/facebookresearch/deit). 4 | 5 | ## Run 6 | 1. Clone the deit repository locally: 7 | ``` 8 | git clone https://github.com/facebookresearch/deit.git 9 | ``` 10 | 11 | 2. Follow [README_deit](https://github.com/facebookresearch/deit/blob/main/README_deit.md) for setup and data preparation. 12 | 13 | 3. Copy Scion files to the deit repository: 14 | ``` 15 | cp -r main_scion.py scion.py deit/ 16 | ``` 17 | 18 | 4. Train DeiT-base model: 19 | ``` 20 | torchrun --nnodes=4 --nproc_per_node=4 main_scion.py --model deit_base_patch16_224 --epochs 200 --output_dir path2checkpoints_scion --batch-size 256 --lr 8e-5 --min-lr 1e-7 --warmup-epochs 0 --data-path "path_to_imagenet" 21 | ``` 22 | This should give 23 | ``` 24 | "test_acc1": 81.974, "test_acc5": 95.716 25 | ``` 26 | 27 | 28 | ## CHANGELOG 29 | 30 | Changes made: 31 | 32 | - Modernized architecture: 33 | - RMS norm instead of LayerNorm 34 | - GELU `sqrt(2)` scaling 35 | - Increases total batch size to 4096 36 | - Decreases the number of epochs from 300 to 200 37 | - No warmup 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LIONS 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /examples/shallow-nanogpt/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrej Karpathy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/data/shakespeare/prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import tiktoken 4 | import numpy as np 5 | 6 | # download the tiny shakespeare dataset 7 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') 8 | if not os.path.exists(input_file_path): 9 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 10 | with open(input_file_path, 'w', encoding='utf-8') as f: 11 | f.write(requests.get(data_url).text) 12 | 13 | with open(input_file_path, 'r', encoding='utf-8') as f: 14 | data = f.read() 15 | n = len(data) 16 | train_data = data[:int(n*0.9)] 17 | val_data = data[int(n*0.9):] 18 | 19 | # encode with tiktoken gpt2 bpe 20 | enc = tiktoken.get_encoding("gpt2") 21 | train_ids = enc.encode_ordinary(train_data) 22 | val_ids = enc.encode_ordinary(val_data) 23 | print(f"train has {len(train_ids):,} tokens") 24 | print(f"val has {len(val_ids):,} tokens") 25 | 26 | # export to bin files 27 | train_ids = np.array(train_ids, dtype=np.uint16) 28 | val_ids = np.array(val_ids, dtype=np.uint16) 29 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) 30 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) 31 | 32 | # train.bin has 301,966 tokens 33 | # val.bin has 36,059 tokens 34 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/config/train_shakespeare_char.py: -------------------------------------------------------------------------------- 1 | # train a miniature character-level shakespeare model 2 | # good for debugging and playing on macbooks and such 3 | 4 | out_dir = 'out-shakespeare-char' 5 | eval_interval = 250 # keep frequent because we'll overfit 6 | eval_iters = 200 7 | log_interval = 10 # don't print too too often 8 | 9 | # we expect to overfit on this small dataset, so only save when val improves 10 | always_save_checkpoint = False 11 | 12 | wandb_log = False # override via command line if you like 13 | wandb_project = 'shakespeare-char' 14 | wandb_run_name = 'mini-gpt' 15 | 16 | dataset = 'shakespeare_char' 17 | gradient_accumulation_steps = 1 18 | batch_size = 64 19 | block_size = 256 # context of up to 256 previous characters 20 | 21 | # baby GPT model :) 22 | n_layer = 6 23 | n_head = 6 24 | n_embd = 384 25 | dropout = 0.2 26 | 27 | learning_rate = 1e-3 # with baby networks can afford to go a bit higher 28 | max_iters = 5000 29 | lr_decay_iters = 5000 # make equal to max_iters usually 30 | min_lr = 1e-4 # learning_rate / 10 usually 31 | beta2 = 0.99 # make a bit bigger because number of tokens per iter is small 32 | 33 | warmup_iters = 100 # not super necessary potentially 34 | 35 | # on macbook also add 36 | # device = 'cpu' # run on cpu only 37 | # compile = False # do not torch compile the model 38 | -------------------------------------------------------------------------------- /examples/airbench/README.md: -------------------------------------------------------------------------------- 1 | # Scion on [airbench](https://github.com/KellerJordan/cifar10-airbench/tree/1da61ae58ee9c112e7f166fac3b97d245aa72942) 2 | 3 | Scion simplifies the training setup by: 4 | 5 | - avoiding Frobenius normalization within the optimizer 6 | - remove the need for training the whiten bias (i.e., `whiten_bias_epoch=0`) 7 | 8 | ## Overview 9 | 10 | - [`airbench_muon.py`](airbench_muon.py): 94.04% (mean over 200 runs) 11 | - [`airbench_sgd.py`](airbench_sgd.py): 94.01% (mean over 200 runs) 12 | - [`airbench_scion.py`](airbench_scion.py): 93.95% (mean over 50 runs) 13 | - [`airbench_scion_speedrun.py`](airbench_scion_speedrun.py): 94.07% (mean over 200 runs) through further optimization of the scaling factors and the learning rate. 14 | 15 | 16 | ## Pseudocode 17 | 18 | The configuration used in `airbench_scion.py`: 19 | 20 | ```python 21 | radius = 8.0 22 | optim_groups = [{ 23 | 'params': conv_layers, 24 | 'norm': 'SpectralConv', 25 | 'norm_kwargs': {'steps': 9}, # to stay consistent with the Muon baseline 26 | 'scale': radius, 27 | }, { 28 | 'params': batchnorm_layers, 29 | 'norm': 'BiasRMS', # heuristically uses l2 for normalization layers 30 | 'norm_kwargs': {}, 31 | 'scale': radius, 32 | }, { 33 | 'params': output_layer, 34 | 'norm': 'Sign', 35 | 'norm_kwargs': {'normalized': True}, 36 | 'scale': radius*16, 37 | }] 38 | optimizer = Scion(optim_groups, lr=0.05, momentum=0.6) 39 | ``` 40 | 41 | The implementation uses the Newton-Schulz version used in the Muon baseline for fair comparison. 42 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/README.md: -------------------------------------------------------------------------------- 1 | # Shallow NanoGPT 2 | 3 | See [`mup_examples/README.md`](mup_examples/README.md) for setup. 4 | 5 | ```bash 6 | # Transfer lr/width sweep 7 | bash mup_examples/mutransfer_lr_shakespeare_char/uscion/run.sh 8 | bash mup_examples/mutransfer_lr_shakespeare_char/scion/run.sh 9 | bash mup_examples/mutransfer_lr_shakespeare_char/scion_full/run.sh 10 | bash mup_examples/mutransfer_lr_shakespeare_char/sp/run.sh 11 | bash mup_examples/mutransfer_lr_shakespeare_char/mup/run.sh 12 | 13 | # Testing a single run (smallest width with optimal stepsize) 14 | bash mup_examples/mutransfer_lr_shakespeare_char/uscion/run_test.sh 15 | bash mup_examples/mutransfer_lr_shakespeare_char/scion/run_test.sh 16 | bash mup_examples/mutransfer_lr_shakespeare_char/scion_full/run_test.sh 17 | bash mup_examples/mutransfer_lr_shakespeare_char/sp/run_test.sh 18 | bash mup_examples/mutransfer_lr_shakespeare_char/mup/run_test.sh 19 | ``` 20 | 21 | 22 | ## CHANGELOG 23 | 24 | Changes made (see `SCION_CHANGE` code comments): 25 | 26 | - Modernized architecture: 27 | - Rotary embedding 28 | - RMS norm instead of LayerNorm 29 | - No weight sharing for first and last layer 30 | - Linear decay instead of cosine 31 | - GELU `sqrt(2)` scaling 32 | - Decouples QKV into three separate `Linear` layers to expose each independently to the optimizer 33 | - Increases batch size to 32 (maximum allowed for 4096 width model on an A100) 34 | - Disables `torch.compile` to support running without triton 35 | - Logs final validation loss (instead of a running average) 36 | 37 | 38 | ## Acknowledgements 39 | 40 | This codebase builds on [EleutherAI/nanoGPT-mup](https://github.com/EleutherAI/nanoGPT-mup/). 41 | See [`README_orig.md`](README_orig.md) for the original readme. 42 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/mutransfer_lr_shakespeare_char/sp/run_test.sh: -------------------------------------------------------------------------------- 1 | for width in 256 #512 1024 2048 2 | do 3 | for lr in 0.00048828 4 | do 5 | for seed in 1 #2 3 6 | do 7 | head_size=64 8 | n_heads=$((width / head_size)) 9 | out_dir="mup_examples/mutransfer_lr_shakespeare_char/sp/out/test" 10 | python train.py \ 11 | --out_dir=$out_dir \ 12 | --log_interval=1 \ 13 | --eval_on_end=True \ 14 | --eval_iters=200 \ 15 | --skip_val_loss=False \ 16 | --eval_only=False \ 17 | --always_save_checkpoint=False \ 18 | --never_save_checkpoint=True \ 19 | --init_from='scratch' \ 20 | --wandb_log=False \ 21 | --csv_log=True \ 22 | --dataset='shakespeare_char' \ 23 | --gradient_accumulation_steps=1 \ 24 | --batch_size=32 \ 25 | --block_size=1024 \ 26 | --n_layer=2 \ 27 | --n_head=$n_heads \ 28 | --n_embd=$width \ 29 | --dropout=0.0 \ 30 | --bias=False \ 31 | --init_std=0.02 \ 32 | --learning_rate=$lr \ 33 | --max_iters=122 \ 34 | --weight_decay=1e-1 \ 35 | --beta1=0.9 \ 36 | --beta2=0.95 \ 37 | --grad_clip=1.0 \ 38 | --decay_lr=False \ 39 | --seed=$seed \ 40 | --backend='nccl' \ 41 | --device='cuda' \ 42 | --dtype='float32' \ 43 | --compile=False 44 | done 45 | done 46 | done 47 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/mutransfer_lr_shakespeare_char/uscion/run_test.sh: -------------------------------------------------------------------------------- 1 | for width in 512 #512 1024 2048 2 | do 3 | for lr in 0.03125 4 | do 5 | for seed in 2 6 | do 7 | head_size=64 8 | n_heads=$((width / head_size)) 9 | out_dir="mup_examples/mutransfer_lr_shakespeare_char/uscion/out/test" 10 | python train.py \ 11 | --out_dir=$out_dir \ 12 | --log_interval=1 \ 13 | --eval_on_end=True \ 14 | --eval_iters=200 \ 15 | --skip_val_loss=False \ 16 | --eval_only=False \ 17 | --always_save_checkpoint=False \ 18 | --never_save_checkpoint=True \ 19 | --init_from='scratch' \ 20 | --wandb_log=False \ 21 | --csv_log=True \ 22 | --dataset='shakespeare_char' \ 23 | --gradient_accumulation_steps=1 \ 24 | --batch_size=32 \ 25 | --block_size=1024 \ 26 | --n_layer=2 \ 27 | --n_head=$n_heads \ 28 | --n_embd=$width \ 29 | --dropout=0.0 \ 30 | --bias=False \ 31 | --init_std=0.02 \ 32 | --learning_rate=$lr \ 33 | --max_iters=122 \ 34 | --beta1=0.9 \ 35 | --grad_clip=0.0 \ 36 | --warmup_iters=0 \ 37 | --decay_lr=True \ 38 | --scion_enabled=True \ 39 | --scion_first_layer='Sign' \ 40 | --seed=$seed \ 41 | --backend='nccl' \ 42 | --device='cuda' \ 43 | --dtype='float32' \ 44 | --compile=False 45 | done 46 | done 47 | done 48 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/mutransfer_lr_shakespeare_char/scion/run_test.sh: -------------------------------------------------------------------------------- 1 | for width in 2048 #512 1024 2048 2 | do 3 | for lr in 0.0078125 4 | do 5 | for seed in 2 6 | do 7 | head_size=64 8 | n_heads=$((width / head_size)) 9 | out_dir="mup_examples/mutransfer_lr_shakespeare_char/scion/out/test" 10 | python train.py \ 11 | --out_dir=$out_dir \ 12 | --log_interval=1 \ 13 | --eval_on_end=True \ 14 | --eval_iters=200 \ 15 | --skip_val_loss=False \ 16 | --eval_only=False \ 17 | --always_save_checkpoint=False \ 18 | --never_save_checkpoint=True \ 19 | --init_from='scratch' \ 20 | --wandb_log=False \ 21 | --csv_log=True \ 22 | --dataset='shakespeare_char' \ 23 | --gradient_accumulation_steps=1 \ 24 | --batch_size=32 \ 25 | --block_size=1024 \ 26 | --n_layer=2 \ 27 | --n_head=$n_heads \ 28 | --n_embd=$width \ 29 | --dropout=0.0 \ 30 | --bias=False \ 31 | --init_std=0.02 \ 32 | --learning_rate=$lr \ 33 | --max_iters=122 \ 34 | --beta1=0.9 \ 35 | --grad_clip=0.0 \ 36 | --warmup_iters=0 \ 37 | --decay_lr=True \ 38 | --scion_enabled=True \ 39 | --scion_first_layer='Sign' \ 40 | --scion_unconstrained=True \ 41 | --seed=$seed \ 42 | --backend='nccl' \ 43 | --device='cuda' \ 44 | --dtype='float32' \ 45 | --compile=False 46 | done 47 | done 48 | done 49 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/mutransfer_lr_shakespeare_char/scion_full/run_test.sh: -------------------------------------------------------------------------------- 1 | for width in 2048 #512 1024 2048 2 | do 3 | for lr in 0.0001 # 0.01 #0.0078125 4 | do 5 | for seed in 2 6 | do 7 | head_size=64 8 | n_heads=$((width / head_size)) 9 | out_dir="mup_examples/mutransfer_lr_shakespeare_char/scion_full/out/test" 10 | python train.py \ 11 | --out_dir=$out_dir \ 12 | --log_interval=1 \ 13 | --eval_on_end=True \ 14 | --eval_iters=200 \ 15 | --skip_val_loss=False \ 16 | --eval_only=False \ 17 | --always_save_checkpoint=False \ 18 | --never_save_checkpoint=True \ 19 | --init_from='scratch' \ 20 | --wandb_log=False \ 21 | --csv_log=True \ 22 | --dataset='shakespeare_char' \ 23 | --gradient_accumulation_steps=1 \ 24 | --batch_size=32 \ 25 | --block_size=1024 \ 26 | --n_layer=2 \ 27 | --n_head=$n_heads \ 28 | --n_embd=$width \ 29 | --dropout=0.0 \ 30 | --bias=False \ 31 | --init_std=0.02 \ 32 | --learning_rate=$lr \ 33 | --max_iters=122 \ 34 | --beta1=0.9 \ 35 | --grad_clip=0.0 \ 36 | --warmup_iters=0 \ 37 | --decay_lr=True \ 38 | --scion_enabled=True \ 39 | --scion_mode='Sign-naive' \ 40 | --scion_unconstrained=True \ 41 | --seed=$seed \ 42 | --backend='nccl' \ 43 | --device='cuda' \ 44 | --dtype='float32' \ 45 | --compile=False 46 | done 47 | done 48 | done 49 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/mutransfer_lr_shakespeare_char/sp/run.sh: -------------------------------------------------------------------------------- 1 | for width in 256 512 1024 2048 2 | do 3 | for lr in 0.00390625 0.001953125 0.0009765625 0.00048828125 0.000244140625 0.0001220703125 0.00006103515625 0.00003051757812 0.00001525878906 0.000007629394531 0.000003814697266 4 | do 5 | for seed in 1 2 3 6 | do 7 | head_size=64 8 | n_heads=$((width / head_size)) 9 | out_dir="mup_examples/mutransfer_lr_shakespeare_char/sp/out/width${width}_depth2_seed${seed}_lr${lr}" 10 | python train.py \ 11 | --out_dir=$out_dir \ 12 | --eval_interval=1 \ 13 | --log_interval=1 \ 14 | --eval_iters=1 \ 15 | --eval_only=False \ 16 | --skip_val_loss=False \ 17 | --always_save_checkpoint=False \ 18 | --never_save_checkpoint=True \ 19 | --init_from='scratch' \ 20 | --wandb_log=False \ 21 | --csv_log=True \ 22 | --dataset='shakespeare_char' \ 23 | --gradient_accumulation_steps=1 \ 24 | --batch_size=32 \ 25 | --block_size=1024 \ 26 | --n_layer=2 \ 27 | --n_head=$n_heads \ 28 | --n_embd=$width \ 29 | --dropout=0.0 \ 30 | --bias=False \ 31 | --init_std=0.02 \ 32 | --learning_rate=$lr \ 33 | --max_iters=122 \ 34 | --weight_decay=1e-1 \ 35 | --beta1=0.9 \ 36 | --beta2=0.95 \ 37 | --grad_clip=1.0 \ 38 | --decay_lr=False \ 39 | --seed=$seed \ 40 | --backend='nccl' \ 41 | --device='cuda' \ 42 | --dtype='float32' \ 43 | --compile=False 44 | done 45 | done 46 | done 47 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/mutransfer_lr_shakespeare_char/mup/run_test.sh: -------------------------------------------------------------------------------- 1 | for width in 256 2 | do 3 | for lr in 0.00097656 4 | do 5 | for seed in 1 # 2 3 6 | do 7 | head_size=64 8 | n_heads=$((width / head_size)) 9 | mup_base_width=256 10 | mup_width_multiplier=$(echo "scale=8; $width/$mup_base_width" | bc -l) 11 | out_dir="mup_examples/mutransfer_lr_shakespeare_char/mup/out/test" 12 | python train.py \ 13 | --out_dir=$out_dir \ 14 | --eval_on_end=True \ 15 | --eval_iters=200 \ 16 | --skip_val_loss=False \ 17 | --eval_only=False \ 18 | --log_interval=1 \ 19 | --always_save_checkpoint=False \ 20 | --never_save_checkpoint=True \ 21 | --init_from='scratch' \ 22 | --wandb_log=False \ 23 | --csv_log=True \ 24 | --dataset='shakespeare_char' \ 25 | --gradient_accumulation_steps=1\ 26 | --batch_size=32 \ 27 | --block_size=1024 \ 28 | --n_layer=2 \ 29 | --n_head=$n_heads \ 30 | --n_embd=$width \ 31 | --dropout=0.0 \ 32 | --bias=False \ 33 | --init_std=0.02 \ 34 | --learning_rate=$lr \ 35 | --max_iters=122 \ 36 | --weight_decay=1e-1 \ 37 | --beta1=0.9 \ 38 | --beta2=0.95 \ 39 | --grad_clip=1.0 \ 40 | --decay_lr=False \ 41 | --mup_enabled=True \ 42 | --mup_width_multiplier=$mup_width_multiplier \ 43 | --mup_input_alpha=1.0 \ 44 | --mup_output_alpha=1.0 \ 45 | --seed=$seed \ 46 | --backend='nccl' \ 47 | --device='cuda' \ 48 | --dtype='float32' \ 49 | --compile=False 50 | done 51 | done 52 | done 53 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/mutransfer_lr_shakespeare_char/mup/run.sh: -------------------------------------------------------------------------------- 1 | for width in 256 512 1024 2048 2 | do 3 | for lr in 0.125 0.0625 0.03125 0.015625 0.0078125 0.00390625 0.001953125 0.0009765625 0.00048828125 0.000244140625 0.0001220703125 0.00006103515625 4 | do 5 | for seed in 1 2 3 6 | do 7 | head_size=64 8 | n_heads=$((width / head_size)) 9 | mup_base_width=256 10 | mup_width_multiplier=$(echo "scale=8; $width/$mup_base_width" | bc -l) 11 | out_dir="mup_examples/mutransfer_lr_shakespeare_char/mup/out/width${width}_depth2_seed${seed}_lr${lr}" 12 | python train.py \ 13 | --out_dir=$out_dir \ 14 | --eval_interval=1 \ 15 | --log_interval=1 \ 16 | --eval_iters=1 \ 17 | --eval_only=False \ 18 | --skip_val_loss=False \ 19 | --always_save_checkpoint=False \ 20 | --never_save_checkpoint=True \ 21 | --init_from='scratch' \ 22 | --wandb_log=False \ 23 | --csv_log=True \ 24 | --dataset='shakespeare_char' \ 25 | --gradient_accumulation_steps=1\ 26 | --batch_size=32 \ 27 | --block_size=1024 \ 28 | --n_layer=2 \ 29 | --n_head=$n_heads \ 30 | --n_embd=$width \ 31 | --dropout=0.0 \ 32 | --bias=False \ 33 | --init_std=0.02 \ 34 | --learning_rate=$lr \ 35 | --max_iters=122 \ 36 | --weight_decay=1e-1 \ 37 | --beta1=0.9 \ 38 | --beta2=0.95 \ 39 | --grad_clip=1.0 \ 40 | --decay_lr=False \ 41 | --mup_enabled=True \ 42 | --mup_width_multiplier=$mup_width_multiplier \ 43 | --mup_input_alpha=1.0 \ 44 | --mup_output_alpha=1.0 \ 45 | --seed=$seed \ 46 | --backend='nccl' \ 47 | --device='cuda' \ 48 | --dtype='float32' \ 49 | --compile=False 50 | done 51 | done 52 | done 53 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/mutransfer_lr_shakespeare_char/uscion/run.sh: -------------------------------------------------------------------------------- 1 | for first_layer in 'Sign' 'Spectral' 'ColNorm' 2 | do 3 | for width in 256 512 1024 2048 4 | do 5 | for lr in 0.5 0.25 0.125 0.0625 0.03125 0.015625 0.0078125 0.00390625 0.00195312 0.0009765625 0.00048828125 0.000244140625 6 | do 7 | for seed in 1 2 3 8 | do 9 | head_size=64 10 | n_heads=$((width / head_size)) 11 | out_dir="mup_examples/mutransfer_lr_shakespeare_char/uscion/out/first_layer${first_layer}_width${width}_depth2_seed${seed}_lr${lr}" 12 | python train.py \ 13 | --out_dir=$out_dir \ 14 | --eval_on_end=True \ 15 | --eval_iters=200 \ 16 | --skip_val_loss=False \ 17 | --eval_only=False \ 18 | --log_interval=1 \ 19 | --always_save_checkpoint=False \ 20 | --never_save_checkpoint=True \ 21 | --init_from='scratch' \ 22 | --wandb_log=False \ 23 | --csv_log=True \ 24 | --dataset='shakespeare_char' \ 25 | --gradient_accumulation_steps=1 \ 26 | --batch_size=32 \ 27 | --block_size=1024 \ 28 | --n_layer=2 \ 29 | --n_head=$n_heads \ 30 | --n_embd=$width \ 31 | --dropout=0.0 \ 32 | --bias=False \ 33 | --init_std=0.02 \ 34 | --learning_rate=$lr \ 35 | --max_iters=122 \ 36 | --beta1=0.9 \ 37 | --grad_clip=0.0 \ 38 | --warmup_iters=0 \ 39 | --decay_lr=True \ 40 | --scion_enabled=True \ 41 | --scion_first_layer=$first_layer \ 42 | --seed=$seed \ 43 | --backend='nccl' \ 44 | --device='cuda' \ 45 | --dtype='float32' \ 46 | --compile=False 47 | done 48 | done 49 | done 50 | done 51 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/mutransfer_lr_shakespeare_char/scion_full/run.sh: -------------------------------------------------------------------------------- 1 | for mode in 'Sign' 'ColNorm' 'RowNorm' 2 | do 3 | for width in 256 512 1024 2048 4 | do 5 | for lr in 0.5 0.25 0.125 0.0625 0.03125 0.015625 0.0078125 0.00390625 0.00195312 0.0009765625 0.00048828125 0.000244140625 6 | do 7 | for seed in 1 2 3 8 | do 9 | head_size=64 10 | n_heads=$((width / head_size)) 11 | out_dir="mup_examples/mutransfer_lr_shakespeare_char/scion_full/out/mode${mode}_width${width}_depth2_seed${seed}_lr${lr}" 12 | python train.py \ 13 | --out_dir=$out_dir \ 14 | --eval_on_end=True \ 15 | --eval_iters=200 \ 16 | --skip_val_loss=False \ 17 | --eval_only=False \ 18 | --log_interval=1 \ 19 | --always_save_checkpoint=False \ 20 | --never_save_checkpoint=True \ 21 | --init_from='scratch' \ 22 | --wandb_log=False \ 23 | --csv_log=True \ 24 | --dataset='shakespeare_char' \ 25 | --gradient_accumulation_steps=1 \ 26 | --batch_size=32 \ 27 | --block_size=1024 \ 28 | --n_layer=2 \ 29 | --n_head=$n_heads \ 30 | --n_embd=$width \ 31 | --dropout=0.0 \ 32 | --bias=False \ 33 | --init_std=0.02 \ 34 | --learning_rate=$lr \ 35 | --max_iters=122 \ 36 | --beta1=0.9 \ 37 | --grad_clip=0.0 \ 38 | --warmup_iters=0 \ 39 | --decay_lr=True \ 40 | --scion_enabled=True \ 41 | --scion_mode=$mode \ 42 | --scion_unconstrained=True \ 43 | --seed=$seed \ 44 | --backend='nccl' \ 45 | --device='cuda' \ 46 | --dtype='float32' \ 47 | --compile=False 48 | done 49 | done 50 | done 51 | done 52 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/mutransfer_lr_shakespeare_char/scion/run.sh: -------------------------------------------------------------------------------- 1 | for first_layer in 'Sign' 'Spectral' 'ColNorm' 2 | do 3 | for width in 256 512 1024 2048 4 | do 5 | for lr in 0.5 0.25 0.125 0.0625 0.03125 0.015625 0.0078125 0.00390625 0.00195312 0.0009765625 0.00048828125 0.000244140625 6 | do 7 | for seed in 1 2 3 8 | do 9 | head_size=64 10 | n_heads=$((width / head_size)) 11 | out_dir="mup_examples/mutransfer_lr_shakespeare_char/scion/out/first_layer${first_layer}_width${width}_depth2_seed${seed}_lr${lr}" 12 | python train.py \ 13 | --out_dir=$out_dir \ 14 | --eval_on_end=True \ 15 | --eval_iters=200 \ 16 | --skip_val_loss=False \ 17 | --eval_only=False \ 18 | --log_interval=1 \ 19 | --always_save_checkpoint=False \ 20 | --never_save_checkpoint=True \ 21 | --init_from='scratch' \ 22 | --wandb_log=False \ 23 | --csv_log=True \ 24 | --dataset='shakespeare_char' \ 25 | --gradient_accumulation_steps=1 \ 26 | --batch_size=32 \ 27 | --block_size=1024 \ 28 | --n_layer=2 \ 29 | --n_head=$n_heads \ 30 | --n_embd=$width \ 31 | --dropout=0.0 \ 32 | --bias=False \ 33 | --init_std=0.02 \ 34 | --learning_rate=$lr \ 35 | --max_iters=122 \ 36 | --beta1=0.9 \ 37 | --grad_clip=0.0 \ 38 | --warmup_iters=0 \ 39 | --decay_lr=True \ 40 | --scion_enabled=True \ 41 | --scion_first_layer=$first_layer \ 42 | --scion_unconstrained=True \ 43 | --seed=$seed \ 44 | --backend='nccl' \ 45 | --device='cuda' \ 46 | --dtype='float32' \ 47 | --compile=False 48 | done 49 | done 50 | done 51 | done 52 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/mutransfer_lr_shakespeare_char/scion_full/run_naive.sh: -------------------------------------------------------------------------------- 1 | for mode in 'Sign-naive' 2 | do 3 | for width in 256 512 1024 2048 4 | do 5 | for lr in 0.00390625 0.001953125 0.0009765625 0.00048828125 0.000244140625 0.0001220703125 0.00006103515625 0.00003051757812 0.00001525878906 0.000007629394531 0.000003814697266 6 | do 7 | for seed in 1 2 3 8 | do 9 | head_size=64 10 | n_heads=$((width / head_size)) 11 | out_dir="mup_examples/mutransfer_lr_shakespeare_char/scion_full/out/mode${mode}_width${width}_depth2_seed${seed}_lr${lr}" 12 | python train.py \ 13 | --out_dir=$out_dir \ 14 | --eval_on_end=True \ 15 | --eval_iters=200 \ 16 | --skip_val_loss=False \ 17 | --eval_only=False \ 18 | --log_interval=1 \ 19 | --always_save_checkpoint=False \ 20 | --never_save_checkpoint=True \ 21 | --init_from='scratch' \ 22 | --wandb_log=False \ 23 | --csv_log=True \ 24 | --dataset='shakespeare_char' \ 25 | --gradient_accumulation_steps=1 \ 26 | --batch_size=32 \ 27 | --block_size=1024 \ 28 | --n_layer=2 \ 29 | --n_head=$n_heads \ 30 | --n_embd=$width \ 31 | --dropout=0.0 \ 32 | --bias=False \ 33 | --init_std=0.02 \ 34 | --learning_rate=$lr \ 35 | --max_iters=122 \ 36 | --beta1=0.9 \ 37 | --grad_clip=0.0 \ 38 | --warmup_iters=0 \ 39 | --decay_lr=True \ 40 | --scion_enabled=True \ 41 | --scion_mode=$mode \ 42 | --scion_unconstrained=True \ 43 | --seed=$seed \ 44 | --backend='nccl' \ 45 | --device='cuda' \ 46 | --dtype='float32' \ 47 | --compile=False 48 | done 49 | done 50 | done 51 | done 52 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/data/shakespeare_char/prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prepare the Shakespeare dataset for character-level language modeling. 3 | So instead of encoding with GPT-2 BPE tokens, we just map characters to ints. 4 | Will save train.bin, val.bin containing the ids, and meta.pkl containing the 5 | encoder and decoder and some other related info. 6 | """ 7 | import os 8 | import pickle 9 | import requests 10 | import numpy as np 11 | 12 | # download the tiny shakespeare dataset 13 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') 14 | if not os.path.exists(input_file_path): 15 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 16 | with open(input_file_path, 'w') as f: 17 | f.write(requests.get(data_url).text) 18 | 19 | with open(input_file_path, 'r') as f: 20 | data = f.read() 21 | print(f"length of dataset in characters: {len(data):,}") 22 | 23 | # get all the unique characters that occur in this text 24 | chars = sorted(list(set(data))) 25 | vocab_size = len(chars) 26 | print("all the unique characters:", ''.join(chars)) 27 | print(f"vocab size: {vocab_size:,}") 28 | 29 | # create a mapping from characters to integers 30 | stoi = { ch:i for i,ch in enumerate(chars) } 31 | itos = { i:ch for i,ch in enumerate(chars) } 32 | def encode(s): 33 | return [stoi[c] for c in s] # encoder: take a string, output a list of integers 34 | def decode(l): 35 | return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string 36 | 37 | # create the train and test splits 38 | n = len(data) 39 | train_data = data[:int(n*0.9)] 40 | val_data = data[int(n*0.9):] 41 | 42 | # encode both to integers 43 | train_ids = encode(train_data) 44 | val_ids = encode(val_data) 45 | print(f"train has {len(train_ids):,} tokens") 46 | print(f"val has {len(val_ids):,} tokens") 47 | 48 | # export to bin files 49 | train_ids = np.array(train_ids, dtype=np.uint16) 50 | val_ids = np.array(val_ids, dtype=np.uint16) 51 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) 52 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) 53 | 54 | # save the meta information as well, to help us encode/decode later 55 | meta = { 56 | 'vocab_size': vocab_size, 57 | 'itos': itos, 58 | 'stoi': stoi, 59 | } 60 | with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f: 61 | pickle.dump(meta, f) 62 | 63 | # length of dataset in characters: 1115394 64 | # all the unique characters: 65 | # !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz 66 | # vocab size: 65 67 | # train has 1003854 tokens 68 | # val has 111540 tokens 69 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/data/openwebtext/prepare.py: -------------------------------------------------------------------------------- 1 | # saves the openwebtext dataset to a binary file for training. following was helpful: 2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 3 | 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import tiktoken 8 | from datasets import load_dataset # huggingface datasets 9 | 10 | # number of workers in .map() call 11 | # good number to use is ~order number of cpu cores // 2 12 | num_proc = 8 13 | 14 | # number of workers in load_dataset() call 15 | # best number might be different from num_proc above as it also depends on NW speed. 16 | # it is better than 1 usually though 17 | num_proc_load_dataset = num_proc 18 | 19 | enc = tiktoken.get_encoding("gpt2") 20 | 21 | if __name__ == '__main__': 22 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 23 | dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset) 24 | 25 | # owt by default only contains the 'train' split, so create a test split 26 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 27 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val 28 | 29 | # this results in: 30 | # >>> split_dataset 31 | # DatasetDict({ 32 | # train: Dataset({ 33 | # features: ['text'], 34 | # num_rows: 8009762 35 | # }) 36 | # val: Dataset({ 37 | # features: ['text'], 38 | # num_rows: 4007 39 | # }) 40 | # }) 41 | 42 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 43 | def process(example): 44 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 45 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 46 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 47 | out = {'ids': ids, 'len': len(ids)} 48 | return out 49 | 50 | # tokenize the dataset 51 | tokenized = split_dataset.map( 52 | process, 53 | remove_columns=['text'], 54 | desc="tokenizing the splits", 55 | num_proc=num_proc, 56 | ) 57 | 58 | # concatenate all the ids in each dataset into one large file we can use for training 59 | for split, dset in tokenized.items(): 60 | arr_len = np.sum(dset['len'], dtype=np.uint64) 61 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 62 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 63 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 64 | total_batches = 1024 65 | 66 | idx = 0 67 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): 68 | # Batch together samples for faster write 69 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') 70 | arr_batch = np.concatenate(batch['ids']) 71 | # Write into mmap 72 | arr[idx : idx + len(arr_batch)] = arr_batch 73 | idx += len(arr_batch) 74 | arr.flush() 75 | 76 | # train.bin is ~17GB, val.bin ~8.5MB 77 | # train has ~9B tokens (9,035,582,198) 78 | # val has ~4M tokens (4,434,897) 79 | 80 | # to read the bin files later, e.g. with numpy: 81 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scion 2 | 3 | Code accompanying the paper [Training Deep Learning Models with Norm-Constrained LMOs](https://arxiv.org/pdf/2502.07529). 4 | 5 | ## Repository structure 6 | 7 | - [`scion.py`](scion.py): Contains the `Scion` and `ScionLight` reference implementation along with various norm choices. 8 | `ScionLight` is a memory-efficient variant that reuses `p.grad`. 9 | - [`examples/`](examples/): Example usage containing nanoGPT experiments with and without weight sharing. 10 | 11 | ## Notes 12 | 13 | The `Scion` optimizer comes with a couple of hyperparameters: 14 | 15 | - `momentum`: The parameter is `1-usual_momentum` of e.g. the PyTorch implementation of SGD with momentum. 16 | A good default is 0.1. 17 | Higher values seems to work better (e.g. 0.5) for short training runs with low noise as also supported by theory. 18 | - `scale`: Controls the per-layer constraint radius factor. 19 | The layerwise radius can be tuned on a small proxy model similarly to the input and output scaling factor of µP. 20 | - `lr`: The learning rate can similarly be tuned on a small proxy model (corresponds to γ in the paper). 21 | - `unconstrained`: When set to `False` the constrained variant of the Scion is used, which guarantees the iterates to stay bounded. 22 | The flag is useful for numerical stability in long training runs and to avoid overfitting. 23 | See [Section 3](https://arxiv.org/pdf/2502.07529) for a discussion on the connection with weight decay. 24 | 25 | Architectural changes: 26 | 27 | - Scale activation functions (ReLU, GELU) [by √2](https://github.com/LIONS-EPFL/scion/blob/main/examples/shallow-nanogpt/model.py#L104) to maintain the input variance. 28 | 29 | 30 | ## Examples 31 | 32 | For runnable examples see [`examples/`](examples/). 33 | Below are some pseudocode configurations for different architectures and domains (see [Appendix E.4](https://arxiv.org/pdf/2502.07529) for exact parameter choices): 34 | 35 | 36 | - nanoGPT with weight sharing: 37 | 38 | ```python 39 | radius = 50.0 40 | optim_groups = [{ 41 | 'params': model.transformer.h.parameters(), 42 | 'norm': 'Spectral', 43 | 'norm_kwargs': {}, 44 | 'scale': radius, 45 | }, { 46 | 'params': model.lm_head.parameters(), 47 | 'norm': 'Sign', 48 | 'norm_kwargs': {}, 49 | 'scale': radius*60.0, 50 | }] 51 | optimizer = Scion(optim_groups, lr=2**-12, momentum=0.1, unconstrained=False) 52 | ``` 53 | 54 | - MLP: 55 | 56 | ```python 57 | radius = 1.0 58 | optim_groups = [{ 59 | 'params': input_layer, 60 | 'norm': 'Spectral', 61 | 'norm_kwargs': {'max': True}, 62 | 'scale': radius, 63 | }, { 64 | 'params': hidden_layers, 65 | 'norm': 'Spectral', 66 | 'norm_kwargs': {}, 67 | 'scale': radius, 68 | }, { 69 | 'params': output_layer, 70 | 'norm': 'Sign', 71 | 'norm_kwargs': {'normalized': True}, 72 | 'scale': radius*2**10.0, 73 | }] 74 | optimizer = Scion(optim_groups, lr=2**-6, momentum=0.1) 75 | optimizer.init() 76 | ``` 77 | 78 | - CNN (see [`examples/airbench`](examples/airbench) for further details): 79 | 80 | ```python 81 | radius = 8.0 82 | optim_groups = [{ 83 | 'params': remaining_parameters, 84 | 'norm': 'Auto', # Picks layerwise norm based on the parameter shape 85 | 'norm_kwargs': {}, 86 | 'scale': radius, 87 | }, { 88 | 'params': output_layer, 89 | 'norm': 'Sign', 90 | 'norm_kwargs': {'normalized': True}, 91 | 'scale': radius*16, 92 | }] 93 | optimizer = Scion(optim_groups, lr=2**-4, momentum=0.5) 94 | ``` 95 | 96 | 97 | ## Citation 98 | 99 | If you find this work useful, please cite it as follows: 100 | 101 | ```bibtex 102 | @article{pethick2025training, 103 | title={Training Deep Learning Models with Norm-Constrained LMOs}, 104 | author={Pethick, Thomas and Xie, Wanyun and Antonakopoulos, Kimon and Zhu, Zhenyu and Silveti-Falls, Antonio and Cevher, Volkan}, 105 | journal={arXiv preprint arXiv:2502.07529}, 106 | year={2025} 107 | } 108 | ``` 109 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample from a trained model 3 | """ 4 | import os 5 | import pickle 6 | from contextlib import nullcontext 7 | import torch 8 | import tiktoken 9 | from model import GPTConfig, GPT 10 | 11 | # ----------------------------------------------------------------------------- 12 | init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl') 13 | out_dir = 'out' # ignored if init_from is not 'resume' 14 | start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt" 15 | num_samples = 10 # number of samples to draw 16 | max_new_tokens = 500 # number of tokens generated in each sample 17 | temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions 18 | top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability 19 | seed = 1337 20 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. 21 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' 22 | compile = False # use PyTorch 2.0 to compile the model to be faster 23 | exec(open('configurator.py').read()) # overrides from command line or config file 24 | # ----------------------------------------------------------------------------- 25 | 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 29 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 30 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 31 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 32 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 33 | 34 | # model 35 | if init_from == 'resume': 36 | # init from a model saved in a specific directory 37 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 38 | checkpoint = torch.load(ckpt_path, map_location=device) 39 | gptconf = GPTConfig(**checkpoint['model_args']) 40 | model = GPT(gptconf) 41 | state_dict = checkpoint['model'] 42 | unwanted_prefix = '_orig_mod.' 43 | for k,v in list(state_dict.items()): 44 | if k.startswith(unwanted_prefix): 45 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 46 | model.load_state_dict(state_dict) 47 | elif init_from.startswith('gpt2'): 48 | # init from a given GPT-2 model 49 | model = GPT.from_pretrained(init_from, dict(dropout=0.0)) 50 | 51 | model.eval() 52 | model.to(device) 53 | if compile: 54 | model = torch.compile(model) # requires PyTorch 2.0 (optional) 55 | 56 | # look for the meta pickle in case it is available in the dataset folder 57 | load_meta = False 58 | if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these... 59 | meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl') 60 | load_meta = os.path.exists(meta_path) 61 | if load_meta: 62 | print(f"Loading meta from {meta_path}...") 63 | with open(meta_path, 'rb') as f: 64 | meta = pickle.load(f) 65 | # TODO want to make this more general to arbitrary encoder/decoder schemes 66 | stoi, itos = meta['stoi'], meta['itos'] 67 | encode = lambda s: [stoi[c] for c in s] 68 | decode = lambda l: ''.join([itos[i] for i in l]) 69 | else: 70 | # ok let's assume gpt-2 encodings by default 71 | print("No meta.pkl found, assuming GPT-2 encodings...") 72 | enc = tiktoken.get_encoding("gpt2") 73 | encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) 74 | decode = lambda l: enc.decode(l) 75 | 76 | # encode the beginning of the prompt 77 | if start.startswith('FILE:'): 78 | with open(start[5:], 'r', encoding='utf-8') as f: 79 | start = f.read() 80 | start_ids = encode(start) 81 | x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) 82 | 83 | # run generation 84 | with torch.no_grad(): 85 | with ctx: 86 | for k in range(num_samples): 87 | y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) 88 | print(decode(y[0].tolist())) 89 | print('---------------') 90 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/bench.py: -------------------------------------------------------------------------------- 1 | """ 2 | A much shorter version of train.py for benchmarking 3 | """ 4 | import os 5 | from contextlib import nullcontext 6 | import numpy as np 7 | import time 8 | import torch 9 | from model import GPTConfig, GPT 10 | 11 | # ----------------------------------------------------------------------------- 12 | batch_size = 12 13 | block_size = 1024 14 | bias = False 15 | real_data = True 16 | seed = 1337 17 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. 18 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' 19 | compile = True # use PyTorch 2.0 to compile the model to be faster 20 | profile = False # use pytorch profiler, or just simple benchmarking? 21 | exec(open('configurator.py').read()) # overrides from command line or config file 22 | # ----------------------------------------------------------------------------- 23 | 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 27 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 28 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 29 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 30 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 31 | 32 | # data loading init 33 | if real_data: 34 | dataset = 'openwebtext' 35 | data_dir = os.path.join('data', dataset) 36 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 37 | def get_batch(split): 38 | data = train_data # note ignore split in benchmarking script 39 | ix = torch.randint(len(data) - block_size, (batch_size,)) 40 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 41 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 42 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 43 | return x, y 44 | else: 45 | # alternatively, if fixed data is desired to not care about data loading 46 | x = torch.randint(50304, (batch_size, block_size), device=device) 47 | y = torch.randint(50304, (batch_size, block_size), device=device) 48 | get_batch = lambda split: (x, y) 49 | 50 | # model init 51 | gptconf = GPTConfig( 52 | block_size = block_size, # how far back does the model look? i.e. context size 53 | n_layer = 12, n_head = 12, n_embd = 768, # size of the model 54 | dropout = 0, # for determinism 55 | bias = bias, 56 | ) 57 | model = GPT(gptconf) 58 | model.to(device) 59 | 60 | optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type) 61 | 62 | if compile: 63 | print("Compiling model...") 64 | model = torch.compile(model) # pytorch 2.0 65 | 66 | if profile: 67 | # useful docs on pytorch profiler: 68 | # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html 69 | # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile 70 | wait, warmup, active = 5, 5, 5 71 | num_steps = wait + warmup + active 72 | with torch.profiler.profile( 73 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], 74 | schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), 75 | on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'), 76 | record_shapes=False, 77 | profile_memory=False, 78 | with_stack=False, # incurs an additional overhead, disable if not needed 79 | with_flops=True, 80 | with_modules=False, # only for torchscript models atm 81 | ) as prof: 82 | 83 | X, Y = get_batch('train') 84 | for k in range(num_steps): 85 | with ctx: 86 | logits, loss = model(X, Y) 87 | X, Y = get_batch('train') 88 | optimizer.zero_grad(set_to_none=True) 89 | loss.backward() 90 | optimizer.step() 91 | lossf = loss.item() 92 | print(f"{k}/{num_steps} loss: {lossf:.4f}") 93 | 94 | prof.step() # notify the profiler at end of each step 95 | 96 | else: 97 | 98 | # simple benchmarking 99 | torch.cuda.synchronize() 100 | for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark 101 | t0 = time.time() 102 | X, Y = get_batch('train') 103 | for k in range(num_steps): 104 | with ctx: 105 | logits, loss = model(X, Y) 106 | X, Y = get_batch('train') 107 | optimizer.zero_grad(set_to_none=True) 108 | loss.backward() 109 | optimizer.step() 110 | lossf = loss.item() 111 | print(f"{k}/{num_steps} loss: {lossf:.4f}") 112 | torch.cuda.synchronize() 113 | t1 = time.time() 114 | dt = t1-t0 115 | mfu = model.estimate_mfu(batch_size * 1 * num_steps, dt) 116 | if stage == 1: 117 | print(f"time per iteration: {dt/num_steps*1000:.4f}ms, MFU: {mfu*100:.2f}%") 118 | -------------------------------------------------------------------------------- /examples/modded-nanogpt/data/fineweb.py: -------------------------------------------------------------------------------- 1 | """ 2 | FineWeb dataset (for srs pretraining) 3 | https://huggingface.co/datasets/HuggingFaceFW/fineweb 4 | 5 | example doc to highlight the structure of the dataset: 6 | { 7 | "text": "Posted by mattsmith on 20th April 2012\nStraight from...", 8 | "id": "", 9 | "dump": "CC-MAIN-2013-20", 10 | "url": "http://nleastchatter.com/philliesphandom/tag/freddy-galvis/", 11 | "date": "2013-05-18T07:24:47Z", 12 | "file_path": "s3://commoncrawl/long.../path.../file.gz", 13 | "language": "en", 14 | "language_score": 0.9185474514961243, 15 | "token_count": 594 16 | } 17 | """ 18 | import os 19 | import argparse 20 | import multiprocessing as mp 21 | import numpy as np 22 | import tiktoken 23 | # from huggingface_hub import snapshot_download 24 | from datasets import load_dataset 25 | from tqdm import tqdm 26 | import argparse 27 | import numpy as np 28 | def write_datafile(filename, toks): 29 | """ 30 | Saves token data as a .bin file, for reading in C. 31 | - First comes a header with 256 int32s 32 | - The tokens follow, each as a uint16 33 | """ 34 | assert len(toks) < 2**31, "token count too large" # ~2.1B tokens 35 | # construct the header 36 | header = np.zeros(256, dtype=np.int32) 37 | header[0] = 20240520 # magic 38 | header[1] = 1 # version 39 | header[2] = len(toks) # number of tokens after the 256*4 bytes of header (each 2 bytes as uint16) 40 | # construct the tokens numpy array, if not already 41 | if not isinstance(toks, np.ndarray) or not toks.dtype == np.uint16: 42 | # validate that no token exceeds a uint16 43 | maxtok = 2**16 44 | assert all(0 <= t < maxtok for t in toks), "token dictionary too large for uint16" 45 | toks_np = np.array(toks, dtype=np.uint16) 46 | else: 47 | toks_np = toks 48 | # write to file 49 | print(f"writing {len(toks):,} tokens to {filename}") 50 | with open(filename, "wb") as f: 51 | f.write(header.tobytes()) 52 | f.write(toks_np.tobytes()) 53 | # ------------------------------------------ 54 | 55 | parser = argparse.ArgumentParser(description="FineWeb dataset preprocessing") 56 | parser.add_argument("-v", "--version", type=str, default="10B", help="Which version of fineweb to use 10B|100B") 57 | parser.add_argument("-s", "--shard_size", type=int, default=10**8, help="Size of each shard in tokens") 58 | args = parser.parse_args() 59 | 60 | # FineWeb has a few possible subsamples available 61 | assert args.version in ["10B", "100B"], "version must be one of 10B, 100B" 62 | if args.version == "10B": 63 | local_dir = "fineweb10B" 64 | remote_name = "sample-10BT" 65 | elif args.version == "100B": 66 | local_dir = "fineweb100B" 67 | remote_name = "sample-100BT" 68 | 69 | # create the cache the local directory if it doesn't exist yet 70 | DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir) 71 | os.makedirs(DATA_CACHE_DIR, exist_ok=True) 72 | 73 | # download the dataset 74 | fw = load_dataset("HuggingFaceFW/fineweb", name=remote_name, split="train") 75 | 76 | # init the tokenizer 77 | enc = tiktoken.get_encoding("gpt2") 78 | eot = enc._special_tokens['<|endoftext|>'] # end of text token 79 | def tokenize(doc): 80 | # tokenizes a single document and returns a numpy array of uint16 tokens 81 | tokens = [eot] # the special <|endoftext|> token delimits all documents 82 | tokens.extend(enc.encode_ordinary(doc["text"])) 83 | tokens_np = np.array(tokens) 84 | assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16" 85 | tokens_np_uint16 = tokens_np.astype(np.uint16) 86 | return tokens_np_uint16 87 | 88 | # tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder) 89 | nprocs = max(1, os.cpu_count() - 2) # don't hog the entire system 90 | with mp.Pool(nprocs) as pool: 91 | shard_index = 0 92 | # preallocate buffer to hold current shard 93 | all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16) 94 | token_count = 0 95 | progress_bar = None 96 | for tokens in pool.imap(tokenize, fw, chunksize=16): 97 | 98 | # is there enough space in the current shard for the new tokens? 99 | if token_count + len(tokens) < args.shard_size: 100 | # simply append tokens to current shard 101 | all_tokens_np[token_count:token_count+len(tokens)] = tokens 102 | token_count += len(tokens) 103 | # update progress bar 104 | if progress_bar is None: 105 | progress_bar = tqdm(total=args.shard_size, unit="tokens", desc=f"Shard {shard_index}") 106 | progress_bar.update(len(tokens)) 107 | else: 108 | # write the current shard and start a new one 109 | split = "val" if shard_index == 0 else "train" 110 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin") 111 | # split the document into whatever fits in this shard; the remainder goes to next one 112 | remainder = args.shard_size - token_count 113 | progress_bar.update(remainder) 114 | all_tokens_np[token_count:token_count+remainder] = tokens[:remainder] 115 | write_datafile(filename, all_tokens_np) 116 | shard_index += 1 117 | progress_bar = None 118 | # populate the next shard with the leftovers of the current doc 119 | all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:] 120 | token_count = len(tokens)-remainder 121 | 122 | # write any remaining tokens as the last shard 123 | if token_count != 0: 124 | split = "val" if shard_index == 0 else "train" 125 | filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin") 126 | write_datafile(filename, all_tokens_np[:token_count]) 127 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/csv_logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authored by Gavia Gray (https://github.com/gngdb) 3 | 4 | Wrapper for wandb logging with efficient CSV logging and correct config JSON writing. 5 | The CSV structure maintains a consistent order of keys based on their first appearance, 6 | using a simple list for ordering. This ensures data integrity and allows for graceful 7 | failure and manual recovery if needed. 8 | 9 | Example usage: 10 | run = wandb.init(config=your_config) 11 | wrapper = LogWrapper(run, out_dir='path/to/output') 12 | 13 | ... 14 | # in train loop 15 | wrapper.log({"train/loss": 0.5, "train/accuracy": 0.9, "val/loss": 0.6, "val/accuracy": 0.85}) 16 | wrapper.print("Train: {loss=:.4f}, {accuracy=:.2%}", prefix="train/") 17 | wrapper.print("Val: {loss=:.4f}, {accuracy=:.2%}", prefix="val/") 18 | wrapper.step() 19 | 20 | ... 21 | # at the end of your script 22 | wrapper.close() 23 | 24 | # If the script terminates unexpectedly, you can still recover the CSV using bash: 25 | # cat path/to/output/log_header.csv.tmp path/to/output/log_data.csv.tmp > path/to/output/log.csv 26 | """ 27 | 28 | import re 29 | import os 30 | import csv 31 | import json 32 | import atexit 33 | 34 | 35 | def exists(x): return x is not None 36 | 37 | def transform_format_string(s): 38 | """ 39 | Transforms a string containing f-string-like expressions to a format 40 | compatible with str.format(). 41 | 42 | This function converts expressions like '{var=}' or '{var=:formatting}' 43 | to 'var={var}' or 'var={var:formatting}' respectively. This allows 44 | for f-string-like syntax to be used with str.format(). 45 | 46 | Args: 47 | s (str): The input string containing f-string-like expressions. 48 | 49 | Returns: 50 | str: The transformed string, compatible with str.format(). 51 | 52 | Examples: 53 | >>> transform_format_string("Value is {x=}") 54 | "Value is x={x}" 55 | >>> transform_format_string("Formatted value is {x=:.2f}") 56 | "Formatted value is x={x:.2f}" 57 | """ 58 | pattern = r'\{(\w+)=(:.[^}]*)?\}' 59 | return re.sub(pattern, lambda m: f"{m.group(1)}={{{m.group(1)}{m.group(2) or ''}}}", s) 60 | 61 | class CSVLogWrapper: 62 | def __init__(self, logf=None, config={}, out_dir=None): 63 | self.logf = logf 64 | self.config = config 65 | self.log_dict = {} 66 | self.out_dir = out_dir 67 | self.csv_data_file = None 68 | self.csv_header_file = None 69 | self.csv_writer = None 70 | self.step_count = 0 71 | self.ordered_keys = [] 72 | self.header_updated = False 73 | self.is_finalized = False 74 | self.no_sync_keyword = 'no_sync' # Keyword to prevent syncing to wandb 75 | 76 | if self.out_dir: 77 | os.makedirs(self.out_dir, exist_ok=True) 78 | self.setup_csv_writer() 79 | self.write_config() 80 | 81 | atexit.register(self.close) 82 | 83 | def setup_csv_writer(self): 84 | self.csv_data_path = os.path.join(self.out_dir, 'log_data.csv.tmp') 85 | self.csv_header_path = os.path.join(self.out_dir, 'log_header.csv.tmp') 86 | self.csv_data_file = open(self.csv_data_path, 'w', newline='') 87 | self.csv_header_file = open(self.csv_header_path, 'w', newline='') 88 | self.csv_writer = csv.writer(self.csv_data_file) 89 | 90 | def write_config(self): 91 | if self.config: 92 | config_path = os.path.join(self.out_dir, 'config.json') 93 | with open(config_path, 'w') as f: 94 | json.dump(dict(**self.config), f, indent=2) 95 | 96 | def log(self, data): 97 | self.log_dict.update(data) 98 | for key in data: 99 | if key not in self.ordered_keys: 100 | self.ordered_keys.append(key) 101 | self.header_updated = True 102 | 103 | def update_header(self): 104 | if self.header_updated: 105 | header = ['step'] + self.ordered_keys 106 | with open(self.csv_header_path, 'w', newline='') as header_file: 107 | csv.writer(header_file).writerow(header) 108 | self.header_updated = False 109 | 110 | def print(self, format_string, prefix=None): 111 | format_string = transform_format_string(format_string) 112 | 113 | if prefix: 114 | # Filter keys with the given prefix and remove the prefix 115 | filtered_dict = {k.replace(prefix, ''): v for k, v in self.log_dict.items() if k.startswith(prefix)} 116 | else: 117 | filtered_dict = self.log_dict 118 | # replace any '/' in keys with '_' 119 | filtered_dict = {k.replace('/', '_'): v for k, v in filtered_dict.items()} 120 | 121 | try: 122 | print(format_string.format(**filtered_dict)) 123 | except KeyError as e: 124 | print(f"KeyError: {e}. Available keys: {', '.join(filtered_dict.keys())}") 125 | raise e 126 | 127 | def step(self): 128 | if exists(self.logf) and self.log_dict: 129 | self.logf({k: v for k, v in self.log_dict.items() if self.no_sync_keyword not in k}) 130 | 131 | if self.csv_writer and self.log_dict: 132 | self.update_header() 133 | 134 | # Prepare the row data 135 | row_data = [self.step_count] + [self.log_dict.get(key, '') for key in self.ordered_keys] 136 | self.csv_writer.writerow(row_data) 137 | self.csv_data_file.flush() # Ensure data is written to file 138 | 139 | self.step_count += 1 140 | self.log_dict.clear() 141 | 142 | def close(self): 143 | if self.csv_data_file: 144 | self.csv_data_file.close() 145 | 146 | self.finalize_csv() 147 | 148 | def finalize_csv(self): 149 | if self.is_finalized: 150 | return 151 | 152 | csv_final_path = os.path.join(self.out_dir, 'log.csv') 153 | 154 | with open(csv_final_path, 'w', newline='') as final_csv: 155 | # Copy header 156 | with open(self.csv_header_path, 'r') as header_file: 157 | final_csv.write(header_file.read()) 158 | 159 | # Copy data 160 | with open(self.csv_data_path, 'r') as data_file: 161 | final_csv.write(data_file.read()) 162 | self.is_finalized = True 163 | 164 | # Remove the temporary files 165 | os.remove(self.csv_header_path) 166 | os.remove(self.csv_data_path) 167 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/mup_examples/mutransfer_lr_shakespeare_char/plot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import pandas as pd\n", 11 | "import numpy as np\n", 12 | "from tqdm import tqdm\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import matplotlib as mpl\n", 15 | "from matplotlib import cm\n", 16 | "import seaborn as sns\n", 17 | "#sns.set(style='whitegrid')\n", 18 | "sns.set(style='whitegrid')\n", 19 | "\n", 20 | "PARAMETERIZATIONS = [[\n", 21 | " ('uscion', r'Scion (Sign $\\rightarrow$ Spectral $\\rightarrow$ Sign)', 'Sign'),\n", 22 | " ('uscion', r'Scion (Spectral $\\rightarrow$ Spectral $\\rightarrow$ Sign)', 'Spectral'),\n", 23 | " ('uscion', r'Scion (ColNorm $\\rightarrow$ Spectral $\\rightarrow$ Sign)', 'ColNorm'),\n", 24 | "],[\n", 25 | " ('sp', r'SP (AdamW)', None),\n", 26 | " ('scion', r'Unconstrained Scion (Sign $\\rightarrow$ Spectral $\\rightarrow$ Sign)', 'Sign'),\n", 27 | " ('scion', r'Unconstrained Scion (Spectral $\\rightarrow$ Spectral $\\rightarrow$ Sign)', 'Spectral'),\n", 28 | " ('scion', r'Unconstrained Scion (ColNorm $\\rightarrow$ Spectral $\\rightarrow$ Sign)', 'ColNorm'),\n", 29 | "],[\n", 30 | " ('scion_full', r'Signum', 'Sign-naive'),\n", 31 | " ('scion_full', r'Unconstrained Scion (Sign throughout)', 'Sign'),\n", 32 | "],[\n", 33 | " ('scion_full', r'Unconstrained Scion (RowNorm throughout)', 'RowNorm'),\n", 34 | " ('scion_full', r'Unconstrained Scion (ColNorm throughout)', 'ColNorm'),\n", 35 | "]]\n", 36 | "for j, parameterizations in enumerate(PARAMETERIZATIONS):\n", 37 | " seeds = [1,2,3]\n", 38 | " widths = [\n", 39 | " 256,\n", 40 | " 512,\n", 41 | " 1024,\n", 42 | " 2048,\n", 43 | " ]\n", 44 | " lrs = [\n", 45 | " 0.5,\n", 46 | " 0.25,\n", 47 | " 0.125,\n", 48 | " 0.0625,\n", 49 | " 0.03125,\n", 50 | " 0.015625,\n", 51 | " 0.0078125,\n", 52 | " 0.00390625,\n", 53 | " 0.001953125,\n", 54 | " 0.0009765625,\n", 55 | " 0.00048828125,\n", 56 | " 0.000244140625,\n", 57 | " 0.0001220703125,\n", 58 | " 0.00006103515625,\n", 59 | " 0.00003051757812,\n", 60 | " 0.00001525878906,\n", 61 | " 0.000007629394531,\n", 62 | " 0.000003814697266,\n", 63 | " ]\n", 64 | " class MplColorHelper:\n", 65 | "\n", 66 | " def __init__(self, cmap_name, start_val, stop_val):\n", 67 | " self.cmap_name = cmap_name\n", 68 | " self.cmap = plt.get_cmap(cmap_name)\n", 69 | " self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val)\n", 70 | " self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap)\n", 71 | "\n", 72 | " def get_rgb(self, val):\n", 73 | " return self.scalarMap.to_rgba(val)\n", 74 | "\n", 75 | "\n", 76 | " color_helper = MplColorHelper('viridis', 0, len(widths)-1)\n", 77 | " n_cols = len(parameterizations)\n", 78 | " n_rows = 1\n", 79 | " fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 3.33*n_rows))\n", 80 | " plt.subplots_adjust(wspace=5.0) # Adjust the width space between the axes\n", 81 | "\n", 82 | " for parameterization_idx, (parameterization, parameterization_str, extra) in enumerate(parameterizations):\n", 83 | " ax = axes[parameterization_idx]\n", 84 | " optimal_lrs = []\n", 85 | " optimal_losses = []\n", 86 | " for width_idx, width in enumerate(widths):\n", 87 | " mean_losses = []\n", 88 | " sem_losses = []\n", 89 | " lrs_to_plot = []\n", 90 | " for lr in lrs:\n", 91 | " losses = []\n", 92 | " for seed in seeds:\n", 93 | " if parameterization == 'scion_full' and extra is not None:\n", 94 | " job_name = f'mode{extra}_width{width}_depth2_seed{seed}_lr{lr:.20f}'.rstrip('0') \n", 95 | " elif extra is not None:\n", 96 | " job_name = f'first_layer{extra}_width{width}_depth2_seed{seed}_lr{lr:.20f}'.rstrip('0')\n", 97 | " else:\n", 98 | " job_name = f'width{width}_depth2_seed{seed}_lr{lr:.20f}'.rstrip('0')\n", 99 | " csv_path = os.path.join(parameterization, 'out', job_name, 'log.csv')\n", 100 | " if os.path.exists(csv_path):\n", 101 | " ckpt_df = pd.read_csv(csv_path)\n", 102 | " #losses.append(ckpt_df['train/loss'].mean())\n", 103 | " #losses.append(ckpt_df['train/loss'].min())\n", 104 | " losses.append(ckpt_df['val/loss'].values[-1])\n", 105 | " #losses.append(ckpt_df['train/loss'].ewm(alpha=0.9).mean().values[-1])\n", 106 | " # else:\n", 107 | " # print(f'Missing {csv_path}')\n", 108 | " if len(losses):\n", 109 | " mean_losses.append(np.mean(losses))\n", 110 | " sem_losses.append(np.std(losses, ddof=1) / np.sqrt(len(losses)))\n", 111 | " lrs_to_plot.append(lr)\n", 112 | " \n", 113 | " mean_losses = np.array(mean_losses)\n", 114 | " sem_losses = np.array(sem_losses)\n", 115 | " #ax.plot(lrs_to_plot, mean_losses, label=width, marker='o', color=color_helper.get_rgb(width_idx))\n", 116 | " #ax.fill_between(lrs_to_plot, mean_losses-sem_losses, mean_losses+sem_losses, color=color_helper.get_rgb(width_idx), alpha=0.33)\n", 117 | " palette = sns.color_palette(\"mako\", n_colors=len(widths))\n", 118 | " ax.plot(lrs_to_plot, mean_losses, label=width, color=palette[len(widths)-width_idx-1])\n", 119 | " ax.fill_between(lrs_to_plot, mean_losses-sem_losses, mean_losses+sem_losses, color=palette[len(widths)-width_idx-1], alpha=0.33)\n", 120 | "\n", 121 | " if len(mean_losses):\n", 122 | " optimum_idx = np.argmin(mean_losses)\n", 123 | " optimal_lrs.append(lrs_to_plot[optimum_idx])\n", 124 | " optimal_losses.append(mean_losses[optimum_idx])\n", 125 | " \n", 126 | " ax.plot(optimal_lrs, optimal_losses, color='red', linestyle='none', marker='o')\n", 127 | " ax.set_xscale('log', base=2)\n", 128 | " ax.set_xlabel('Learning rate')\n", 129 | " ax.set_title(parameterization_str)\n", 130 | " #ax.set_ylim(2.57, 3.15)\n", 131 | " #ax.set_ylim(2.0, 3.0)\n", 132 | " if j >= 2:\n", 133 | " ax.set_ylim(1.75, 3.0)\n", 134 | " else: \n", 135 | " ax.set_ylim(1.3, 3.0)\n", 136 | " #ax.set_ylim(2.3, 8.0)\n", 137 | " # ax.set_ylim(2.3, 2.7)\n", 138 | " # ax.set_ylim(2.4, 2.8)\n", 139 | "\n", 140 | " axes[0].legend(title='Width')\n", 141 | " # axes[0].set_ylabel('Train loss on\\nshakespeare_char')\n", 142 | " axes[0].set_ylabel('Validation loss')\n", 143 | " for i in range(len(axes))[1:]:\n", 144 | " axes[i].yaxis.set_ticklabels([])\n", 145 | " axes[i].tick_params(axis='y', length=0, width=0)\n", 146 | "\n", 147 | " plt.tight_layout()\n", 148 | " plt.savefig(f\"GPT_shakespeare_transfer_{j}.pdf\")\n", 149 | " plt.show()\n", 150 | " plt.close()\n" 151 | ] 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": "base", 157 | "language": "python", 158 | "name": "python3" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": { 162 | "name": "ipython", 163 | "version": 3 164 | }, 165 | "file_extension": ".py", 166 | "mimetype": "text/x-python", 167 | "name": "python", 168 | "nbconvert_exporter": "python", 169 | "pygments_lexer": "ipython3", 170 | "version": "3.9.15" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 2 175 | } 176 | -------------------------------------------------------------------------------- /examples/airbench/scion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | ####################################################### 6 | # Scion 7 | ####################################################### 8 | 9 | 10 | class Norm(object): 11 | def lmo(self, g): 12 | raise NotImplementedError 13 | 14 | def init(self, w): 15 | raise NotImplementedError 16 | 17 | 18 | class ColNorm(Norm): 19 | """ 20 | Column-wise normalization. 21 | 22 | Args: 23 | normalized (bool, optional): If True, normalizes by the input dimension. Use True only for non-input layers. 24 | transpose (bool, optional): If True, transposes input before normalization. Use True for embedding layers 25 | which store weights as (vocab_size, embedding_dim). 26 | """ 27 | def __init__(self, normalized=False, transpose=False): 28 | self.normalized = normalized 29 | self.transpose = transpose 30 | 31 | def lmo(self, g): 32 | eps = 1e-8 33 | if self.transpose: 34 | g = g.transpose(0, 1) 35 | rms_values = 1/math.sqrt(g.size(0))*torch.sqrt(torch.sum(g ** 2, dim=0, keepdim=True)) 36 | if self.normalized: 37 | rms_values *= g.size(1) 38 | g = g / (rms_values + eps) 39 | if self.transpose: 40 | g = g.transpose(0, 1) 41 | return g 42 | 43 | def init(self, w): 44 | dtype = w.data.dtype 45 | if self.transpose: 46 | w.data = w.data.transpose(0, 1) 47 | torch.nn.init.normal_(w.data) 48 | w.data /= w.norm(dim=0, keepdim=True) 49 | w.data *= math.sqrt(w.size(0)) 50 | if self.normalized: 51 | w.data /= w.size(1) 52 | w.data = w.data.to(dtype=dtype) 53 | if self.transpose: 54 | w.data = w.data.transpose(0, 1) 55 | return w 56 | 57 | 58 | class RowNorm(Norm): 59 | """ 60 | Row-wise normalization. 61 | 62 | Args: 63 | normalized (bool, optional): If True, normalizes by the input dimension. Use False only for the input layer. 64 | transpose (bool, optional): If True, transposes input before normalization. Use True for embedding layers 65 | which store weights as (vocab_size, embedding_dim). 66 | """ 67 | def __init__(self, normalized=True, transpose=False): 68 | self.normalized = normalized 69 | self.transpose = transpose 70 | 71 | def lmo(self, g): 72 | eps = 1e-8 73 | if self.transpose: 74 | g = g.transpose(0, 1) 75 | rms_values = torch.sqrt(torch.sum(g ** 2, dim=-1, keepdim=True)) 76 | if self.normalized: 77 | rms_values *= math.sqrt(g.size(-1)) 78 | g = g / (rms_values + eps) 79 | if self.transpose: 80 | g = g.transpose(0, 1) 81 | return g 82 | 83 | def init(self, w): 84 | dtype = w.data.dtype 85 | if self.transpose: 86 | w.data = w.data.transpose(0, 1) 87 | torch.nn.init.normal_(w.data) 88 | w.data /= w.norm(dim=-1, keepdim=True) 89 | if self.normalized: 90 | w.data /= math.sqrt(w.size(-1)) 91 | w.data = w.data.to(dtype=dtype) 92 | if self.transpose: 93 | w.data = w.data.transpose(0, 1) 94 | return w 95 | 96 | 97 | class BiasRMS(Norm): 98 | def lmo(self, g): 99 | eps = 1e-8 100 | rms_values = torch.sqrt(torch.mean(g ** 2, dim=0, keepdim=True)) 101 | g = g / (rms_values + eps) 102 | return g 103 | 104 | def init(self, g): 105 | return torch.nn.init.zeros_(g) 106 | 107 | class SpectralConv(Norm): 108 | def __init__(self, steps=5): 109 | self.steps = steps 110 | 111 | def lmo(self, g): 112 | g = zeropower_via_newtonschulz5(g.reshape(len(g), -1), steps=self.steps).view(g.shape) 113 | if g.ndim == 3: 114 | out_channels, in_channels, k = g.shape 115 | g *= (out_channels / in_channels)**0.5 / k 116 | elif g.ndim == 4: 117 | out_channels, in_channels, k, _ = g.shape 118 | g *= (out_channels / in_channels)**0.5 / (k ** 2) 119 | return g 120 | 121 | def init(self, w): 122 | w_fp = w.data.double() 123 | k = w.data.size(2) 124 | for kx in range(k): 125 | for ky in range(k): 126 | torch.nn.init.orthogonal_(w_fp[:,:,kx,ky]) 127 | 128 | if w.ndim == 3: 129 | out_channels, in_channels, k = w_fp.shape 130 | w_fp.mul_((out_channels / in_channels)**0.5 / k) 131 | elif w.ndim == 4: 132 | out_channels, in_channels, k, _ = w_fp.shape 133 | w_fp.mul_((out_channels / in_channels)**0.5 / (k ** 2)) 134 | w.data = w_fp.to(dtype=w.data.dtype) 135 | return w 136 | 137 | 138 | class Spectral(Norm): 139 | def __init__(self, max=False, normalized=True, steps=5): 140 | self.max = max 141 | self.steps = steps 142 | self.normalized = normalized 143 | 144 | def lmo(self, g): 145 | g = zeropower_via_newtonschulz5(g.reshape(len(g), -1), steps=self.steps).view(g.shape) 146 | d_out, d_in = g.shape 147 | 148 | if self.normalized: 149 | scale = (d_out / d_in)**0.5 150 | else: 151 | scale = d_out**0.5 152 | if self.max: 153 | scale = max(1,scale) 154 | g *= scale 155 | 156 | return g 157 | 158 | def init(self, w): 159 | w_fp = w.data.double() 160 | torch.nn.init.orthogonal_(w_fp) 161 | d_out, d_in = w_fp.shape 162 | 163 | if self.normalized: 164 | scale = (d_out / d_in)**0.5 165 | else: 166 | scale = d_out**0.5 167 | if self.max: 168 | scale = max(1,scale) 169 | w_fp.mul_(scale) 170 | 171 | w.data = w_fp.to(dtype=w.data.dtype) 172 | return w 173 | 174 | 175 | class Sign(Norm): 176 | def __init__(self, zero_init=False, normalized=True): 177 | self.zero_init = zero_init 178 | self.normalized = normalized 179 | 180 | def lmo(self, g): 181 | d_out, d_in = g.shape 182 | if self.normalized: 183 | return (1/d_in)*torch.sign(g) 184 | else: 185 | return torch.sign(g) 186 | 187 | def init(self, w): 188 | if self.zero_init: 189 | torch.nn.init.zeros_(w) 190 | else: 191 | # Generate -1/fan_in or 1/fan_in uniformly at random 192 | d_out, d_in = w.shape 193 | w.data = (torch.randint(0, 2, w.shape, dtype=w.dtype, device=w.device) * 2 - 1) 194 | if self.normalized: 195 | w.data *= (1/d_in) 196 | return w 197 | 198 | 199 | class Auto(Norm): 200 | def lmo(self, g): 201 | if g.ndim in [3,4]: 202 | return SpectralConv().lmo(g) 203 | elif g.ndim == 2: 204 | return Spectral().lmo(g) 205 | elif g.ndim in [0,1]: 206 | return BiasRMS().lmo(g) 207 | 208 | def init(self, w): 209 | if w.ndim in [3,4]: 210 | return SpectralConv().init(w) 211 | elif w.ndim == 2: 212 | return Spectral().init(w) 213 | elif w.ndim in [0,1]: 214 | return BiasRMS().init(w) 215 | 216 | 217 | norm_dict = { 218 | 'ColNorm': ColNorm, 219 | 'RowNorm': RowNorm, 220 | 'BiasRMS': BiasRMS, 221 | 'SpectralConv': SpectralConv, 222 | 'Spectral': Spectral, 223 | 'Sign': Sign, 224 | 'Auto': Auto, 225 | } 226 | 227 | 228 | class Scion(torch.optim.Optimizer): 229 | """Scion optimizer implementation. 230 | 231 | Args: 232 | params: Iterable of parameters to optimize or dicts defining parameter groups 233 | lr (float, optional): Learning rate (default: 1e-3) 234 | momentum (float, optional): One minus the traditional momentum factor. For example, 235 | a traditional momentum of 0.9 would be specified as momentum=0.1 here (default: 1.0) 236 | norm (str, optional): Choice of norm for gradient projection ('Auto', 'SpectralConv', 237 | 'ColNorm', 'RowNorm', 'BiasRMS', 'Spectral', or 'Sign') (default: 'Auto') 238 | norm_kwargs (dict, optional): Additional arguments for the norm projection (default: None) 239 | scale (float, optional): Scale factor for updates (default: 1.0) 240 | unconstrained (bool, optional): Whether to use unconstrained updates (default: False) 241 | 242 | Example: 243 | >>> radius = 50.0 244 | >>> optim_groups = [{ 245 | ... 'params': model.transformer.h.parameters(), 246 | ... 'norm': 'Spectral', 247 | ... 'norm_kwargs': {}, 248 | ... 'scale': radius, 249 | ... }, { 250 | ... 'params': model.lm_head.parameters(), 251 | ... 'norm': 'Sign', 252 | ... 'norm_kwargs': {}, 253 | ... 'scale': radius*60.0, 254 | ... }] 255 | >>> optimizer = Scion(optim_groups, lr=2**-12, momentum=0.1) 256 | """ 257 | def __init__(self, params, lr=1e-3, momentum=1.0, norm: str='Auto', norm_kwargs: dict=None, scale=1.0, unconstrained=False): 258 | if lr < 0.0: 259 | raise ValueError(f"Invalid learning rate: {lr}") 260 | if momentum < 0.0: 261 | raise ValueError(f"Invalid momentum value: {momentum}") 262 | if norm_kwargs is None: 263 | norm_kwargs = {} 264 | defaults = dict(lr=lr, momentum=momentum, scale=scale, unconstrained=unconstrained, norm=norm, norm_kwargs=norm_kwargs) 265 | super().__init__(params, defaults) 266 | 267 | def step(self): 268 | for group in self.param_groups: 269 | lr = group['lr'] 270 | momentum = group['momentum'] 271 | scale = group['scale'] 272 | unconstrained = group['unconstrained'] 273 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 274 | for p in group['params']: 275 | g = p.grad 276 | if g is None: 277 | continue 278 | state = self.state[p] 279 | 280 | if momentum != 1: 281 | if 'momentum_buffer' not in state.keys(): 282 | state['momentum_buffer'] = torch.zeros_like(g).add_(g) 283 | buf = state['momentum_buffer'] 284 | buf.mul_(1-momentum).add_(g, alpha=momentum) 285 | g = buf 286 | 287 | update = scale * norm_backend.lmo(g) 288 | if not unconstrained: 289 | p.data.mul_(1-lr) 290 | p.data.add_(update, alpha=-lr) 291 | 292 | def init(self): 293 | for group in self.param_groups: 294 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 295 | init_func = norm_backend.init 296 | scale = group['scale'] 297 | for p in group['params']: 298 | init_func(p) 299 | p.data *= scale 300 | 301 | 302 | @torch.compile 303 | def zeropower_via_newtonschulz5(G, steps=9): 304 | """ 305 | For fair comparison we use the same implementation as the Muon baseline. 306 | 307 | Computing zeroth matrix powers via Lakic 1998. 308 | paper: "On the Computation of the Matrix k-th Root" 309 | Suppose we have a matrix G = USV^T and we want to compute G^0 defined via G^0 = UV^T. 310 | We might want to do this to run "stochastic spectral descent" of Carlson et al 2015. 311 | The naive way to do this is via the SVD. But we can also just do (GG^T)^(-1/2) G or 312 | alternatively G (G^TG)^(-1/2) and apply the iterative method from Lakic 1998. 313 | In particular, we implement the first special case of Alg 1 in that paper. 314 | 315 | Code taken from: https://gist.github.com/jxbz/fe235ee1c72b8b41ccd0d02b43378cf2 316 | https://x.com/jxbz/status/1821610280708948103 317 | Modifications: To speed things up, I am running this in bfloat16 using torch.compile. 318 | """ 319 | 320 | orig_dtype = G.dtype 321 | G = G.bfloat16() 322 | 323 | d1, d2 = G.shape 324 | d = min(d1, d2) 325 | I = torch.eye(d, device=G.device, dtype=G.dtype) 326 | 327 | # store the smaller of the squares as S 328 | S = G @ G.T if d1 < d2 else G.T @ G 329 | S_norm = torch.linalg.matrix_norm(S, ord='fro') # there is freedom here. See Lakic (1998) Thm 2.3 330 | S /= S_norm 331 | 332 | # Now let's set up the state for the Lakic (1998) method 333 | N = S 334 | X = I.clone() 335 | 336 | # Now let's run the iteration 337 | for step in range(steps): 338 | U = (3 * I - N) / 2 339 | X = X @ U if step > 0 else U # optimization since X = I on step 0 340 | if step < steps-1: # optimization suggested by @EitanTurok https://x.com/EitanTurok/status/1839754807696855333 341 | N = N @ U @ U 342 | X /= S_norm.sqrt() 343 | 344 | # X should now store either (G G^T)^(-1/2) or (G^T G)^(-1/2) 345 | O = X @ G if d1 < d2 else G @ X 346 | return O.to(orig_dtype) 347 | -------------------------------------------------------------------------------- /examples/deit/scion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | ####################################################### 6 | # Scion 7 | ####################################################### 8 | 9 | 10 | class Norm(object): 11 | def lmo(self, g): 12 | raise NotImplementedError 13 | 14 | def init(self, w): 15 | raise NotImplementedError 16 | 17 | 18 | class ColNorm(Norm): 19 | """ 20 | Column-wise normalization. 21 | 22 | Args: 23 | normalized (bool, optional): If True, normalizes by the input dimension. Use True only for non-input layers. 24 | transpose (bool, optional): If True, transposes input before normalization. Use True for embedding layers 25 | which store weights as (vocab_size, embedding_dim). 26 | """ 27 | def __init__(self, normalized=False, transpose=False): 28 | self.normalized = normalized 29 | self.transpose = transpose 30 | 31 | def lmo(self, g): 32 | eps = 1e-8 33 | if self.transpose: 34 | g = g.transpose(0, 1) 35 | rms_values = 1/math.sqrt(g.size(0))*torch.sqrt(torch.sum(g ** 2, dim=0, keepdim=True)) 36 | if self.normalized: 37 | rms_values *= g.size(1) 38 | g = g / (rms_values + eps) 39 | if self.transpose: 40 | g = g.transpose(0, 1) 41 | return g 42 | 43 | def init(self, w): 44 | dtype = w.data.dtype 45 | if self.transpose: 46 | w.data = w.data.transpose(0, 1) 47 | torch.nn.init.normal_(w.data) 48 | w.data /= w.norm(dim=0, keepdim=True) 49 | w.data *= math.sqrt(w.size(0)) 50 | if self.normalized: 51 | w.data /= w.size(1) 52 | w.data = w.data.to(dtype=dtype) 53 | if self.transpose: 54 | w.data = w.data.transpose(0, 1) 55 | return w 56 | 57 | 58 | class RowNorm(Norm): 59 | """ 60 | Row-wise normalization. 61 | 62 | Args: 63 | normalized (bool, optional): If True, normalizes by the input dimension. Use False only for the input layer. 64 | transpose (bool, optional): If True, transposes input before normalization. Use True for embedding layers 65 | which store weights as (vocab_size, embedding_dim). 66 | """ 67 | def __init__(self, normalized=True, transpose=False): 68 | self.normalized = normalized 69 | self.transpose = transpose 70 | 71 | def lmo(self, g): 72 | eps = 1e-8 73 | if self.transpose: 74 | g = g.transpose(0, 1) 75 | rms_values = torch.sqrt(torch.sum(g ** 2, dim=-1, keepdim=True)) 76 | if self.normalized: 77 | rms_values *= math.sqrt(g.size(-1)) 78 | g = g / (rms_values + eps) 79 | if self.transpose: 80 | g = g.transpose(0, 1) 81 | return g 82 | 83 | def init(self, w): 84 | dtype = w.data.dtype 85 | if self.transpose: 86 | w.data = w.data.transpose(0, 1) 87 | torch.nn.init.normal_(w.data) 88 | w.data /= w.norm(dim=-1, keepdim=True) 89 | if self.normalized: 90 | w.data /= math.sqrt(w.size(-1)) 91 | w.data = w.data.to(dtype=dtype) 92 | if self.transpose: 93 | w.data = w.data.transpose(0, 1) 94 | return w 95 | 96 | 97 | class BiasRMS(Norm): 98 | def lmo(self, g): 99 | eps = 1e-8 100 | rms_values = torch.sqrt(torch.mean(g ** 2, dim=0, keepdim=True)) 101 | g = g / (rms_values + eps) 102 | return g 103 | 104 | def init(self, g): 105 | return torch.nn.init.zeros_(g) 106 | 107 | 108 | class SpectralConv(Norm): 109 | def __init__(self, steps=5): 110 | self.steps = steps 111 | 112 | def lmo(self, g): 113 | g = zeropower_via_newtonschulz5(g.reshape(len(g), -1), steps=self.steps).view(g.shape) 114 | if g.ndim == 3: 115 | out_channels, in_channels, k = g.shape 116 | g *= (out_channels / in_channels)**0.5 / k 117 | elif g.ndim == 4: 118 | out_channels, in_channels, k, _ = g.shape 119 | g *= (out_channels / in_channels)**0.5 / (k ** 2) 120 | return g 121 | 122 | def init(self, w): 123 | w_fp = w.data.double() 124 | k = w.data.size(2) 125 | for kx in range(k): 126 | for ky in range(k): 127 | torch.nn.init.orthogonal_(w_fp[:,:,kx,ky]) 128 | 129 | if w.ndim == 3: 130 | out_channels, in_channels, k = w_fp.shape 131 | w_fp.mul_((out_channels / in_channels)**0.5 / k) 132 | elif w.ndim == 4: 133 | out_channels, in_channels, k, _ = w_fp.shape 134 | w_fp.mul_((out_channels / in_channels)**0.5 / (k ** 2)) 135 | w.data = w_fp.to(dtype=w.data.dtype) 136 | return w 137 | 138 | 139 | class Spectral(Norm): 140 | def __init__(self, max=False, normalized=True, steps=5): 141 | self.max = max 142 | self.steps = steps 143 | self.normalized = normalized 144 | 145 | def lmo(self, g): 146 | g = zeropower_via_newtonschulz5(g.reshape(len(g), -1), steps=self.steps).view(g.shape) 147 | d_out, d_in = g.shape 148 | 149 | if self.normalized: 150 | scale = (d_out / d_in)**0.5 151 | else: 152 | scale = d_out**0.5 153 | if self.max: 154 | scale = max(1,scale) 155 | g *= scale 156 | 157 | return g 158 | 159 | def init(self, w): 160 | w_fp = w.data.double() 161 | torch.nn.init.orthogonal_(w_fp) 162 | d_out, d_in = w_fp.shape 163 | 164 | if self.normalized: 165 | scale = (d_out / d_in)**0.5 166 | else: 167 | scale = d_out**0.5 168 | if self.max: 169 | scale = max(1,scale) 170 | w_fp.mul_(scale) 171 | 172 | w.data = w_fp.to(dtype=w.data.dtype) 173 | return w 174 | 175 | 176 | class Sign(Norm): 177 | def __init__(self, zero_init=False, normalized=True): 178 | self.zero_init = zero_init 179 | self.normalized = normalized 180 | 181 | def lmo(self, g): 182 | d_out, d_in = g.shape 183 | if self.normalized: 184 | return (1/d_in)*torch.sign(g) 185 | else: 186 | return torch.sign(g) 187 | 188 | def init(self, w): 189 | if self.zero_init: 190 | torch.nn.init.zeros_(w) 191 | else: 192 | # Generate -1/fan_in or 1/fan_in uniformly at random 193 | d_out, d_in = w.shape 194 | w.data = (torch.randint(0, 2, w.shape, dtype=w.dtype, device=w.device) * 2 - 1) 195 | if self.normalized: 196 | w.data *= (1/d_in) 197 | return w 198 | 199 | 200 | class Auto(Norm): 201 | def lmo(self, g): 202 | if g.ndim in [3,4]: 203 | return SpectralConv().lmo(g) 204 | elif g.ndim == 2: 205 | return Spectral().lmo(g) 206 | elif g.ndim in [0,1]: 207 | return BiasRMS().lmo(g) 208 | 209 | def init(self, w): 210 | if w.ndim in [3,4]: 211 | return SpectralConv().init(w) 212 | elif w.ndim == 2: 213 | return Spectral().init(w) 214 | elif w.ndim in [0,1]: 215 | return BiasRMS().init(w) 216 | 217 | 218 | norm_dict = { 219 | 'ColNorm': ColNorm, 220 | 'RowNorm': RowNorm, 221 | 'BiasRMS': BiasRMS, 222 | 'SpectralConv': SpectralConv, 223 | 'Spectral': Spectral, 224 | 'Sign': Sign, 225 | 'Auto': Auto, 226 | } 227 | 228 | 229 | class Scion(torch.optim.Optimizer): 230 | """Scion optimizer implementation. 231 | 232 | Args: 233 | params: Iterable of parameters to optimize or dicts defining parameter groups 234 | lr (float, optional): Learning rate (default: 1e-3) 235 | momentum (float, optional): One minus the traditional momentum factor. For example, 236 | a traditional momentum of 0.9 would be specified as momentum=0.1 here (default: 1.0) 237 | norm (str, optional): Choice of norm for gradient projection ('Auto', 'SpectralConv', 238 | 'ColNorm', 'RowNorm', 'BiasRMS', 'Spectral', or 'Sign') (default: 'Auto') 239 | norm_kwargs (dict, optional): Additional arguments for the norm projection (default: None) 240 | scale (float, optional): Scale factor for updates (default: 1.0) 241 | unconstrained (bool, optional): Whether to use unconstrained updates (default: False) 242 | 243 | Example: 244 | >>> radius = 50.0 245 | >>> optim_groups = [{ 246 | ... 'params': model.transformer.h.parameters(), 247 | ... 'norm': 'Spectral', 248 | ... 'norm_kwargs': {}, 249 | ... 'scale': radius, 250 | ... }, { 251 | ... 'params': model.lm_head.parameters(), 252 | ... 'norm': 'Sign', 253 | ... 'norm_kwargs': {}, 254 | ... 'scale': radius*60.0, 255 | ... }] 256 | >>> optimizer = Scion(optim_groups, lr=2**-12, momentum=0.1) 257 | """ 258 | def __init__(self, params, lr=1e-3, momentum=1.0, norm: str='Auto', norm_kwargs: dict=None, scale=1.0, unconstrained=False): 259 | if lr < 0.0: 260 | raise ValueError(f"Invalid learning rate: {lr}") 261 | if momentum < 0.0: 262 | raise ValueError(f"Invalid momentum value: {momentum}") 263 | if norm_kwargs is None: 264 | norm_kwargs = {} 265 | defaults = dict(lr=lr, momentum=momentum, scale=scale, unconstrained=unconstrained, norm=norm, norm_kwargs=norm_kwargs) 266 | super().__init__(params, defaults) 267 | 268 | def step(self): 269 | for group in self.param_groups: 270 | lr = group['lr'] 271 | momentum = group['momentum'] 272 | scale = group['scale'] 273 | unconstrained = group['unconstrained'] 274 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 275 | for p in group['params']: 276 | g = p.grad 277 | if g is None: 278 | continue 279 | state = self.state[p] 280 | 281 | if momentum != 1: 282 | if 'momentum_buffer' not in state.keys(): 283 | state['momentum_buffer'] = torch.zeros_like(g) 284 | buf = state['momentum_buffer'] 285 | buf.mul_(1-momentum).add_(g, alpha=momentum) 286 | g = buf 287 | 288 | update = scale * norm_backend.lmo(g) 289 | if not unconstrained: 290 | p.data.mul_(1-lr) 291 | p.data.add_(update, alpha=-lr) 292 | 293 | def init(self): 294 | for group in self.param_groups: 295 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 296 | init_func = norm_backend.init 297 | scale = group['scale'] 298 | for p in group['params']: 299 | init_func(p) 300 | p.data *= scale 301 | 302 | 303 | class ScionLight(torch.optim.Optimizer): 304 | """Memory-efficient variant of the Scion optimizer. 305 | 306 | This implementation saves memory by storing only the averaged gradient instead of 307 | both the gradient and its average. Note that gradients should not be zeroed since 308 | p.grad is used directly to store the gradient average. 309 | 310 | Args: 311 | params: Iterable of parameters to optimize or dicts defining parameter groups 312 | lr (float, optional): Learning rate (default: 1e-3) 313 | momentum (float, optional): One minus the traditional momentum factor. For example, 314 | a traditional momentum of 0.9 would be specified as momentum=0.1 here (default: 1.0) 315 | norm (str, optional): Choice of norm for gradient projection ('Auto', 'SpectralConv', 316 | 'ColNorm', 'RowNorm', 'BiasRMS', 'Spectral', or 'Sign') (default: 'Auto') 317 | norm_kwargs (dict, optional): Additional arguments for the norm projection (default: None) 318 | scale (float, optional): Scale factor for updates (default: 1.0) 319 | unconstrained (bool, optional): Whether to use unconstrained updates (default: False) 320 | 321 | Example: 322 | >>> radius = 50.0 323 | >>> optim_groups = [{ 324 | ... 'params': model.transformer.h.parameters(), 325 | ... 'norm': 'Spectral', 326 | ... 'norm_kwargs': {}, 327 | ... 'scale': radius, 328 | ... }, { 329 | ... 'params': model.lm_head.parameters(), 330 | ... 'norm': 'Sign', 331 | ... 'norm_kwargs': {}, 332 | ... 'scale': radius*60.0, 333 | ... }] 334 | >>> optimizer = ScionLight(optim_groups, lr=2**-12, momentum=0.1) 335 | """ 336 | def __init__(self, params, lr=1e-3, momentum=1.0, norm: str='Auto', norm_kwargs: dict=None, scale=1.0, unconstrained=False): 337 | if lr < 0.0: 338 | raise ValueError(f"Invalid learning rate: {lr}") 339 | if momentum < 0.0: 340 | raise ValueError(f"Invalid momentum value: {momentum}") 341 | if norm_kwargs is None: 342 | norm_kwargs = {} 343 | defaults = dict(lr=lr, momentum=momentum, scale=scale, unconstrained=unconstrained, norm=norm, norm_kwargs=norm_kwargs) 344 | super().__init__(params, defaults) 345 | 346 | def step(self): 347 | for group in self.param_groups: 348 | lr = group['lr'] 349 | momentum = group['momentum'] 350 | scale = group['scale'] 351 | unconstrained = group['unconstrained'] 352 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 353 | for p in group['params']: 354 | G = p.grad 355 | if G is None: 356 | continue 357 | 358 | update = scale * norm_backend.lmo(G) 359 | if not unconstrained: 360 | p.data.mul_(1-lr) 361 | p.data.add_(update, alpha=-lr) 362 | 363 | if momentum != 1: 364 | G.mul_(1-momentum) 365 | 366 | def init(self): 367 | for group in self.param_groups: 368 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 369 | init_func = norm_backend.init 370 | scale = group['scale'] 371 | for p in group['params']: 372 | init_func(p) 373 | p.data *= scale 374 | 375 | 376 | @torch.compile 377 | def zeropower_via_newtonschulz5(G, steps=5): 378 | """ 379 | From: https://github.com/KellerJordan/modded-nanogpt/blob/master/records/101724_DistributedMuon/22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt 380 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 381 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 382 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 383 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 384 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 385 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 386 | performance at all relative to UV^T, where USV^T = G is the SVD. 387 | """ 388 | assert len(G.shape) == 2 389 | a, b, c = (3.4445, -4.7750, 2.0315) 390 | X = G.bfloat16() 391 | if G.size(0) > G.size(1): 392 | X = X.T 393 | 394 | # Ensure spectral norm is at most 1 395 | X = X / (X.norm() + 1e-7) 396 | # Perform the NS iterations 397 | for _ in range(steps): 398 | A = X @ X.T 399 | B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng 400 | X = a * X + B @ X 401 | 402 | if G.size(0) > G.size(1): 403 | X = X.T 404 | return X 405 | 406 | 407 | def zeroth_power_via_svd(G): 408 | U, S, V = G.svd() 409 | return U @ V.T 410 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/scion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | ####################################################### 6 | # Scion 7 | ####################################################### 8 | 9 | 10 | class Norm(object): 11 | def lmo(self, g): 12 | raise NotImplementedError 13 | 14 | def init(self, w): 15 | raise NotImplementedError 16 | 17 | 18 | class ColNorm(Norm): 19 | """ 20 | Column-wise normalization. 21 | 22 | Args: 23 | normalized (bool, optional): If True, normalizes by the input dimension. Use True only for non-input layers. 24 | transpose (bool, optional): If True, transposes input before normalization. Use True for embedding layers 25 | which store weights as (vocab_size, embedding_dim). 26 | """ 27 | def __init__(self, normalized=False, transpose=False): 28 | self.normalized = normalized 29 | self.transpose = transpose 30 | 31 | def lmo(self, g): 32 | eps = 1e-8 33 | if self.transpose: 34 | g = g.transpose(0, 1) 35 | rms_values = 1/math.sqrt(g.size(0))*torch.sqrt(torch.sum(g ** 2, dim=0, keepdim=True)) 36 | if self.normalized: 37 | rms_values *= g.size(1) 38 | g = g / (rms_values + eps) 39 | if self.transpose: 40 | g = g.transpose(0, 1) 41 | return g 42 | 43 | def init(self, w): 44 | dtype = w.data.dtype 45 | if self.transpose: 46 | w.data = w.data.transpose(0, 1) 47 | torch.nn.init.normal_(w.data) 48 | w.data /= w.norm(dim=0, keepdim=True) 49 | w.data *= math.sqrt(w.size(0)) 50 | if self.normalized: 51 | w.data /= w.size(1) 52 | w.data = w.data.to(dtype=dtype) 53 | if self.transpose: 54 | w.data = w.data.transpose(0, 1) 55 | return w 56 | 57 | 58 | class RowNorm(Norm): 59 | """ 60 | Row-wise normalization. 61 | 62 | Args: 63 | normalized (bool, optional): If True, normalizes by the input dimension. Use False only for the input layer. 64 | transpose (bool, optional): If True, transposes input before normalization. Use True for embedding layers 65 | which store weights as (vocab_size, embedding_dim). 66 | """ 67 | def __init__(self, normalized=True, transpose=False): 68 | self.normalized = normalized 69 | self.transpose = transpose 70 | 71 | def lmo(self, g): 72 | eps = 1e-8 73 | if self.transpose: 74 | g = g.transpose(0, 1) 75 | rms_values = torch.sqrt(torch.sum(g ** 2, dim=-1, keepdim=True)) 76 | if self.normalized: 77 | rms_values *= math.sqrt(g.size(-1)) 78 | g = g / (rms_values + eps) 79 | if self.transpose: 80 | g = g.transpose(0, 1) 81 | return g 82 | 83 | def init(self, w): 84 | dtype = w.data.dtype 85 | if self.transpose: 86 | w.data = w.data.transpose(0, 1) 87 | torch.nn.init.normal_(w.data) 88 | w.data /= w.norm(dim=-1, keepdim=True) 89 | if self.normalized: 90 | w.data /= math.sqrt(w.size(-1)) 91 | w.data = w.data.to(dtype=dtype) 92 | if self.transpose: 93 | w.data = w.data.transpose(0, 1) 94 | return w 95 | 96 | 97 | class BiasRMS(Norm): 98 | def lmo(self, g): 99 | eps = 1e-8 100 | rms_values = torch.sqrt(torch.mean(g ** 2, dim=0, keepdim=True)) 101 | g = g / (rms_values + eps) 102 | return g 103 | 104 | def init(self, g): 105 | return torch.nn.init.zeros_(g) 106 | 107 | 108 | class SpectralConv(Norm): 109 | def __init__(self, steps=5): 110 | self.steps = steps 111 | 112 | def lmo(self, g): 113 | g = zeropower_via_newtonschulz5(g.reshape(len(g), -1), steps=self.steps).view(g.shape) 114 | if g.ndim == 3: 115 | out_channels, in_channels, k = g.shape 116 | g *= (out_channels / in_channels)**0.5 / k 117 | elif g.ndim == 4: 118 | out_channels, in_channels, k, _ = g.shape 119 | g *= (out_channels / in_channels)**0.5 / (k ** 2) 120 | return g 121 | 122 | def init(self, w): 123 | w_fp = w.data.double() 124 | k = w.data.size(2) 125 | for kx in range(k): 126 | for ky in range(k): 127 | torch.nn.init.orthogonal_(w_fp[:,:,kx,ky]) 128 | 129 | if w.ndim == 3: 130 | out_channels, in_channels, k = w_fp.shape 131 | w_fp.mul_((out_channels / in_channels)**0.5 / k) 132 | elif w.ndim == 4: 133 | out_channels, in_channels, k, _ = w_fp.shape 134 | w_fp.mul_((out_channels / in_channels)**0.5 / (k ** 2)) 135 | w.data = w_fp.to(dtype=w.data.dtype) 136 | return w 137 | 138 | 139 | class Spectral(Norm): 140 | def __init__(self, max=False, normalized=True, steps=5): 141 | self.max = max 142 | self.steps = steps 143 | self.normalized = normalized 144 | 145 | def lmo(self, g): 146 | g = zeropower_via_newtonschulz5(g.reshape(len(g), -1), steps=self.steps).view(g.shape) 147 | d_out, d_in = g.shape 148 | 149 | if self.normalized: 150 | scale = (d_out / d_in)**0.5 151 | else: 152 | scale = d_out**0.5 153 | if self.max: 154 | scale = max(1,scale) 155 | g *= scale 156 | 157 | return g 158 | 159 | def init(self, w): 160 | w_fp = w.data.double() 161 | torch.nn.init.orthogonal_(w_fp) 162 | d_out, d_in = w_fp.shape 163 | 164 | if self.normalized: 165 | scale = (d_out / d_in)**0.5 166 | else: 167 | scale = d_out**0.5 168 | if self.max: 169 | scale = max(1,scale) 170 | w_fp.mul_(scale) 171 | 172 | w.data = w_fp.to(dtype=w.data.dtype) 173 | return w 174 | 175 | 176 | class Sign(Norm): 177 | def __init__(self, zero_init=False, normalized=True): 178 | self.zero_init = zero_init 179 | self.normalized = normalized 180 | 181 | def lmo(self, g): 182 | d_out, d_in = g.shape 183 | if self.normalized: 184 | return (1/d_in)*torch.sign(g) 185 | else: 186 | return torch.sign(g) 187 | 188 | def init(self, w): 189 | if self.zero_init: 190 | torch.nn.init.zeros_(w) 191 | else: 192 | # Generate -1/fan_in or 1/fan_in uniformly at random 193 | d_out, d_in = w.shape 194 | w.data = (torch.randint(0, 2, w.shape, dtype=w.dtype, device=w.device) * 2 - 1) 195 | if self.normalized: 196 | w.data *= (1/d_in) 197 | return w 198 | 199 | 200 | class Auto(Norm): 201 | def lmo(self, g): 202 | if g.ndim in [3,4]: 203 | return SpectralConv().lmo(g) 204 | elif g.ndim == 2: 205 | return Spectral().lmo(g) 206 | elif g.ndim in [0,1]: 207 | return BiasRMS().lmo(g) 208 | 209 | def init(self, w): 210 | if w.ndim in [3,4]: 211 | return SpectralConv().init(w) 212 | elif w.ndim == 2: 213 | return Spectral().init(w) 214 | elif w.ndim in [0,1]: 215 | return BiasRMS().init(w) 216 | 217 | 218 | norm_dict = { 219 | 'ColNorm': ColNorm, 220 | 'RowNorm': RowNorm, 221 | 'BiasRMS': BiasRMS, 222 | 'SpectralConv': SpectralConv, 223 | 'Spectral': Spectral, 224 | 'Sign': Sign, 225 | 'Auto': Auto, 226 | } 227 | 228 | 229 | class Scion(torch.optim.Optimizer): 230 | """Scion optimizer implementation. 231 | 232 | Args: 233 | params: Iterable of parameters to optimize or dicts defining parameter groups 234 | lr (float, optional): Learning rate (default: 1e-3) 235 | momentum (float, optional): One minus the traditional momentum factor. For example, 236 | a traditional momentum of 0.9 would be specified as momentum=0.1 here (default: 1.0) 237 | norm (str, optional): Choice of norm for gradient projection ('Auto', 'SpectralConv', 238 | 'ColNorm', 'RowNorm', 'BiasRMS', 'Spectral', or 'Sign') (default: 'Auto') 239 | norm_kwargs (dict, optional): Additional arguments for the norm projection (default: None) 240 | scale (float, optional): Scale factor for updates (default: 1.0) 241 | unconstrained (bool, optional): Whether to use unconstrained updates (default: False) 242 | 243 | Example: 244 | >>> radius = 50.0 245 | >>> optim_groups = [{ 246 | ... 'params': model.transformer.h.parameters(), 247 | ... 'norm': 'Spectral', 248 | ... 'norm_kwargs': {}, 249 | ... 'scale': radius, 250 | ... }, { 251 | ... 'params': model.lm_head.parameters(), 252 | ... 'norm': 'Sign', 253 | ... 'norm_kwargs': {}, 254 | ... 'scale': radius*60.0, 255 | ... }] 256 | >>> optimizer = Scion(optim_groups, lr=2**-12, momentum=0.1) 257 | """ 258 | def __init__(self, params, lr=1e-3, momentum=1.0, norm: str='Auto', norm_kwargs: dict=None, scale=1.0, unconstrained=False): 259 | if lr < 0.0: 260 | raise ValueError(f"Invalid learning rate: {lr}") 261 | if momentum < 0.0: 262 | raise ValueError(f"Invalid momentum value: {momentum}") 263 | if norm_kwargs is None: 264 | norm_kwargs = {} 265 | defaults = dict(lr=lr, momentum=momentum, scale=scale, unconstrained=unconstrained, norm=norm, norm_kwargs=norm_kwargs) 266 | super().__init__(params, defaults) 267 | 268 | def step(self): 269 | for group in self.param_groups: 270 | lr = group['lr'] 271 | momentum = group['momentum'] 272 | scale = group['scale'] 273 | unconstrained = group['unconstrained'] 274 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 275 | for p in group['params']: 276 | g = p.grad 277 | if g is None: 278 | continue 279 | state = self.state[p] 280 | 281 | if momentum != 1: 282 | if 'momentum_buffer' not in state.keys(): 283 | state['momentum_buffer'] = torch.zeros_like(g) 284 | buf = state['momentum_buffer'] 285 | buf.mul_(1-momentum).add_(g, alpha=momentum) 286 | g = buf 287 | 288 | update = scale * norm_backend.lmo(g) 289 | if not unconstrained: 290 | p.data.mul_(1-lr) 291 | p.data.add_(update, alpha=-lr) 292 | 293 | def init(self): 294 | for group in self.param_groups: 295 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 296 | init_func = norm_backend.init 297 | scale = group['scale'] 298 | for p in group['params']: 299 | init_func(p) 300 | p.data *= scale 301 | 302 | 303 | class ScionLight(torch.optim.Optimizer): 304 | """Memory-efficient variant of the Scion optimizer. 305 | 306 | This implementation saves memory by storing only the averaged gradient instead of 307 | both the gradient and its average. Note that gradients should not be zeroed since 308 | p.grad is used directly to store the gradient average. 309 | 310 | Args: 311 | params: Iterable of parameters to optimize or dicts defining parameter groups 312 | lr (float, optional): Learning rate (default: 1e-3) 313 | momentum (float, optional): One minus the traditional momentum factor. For example, 314 | a traditional momentum of 0.9 would be specified as momentum=0.1 here (default: 1.0) 315 | norm (str, optional): Choice of norm for gradient projection ('Auto', 'SpectralConv', 316 | 'ColNorm', 'RowNorm', 'BiasRMS', 'Spectral', or 'Sign') (default: 'Auto') 317 | norm_kwargs (dict, optional): Additional arguments for the norm projection (default: None) 318 | scale (float, optional): Scale factor for updates (default: 1.0) 319 | unconstrained (bool, optional): Whether to use unconstrained updates (default: False) 320 | 321 | Example: 322 | >>> radius = 50.0 323 | >>> optim_groups = [{ 324 | ... 'params': model.transformer.h.parameters(), 325 | ... 'norm': 'Spectral', 326 | ... 'norm_kwargs': {}, 327 | ... 'scale': radius, 328 | ... }, { 329 | ... 'params': model.lm_head.parameters(), 330 | ... 'norm': 'Sign', 331 | ... 'norm_kwargs': {}, 332 | ... 'scale': radius*60.0, 333 | ... }] 334 | >>> optimizer = ScionLight(optim_groups, lr=2**-12, momentum=0.1) 335 | """ 336 | def __init__(self, params, lr=1e-3, momentum=1.0, norm: str='Auto', norm_kwargs: dict=None, scale=1.0, unconstrained=False): 337 | if lr < 0.0: 338 | raise ValueError(f"Invalid learning rate: {lr}") 339 | if momentum < 0.0: 340 | raise ValueError(f"Invalid momentum value: {momentum}") 341 | if norm_kwargs is None: 342 | norm_kwargs = {} 343 | defaults = dict(lr=lr, momentum=momentum, scale=scale, unconstrained=unconstrained, norm=norm, norm_kwargs=norm_kwargs) 344 | super().__init__(params, defaults) 345 | 346 | def step(self): 347 | for group in self.param_groups: 348 | lr = group['lr'] 349 | momentum = group['momentum'] 350 | scale = group['scale'] 351 | unconstrained = group['unconstrained'] 352 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 353 | for p in group['params']: 354 | G = p.grad 355 | if G is None: 356 | continue 357 | 358 | update = scale * norm_backend.lmo(G) 359 | if not unconstrained: 360 | p.data.mul_(1-lr) 361 | p.data.add_(update, alpha=-lr) 362 | 363 | if momentum != 1: 364 | G.mul_(1-momentum) 365 | 366 | def init(self): 367 | for group in self.param_groups: 368 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 369 | init_func = norm_backend.init 370 | scale = group['scale'] 371 | for p in group['params']: 372 | init_func(p) 373 | p.data *= scale 374 | 375 | 376 | #@torch.compile 377 | def zeropower_via_newtonschulz5(G, steps=5): 378 | """ 379 | From: https://github.com/KellerJordan/modded-nanogpt/blob/master/records/101724_DistributedMuon/22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt 380 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 381 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 382 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 383 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 384 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 385 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 386 | performance at all relative to UV^T, where USV^T = G is the SVD. 387 | """ 388 | assert len(G.shape) == 2 389 | a, b, c = (3.4445, -4.7750, 2.0315) 390 | X = G.bfloat16() 391 | if G.size(0) > G.size(1): 392 | X = X.T 393 | 394 | # Ensure spectral norm is at most 1 395 | X = X / (X.norm() + 1e-7) 396 | # Perform the NS iterations 397 | for _ in range(steps): 398 | A = X @ X.T 399 | B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng 400 | X = a * X + B @ X 401 | 402 | if G.size(0) > G.size(1): 403 | X = X.T 404 | return X 405 | 406 | 407 | def zeroth_power_via_svd(G): 408 | U, S, V = G.svd() 409 | return U @ V.T 410 | -------------------------------------------------------------------------------- /scion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | ####################################################### 6 | # Scion 7 | ####################################################### 8 | 9 | 10 | class Norm(object): 11 | def lmo(self, g): 12 | raise NotImplementedError 13 | 14 | def init(self, w): 15 | raise NotImplementedError 16 | 17 | 18 | class ColNorm(Norm): 19 | """ 20 | Column-wise normalization. 21 | 22 | Args: 23 | normalized (bool, optional): If True, normalizes by the input dimension. Use True only for non-input layers. 24 | transpose (bool, optional): If True, transposes input before normalization. Use True for embedding layers 25 | which store weights as (vocab_size, embedding_dim). 26 | """ 27 | def __init__(self, normalized=False, transpose=False): 28 | self.normalized = normalized 29 | self.transpose = transpose 30 | 31 | def lmo(self, g): 32 | eps = 1e-8 33 | if self.transpose: 34 | g = g.transpose(0, 1) 35 | rms_values = 1/math.sqrt(g.size(0))*torch.sqrt(torch.sum(g ** 2, dim=0, keepdim=True)) 36 | if self.normalized: 37 | rms_values *= g.size(1) 38 | g = g / (rms_values + eps) 39 | if self.transpose: 40 | g = g.transpose(0, 1) 41 | return g 42 | 43 | def init(self, w): 44 | dtype = w.data.dtype 45 | if self.transpose: 46 | w.data = w.data.transpose(0, 1) 47 | torch.nn.init.normal_(w.data) 48 | w.data /= w.norm(dim=0, keepdim=True) 49 | w.data *= math.sqrt(w.size(0)) 50 | if self.normalized: 51 | w.data /= w.size(1) 52 | w.data = w.data.to(dtype=dtype) 53 | if self.transpose: 54 | w.data = w.data.transpose(0, 1) 55 | return w 56 | 57 | 58 | class RowNorm(Norm): 59 | """ 60 | Row-wise normalization. 61 | 62 | Args: 63 | normalized (bool, optional): If True, normalizes by the input dimension. Use False only for the input layer. 64 | transpose (bool, optional): If True, transposes input before normalization. Use True for embedding layers 65 | which store weights as (vocab_size, embedding_dim). 66 | """ 67 | def __init__(self, normalized=True, transpose=False): 68 | self.normalized = normalized 69 | self.transpose = transpose 70 | 71 | def lmo(self, g): 72 | eps = 1e-8 73 | if self.transpose: 74 | g = g.transpose(0, 1) 75 | rms_values = torch.sqrt(torch.sum(g ** 2, dim=-1, keepdim=True)) 76 | if self.normalized: 77 | rms_values *= math.sqrt(g.size(-1)) 78 | g = g / (rms_values + eps) 79 | if self.transpose: 80 | g = g.transpose(0, 1) 81 | return g 82 | 83 | def init(self, w): 84 | dtype = w.data.dtype 85 | if self.transpose: 86 | w.data = w.data.transpose(0, 1) 87 | torch.nn.init.normal_(w.data) 88 | w.data /= w.norm(dim=-1, keepdim=True) 89 | if self.normalized: 90 | w.data /= math.sqrt(w.size(-1)) 91 | w.data = w.data.to(dtype=dtype) 92 | if self.transpose: 93 | w.data = w.data.transpose(0, 1) 94 | return w 95 | 96 | 97 | class BiasRMS(Norm): 98 | def lmo(self, g): 99 | eps = 1e-8 100 | rms_values = torch.sqrt(torch.mean(g ** 2, dim=0, keepdim=True)) 101 | g = g / (rms_values + eps) 102 | return g 103 | 104 | def init(self, g): 105 | return torch.nn.init.zeros_(g) 106 | 107 | 108 | class SpectralConv(Norm): 109 | def __init__(self, steps=5): 110 | self.steps = steps 111 | 112 | def lmo(self, g): 113 | g = zeropower_via_newtonschulz5(g.reshape(len(g), -1), steps=self.steps).view(g.shape) 114 | if g.ndim == 3: # Conv1d 115 | out_channels, in_channels, k = g.shape 116 | g *= (out_channels / in_channels)**0.5 / k 117 | elif g.ndim == 4: # Conv2d 118 | out_channels, in_channels, k, _ = g.shape 119 | g *= (out_channels / in_channels)**0.5 / (k ** 2) 120 | return g 121 | 122 | def init(self, w): 123 | w_fp = w.data.double() 124 | k = w.data.size(2) 125 | for kx in range(k): 126 | for ky in range(k): 127 | torch.nn.init.orthogonal_(w_fp[:,:,kx,ky]) 128 | 129 | if w.ndim == 3: # Conv1d 130 | out_channels, in_channels, k = w_fp.shape 131 | w_fp.mul_((out_channels / in_channels)**0.5 / k) 132 | elif w.ndim == 4: # Conv2d 133 | out_channels, in_channels, k, _ = w_fp.shape 134 | w_fp.mul_((out_channels / in_channels)**0.5 / (k ** 2)) 135 | w.data = w_fp.to(dtype=w.data.dtype) 136 | return w 137 | 138 | 139 | class Spectral(Norm): 140 | def __init__(self, max=False, normalized=True, steps=5): 141 | self.max = max 142 | self.steps = steps 143 | self.normalized = normalized 144 | 145 | def lmo(self, g): 146 | g = zeropower_via_newtonschulz5(g.reshape(len(g), -1), steps=self.steps).view(g.shape) 147 | d_out, d_in = g.shape 148 | 149 | if self.normalized: 150 | scale = (d_out / d_in)**0.5 151 | else: 152 | scale = d_out**0.5 153 | if self.max: 154 | scale = max(1,scale) 155 | g *= scale 156 | 157 | return g 158 | 159 | def init(self, w): 160 | w_fp = w.data.double() 161 | torch.nn.init.orthogonal_(w_fp) 162 | d_out, d_in = w_fp.shape 163 | 164 | if self.normalized: 165 | scale = (d_out / d_in)**0.5 166 | else: 167 | scale = d_out**0.5 168 | if self.max: 169 | scale = max(1,scale) 170 | w_fp.mul_(scale) 171 | 172 | w.data = w_fp.to(dtype=w.data.dtype) 173 | return w 174 | 175 | 176 | class Sign(Norm): 177 | def __init__(self, zero_init=False, normalized=True): 178 | self.zero_init = zero_init 179 | self.normalized = normalized 180 | 181 | def lmo(self, g): 182 | d_out, d_in = g.shape 183 | if self.normalized: 184 | return (1/d_in)*torch.sign(g) 185 | else: 186 | return torch.sign(g) 187 | 188 | def init(self, w): 189 | if self.zero_init: 190 | torch.nn.init.zeros_(w) 191 | else: 192 | # Generate -1/fan_in or 1/fan_in uniformly at random 193 | d_out, d_in = w.shape 194 | w.data = (torch.randint(0, 2, w.shape, dtype=w.dtype, device=w.device) * 2 - 1) 195 | if self.normalized: 196 | w.data *= (1/d_in) 197 | return w 198 | 199 | 200 | class Auto(Norm): 201 | def lmo(self, g): 202 | if g.ndim in [3,4]: 203 | return SpectralConv().lmo(g) 204 | elif g.ndim == 2: 205 | return Spectral().lmo(g) 206 | elif g.ndim in [0,1]: 207 | return BiasRMS().lmo(g) 208 | 209 | def init(self, w): 210 | if w.ndim in [3,4]: 211 | return SpectralConv().init(w) 212 | elif w.ndim == 2: 213 | return Spectral().init(w) 214 | elif w.ndim in [0,1]: 215 | return BiasRMS().init(w) 216 | 217 | 218 | norm_dict = { 219 | 'ColNorm': ColNorm, 220 | 'RowNorm': RowNorm, 221 | 'BiasRMS': BiasRMS, 222 | 'SpectralConv': SpectralConv, 223 | 'Spectral': Spectral, 224 | 'Sign': Sign, 225 | 'Auto': Auto, 226 | } 227 | 228 | 229 | class Scion(torch.optim.Optimizer): 230 | """Scion optimizer implementation. 231 | 232 | Args: 233 | params: Iterable of parameters to optimize or dicts defining parameter groups 234 | lr (float, optional): Learning rate (default: 1e-3) 235 | momentum (float, optional): One minus the traditional momentum factor. For example, 236 | a traditional momentum of 0.9 would be specified as momentum=0.1 here (default: 1.0) 237 | norm (str, optional): Choice of norm for gradient projection ('Auto', 'SpectralConv', 238 | 'ColNorm', 'RowNorm', 'BiasRMS', 'Spectral', or 'Sign') (default: 'Auto') 239 | norm_kwargs (dict, optional): Additional arguments for the norm projection (default: None) 240 | scale (float, optional): Scale factor for updates (default: 1.0) 241 | unconstrained (bool, optional): Whether to use unconstrained updates (default: False) 242 | 243 | Example: 244 | >>> radius = 50.0 245 | >>> optim_groups = [{ 246 | ... 'params': model.transformer.h.parameters(), 247 | ... 'norm': 'Spectral', 248 | ... 'norm_kwargs': {}, 249 | ... 'scale': radius, 250 | ... }, { 251 | ... 'params': model.lm_head.parameters(), 252 | ... 'norm': 'Sign', 253 | ... 'norm_kwargs': {}, 254 | ... 'scale': radius*60.0, 255 | ... }] 256 | >>> optimizer = Scion(optim_groups, lr=2**-12, momentum=0.1) 257 | """ 258 | def __init__(self, params, lr=1e-3, momentum=1.0, norm: str='Auto', norm_kwargs: dict=None, scale=1.0, unconstrained=False): 259 | if lr < 0.0: 260 | raise ValueError(f"Invalid learning rate: {lr}") 261 | if momentum < 0.0: 262 | raise ValueError(f"Invalid momentum value: {momentum}") 263 | if norm_kwargs is None: 264 | norm_kwargs = {} 265 | defaults = dict(lr=lr, momentum=momentum, scale=scale, unconstrained=unconstrained, norm=norm, norm_kwargs=norm_kwargs) 266 | super().__init__(params, defaults) 267 | 268 | def step(self): 269 | for group in self.param_groups: 270 | lr = group['lr'] 271 | momentum = group['momentum'] 272 | scale = group['scale'] 273 | unconstrained = group['unconstrained'] 274 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 275 | for p in group['params']: 276 | g = p.grad 277 | if g is None: 278 | continue 279 | state = self.state[p] 280 | 281 | if momentum != 1: 282 | if 'momentum_buffer' not in state.keys(): 283 | state['momentum_buffer'] = torch.zeros_like(g) 284 | buf = state['momentum_buffer'] 285 | buf.mul_(1-momentum).add_(g, alpha=momentum) 286 | g = buf 287 | 288 | update = scale * norm_backend.lmo(g) 289 | if not unconstrained: 290 | p.data.mul_(1-lr) 291 | p.data.add_(update, alpha=-lr) 292 | 293 | def init(self): 294 | for group in self.param_groups: 295 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 296 | init_func = norm_backend.init 297 | scale = group['scale'] 298 | for p in group['params']: 299 | init_func(p) 300 | p.data *= scale 301 | 302 | 303 | class ScionLight(torch.optim.Optimizer): 304 | """Memory-efficient variant of the Scion optimizer. 305 | 306 | This implementation saves memory by storing only the averaged gradient instead of 307 | both the gradient and its average. Note that gradients should not be zeroed since 308 | p.grad is used directly to store the gradient average. 309 | 310 | Args: 311 | params: Iterable of parameters to optimize or dicts defining parameter groups 312 | lr (float, optional): Learning rate (default: 1e-3) 313 | momentum (float, optional): One minus the traditional momentum factor. For example, 314 | a traditional momentum of 0.9 would be specified as momentum=0.1 here (default: 1.0) 315 | norm (str, optional): Choice of norm for gradient projection ('Auto', 'SpectralConv', 316 | 'ColNorm', 'RowNorm', 'BiasRMS', 'Spectral', or 'Sign') (default: 'Auto') 317 | norm_kwargs (dict, optional): Additional arguments for the norm projection (default: None) 318 | scale (float, optional): Scale factor for updates (default: 1.0) 319 | unconstrained (bool, optional): Whether to use unconstrained updates (default: False) 320 | 321 | Example: 322 | >>> radius = 50.0 323 | >>> optim_groups = [{ 324 | ... 'params': model.transformer.h.parameters(), 325 | ... 'norm': 'Spectral', 326 | ... 'norm_kwargs': {}, 327 | ... 'scale': radius, 328 | ... }, { 329 | ... 'params': model.lm_head.parameters(), 330 | ... 'norm': 'Sign', 331 | ... 'norm_kwargs': {}, 332 | ... 'scale': radius*60.0, 333 | ... }] 334 | >>> optimizer = ScionLight(optim_groups, lr=2**-12, momentum=0.1) 335 | """ 336 | def __init__(self, params, lr=1e-3, momentum=1.0, norm: str='Auto', norm_kwargs: dict=None, scale=1.0, unconstrained=False): 337 | if lr < 0.0: 338 | raise ValueError(f"Invalid learning rate: {lr}") 339 | if momentum < 0.0: 340 | raise ValueError(f"Invalid momentum value: {momentum}") 341 | if norm_kwargs is None: 342 | norm_kwargs = {} 343 | defaults = dict(lr=lr, momentum=momentum, scale=scale, unconstrained=unconstrained, norm=norm, norm_kwargs=norm_kwargs) 344 | super().__init__(params, defaults) 345 | # Initialize state 346 | self._store_grads_in_state() 347 | # Do not pass `self` through syntactic sugar. We need the 348 | # argument to not be populated. 349 | self.register_state_dict_pre_hook( 350 | type(self)._store_grads_in_state, 351 | ) 352 | self.register_load_state_dict_post_hook( 353 | type(self)._load_grads_from_state, 354 | ) 355 | 356 | def step(self): 357 | for group in self.param_groups: 358 | lr = group['lr'] 359 | momentum = group['momentum'] 360 | scale = group['scale'] 361 | unconstrained = group['unconstrained'] 362 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 363 | for p in group['params']: 364 | G = p.grad 365 | if G is None: 366 | continue 367 | 368 | update = scale * norm_backend.lmo(G) 369 | if not unconstrained: 370 | p.data.mul_(1-lr) 371 | p.data.add_(update, alpha=-lr) 372 | 373 | if momentum != 1: 374 | G.mul_(1-momentum) 375 | 376 | def init(self): 377 | for group in self.param_groups: 378 | norm_backend = norm_dict[group['norm']](**group['norm_kwargs']) 379 | init_func = norm_backend.init 380 | scale = group['scale'] 381 | for p in group['params']: 382 | init_func(p) 383 | p.data *= scale 384 | 385 | def __getstate__(self): 386 | self._store_grads_in_state() 387 | return super().__getstate__() 388 | 389 | def __setstate__(self, state): 390 | super().__setstate__(state) 391 | self._load_grads_from_state() 392 | 393 | def _store_grads_in_state(self): 394 | for group in self.param_groups: 395 | for param in group['params']: 396 | if isinstance(param, torch.Tensor) and param.grad is not None: 397 | self.state.setdefault(param, {})['grad_state'] = param.grad 398 | 399 | def _load_grads_from_state(self): 400 | for (param, state) in self.state.items(): 401 | if 'grad_state' in state: 402 | param.grad = state['grad_state'] 403 | elif isinstance(param, torch.Tensor): 404 | param.grad = None 405 | 406 | 407 | @torch.compile 408 | def zeropower_via_newtonschulz5(G, steps=5): 409 | """ 410 | From: https://github.com/KellerJordan/modded-nanogpt/blob/master/records/101724_DistributedMuon/22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt 411 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 412 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 413 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 414 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 415 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 416 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 417 | performance at all relative to UV^T, where USV^T = G is the SVD. 418 | """ 419 | assert len(G.shape) == 2 420 | a, b, c = (3.4445, -4.7750, 2.0315) 421 | X = G.bfloat16() 422 | if G.size(0) > G.size(1): 423 | X = X.T 424 | 425 | # Ensure spectral norm is at most 1 426 | X = X / (X.norm() + 1e-7) 427 | # Perform the NS iterations 428 | for _ in range(steps): 429 | A = X @ X.T 430 | B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng 431 | X = a * X + B @ X 432 | 433 | if G.size(0) > G.size(1): 434 | X = X.T 435 | return X 436 | 437 | 438 | def zeroth_power_via_svd(G): 439 | U, S, V = G.svd() 440 | return U @ V.T 441 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/README_orig.md: -------------------------------------------------------------------------------- 1 | # nanoGPT-mup 2 | 3 | This repository is a fork of [nanoGPT](https://github.com/karpathy/nanoGPT) that provides a minimal implementation of the [maximal update parameterization](https://arxiv.org/abs/2203.03466) ([muP](https://github.com/microsoft/mup)). 4 | 5 | Branches 6 | - The [master branch](https://github.com/EleutherAI/nanoGPT-mup) acts as supplementary material for ["The Practitioner’s Guide to the Maximal Update Parameterization"](https://www.cerebras.ai/blog/the-practitioners-guide-to-the-maximal-update-parameterization). 7 | - The [supar branch](https://github.com/EleutherAI/nanoGPT-mup/tree/supar) contains a minimal implementation of sparse maximal update parameterization (SuPar) introduced in [Sparse maximal update parameterization: A holistic approach to sparse training dynamics](https://arxiv.org/abs/2405.15743). 8 | 9 | The [mup_examples](https://github.com/EleutherAI/nanoGPT-mup/tree/master/mup_examples) folder contains scripts to reproduce the plots in ["The Practitioner’s Guide to the Maximal Update Parameterization"](https://www.cerebras.ai/blog/the-practitioners-guide-to-the-maximal-update-parameterization) (see [mup_examples/README.md](https://github.com/EleutherAI/nanoGPT-mup/blob/master/mup_examples/README.md) for instructions to reproduce). 10 | 11 | Each of the critical muP changes are marked with 12 | ``` 13 | ### Begin muP code ### 14 | 15 | ### End muP code ### 16 | ``` 17 | to make everything easily searchable. 18 | 19 | | Parameterization | SP | **μP** | Code | 20 | |------------------|----|----|----| 21 | | Embedding Init. Var. | $σ_{base}^2$ | $σ_{base}^2$ | | 22 | | Embedding LR | $η_{base}$ | $η_{base}$ | | 23 | | Embedding Fwd. | $x W_{\text{emb}}$ | $\mathbf{α_{input}} · x W_{\text{emb}}$ | [Code](https://github.com/EleutherAI/nanoGPT-mup/blob/bcadbc3c7a44138525eca8a799764afba7dca2b3/model.py#L208) | 24 | | Hidden Init. Var. | $σ_{base}^2$ | $σ_{base}^2 / \mathbf{m_d}$ | [Code](https://github.com/EleutherAI/nanoGPT-mup/blob/bcadbc3c7a44138525eca8a799764afba7dca2b3/model.py#L163-L169) | 25 | | Hidden LR (Adam) | $η_{base}$ | $η_{base} / \mathbf{m_d}$ | [Code](https://github.com/EleutherAI/nanoGPT-mup/blob/bcadbc3c7a44138525eca8a799764afba7dca2b3/model.py#L306-L329) | 26 | | Output Logit Fwd. | $x W_{\text{emb}}^\top$ | $\mathbf{α_{output}} · x W_{\text{emb}}^\top / \mathbf{m_d}$ | [Code](https://github.com/EleutherAI/nanoGPT-mup/blob/bcadbc3c7a44138525eca8a799764afba7dca2b3/model.py#L219) | 27 | | Attention logits | $Q^\top K / \sqrt{d_{\text{head}}}$ | $Q^\top K / \mathbf{d_{\text{head}}}$ | [Code](https://github.com/EleutherAI/nanoGPT-mup/blob/bcadbc3c7a44138525eca8a799764afba7dca2b3/model.py#L65) | 28 | 29 | 30 | ## Implementation Validation 31 | 32 | ### Coordinate Checks 33 | 34 | Standard Parameterization: 35 | 36 | SP 37 | 38 | muTransfer: 39 | 40 | muP 41 | 42 | 43 | ### Learning Rate muTransfer 44 | 45 | **Tiny Shakespeare** | **OpenWebText** 46 | :-------------------------:|:-------------------------: 47 | mup-shakespeare | mup-owt 48 | 49 | 50 | ## Citation 51 | 52 | If ["The Practitioner’s Guide to the Maximal Update Parameterization"](https://www.cerebras.ai/blog/the-practitioners-guide-to-the-maximal-update-parameterization) or this repository was useful to you, please cite: 53 | ``` 54 | @misc{cerebras2024mupguide, 55 | author = {Dey, Nolan and Anthony, Quentin and Hestness, Joel}, 56 | title = {{The practitioner’s guide to the maximal update parameterization}}, 57 | month = September, 58 | year = 2024, 59 | howpublished = {\url{https://www.cerebras.ai/blog/the-practitioners-guide-to-the-maximal-update-parameterization}}, 60 | url = \url{https://www.cerebras.ai/blog/the-practitioners-guide-to-the-maximal-update-parameterization}, 61 | } 62 | ``` 63 | 64 | # nanoGPT (Original README) 65 | 66 | ![nanoGPT](assets/nanogpt.jpg) 67 | 68 | The simplest, fastest repository for training/finetuning medium-sized GPTs. It is a rewrite of [minGPT](https://github.com/karpathy/minGPT) that prioritizes teeth over education. Still under active development, but currently the file `train.py` reproduces GPT-2 (124M) on OpenWebText, running on a single 8XA100 40GB node in about 4 days of training. The code itself is plain and readable: `train.py` is a ~300-line boilerplate training loop and `model.py` a ~300-line GPT model definition, which can optionally load the GPT-2 weights from OpenAI. That's it. 69 | 70 | ![repro124m](assets/gpt2_124M_loss.png) 71 | 72 | Because the code is so simple, it is very easy to hack to your needs, train new models from scratch, or finetune pretrained checkpoints (e.g. biggest one currently available as a starting point would be the GPT-2 1.3B model from OpenAI). 73 | 74 | ## install 75 | 76 | ``` 77 | pip install torch numpy transformers datasets tiktoken wandb tqdm 78 | ``` 79 | 80 | Dependencies: 81 | 82 | - [pytorch](https://pytorch.org) <3 83 | - [numpy](https://numpy.org/install/) <3 84 | - `transformers` for huggingface transformers <3 (to load GPT-2 checkpoints) 85 | - `datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText) 86 | - `tiktoken` for OpenAI's fast BPE code <3 87 | - `wandb` for optional logging <3 88 | - `tqdm` for progress bars <3 89 | 90 | ## quick start 91 | 92 | If you are not a deep learning professional and you just want to feel the magic and get your feet wet, the fastest way to get started is to train a character-level GPT on the works of Shakespeare. First, we download it as a single (1MB) file and turn it from raw text into one large stream of integers: 93 | 94 | ```sh 95 | python data/shakespeare_char/prepare.py 96 | ``` 97 | 98 | This creates a `train.bin` and `val.bin` in that data directory. Now it is time to train your GPT. The size of it very much depends on the computational resources of your system: 99 | 100 | **I have a GPU**. Great, we can quickly train a baby GPT with the settings provided in the [config/train_shakespeare_char.py](config/train_shakespeare_char.py) config file: 101 | 102 | ```sh 103 | python train.py config/train_shakespeare_char.py 104 | ``` 105 | 106 | If you peek inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory: 107 | 108 | ```sh 109 | python sample.py --out_dir=out-shakespeare-char 110 | ``` 111 | 112 | This generates a few samples, for example: 113 | 114 | ``` 115 | ANGELO: 116 | And cowards it be strawn to my bed, 117 | And thrust the gates of my threats, 118 | Because he that ale away, and hang'd 119 | An one with him. 120 | 121 | DUKE VINCENTIO: 122 | I thank your eyes against it. 123 | 124 | DUKE VINCENTIO: 125 | Then will answer him to save the malm: 126 | And what have you tyrannous shall do this? 127 | 128 | DUKE VINCENTIO: 129 | If you have done evils of all disposition 130 | To end his power, the day of thrust for a common men 131 | That I leave, to fight with over-liking 132 | Hasting in a roseman. 133 | ``` 134 | 135 | lol `¯\_(ツ)_/¯`. Not bad for a character-level model after 3 minutes of training on a GPU. Better results are quite likely obtainable by instead finetuning a pretrained GPT-2 model on this dataset (see finetuning section later). 136 | 137 | **I only have a macbook** (or other cheap computer). No worries, we can still train a GPT but we want to dial things down a notch. I recommend getting the bleeding edge PyTorch nightly ([select it here](https://pytorch.org/get-started/locally/) when installing) as it is currently quite likely to make your code more efficient. But even without it, a simple train run could look as follows: 138 | 139 | ```sh 140 | python train.py config/train_shakespeare_char.py --device=cpu --compile=False --eval_iters=20 --log_interval=1 --block_size=64 --batch_size=12 --n_layer=4 --n_head=4 --n_embd=128 --max_iters=2000 --lr_decay_iters=2000 --dropout=0.0 141 | ``` 142 | 143 | Here, since we are running on CPU instead of GPU we must set both `--device=cpu` and also turn off PyTorch 2.0 compile with `--compile=False`. Then when we evaluate we get a bit more noisy but faster estimate (`--eval_iters=20`, down from 200), our context size is only 64 characters instead of 256, and the batch size only 12 examples per iteration, not 64. We'll also use a much smaller Transformer (4 layers, 4 heads, 128 embedding size), and decrease the number of iterations to 2000 (and correspondingly usually decay the learning rate to around max_iters with `--lr_decay_iters`). Because our network is so small we also ease down on regularization (`--dropout=0.0`). This still runs in about ~3 minutes, but gets us a loss of only 1.88 and therefore also worse samples, but it's still good fun: 144 | 145 | ```sh 146 | python sample.py --out_dir=out-shakespeare-char --device=cpu 147 | ``` 148 | Generates samples like this: 149 | 150 | ``` 151 | GLEORKEN VINGHARD III: 152 | Whell's the couse, the came light gacks, 153 | And the for mought you in Aut fries the not high shee 154 | bot thou the sought bechive in that to doth groan you, 155 | No relving thee post mose the wear 156 | ``` 157 | 158 | Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer, feel free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc. 159 | 160 | Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device=mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more. 161 | 162 | ## reproducing GPT-2 163 | 164 | A more serious deep learning professional may be more interested in reproducing GPT-2 results. So here we go - we first tokenize the dataset, in this case the [OpenWebText](https://openwebtext2.readthedocs.io/en/latest/), an open reproduction of OpenAI's (private) WebText: 165 | 166 | ```sh 167 | python data/openwebtext/prepare.py 168 | ``` 169 | 170 | This downloads and tokenizes the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. It will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. To reproduce GPT-2 (124M) you'll want at least an 8X A100 40GB node and run: 171 | 172 | ```sh 173 | torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py 174 | ``` 175 | 176 | This will run for about 4 days using PyTorch Distributed Data Parallel (DDP) and go down to loss of ~2.85. Now, a GPT-2 model just evaluated on OWT gets a val loss of about 3.11, but if you finetune it it will come down to ~2.85 territory (due to an apparent domain gap), making the two models ~match. 177 | 178 | If you're in a cluster environment and you are blessed with multiple GPU nodes you can make GPU go brrrr e.g. across 2 nodes like: 179 | 180 | ```sh 181 | # Run on the first (master) node with example IP 123.456.123.456: 182 | torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 183 | # Run on the worker node: 184 | torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 185 | ``` 186 | 187 | It is a good idea to benchmark your interconnect (e.g. iperf3). In particular, if you don't have Infiniband then also prepend `NCCL_IB_DISABLE=1` to the above launches. Your multinode training will work, but most likely _crawl_. By default checkpoints are periodically written to the `--out_dir`. We can sample from the model by simply `python sample.py`. 188 | 189 | Finally, to train on a single GPU simply run the `python train.py` script. Have a look at all of its args, the script tries to be very readable, hackable and transparent. You'll most likely want to tune a number of those variables depending on your needs. 190 | 191 | ## baselines 192 | 193 | OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows: 194 | 195 | ```sh 196 | $ python train.py config/eval_gpt2.py 197 | $ python train.py config/eval_gpt2_medium.py 198 | $ python train.py config/eval_gpt2_large.py 199 | $ python train.py config/eval_gpt2_xl.py 200 | ``` 201 | 202 | and observe the following losses on train and val: 203 | 204 | | model | params | train loss | val loss | 205 | | ------| ------ | ---------- | -------- | 206 | | gpt2 | 124M | 3.11 | 3.12 | 207 | | gpt2-medium | 350M | 2.85 | 2.84 | 208 | | gpt2-large | 774M | 2.66 | 2.67 | 209 | | gpt2-xl | 1558M | 2.56 | 2.54 | 210 | 211 | However, we have to note that GPT-2 was trained on (closed, never released) WebText, while OpenWebText is just a best-effort open reproduction of this dataset. This means there is a dataset domain gap. Indeed, taking the GPT-2 (124M) checkpoint and finetuning on OWT directly for a while reaches loss down to ~2.85. This then becomes the more appropriate baseline w.r.t. reproduction. 212 | 213 | ## finetuning 214 | 215 | Finetuning is no different than training, we just make sure to initialize from a pretrained model and train with a smaller learning rate. For an example of how to finetune a GPT on new text go to `data/shakespeare` and run `prepare.py` to download the tiny shakespeare dataset and render it into a `train.bin` and `val.bin`, using the OpenAI BPE tokenizer from GPT-2. Unlike OpenWebText this will run in seconds. Finetuning can take very little time, e.g. on a single GPU just a few minutes. Run an example finetuning like: 216 | 217 | ```sh 218 | python train.py config/finetune_shakespeare.py 219 | ``` 220 | 221 | This will load the config parameter overrides in `config/finetune_shakespeare.py` (I didn't tune them much though). Basically, we initialize from a GPT2 checkpoint with `init_from` and train as normal, except shorter and with a small learning rate. If you're running out of memory try decreasing the model size (they are `{'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}`) or possibly decreasing the `block_size` (context length). The best checkpoint (lowest validation loss) will be in the `out_dir` directory, e.g. in `out-shakespeare` by default, per the config file. You can then run the code in `sample.py --out_dir=out-shakespeare`: 222 | 223 | ``` 224 | THEODORE: 225 | Thou shalt sell me to the highest bidder: if I die, 226 | I sell thee to the first; if I go mad, 227 | I sell thee to the second; if I 228 | lie, I sell thee to the third; if I slay, 229 | I sell thee to the fourth: so buy or sell, 230 | I tell thee again, thou shalt not sell my 231 | possession. 232 | 233 | JULIET: 234 | And if thou steal, thou shalt not sell thyself. 235 | 236 | THEODORE: 237 | I do not steal; I sell the stolen goods. 238 | 239 | THEODORE: 240 | Thou know'st not what thou sell'st; thou, a woman, 241 | Thou art ever a victim, a thing of no worth: 242 | Thou hast no right, no right, but to be sold. 243 | ``` 244 | 245 | Whoa there, GPT, entering some dark place over there. I didn't really tune the hyperparameters in the config too much, feel free to try! 246 | 247 | ## sampling / inference 248 | 249 | Use the script `sample.py` to sample either from pre-trained GPT-2 models released by OpenAI, or from a model you trained yourself. For example, here is a way to sample from the largest available `gpt2-xl` model: 250 | 251 | ```sh 252 | python sample.py \ 253 | --init_from=gpt2-xl \ 254 | --start="What is the answer to life, the universe, and everything?" \ 255 | --num_samples=5 --max_new_tokens=100 256 | ``` 257 | 258 | If you'd like to sample from a model you trained, use the `--out_dir` to point the code appropriately. You can also prompt the model with some text from a file, e.g. ```python sample.py --start=FILE:prompt.txt```. 259 | 260 | ## efficiency notes 261 | 262 | For simple model benchmarking and profiling, `bench.py` might be useful. It's identical to what happens in the meat of the training loop of `train.py`, but omits much of the other complexities. 263 | 264 | Note that the code by default uses [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/). At the time of writing (Dec 29, 2022) this makes `torch.compile()` available in the nightly release. The improvement from the one line of code is noticeable, e.g. cutting down iteration time from ~250ms / iter to 135ms / iter. Nice work PyTorch team! 265 | 266 | ## todos 267 | 268 | - Investigate and add FSDP instead of DDP 269 | - Eval zero-shot perplexities on standard evals (e.g. LAMBADA? HELM? etc.) 270 | - Finetune the finetuning script, I think the hyperparams are not great 271 | - Schedule for linear batch size increase during training 272 | - Incorporate other embeddings (rotary, alibi) 273 | - Separate out the optim buffers from model params in checkpoints I think 274 | - Additional logging around network health (e.g. gradient clip events, magnitudes) 275 | - Few more investigations around better init etc. 276 | 277 | ## troubleshooting 278 | 279 | Note that by default this repo uses PyTorch 2.0 (i.e. `torch.compile`). This is fairly new and experimental, and not yet available on all platforms (e.g. Windows). If you're running into related error messages try to disable this by adding `--compile=False` flag. This will slow down the code but at least it will run. 280 | 281 | For some context on this repository, GPT, and language modeling it might be helpful to watch my [Zero To Hero series](https://karpathy.ai/zero-to-hero.html). Specifically, the [GPT video](https://www.youtube.com/watch?v=kCc8FmEb1nY) is popular if you have some prior language modeling context. 282 | 283 | For more questions/discussions feel free to stop by **#nanoGPT** on Discord: 284 | 285 | [![](https://dcbadge.vercel.app/api/server/3zy8kqD9Cp?compact=true&style=flat)](https://discord.gg/3zy8kqD9Cp) 286 | 287 | ## acknowledgements 288 | 289 | All nanoGPT experiments are powered by GPUs on [Lambda labs](https://lambdalabs.com), my favorite Cloud GPU provider. Thank you Lambda labs for sponsoring nanoGPT! 290 | -------------------------------------------------------------------------------- /examples/shallow-nanogpt/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This training script can be run both on a single gpu in debug mode, 3 | and also in a larger training run with distributed data parallel (ddp). 4 | 5 | To run on a single GPU, example: 6 | $ python train.py --batch_size=32 --compile=False 7 | 8 | To run with DDP on 4 gpus on 1 node, example: 9 | $ torchrun --standalone --nproc_per_node=4 train.py 10 | 11 | To run with DDP on 4 gpus across 2 nodes, example: 12 | - Run on the first (master) node with example IP 123.456.123.456: 13 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 14 | - Run on the worker node: 15 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 16 | (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1) 17 | """ 18 | 19 | import os 20 | import time 21 | import math 22 | import pickle 23 | from contextlib import nullcontext 24 | from functools import partial 25 | 26 | import numpy as np 27 | import torch 28 | from torch.nn.parallel import DistributedDataParallel as DDP 29 | from torch.distributed import init_process_group, destroy_process_group 30 | 31 | from model import GPTConfig, GPT 32 | 33 | # ----------------------------------------------------------------------------- 34 | # default config values designed to train a gpt2 (124M) on OpenWebText 35 | # I/O 36 | out_dir = 'out' 37 | eval_interval = 2000 38 | log_interval = 1 39 | eval_on_end = False 40 | eval_iters = 200 41 | eval_only = False # if True, script exits right after the first eval 42 | skip_val_loss = False # If True, will only measure train loss 43 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 44 | never_save_checkpoint = False # if True, never save a checkpoint 45 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 46 | # wandb logging 47 | wandb_log = False # disabled by default 48 | wandb_project = 'owt' 49 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 50 | # csv logging 51 | csv_log = False # If enabled, logs stats to a csv file 52 | # data 53 | dataset = 'openwebtext' 54 | gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes 55 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 56 | block_size = 1024 57 | # model 58 | n_layer = 12 59 | n_head = 12 60 | n_embd = 768 61 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 62 | bias = False # do we use bias inside LayerNorm and Linear layers? 63 | init_std = 0.02 # Initialization standard deviation for weights 64 | # adamw optimizer 65 | learning_rate = 6e-4 # max learning rate 66 | max_iters = 600000 # total number of training iterations 67 | weight_decay = 1e-1 68 | beta1 = 0.9 69 | beta2 = 0.95 70 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 71 | # learning rate decay settings 72 | decay_lr = True # whether to decay the learning rate 73 | warmup_iters = 2000 # how many steps to warm up for 74 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 75 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 76 | scion_enabled = False 77 | scion_first_layer = 'Sign' 78 | scion_unconstrained = False 79 | scion_mode = 'Mixed' 80 | # mup settings 81 | mup_enabled = False # Whether to use muP. If False then all other mup variables are ignored 82 | mup_disable_attention_scaling = False # Uses 1/sqrt(d_head) attn scaling instead of 1/d_head (Only needed for the step-by-step coord check in the blog) 83 | mup_disable_hidden_lr_scaling = False # Disables muP hidden LR adjustment (Only needed for the step-by-step coord check in the blog) 84 | mup_width_multiplier = 1.0 # mup_width_multiplier = width / base_width where base_width is typically 256 85 | mup_input_alpha = 1.0 # Optional tunable multiplier applied to input embedding forward pass output 86 | mup_output_alpha = 1.0 # Optional tunable multiplier applied to output unembedding forward pass output 87 | mup_enable_coord_check_logging = False # If True will track the output.abs().mean() of various layers throughout training 88 | # seed 89 | seed = 1337 90 | # DDP settings 91 | backend = 'nccl' # 'nccl', 'gloo', etc. 92 | # system 93 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 94 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 95 | compile = True # use PyTorch 2.0 to compile the model to be faster 96 | # ----------------------------------------------------------------------------- 97 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 98 | exec(open('configurator.py').read()) # overrides from command line or config file 99 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 100 | # ----------------------------------------------------------------------------- 101 | 102 | assert not (never_save_checkpoint and always_save_checkpoint) 103 | 104 | # various inits, derived attributes, I/O setup 105 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 106 | if ddp: 107 | init_process_group(backend=backend) 108 | ddp_rank = int(os.environ['RANK']) 109 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 110 | ddp_world_size = int(os.environ['WORLD_SIZE']) 111 | device = f'cuda:{ddp_local_rank}' 112 | torch.cuda.set_device(device) 113 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 114 | seed_offset = ddp_rank # each process gets a different seed 115 | # world_size number of processes will be training simultaneously, so we can scale 116 | # down the desired gradient accumulation iterations per process proportionally 117 | assert gradient_accumulation_steps % ddp_world_size == 0 118 | gradient_accumulation_steps //= ddp_world_size 119 | else: 120 | # if not ddp, we are running on a single gpu, and one process 121 | master_process = True 122 | seed_offset = 0 123 | ddp_world_size = 1 124 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size 125 | print(f"tokens per iteration will be: {tokens_per_iter:,}") 126 | 127 | if master_process: 128 | os.makedirs(out_dir, exist_ok=True) 129 | torch.manual_seed(seed + seed_offset) 130 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 131 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 132 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 133 | # note: float16 data type will automatically use a GradScaler 134 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 135 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 136 | 137 | # poor man's data loader 138 | data_dir = os.path.join('data', dataset) 139 | def get_batch(split): 140 | # We recreate np.memmap every batch to avoid a memory leak, as per 141 | # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 142 | if split == 'train': 143 | data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 144 | else: 145 | data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 146 | ix = torch.randint(len(data) - block_size, (batch_size,)) 147 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 148 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 149 | if device_type == 'cuda': 150 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 151 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 152 | else: 153 | x, y = x.to(device), y.to(device) 154 | return x, y 155 | 156 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 157 | iter_num = 0 158 | best_val_loss = 1e9 159 | 160 | # attempt to derive vocab_size from the dataset 161 | meta_path = os.path.join(data_dir, 'meta.pkl') 162 | meta_vocab_size = None 163 | if os.path.exists(meta_path): 164 | with open(meta_path, 'rb') as f: 165 | meta = pickle.load(f) 166 | meta_vocab_size = meta['vocab_size'] 167 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 168 | 169 | # model init 170 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 171 | bias=bias, vocab_size=None, dropout=dropout, 172 | scion_enabled=scion_enabled, 173 | scion_first_layer=scion_first_layer, 174 | scion_unconstrained=scion_unconstrained, 175 | scion_mode=scion_mode, 176 | mup_enabled=mup_enabled, 177 | mup_disable_attention_scaling=mup_disable_attention_scaling, 178 | mup_disable_hidden_lr_scaling=mup_disable_hidden_lr_scaling, 179 | mup_width_multiplier=mup_width_multiplier, mup_input_alpha=mup_input_alpha, 180 | mup_output_alpha=mup_output_alpha) # start with model_args from command line 181 | 182 | if init_from == 'scratch': 183 | # init a new model from scratch 184 | print("Initializing a new model from scratch") 185 | # determine the vocab size we'll use for from-scratch training 186 | if meta_vocab_size is None: 187 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 188 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 189 | gptconf = GPTConfig(**model_args) 190 | model = GPT(gptconf) 191 | elif init_from == 'resume': 192 | print(f"Resuming training from {out_dir}") 193 | # resume training from a checkpoint. 194 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 195 | checkpoint = torch.load(ckpt_path, map_location=device) 196 | checkpoint_model_args = checkpoint['model_args'] 197 | # force these config attributes to be equal otherwise we can't even resume training 198 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 199 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 200 | model_args[k] = checkpoint_model_args[k] 201 | # create the model 202 | gptconf = GPTConfig(**model_args) 203 | model = GPT(gptconf) 204 | state_dict = checkpoint['model'] 205 | # fix the keys of the state dictionary :( 206 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 207 | unwanted_prefix = '_orig_mod.' 208 | for k,v in list(state_dict.items()): 209 | if k.startswith(unwanted_prefix): 210 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 211 | model.load_state_dict(state_dict) 212 | iter_num = checkpoint['iter_num'] 213 | best_val_loss = checkpoint['best_val_loss'] 214 | elif init_from.startswith('gpt2'): 215 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 216 | # initialize from OpenAI GPT-2 weights 217 | override_args = dict(dropout=dropout) 218 | model = GPT.from_pretrained(init_from, override_args) 219 | # read off the created config params, so we can store them into checkpoint correctly 220 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 221 | model_args[k] = getattr(model.config, k) 222 | # crop down the model block size if desired, using model surgery 223 | if block_size < model.config.block_size: 224 | model.crop_block_size(block_size) 225 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 226 | model.to(device) 227 | 228 | # initialize a GradScaler. If enabled=False scaler is a no-op 229 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 230 | 231 | # optimizer 232 | optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) 233 | if init_from == 'resume': 234 | optimizer.load_state_dict(checkpoint['optimizer']) 235 | checkpoint = None # free up memory 236 | 237 | # compile the model 238 | if compile: 239 | print("compiling the model... (takes a ~minute)") 240 | unoptimized_model = model 241 | model = torch.compile(model) # requires PyTorch 2.0 242 | 243 | # wrap model into DDP container 244 | if ddp: 245 | model = DDP(model, device_ids=[ddp_local_rank]) 246 | 247 | # helps estimate an arbitrarily accurate loss over either split using many batches 248 | @torch.no_grad() 249 | def estimate_loss(): 250 | out = {} 251 | model.eval() 252 | splits = ['train'] if skip_val_loss else ['train', 'val'] 253 | for split in splits: 254 | losses = torch.zeros(eval_iters) 255 | for k in range(eval_iters): 256 | X, Y = get_batch(split) 257 | with ctx: 258 | logits, loss = model(X, Y) 259 | losses[k] = loss.item() 260 | out[split] = losses.mean().item() 261 | if skip_val_loss: 262 | out['val'] = -1 263 | model.train() 264 | return out 265 | 266 | # learning rate decay scheduler (cosine with warmup) 267 | # def get_lr(it): 268 | # # 1) linear warmup for warmup_iters steps 269 | # if it < warmup_iters: 270 | # return learning_rate * it / warmup_iters 271 | # # 2) if it > lr_decay_iters, return min learning rate 272 | # if it > lr_decay_iters: 273 | # return min_lr 274 | # # 3) in between, use cosine decay down to min learning rate 275 | # decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 276 | # assert 0 <= decay_ratio <= 1 277 | # coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 278 | # return min_lr + coeff * (learning_rate - min_lr) 279 | 280 | # FW_CHANGE: Use linear decay schedule 281 | def get_lr(it): 282 | # 1) linear warmup for warmup_iters steps 283 | if it < warmup_iters: 284 | return learning_rate * it / warmup_iters 285 | # 2) if it > lr_decay_iters, return min learning rate 286 | if it > lr_decay_iters: 287 | return min_lr 288 | # 3) in between, use linear decay down to min learning rate 289 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 290 | assert 0 <= decay_ratio <= 1 291 | return learning_rate - decay_ratio * (learning_rate - min_lr) 292 | 293 | 294 | # logging 295 | if master_process: 296 | if wandb_log: 297 | import wandb 298 | wandb_run = wandb.init(project=wandb_project, name=wandb_run_name, config=config) 299 | if csv_log: 300 | from csv_logging import CSVLogWrapper 301 | def log(log_dict): 302 | pass 303 | csv_logger = CSVLogWrapper(log, config=config, out_dir=out_dir) 304 | 305 | # training loop 306 | X, Y = get_batch('train') # fetch the very first batch 307 | t0 = time.time() 308 | local_iter_num = 0 # number of iterations in the lifetime of this process 309 | raw_model = model.module if ddp else model # unwrap DDP container if needed 310 | running_mfu = -1.0 311 | coord_check_dict = None 312 | while True: 313 | 314 | # determine and set the learning rate for this iteration 315 | lr = get_lr(iter_num) if decay_lr else learning_rate 316 | for param_group in optimizer.param_groups: 317 | param_group['lr'] = lr * param_group.get('lr_scale', 1.0) 318 | 319 | # evaluate the loss on train/val sets and write checkpoints 320 | # FW_CHANGE: allow only evaluation at the end (so that we have budget to compute a better estimate) 321 | do_eval = (not eval_on_end and iter_num % eval_interval == 0) or (eval_on_end and iter_num >= max_iters) 322 | if do_eval and master_process: 323 | losses = estimate_loss() 324 | if np.isnan(losses['train']): 325 | raise Exception('NaN loss') 326 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 327 | log_dict = { 328 | "iter": iter_num, 329 | "train/loss": losses['train'], 330 | "val/loss": losses['val'], 331 | "lr": lr, 332 | "mfu": running_mfu*100, # convert to percentage 333 | } 334 | if mup_enable_coord_check_logging and coord_check_dict is not None: 335 | for key in coord_check_dict: 336 | log_dict[key + '_act_abs_mean'] = np.mean(coord_check_dict[key]) 337 | if wandb_log: 338 | wandb_run.log(log_dict) 339 | if csv_log: 340 | csv_logger.log(log_dict) 341 | csv_logger.step() 342 | if (not never_save_checkpoint) and (losses['val'] < best_val_loss or always_save_checkpoint): 343 | best_val_loss = losses['val'] 344 | if iter_num > 0: 345 | checkpoint = { 346 | 'model': raw_model.state_dict(), 347 | 'optimizer': optimizer.state_dict(), 348 | 'model_args': model_args, 349 | 'iter_num': iter_num, 350 | 'best_val_loss': best_val_loss, 351 | 'config': config, 352 | } 353 | print(f"saving checkpoint to {out_dir}") 354 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 355 | if iter_num == 0 and eval_only: 356 | break 357 | 358 | if mup_enable_coord_check_logging: 359 | coord_check_dict = { 360 | 'token_embedding': [], 361 | 'attn': [], 362 | 'mlp': [], 363 | 'lm_head': [], 364 | } 365 | def hook(module, input, output, key): 366 | with torch.no_grad(): 367 | coord_check_dict[key].append(output.abs().mean().item()) 368 | coord_check_handles = [] 369 | for module_name, module in model.named_modules(): 370 | if module_name == 'transformer.wte': 371 | coord_check_handles.append(module.register_forward_hook(partial(hook, key='token_embedding'))) 372 | elif module_name.endswith('.attn'): 373 | coord_check_handles.append(module.register_forward_hook(partial(hook, key='attn'))) 374 | elif module_name.endswith('.mlp'): 375 | coord_check_handles.append(module.register_forward_hook(partial(hook, key='mlp'))) 376 | elif module_name == 'lm_head': 377 | coord_check_handles.append(module.register_forward_hook(partial(hook, key='lm_head'))) 378 | else: 379 | coord_check_dict = None 380 | 381 | # forward backward update, with optional gradient accumulation to simulate larger batch size 382 | # and using the GradScaler if data type is float16 383 | for micro_step in range(gradient_accumulation_steps): 384 | if ddp: 385 | # in DDP training we only need to sync gradients at the last micro step. 386 | # the official way to do this is with model.no_sync() context manager, but 387 | # I really dislike that this bloats the code and forces us to repeat code 388 | # looking at the source of that context manager, it just toggles this variable 389 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 390 | with ctx: 391 | logits, loss = model(X, Y) 392 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation 393 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 394 | X, Y = get_batch('train') 395 | # backward pass, with gradient scaling if training in fp16 396 | scaler.scale(loss).backward() 397 | # clip the gradient 398 | if grad_clip != 0.0: 399 | scaler.unscale_(optimizer) 400 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 401 | # step the optimizer and scaler if training in fp16 402 | scaler.step(optimizer) 403 | scaler.update() 404 | # flush the gradients as soon as we can, no need for this memory anymore 405 | optimizer.zero_grad(set_to_none=True) 406 | 407 | # timing and logging 408 | t1 = time.time() 409 | dt = t1 - t0 410 | t0 = t1 411 | if iter_num % log_interval == 0 and master_process: 412 | # get loss as float. note: this is a CPU-GPU sync point 413 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum) 414 | lossf = loss.item() * gradient_accumulation_steps 415 | if local_iter_num >= 5: # let the training loop settle a bit 416 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 417 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 418 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 419 | iter_num += 1 420 | local_iter_num += 1 421 | 422 | if mup_enable_coord_check_logging: 423 | for handle in coord_check_handles: 424 | handle.remove() 425 | 426 | # termination conditions 427 | if iter_num > max_iters: 428 | break 429 | 430 | if ddp: 431 | destroy_process_group() 432 | -------------------------------------------------------------------------------- /examples/airbench/airbench_scion.py: -------------------------------------------------------------------------------- 1 | """ 2 | airbench94_spectral.py 3 | Runs in 2.67 seconds on a 400W NVIDIA A100 4 | Attains 94.04 mean accuracy (n=200 trials) 5 | """ 6 | 7 | ############################################# 8 | # Setup/Hyperparameters # 9 | ############################################# 10 | 11 | import os 12 | import sys 13 | import uuid 14 | import math 15 | from math import ceil 16 | 17 | import numpy as np 18 | import torch 19 | from torch import nn 20 | import torch.nn.functional as F 21 | import torchvision 22 | import torchvision.transforms as T 23 | 24 | import wandb 25 | from scion import Scion 26 | 27 | # ADDED 28 | import torch._dynamo 29 | torch._dynamo.config.suppress_errors = True 30 | 31 | torch.backends.cudnn.benchmark = True 32 | 33 | # We express the main training hyperparameters (batch size, learning rate, momentum, and weight decay) 34 | # in decoupled form, so that each one can be tuned independently. This accomplishes the following: 35 | # * Assuming time-constant gradients, the average step size is decoupled from everything but the lr. 36 | # * The size of the weight decay update is decoupled from everything but the wd. 37 | # In constrast, normally when we increase the (Nesterov) momentum, this also scales up the step size 38 | # proportionally to 1 + 1 / (1 - momentum), meaning we cannot change momentum without having to re-tune 39 | # the learning rate. Similarly, normally when we increase the learning rate this also increases the size 40 | # of the weight decay, requiring a proportional decrease in the wd to maintain the same decay strength. 41 | # 42 | # The practical impact is that hyperparameter tuning is faster, since this parametrization allows each 43 | # one to be tuned independently. See https://myrtle.ai/learn/how-to-train-your-resnet-5-hyperparameters/. 44 | 45 | hyp = { 46 | 'meta': { 47 | 'runs': 5, 48 | }, 49 | 'opt': { 50 | 'svd_backend': 'newton', 51 | 'train_epochs': 8, 52 | 'batch_size': 2000, 53 | 'lr': 6.5, # learning rate per 1024 examples 54 | 'momentum': 0.85, 55 | 'weight_decay': 0.015, # weight decay per 1024 examples (decoupled from learning rate) 56 | 'bias_scaler': 64.0, # scales up learning rate (but not weight decay) for BatchNorm biases 57 | 'label_smoothing': 0.2, 58 | 'whiten_bias_epochs': 0, # how many epochs to train the whitening layer bias before freezing 59 | }, 60 | 'aug': { 61 | 'flip': True, 62 | 'translate': 2, 63 | }, 64 | 'net': { 65 | 'widths': { 66 | 'block1': 64, 67 | 'block2': 256, 68 | 'block3': 256, 69 | }, 70 | 'batchnorm_momentum': 0.6, 71 | 'scaling_factor': 1/9, 72 | 'tta_level': 2, # the level of test-time augmentation: 0=none, 1=mirror, 2=mirror+translate 73 | }, 74 | } 75 | 76 | 77 | ############################################# 78 | # DataLoader # 79 | ############################################# 80 | 81 | CIFAR_MEAN = torch.tensor((0.4914, 0.4822, 0.4465)) 82 | CIFAR_STD = torch.tensor((0.2470, 0.2435, 0.2616)) 83 | 84 | def batch_flip_lr(inputs): 85 | flip_mask = (torch.rand(len(inputs), device=inputs.device) < 0.5).view(-1, 1, 1, 1) 86 | return torch.where(flip_mask, inputs.flip(-1), inputs) 87 | 88 | def batch_crop(images, crop_size): 89 | r = (images.size(-1) - crop_size)//2 90 | shifts = torch.randint(-r, r+1, size=(len(images), 2), device=images.device) 91 | images_out = torch.empty((len(images), 3, crop_size, crop_size), device=images.device, dtype=images.dtype) 92 | # The two cropping methods in this if-else produce equivalent results, but the second is faster for r > 2. 93 | if r <= 2: 94 | for sy in range(-r, r+1): 95 | for sx in range(-r, r+1): 96 | mask = (shifts[:, 0] == sy) & (shifts[:, 1] == sx) 97 | images_out[mask] = images[mask, :, r+sy:r+sy+crop_size, r+sx:r+sx+crop_size] 98 | else: 99 | images_tmp = torch.empty((len(images), 3, crop_size, crop_size+2*r), device=images.device, dtype=images.dtype) 100 | for s in range(-r, r+1): 101 | mask = (shifts[:, 0] == s) 102 | images_tmp[mask] = images[mask, :, r+s:r+s+crop_size, :] 103 | for s in range(-r, r+1): 104 | mask = (shifts[:, 1] == s) 105 | images_out[mask] = images_tmp[mask, :, :, r+s:r+s+crop_size] 106 | return images_out 107 | 108 | class CifarLoader: 109 | 110 | def __init__(self, path, train=True, batch_size=500, aug=None, drop_last=None, shuffle=None, gpu=0): 111 | data_path = os.path.join(path, 'train.pt' if train else 'test.pt') 112 | if not os.path.exists(data_path): 113 | dset = torchvision.datasets.CIFAR10(path, download=True, train=train) 114 | images = torch.tensor(dset.data) 115 | labels = torch.tensor(dset.targets) 116 | torch.save({'images': images, 'labels': labels, 'classes': dset.classes}, data_path) 117 | 118 | data = torch.load(data_path, map_location=torch.device(gpu)) 119 | self.images, self.labels, self.classes = data['images'], data['labels'], data['classes'] 120 | # It's faster to load+process uint8 data than to load preprocessed fp16 data 121 | self.images = (self.images.half() / 255).permute(0, 3, 1, 2).to(memory_format=torch.channels_last) 122 | 123 | self.normalize = T.Normalize(CIFAR_MEAN, CIFAR_STD) 124 | self.proc_images = {} # Saved results of image processing to be done on the first epoch 125 | self.epoch = 0 126 | 127 | self.aug = aug or {} 128 | for k in self.aug.keys(): 129 | assert k in ['flip', 'translate'], 'Unrecognized key: %s' % k 130 | 131 | self.batch_size = batch_size 132 | self.drop_last = train if drop_last is None else drop_last 133 | self.shuffle = train if shuffle is None else shuffle 134 | 135 | def __len__(self): 136 | return len(self.images)//self.batch_size if self.drop_last else ceil(len(self.images)/self.batch_size) 137 | 138 | def __iter__(self): 139 | 140 | if self.epoch == 0: 141 | images = self.proc_images['norm'] = self.normalize(self.images) 142 | # Pre-flip images in order to do every-other epoch flipping scheme 143 | if self.aug.get('flip', False): 144 | images = self.proc_images['flip'] = batch_flip_lr(images) 145 | # Pre-pad images to save time when doing random translation 146 | pad = self.aug.get('translate', 0) 147 | if pad > 0: 148 | self.proc_images['pad'] = F.pad(images, (pad,)*4, 'reflect') 149 | 150 | if self.aug.get('translate', 0) > 0: 151 | images = batch_crop(self.proc_images['pad'], self.images.shape[-2]) 152 | elif self.aug.get('flip', False): 153 | images = self.proc_images['flip'] 154 | else: 155 | images = self.proc_images['norm'] 156 | # Flip all images together every other epoch. This increases diversity relative to random flipping 157 | if self.aug.get('flip', False): 158 | if self.epoch % 2 == 1: 159 | images = images.flip(-1) 160 | 161 | self.epoch += 1 162 | 163 | indices = (torch.randperm if self.shuffle else torch.arange)(len(images), device=images.device) 164 | for i in range(len(self)): 165 | idxs = indices[i*self.batch_size:(i+1)*self.batch_size] 166 | yield (images[idxs], self.labels[idxs]) 167 | 168 | ############################################# 169 | # Network Components # 170 | ############################################# 171 | 172 | class Flatten(nn.Module): 173 | def forward(self, x): 174 | return x.view(x.size(0), -1) 175 | 176 | class Mul(nn.Module): 177 | def __init__(self, scale): 178 | super().__init__() 179 | self.scale = scale 180 | def forward(self, x): 181 | return x * self.scale 182 | 183 | class BatchNorm(nn.BatchNorm2d): 184 | def __init__(self, num_features, momentum, eps=1e-12, 185 | weight=False, bias=True): 186 | super().__init__(num_features, eps=eps, momentum=1-momentum) 187 | self.weight.requires_grad = weight 188 | self.bias.requires_grad = bias 189 | # Note that PyTorch already initializes the weights to one and bias to zero 190 | 191 | class Conv(nn.Conv2d): 192 | def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False): 193 | super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) 194 | 195 | def reset_parameters(self): 196 | super().reset_parameters() 197 | if self.bias is not None: 198 | self.bias.data.zero_() 199 | w = self.weight.data 200 | torch.nn.init.dirac_(w[:w.size(1)]) 201 | 202 | class ConvGroup(nn.Module): 203 | def __init__(self, channels_in, channels_out, batchnorm_momentum): 204 | super().__init__() 205 | self.conv1 = Conv(channels_in, channels_out) 206 | self.pool = nn.MaxPool2d(2) 207 | self.norm1 = BatchNorm(channels_out, batchnorm_momentum) 208 | self.conv2 = Conv(channels_out, channels_out) 209 | self.norm2 = BatchNorm(channels_out, batchnorm_momentum) 210 | self.activ = nn.GELU() 211 | 212 | def forward(self, x): 213 | x = self.conv1(x) 214 | x = self.pool(x) 215 | x = self.norm1(x) 216 | x = self.activ(x) 217 | x = self.conv2(x) 218 | x = self.norm2(x) 219 | x = self.activ(x) 220 | return x 221 | 222 | ############################################# 223 | # Network Definition # 224 | ############################################# 225 | 226 | def make_net(widths): 227 | batchnorm_momentum = hyp['net']['batchnorm_momentum'] 228 | whiten_kernel_size = 2 229 | whiten_width = 2 * 3 * whiten_kernel_size**2 230 | net = nn.Sequential( 231 | Conv(3, whiten_width, whiten_kernel_size, padding=0, bias=True), 232 | nn.GELU(), 233 | ConvGroup(whiten_width, widths['block1'], batchnorm_momentum), 234 | ConvGroup(widths['block1'], widths['block2'], batchnorm_momentum), 235 | ConvGroup(widths['block2'], widths['block3'], batchnorm_momentum), 236 | nn.MaxPool2d(3), 237 | Flatten(), 238 | nn.Linear(widths['block3'], 10, bias=False), 239 | Mul(hyp['net']['scaling_factor']), 240 | ) 241 | net[0].weight.requires_grad = False 242 | net = net.half().cuda() 243 | net = net.to(memory_format=torch.channels_last) 244 | for mod in net.modules(): 245 | if isinstance(mod, BatchNorm): 246 | mod.float() 247 | return net 248 | 249 | 250 | def reinit_net(model): 251 | for m in model.modules(): 252 | if type(m) in (Conv, BatchNorm, nn.Linear): 253 | m.reset_parameters() 254 | # if type(m) in (Conv,): 255 | # conv_init(m.weight) 256 | # if type(m) in (nn.Linear,): 257 | # embedding_init(m.weight) 258 | #linear_init(m.weight) 259 | 260 | 261 | ############################################# 262 | # Whitening Conv Initialization # 263 | ############################################# 264 | 265 | def get_patches(x, patch_shape): 266 | c, (h, w) = x.shape[1], patch_shape 267 | return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float() 268 | 269 | def get_whitening_parameters(patches): 270 | n,c,h,w = patches.shape 271 | patches_flat = patches.view(n, -1) 272 | est_patch_covariance = (patches_flat.T @ patches_flat) / n 273 | eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO='U') 274 | return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.T.reshape(c*h*w,c,h,w).flip(0) 275 | 276 | def init_whitening_conv(layer, train_set, eps=5e-4): 277 | patches = get_patches(train_set, patch_shape=layer.weight.data.shape[2:]) 278 | eigenvalues, eigenvectors = get_whitening_parameters(patches) 279 | eigenvectors_scaled = eigenvectors / torch.sqrt(eigenvalues + eps) 280 | layer.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled)) 281 | 282 | ############################################ 283 | # Logging # 284 | ############################################ 285 | 286 | def print_columns(columns_list, is_head=False, is_final_entry=False): 287 | print_string = '' 288 | for col in columns_list: 289 | print_string += '| %s ' % col 290 | print_string += '|' 291 | if is_head: 292 | print('-'*len(print_string)) 293 | print(print_string) 294 | if is_head or is_final_entry: 295 | print('-'*len(print_string)) 296 | 297 | logging_columns_list = ['run ', 'epoch', 'train_loss', 'train_acc', 'val_acc', 'tta_val_acc', 'total_time_seconds'] 298 | def print_training_details(variables, is_final_entry): 299 | formatted = [] 300 | for col in logging_columns_list: 301 | var = variables.get(col.strip(), None) 302 | if type(var) in (int, str): 303 | res = str(var) 304 | elif type(var) is float: 305 | res = '{:0.4f}'.format(var) 306 | else: 307 | assert var is None 308 | res = '' 309 | formatted.append(res.rjust(len(col))) 310 | print_columns(formatted, is_final_entry=is_final_entry) 311 | 312 | ############################################ 313 | # Evaluation # 314 | ############################################ 315 | 316 | def infer(model, loader, tta_level=0): 317 | 318 | # Test-time augmentation strategy (for tta_level=2): 319 | # 1. Flip/mirror the image left-to-right (50% of the time). 320 | # 2. Translate the image by one pixel either up-and-left or down-and-right (50% of the time, 321 | # i.e. both happen 25% of the time). 322 | # 323 | # This creates 6 views per image (left/right times the two translations and no-translation), 324 | # which we evaluate and then weight according to the given probabilities. 325 | 326 | def infer_basic(inputs, net): 327 | return net(inputs).clone() 328 | 329 | def infer_mirror(inputs, net): 330 | return 0.5 * net(inputs) + 0.5 * net(inputs.flip(-1)) 331 | 332 | def infer_mirror_translate(inputs, net): 333 | logits = infer_mirror(inputs, net) 334 | pad = 1 335 | padded_inputs = F.pad(inputs, (pad,)*4, 'reflect') 336 | inputs_translate_list = [ 337 | padded_inputs[:, :, 0:32, 0:32], 338 | padded_inputs[:, :, 2:34, 2:34], 339 | ] 340 | logits_translate_list = [infer_mirror(inputs_translate, net) 341 | for inputs_translate in inputs_translate_list] 342 | logits_translate = torch.stack(logits_translate_list).mean(0) 343 | return 0.5 * logits + 0.5 * logits_translate 344 | 345 | model.eval() 346 | test_images = loader.normalize(loader.images) 347 | infer_fn = [infer_basic, infer_mirror, infer_mirror_translate][tta_level] 348 | with torch.no_grad(): 349 | return torch.cat([infer_fn(inputs, model) for inputs in test_images.split(2000)]) 350 | 351 | def evaluate(model, loader, tta_level=0): 352 | logits = infer(model, loader, tta_level) 353 | return (logits.argmax(1) == loader.labels).float().mean().item() 354 | 355 | ############################################ 356 | # Training # 357 | ############################################ 358 | 359 | def main(run, model_trainbias, model_freezebias, lr, momentum): 360 | batch_size = hyp['opt']['batch_size'] 361 | epochs = hyp['opt']['train_epochs'] 362 | #momentum = hyp['opt']['momentum'] 363 | # Assuming gradients are constant in time, for Nesterov momentum, the below ratio is how much 364 | # larger the default steps will be than the underlying per-example gradients. We divide the 365 | # learning rate by this ratio in order to ensure steps are the same scale as gradients, regardless 366 | # of the choice of momentum. 367 | #kilostep_scale = 1024 * (1 + 1 / (1 - momentum)) 368 | #lr = hyp['opt']['lr'] / kilostep_scale # un-decoupled learning rate for PyTorch SGD 369 | #wd = hyp['opt']['weight_decay'] * batch_size / kilostep_scale 370 | #lr_biases = lr * hyp['opt']['bias_scaler'] 371 | 372 | loss_fn = nn.CrossEntropyLoss(label_smoothing=hyp['opt']['label_smoothing'], reduction='none') 373 | 374 | test_loader = CifarLoader('cifar10', train=False, batch_size=2000) 375 | train_loader = CifarLoader('cifar10', train=True, batch_size=batch_size, aug=hyp['aug']) 376 | if run == 'warmup': 377 | # The only purpose of the first run is to warmup the compiled model, so we can use dummy data 378 | train_loader.labels = torch.randint(0, 10, size=(len(train_loader.labels),), device=train_loader.labels.device) 379 | total_train_steps = ceil(len(train_loader) * epochs) 380 | 381 | # Create optimizers for train whiten bias stage 382 | # (in practice we don't need this stage since `whiten_bias_epochs=0` but we keep for ablation) 383 | model = model_trainbias 384 | # print([k for k,v in model.named_parameters()]) 385 | first_layer = model._orig_mod[0].weight 386 | fc_layer = model._orig_mod[-2].weight 387 | whiten_bias = model._orig_mod[0].bias 388 | radius = 8.0 389 | head_radius = 128 390 | parameters = [ 391 | dict(norm="SpectralConv", norm_kwargs={'steps': 9}, scale=radius, params=[first_layer]), 392 | dict(norm='SpectralConv', norm_kwargs={'steps': 9}, scale=radius, params=[p for n, p in model.named_parameters() if len(p.shape) == 4 and p.requires_grad and "0.weight" not in n]), 393 | dict(norm='BiasRMS', scale=radius, params=[p for n, p in model.named_parameters() if 'norm' in n and p.requires_grad]), 394 | dict(norm='Sign', scale=head_radius, params=[fc_layer]), 395 | ] 396 | optimizer = Scion(parameters, lr=lr, momentum=momentum) 397 | optimizer_trainbias = optimizer 398 | optimizer2_trainbias = Scion(norm='BiasRMS', scale=radius, params=[whiten_bias], lr=lr, momentum=momentum) 399 | 400 | # Create optimizers for frozen whiten bias stage 401 | model = model_freezebias 402 | first_layer = model._orig_mod[0].weight 403 | fc_layer = model._orig_mod[-2].weight 404 | radius = 8.0 405 | head_radius = 128 406 | parameters = [ 407 | dict(norm="SpectralConv", norm_kwargs={'steps': 9}, scale=radius, params=[first_layer]), 408 | dict(norm='SpectralConv', norm_kwargs={'steps': 9}, scale=radius, params=[p for n, p in model.named_parameters() if len(p.shape) == 4 and p.requires_grad and "0.weight" not in n]), 409 | dict(norm='BiasRMS', scale=radius, params=[p for n, p in model.named_parameters() if 'norm' in n and p.requires_grad]), 410 | dict(norm='Sign', scale=head_radius, params=[fc_layer]), 411 | ] 412 | optimizer = Scion(parameters, lr=lr, momentum=momentum) 413 | optimizer_freezebias = optimizer 414 | 415 | # Make learning rate schedulers for all 5 optimizers 416 | def get_lr(step): 417 | return 1 - step / total_train_steps 418 | scheduler_trainbias = torch.optim.lr_scheduler.LambdaLR(optimizer_trainbias, get_lr) 419 | scheduler2_trainbias = torch.optim.lr_scheduler.LambdaLR(optimizer2_trainbias, get_lr) 420 | scheduler_freezebias = torch.optim.lr_scheduler.LambdaLR(optimizer_freezebias, get_lr) 421 | 422 | # Reinitialize the network from scratch - nothing is reused from previous runs besides the PyTorch compilation 423 | reinit_net(model_trainbias) 424 | current_steps = 0 425 | # for optimizer in [optimizer1_trainbias, optimizer2_trainbias, optimizer22_trainbias, optimizer3_trainbias]: 426 | # optimizer.init() 427 | 428 | # For accurately timing GPU code 429 | starter = torch.cuda.Event(enable_timing=True) 430 | ender = torch.cuda.Event(enable_timing=True) 431 | total_time_seconds = 0.0 432 | 433 | # Initialize the whitening layer using training images 434 | starter.record() 435 | train_images = train_loader.normalize(train_loader.images[:5000]) 436 | init_whitening_conv(model_trainbias._orig_mod[0], train_images) 437 | ender.record() 438 | torch.cuda.synchronize() 439 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 440 | 441 | # import norm_logging 442 | # for k,v in norm_logging.log_spectral_norms(model).items(): 443 | # print(f"{k}: {v}") 444 | 445 | for epoch in range(ceil(epochs)): 446 | 447 | # After training the whiten bias for some epochs, swap in the compiled model with frozen bias 448 | if epoch == 0: 449 | model = model_trainbias 450 | optimizers = [optimizer_trainbias, optimizer2_trainbias] 451 | schedulers = [scheduler_trainbias, scheduler2_trainbias] 452 | elif epoch == hyp['opt']['whiten_bias_epochs']: 453 | model = model_freezebias 454 | old_optimizers = optimizers 455 | old_schedulers = schedulers 456 | optimizers = [optimizer_freezebias] 457 | schedulers = [scheduler_freezebias] 458 | model.load_state_dict(model_trainbias.state_dict()) 459 | for i, (opt, sched) in enumerate(zip(optimizers, schedulers)): 460 | opt.load_state_dict(old_optimizers[i].state_dict()) 461 | sched.load_state_dict(old_schedulers[i].state_dict()) 462 | 463 | #################### 464 | # Training # 465 | #################### 466 | 467 | starter.record() 468 | 469 | model.train() 470 | for inputs, labels in train_loader: 471 | outputs = model(inputs) 472 | loss = loss_fn(outputs, labels).sum() 473 | model.zero_grad(set_to_none=True) 474 | loss.backward() 475 | for opt, sched in zip(optimizers, schedulers): 476 | opt.step() 477 | sched.step() 478 | current_steps += 1 479 | if current_steps >= total_train_steps: 480 | break 481 | 482 | ender.record() 483 | torch.cuda.synchronize() 484 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 485 | 486 | #################### 487 | # Evaluation # 488 | #################### 489 | 490 | # Save the accuracy and loss from the last training batch of the epoch 491 | train_acc = (outputs.detach().argmax(1) == labels).float().mean().item() 492 | train_loss = loss.item() / batch_size 493 | val_acc = evaluate(model, test_loader, tta_level=0) 494 | print_training_details(locals(), is_final_entry=False) 495 | run = None # Only print the run number once 496 | 497 | #################### 498 | # TTA Evaluation # 499 | #################### 500 | 501 | starter.record() 502 | tta_val_acc = evaluate(model, test_loader, tta_level=hyp['net']['tta_level']) 503 | ender.record() 504 | torch.cuda.synchronize() 505 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 506 | 507 | epoch = 'eval' 508 | print_training_details(locals(), is_final_entry=True) 509 | 510 | # import norm_logging 511 | # for k,v in norm_logging.log_spectral_norms(model).items(): 512 | # print(f"{k}: {v}") 513 | 514 | return tta_val_acc 515 | 516 | 517 | ####################################################### 518 | # Run 519 | ####################################################### 520 | 521 | if __name__ == "__main__": 522 | with open(sys.argv[0]) as f: 523 | code = f.read() 524 | 525 | #for width_factor in [0.5, 1.0, 2.0, 4.0]: 526 | for width_factor in [1.0]: 527 | widths = {k: round(width_factor*v) for k,v in hyp['net']['widths'].items()} 528 | # These two compiled models are first warmed up, and then reinitialized every run. No learned 529 | # weights are reused between runs. To implement freezing of the whitening-layer bias parameter 530 | # midway through training, we use two compiled models, one with trainable and the other with 531 | # frozen whitening bias. This is faster than the naive approach of setting requires_grad=False 532 | # on the whitening bias midway through training on a single compiled model. 533 | model_trainbias = make_net(widths) 534 | model_freezebias = make_net(widths) 535 | model_freezebias[0].bias.requires_grad = False 536 | model_trainbias = torch.compile(model_trainbias)#, mode='max-autotune') 537 | model_freezebias = torch.compile(model_freezebias)#, mode='max-autotune') 538 | 539 | print_columns(logging_columns_list, is_head=True) 540 | main('warmup', model_trainbias, model_freezebias, lr=0.05, momentum=0.5) 541 | 542 | #for log2lr in np.linspace(-9, 0, 10): 543 | for log2lr in [math.log2(0.05)]: 544 | momentum = 0.6 545 | # run = wandb.init( 546 | # project="MYPROJECT", 547 | # entity="MYENTITY", 548 | # name=f"v1|scion|transfer|width-factor={width_factor}|log2lr={log2lr}", 549 | # tags=['transfer-v1'], 550 | # config={'momentum': momentum, 'log2lr': log2lr, 'method': 'scion', 'width-factor': width_factor}, reinit=True) 551 | 552 | accs = torch.tensor([main(run, model_trainbias, model_freezebias, lr=2**log2lr, momentum=momentum) for run in range(50)]) 553 | print('lr=%d width_facto=%.1f - Mean: %.4f Std: %.4f' % (log2lr, width_factor, accs.mean(), accs.std())) 554 | # wandb.log({'test_acc_mean': accs.mean(), 'test_acc_std': accs.std()}) 555 | # run.finish() 556 | 557 | --------------------------------------------------------------------------------