├── LICENSE ├── README.md ├── assets ├── 1.5B_200k_new.png ├── medium_100k_plus.png ├── small_100k_plus.png └── t5_winrate.png ├── config ├── train_gpt2_large_adam.py ├── train_gpt2_large_sophiag.py ├── train_gpt2_medium_adam.py ├── train_gpt2_medium_sophiag.py ├── train_gpt2_small_adam.py └── train_gpt2_small_sophiag.py ├── configurator.py ├── data └── openwebtext │ └── prepare.py ├── model.py ├── sophia.py ├── train_adam.py └── train_sophiag.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Hong Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training 2 | 3 | 4 | This is an official implementation of the **Sophia-G** optimizer in the paper [https://arxiv.org/abs/2305.14342](https://arxiv.org/abs/2305.14342) and GPT-2 training scripts. The code is based on [nanoGPT](https://github.com/karpathy/nanoGPT/) and [levanter](https://github.com/stanford-crfm/levanter/). Please cite the paper and star this repo if you find Sophia useful. Thanks! 5 | 6 | 7 | ```tex 8 | @article{liu2023sophia, 9 | title={Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training}, 10 | author={Liu, Hong and Li, Zhiyuan and Hall, David and Liang, Percy and Ma, Tengyu}, 11 | journal={arXiv preprint arXiv:2305.14342}, 12 | year={2023} 13 | } 14 | ``` 15 | 16 | 17 | ## News and Updates 18 | - Updated results with latest PyTorch version. 19 | 20 | 21 | 22 | ## Dependencies 23 | 24 | 25 | - [PyTorch](https://pytorch.org) 2.1.2 26 | - transformers 4.33.0 27 | - datasets 28 | - tiktoken 29 | - wandb 30 | 31 | ## General Usage 32 | 33 | Below is an example code snippet for training a general model with NLL loss with SophiaG. Please refer to the next section for guidelines on hyperparameter tuning. 34 | 35 | ```python 36 | import torch 37 | import torch.nn.functional as F 38 | from sophia import SophiaG 39 | 40 | # init model loss function and input data 41 | model = Model() 42 | data_loader = ... 43 | 44 | # init the optimizer 45 | optimizer = SophiaG(model.parameters(), lr=2e-4, betas=(0.965, 0.99), rho=0.01, weight_decay=1e-1) 46 | 47 | total_bs = len(data_loader) 48 | bs = total_bs * block_size 49 | k = 10 50 | iter_num = -1 51 | 52 | # training loop 53 | for epoch in range(epochs): 54 | for X, Y in data_loader: 55 | # standard training code 56 | logits, loss = model(X, Y) 57 | loss.backward() 58 | optimizer.step(bs=bs) 59 | optimizer.zero_grad(set_to_none=True) 60 | iter_num += 1 61 | 62 | if iter_num % k != k - 1: 63 | continue 64 | else: 65 | # update hessian EMA 66 | logits, _ = model(X, None) 67 | samp_dist = torch.distributions.Categorical(logits=logits) 68 | y_sample = samp_dist.sample() 69 | loss_sampled = F.cross_entropy(logits.view(-1, logits.size(-1)), y_sample.view(-1), ignore_index=-1) 70 | loss_sampled.backward() 71 | optimizer.update_hessian() 72 | optimizer.zero_grad(set_to_none=True) 73 | model.zero_grad() 74 | ``` 75 | 76 | 77 | ## Hyper-parameter Tuning 78 | 79 | ### Definition of learning rate 80 | - The update in the code is written as $\theta_{t+1} = \theta_t - lr*\textup{clip}(m_t / (\rho * h_t + \epsilon), 1)$, which is equivalent to the update in the paper up to a re-parameterization. (the $lr$ here corresponds to $\rho \cdot \eta_t$ in the paper). As a result, the learning rate of AdamW and Lion is not directly comparable. Empirically, Adam and Lion with learning rate ratio 5:1 has similar behaviour. The learning rate of SophiaG and Lion is directly comparable. Sophia allows to use much larger learning rate the Lion, and this is why Sophia is much faster. 81 | 82 | ### Tuning the hyperparameter $\rho$ 83 | - Tune $\rho$ to make the proportion of the clipped coordinates stable and in a proper range. This is tracked as ```train/win_rate``` in the [GPT-2 training example](https://github.com/Liuhong99/Sophia/blob/2443b03529ecdccf65699a5e55e68d69ede39509/train_sophiag.py#L398C21-L398C65). ```train/win_rate``` should peak in the beginning and remain stable afterwards. ```train/win_rate``` should stay in the range of 0.1 - 0.5. Typically a large $\rho$ will lead to a large ```train/win_rate```. An example of typical ```win_rate``` behavior in T5 model is provided below. 84 | 85 | ### Tuning the learning rate and weight decay 86 | - Choose lr to be slightly smaller than the learning rate that you would use for AdamW or 3 - 5 times the learning rate that you would use for Lion. 87 |

88 | 89 |

90 | 91 | - If the loss blows up, slightly decrease the learning rate or increase $\rho$. 92 | 93 | - Always use about 2x larger weight decay than what you would use for AdamW. 94 | 95 | ### Hyperparameters for GPT-2 models 96 | 97 | - Choose lr to be about the same as the learning rate that you would use for AdamW or 5 - 10 times the learning rate that you would use for Lion. 98 | - Tune $\rho$ to make the proportion of the parameters where the update is not clipped stable and in a proper range. This is tracked as ```train/win_rate``` in the [GPT-2 training example](https://github.com/Liuhong99/Sophia/blob/2443b03529ecdccf65699a5e55e68d69ede39509/train_sophiag.py#L398C21-L398C65). ```train/win_rate``` should peak in the beginning and remain stable afterwards. ```train/win_rate``` should stay in the range of 0.1 - 0.5. Typically a large $\rho$ will lead to a large ```train/win_rate```. 99 | - Use slightly larger weight decay than AdamW, e.g. 0.2. 100 | - Except learning rate, all other hyperparameters are transferable across different model sizes. 101 | - See the table below for the hyperparameters for different model sizes. 102 | 103 | | Model Size | lr for Adam | lr for Lion | lr for Sophia | $\rho$ for Sophia | weight decay for Sophia | 104 | | -------- | ------- | ------- | ------- | ------- | ------- | 105 | | 125M | 6e-4 | 1e-4 | 6e-4 | 0.05 | 0.2 | 106 | | 355M | 3e-4 | 1e-4 | 7e-4 | 0.08 | 0.2 | 107 | | 770M | 2e-4 | 8e-5 | 3e-4 | 0.05 | 0.2 | 108 | 109 | - Please feel free to let us know what you find out during hyper-parameters tuning. We appreciate your valuable feedback and comments! 110 | 111 | ## Reproduce GPT-2 Results 112 | 113 | Prepare the [OpenWebText](https://huggingface.co/datasets/openwebtext) data following [nanoGPT](https://github.com/karpathy/nanoGPT/): 114 | ``` 115 | $ python data/openwebtext/prepare.py 116 | ``` 117 | Start pre-training GPT2 Small (125M): 118 | 119 | If you have a machine with 10 A5000 (24GB) GPUs, 120 | ``` 121 | $ torchrun --standalone --nproc_per_node=10 \ 122 | train_sophiag.py \ 123 | config/train_gpt2_small_sophiag.py \ 124 | --batch_size=8 \ 125 | --gradient_accumulation_steps=6 126 | ``` 127 | If you have a machine with 8 A100 (40GB) GPUs, 128 | ``` 129 | $ torchrun --standalone --nproc_per_node=8 \ 130 | train_sophiag.py \ 131 | config/train_gpt2_small_sophiag.py \ 132 | --batch_size=12 \ 133 | --gradient_accumulation_steps=5 134 | ``` 135 | 136 | To reproduce the AdamW baseline following [nanoGPT](https://github.com/karpathy/nanoGPT/): 137 | ``` 138 | $ torchrun --standalone --nproc_per_node=10 \ 139 | train_adam.py \ 140 | config/train_gpt2_small_adam.py \ 141 | --batch_size=8 \ 142 | --gradient_accumulation_steps=6 143 | ``` 144 | 145 | This will lead to results in the figure below: 146 |

147 | 148 |

149 | 150 | Start pre-training GPT2 Medium (355M): 151 | 152 | If you have a machine with 8 A100 (40GB) GPUs, 153 | ``` 154 | $ torchrun --standalone --nproc_per_node=8 \ 155 | train_sophiag.py \ 156 | config/train_gpt2_medium_sophiag.py \ 157 | --batch_size=6 \ 158 | --gradient_accumulation_steps=10 159 | ``` 160 | 161 | To reproduce the AdamW baseline: 162 | ``` 163 | $ torchrun --standalone --nproc_per_node=8 \ 164 | train_adam.py \ 165 | config/train_gpt2_medium_adam.py \ 166 | --batch_size=6 \ 167 | --gradient_accumulation_steps=10 168 | ``` 169 | 170 | Please adjust ```nproc_per_node```, ```batch_size```, and ```gradient_accumulation_steps``` accordingly if you use other hardware setup. Make sure their product equals 480. 171 | 172 | 173 | This will lead to results in the figure below: 174 |

175 | 176 |

177 | 178 | Start pre-training GPT2 1.5B: 179 | 180 | We use [the Pile](https://github.com/EleutherAI/the-pile) and GPT NeoX tokenizer. First set up TPU instances and environment following [levanter](https://github.com/stanford-crfm/levanter/blob/e183ec80ec5971b12d4a3fb08a160268de342670/docs/Getting-Started-TPU-VM.md). Then change GAMMA_SOPHIA_G to 200 in [optim.py](https://github.com/stanford-crfm/levanter/blob/e183ec80ec5971b12d4a3fb08a160268de342670/src/levanter/optim.py). The training script for 1.5B model is 181 | ``` 182 | gcloud compute tpus tpu-vm ssh \ 183 | --zone \ 184 | --worker=all \ 185 | --command 'WANDB_API_KEY= levanter/infra/launch.sh python levanter/examples/gpt2_example.py --config_path levanter/config/gpt2_1536_pile.yaml --trainer.beta1 0.965 --trainer.beta2 0.99 --trainer.min_lr_ratio 0.020 --trainer.weight_decay 0.15 --trainer.learning_rate 2.5e-4 --trainer.warmup_ratio 0.01' 186 | 187 | ``` 188 | 189 | ## Acknowledgement 190 | 191 | The GPT-2 training code is based on [nanoGPT](https://github.com/karpathy/nanoGPT/), which is elegant and super efficient. 192 | -------------------------------------------------------------------------------- /assets/1.5B_200k_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuhong99/Sophia/a7e157229b71d58cf995d32854f1be15c265b350/assets/1.5B_200k_new.png -------------------------------------------------------------------------------- /assets/medium_100k_plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuhong99/Sophia/a7e157229b71d58cf995d32854f1be15c265b350/assets/medium_100k_plus.png -------------------------------------------------------------------------------- /assets/small_100k_plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuhong99/Sophia/a7e157229b71d58cf995d32854f1be15c265b350/assets/small_100k_plus.png -------------------------------------------------------------------------------- /assets/t5_winrate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuhong99/Sophia/a7e157229b71d58cf995d32854f1be15c265b350/assets/t5_winrate.png -------------------------------------------------------------------------------- /config/train_gpt2_large_adam.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'sophia' 3 | wandb_run_name='gpt2-large-adam-100k' 4 | 5 | # these make the total batch size be ~0.5M 6 | # 6 batch size * 1024 block size * 10 gradaccum * 8 GPUs = 491,520 7 | batch_size = 4 8 | block_size = 1024 9 | gradient_accumulation_steps = 12 10 | 11 | n_layer = 36 12 | n_head = 20 13 | n_embd = 1280 14 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 15 | bias = False 16 | scale_attn_by_inverse_layer_idx = True 17 | 18 | # this makes total number of tokens be 300B 19 | max_iters = 100000 20 | lr_decay_iters = 100000 21 | 22 | # eval stuff 23 | eval_interval = 1000 24 | eval_iters = 200 25 | log_interval = 10 26 | 27 | # optimizer 28 | optimizer_name = 'adamw' 29 | learning_rate = 2e-4 # max learning rate 30 | weight_decay = 1e-1 31 | beta1 = 0.9 32 | beta2 = 0.95 33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 34 | # learning rate decay settings 35 | decay_lr = True # whether to decay the learning rate 36 | warmup_iters = 2000 # how many steps to warm up for 37 | min_lr = 1e-5 38 | 39 | compile = True 40 | 41 | out_dir = 'out_large_adam_100k' 42 | -------------------------------------------------------------------------------- /config/train_gpt2_large_sophiag.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'sophia' 3 | wandb_run_name='gpt2-large-sophiag-100k' 4 | 5 | # these make the total batch size be ~0.5M 6 | # 6 batch size * 1024 block size * 10 gradaccum * 8 GPUs = 491,520 7 | batch_size = 4 8 | block_size = 1024 9 | gradient_accumulation_steps = 12 10 | 11 | n_layer = 36 12 | n_head = 20 13 | n_embd = 1280 14 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 15 | bias = False 16 | scale_attn_by_inverse_layer_idx = True 17 | 18 | # this makes total number of tokens be 300B 19 | max_iters = 100000 20 | lr_decay_iters = 100000 21 | 22 | # eval stuff 23 | eval_interval = 1000 24 | eval_iters = 200 25 | log_interval = 10 26 | 27 | # optimizer 28 | optimizer_name = 'sophiag' 29 | learning_rate = 3e-4 # max learning rate 30 | weight_decay = 2e-1 31 | beta1 = 0.965 32 | beta2 = 0.99 33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 34 | # learning rate decay settings 35 | decay_lr = True # whether to decay the learning rate 36 | warmup_iters = 2000 # how many steps to warm up for 37 | min_lr = 1e-5 38 | rho = 0.05 39 | interval = 10 40 | 41 | compile = True 42 | 43 | out_dir = 'out_large_sophiag_100k' 44 | -------------------------------------------------------------------------------- /config/train_gpt2_medium_adam.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'sophia' 3 | wandb_run_name='gpt2-medium-adam-100k' 4 | 5 | # these make the total batch size be ~0.5M 6 | # 6 batch size * 1024 block size * 10 gradaccum * 8 GPUs = 491,520 7 | batch_size = 6 8 | block_size = 1024 9 | gradient_accumulation_steps = 8 10 | 11 | n_layer = 24 12 | n_head = 16 13 | n_embd = 1024 14 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 15 | bias = False 16 | scale_attn_by_inverse_layer_idx = True 17 | 18 | # this makes total number of tokens be 300B 19 | max_iters = 100000 20 | lr_decay_iters = 100000 21 | 22 | # eval stuff 23 | eval_interval = 1000 24 | eval_iters = 200 25 | log_interval = 10 26 | 27 | # optimizer 28 | optimizer_name = 'adamw' 29 | learning_rate = 3e-4 # max learning rate 30 | weight_decay = 1e-1 31 | beta1 = 0.9 32 | beta2 = 0.95 33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 34 | # learning rate decay settings 35 | decay_lr = True # whether to decay the learning rate 36 | warmup_iters = 2000 # how many steps to warm up for 37 | min_lr = 6e-5 38 | 39 | compile = True 40 | 41 | out_dir = 'out_medium_adam_100k' 42 | -------------------------------------------------------------------------------- /config/train_gpt2_medium_sophiag.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'sophia' 3 | wandb_run_name='gpt2-medium-sophiag-100k' 4 | 5 | # these make the total batch size be ~0.5M 6 | # 6 batch size * 1024 block size * 10 gradaccum * 8 GPUs = 491,520 7 | batch_size = 10 8 | block_size = 1024 9 | gradient_accumulation_steps = 6 10 | 11 | n_layer = 24 12 | n_head = 16 13 | n_embd = 1024 14 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 15 | bias = False 16 | scale_attn_by_inverse_layer_idx = True 17 | 18 | # this makes total number of tokens be 300B 19 | max_iters = 100000 20 | lr_decay_iters = 100000 21 | 22 | # eval stuff 23 | eval_interval = 1000 24 | eval_iters = 200 25 | log_interval = 10 26 | 27 | # optimizer 28 | optimizer_name = 'sophiag' 29 | learning_rate = 7e-4 # max learning rate 30 | weight_decay = 2e-1 31 | beta1 = 0.965 32 | beta2 = 0.99 33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 34 | # learning rate decay settings 35 | decay_lr = True # whether to decay the learning rate 36 | warmup_iters = 2000 # how many steps to warm up for 37 | min_lr = 1e-5 38 | rho = 0.08 39 | interval = 10 40 | 41 | compile = True 42 | 43 | out_dir = 'out_medium_sophiag_100k' 44 | -------------------------------------------------------------------------------- /config/train_gpt2_small_adam.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'sophia' 3 | wandb_run_name='gpt2-small-adam-100k' 4 | 5 | # these make the total batch size be ~0.5M 6 | # 8 batch size * 1024 block size * 6 gradaccum * 10 GPUs = 491,520 7 | batch_size = 8 8 | block_size = 1024 9 | gradient_accumulation_steps = 6 10 | 11 | n_layer = 12 12 | n_head = 12 13 | n_embd = 768 14 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 15 | bias = False 16 | 17 | # this makes total number of tokens be 300B 18 | max_iters = 100000 19 | lr_decay_iters = 100000 20 | 21 | # eval stuff 22 | eval_interval = 1000 23 | eval_iters = 200 24 | log_interval = 10 25 | 26 | # optimizer 27 | optimizer_name = 'adamw' 28 | learning_rate = 6e-4 # max learning rate 29 | weight_decay = 1e-1 30 | beta1 = 0.9 31 | beta2 = 0.95 32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 33 | # learning rate decay settings 34 | decay_lr = True # whether to decay the learning rate 35 | warmup_iters = 2000 # how many steps to warm up for 36 | min_lr = 3e-5 37 | 38 | compile = True 39 | 40 | out_dir = 'out_small_adam_100k' 41 | -------------------------------------------------------------------------------- /config/train_gpt2_small_sophiag.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'sophia' 3 | wandb_run_name='gpt2-small-sophiag-100k' 4 | 5 | # these make the total batch size be ~0.5M 6 | # 8 batch size * 1024 block size * 6 gradaccum * 10 GPUs = 491,520 7 | batch_size = 8 8 | block_size = 1024 9 | gradient_accumulation_steps = 6 10 | total_bs = 480 11 | 12 | n_layer = 12 13 | n_head = 12 14 | n_embd = 768 15 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 16 | bias = False 17 | 18 | # this makes total number of tokens be 300B 19 | max_iters = 100000 20 | lr_decay_iters = 100000 21 | 22 | # eval stuff 23 | eval_interval = 1000 24 | eval_iters = 200 25 | log_interval = 10 26 | 27 | # optimizer 28 | optimizer_name = 'sophiag' 29 | learning_rate = 6e-4 # max learning rate 30 | weight_decay = 2e-1 31 | beta1 = 0.965 32 | beta2 = 0.99 33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 34 | # learning rate decay settings 35 | decay_lr = True # whether to decay the learning rate 36 | warmup_iters = 2000 # how many steps to warm up for 37 | min_lr = 1.5e-5 38 | rho = 0.05 39 | interval = 10 40 | 41 | compile = True 42 | 43 | out_dir = 'out_small_sophiag_100k' 44 | -------------------------------------------------------------------------------- /configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /data/openwebtext/prepare.py: -------------------------------------------------------------------------------- 1 | # saves the openwebtext dataset to a binary file for training. following was helpful: 2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 3 | 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import tiktoken 8 | from datasets import load_dataset # huggingface datasets 9 | 10 | # number of workers in .map() call 11 | # good number to use is ~order number of cpu cores // 2 12 | num_proc = 8 13 | 14 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 15 | dataset = load_dataset("openwebtext", cache_dir="/tiger/u/hliu99/nanoGPT/cache") 16 | 17 | # owt by default only contains the 'train' split, so create a test split 18 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 19 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val 20 | 21 | # this results in: 22 | # >>> split_dataset 23 | # DatasetDict({ 24 | # train: Dataset({ 25 | # features: ['text'], 26 | # num_rows: 8009762 27 | # }) 28 | # val: Dataset({ 29 | # features: ['text'], 30 | # num_rows: 4007 31 | # }) 32 | # }) 33 | 34 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 35 | enc = tiktoken.get_encoding("gpt2") 36 | def process(example): 37 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 38 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 39 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 40 | out = {'ids': ids, 'len': len(ids)} 41 | return out 42 | 43 | # tokenize the dataset 44 | tokenized = split_dataset.map( 45 | process, 46 | remove_columns=['text'], 47 | desc="tokenizing the splits", 48 | num_proc=num_proc, 49 | ) 50 | print('tokenization finished') 51 | # concatenate all the ids in each dataset into one large file we can use for training 52 | for split, dset in tokenized.items(): 53 | arr_len = np.sum(dset['len']) 54 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 55 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 56 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 57 | 58 | print(f"writing {filename}...") 59 | idx = 0 60 | for example in tqdm(dset): 61 | arr[idx : idx + example['len']] = example['ids'] 62 | idx += example['len'] 63 | arr.flush() 64 | 65 | # train.bin is ~17GB, val.bin ~8.5MB 66 | # train has ~9B tokens (9,035,582,198) 67 | # val has ~4M tokens (4,434,897) 68 | 69 | # to read the bin files later, e.g. with numpy: 70 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 71 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import inspect 3 | from dataclasses import dataclass 4 | from sophia import SophiaG 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | 10 | optimizer_dict = {'adamw': torch.optim.AdamW, 11 | 'sophiag': SophiaG 12 | } 13 | 14 | # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default) 15 | def new_gelu(x): 16 | """ 17 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). 18 | Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 19 | """ 20 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 21 | 22 | class LayerNorm(nn.Module): 23 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 24 | 25 | def __init__(self, ndim, bias): 26 | super().__init__() 27 | self.weight = nn.Parameter(torch.ones(ndim)) 28 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 29 | 30 | def forward(self, input): 31 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 32 | 33 | class CausalSelfAttention(nn.Module): 34 | 35 | def __init__(self, config, idx_layer): 36 | super().__init__() 37 | assert config.n_embd % config.n_head == 0 38 | # key, query, value projections for all heads, but in a batch 39 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 40 | # output projection 41 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 42 | # regularization 43 | self.attn_dropout = nn.Dropout(config.dropout) 44 | self.resid_dropout = nn.Dropout(config.dropout) 45 | self.n_head = config.n_head 46 | self.n_embd = config.n_embd 47 | self.dropout = config.dropout 48 | self.idx_layer = idx_layer 49 | self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx 50 | 51 | # causal mask to ensure that attention is only applied to the left in the input sequence 52 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 53 | .view(1, 1, config.block_size, config.block_size)) 54 | 55 | def forward(self, x): 56 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 57 | 58 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 59 | q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) 60 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 61 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 62 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 63 | 64 | if self.scale_attn_by_inverse_layer_idx: 65 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)) / float(self.idx_layer + 1)) 66 | else: 67 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 68 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 69 | att = F.softmax(att, dim=-1) 70 | att = self.attn_dropout(att) 71 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 72 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 73 | 74 | # output projection 75 | y = self.resid_dropout(self.c_proj(y)) 76 | return y 77 | 78 | class MLP(nn.Module): 79 | 80 | def __init__(self, config): 81 | super().__init__() 82 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 83 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 84 | self.dropout = nn.Dropout(config.dropout) 85 | 86 | def forward(self, x): 87 | x = self.c_fc(x) 88 | x = new_gelu(x) 89 | x = self.c_proj(x) 90 | x = self.dropout(x) 91 | return x 92 | 93 | class Block(nn.Module): 94 | 95 | def __init__(self, config, idx_layer): 96 | super().__init__() 97 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 98 | self.attn = CausalSelfAttention(config, idx_layer) 99 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 100 | self.mlp = MLP(config) 101 | 102 | def forward(self, x): 103 | x = x + self.attn(self.ln_1(x)) 104 | x = x + self.mlp(self.ln_2(x)) 105 | return x 106 | 107 | @dataclass 108 | class GPTConfig: 109 | block_size: int = 1024 110 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 111 | n_layer: int = 12 112 | n_head: int = 12 113 | n_embd: int = 768 114 | dropout: float = 0.0 115 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 116 | scale_attn_by_inverse_layer_idx: bool = False 117 | 118 | 119 | class GPT(nn.Module): 120 | 121 | def __init__(self, config): 122 | super().__init__() 123 | assert config.vocab_size is not None 124 | assert config.block_size is not None 125 | self.config = config 126 | 127 | self.transformer = nn.ModuleDict(dict( 128 | wte = nn.Embedding(config.vocab_size, config.n_embd), 129 | wpe = nn.Embedding(config.block_size, config.n_embd), 130 | drop = nn.Dropout(config.dropout), 131 | h = nn.ModuleList([Block(config, idx_layer) for idx_layer in range(config.n_layer)]), 132 | ln_f = LayerNorm(config.n_embd, bias=config.bias), 133 | )) 134 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 135 | # with weight tying when using torch.compile() some warnings get generated: 136 | # "UserWarning: functional_call was passed multiple values for tied weights. 137 | # This behavior is deprecated and will be an error in future versions" 138 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 139 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 140 | 141 | # init all weights 142 | self.apply(self._init_weights) 143 | # apply special scaled init to the residual projections, per GPT-2 paper 144 | for pn, p in self.named_parameters(): 145 | if pn.endswith('c_proj.weight'): 146 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 147 | 148 | # report number of parameters 149 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 150 | 151 | def get_num_params(self, non_embedding=True): 152 | """ 153 | Return the number of parameters in the model. 154 | For non-embedding count (default), the position embeddings get subtracted. 155 | The token embeddings would too, except due to the parameter sharing these 156 | params are actually used as weights in the final layer, so we include them. 157 | """ 158 | n_params = sum(p.numel() for p in self.parameters()) 159 | if non_embedding: 160 | n_params -= self.transformer.wpe.weight.numel() 161 | return n_params 162 | 163 | def _init_weights(self, module): 164 | if isinstance(module, nn.Linear): 165 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 166 | if module.bias is not None: 167 | torch.nn.init.zeros_(module.bias) 168 | elif isinstance(module, nn.Embedding): 169 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 170 | 171 | def forward(self, idx, targets=None): 172 | device = idx.device 173 | b, t = idx.size() 174 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 175 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) 176 | 177 | # forward the GPT model itself 178 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 179 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) 180 | x = self.transformer.drop(tok_emb + pos_emb) 181 | for block in self.transformer.h: 182 | x = block(x) 183 | x = self.transformer.ln_f(x) 184 | 185 | if targets is not None: 186 | # if we are given some desired targets also calculate the loss 187 | if not isinstance(targets, int): 188 | logits = self.lm_head(x) 189 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 190 | else: 191 | logits = self.lm_head(x) 192 | loss = None 193 | else: 194 | # inference-time mini-optimization: only forward the lm_head on the very last position 195 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 196 | loss = None 197 | 198 | return logits, loss 199 | 200 | def crop_block_size(self, block_size): 201 | # model surgery to decrease the block size if necessary 202 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 203 | # but want to use a smaller block size for some smaller, simpler model 204 | assert block_size <= self.config.block_size 205 | self.config.block_size = block_size 206 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) 207 | for block in self.transformer.h: 208 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] 209 | 210 | @classmethod 211 | def from_pretrained(cls, model_type, override_args=None): 212 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 213 | override_args = override_args or {} # default to empty dict 214 | # only dropout can be overridden see more notes below 215 | assert all(k == 'dropout' for k in override_args) 216 | from transformers import GPT2LMHeadModel 217 | print("loading weights from pretrained gpt: %s" % model_type) 218 | 219 | # n_layer, n_head and n_embd are determined from model_type 220 | config_args = { 221 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 222 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 223 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 224 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 225 | }[model_type] 226 | print("forcing vocab_size=50257, block_size=1024, bias=True") 227 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints 228 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints 229 | config_args['bias'] = True # always True for GPT model checkpoints 230 | # we can override the dropout rate, if desired 231 | if 'dropout' in override_args: 232 | print(f"overriding dropout rate to {override_args['dropout']}") 233 | config_args['dropout'] = override_args['dropout'] 234 | # create a from-scratch initialized minGPT model 235 | config = GPTConfig(**config_args) 236 | model = GPT(config) 237 | sd = model.state_dict() 238 | sd_keys = sd.keys() 239 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param 240 | 241 | # init a huggingface/transformers model 242 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 243 | sd_hf = model_hf.state_dict() 244 | 245 | # copy while ensuring all of the parameters are aligned and match in names and shapes 246 | sd_keys_hf = sd_hf.keys() 247 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer 248 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) 249 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 250 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 251 | # this means that we have to transpose these weights when we import them 252 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 253 | for k in sd_keys_hf: 254 | if any(k.endswith(w) for w in transposed): 255 | # special treatment for the Conv1D weights we need to transpose 256 | assert sd_hf[k].shape[::-1] == sd[k].shape 257 | with torch.no_grad(): 258 | sd[k].copy_(sd_hf[k].t()) 259 | else: 260 | # vanilla copy over the other parameters 261 | assert sd_hf[k].shape == sd[k].shape 262 | with torch.no_grad(): 263 | sd[k].copy_(sd_hf[k]) 264 | 265 | return model 266 | 267 | def configure_optimizers(self, optimizer_name, weight_decay, learning_rate, betas, rho, device_type): 268 | """ 269 | This long function is unfortunately doing something very simple and is being very defensive: 270 | We are separating out all parameters of the model into two buckets: those that will experience 271 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 272 | We are then returning the PyTorch optimizer object. 273 | """ 274 | 275 | # separate out all parameters to those that will and won't experience regularizing weight decay 276 | decay = set() 277 | no_decay = set() 278 | whitelist_weight_modules = (torch.nn.Linear, ) 279 | blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding) 280 | for mn, m in self.named_modules(): 281 | for pn, p in m.named_parameters(): 282 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 283 | # random note: because named_modules and named_parameters are recursive 284 | # we will see the same tensors p many many times. but doing it this way 285 | # allows us to know which parent module any tensor p belongs to... 286 | if pn.endswith('bias'): 287 | # all biases will not be decayed 288 | no_decay.add(fpn) 289 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 290 | # weights of whitelist modules will be weight decayed 291 | decay.add(fpn) 292 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 293 | # weights of blacklist modules will NOT be weight decayed 294 | no_decay.add(fpn) 295 | 296 | # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they 297 | # will appear in the no_decay and decay sets respectively after the above. 298 | # In addition, because named_parameters() doesn't return duplicates, it 299 | # will only return the first occurrence, key'd by 'transformer.wte.weight', below. 300 | # so let's manually remove 'lm_head.weight' from decay set. This will include 301 | # this tensor into optimization via transformer.wte.weight only, and not decayed. 302 | decay.remove('lm_head.weight') 303 | 304 | # validate that we considered every parameter 305 | param_dict = {pn: p for pn, p in self.named_parameters()} 306 | inter_params = decay & no_decay 307 | union_params = decay | no_decay 308 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 309 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 310 | % (str(param_dict.keys() - union_params), ) 311 | 312 | # create the pytorch optimizer object 313 | optim_groups = [ 314 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, 315 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 316 | ] 317 | 318 | opt_func = optimizer_dict[optimizer_name] 319 | if optimizer_name == 'adamw': 320 | # new PyTorch nightly has a new 'fused' option for AdamW that is much faster 321 | use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters) 322 | print(f"using fused AdamW: {use_fused}") 323 | extra_args = dict(fused=True) if use_fused else dict() 324 | optimizer = opt_func(optim_groups, lr=learning_rate, betas=betas, **extra_args) 325 | elif optimizer_name == 'sophiag': 326 | optimizer = opt_func(optim_groups, lr=learning_rate, betas=betas, rho=rho) 327 | else: 328 | raise ValueError('Invalid optimizer.') 329 | return optimizer 330 | 331 | def estimate_mfu(self, fwdbwd_per_iter, dt): 332 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 333 | # first estimate the number of flops we do per iteration. 334 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 335 | N = self.get_num_params() 336 | cfg = self.config 337 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size 338 | flops_per_token = 6*N + 12*L*H*Q*T 339 | flops_per_fwdbwd = flops_per_token * T 340 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 341 | # express our flops throughput as ratio of A100 bfloat16 peak flops 342 | flops_achieved = flops_per_iter * (1.0/dt) # per second 343 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 344 | mfu = flops_achieved / flops_promised 345 | return mfu 346 | 347 | @torch.no_grad() 348 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 349 | """ 350 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 351 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 352 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 353 | """ 354 | for _ in range(max_new_tokens): 355 | # if the sequence context is growing too long we must crop it at block_size 356 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 357 | # forward the model to get the logits for the index in the sequence 358 | logits, _ = self(idx_cond) 359 | # pluck the logits at the final step and scale by desired temperature 360 | logits = logits[:, -1, :] / temperature 361 | # optionally crop the logits to only the top k options 362 | if top_k is not None: 363 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 364 | logits[logits < v[:, [-1]]] = -float('Inf') 365 | # apply softmax to convert logits to (normalized) probabilities 366 | probs = F.softmax(logits, dim=-1) 367 | # sample from the distribution 368 | idx_next = torch.multinomial(probs, num_samples=1) 369 | # append sampled index to the running sequence and continue 370 | idx = torch.cat((idx, idx_next), dim=1) 371 | 372 | return idx 373 | -------------------------------------------------------------------------------- /sophia.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import Tensor 4 | from torch.optim.optimizer import Optimizer 5 | from typing import List, Optional 6 | 7 | 8 | class SophiaG(Optimizer): 9 | def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho = 0.04, 10 | weight_decay=1e-1, *, maximize: bool = False, 11 | capturable: bool = False): 12 | if not 0.0 <= lr: 13 | raise ValueError("Invalid learning rate: {}".format(lr)) 14 | if not 0.0 <= betas[0] < 1.0: 15 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 16 | if not 0.0 <= betas[1] < 1.0: 17 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 18 | if not 0.0 <= rho: 19 | raise ValueError("Invalid rho parameter at index 1: {}".format(rho)) 20 | if not 0.0 <= weight_decay: 21 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 22 | defaults = dict(lr=lr, betas=betas, rho=rho, 23 | weight_decay=weight_decay, 24 | maximize=maximize, capturable=capturable) 25 | super(SophiaG, self).__init__(params, defaults) 26 | 27 | def __setstate__(self, state): 28 | super().__setstate__(state) 29 | for group in self.param_groups: 30 | group.setdefault('maximize', False) 31 | group.setdefault('capturable', False) 32 | state_values = list(self.state.values()) 33 | step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step']) 34 | if not step_is_tensor: 35 | for s in state_values: 36 | s['step'] = torch.tensor(float(s['step'])) 37 | 38 | @torch.no_grad() 39 | def update_hessian(self): 40 | for group in self.param_groups: 41 | beta1, beta2 = group['betas'] 42 | for p in group['params']: 43 | if p.grad is None: 44 | continue 45 | state = self.state[p] 46 | 47 | if len(state) == 0: 48 | state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ 49 | if self.defaults['capturable'] else torch.tensor(0.) 50 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 51 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) 52 | 53 | if 'hessian' not in state.keys(): 54 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) 55 | 56 | state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) 57 | 58 | 59 | @torch.no_grad() 60 | def step(self, closure=None, bs=5120): 61 | loss = None 62 | if closure is not None: 63 | with torch.enable_grad(): 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | params_with_grad = [] 68 | grads = [] 69 | exp_avgs = [] 70 | state_steps = [] 71 | hessian = [] 72 | beta1, beta2 = group['betas'] 73 | 74 | for p in group['params']: 75 | if p.grad is None: 76 | continue 77 | params_with_grad.append(p) 78 | 79 | if p.grad.is_sparse: 80 | raise RuntimeError('Hero does not support sparse gradients') 81 | grads.append(p.grad) 82 | state = self.state[p] 83 | # State initialization 84 | if len(state) == 0: 85 | state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ 86 | if self.defaults['capturable'] else torch.tensor(0.) 87 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 88 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) 89 | 90 | if 'hessian' not in state.keys(): 91 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) 92 | 93 | exp_avgs.append(state['exp_avg']) 94 | state_steps.append(state['step']) 95 | hessian.append(state['hessian']) 96 | 97 | if self.defaults['capturable']: 98 | bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs 99 | 100 | sophiag(params_with_grad, 101 | grads, 102 | exp_avgs, 103 | hessian, 104 | state_steps, 105 | bs=bs, 106 | beta1=beta1, 107 | beta2=beta2, 108 | rho=group['rho'], 109 | lr=group['lr'], 110 | weight_decay=group['weight_decay'], 111 | maximize=group['maximize'], 112 | capturable=group['capturable']) 113 | 114 | return loss 115 | 116 | def sophiag(params: List[Tensor], 117 | grads: List[Tensor], 118 | exp_avgs: List[Tensor], 119 | hessian: List[Tensor], 120 | state_steps: List[Tensor], 121 | capturable: bool = False, 122 | *, 123 | bs: int, 124 | beta1: float, 125 | beta2: float, 126 | rho: float, 127 | lr: float, 128 | weight_decay: float, 129 | maximize: bool): 130 | 131 | if not all(isinstance(t, torch.Tensor) for t in state_steps): 132 | raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") 133 | 134 | 135 | func = _single_tensor_sophiag 136 | 137 | func(params, 138 | grads, 139 | exp_avgs, 140 | hessian, 141 | state_steps, 142 | bs=bs, 143 | beta1=beta1, 144 | beta2=beta2, 145 | rho=rho, 146 | lr=lr, 147 | weight_decay=weight_decay, 148 | maximize=maximize, 149 | capturable=capturable) 150 | 151 | def _single_tensor_sophiag(params: List[Tensor], 152 | grads: List[Tensor], 153 | exp_avgs: List[Tensor], 154 | hessian: List[Tensor], 155 | state_steps: List[Tensor], 156 | *, 157 | bs: int, 158 | beta1: float, 159 | beta2: float, 160 | rho: float, 161 | lr: float, 162 | weight_decay: float, 163 | maximize: bool, 164 | capturable: bool): 165 | 166 | for i, param in enumerate(params): 167 | grad = grads[i] if not maximize else -grads[i] 168 | exp_avg = exp_avgs[i] 169 | hess = hessian[i] 170 | step_t = state_steps[i] 171 | 172 | if capturable: 173 | assert param.is_cuda and step_t.is_cuda and bs.is_cuda 174 | 175 | if torch.is_complex(param): 176 | grad = torch.view_as_real(grad) 177 | exp_avg = torch.view_as_real(exp_avg) 178 | hess = torch.view_as_real(hess) 179 | param = torch.view_as_real(param) 180 | 181 | # update step 182 | step_t += 1 183 | 184 | # Perform stepweight decay 185 | param.mul_(1 - lr * weight_decay) 186 | 187 | # Decay the first and second moment running average coefficient 188 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 189 | 190 | if capturable: 191 | step_size = lr 192 | step_size_neg = step_size.neg() 193 | 194 | ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1) 195 | param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) 196 | else: 197 | step_size_neg = - lr 198 | 199 | ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1) 200 | param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) -------------------------------------------------------------------------------- /train_adam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import pickle 5 | from contextlib import nullcontext 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from torch.distributed import init_process_group, destroy_process_group 11 | 12 | from model import GPTConfig, GPT 13 | 14 | # ----------------------------------------------------------------------------- 15 | # default config values designed to train a gpt2 (124M) on OpenWebText 16 | # I/O 17 | out_dir = 'out' 18 | eval_interval = 2000 19 | log_interval = 1 20 | eval_iters = 200 21 | eval_only = False # if True, script exits right after the first eval 22 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 23 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 24 | # wandb logging 25 | wandb_log = False # disabled by default 26 | wandb_project = 'owt' 27 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 28 | # data 29 | dataset = 'openwebtext' 30 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes 31 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 32 | block_size = 1024 33 | # model 34 | n_layer = 12 35 | n_head = 12 36 | n_embd = 768 37 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 38 | bias = False # do we use bias inside LayerNorm and Linear layers? 39 | # optimizer 40 | optimizer_name = 'adamw' 41 | learning_rate = 6e-4 # max learning rate 42 | max_iters = 600000 # total number of training iterations 43 | weight_decay = 1e-1 44 | beta1 = 0.9 45 | beta2 = 0.95 46 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 47 | rho = 0.1 48 | interval = 10 49 | variant = 4 50 | # learning rate decay settings 51 | decay_lr = True # whether to decay the learning rate 52 | warmup_iters = 2000 # how many steps to warm up for 53 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 54 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 55 | # DDP settings 56 | backend = 'nccl' # 'nccl', 'gloo', etc. 57 | # system 58 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 59 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 60 | compile = True # use PyTorch 2.0 to compile the model to be faster 61 | scale_attn_by_inverse_layer_idx = True 62 | # ----------------------------------------------------------------------------- 63 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 64 | exec(open('configurator.py').read()) # overrides from command line or config file 65 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 66 | # ----------------------------------------------------------------------------- 67 | 68 | # various inits, derived attributes, I/O setup 69 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 70 | if ddp: 71 | init_process_group(backend=backend) 72 | ddp_rank = int(os.environ['RANK']) 73 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 74 | device = f'cuda:{ddp_local_rank}' 75 | torch.cuda.set_device(device) 76 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 77 | seed_offset = ddp_rank # each process gets a different seed 78 | else: 79 | # if not ddp, we are running on a single gpu, and one process 80 | master_process = True 81 | seed_offset = 0 82 | gradient_accumulation_steps *= 8 # simulate 8 gpus 83 | 84 | if master_process: 85 | os.makedirs(out_dir, exist_ok=True) 86 | torch.manual_seed(5000 + seed_offset) 87 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 88 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 89 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 90 | # note: float16 data type will automatically use a GradScaler 91 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 92 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype) 93 | 94 | # poor man's data loader 95 | data_dir = os.path.join('data', dataset) 96 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 97 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 98 | def get_batch(split): 99 | data = train_data if split == 'train' else val_data 100 | ix = torch.randint(len(data) - block_size, (batch_size,)) 101 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 102 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 103 | if device_type == 'cuda': 104 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 105 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 106 | else: 107 | x, y = x.to(device), y.to(device) 108 | return x, y 109 | 110 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 111 | iter_num = 0 112 | best_val_loss = 1e9 113 | 114 | # attempt to derive vocab_size from the dataset 115 | meta_path = os.path.join(data_dir, 'meta.pkl') 116 | meta_vocab_size = None 117 | if os.path.exists(meta_path): 118 | with open(meta_path, 'rb') as f: 119 | meta = pickle.load(f) 120 | meta_vocab_size = meta['vocab_size'] 121 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 122 | 123 | # model init 124 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 125 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line 126 | if init_from == 'scratch': 127 | # init a new model from scratch 128 | print("Initializing a new model from scratch") 129 | # determine the vocab size we'll use for from-scratch training 130 | if meta_vocab_size is None: 131 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 132 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 133 | gptconf = GPTConfig(**model_args) 134 | model = GPT(gptconf) 135 | elif init_from == 'resume': 136 | print(f"Resuming training from {out_dir}") 137 | # resume training from a checkpoint. 138 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 139 | checkpoint = torch.load(ckpt_path, map_location=device) 140 | checkpoint_model_args = checkpoint['model_args'] 141 | # force these config attributes to be equal otherwise we can't even resume training 142 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 143 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 144 | model_args[k] = checkpoint_model_args[k] 145 | # create the model 146 | gptconf = GPTConfig(**model_args) 147 | model = GPT(gptconf) 148 | state_dict = checkpoint['model'] 149 | # fix the keys of the state dictionary :( 150 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 151 | unwanted_prefix = '_orig_mod.' 152 | for k,v in list(state_dict.items()): 153 | if k.startswith(unwanted_prefix): 154 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 155 | model.load_state_dict(state_dict) 156 | iter_num = checkpoint['iter_num'] 157 | best_val_loss = checkpoint['best_val_loss'] 158 | elif init_from.startswith('gpt2'): 159 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 160 | # initialize from OpenAI GPT-2 weights 161 | override_args = dict(dropout=dropout) 162 | model = GPT.from_pretrained(init_from, override_args) 163 | # read off the created config params, so we can store them into checkpoint correctly 164 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 165 | model_args[k] = getattr(model.config, k) 166 | # crop down the model block size if desired, using model surgery 167 | if block_size < model.config.block_size: 168 | model.crop_block_size(block_size) 169 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 170 | model.to(device) 171 | 172 | # initialize a GradScaler. If enabled=False scaler is a no-op 173 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 174 | 175 | # optimizer 176 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), rho, device_type) 177 | if init_from == 'resume': 178 | optimizer.load_state_dict(checkpoint['optimizer']) 179 | del state_dict 180 | del checkpoint 181 | # compile the model 182 | if compile: 183 | print("compiling the model... (takes a ~minute)") 184 | unoptimized_model = model 185 | model = torch.compile(model) # requires PyTorch 2.0 186 | 187 | # wrap model into DDP container 188 | if ddp: 189 | model = DDP(model, device_ids=[ddp_local_rank]) 190 | 191 | # helps estimate an arbitrarily accurate loss over either split using many batches 192 | @torch.no_grad() 193 | def estimate_loss(): 194 | out = {} 195 | model.eval() 196 | for split in ['train', 'val']: 197 | losses = torch.zeros(eval_iters) 198 | for k in range(eval_iters): 199 | X, Y = get_batch(split) 200 | with ctx: 201 | logits, loss = model(X, Y) 202 | losses[k] = loss.item() 203 | out[split] = losses.mean() 204 | model.train() 205 | return out 206 | 207 | # learning rate decay scheduler (cosine with warmup) 208 | def get_lr(it): 209 | # 1) linear warmup for warmup_iters steps 210 | if it < warmup_iters: 211 | return learning_rate * it / warmup_iters 212 | # 2) if it > lr_decay_iters, return min learning rate 213 | if it > lr_decay_iters: 214 | return min_lr 215 | # 3) in between, use cosine decay down to min learning rate 216 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 217 | assert 0 <= decay_ratio <= 1 218 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 219 | return min_lr + coeff * (learning_rate - min_lr) 220 | 221 | # logging 222 | if wandb_log and master_process: 223 | import wandb 224 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 225 | 226 | # training loop 227 | X, Y = get_batch('train') # fetch the very first batch 228 | t0 = time.time() 229 | local_iter_num = 0 # number of iterations in the lifetime of this process 230 | raw_model = model.module if ddp else model # unwrap DDP container if needed 231 | running_mfu = -1.0 232 | clip_time = 0 233 | while True: 234 | 235 | # determine and set the learning rate for this iteration 236 | lr = get_lr(iter_num) if decay_lr else learning_rate 237 | for param_group in optimizer.param_groups: 238 | param_group['lr'] = lr 239 | 240 | # evaluate the loss on train/val sets and write checkpoints 241 | if iter_num % eval_interval == 0 and master_process: 242 | losses = estimate_loss() 243 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 244 | if wandb_log: 245 | wandb.log({ 246 | "iter": iter_num, 247 | "train/loss": losses['train'], 248 | "val/loss": losses['val'], 249 | "lr": lr, 250 | "mfu": running_mfu*100, # convert to percentage 251 | }, step=iter_num) 252 | if losses['val'] < best_val_loss or always_save_checkpoint: 253 | best_val_loss = losses['val'] 254 | if iter_num > 0: 255 | checkpoint = { 256 | 'model': raw_model.state_dict(), 257 | 'optimizer': optimizer.state_dict(), 258 | 'model_args': model_args, 259 | 'iter_num': iter_num, 260 | 'best_val_loss': best_val_loss, 261 | 'config': config, 262 | } 263 | print(f"saving checkpoint to {out_dir}") 264 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 265 | if iter_num % (eval_interval * 5) == 0: 266 | checkpoint = { 267 | 'model': raw_model.state_dict(), 268 | 'optimizer': optimizer.state_dict(), 269 | 'model_args': model_args, 270 | 'iter_num': iter_num, 271 | 'best_val_loss': best_val_loss, 272 | 'config': config, 273 | } 274 | print(f"saving checkpoint to {out_dir}") 275 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt')) 276 | if iter_num == 0 and eval_only: 277 | break 278 | 279 | # forward backward update, with optional gradient accumulation to simulate larger batch size 280 | # and using the GradScaler if data type is float16 281 | for micro_step in range(gradient_accumulation_steps): 282 | if ddp: 283 | # in DDP training we only need to sync gradients at the last micro step. 284 | # the official way to do this is with model.no_sync() context manager, but 285 | # I really dislike that this bloats the code and forces us to repeat code 286 | # looking at the source of that context manager, it just toggles this variable 287 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 288 | with ctx: 289 | logits, loss = model(X, Y) 290 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 291 | X, Y = get_batch('train') 292 | # backward pass, with gradient scaling if training in fp16 293 | scaler.scale(loss).backward() 294 | # clip the gradient 295 | if grad_clip != 0.0: 296 | scaler.unscale_(optimizer) 297 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 298 | if total_norm.item() > grad_clip: 299 | clip_time += 1 300 | # step the optimizer and scaler if training in fp16 301 | scaler.step(optimizer) 302 | scaler.update() 303 | # flush the gradients as soon as we can, no need for this memory anymore 304 | optimizer.zero_grad(set_to_none=True) 305 | 306 | # timing and logging 307 | t1 = time.time() 308 | dt = t1 - t0 309 | t0 = t1 310 | if iter_num % log_interval == 0 and master_process: 311 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point 312 | if local_iter_num >= 5: # let the training loop settle a bit 313 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 314 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 315 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 316 | params = [] 317 | for (name, p) in model.named_parameters(): 318 | params.append(p) 319 | total_param_norm = 0 320 | for p in params: 321 | param_norm = p.data.norm(2) 322 | total_param_norm += param_norm.item() ** 2 323 | total_param_norm = total_param_norm ** 0.5 324 | momentum_norm = 0 325 | LL = len(optimizer.state_dict()['state']) 326 | for jj in range(LL): 327 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2 328 | momentum_norm = torch.sqrt(momentum_norm).item() 329 | if wandb_log: 330 | wandb.log({ 331 | "iter": iter_num, 332 | "train/loss": lossf, 333 | "lr": lr, 334 | "param_norm": total_param_norm, 335 | "momentum_norm" : momentum_norm, 336 | "train/clip_rate": clip_time / (iter_num + 1) 337 | }, step=iter_num) 338 | iter_num += 1 339 | local_iter_num += 1 340 | 341 | # termination conditions 342 | if iter_num > max_iters: 343 | break 344 | 345 | if ddp: 346 | destroy_process_group() 347 | -------------------------------------------------------------------------------- /train_sophiag.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import pickle 5 | from contextlib import nullcontext 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | from torch.distributed import init_process_group, destroy_process_group 12 | from model import GPTConfig, GPT 13 | import torch.autograd as autograd 14 | 15 | # ----------------------------------------------------------------------------- 16 | # default config values designed to train a gpt2 (124M) on OpenWebText 17 | # I/O 18 | out_dir = 'out' 19 | eval_interval = 2000 20 | log_interval = 1 21 | eval_iters = 200 22 | eval_only = False # if True, script exits right after the first eval 23 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 24 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 25 | # wandb logging 26 | wandb_log = False # disabled by default 27 | wandb_project = 'owt' 28 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 29 | # data 30 | dataset = 'openwebtext' 31 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes 32 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 33 | block_size = 1024 34 | total_bs = 480 35 | # model 36 | n_layer = 12 37 | n_head = 12 38 | n_embd = 768 39 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 40 | bias = False # do we use bias inside LayerNorm and Linear layers? 41 | # optimizer 42 | optimizer_name = 'sophiag' 43 | learning_rate = 3e-4 # max learning rate 44 | max_iters = 600000 # total number of training iterations 45 | weight_decay = 1e-1 46 | beta1 = 0.9 47 | beta2 = 0.95 48 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 49 | rho = 0.03 50 | interval = 10 51 | hess_interval = interval 52 | variant = 4 53 | # learning rate decay settings 54 | decay_lr = True # whether to decay the learning rate 55 | warmup_iters = 2000 # how many steps to warm up for 56 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 57 | min_lr = 1.5e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 58 | # DDP settings 59 | backend = 'nccl' # 'nccl', 'gloo', etc. 60 | # system 61 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 62 | dtype = 'bfloat16' # 'float32', 'bfloat16' 63 | compile = True # use PyTorch 2.0 to compile the model to be faster 64 | scale_attn_by_inverse_layer_idx = True 65 | # ----------------------------------------------------------------------------- 66 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 67 | exec(open('configurator.py').read()) # overrides from command line or config file 68 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 69 | # ----------------------------------------------------------------------------- 70 | 71 | # various inits, derived attributes, I/O setup 72 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 73 | if ddp: 74 | init_process_group(backend=backend) 75 | ddp_rank = int(os.environ['RANK']) 76 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 77 | device = f'cuda:{ddp_local_rank}' 78 | torch.cuda.set_device(device) 79 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 80 | seed_offset = ddp_rank # each process gets a different seed 81 | else: 82 | # if not ddp, we are running on a single gpu, and one process 83 | ddp_rank = 0 #ddp_rank is used in get_batch function so this has to be here also when running locally 84 | master_process = True 85 | seed_offset = 0 86 | gradient_accumulation_steps *= 8 # simulate 8 gpus 87 | 88 | if master_process: 89 | os.makedirs(out_dir, exist_ok=True) 90 | torch.manual_seed(2099) 91 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 92 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 93 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 94 | # note: float16 data type will automatically use a GradScaler 95 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 96 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 97 | 98 | # poor man's data loader 99 | data_dir = os.path.join('data', dataset) 100 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 101 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 102 | def get_batch(split): 103 | data = train_data if split == 'train' else val_data 104 | ix_list = [] 105 | for jj in range(10): 106 | ix_list.append(torch.randint(len(data) - block_size, (batch_size,))) 107 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix_list[ddp_rank]]) 108 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix_list[ddp_rank]]) 109 | if device_type == 'cuda': 110 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 111 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 112 | else: 113 | x, y = x.to(device), y.to(device) 114 | return x, y 115 | 116 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 117 | iter_num = 0 118 | best_val_loss = 1e9 119 | 120 | # attempt to derive vocab_size from the dataset 121 | meta_path = os.path.join(data_dir, 'meta.pkl') 122 | meta_vocab_size = None 123 | if os.path.exists(meta_path): 124 | with open(meta_path, 'rb') as f: 125 | meta = pickle.load(f) 126 | meta_vocab_size = meta['vocab_size'] 127 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 128 | 129 | # model init 130 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 131 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line 132 | if init_from == 'scratch': 133 | # init a new model from scratch 134 | print("Initializing a new model from scratch") 135 | # determine the vocab size we'll use for from-scratch training 136 | if meta_vocab_size is None: 137 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 138 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 139 | gptconf = GPTConfig(**model_args) 140 | model = GPT(gptconf) 141 | elif init_from == 'resume': 142 | print(f"Resuming training from {out_dir}") 143 | # resume training from a checkpoint. 144 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 145 | checkpoint = torch.load(ckpt_path, map_location=device) 146 | checkpoint_model_args = checkpoint['model_args'] 147 | # force these config attributes to be equal otherwise we can't even resume training 148 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 149 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 150 | model_args[k] = checkpoint_model_args[k] 151 | # create the model 152 | gptconf = GPTConfig(**model_args) 153 | model = GPT(gptconf) 154 | state_dict = checkpoint['model'] 155 | # fix the keys of the state dictionary :( 156 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 157 | unwanted_prefix = '_orig_mod.' 158 | for k,v in list(state_dict.items()): 159 | if k.startswith(unwanted_prefix): 160 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 161 | model.load_state_dict(state_dict) 162 | iter_num = checkpoint['iter_num'] 163 | best_val_loss = checkpoint['best_val_loss'] 164 | elif init_from.startswith('gpt2'): 165 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 166 | # initialize from OpenAI GPT-2 weights 167 | override_args = dict(dropout=dropout) 168 | model = GPT.from_pretrained(init_from, override_args) 169 | # read off the created config params, so we can store them into checkpoint correctly 170 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 171 | model_args[k] = getattr(model.config, k) 172 | # crop down the model block size if desired, using model surgery 173 | if block_size < model.config.block_size: 174 | model.crop_block_size(block_size) 175 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 176 | model.to(device) 177 | 178 | 179 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), rho, device_type) 180 | if init_from == 'resume': 181 | optimizer.load_state_dict(checkpoint['optimizer']) 182 | del state_dict 183 | del checkpoint 184 | # compile the model 185 | if compile: 186 | print("compiling the model... (takes a ~minute)") 187 | unoptimized_model = model 188 | model = torch.compile(model) # requires PyTorch 2.0 189 | 190 | # wrap model into DDP container 191 | if ddp: 192 | model = DDP(model, device_ids=[ddp_local_rank]) 193 | 194 | # helps estimate an arbitrarily accurate loss over either split using many batches 195 | @torch.no_grad() 196 | def estimate_loss(): 197 | out = {} 198 | model.eval() 199 | for split in ['train', 'val']: 200 | losses = torch.zeros(eval_iters) 201 | for k in range(eval_iters): 202 | X, Y = get_batch(split) 203 | with ctx: 204 | logits, loss = model(X, Y) 205 | losses[k] = loss.item() 206 | out[split] = losses.mean() 207 | model.train() 208 | return out 209 | 210 | # learning rate decay scheduler (cosine with warmup) 211 | def get_lr(it): 212 | # 1) linear warmup for warmup_iters steps 213 | if it < warmup_iters: 214 | return learning_rate * it / warmup_iters 215 | # 2) if it > lr_decay_iters, return min learning rate 216 | if it > lr_decay_iters: 217 | return min_lr 218 | # 3) in between, use cosine decay down to min learning rate 219 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 220 | assert 0 <= decay_ratio <= 1 221 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 222 | return min_lr + coeff * (learning_rate - min_lr) 223 | 224 | # logging 225 | if wandb_log and master_process: 226 | import wandb 227 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 228 | 229 | # training loop 230 | X, Y = get_batch('train') # fetch the very first batch 231 | t0 = time.time() 232 | local_iter_num = 0 # number of iterations in the lifetime of this process 233 | raw_model = model.module if ddp else model # unwrap DDP container if needed 234 | running_mfu = -1.0 235 | num_param = 1 236 | num_effective = 0 237 | momentum_norm = 0 238 | hessian_norm = 0 239 | hessian_norm2 = 0 240 | clip_time = 0 241 | while True: 242 | 243 | # determine and set the learning rate for this iteration 244 | lr = get_lr(iter_num) if decay_lr else learning_rate 245 | for param_group in optimizer.param_groups: 246 | param_group['lr'] = lr 247 | 248 | # evaluate the loss on train/val sets and write checkpoints 249 | if iter_num % eval_interval == 0 and master_process: 250 | losses = estimate_loss() 251 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 252 | if wandb_log: 253 | wandb.log({ 254 | "iter": iter_num, 255 | "train/loss": losses['train'], 256 | "val/loss": losses['val'], 257 | "lr": lr, 258 | "mfu": running_mfu*100, # convert to percentage 259 | }, step=iter_num) 260 | if losses['val'] < best_val_loss or always_save_checkpoint: 261 | best_val_loss = losses['val'] 262 | if iter_num > 0: 263 | checkpoint = { 264 | 'model': raw_model.state_dict(), 265 | 'optimizer': optimizer.state_dict(), 266 | 'model_args': model_args, 267 | 'iter_num': iter_num, 268 | 'best_val_loss': best_val_loss, 269 | 'config': config, 270 | } 271 | print(f"saving checkpoint to {out_dir}") 272 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 273 | if iter_num % (eval_interval * 5) == 0: 274 | checkpoint = { 275 | 'model': raw_model.state_dict(), 276 | 'optimizer': optimizer.state_dict(), 277 | 'model_args': model_args, 278 | 'iter_num': iter_num, 279 | 'best_val_loss': best_val_loss, 280 | 'config': config, 281 | } 282 | print(f"saving checkpoint to {out_dir}") 283 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt')) 284 | if iter_num == 0 and eval_only: 285 | break 286 | 287 | # forward backward update, with optional gradient accumulation to simulate larger batch size 288 | # and using the GradScaler if data type is float16 289 | if iter_num % hess_interval != hess_interval - 1: 290 | for micro_step in range(gradient_accumulation_steps): 291 | if ddp: 292 | # in DDP training we only need to sync gradients at the last micro step. 293 | # the official way to do this is with model.no_sync() context manager, but 294 | # I really dislike that this bloats the code and forces us to repeat code 295 | # looking at the source of that context manager, it just toggles this variable 296 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 297 | with ctx: 298 | logits, loss = model(X, Y) 299 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 300 | X, Y = get_batch('train') 301 | # backward pass, with gradient scaling if training in fp16 302 | (loss / gradient_accumulation_steps).backward() 303 | # clip the gradient 304 | if grad_clip != 0.0: 305 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 306 | if total_norm.item() > grad_clip: 307 | clip_time += 1 308 | # step the optimizer and scaler if training in fp16 309 | optimizer.step(bs=total_bs * block_size) 310 | # flush the gradients as soon as we can, no need for this memory anymore 311 | optimizer.zero_grad(set_to_none=True) 312 | 313 | # timing and logging 314 | t1 = time.time() 315 | dt = t1 - t0 316 | t0 = t1 317 | if iter_num % log_interval == 0 and master_process: 318 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point 319 | if local_iter_num >= 5: # let the training loop settle a bit 320 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 321 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 322 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 323 | total_param_norm = 0 324 | momentum_norm = 0 325 | params = [] 326 | for (name, p) in model.named_parameters(): 327 | params.append(p) 328 | for p in params: 329 | param_norm = p.data.norm(2) 330 | total_param_norm += param_norm.item() ** 2 331 | total_param_norm = total_param_norm ** 0.5 332 | momentum_norm = 0 333 | LL = len(optimizer.state_dict()['state']) 334 | for jj in range(LL): 335 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2 336 | momentum_norm = torch.sqrt(momentum_norm).item() 337 | if wandb_log: 338 | wandb.log({ 339 | "iter": iter_num, 340 | "train/loss": lossf, 341 | "lr": lr, 342 | "param_norm": total_param_norm, 343 | "momentum_norm" : momentum_norm, 344 | "hessian_norm": hessian_norm, 345 | "hessian_norm2": hessian_norm2, 346 | "train/win_rate": num_effective / num_param, 347 | "train/clip_rate": clip_time / (iter_num + 1) 348 | 349 | }, step=iter_num) 350 | iter_num += 1 351 | local_iter_num += 1 352 | else: 353 | for micro_step in range(gradient_accumulation_steps): 354 | if ddp: 355 | # in DDP training we only need to sync gradients at the last micro step. 356 | # the official way to do this is with model.no_sync() context manager, but 357 | # I really dislike that this bloats the code and forces us to repeat code 358 | # looking at the source of that context manager, it just toggles this variable 359 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 360 | with ctx: 361 | logits, loss = model(X, Y) 362 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 363 | X, Y = get_batch('train') 364 | # backward pass, with gradient scaling if training in fp16 365 | (loss / gradient_accumulation_steps).backward() 366 | # clip the gradient 367 | if grad_clip != 0.0: 368 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 369 | if total_norm.item() > grad_clip: 370 | clip_time += 1 371 | # step the optimizer and scaler if training in fp16 372 | optimizer.step(bs=total_bs * block_size) 373 | # flush the gradients as soon as we can, no need for this memory anymore 374 | optimizer.zero_grad(set_to_none=True) 375 | 376 | # timing and logging 377 | t1 = time.time() 378 | dt = t1 - t0 379 | t0 = t1 380 | if iter_num % log_interval == 0 and master_process: 381 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point 382 | if local_iter_num >= 5: # let the training loop settle a bit 383 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 384 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 385 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 386 | 387 | total_param_norm = 0 388 | momentum_norm = 0 389 | 390 | if wandb_log: 391 | wandb.log({ 392 | "iter": iter_num, 393 | "train/loss": lossf, 394 | "lr": lr, 395 | "param_norm": total_param_norm, 396 | "momentum_norm" : momentum_norm, 397 | "hessian_norm": hessian_norm.item(), 398 | "hessian_norm2": hessian_norm2, 399 | "train/win_rate": num_effective / num_param, 400 | "train/clip_rate": clip_time / (iter_num + 1) 401 | 402 | }, step=iter_num) 403 | iter_num += 1 404 | local_iter_num += 1 405 | 406 | for micro_step in range(gradient_accumulation_steps): 407 | if ddp: 408 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 409 | with ctx: 410 | logits, _ = model(X, 0) 411 | X, Y = get_batch('train') 412 | samp_dist = torch.distributions.Categorical(logits=logits) 413 | y_sample = samp_dist.sample() 414 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y_sample.view(-1), ignore_index=-1) 415 | # backward pass, with gradient scaling if training in fp16 416 | (loss / gradient_accumulation_steps).backward() 417 | # clip the gradient 418 | if grad_clip != 0.0: 419 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 420 | # step the optimizer and scaler if training in fp16 421 | optimizer.update_hessian() 422 | # flush the gradients as soon as we can, no need for this memory anymore 423 | optimizer.zero_grad(set_to_none=True) 424 | 425 | num_param = 0 426 | num_effective = 0 427 | hessian_norm = 0 428 | hessian_norm2 = 0 429 | 430 | LL = len(optimizer.state_dict()['state']) 431 | 432 | for jj in range(LL): 433 | num_param += optimizer.state_dict()['state'][jj]['exp_avg'].numel() 434 | num_effective += torch.sum(torch.abs(optimizer.state_dict()['state'][jj]['exp_avg']) < rho * total_bs * block_size * optimizer.state_dict()['state'][jj]['hessian']) 435 | hessian_norm += optimizer.state_dict()['state'][jj]['hessian'].detach().norm(1).item() 436 | hessian_norm2 += optimizer.state_dict()['state'][jj]['hessian'].detach().norm(2).item() ** 2 437 | hessian_norm2 = hessian_norm2 ** 0.5 438 | 439 | 440 | 441 | t1 = time.time() 442 | dt = t1 - t0 443 | t0 = t1 444 | if master_process: 445 | # loss as float. note: this is a CPU-GPU sync point 446 | if local_iter_num >= 5: # let the training loop settle a bit 447 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 448 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 449 | print(f"iter {iter_num}: time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 450 | 451 | 452 | # termination conditions 453 | if iter_num > max_iters: 454 | break 455 | 456 | if ddp: 457 | destroy_process_group() 458 | --------------------------------------------------------------------------------