├── mup_examples ├── requirements.txt ├── coord_check_shakespeare_char │ ├── sp │ │ └── run.sh │ ├── mup │ │ └── run.sh │ ├── sp_with_mup_hidden_init │ │ └── run.sh │ ├── sp_with_mup_hidden_init_and_lr │ │ └── run.sh │ ├── sp_with_mup_hidden_init_and_lr_output_logits │ │ └── run.sh │ └── sp_with_mup_hidden_init_and_lr_partial_output_logits │ │ └── run.sh ├── mutransfer_lr_shakespeare_char │ ├── sp │ │ └── run.sh │ ├── mup │ │ └── run.sh │ └── plot.ipynb ├── mutransfer_lr_owt │ ├── sp │ │ └── run.sh │ └── mup │ │ └── run.sh └── README.md ├── assets ├── nanogpt.jpg ├── coord_check_sp.png ├── gpt2_124M_loss.png ├── coord_check_mup.png ├── mutransfer_lr_owt.png └── mutransfer_lr_shakespeare_char.png ├── .gitignore ├── data ├── shakespeare │ ├── readme.md │ └── prepare.py ├── shakespeare_char │ ├── readme.md │ └── prepare.py └── openwebtext │ ├── readme.md │ └── prepare.py ├── .gitattributes ├── config ├── eval_gpt2.py ├── eval_gpt2_xl.py ├── eval_gpt2_large.py ├── eval_gpt2_medium.py ├── finetune_shakespeare.py ├── train_gpt2.py └── train_shakespeare_char.py ├── LICENSE ├── configurator.py ├── sample.py ├── bench.py ├── csv_logging.py ├── transformer_sizing.ipynb ├── README.md ├── train.py └── model.py /mup_examples/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.13.0 2 | seaborn 3 | -------------------------------------------------------------------------------- /assets/nanogpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/nanoGPT-mup/HEAD/assets/nanogpt.jpg -------------------------------------------------------------------------------- /assets/coord_check_sp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/nanoGPT-mup/HEAD/assets/coord_check_sp.png -------------------------------------------------------------------------------- /assets/gpt2_124M_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/nanoGPT-mup/HEAD/assets/gpt2_124M_loss.png -------------------------------------------------------------------------------- /assets/coord_check_mup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/nanoGPT-mup/HEAD/assets/coord_check_mup.png -------------------------------------------------------------------------------- /assets/mutransfer_lr_owt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/nanoGPT-mup/HEAD/assets/mutransfer_lr_owt.png -------------------------------------------------------------------------------- /assets/mutransfer_lr_shakespeare_char.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/nanoGPT-mup/HEAD/assets/mutransfer_lr_shakespeare_char.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | .ipynb_checkpoints/ 4 | .vscode 5 | __pycache__/ 6 | *.bin 7 | *.pkl 8 | *.pt 9 | *.pyc 10 | input.txt 11 | env/ 12 | venv/ 13 | mup_examples/*/*/out/* -------------------------------------------------------------------------------- /data/shakespeare/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # tiny shakespeare 3 | 4 | Tiny shakespeare, of the good old char-rnn fame :) 5 | 6 | After running `prepare.py`: 7 | 8 | - train.bin has 301,966 tokens 9 | - val.bin has 36,059 tokens 10 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Override jupyter in Github language stats for more accurate estimate of repo code languages 2 | # reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code 3 | *.ipynb linguist-generated 4 | -------------------------------------------------------------------------------- /config/eval_gpt2.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=12, n_head=12, n_embd=768 3 | # 124M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2' 9 | -------------------------------------------------------------------------------- /config/eval_gpt2_xl.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=48, n_head=25, n_embd=1600 3 | # 1558M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-xl' 9 | -------------------------------------------------------------------------------- /config/eval_gpt2_large.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=36, n_head=20, n_embd=1280 3 | # 774M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-large' 9 | -------------------------------------------------------------------------------- /config/eval_gpt2_medium.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=24, n_head=16, n_embd=1024 3 | # 350M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-medium' 9 | -------------------------------------------------------------------------------- /data/shakespeare_char/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # tiny shakespeare, character-level 3 | 4 | Tiny shakespeare, of the good old char-rnn fame :) Treated on character-level. 5 | 6 | After running `prepare.py`: 7 | 8 | - train.bin has 1,003,854 tokens 9 | - val.bin has 111,540 tokens 10 | -------------------------------------------------------------------------------- /data/openwebtext/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## openwebtext dataset 3 | 4 | after running `prepare.py` (preprocess) we get: 5 | 6 | - train.bin is ~17GB, val.bin ~8.5MB 7 | - train has ~9B tokens (9,035,582,198) 8 | - val has ~4M tokens (4,434,897) 9 | 10 | this came from 8,013,769 documents in total. 11 | 12 | references: 13 | 14 | - OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) 15 | - [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset 16 | -------------------------------------------------------------------------------- /config/finetune_shakespeare.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | out_dir = 'out-shakespeare' 4 | eval_interval = 5 5 | eval_iters = 40 6 | wandb_log = False # feel free to turn on 7 | wandb_project = 'shakespeare' 8 | wandb_run_name = 'ft-' + str(time.time()) 9 | 10 | dataset = 'shakespeare' 11 | init_from = 'gpt2-xl' # this is the largest GPT-2 model 12 | 13 | # only save checkpoints if the validation loss improves 14 | always_save_checkpoint = False 15 | 16 | # the number of examples per iter: 17 | # 1 batch_size * 32 grad_accum * 1024 tokens = 32,768 tokens/iter 18 | # shakespeare has 301,966 tokens, so 1 epoch ~= 9.2 iters 19 | batch_size = 1 20 | gradient_accumulation_steps = 32 21 | max_iters = 20 22 | 23 | # finetune at constant LR 24 | learning_rate = 3e-5 25 | decay_lr = False 26 | -------------------------------------------------------------------------------- /config/train_gpt2.py: -------------------------------------------------------------------------------- 1 | # config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB 2 | # launch as the following (e.g. in a screen session) and wait ~5 days: 3 | # $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py 4 | 5 | wandb_log = True 6 | wandb_project = 'owt' 7 | wandb_run_name='gpt2-124M' 8 | 9 | # these make the total batch size be ~0.5M 10 | # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520 11 | batch_size = 12 12 | block_size = 1024 13 | gradient_accumulation_steps = 5 * 8 14 | 15 | # this makes total number of tokens be 300B 16 | max_iters = 600000 17 | lr_decay_iters = 600000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # weight decay 25 | weight_decay = 1e-1 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrej Karpathy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/shakespeare/prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import tiktoken 4 | import numpy as np 5 | 6 | # download the tiny shakespeare dataset 7 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') 8 | if not os.path.exists(input_file_path): 9 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 10 | with open(input_file_path, 'w', encoding='utf-8') as f: 11 | f.write(requests.get(data_url).text) 12 | 13 | with open(input_file_path, 'r', encoding='utf-8') as f: 14 | data = f.read() 15 | n = len(data) 16 | train_data = data[:int(n*0.9)] 17 | val_data = data[int(n*0.9):] 18 | 19 | # encode with tiktoken gpt2 bpe 20 | enc = tiktoken.get_encoding("gpt2") 21 | train_ids = enc.encode_ordinary(train_data) 22 | val_ids = enc.encode_ordinary(val_data) 23 | print(f"train has {len(train_ids):,} tokens") 24 | print(f"val has {len(val_ids):,} tokens") 25 | 26 | # export to bin files 27 | train_ids = np.array(train_ids, dtype=np.uint16) 28 | val_ids = np.array(val_ids, dtype=np.uint16) 29 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) 30 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) 31 | 32 | # train.bin has 301,966 tokens 33 | # val.bin has 36,059 tokens 34 | -------------------------------------------------------------------------------- /config/train_shakespeare_char.py: -------------------------------------------------------------------------------- 1 | # train a miniature character-level shakespeare model 2 | # good for debugging and playing on macbooks and such 3 | 4 | out_dir = 'out-shakespeare-char' 5 | eval_interval = 250 # keep frequent because we'll overfit 6 | eval_iters = 200 7 | log_interval = 10 # don't print too too often 8 | 9 | # we expect to overfit on this small dataset, so only save when val improves 10 | always_save_checkpoint = False 11 | 12 | wandb_log = False # override via command line if you like 13 | wandb_project = 'shakespeare-char' 14 | wandb_run_name = 'mini-gpt' 15 | 16 | dataset = 'shakespeare_char' 17 | gradient_accumulation_steps = 1 18 | batch_size = 64 19 | block_size = 256 # context of up to 256 previous characters 20 | 21 | # baby GPT model :) 22 | n_layer = 6 23 | n_head = 6 24 | n_embd = 384 25 | dropout = 0.2 26 | 27 | learning_rate = 1e-3 # with baby networks can afford to go a bit higher 28 | max_iters = 5000 29 | lr_decay_iters = 5000 # make equal to max_iters usually 30 | min_lr = 1e-4 # learning_rate / 10 usually 31 | beta2 = 0.99 # make a bit bigger because number of tokens per iter is small 32 | 33 | warmup_iters = 100 # not super necessary potentially 34 | 35 | # on macbook also add 36 | # device = 'cpu' # run on cpu only 37 | # compile = False # do not torch compile the model 38 | -------------------------------------------------------------------------------- /mup_examples/coord_check_shakespeare_char/sp/run.sh: -------------------------------------------------------------------------------- 1 | for width in 256 512 1024 2048 4096 2 | do 3 | for seed in 1 2 3 4 5 4 | do 5 | head_size=64 6 | n_heads=$((width / head_size)) 7 | out_dir="mup_examples/coord_check_shakespeare_char/sp/out/width${width}_depth2_seed${seed}" 8 | python train.py \ 9 | --out_dir=$out_dir \ 10 | --eval_interval=1 \ 11 | --log_interval=1 \ 12 | --eval_iters=1 \ 13 | --eval_only=False \ 14 | --always_save_checkpoint=False \ 15 | --never_save_checkpoint=True \ 16 | --init_from='scratch' \ 17 | --wandb_log=False \ 18 | --csv_log=True \ 19 | --dataset='shakespeare_char' \ 20 | --gradient_accumulation_steps=4 \ 21 | --batch_size=2 \ 22 | --block_size=1024 \ 23 | --n_layer=2 \ 24 | --n_head=$n_heads \ 25 | --n_embd=$width \ 26 | --dropout=0.0 \ 27 | --bias=False \ 28 | --init_std=0.02 \ 29 | --learning_rate=1e-2 \ 30 | --max_iters=10 \ 31 | --weight_decay=1e-1 \ 32 | --beta1=0.9 \ 33 | --beta2=0.95 \ 34 | --grad_clip=1.0 \ 35 | --decay_lr=False \ 36 | --seed=$seed \ 37 | --backend='nccl' \ 38 | --device='mps' \ 39 | --dtype='float32' \ 40 | --compile=False \ 41 | --mup_enable_coord_check_logging=True 42 | done 43 | done 44 | -------------------------------------------------------------------------------- /mup_examples/coord_check_shakespeare_char/mup/run.sh: -------------------------------------------------------------------------------- 1 | for width in 256 512 1024 2048 4096 2 | do 3 | for seed in 1 2 3 4 5 4 | do 5 | head_size=64 6 | n_heads=$((width / head_size)) 7 | mup_base_width=256 8 | mup_width_multiplier=$(echo "scale=8; $width/$mup_base_width" | bc -l) 9 | out_dir="mup_examples/coord_check_shakespeare_char/mup/out/width${width}_depth2_seed${seed}" 10 | python train.py \ 11 | --out_dir=$out_dir \ 12 | --eval_interval=1 \ 13 | --log_interval=1 \ 14 | --eval_iters=1 \ 15 | --eval_only=False \ 16 | --always_save_checkpoint=False \ 17 | --never_save_checkpoint=True \ 18 | --init_from='scratch' \ 19 | --wandb_log=False \ 20 | --csv_log=True \ 21 | --dataset='shakespeare_char' \ 22 | --gradient_accumulation_steps=4 \ 23 | --batch_size=2 \ 24 | --block_size=1024 \ 25 | --n_layer=2 \ 26 | --n_head=$n_heads \ 27 | --n_embd=$width \ 28 | --dropout=0.0 \ 29 | --bias=False \ 30 | --init_std=0.02 \ 31 | --learning_rate=1e-2 \ 32 | --max_iters=10 \ 33 | --weight_decay=1e-1 \ 34 | --beta1=0.9 \ 35 | --beta2=0.95 \ 36 | --grad_clip=1.0 \ 37 | --decay_lr=False \ 38 | --mup_enabled=True \ 39 | --mup_width_multiplier=$mup_width_multiplier \ 40 | --mup_input_alpha=1.0 \ 41 | --mup_output_alpha=1.0 \ 42 | --mup_enable_coord_check_logging=True \ 43 | --seed=$seed \ 44 | --backend='nccl' \ 45 | --device='mps' \ 46 | --dtype='float32' \ 47 | --compile=False 48 | done 49 | done 50 | -------------------------------------------------------------------------------- /mup_examples/coord_check_shakespeare_char/sp_with_mup_hidden_init/run.sh: -------------------------------------------------------------------------------- 1 | for width in 256 512 1024 2048 4096 2 | do 3 | for seed in 1 2 3 4 5 4 | do 5 | head_size=64 6 | n_heads=$((width / head_size)) 7 | mup_base_width=256 8 | mup_width_multiplier=$(echo "scale=8; $width/$mup_base_width" | bc -l) 9 | out_dir="mup_examples/coord_check_shakespeare_char/sp_with_mup_hidden_init/out/width${width}_depth2_seed${seed}" 10 | python train.py \ 11 | --out_dir=$out_dir \ 12 | --eval_interval=1 \ 13 | --log_interval=1 \ 14 | --eval_iters=1 \ 15 | --eval_only=False \ 16 | --always_save_checkpoint=False \ 17 | --never_save_checkpoint=True \ 18 | --init_from='scratch' \ 19 | --wandb_log=False \ 20 | --csv_log=True \ 21 | --dataset='shakespeare_char' \ 22 | --gradient_accumulation_steps=4 \ 23 | --batch_size=2 \ 24 | --block_size=1024 \ 25 | --n_layer=2 \ 26 | --n_head=$n_heads \ 27 | --n_embd=$width \ 28 | --dropout=0.0 \ 29 | --bias=False \ 30 | --init_std=0.02 \ 31 | --learning_rate=1e-2 \ 32 | --max_iters=10 \ 33 | --weight_decay=1e-1 \ 34 | --beta1=0.9 \ 35 | --beta2=0.95 \ 36 | --grad_clip=1.0 \ 37 | --decay_lr=False \ 38 | --mup_enabled=True \ 39 | --mup_disable_attention_scaling=True \ 40 | --mup_disable_hidden_lr_scaling=True \ 41 | --mup_width_multiplier=$mup_width_multiplier \ 42 | --mup_input_alpha=1.0 \ 43 | --mup_output_alpha=$mup_width_multiplier \ 44 | --mup_enable_coord_check_logging=True \ 45 | --seed=$seed \ 46 | --backend='nccl' \ 47 | --device='mps' \ 48 | --dtype='float32' \ 49 | --compile=False 50 | done 51 | done 52 | -------------------------------------------------------------------------------- /mup_examples/coord_check_shakespeare_char/sp_with_mup_hidden_init_and_lr/run.sh: -------------------------------------------------------------------------------- 1 | for width in 256 512 1024 2048 4096 2 | do 3 | for seed in 1 2 3 4 5 4 | do 5 | head_size=64 6 | n_heads=$((width / head_size)) 7 | mup_base_width=256 8 | mup_width_multiplier=$(echo "scale=8; $width/$mup_base_width" | bc -l) 9 | out_dir="mup_examples/coord_check_shakespeare_char/sp_with_mup_hidden_init_and_lr/out/width${width}_depth2_seed${seed}" 10 | python train.py \ 11 | --out_dir=$out_dir \ 12 | --eval_interval=1 \ 13 | --log_interval=1 \ 14 | --eval_iters=1 \ 15 | --eval_only=False \ 16 | --always_save_checkpoint=False \ 17 | --never_save_checkpoint=True \ 18 | --init_from='scratch' \ 19 | --wandb_log=False \ 20 | --csv_log=True \ 21 | --dataset='shakespeare_char' \ 22 | --gradient_accumulation_steps=4 \ 23 | --batch_size=2 \ 24 | --block_size=1024 \ 25 | --n_layer=2 \ 26 | --n_head=$n_heads \ 27 | --n_embd=$width \ 28 | --dropout=0.0 \ 29 | --bias=False \ 30 | --init_std=0.02 \ 31 | --learning_rate=1e-2 \ 32 | --max_iters=10 \ 33 | --weight_decay=1e-1 \ 34 | --beta1=0.9 \ 35 | --beta2=0.95 \ 36 | --grad_clip=1.0 \ 37 | --decay_lr=False \ 38 | --mup_enabled=True \ 39 | --mup_disable_attention_scaling=True \ 40 | --mup_disable_hidden_lr_scaling=False \ 41 | --mup_width_multiplier=$mup_width_multiplier \ 42 | --mup_input_alpha=1.0 \ 43 | --mup_output_alpha=$mup_width_multiplier \ 44 | --mup_enable_coord_check_logging=True \ 45 | --seed=$seed \ 46 | --backend='nccl' \ 47 | --device='mps' \ 48 | --dtype='float32' \ 49 | --compile=False 50 | done 51 | done 52 | -------------------------------------------------------------------------------- /mup_examples/coord_check_shakespeare_char/sp_with_mup_hidden_init_and_lr_output_logits/run.sh: -------------------------------------------------------------------------------- 1 | for width in 256 512 1024 2048 4096 2 | do 3 | for seed in 1 2 3 4 5 4 | do 5 | head_size=64 6 | n_heads=$((width / head_size)) 7 | mup_base_width=256 8 | mup_width_multiplier=$(echo "scale=8; $width/$mup_base_width" | bc -l) 9 | out_dir="mup_examples/coord_check_shakespeare_char/sp_with_mup_hidden_init_and_lr_output_logits/out/width${width}_depth2_seed${seed}" 10 | python train.py \ 11 | --out_dir=$out_dir \ 12 | --eval_interval=1 \ 13 | --log_interval=1 \ 14 | --eval_iters=1 \ 15 | --eval_only=False \ 16 | --always_save_checkpoint=False \ 17 | --never_save_checkpoint=True \ 18 | --init_from='scratch' \ 19 | --wandb_log=False \ 20 | --csv_log=True \ 21 | --dataset='shakespeare_char' \ 22 | --gradient_accumulation_steps=4 \ 23 | --batch_size=2 \ 24 | --block_size=1024 \ 25 | --n_layer=2 \ 26 | --n_head=$n_heads \ 27 | --n_embd=$width \ 28 | --dropout=0.0 \ 29 | --bias=False \ 30 | --init_std=0.02 \ 31 | --learning_rate=1e-2 \ 32 | --max_iters=10 \ 33 | --weight_decay=1e-1 \ 34 | --beta1=0.9 \ 35 | --beta2=0.95 \ 36 | --grad_clip=1.0 \ 37 | --decay_lr=False \ 38 | --mup_enabled=True \ 39 | --mup_disable_attention_scaling=True \ 40 | --mup_disable_hidden_lr_scaling=False \ 41 | --mup_width_multiplier=$mup_width_multiplier \ 42 | --mup_input_alpha=1.0 \ 43 | --mup_output_alpha=1.0 \ 44 | --mup_enable_coord_check_logging=True \ 45 | --seed=$seed \ 46 | --backend='nccl' \ 47 | --device='mps' \ 48 | --dtype='float32' \ 49 | --compile=False 50 | done 51 | done 52 | -------------------------------------------------------------------------------- /configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /mup_examples/coord_check_shakespeare_char/sp_with_mup_hidden_init_and_lr_partial_output_logits/run.sh: -------------------------------------------------------------------------------- 1 | for width in 256 512 1024 2048 4096 2 | do 3 | for seed in 1 2 3 4 5 4 | do 5 | head_size=64 6 | n_heads=$((width / head_size)) 7 | mup_base_width=256 8 | mup_width_multiplier=$(echo "scale=8; $width/$mup_base_width" | bc -l) 9 | out_dir="mup_examples/coord_check_shakespeare_char/sp_with_mup_hidden_init_and_lr_partial_output_logits/out/width${width}_depth2_seed${seed}" 10 | mup_output_alpha=$(echo "scale=8; sqrt($mup_width_multiplier)" | bc -l) 11 | python train.py \ 12 | --out_dir=$out_dir \ 13 | --eval_interval=1 \ 14 | --log_interval=1 \ 15 | --eval_iters=1 \ 16 | --eval_only=False \ 17 | --always_save_checkpoint=False \ 18 | --never_save_checkpoint=True \ 19 | --init_from='scratch' \ 20 | --wandb_log=False \ 21 | --csv_log=True \ 22 | --dataset='shakespeare_char' \ 23 | --gradient_accumulation_steps=4 \ 24 | --batch_size=2 \ 25 | --block_size=1024 \ 26 | --n_layer=2 \ 27 | --n_head=$n_heads \ 28 | --n_embd=$width \ 29 | --dropout=0.0 \ 30 | --bias=False \ 31 | --init_std=0.02 \ 32 | --learning_rate=1e-2 \ 33 | --max_iters=10 \ 34 | --weight_decay=1e-1 \ 35 | --beta1=0.9 \ 36 | --beta2=0.95 \ 37 | --grad_clip=1.0 \ 38 | --decay_lr=False \ 39 | --mup_enabled=True \ 40 | --mup_disable_attention_scaling=True \ 41 | --mup_disable_hidden_lr_scaling=False \ 42 | --mup_width_multiplier=$mup_width_multiplier \ 43 | --mup_input_alpha=1.0 \ 44 | --mup_output_alpha=$mup_output_alpha \ 45 | --mup_enable_coord_check_logging=True \ 46 | --seed=$seed \ 47 | --backend='nccl' \ 48 | --device='mps' \ 49 | --dtype='float32' \ 50 | --compile=False 51 | done 52 | done 53 | -------------------------------------------------------------------------------- /mup_examples/mutransfer_lr_shakespeare_char/sp/run.sh: -------------------------------------------------------------------------------- 1 | for width in 256 512 1024 2048 2 | do 3 | for lr in 0.00390625 0.001953125 0.0009765625 0.00048828125 0.000244140625 0.0001220703125 0.00006103515625 0.00003051757812 0.00048828125 0.000244140625 0.0001220703125 0.00006103515625 0.00003051757812 0.00001525878906 0.000007629394531 0.000003814697266 4 | do 5 | for seed in 1 2 3 6 | do 7 | head_size=64 8 | n_heads=$((width / head_size)) 9 | out_dir="mup_examples/mutransfer_lr_shakespeare_char/sp/out/width${width}_depth2_seed${seed}_lr${lr}" 10 | python train.py \ 11 | --out_dir=$out_dir \ 12 | --eval_interval=1 \ 13 | --log_interval=1 \ 14 | --eval_iters=1 \ 15 | --eval_only=False \ 16 | --skip_val_loss=True \ 17 | --always_save_checkpoint=False \ 18 | --never_save_checkpoint=True \ 19 | --init_from='scratch' \ 20 | --wandb_log=False \ 21 | --csv_log=True \ 22 | --dataset='shakespeare_char' \ 23 | --gradient_accumulation_steps=8 \ 24 | --batch_size=1 \ 25 | --block_size=1024 \ 26 | --n_layer=2 \ 27 | --n_head=$n_heads \ 28 | --n_embd=$width \ 29 | --dropout=0.0 \ 30 | --bias=False \ 31 | --init_std=0.02 \ 32 | --learning_rate=$lr \ 33 | --max_iters=122 \ 34 | --weight_decay=1e-1 \ 35 | --beta1=0.9 \ 36 | --beta2=0.95 \ 37 | --grad_clip=1.0 \ 38 | --decay_lr=False \ 39 | --seed=$seed \ 40 | --backend='nccl' \ 41 | --device='mps' \ 42 | --dtype='float32' \ 43 | --compile=False 44 | done 45 | done 46 | done 47 | -------------------------------------------------------------------------------- /mup_examples/mutransfer_lr_owt/sp/run.sh: -------------------------------------------------------------------------------- 1 | # Single-GPU Launching 2 | LAUNCHER=python 3 | 4 | # Multi-GPU Launching (single node) 5 | #GPU=2 6 | #LAUNCHER=torchrun --standalone --nproc_per_node=$GPU 7 | 8 | LAYERS=2 9 | 10 | for width in 256 512 1024 2048 11 | do 12 | for lr in 0.125 0.0625 0.03125 0.015625 0.0078125 0.00390625 0.001953125 0.0009765625 0.00048828125 0.000244140625 0.0001220703125 0.00006103515625 13 | do 14 | for seed in 1 2 3 15 | do 16 | head_size=64 17 | n_heads=$((width / head_size)) 18 | min_lr=$(awk "BEGIN {print $lr/10}") 19 | out_dir="mup_examples/mutransfer_lr_owt/sp/out/width${width}_depth${LAYERS}_seed${seed}_lr${lr}" 20 | $LAUNCHER train.py \ 21 | --out_dir=$out_dir \ 22 | --eval_interval=1 \ 23 | --log_interval=1 \ 24 | --eval_iters=1 \ 25 | --eval_only=False \ 26 | --skip_val_loss=True \ 27 | --always_save_checkpoint=False \ 28 | --never_save_checkpoint=True \ 29 | --init_from='scratch' \ 30 | --wandb_log=False \ 31 | --csv_log=True \ 32 | --dataset='openwebtext' \ 33 | --gradient_accumulation_steps=1 \ 34 | --batch_size=32 \ 35 | --block_size=1024 \ 36 | --n_layer=2 \ 37 | --n_head=$n_heads \ 38 | --n_embd=$width \ 39 | --dropout=0.0 \ 40 | --bias=False \ 41 | --init_std=0.02 \ 42 | --learning_rate=$lr \ 43 | --lr_decay_iters=1000 \ 44 | --min_lr=$min_lr \ 45 | --max_iters=1000 \ 46 | --weight_decay=1e-1 \ 47 | --beta1=0.9 \ 48 | --beta2=0.95 \ 49 | --grad_clip=1.0 \ 50 | --decay_lr=True \ 51 | --seed=$seed \ 52 | --backend='nccl' \ 53 | --device='cuda' \ 54 | --dtype='bfloat16' \ 55 | --compile=True 56 | done 57 | done 58 | done 59 | -------------------------------------------------------------------------------- /mup_examples/mutransfer_lr_shakespeare_char/mup/run.sh: -------------------------------------------------------------------------------- 1 | for width in 256 512 1024 2048 2 | do 3 | for lr in 0.125 0.0625 0.03125 0.015625 0.0078125 0.00390625 0.001953125 0.0009765625 0.00048828125 0.000244140625 0.0001220703125 0.00006103515625 4 | do 5 | for seed in 1 2 3 6 | do 7 | head_size=64 8 | n_heads=$((width / head_size)) 9 | mup_base_width=256 10 | mup_width_multiplier=$(echo "scale=8; $width/$mup_base_width" | bc -l) 11 | out_dir="mup_examples/mutransfer_lr_shakespeare_char/mup/out/width${width}_depth2_seed${seed}_lr${lr}" 12 | python train.py \ 13 | --out_dir=$out_dir \ 14 | --eval_interval=1 \ 15 | --log_interval=1 \ 16 | --eval_iters=1 \ 17 | --eval_only=False \ 18 | --skip_val_loss=True \ 19 | --always_save_checkpoint=False \ 20 | --never_save_checkpoint=True \ 21 | --init_from='scratch' \ 22 | --wandb_log=False \ 23 | --csv_log=True \ 24 | --dataset='shakespeare_char' \ 25 | --gradient_accumulation_steps=8\ 26 | --batch_size=1 \ 27 | --block_size=1024 \ 28 | --n_layer=2 \ 29 | --n_head=$n_heads \ 30 | --n_embd=$width \ 31 | --dropout=0.0 \ 32 | --bias=False \ 33 | --init_std=0.02 \ 34 | --learning_rate=$lr \ 35 | --max_iters=122 \ 36 | --weight_decay=1e-1 \ 37 | --beta1=0.9 \ 38 | --beta2=0.95 \ 39 | --grad_clip=1.0 \ 40 | --decay_lr=False \ 41 | --mup_enabled=True \ 42 | --mup_width_multiplier=$mup_width_multiplier \ 43 | --mup_input_alpha=1.0 \ 44 | --mup_output_alpha=1.0 \ 45 | --seed=$seed \ 46 | --backend='nccl' \ 47 | --device='mps' \ 48 | --dtype='float32' \ 49 | --compile=False 50 | done 51 | done 52 | done 53 | -------------------------------------------------------------------------------- /mup_examples/mutransfer_lr_owt/mup/run.sh: -------------------------------------------------------------------------------- 1 | # Single-GPU Launching 2 | LAUNCHER=python 3 | 4 | # Multi-GPU Launching (single node) 5 | #GPU=2 6 | #LAUNCHER=torchrun --standalone --nproc_per_node=$GPU 7 | 8 | LAYERS=2 9 | 10 | for width in 256 512 1024 2048 11 | do 12 | for lr in 0.125 0.0625 0.03125 0.015625 0.0078125 0.00390625 0.001953125 0.0009765625 0.00048828125 0.000244140625 0.0001220703125 0.00006103515625 13 | do 14 | for seed in 1 2 3 15 | do 16 | head_size=64 17 | n_heads=$((width / head_size)) 18 | min_lr=$(awk "BEGIN {print $lr/10}") 19 | mup_base_width=256 20 | mup_width_multiplier=$(echo "scale=8; $width/$mup_base_width" | bc -l) 21 | out_dir="mup_examples/mutransfer_lr_owt/mup/out/width${width}_depth${LAYERS}_seed${seed}_lr${lr}" 22 | $LAUNCHER train.py \ 23 | --out_dir=$out_dir \ 24 | --eval_interval=1 \ 25 | --log_interval=1 \ 26 | --eval_iters=1 \ 27 | --eval_only=False \ 28 | --skip_val_loss=True \ 29 | --always_save_checkpoint=False \ 30 | --never_save_checkpoint=True \ 31 | --init_from='scratch' \ 32 | --wandb_log=False \ 33 | --csv_log=True \ 34 | --dataset='openwebtext' \ 35 | --gradient_accumulation_steps=1 \ 36 | --batch_size=32 \ 37 | --block_size=1024 \ 38 | --n_layer=2 \ 39 | --n_head=$n_heads \ 40 | --n_embd=$width \ 41 | --dropout=0.0 \ 42 | --bias=False \ 43 | --init_std=0.02 \ 44 | --learning_rate=$lr \ 45 | --lr_decay_iters=1000 \ 46 | --min_lr=$min_lr \ 47 | --max_iters=1000 \ 48 | --weight_decay=1e-1 \ 49 | --beta1=0.9 \ 50 | --beta2=0.95 \ 51 | --grad_clip=1.0 \ 52 | --decay_lr=True \ 53 | --mup_enabled=True \ 54 | --mup_width_multiplier=$mup_width_multiplier \ 55 | --mup_input_alpha=1.0 \ 56 | --mup_output_alpha=1.0 \ 57 | --seed=$seed \ 58 | --backend='nccl' \ 59 | --device='cuda' \ 60 | --dtype='bfloat16' \ 61 | --compile=True 62 | done 63 | done 64 | done 65 | -------------------------------------------------------------------------------- /data/shakespeare_char/prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prepare the Shakespeare dataset for character-level language modeling. 3 | So instead of encoding with GPT-2 BPE tokens, we just map characters to ints. 4 | Will save train.bin, val.bin containing the ids, and meta.pkl containing the 5 | encoder and decoder and some other related info. 6 | """ 7 | import os 8 | import pickle 9 | import requests 10 | import numpy as np 11 | 12 | # download the tiny shakespeare dataset 13 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') 14 | if not os.path.exists(input_file_path): 15 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 16 | with open(input_file_path, 'w') as f: 17 | f.write(requests.get(data_url).text) 18 | 19 | with open(input_file_path, 'r') as f: 20 | data = f.read() 21 | print(f"length of dataset in characters: {len(data):,}") 22 | 23 | # get all the unique characters that occur in this text 24 | chars = sorted(list(set(data))) 25 | vocab_size = len(chars) 26 | print("all the unique characters:", ''.join(chars)) 27 | print(f"vocab size: {vocab_size:,}") 28 | 29 | # create a mapping from characters to integers 30 | stoi = { ch:i for i,ch in enumerate(chars) } 31 | itos = { i:ch for i,ch in enumerate(chars) } 32 | def encode(s): 33 | return [stoi[c] for c in s] # encoder: take a string, output a list of integers 34 | def decode(l): 35 | return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string 36 | 37 | # create the train and test splits 38 | n = len(data) 39 | train_data = data[:int(n*0.9)] 40 | val_data = data[int(n*0.9):] 41 | 42 | # encode both to integers 43 | train_ids = encode(train_data) 44 | val_ids = encode(val_data) 45 | print(f"train has {len(train_ids):,} tokens") 46 | print(f"val has {len(val_ids):,} tokens") 47 | 48 | # export to bin files 49 | train_ids = np.array(train_ids, dtype=np.uint16) 50 | val_ids = np.array(val_ids, dtype=np.uint16) 51 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) 52 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) 53 | 54 | # save the meta information as well, to help us encode/decode later 55 | meta = { 56 | 'vocab_size': vocab_size, 57 | 'itos': itos, 58 | 'stoi': stoi, 59 | } 60 | with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f: 61 | pickle.dump(meta, f) 62 | 63 | # length of dataset in characters: 1115394 64 | # all the unique characters: 65 | # !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz 66 | # vocab size: 65 67 | # train has 1003854 tokens 68 | # val has 111540 tokens 69 | -------------------------------------------------------------------------------- /mup_examples/README.md: -------------------------------------------------------------------------------- 1 | # Experiment Reproduction 2 | 3 | Install the minimal dataset and plotting requirements with `pip install -r requirements.txt`. We used the PyTorch NGC container for GPU-based runs, but any environment containing the dependencies from [the main README](https://github.com/EleutherAI/nanoGPT-mup?tab=readme-ov-file#install) will suffice. 4 | 5 | To download the tiny shakespeare dataset, run `python data/shakespeare_char/prepare.py`. For OpenWebText (OWT), run `python data/openwebtext/prepare.py`. 6 | 7 | 8 | # Coordinate Checks 9 | 10 | The lowest-overhead correctness check of a mutransfer implementation is a [coordinate check](https://github.com/microsoft/mup?tab=readme-ov-file#checking-correctness-of-parametrization). 11 | 12 | To run coordinate checks in our implementation using the tiny shakespeare dataset, use the following scripts for Standard Parameterization (SP): 13 | 14 | ``` 15 | bash mup_examples/coord_check_shakespeare_char/sp/run.sh 16 | ``` 17 | 18 | And muP: 19 | 20 | ``` 21 | bash mup_examples/coord_check_shakespeare_char/mup/run.sh 22 | ``` 23 | 24 | These scripts populate the `out/` subdirectories with your coord check data, which you can then plot with `mup_examples/coord_check_shakespeare_char/plot.ipynb` 25 | 26 | 27 | # Learning Rate muTransfer 28 | 29 | To actually test transferring hyperparameters, you need to run training for a set number of steps on a chosen dataset. 30 | 31 | 1. Tiny Shakespeare is small and simple enough to see stable training loss with few iterations and small batch sizes, so we recommend it to test transfer quickly or on compute-constrained systems (e.g. laptop/desktop CPU). 32 | 2. OpenWebText is comparatively large, but more representative of the massive webcrawl-based datasets used to train most models today. 33 | 34 | The default values chosen in each `run.sh` reflect this. 35 | 36 | ## Tiny Shakespeare 37 | 38 | To sweep over seeds, model widths, and learning rates on the tiny shakespeare dataset with muP: 39 | 40 | ``` 41 | bash mup_examples/mutransfer_lr_shakespeare_char/mup/run.sh 42 | ``` 43 | 44 | and SP: 45 | 46 | ``` 47 | bash mup_examples/mutransfer_lr_shakespeare_char/sp/run.sh 48 | ``` 49 | 50 | ## OpenWebText 51 | 52 | To sweep over seeds, model widths, and learning rates on the OpenWebText (OWT) dataset with muP: 53 | 54 | ``` 55 | bash mup_examples/mutransfer_lr_owt/mup/run.sh 56 | ``` 57 | 58 | and SP: 59 | 60 | ``` 61 | bash mup_examples/mutransfer_lr_owt/sp/run.sh 62 | ``` 63 | 64 | These scripts populate 65 | 66 | These scripts populate the `out/` subdirectories with your train loss data, which you can then plot with `mup_examples/mutransfer_lr/plot.ipynb` -------------------------------------------------------------------------------- /data/openwebtext/prepare.py: -------------------------------------------------------------------------------- 1 | # saves the openwebtext dataset to a binary file for training. following was helpful: 2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 3 | 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import tiktoken 8 | from datasets import load_dataset # huggingface datasets 9 | 10 | # number of workers in .map() call 11 | # good number to use is ~order number of cpu cores // 2 12 | num_proc = 8 13 | 14 | # number of workers in load_dataset() call 15 | # best number might be different from num_proc above as it also depends on NW speed. 16 | # it is better than 1 usually though 17 | num_proc_load_dataset = num_proc 18 | 19 | enc = tiktoken.get_encoding("gpt2") 20 | 21 | if __name__ == '__main__': 22 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 23 | dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset) 24 | 25 | # owt by default only contains the 'train' split, so create a test split 26 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 27 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val 28 | 29 | # this results in: 30 | # >>> split_dataset 31 | # DatasetDict({ 32 | # train: Dataset({ 33 | # features: ['text'], 34 | # num_rows: 8009762 35 | # }) 36 | # val: Dataset({ 37 | # features: ['text'], 38 | # num_rows: 4007 39 | # }) 40 | # }) 41 | 42 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 43 | def process(example): 44 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 45 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 46 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 47 | out = {'ids': ids, 'len': len(ids)} 48 | return out 49 | 50 | # tokenize the dataset 51 | tokenized = split_dataset.map( 52 | process, 53 | remove_columns=['text'], 54 | desc="tokenizing the splits", 55 | num_proc=num_proc, 56 | ) 57 | 58 | # concatenate all the ids in each dataset into one large file we can use for training 59 | for split, dset in tokenized.items(): 60 | arr_len = np.sum(dset['len'], dtype=np.uint64) 61 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 62 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 63 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 64 | total_batches = 1024 65 | 66 | idx = 0 67 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): 68 | # Batch together samples for faster write 69 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') 70 | arr_batch = np.concatenate(batch['ids']) 71 | # Write into mmap 72 | arr[idx : idx + len(arr_batch)] = arr_batch 73 | idx += len(arr_batch) 74 | arr.flush() 75 | 76 | # train.bin is ~17GB, val.bin ~8.5MB 77 | # train has ~9B tokens (9,035,582,198) 78 | # val has ~4M tokens (4,434,897) 79 | 80 | # to read the bin files later, e.g. with numpy: 81 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 82 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample from a trained model 3 | """ 4 | import os 5 | import pickle 6 | from contextlib import nullcontext 7 | import torch 8 | import tiktoken 9 | from model import GPTConfig, GPT 10 | 11 | # ----------------------------------------------------------------------------- 12 | init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl') 13 | out_dir = 'out' # ignored if init_from is not 'resume' 14 | start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt" 15 | num_samples = 10 # number of samples to draw 16 | max_new_tokens = 500 # number of tokens generated in each sample 17 | temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions 18 | top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability 19 | seed = 1337 20 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. 21 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' 22 | compile = False # use PyTorch 2.0 to compile the model to be faster 23 | exec(open('configurator.py').read()) # overrides from command line or config file 24 | # ----------------------------------------------------------------------------- 25 | 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 29 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 30 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 31 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 32 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 33 | 34 | # model 35 | if init_from == 'resume': 36 | # init from a model saved in a specific directory 37 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 38 | checkpoint = torch.load(ckpt_path, map_location=device) 39 | gptconf = GPTConfig(**checkpoint['model_args']) 40 | model = GPT(gptconf) 41 | state_dict = checkpoint['model'] 42 | unwanted_prefix = '_orig_mod.' 43 | for k,v in list(state_dict.items()): 44 | if k.startswith(unwanted_prefix): 45 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 46 | model.load_state_dict(state_dict) 47 | elif init_from.startswith('gpt2'): 48 | # init from a given GPT-2 model 49 | model = GPT.from_pretrained(init_from, dict(dropout=0.0)) 50 | 51 | model.eval() 52 | model.to(device) 53 | if compile: 54 | model = torch.compile(model) # requires PyTorch 2.0 (optional) 55 | 56 | # look for the meta pickle in case it is available in the dataset folder 57 | load_meta = False 58 | if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these... 59 | meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl') 60 | load_meta = os.path.exists(meta_path) 61 | if load_meta: 62 | print(f"Loading meta from {meta_path}...") 63 | with open(meta_path, 'rb') as f: 64 | meta = pickle.load(f) 65 | # TODO want to make this more general to arbitrary encoder/decoder schemes 66 | stoi, itos = meta['stoi'], meta['itos'] 67 | encode = lambda s: [stoi[c] for c in s] 68 | decode = lambda l: ''.join([itos[i] for i in l]) 69 | else: 70 | # ok let's assume gpt-2 encodings by default 71 | print("No meta.pkl found, assuming GPT-2 encodings...") 72 | enc = tiktoken.get_encoding("gpt2") 73 | encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) 74 | decode = lambda l: enc.decode(l) 75 | 76 | # encode the beginning of the prompt 77 | if start.startswith('FILE:'): 78 | with open(start[5:], 'r', encoding='utf-8') as f: 79 | start = f.read() 80 | start_ids = encode(start) 81 | x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) 82 | 83 | # run generation 84 | with torch.no_grad(): 85 | with ctx: 86 | for k in range(num_samples): 87 | y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) 88 | print(decode(y[0].tolist())) 89 | print('---------------') 90 | -------------------------------------------------------------------------------- /bench.py: -------------------------------------------------------------------------------- 1 | """ 2 | A much shorter version of train.py for benchmarking 3 | """ 4 | import os 5 | from contextlib import nullcontext 6 | import numpy as np 7 | import time 8 | import torch 9 | from model import GPTConfig, GPT 10 | 11 | # ----------------------------------------------------------------------------- 12 | batch_size = 12 13 | block_size = 1024 14 | bias = False 15 | real_data = True 16 | seed = 1337 17 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. 18 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' 19 | compile = True # use PyTorch 2.0 to compile the model to be faster 20 | profile = False # use pytorch profiler, or just simple benchmarking? 21 | exec(open('configurator.py').read()) # overrides from command line or config file 22 | # ----------------------------------------------------------------------------- 23 | 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 27 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 28 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 29 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 30 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 31 | 32 | # data loading init 33 | if real_data: 34 | dataset = 'openwebtext' 35 | data_dir = os.path.join('data', dataset) 36 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 37 | def get_batch(split): 38 | data = train_data # note ignore split in benchmarking script 39 | ix = torch.randint(len(data) - block_size, (batch_size,)) 40 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 41 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 42 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 43 | return x, y 44 | else: 45 | # alternatively, if fixed data is desired to not care about data loading 46 | x = torch.randint(50304, (batch_size, block_size), device=device) 47 | y = torch.randint(50304, (batch_size, block_size), device=device) 48 | get_batch = lambda split: (x, y) 49 | 50 | # model init 51 | gptconf = GPTConfig( 52 | block_size = block_size, # how far back does the model look? i.e. context size 53 | n_layer = 12, n_head = 12, n_embd = 768, # size of the model 54 | dropout = 0, # for determinism 55 | bias = bias, 56 | ) 57 | model = GPT(gptconf) 58 | model.to(device) 59 | 60 | optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type) 61 | 62 | if compile: 63 | print("Compiling model...") 64 | model = torch.compile(model) # pytorch 2.0 65 | 66 | if profile: 67 | # useful docs on pytorch profiler: 68 | # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html 69 | # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile 70 | wait, warmup, active = 5, 5, 5 71 | num_steps = wait + warmup + active 72 | with torch.profiler.profile( 73 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], 74 | schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), 75 | on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'), 76 | record_shapes=False, 77 | profile_memory=False, 78 | with_stack=False, # incurs an additional overhead, disable if not needed 79 | with_flops=True, 80 | with_modules=False, # only for torchscript models atm 81 | ) as prof: 82 | 83 | X, Y = get_batch('train') 84 | for k in range(num_steps): 85 | with ctx: 86 | logits, loss = model(X, Y) 87 | X, Y = get_batch('train') 88 | optimizer.zero_grad(set_to_none=True) 89 | loss.backward() 90 | optimizer.step() 91 | lossf = loss.item() 92 | print(f"{k}/{num_steps} loss: {lossf:.4f}") 93 | 94 | prof.step() # notify the profiler at end of each step 95 | 96 | else: 97 | 98 | # simple benchmarking 99 | torch.cuda.synchronize() 100 | for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark 101 | t0 = time.time() 102 | X, Y = get_batch('train') 103 | for k in range(num_steps): 104 | with ctx: 105 | logits, loss = model(X, Y) 106 | X, Y = get_batch('train') 107 | optimizer.zero_grad(set_to_none=True) 108 | loss.backward() 109 | optimizer.step() 110 | lossf = loss.item() 111 | print(f"{k}/{num_steps} loss: {lossf:.4f}") 112 | torch.cuda.synchronize() 113 | t1 = time.time() 114 | dt = t1-t0 115 | mfu = model.estimate_mfu(batch_size * 1 * num_steps, dt) 116 | if stage == 1: 117 | print(f"time per iteration: {dt/num_steps*1000:.4f}ms, MFU: {mfu*100:.2f}%") 118 | -------------------------------------------------------------------------------- /csv_logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authored by Gavia Gray (https://github.com/gngdb) 3 | 4 | Wrapper for wandb logging with efficient CSV logging and correct config JSON writing. 5 | The CSV structure maintains a consistent order of keys based on their first appearance, 6 | using a simple list for ordering. This ensures data integrity and allows for graceful 7 | failure and manual recovery if needed. 8 | 9 | Example usage: 10 | run = wandb.init(config=your_config) 11 | wrapper = LogWrapper(run, out_dir='path/to/output') 12 | 13 | ... 14 | # in train loop 15 | wrapper.log({"train/loss": 0.5, "train/accuracy": 0.9, "val/loss": 0.6, "val/accuracy": 0.85}) 16 | wrapper.print("Train: {loss=:.4f}, {accuracy=:.2%}", prefix="train/") 17 | wrapper.print("Val: {loss=:.4f}, {accuracy=:.2%}", prefix="val/") 18 | wrapper.step() 19 | 20 | ... 21 | # at the end of your script 22 | wrapper.close() 23 | 24 | # If the script terminates unexpectedly, you can still recover the CSV using bash: 25 | # cat path/to/output/log_header.csv.tmp path/to/output/log_data.csv.tmp > path/to/output/log.csv 26 | """ 27 | 28 | import re 29 | import os 30 | import csv 31 | import json 32 | import atexit 33 | 34 | 35 | def exists(x): return x is not None 36 | 37 | def transform_format_string(s): 38 | """ 39 | Transforms a string containing f-string-like expressions to a format 40 | compatible with str.format(). 41 | 42 | This function converts expressions like '{var=}' or '{var=:formatting}' 43 | to 'var={var}' or 'var={var:formatting}' respectively. This allows 44 | for f-string-like syntax to be used with str.format(). 45 | 46 | Args: 47 | s (str): The input string containing f-string-like expressions. 48 | 49 | Returns: 50 | str: The transformed string, compatible with str.format(). 51 | 52 | Examples: 53 | >>> transform_format_string("Value is {x=}") 54 | "Value is x={x}" 55 | >>> transform_format_string("Formatted value is {x=:.2f}") 56 | "Formatted value is x={x:.2f}" 57 | """ 58 | pattern = r'\{(\w+)=(:.[^}]*)?\}' 59 | return re.sub(pattern, lambda m: f"{m.group(1)}={{{m.group(1)}{m.group(2) or ''}}}", s) 60 | 61 | class CSVLogWrapper: 62 | def __init__(self, logf=None, config={}, out_dir=None, flush_every: int = 100): 63 | self.logf = logf 64 | self.config = config 65 | self.log_dict = {} 66 | self.out_dir = out_dir 67 | self.csv_data_file = None 68 | self.csv_header_file = None 69 | self.csv_writer = None 70 | self.step_count = 0 71 | self.flush_every = flush_every # how often to flush; 0 = never flush mid-run 72 | self.ordered_keys = [] 73 | self.header_updated = False 74 | self.is_finalized = False 75 | self.no_sync_keyword = 'no_sync' # Keyword to prevent syncing to wandb 76 | 77 | if self.out_dir: 78 | os.makedirs(self.out_dir, exist_ok=True) 79 | self.setup_csv_writer() 80 | self.write_config() 81 | 82 | atexit.register(self.close) 83 | 84 | def setup_csv_writer(self): 85 | self.csv_data_path = os.path.join(self.out_dir, 'log_data.csv.tmp') 86 | self.csv_header_path = os.path.join(self.out_dir, 'log_header.csv.tmp') 87 | self.csv_data_file = open(self.csv_data_path, 'w', newline='') 88 | self.csv_header_file = open(self.csv_header_path, 'w', newline='') 89 | self.csv_writer = csv.writer(self.csv_data_file) 90 | 91 | def write_config(self): 92 | if self.config: 93 | config_path = os.path.join(self.out_dir, 'config.json') 94 | with open(config_path, 'w') as f: 95 | json.dump(dict(**self.config), f, indent=2) 96 | 97 | def log(self, data): 98 | self.log_dict.update(data) 99 | for key in data: 100 | if key not in self.ordered_keys: 101 | self.ordered_keys.append(key) 102 | self.header_updated = True 103 | 104 | def update_header(self): 105 | if self.header_updated: 106 | header = ['step'] + self.ordered_keys 107 | with open(self.csv_header_path, 'w', newline='') as header_file: 108 | csv.writer(header_file).writerow(header) 109 | self.header_updated = False 110 | 111 | def print(self, format_string, prefix=None): 112 | format_string = transform_format_string(format_string) 113 | 114 | if prefix: 115 | # Filter keys with the given prefix and remove the prefix 116 | filtered_dict = {k.replace(prefix, ''): v for k, v in self.log_dict.items() if k.startswith(prefix)} 117 | else: 118 | filtered_dict = self.log_dict 119 | # replace any '/' in keys with '_' 120 | filtered_dict = {k.replace('/', '_'): v for k, v in filtered_dict.items()} 121 | 122 | try: 123 | print(format_string.format(**filtered_dict)) 124 | except KeyError as e: 125 | print(f"KeyError: {e}. Available keys: {', '.join(filtered_dict.keys())}") 126 | raise e 127 | 128 | def step(self): 129 | if exists(self.logf) and self.log_dict: 130 | self.logf({k: v for k, v in self.log_dict.items() if self.no_sync_keyword not in k}) 131 | 132 | if self.csv_writer and self.log_dict: 133 | self.update_header() 134 | 135 | # Prepare the row data 136 | row_data = [self.step_count] + [self.log_dict.get(key, '') for key in self.ordered_keys] 137 | self.csv_writer.writerow(row_data) 138 | if self.flush_every and (self.step_count % self.flush_every == 0): 139 | self.csv_data_file.flush() 140 | 141 | self.step_count += 1 142 | self.log_dict.clear() 143 | 144 | def close(self): 145 | if self.csv_data_file: 146 | self.csv_data_file.close() 147 | 148 | self.finalize_csv() 149 | 150 | def finalize_csv(self): 151 | if self.is_finalized: 152 | return 153 | 154 | csv_final_path = os.path.join(self.out_dir, 'log.csv') 155 | 156 | with open(csv_final_path, 'w', newline='') as final_csv: 157 | # Copy header 158 | with open(self.csv_header_path, 'r') as header_file: 159 | final_csv.write(header_file.read()) 160 | 161 | # Copy data 162 | with open(self.csv_data_path, 'r') as data_file: 163 | final_csv.write(data_file.read()) 164 | self.is_finalized = True 165 | 166 | # Remove the temporary files 167 | os.remove(self.csv_header_path) 168 | os.remove(self.csv_data_path) 169 | -------------------------------------------------------------------------------- /transformer_sizing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "### Transformer Theoretical Model\n", 9 | "\n", 10 | "This notebook stores a bunch of analysis about a Transformer, e.g. estimates the number of FLOPs, parameters, peak memory footprint, checkpoint size, etc." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from collections import OrderedDict" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# config_args = {\n", 29 | "# 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params\n", 30 | "# 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n", 31 | "# 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n", 32 | "# 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n", 33 | "# }[model_type]\n", 34 | "\n", 35 | "block_size = 1024\n", 36 | "vocab_size = 50257\n", 37 | "n_layer = 12\n", 38 | "n_head = 12\n", 39 | "n_embd = 768\n", 40 | "bias = False\n", 41 | "assert not bias, \"this notebook assumes bias=False just for simplicity\"" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "we see: 124337664, expected: 124337664, match: True\n", 54 | "name params ratio (%) \n", 55 | "emebedding/position 786432 0.6325\n", 56 | "embedding/token 38597376 31.0424\n", 57 | "embedding 39383808 31.6749\n", 58 | "attention/ln 768 0.0006\n", 59 | "attention/kqv 1769472 1.4231\n", 60 | "attention/proj 589824 0.4744\n", 61 | "attention 2360064 1.8981\n", 62 | "mlp/ln 768 0.0006\n", 63 | "mlp/ffw 2359296 1.8975\n", 64 | "mlp/proj 2359296 1.8975\n", 65 | "mlp 4719360 3.7956\n", 66 | "block 7079424 5.6937\n", 67 | "transformer 84953088 68.3245\n", 68 | "ln_f 768 0.0006\n", 69 | "dense 0 0.0000\n", 70 | "total 124337664 100.0000\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "def params():\n", 76 | " \"\"\" estimates the number of parameters in the model\"\"\"\n", 77 | " out = OrderedDict()\n", 78 | "\n", 79 | " # token and position embeddings\n", 80 | " out['emebedding/position'] = n_embd * block_size\n", 81 | " out['embedding/token'] = n_embd * vocab_size\n", 82 | " out['embedding'] = out['emebedding/position'] + out['embedding/token']\n", 83 | "\n", 84 | " # attention blocks\n", 85 | " out['attention/ln'] = n_embd # note, bias=False in our LN\n", 86 | " out['attention/kqv'] = n_embd * 3*n_embd\n", 87 | " out['attention/proj'] = n_embd**2\n", 88 | " out['attention'] = out['attention/ln'] + out['attention/kqv'] + out['attention/proj']\n", 89 | "\n", 90 | " # MLP blocks\n", 91 | " ffw_size = 4*n_embd # feed forward size\n", 92 | " out['mlp/ln'] = n_embd\n", 93 | " out['mlp/ffw'] = n_embd * ffw_size\n", 94 | " out['mlp/proj'] = ffw_size * n_embd\n", 95 | " out['mlp'] = out['mlp/ln'] + out['mlp/ffw'] + out['mlp/proj']\n", 96 | " \n", 97 | " # the transformer and the rest of it\n", 98 | " out['block'] = out['attention'] + out['mlp']\n", 99 | " out['transformer'] = n_layer * out['block']\n", 100 | " out['ln_f'] = n_embd # final layernorm\n", 101 | " out['dense'] = 0 # 0 because of parameter sharing. This layer uses the weights from the embedding layer\n", 102 | "\n", 103 | " # total\n", 104 | " out['total'] = out['embedding'] + out['transformer'] + out['ln_f'] + out['dense']\n", 105 | "\n", 106 | " return out\n", 107 | "\n", 108 | "# compare our param count to that reported by PyTorch\n", 109 | "p = params()\n", 110 | "params_total = p['total']\n", 111 | "print(f\"we see: {params_total}, expected: {124337664}, match: {params_total == 124337664}\")\n", 112 | "# create a header\n", 113 | "print(f\"{'name':20s} {'params':10s} {'ratio (%)':10s}\")\n", 114 | "for k,v in p.items():\n", 115 | " print(f\"{k:20s} {v:10d} {v/params_total*100:10.4f}\")\n", 116 | " " 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 4, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "est checkpoint size: 1.49 GB\n", 129 | "measured with wc -c ckpt.pt: 1542470366\n", 130 | "fluff ratio: 103.38%\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "# we can now calculate the size of each checkpoint\n", 136 | "# params are stored in fp32, and the AdamW optimizer has 2 additional buffers per param for statistics\n", 137 | "params_bytes = params_total*4\n", 138 | "params_and_buffers_bytes = params_bytes + 2*params_bytes\n", 139 | "print(f\"est checkpoint size: {params_and_buffers_bytes/1e9:.2f} GB\")\n", 140 | "measured_bytes = 1542470366 # from wc -c ckpt.pt\n", 141 | "print(f\"measured with wc -c ckpt.pt: {measured_bytes}\")\n", 142 | "print(f\"fluff ratio: {measured_bytes/params_and_buffers_bytes*100:.2f}%\")" 143 | ] 144 | }, 145 | { 146 | "attachments": {}, 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "We can also estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 5, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "memory ratio taken up just for parameters: 3.73%\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "gpu_memory = 40e9 # 40 GB A100 GPU, roughly\n", 168 | "print(f\"memory ratio taken up just for parameters: {params_and_buffers_bytes / gpu_memory * 100:.2f}%\")" 169 | ] 170 | }, 171 | { 172 | "attachments": {}, 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "i.e. not that much of the memory for this tiny model, most of the memory is activations (forward and backward). This of course changes dramatically for larger and larger models." 177 | ] 178 | }, 179 | { 180 | "attachments": {}, 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "Let's estimate FLOPs for a single forward pass." 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 6, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "name flops ratio (%) \n", 197 | "attention/kqv 3623878656 1.2426\n", 198 | "attention/scores 1610612736 0.5522\n", 199 | "attention/reduce 1610612736 0.5522\n", 200 | "attention/proj 1207959552 0.4142\n", 201 | "attention 8053063680 2.7612\n", 202 | "mlp/ffw1 4831838208 1.6567\n", 203 | "mlp/ffw2 4831838208 1.6567\n", 204 | "mlp 9663676416 3.3135\n", 205 | "block 17716740096 6.0747\n", 206 | "transformer 212600881152 72.8963\n", 207 | "dense 79047426048 27.1037\n", 208 | "forward_total 291648307200 100.0000\n", 209 | "backward_total 583296614400 200.0000\n", 210 | "total 874944921600 300.0000\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "def flops():\n", 216 | " # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant\n", 217 | " # we count actual FLOPs, not MACs. Hence 2* all over the place\n", 218 | " # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D\n", 219 | "\n", 220 | " out = OrderedDict()\n", 221 | " head_size = n_embd // n_head\n", 222 | "\n", 223 | " # attention blocks\n", 224 | " # 1) the projection to key, query, values\n", 225 | " out['attention/kqv'] = 2 * block_size * (n_embd * 3*n_embd)\n", 226 | " # 2) calculating the attention scores\n", 227 | " out['attention/scores'] = 2 * block_size * block_size * n_embd\n", 228 | " # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n", 229 | " out['attention/reduce'] = 2 * n_head * (block_size * block_size * head_size)\n", 230 | " # 4) the final linear projection\n", 231 | " out['attention/proj'] = 2 * block_size * (n_embd * n_embd)\n", 232 | " out['attention'] = sum(out['attention/'+k] for k in ['kqv', 'scores', 'reduce', 'proj'])\n", 233 | "\n", 234 | " # MLP blocks\n", 235 | " ffw_size = 4*n_embd # feed forward size\n", 236 | " out['mlp/ffw1'] = 2 * block_size * (n_embd * ffw_size)\n", 237 | " out['mlp/ffw2'] = 2 * block_size * (ffw_size * n_embd)\n", 238 | " out['mlp'] = out['mlp/ffw1'] + out['mlp/ffw2']\n", 239 | "\n", 240 | " # the transformer and the rest of it\n", 241 | " out['block'] = out['attention'] + out['mlp']\n", 242 | " out['transformer'] = n_layer * out['block']\n", 243 | " out['dense'] = 2 * block_size * (n_embd * vocab_size)\n", 244 | "\n", 245 | " # forward,backward,total\n", 246 | " out['forward_total'] = out['transformer'] + out['dense']\n", 247 | " out['backward_total'] = 2 * out['forward_total'] # use common estimate of bwd = 2*fwd\n", 248 | " out['total'] = out['forward_total'] + out['backward_total']\n", 249 | "\n", 250 | " return out\n", 251 | " \n", 252 | "# compare our param count to that reported by PyTorch\n", 253 | "f = flops()\n", 254 | "flops_total = f['forward_total']\n", 255 | "print(f\"{'name':20s} {'flops':14s} {'ratio (%)':10s}\")\n", 256 | "for k,v in f.items():\n", 257 | " print(f\"{k:20s} {v:14d} {v/flops_total*100:10.4f}\")\n", 258 | " " 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 7, 264 | "metadata": {}, 265 | "outputs": [ 266 | { 267 | "name": "stdout", 268 | "output_type": "stream", 269 | "text": [ 270 | "palm_flops: 875062886400, flops: 874944921600, ratio: 1.0001\n" 271 | ] 272 | } 273 | ], 274 | "source": [ 275 | "# now here is an estimate copy pasted from the PaLM paper\n", 276 | "# this formula is often used to calculate MFU (model flops utilization)\n", 277 | "def palm_flops():\n", 278 | " \"\"\"estimate of the model flops following PaLM paper formula\"\"\"\n", 279 | " # non-embedding model parameters. note that we do not subtract the\n", 280 | " # embedding/token params because those are tied and get used in the last layer.\n", 281 | " N = params()['total'] - params()['emebedding/position']\n", 282 | " L, H, Q, T = n_layer, n_head, n_embd//n_head, block_size\n", 283 | " mf_per_token = 6*N + 12*L*H*Q*T\n", 284 | " mf = mf_per_token * block_size\n", 285 | " return mf\n", 286 | "\n", 287 | "print(f\"palm_flops: {palm_flops():d}, flops: {flops()['total']:d}, ratio: {palm_flops()/flops()['total']:.4f}\")" 288 | ] 289 | }, 290 | { 291 | "attachments": {}, 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "Ok they are quite similar, giving some confidence that my math in flops() function was ~ok. Now, A100 is cited at 312TFLOPS bfloat16 on tensor cores. So what is our model flops utilization (MFU)? I trained the model above with a batch_size of 20 and grad_accum of 5, which runs in about 755ms on a single A100 GPU. We get:" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 8, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "fraction of A100 used: 37.14%\n" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "# here is what we currently roughly measure\n", 313 | "batch_size = 20 * 5 # 5 is grad_accum, so total batch size is 100\n", 314 | "measured_time = 0.755 # in seconds per iteration\n", 315 | "measured_throughput = batch_size / measured_time\n", 316 | "flops_achieved = f['total'] * measured_throughput\n", 317 | "\n", 318 | "# A100 is cited to be 312 TFLOPS of bloat16 running on tensor cores\n", 319 | "a100_flops_promised = 312e12\n", 320 | "\n", 321 | "# the fraction of the A100 that we are using:\n", 322 | "print(f\"fraction of A100 used: {flops_achieved / a100_flops_promised * 100:.2f}%\")" 323 | ] 324 | }, 325 | { 326 | "attachments": {}, 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "For reference, we'd prefer to be somewhere around 50%+, and not just for a single GPU but for an entire DDP run. So we still have some work to do, but at least we're within a factor of ~2X of what is achievable with this GPU." 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 9, 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "name": "stdout", 340 | "output_type": "stream", 341 | "text": [ 342 | "time needed to train the model: 3.46 days\n" 343 | ] 344 | } 345 | ], 346 | "source": [ 347 | "# Finally let's check out the 6ND approximation as total cost of training in FLOPs\n", 348 | "model_size = params()['total'] # this is number of parameters, N\n", 349 | "tokens_num = 300e9 # 300B tokens, this is dataset size in tokens, D\n", 350 | "a100_flops = 312e12 # 312 TFLOPS\n", 351 | "assumed_mfu = 0.3 # assume this model flops utilization (take the current 37% from above and add some DDP overhead)\n", 352 | "flops_throughput = a100_flops * 8 * assumed_mfu # assume an 8XA100 node at 30% utilization\n", 353 | "flops_needed = 6 * model_size * tokens_num # 6ND\n", 354 | "time_needed_s = flops_needed / flops_throughput # in seconds\n", 355 | "print(f\"time needed to train the model: {time_needed_s/3600/24:.2f} days\")" 356 | ] 357 | }, 358 | { 359 | "attachments": {}, 360 | "cell_type": "markdown", 361 | "metadata": {}, 362 | "source": [ 363 | "This is not a bad estimate at all. I trained this model and it converged in roughly 4 days. Btw as a good reference for where 6ND comes from and some intuition around it I recommend [Dzmitry's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4)." 364 | ] 365 | }, 366 | { 367 | "attachments": {}, 368 | "cell_type": "markdown", 369 | "metadata": {}, 370 | "source": [ 371 | "Now, FLOPs are just one constraint, the other that we have to keep a close track of is the memory bandwidth. TODO estimate LOAD/STORE costs of our model later." 372 | ] 373 | } 374 | ], 375 | "metadata": { 376 | "kernelspec": { 377 | "display_name": "pytorch2", 378 | "language": "python", 379 | "name": "python3" 380 | }, 381 | "language_info": { 382 | "codemirror_mode": { 383 | "name": "ipython", 384 | "version": 3 385 | }, 386 | "file_extension": ".py", 387 | "mimetype": "text/x-python", 388 | "name": "python", 389 | "nbconvert_exporter": "python", 390 | "pygments_lexer": "ipython3", 391 | "version": "3.10.8" 392 | }, 393 | "orig_nbformat": 4, 394 | "vscode": { 395 | "interpreter": { 396 | "hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222" 397 | } 398 | } 399 | }, 400 | "nbformat": 4, 401 | "nbformat_minor": 2 402 | } 403 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # nanoGPT-mup 3 | 4 | This repository is a fork of [nanoGPT](https://github.com/karpathy/nanoGPT) that provides a minimal implementation of the [maximal update parameterization](https://arxiv.org/abs/2203.03466) ([muP](https://github.com/microsoft/mup)). 5 | 6 | Branches 7 | - The [master](https://github.com/EleutherAI/nanoGPT-mup) branch acts as supplementary material for ["The Practitioner’s Guide to the Maximal Update Parameterization"](https://www.cerebras.ai/blog/the-practitioners-guide-to-the-maximal-update-parameterization). 8 | - The [supar](https://github.com/EleutherAI/nanoGPT-mup/tree/supar) branch contains a minimal implementation of sparse maximal update parameterization (SuPar) introduced in [Sparse maximal update parameterization: A holistic approach to sparse training dynamics](https://arxiv.org/abs/2405.15743). 9 | - The [completep](https://github.com/EleutherAI/nanoGPT-mup/tree/completep) branch contains a minimal implementation of CompleteP introduced in [Don't be lazy: CompleteP enables compute-efficient deep transformers](https://arxiv.org/abs/2505.01618). 10 | 11 | The [mup_examples](https://github.com/EleutherAI/nanoGPT-mup/tree/master/mup_examples) folder contains scripts to reproduce the plots in ["The Practitioner’s Guide to the Maximal Update Parameterization"](https://www.cerebras.ai/blog/the-practitioners-guide-to-the-maximal-update-parameterization) (see [mup_examples/README.md](https://github.com/EleutherAI/nanoGPT-mup/blob/master/mup_examples/README.md) for instructions to reproduce). 12 | 13 | Each of the critical muP changes are marked with 14 | ``` 15 | ### Begin muP code ### 16 | 17 | ### End muP code ### 18 | ``` 19 | to make everything easily searchable. 20 | 21 | | Parameterization | SP | **μP** | Code | 22 | |------------------|----|----|----| 23 | | Embedding Init. Var. | $σ_{base}^2$ | $σ_{base}^2$ | | 24 | | Embedding LR | $η_{base}$ | $η_{base}$ | | 25 | | Embedding Fwd. | $x W_{\text{emb}}$ | $\mathbf{α_{input}} · x W_{\text{emb}}$ | [Code](https://github.com/EleutherAI/nanoGPT-mup/blob/bcadbc3c7a44138525eca8a799764afba7dca2b3/model.py#L208) | 26 | | Hidden Init. Var. | $σ_{base}^2$ | $σ_{base}^2 / \mathbf{m_d}$ | [Code](https://github.com/EleutherAI/nanoGPT-mup/blob/bcadbc3c7a44138525eca8a799764afba7dca2b3/model.py#L163-L169) | 27 | | Hidden LR (Adam) | $η_{base}$ | $η_{base} / \mathbf{m_d}$ | [Code](https://github.com/EleutherAI/nanoGPT-mup/blob/bcadbc3c7a44138525eca8a799764afba7dca2b3/model.py#L306-L329) | 28 | | Output Logit Fwd. | $x W_{\text{emb}}^\top$ | $\mathbf{α_{output}} · x W_{\text{emb}}^\top / \mathbf{m_d}$ | [Code](https://github.com/EleutherAI/nanoGPT-mup/blob/bcadbc3c7a44138525eca8a799764afba7dca2b3/model.py#L219) | 29 | | Attention logits | $Q^\top K / \sqrt{d_{\text{head}}}$ | $Q^\top K / \mathbf{d_{\text{head}}}$ | [Code](https://github.com/EleutherAI/nanoGPT-mup/blob/bcadbc3c7a44138525eca8a799764afba7dca2b3/model.py#L65) | 30 | 31 | 32 | ## Implementation Validation 33 | 34 | ### Coordinate Checks 35 | 36 | Standard Parameterization: 37 | 38 | SP 39 | 40 | muTransfer: 41 | 42 | muP 43 | 44 | 45 | ### Learning Rate muTransfer 46 | 47 | **Tiny Shakespeare** | **OpenWebText** 48 | :-------------------------:|:-------------------------: 49 | mup-shakespeare | mup-owt 50 | 51 | 52 | ## Citation 53 | 54 | If ["The Practitioner’s Guide to the Maximal Update Parameterization"](https://www.cerebras.ai/blog/the-practitioners-guide-to-the-maximal-update-parameterization) or this repository was useful to you, please cite: 55 | ``` 56 | @misc{cerebras2024mupguide, 57 | author = {Dey, Nolan and Anthony, Quentin and Hestness, Joel}, 58 | title = {{The practitioner’s guide to the maximal update parameterization}}, 59 | month = September, 60 | year = 2024, 61 | howpublished = {\url{https://www.cerebras.ai/blog/the-practitioners-guide-to-the-maximal-update-parameterization}}, 62 | url = \url{https://www.cerebras.ai/blog/the-practitioners-guide-to-the-maximal-update-parameterization}, 63 | } 64 | ``` 65 | 66 | # nanoGPT (Original README) 67 | 68 | ![nanoGPT](assets/nanogpt.jpg) 69 | 70 | The simplest, fastest repository for training/finetuning medium-sized GPTs. It is a rewrite of [minGPT](https://github.com/karpathy/minGPT) that prioritizes teeth over education. Still under active development, but currently the file `train.py` reproduces GPT-2 (124M) on OpenWebText, running on a single 8XA100 40GB node in about 4 days of training. The code itself is plain and readable: `train.py` is a ~300-line boilerplate training loop and `model.py` a ~300-line GPT model definition, which can optionally load the GPT-2 weights from OpenAI. That's it. 71 | 72 | ![repro124m](assets/gpt2_124M_loss.png) 73 | 74 | Because the code is so simple, it is very easy to hack to your needs, train new models from scratch, or finetune pretrained checkpoints (e.g. biggest one currently available as a starting point would be the GPT-2 1.3B model from OpenAI). 75 | 76 | ## install 77 | 78 | ``` 79 | pip install torch numpy transformers datasets tiktoken wandb tqdm 80 | ``` 81 | 82 | Dependencies: 83 | 84 | - [pytorch](https://pytorch.org) <3 85 | - [numpy](https://numpy.org/install/) <3 86 | - `transformers` for huggingface transformers <3 (to load GPT-2 checkpoints) 87 | - `datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText) 88 | - `tiktoken` for OpenAI's fast BPE code <3 89 | - `wandb` for optional logging <3 90 | - `tqdm` for progress bars <3 91 | 92 | ## quick start 93 | 94 | If you are not a deep learning professional and you just want to feel the magic and get your feet wet, the fastest way to get started is to train a character-level GPT on the works of Shakespeare. First, we download it as a single (1MB) file and turn it from raw text into one large stream of integers: 95 | 96 | ```sh 97 | python data/shakespeare_char/prepare.py 98 | ``` 99 | 100 | This creates a `train.bin` and `val.bin` in that data directory. Now it is time to train your GPT. The size of it very much depends on the computational resources of your system: 101 | 102 | **I have a GPU**. Great, we can quickly train a baby GPT with the settings provided in the [config/train_shakespeare_char.py](config/train_shakespeare_char.py) config file: 103 | 104 | ```sh 105 | python train.py config/train_shakespeare_char.py 106 | ``` 107 | 108 | If you peek inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory: 109 | 110 | ```sh 111 | python sample.py --out_dir=out-shakespeare-char 112 | ``` 113 | 114 | This generates a few samples, for example: 115 | 116 | ``` 117 | ANGELO: 118 | And cowards it be strawn to my bed, 119 | And thrust the gates of my threats, 120 | Because he that ale away, and hang'd 121 | An one with him. 122 | 123 | DUKE VINCENTIO: 124 | I thank your eyes against it. 125 | 126 | DUKE VINCENTIO: 127 | Then will answer him to save the malm: 128 | And what have you tyrannous shall do this? 129 | 130 | DUKE VINCENTIO: 131 | If you have done evils of all disposition 132 | To end his power, the day of thrust for a common men 133 | That I leave, to fight with over-liking 134 | Hasting in a roseman. 135 | ``` 136 | 137 | lol `¯\_(ツ)_/¯`. Not bad for a character-level model after 3 minutes of training on a GPU. Better results are quite likely obtainable by instead finetuning a pretrained GPT-2 model on this dataset (see finetuning section later). 138 | 139 | **I only have a macbook** (or other cheap computer). No worries, we can still train a GPT but we want to dial things down a notch. I recommend getting the bleeding edge PyTorch nightly ([select it here](https://pytorch.org/get-started/locally/) when installing) as it is currently quite likely to make your code more efficient. But even without it, a simple train run could look as follows: 140 | 141 | ```sh 142 | python train.py config/train_shakespeare_char.py --device=cpu --compile=False --eval_iters=20 --log_interval=1 --block_size=64 --batch_size=12 --n_layer=4 --n_head=4 --n_embd=128 --max_iters=2000 --lr_decay_iters=2000 --dropout=0.0 143 | ``` 144 | 145 | Here, since we are running on CPU instead of GPU we must set both `--device=cpu` and also turn off PyTorch 2.0 compile with `--compile=False`. Then when we evaluate we get a bit more noisy but faster estimate (`--eval_iters=20`, down from 200), our context size is only 64 characters instead of 256, and the batch size only 12 examples per iteration, not 64. We'll also use a much smaller Transformer (4 layers, 4 heads, 128 embedding size), and decrease the number of iterations to 2000 (and correspondingly usually decay the learning rate to around max_iters with `--lr_decay_iters`). Because our network is so small we also ease down on regularization (`--dropout=0.0`). This still runs in about ~3 minutes, but gets us a loss of only 1.88 and therefore also worse samples, but it's still good fun: 146 | 147 | ```sh 148 | python sample.py --out_dir=out-shakespeare-char --device=cpu 149 | ``` 150 | Generates samples like this: 151 | 152 | ``` 153 | GLEORKEN VINGHARD III: 154 | Whell's the couse, the came light gacks, 155 | And the for mought you in Aut fries the not high shee 156 | bot thou the sought bechive in that to doth groan you, 157 | No relving thee post mose the wear 158 | ``` 159 | 160 | Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer, feel free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc. 161 | 162 | Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device=mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more. 163 | 164 | ## reproducing GPT-2 165 | 166 | A more serious deep learning professional may be more interested in reproducing GPT-2 results. So here we go - we first tokenize the dataset, in this case the [OpenWebText](https://openwebtext2.readthedocs.io/en/latest/), an open reproduction of OpenAI's (private) WebText: 167 | 168 | ```sh 169 | python data/openwebtext/prepare.py 170 | ``` 171 | 172 | This downloads and tokenizes the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. It will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. To reproduce GPT-2 (124M) you'll want at least an 8X A100 40GB node and run: 173 | 174 | ```sh 175 | torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py 176 | ``` 177 | 178 | This will run for about 4 days using PyTorch Distributed Data Parallel (DDP) and go down to loss of ~2.85. Now, a GPT-2 model just evaluated on OWT gets a val loss of about 3.11, but if you finetune it it will come down to ~2.85 territory (due to an apparent domain gap), making the two models ~match. 179 | 180 | If you're in a cluster environment and you are blessed with multiple GPU nodes you can make GPU go brrrr e.g. across 2 nodes like: 181 | 182 | ```sh 183 | # Run on the first (master) node with example IP 123.456.123.456: 184 | torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 185 | # Run on the worker node: 186 | torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 187 | ``` 188 | 189 | It is a good idea to benchmark your interconnect (e.g. iperf3). In particular, if you don't have Infiniband then also prepend `NCCL_IB_DISABLE=1` to the above launches. Your multinode training will work, but most likely _crawl_. By default checkpoints are periodically written to the `--out_dir`. We can sample from the model by simply `python sample.py`. 190 | 191 | Finally, to train on a single GPU simply run the `python train.py` script. Have a look at all of its args, the script tries to be very readable, hackable and transparent. You'll most likely want to tune a number of those variables depending on your needs. 192 | 193 | ## baselines 194 | 195 | OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows: 196 | 197 | ```sh 198 | $ python train.py config/eval_gpt2.py 199 | $ python train.py config/eval_gpt2_medium.py 200 | $ python train.py config/eval_gpt2_large.py 201 | $ python train.py config/eval_gpt2_xl.py 202 | ``` 203 | 204 | and observe the following losses on train and val: 205 | 206 | | model | params | train loss | val loss | 207 | | ------| ------ | ---------- | -------- | 208 | | gpt2 | 124M | 3.11 | 3.12 | 209 | | gpt2-medium | 350M | 2.85 | 2.84 | 210 | | gpt2-large | 774M | 2.66 | 2.67 | 211 | | gpt2-xl | 1558M | 2.56 | 2.54 | 212 | 213 | However, we have to note that GPT-2 was trained on (closed, never released) WebText, while OpenWebText is just a best-effort open reproduction of this dataset. This means there is a dataset domain gap. Indeed, taking the GPT-2 (124M) checkpoint and finetuning on OWT directly for a while reaches loss down to ~2.85. This then becomes the more appropriate baseline w.r.t. reproduction. 214 | 215 | ## finetuning 216 | 217 | Finetuning is no different than training, we just make sure to initialize from a pretrained model and train with a smaller learning rate. For an example of how to finetune a GPT on new text go to `data/shakespeare` and run `prepare.py` to download the tiny shakespeare dataset and render it into a `train.bin` and `val.bin`, using the OpenAI BPE tokenizer from GPT-2. Unlike OpenWebText this will run in seconds. Finetuning can take very little time, e.g. on a single GPU just a few minutes. Run an example finetuning like: 218 | 219 | ```sh 220 | python train.py config/finetune_shakespeare.py 221 | ``` 222 | 223 | This will load the config parameter overrides in `config/finetune_shakespeare.py` (I didn't tune them much though). Basically, we initialize from a GPT2 checkpoint with `init_from` and train as normal, except shorter and with a small learning rate. If you're running out of memory try decreasing the model size (they are `{'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}`) or possibly decreasing the `block_size` (context length). The best checkpoint (lowest validation loss) will be in the `out_dir` directory, e.g. in `out-shakespeare` by default, per the config file. You can then run the code in `sample.py --out_dir=out-shakespeare`: 224 | 225 | ``` 226 | THEODORE: 227 | Thou shalt sell me to the highest bidder: if I die, 228 | I sell thee to the first; if I go mad, 229 | I sell thee to the second; if I 230 | lie, I sell thee to the third; if I slay, 231 | I sell thee to the fourth: so buy or sell, 232 | I tell thee again, thou shalt not sell my 233 | possession. 234 | 235 | JULIET: 236 | And if thou steal, thou shalt not sell thyself. 237 | 238 | THEODORE: 239 | I do not steal; I sell the stolen goods. 240 | 241 | THEODORE: 242 | Thou know'st not what thou sell'st; thou, a woman, 243 | Thou art ever a victim, a thing of no worth: 244 | Thou hast no right, no right, but to be sold. 245 | ``` 246 | 247 | Whoa there, GPT, entering some dark place over there. I didn't really tune the hyperparameters in the config too much, feel free to try! 248 | 249 | ## sampling / inference 250 | 251 | Use the script `sample.py` to sample either from pre-trained GPT-2 models released by OpenAI, or from a model you trained yourself. For example, here is a way to sample from the largest available `gpt2-xl` model: 252 | 253 | ```sh 254 | python sample.py \ 255 | --init_from=gpt2-xl \ 256 | --start="What is the answer to life, the universe, and everything?" \ 257 | --num_samples=5 --max_new_tokens=100 258 | ``` 259 | 260 | If you'd like to sample from a model you trained, use the `--out_dir` to point the code appropriately. You can also prompt the model with some text from a file, e.g. ```python sample.py --start=FILE:prompt.txt```. 261 | 262 | ## efficiency notes 263 | 264 | For simple model benchmarking and profiling, `bench.py` might be useful. It's identical to what happens in the meat of the training loop of `train.py`, but omits much of the other complexities. 265 | 266 | Note that the code by default uses [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/). At the time of writing (Dec 29, 2022) this makes `torch.compile()` available in the nightly release. The improvement from the one line of code is noticeable, e.g. cutting down iteration time from ~250ms / iter to 135ms / iter. Nice work PyTorch team! 267 | 268 | ## todos 269 | 270 | - Investigate and add FSDP instead of DDP 271 | - Eval zero-shot perplexities on standard evals (e.g. LAMBADA? HELM? etc.) 272 | - Finetune the finetuning script, I think the hyperparams are not great 273 | - Schedule for linear batch size increase during training 274 | - Incorporate other embeddings (rotary, alibi) 275 | - Separate out the optim buffers from model params in checkpoints I think 276 | - Additional logging around network health (e.g. gradient clip events, magnitudes) 277 | - Few more investigations around better init etc. 278 | 279 | ## troubleshooting 280 | 281 | Note that by default this repo uses PyTorch 2.0 (i.e. `torch.compile`). This is fairly new and experimental, and not yet available on all platforms (e.g. Windows). If you're running into related error messages try to disable this by adding `--compile=False` flag. This will slow down the code but at least it will run. 282 | 283 | For some context on this repository, GPT, and language modeling it might be helpful to watch my [Zero To Hero series](https://karpathy.ai/zero-to-hero.html). Specifically, the [GPT video](https://www.youtube.com/watch?v=kCc8FmEb1nY) is popular if you have some prior language modeling context. 284 | 285 | For more questions/discussions feel free to stop by **#nanoGPT** on Discord: 286 | 287 | [![](https://dcbadge.vercel.app/api/server/3zy8kqD9Cp?compact=true&style=flat)](https://discord.gg/3zy8kqD9Cp) 288 | 289 | ## acknowledgements 290 | 291 | All nanoGPT experiments are powered by GPUs on [Lambda labs](https://lambdalabs.com), my favorite Cloud GPU provider. Thank you Lambda labs for sponsoring nanoGPT! 292 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This training script can be run both on a single gpu in debug mode, 3 | and also in a larger training run with distributed data parallel (ddp). 4 | 5 | To run on a single GPU, example: 6 | $ python train.py --batch_size=32 --compile=False 7 | 8 | To run with DDP on 4 gpus on 1 node, example: 9 | $ torchrun --standalone --nproc_per_node=4 train.py 10 | 11 | To run with DDP on 4 gpus across 2 nodes, example: 12 | - Run on the first (master) node with example IP 123.456.123.456: 13 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 14 | - Run on the worker node: 15 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 16 | (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1) 17 | """ 18 | 19 | import os 20 | import time 21 | import math 22 | import pickle 23 | from contextlib import nullcontext 24 | from functools import partial 25 | 26 | import numpy as np 27 | import torch 28 | from torch.nn.parallel import DistributedDataParallel as DDP 29 | from torch.distributed import init_process_group, destroy_process_group 30 | 31 | from model import GPTConfig, GPT 32 | 33 | # ----------------------------------------------------------------------------- 34 | # default config values designed to train a gpt2 (124M) on OpenWebText 35 | # I/O 36 | out_dir = 'out' 37 | eval_interval = 2000 38 | log_interval = 1 39 | eval_iters = 200 40 | eval_only = False # if True, script exits right after the first eval 41 | skip_val_loss = False # If True, will only measure train loss 42 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 43 | never_save_checkpoint = False # if True, never save a checkpoint 44 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 45 | # wandb logging 46 | wandb_log = False # disabled by default 47 | wandb_project = 'owt' 48 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 49 | # csv logging 50 | csv_log = False # If enabled, logs stats to a csv file 51 | flush_every = 100 # how often to flush, set to 0 to only flush on close 52 | # data 53 | dataset = 'openwebtext' 54 | gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes 55 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 56 | block_size = 1024 57 | # model 58 | n_layer = 12 59 | n_head = 12 60 | n_embd = 768 61 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 62 | bias = False # do we use bias inside LayerNorm and Linear layers? 63 | init_std = 0.02 # Initialization standard deviation for weights 64 | # adamw optimizer 65 | learning_rate = 6e-4 # max learning rate 66 | max_iters = 600000 # total number of training iterations 67 | weight_decay = 1e-1 68 | beta1 = 0.9 69 | beta2 = 0.95 70 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 71 | # learning rate decay settings 72 | decay_lr = True # whether to decay the learning rate 73 | warmup_iters = 2000 # how many steps to warm up for 74 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 75 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 76 | # mup settings 77 | mup_enabled = False # Whether to use muP. If False then all other mup variables are ignored 78 | mup_disable_attention_scaling = False # Uses 1/sqrt(d_head) attn scaling instead of 1/d_head (Only needed for the step-by-step coord check in the blog) 79 | mup_disable_hidden_lr_scaling = False # Disables muP hidden LR adjustment (Only needed for the step-by-step coord check in the blog) 80 | mup_width_multiplier = 1.0 # mup_width_multiplier = width / base_width where base_width is typically 256 81 | mup_input_alpha = 1.0 # Optional tunable multiplier applied to input embedding forward pass output 82 | mup_output_alpha = 1.0 # Optional tunable multiplier applied to output unembedding forward pass output 83 | mup_enable_coord_check_logging = False # If True will track the output.abs().mean() of various layers throughout training 84 | # seed 85 | seed = 1337 86 | # DDP settings 87 | backend = 'nccl' # 'nccl', 'gloo', etc. 88 | # system 89 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 90 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 91 | compile = True # use PyTorch 2.0 to compile the model to be faster 92 | # ----------------------------------------------------------------------------- 93 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 94 | exec(open('configurator.py').read()) # overrides from command line or config file 95 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 96 | # ----------------------------------------------------------------------------- 97 | 98 | assert not (never_save_checkpoint and always_save_checkpoint) 99 | 100 | # various inits, derived attributes, I/O setup 101 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 102 | if ddp: 103 | init_process_group(backend=backend) 104 | ddp_rank = int(os.environ['RANK']) 105 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 106 | ddp_world_size = int(os.environ['WORLD_SIZE']) 107 | device = f'cuda:{ddp_local_rank}' 108 | torch.cuda.set_device(device) 109 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 110 | seed_offset = ddp_rank # each process gets a different seed 111 | # world_size number of processes will be training simultaneously, so we can scale 112 | # down the desired gradient accumulation iterations per process proportionally 113 | assert gradient_accumulation_steps % ddp_world_size == 0 114 | gradient_accumulation_steps //= ddp_world_size 115 | else: 116 | # if not ddp, we are running on a single gpu, and one process 117 | master_process = True 118 | seed_offset = 0 119 | ddp_world_size = 1 120 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size 121 | print(f"tokens per iteration will be: {tokens_per_iter:,}") 122 | 123 | if master_process: 124 | os.makedirs(out_dir, exist_ok=True) 125 | torch.manual_seed(seed + seed_offset) 126 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 127 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 128 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 129 | # note: float16 data type will automatically use a GradScaler 130 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 131 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 132 | 133 | # poor man's data loader 134 | data_dir = os.path.join('data', dataset) 135 | def get_batch(split): 136 | # We recreate np.memmap every batch to avoid a memory leak, as per 137 | # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 138 | if split == 'train': 139 | data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 140 | else: 141 | data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 142 | ix = torch.randint(len(data) - block_size, (batch_size,)) 143 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 144 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 145 | if device_type == 'cuda': 146 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 147 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 148 | else: 149 | x, y = x.to(device), y.to(device) 150 | return x, y 151 | 152 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 153 | iter_num = 0 154 | best_val_loss = 1e9 155 | 156 | # attempt to derive vocab_size from the dataset 157 | meta_path = os.path.join(data_dir, 'meta.pkl') 158 | meta_vocab_size = None 159 | if os.path.exists(meta_path): 160 | with open(meta_path, 'rb') as f: 161 | meta = pickle.load(f) 162 | meta_vocab_size = meta['vocab_size'] 163 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 164 | 165 | # model init 166 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 167 | bias=bias, vocab_size=None, dropout=dropout, mup_enabled=mup_enabled, 168 | mup_disable_attention_scaling=mup_disable_attention_scaling, 169 | mup_disable_hidden_lr_scaling=mup_disable_hidden_lr_scaling, 170 | mup_width_multiplier=mup_width_multiplier, mup_input_alpha=mup_input_alpha, 171 | mup_output_alpha=mup_output_alpha) # start with model_args from command line 172 | 173 | if init_from == 'scratch': 174 | # init a new model from scratch 175 | print("Initializing a new model from scratch") 176 | # determine the vocab size we'll use for from-scratch training 177 | if meta_vocab_size is None: 178 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 179 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 180 | gptconf = GPTConfig(**model_args) 181 | model = GPT(gptconf) 182 | elif init_from == 'resume': 183 | print(f"Resuming training from {out_dir}") 184 | # resume training from a checkpoint. 185 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 186 | checkpoint = torch.load(ckpt_path, map_location=device) 187 | checkpoint_model_args = checkpoint['model_args'] 188 | # force these config attributes to be equal otherwise we can't even resume training 189 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 190 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 191 | model_args[k] = checkpoint_model_args[k] 192 | # create the model 193 | gptconf = GPTConfig(**model_args) 194 | model = GPT(gptconf) 195 | state_dict = checkpoint['model'] 196 | # fix the keys of the state dictionary :( 197 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 198 | unwanted_prefix = '_orig_mod.' 199 | for k,v in list(state_dict.items()): 200 | if k.startswith(unwanted_prefix): 201 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 202 | model.load_state_dict(state_dict) 203 | iter_num = checkpoint['iter_num'] 204 | best_val_loss = checkpoint['best_val_loss'] 205 | elif init_from.startswith('gpt2'): 206 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 207 | # initialize from OpenAI GPT-2 weights 208 | override_args = dict(dropout=dropout) 209 | model = GPT.from_pretrained(init_from, override_args) 210 | # read off the created config params, so we can store them into checkpoint correctly 211 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 212 | model_args[k] = getattr(model.config, k) 213 | # crop down the model block size if desired, using model surgery 214 | if block_size < model.config.block_size: 215 | model.crop_block_size(block_size) 216 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 217 | model.to(device) 218 | 219 | # initialize a GradScaler. If enabled=False scaler is a no-op 220 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 221 | 222 | # optimizer 223 | optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) 224 | if init_from == 'resume': 225 | optimizer.load_state_dict(checkpoint['optimizer']) 226 | checkpoint = None # free up memory 227 | 228 | # compile the model 229 | if compile: 230 | print("compiling the model... (takes a ~minute)") 231 | unoptimized_model = model 232 | model = torch.compile(model) # requires PyTorch 2.0 233 | 234 | # wrap model into DDP container 235 | if ddp: 236 | model = DDP(model, device_ids=[ddp_local_rank]) 237 | 238 | # helps estimate an arbitrarily accurate loss over either split using many batches 239 | @torch.no_grad() 240 | def estimate_loss(): 241 | out = {} 242 | model.eval() 243 | splits = ['train'] if skip_val_loss else ['train', 'val'] 244 | for split in splits: 245 | losses = torch.zeros(eval_iters) 246 | for k in range(eval_iters): 247 | X, Y = get_batch(split) 248 | with ctx: 249 | logits, loss = model(X, Y) 250 | losses[k] = loss.item() 251 | out[split] = losses.mean().item() 252 | if skip_val_loss: 253 | out['val'] = -1 254 | model.train() 255 | return out 256 | 257 | # learning rate decay scheduler (cosine with warmup) 258 | def get_lr(it): 259 | # 1) linear warmup for warmup_iters steps 260 | if it < warmup_iters: 261 | return learning_rate * it / warmup_iters 262 | # 2) if it > lr_decay_iters, return min learning rate 263 | if it > lr_decay_iters: 264 | return min_lr 265 | # 3) in between, use cosine decay down to min learning rate 266 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 267 | assert 0 <= decay_ratio <= 1 268 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 269 | return min_lr + coeff * (learning_rate - min_lr) 270 | 271 | # logging 272 | if master_process: 273 | if wandb_log: 274 | import wandb 275 | wandb_run = wandb.init(project=wandb_project, name=wandb_run_name, config=config) 276 | if csv_log: 277 | from csv_logging import CSVLogWrapper 278 | def log(log_dict): 279 | pass 280 | csv_logger = CSVLogWrapper(log, config=config, out_dir=out_dir, flush_every=flush_every) 281 | 282 | # training loop 283 | X, Y = get_batch('train') # fetch the very first batch 284 | t0 = time.time() 285 | local_iter_num = 0 # number of iterations in the lifetime of this process 286 | raw_model = model.module if ddp else model # unwrap DDP container if needed 287 | running_mfu = -1.0 288 | coord_check_dict = None 289 | while True: 290 | 291 | # determine and set the learning rate for this iteration 292 | lr = get_lr(iter_num) if decay_lr else learning_rate 293 | for param_group in optimizer.param_groups: 294 | param_group['lr'] = lr * param_group.get('lr_scale', 1.0) 295 | 296 | # evaluate the loss on train/val sets and write checkpoints 297 | if iter_num % eval_interval == 0 and master_process: 298 | losses = estimate_loss() 299 | if np.isnan(losses['train']): 300 | raise Exception('NaN loss') 301 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 302 | log_dict = { 303 | "iter": iter_num, 304 | "train/loss": losses['train'], 305 | "val/loss": losses['val'], 306 | "lr": lr, 307 | "mfu": running_mfu*100, # convert to percentage 308 | } 309 | if mup_enable_coord_check_logging and coord_check_dict is not None: 310 | for key in coord_check_dict: 311 | log_dict[key + '_act_abs_mean'] = np.mean(coord_check_dict[key]) 312 | if wandb_log: 313 | wandb_run.log(log_dict) 314 | if csv_log: 315 | csv_logger.log(log_dict) 316 | csv_logger.step() 317 | if (not never_save_checkpoint) and (losses['val'] < best_val_loss or always_save_checkpoint): 318 | best_val_loss = losses['val'] 319 | if iter_num > 0: 320 | checkpoint = { 321 | 'model': raw_model.state_dict(), 322 | 'optimizer': optimizer.state_dict(), 323 | 'model_args': model_args, 324 | 'iter_num': iter_num, 325 | 'best_val_loss': best_val_loss, 326 | 'config': config, 327 | } 328 | print(f"saving checkpoint to {out_dir}") 329 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 330 | if iter_num == 0 and eval_only: 331 | break 332 | 333 | if mup_enable_coord_check_logging: 334 | coord_check_dict = { 335 | 'token_embedding': [], 336 | 'attn': [], 337 | 'mlp': [], 338 | 'lm_head': [], 339 | } 340 | def hook(module, input, output, key): 341 | with torch.no_grad(): 342 | coord_check_dict[key].append(output.abs().mean().item()) 343 | coord_check_handles = [] 344 | for module_name, module in model.named_modules(): 345 | if module_name == 'transformer.wte': 346 | coord_check_handles.append(module.register_forward_hook(partial(hook, key='token_embedding'))) 347 | elif module_name.endswith('.attn'): 348 | coord_check_handles.append(module.register_forward_hook(partial(hook, key='attn'))) 349 | elif module_name.endswith('.mlp'): 350 | coord_check_handles.append(module.register_forward_hook(partial(hook, key='mlp'))) 351 | elif module_name == 'lm_head': 352 | coord_check_handles.append(module.register_forward_hook(partial(hook, key='lm_head'))) 353 | else: 354 | coord_check_dict = None 355 | 356 | # forward backward update, with optional gradient accumulation to simulate larger batch size 357 | # and using the GradScaler if data type is float16 358 | for micro_step in range(gradient_accumulation_steps): 359 | if ddp: 360 | # in DDP training we only need to sync gradients at the last micro step. 361 | # the official way to do this is with model.no_sync() context manager, but 362 | # I really dislike that this bloats the code and forces us to repeat code 363 | # looking at the source of that context manager, it just toggles this variable 364 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 365 | with ctx: 366 | logits, loss = model(X, Y) 367 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation 368 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 369 | X, Y = get_batch('train') 370 | # backward pass, with gradient scaling if training in fp16 371 | scaler.scale(loss).backward() 372 | # clip the gradient 373 | if grad_clip != 0.0: 374 | scaler.unscale_(optimizer) 375 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 376 | # step the optimizer and scaler if training in fp16 377 | scaler.step(optimizer) 378 | scaler.update() 379 | # flush the gradients as soon as we can, no need for this memory anymore 380 | optimizer.zero_grad(set_to_none=True) 381 | 382 | # timing and logging 383 | t1 = time.time() 384 | dt = t1 - t0 385 | t0 = t1 386 | if iter_num % log_interval == 0 and master_process: 387 | # get loss as float. note: this is a CPU-GPU sync point 388 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum) 389 | lossf = loss.item() * gradient_accumulation_steps 390 | if local_iter_num >= 5: # let the training loop settle a bit 391 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 392 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 393 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 394 | iter_num += 1 395 | local_iter_num += 1 396 | 397 | if mup_enable_coord_check_logging: 398 | for handle in coord_check_handles: 399 | handle.remove() 400 | 401 | # termination conditions 402 | if iter_num > max_iters: 403 | break 404 | 405 | if ddp: 406 | destroy_process_group() 407 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full definition of a GPT Language Model, all of it in this single file. 3 | References: 4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 5 | https://github.com/openai/gpt-2/blob/master/src/model.py 6 | 2) huggingface/transformers PyTorch implementation: 7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 8 | """ 9 | 10 | import math 11 | import inspect 12 | from dataclasses import dataclass 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | 18 | class LayerNorm(nn.Module): 19 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 20 | 21 | def __init__(self, ndim, bias): 22 | super().__init__() 23 | self.weight = nn.Parameter(torch.ones(ndim)) 24 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 25 | 26 | def forward(self, input): 27 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 28 | 29 | class CausalSelfAttention(nn.Module): 30 | 31 | def __init__(self, config): 32 | super().__init__() 33 | assert config.n_embd % config.n_head == 0 34 | # key, query, value projections for all heads, but in a batch 35 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 36 | # output projection 37 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 38 | # regularization 39 | self.attn_dropout = nn.Dropout(config.dropout) 40 | self.resid_dropout = nn.Dropout(config.dropout) 41 | self.n_head = config.n_head 42 | self.n_embd = config.n_embd 43 | self.dropout = config.dropout 44 | self.mup_enabled = config.mup_enabled 45 | self.mup_disable_attention_scaling = config.mup_disable_attention_scaling 46 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 47 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 48 | if not self.flash: 49 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 50 | # causal mask to ensure that attention is only applied to the left in the input sequence 51 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 52 | .view(1, 1, config.block_size, config.block_size)) 53 | 54 | def forward(self, x): 55 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 56 | 57 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 58 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 59 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 60 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 61 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 62 | 63 | if self.mup_enabled and not self.mup_disable_attention_scaling: 64 | ### Begin muP code ### 65 | attention_scale = 1.0 / k.size(-1) 66 | ### End muP code ### 67 | else: 68 | attention_scale = 1.0 / math.sqrt(k.size(-1)) 69 | 70 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 71 | if self.flash: 72 | # efficient attention using Flash Attention CUDA kernels 73 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, 74 | dropout_p=self.dropout if self.training else 0, 75 | is_causal=True, scale=attention_scale) 76 | else: 77 | # manual implementation of attention 78 | att = (q @ k.transpose(-2, -1)) * attention_scale 79 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 80 | att = F.softmax(att, dim=-1) 81 | att = self.attn_dropout(att) 82 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 83 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 84 | 85 | # output projection 86 | y = self.resid_dropout(self.c_proj(y)) 87 | return y 88 | 89 | class MLP(nn.Module): 90 | 91 | def __init__(self, config): 92 | super().__init__() 93 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 94 | self.gelu = nn.GELU() 95 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 96 | self.dropout = nn.Dropout(config.dropout) 97 | 98 | def forward(self, x): 99 | x = self.c_fc(x) 100 | x = self.gelu(x) 101 | x = self.c_proj(x) 102 | x = self.dropout(x) 103 | return x 104 | 105 | class Block(nn.Module): 106 | 107 | def __init__(self, config): 108 | super().__init__() 109 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 110 | self.attn = CausalSelfAttention(config) 111 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 112 | self.mlp = MLP(config) 113 | 114 | def forward(self, x): 115 | x = x + self.attn(self.ln_1(x)) 116 | x = x + self.mlp(self.ln_2(x)) 117 | return x 118 | 119 | @dataclass 120 | class GPTConfig: 121 | block_size: int = 1024 122 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 123 | n_layer: int = 12 124 | n_head: int = 12 125 | n_embd: int = 768 126 | dropout: float = 0.0 127 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 128 | init_std: float = 0.02 129 | mup_enabled: bool = False # Whether to use muP. If False then all other mup variables are ignored 130 | mup_disable_attention_scaling: bool = False # Disables mup attention scaling 131 | mup_disable_hidden_lr_scaling: bool = False # Disables mup hidden LR scaling 132 | mup_width_multiplier: float = 1 # `mup_width_multiplier = width / base_width` where base_width is typically 256 133 | mup_input_alpha: float = 1 # Optional tunable multiplier applied to input embedding forward pass output 134 | mup_output_alpha: float = 1 # Optional tunable multiplier applied to output unembedding forward pass output 135 | 136 | class GPT(nn.Module): 137 | 138 | def __init__(self, config): 139 | super().__init__() 140 | assert config.vocab_size is not None 141 | assert config.block_size is not None 142 | self.config = config 143 | 144 | self.transformer = nn.ModuleDict(dict( 145 | wte = nn.Embedding(config.vocab_size, config.n_embd), 146 | wpe = nn.Embedding(config.block_size, config.n_embd), 147 | drop = nn.Dropout(config.dropout), 148 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 149 | ln_f = LayerNorm(config.n_embd, bias=config.bias), 150 | )) 151 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 152 | # with weight tying when using torch.compile() some warnings get generated: 153 | # "UserWarning: functional_call was passed multiple values for tied weights. 154 | # This behavior is deprecated and will be an error in future versions" 155 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 156 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 157 | 158 | # init all weights 159 | self.apply(self._init_weights) 160 | # apply special scaled init to the residual projections, per GPT-2 paper 161 | for pn, p in self.named_parameters(): 162 | if config.mup_enabled: 163 | ### Begin muP code ### 164 | # Adjust hidden weight initialization variance by 1 / mup_width_multiplier 165 | if pn.endswith('c_attn.weight') or pn.endswith('c_fc.weight'): 166 | torch.nn.init.normal_(p, mean=0.0, std=config.init_std / math.sqrt(config.mup_width_multiplier)) 167 | elif pn.endswith('c_proj.weight'): 168 | torch.nn.init.normal_(p, mean=0.0, std=config.init_std / math.sqrt(2 * config.n_layer * config.mup_width_multiplier)) 169 | ### End muP code ### 170 | elif pn.endswith('c_proj.weight'): 171 | torch.nn.init.normal_(p, mean=0.0, std=config.init_std / math.sqrt(2 * config.n_layer)) 172 | 173 | # report number of parameters 174 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 175 | 176 | def get_num_params(self, non_embedding=True): 177 | """ 178 | Return the number of parameters in the model. 179 | For non-embedding count (default), the position embeddings get subtracted. 180 | The token embeddings would too, except due to the parameter sharing these 181 | params are actually used as weights in the final layer, so we include them. 182 | """ 183 | n_params = sum(p.numel() for p in self.parameters()) 184 | if non_embedding: 185 | n_params -= self.transformer.wpe.weight.numel() 186 | return n_params 187 | 188 | def _init_weights(self, module): 189 | if isinstance(module, nn.Linear): 190 | torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) 191 | if module.bias is not None: 192 | torch.nn.init.zeros_(module.bias) 193 | elif isinstance(module, nn.Embedding): 194 | torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) 195 | 196 | def forward(self, idx, targets=None): 197 | device = idx.device 198 | b, t = idx.size() 199 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 200 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) 201 | 202 | # forward the GPT model itself 203 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 204 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) 205 | x = self.transformer.drop(tok_emb + pos_emb) 206 | if self.config.mup_enabled: 207 | ### Begin muP code ### 208 | x *= self.config.mup_input_alpha 209 | ### End muP code ### 210 | for block in self.transformer.h: 211 | x = block(x) 212 | x = self.transformer.ln_f(x) 213 | 214 | if targets is not None: 215 | # if we are given some desired targets also calculate the loss 216 | if self.config.mup_enabled: 217 | ### Begin muP code ### 218 | # Scaling `x` instead of `logits` allows coord check to log change 219 | x *= self.config.mup_output_alpha / self.config.mup_width_multiplier 220 | ### End muP code ### 221 | logits = self.lm_head(x) 222 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 223 | else: 224 | # inference-time mini-optimization: only forward the lm_head on the very last position 225 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 226 | loss = None 227 | 228 | return logits, loss 229 | 230 | def crop_block_size(self, block_size): 231 | # model surgery to decrease the block size if necessary 232 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 233 | # but want to use a smaller block size for some smaller, simpler model 234 | assert block_size <= self.config.block_size 235 | self.config.block_size = block_size 236 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) 237 | for block in self.transformer.h: 238 | if hasattr(block.attn, 'bias'): 239 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] 240 | 241 | @classmethod 242 | def from_pretrained(cls, model_type, override_args=None): 243 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 244 | override_args = override_args or {} # default to empty dict 245 | # only dropout can be overridden see more notes below 246 | assert all(k == 'dropout' for k in override_args) 247 | from transformers import GPT2LMHeadModel 248 | print("loading weights from pretrained gpt: %s" % model_type) 249 | 250 | # n_layer, n_head and n_embd are determined from model_type 251 | config_args = { 252 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 253 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 254 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 255 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 256 | }[model_type] 257 | print("forcing vocab_size=50257, block_size=1024, bias=True") 258 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints 259 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints 260 | config_args['bias'] = True # always True for GPT model checkpoints 261 | # we can override the dropout rate, if desired 262 | if 'dropout' in override_args: 263 | print(f"overriding dropout rate to {override_args['dropout']}") 264 | config_args['dropout'] = override_args['dropout'] 265 | # create a from-scratch initialized minGPT model 266 | config = GPTConfig(**config_args) 267 | model = GPT(config) 268 | sd = model.state_dict() 269 | sd_keys = sd.keys() 270 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param 271 | 272 | # init a huggingface/transformers model 273 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 274 | sd_hf = model_hf.state_dict() 275 | 276 | # copy while ensuring all of the parameters are aligned and match in names and shapes 277 | sd_keys_hf = sd_hf.keys() 278 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer 279 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) 280 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 281 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 282 | # this means that we have to transpose these weights when we import them 283 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 284 | for k in sd_keys_hf: 285 | if any(k.endswith(w) for w in transposed): 286 | # special treatment for the Conv1D weights we need to transpose 287 | assert sd_hf[k].shape[::-1] == sd[k].shape 288 | with torch.no_grad(): 289 | sd[k].copy_(sd_hf[k].t()) 290 | else: 291 | # vanilla copy over the other parameters 292 | assert sd_hf[k].shape == sd[k].shape 293 | with torch.no_grad(): 294 | sd[k].copy_(sd_hf[k]) 295 | 296 | return model 297 | 298 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 299 | # start with all of the candidate parameters 300 | param_dict = {pn: p for pn, p in self.named_parameters()} 301 | # filter out those that do not require grad 302 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 303 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 304 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 305 | if self.config.mup_enabled and not self.config.mup_disable_hidden_lr_scaling: 306 | ### Begin muP code ### 307 | mup_decay_params = [] 308 | decay_params = [] 309 | nodecay_params = [] 310 | for n, p in param_dict.items(): 311 | if p.dim() >= 2: 312 | if n.endswith('c_attn.weight') or n.endswith('c_fc.weight') or n.endswith('c_proj.weight'): 313 | mup_decay_params.append(p) 314 | else: 315 | decay_params.append(p) 316 | else: 317 | nodecay_params.append(p) 318 | optim_groups = [ 319 | {'params': mup_decay_params, 'weight_decay': weight_decay, 'lr_scale': 1/self.config.mup_width_multiplier}, 320 | {'params': decay_params, 'weight_decay': weight_decay, 'lr_scale': 1}, 321 | {'params': nodecay_params, 'weight_decay': 0.0, 'lr_scale': 1} 322 | ] 323 | num_mup_decay_params = sum(p.numel() for p in mup_decay_params) 324 | num_decay_params = sum(p.numel() for p in decay_params) 325 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 326 | print(f"num mup decayed parameter tensors: {len(mup_decay_params)}, with {num_mup_decay_params:,} parameters") 327 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 328 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 329 | ### End muP code ### 330 | else: 331 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 332 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 333 | optim_groups = [ 334 | {'params': decay_params, 'weight_decay': weight_decay}, 335 | {'params': nodecay_params, 'weight_decay': 0.0} 336 | ] 337 | num_decay_params = sum(p.numel() for p in decay_params) 338 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 339 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 340 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 341 | # Create AdamW optimizer and use the fused version if it is available 342 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 343 | use_fused = fused_available and device_type == 'cuda' 344 | extra_args = dict(fused=True) if use_fused else dict() 345 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 346 | print(f"using fused AdamW: {use_fused}") 347 | 348 | return optimizer 349 | 350 | def estimate_mfu(self, fwdbwd_per_iter, dt): 351 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 352 | # first estimate the number of flops we do per iteration. 353 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 354 | N = self.get_num_params() 355 | cfg = self.config 356 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size 357 | flops_per_token = 6*N + 12*L*H*Q*T 358 | flops_per_fwdbwd = flops_per_token * T 359 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 360 | # express our flops throughput as ratio of A100 bfloat16 peak flops 361 | flops_achieved = flops_per_iter * (1.0/dt) # per second 362 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 363 | mfu = flops_achieved / flops_promised 364 | return mfu 365 | 366 | @torch.no_grad() 367 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 368 | """ 369 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 370 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 371 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 372 | """ 373 | for _ in range(max_new_tokens): 374 | # if the sequence context is growing too long we must crop it at block_size 375 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 376 | # forward the model to get the logits for the index in the sequence 377 | logits, _ = self(idx_cond) 378 | # pluck the logits at the final step and scale by desired temperature 379 | logits = logits[:, -1, :] / temperature 380 | # optionally crop the logits to only the top k options 381 | if top_k is not None: 382 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 383 | logits[logits < v[:, [-1]]] = -float('Inf') 384 | # apply softmax to convert logits to (normalized) probabilities 385 | probs = F.softmax(logits, dim=-1) 386 | # sample from the distribution 387 | idx_next = torch.multinomial(probs, num_samples=1) 388 | # append sampled index to the running sequence and continue 389 | idx = torch.cat((idx, idx_next), dim=1) 390 | 391 | return idx 392 | -------------------------------------------------------------------------------- /mup_examples/mutransfer_lr_shakespeare_char/plot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 17, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "image/png": "", 11 | "text/plain": [ 12 | "
" 13 | ] 14 | }, 15 | "metadata": {}, 16 | "output_type": "display_data" 17 | } 18 | ], 19 | "source": [ 20 | "import os\n", 21 | "import pandas as pd\n", 22 | "import numpy as np\n", 23 | "from tqdm import tqdm\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "import matplotlib as mpl\n", 26 | "from matplotlib import cm\n", 27 | "import seaborn as sns\n", 28 | "sns.set(style='whitegrid')\n", 29 | "\n", 30 | "parameterizations = [\n", 31 | " ('sp', r'SP'),\n", 32 | " ('mup', r'$\\mu$P'),\n", 33 | "]\n", 34 | "seeds = [1,2,3]\n", 35 | "widths = [\n", 36 | " 256,\n", 37 | " 512,\n", 38 | " 1024,\n", 39 | " 2048,\n", 40 | "]\n", 41 | "lrs = [\n", 42 | " # 0.125,\n", 43 | " 0.0625,\n", 44 | " 0.03125,\n", 45 | " 0.015625,\n", 46 | " 0.0078125,\n", 47 | " 0.00390625,\n", 48 | " 0.001953125,\n", 49 | " 0.0009765625,\n", 50 | " 0.00048828125,\n", 51 | " 0.000244140625,\n", 52 | " 0.0001220703125,\n", 53 | " 0.00006103515625,\n", 54 | " 0.00003051757812,\n", 55 | " 0.00001525878906,\n", 56 | " 0.000007629394531,\n", 57 | " 0.000003814697266,\n", 58 | "]\n", 59 | "class MplColorHelper:\n", 60 | "\n", 61 | " def __init__(self, cmap_name, start_val, stop_val):\n", 62 | " self.cmap_name = cmap_name\n", 63 | " self.cmap = plt.get_cmap(cmap_name)\n", 64 | " self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val)\n", 65 | " self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap)\n", 66 | "\n", 67 | " def get_rgb(self, val):\n", 68 | " return self.scalarMap.to_rgba(val)\n", 69 | "\n", 70 | "\n", 71 | "color_helper = MplColorHelper('viridis', 0, len(widths)-1)\n", 72 | "n_cols = len(parameterizations)\n", 73 | "n_rows = 1\n", 74 | "fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 3.33*n_rows))\n", 75 | "\n", 76 | "for parameterization_idx, (parameterization, parameterization_str) in enumerate(parameterizations):\n", 77 | " ax = axes[parameterization_idx]\n", 78 | " optimal_lrs = []\n", 79 | " optimal_losses = []\n", 80 | " for width_idx, width in enumerate(widths):\n", 81 | " mean_losses = []\n", 82 | " sem_losses = []\n", 83 | " lrs_to_plot = []\n", 84 | " for lr in lrs:\n", 85 | " losses = []\n", 86 | " for seed in seeds:\n", 87 | " job_name = f'width{width}_depth2_seed{seed}_lr{lr:.20f}'.rstrip('0')\n", 88 | " csv_path = os.path.join(parameterization, 'out', job_name, 'log.csv')\n", 89 | " if os.path.exists(csv_path):\n", 90 | " ckpt_df = pd.read_csv(csv_path)\n", 91 | " losses.append(ckpt_df['train/loss'].mean())\n", 92 | " # losses.append(ckpt_df['train/loss'].min())\n", 93 | " # losses.append(ckpt_df['train/loss'].ewm(alpha=0.9).mean().values[-1])\n", 94 | " # else:\n", 95 | " # print(f'Missing {csv_path}')\n", 96 | " if len(losses):\n", 97 | " mean_losses.append(np.mean(losses))\n", 98 | " sem_losses.append(np.std(losses, ddof=1) / np.sqrt(len(losses)))\n", 99 | " lrs_to_plot.append(lr)\n", 100 | " \n", 101 | " mean_losses = np.array(mean_losses)\n", 102 | " sem_losses = np.array(sem_losses)\n", 103 | " ax.plot(lrs_to_plot, mean_losses, label=width, marker='o', color=color_helper.get_rgb(width_idx))\n", 104 | " ax.fill_between(lrs_to_plot, mean_losses-sem_losses, mean_losses+sem_losses, color=color_helper.get_rgb(width_idx), alpha=0.33)\n", 105 | " \n", 106 | " if len(mean_losses):\n", 107 | " optimum_idx = np.argmin(mean_losses)\n", 108 | " optimal_lrs.append(lrs_to_plot[optimum_idx])\n", 109 | " optimal_losses.append(mean_losses[optimum_idx])\n", 110 | " \n", 111 | " ax.plot(optimal_lrs, optimal_losses, color='red', linestyle='none', marker='o')\n", 112 | " ax.set_xscale('log', base=2)\n", 113 | " ax.set_xlabel('Learning rate')\n", 114 | " ax.set_title(parameterization_str)\n", 115 | " ax.set_ylim(2.57, 3.15)\n", 116 | " # ax.set_ylim(2.3, 2.7)\n", 117 | " # ax.set_ylim(2.4, 2.8)\n", 118 | "\n", 119 | "axes[1].legend(title='Width')\n", 120 | "# axes[0].set_ylabel('Train loss on\\nshakespeare_char')\n", 121 | "axes[0].set_ylabel('Mean train loss on\\nshakespeare_char')\n", 122 | "axes[1].yaxis.set_ticklabels([])\n", 123 | "axes[1].tick_params(axis='y', length=0, width=0)\n", 124 | "\n", 125 | "plt.tight_layout()\n", 126 | "plt.show()\n", 127 | "plt.close()" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [] 136 | } 137 | ], 138 | "metadata": { 139 | "kernelspec": { 140 | "display_name": "nanogpt", 141 | "language": "python", 142 | "name": "python3" 143 | }, 144 | "language_info": { 145 | "codemirror_mode": { 146 | "name": "ipython", 147 | "version": 3 148 | }, 149 | "file_extension": ".py", 150 | "mimetype": "text/x-python", 151 | "name": "python", 152 | "nbconvert_exporter": "python", 153 | "pygments_lexer": "ipython3", 154 | "version": "3.9.19" 155 | } 156 | }, 157 | "nbformat": 4, 158 | "nbformat_minor": 2 159 | } 160 | --------------------------------------------------------------------------------