├── LICENSE
├── README.md
├── assets
├── 1.5B_200k_new.png
├── medium_100k_plus.png
├── small_100k_plus.png
└── t5_winrate.png
├── config
├── train_gpt2_large_adam.py
├── train_gpt2_large_sophiag.py
├── train_gpt2_medium_adam.py
├── train_gpt2_medium_sophiag.py
├── train_gpt2_small_adam.py
└── train_gpt2_small_sophiag.py
├── configurator.py
├── data
└── openwebtext
│ └── prepare.py
├── model.py
├── sophia.py
├── train_adam.py
└── train_sophiag.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Hong Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training
2 |
3 |
4 | This is an official implementation of the **Sophia-G** optimizer in the paper [https://arxiv.org/abs/2305.14342](https://arxiv.org/abs/2305.14342) and GPT-2 training scripts. The code is based on [nanoGPT](https://github.com/karpathy/nanoGPT/) and [levanter](https://github.com/stanford-crfm/levanter/). Please cite the paper and star this repo if you find Sophia useful. Thanks!
5 |
6 |
7 | ```tex
8 | @article{liu2023sophia,
9 | title={Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training},
10 | author={Liu, Hong and Li, Zhiyuan and Hall, David and Liang, Percy and Ma, Tengyu},
11 | journal={arXiv preprint arXiv:2305.14342},
12 | year={2023}
13 | }
14 | ```
15 |
16 |
17 | ## News and Updates
18 | - Updated results with latest PyTorch version.
19 |
20 |
21 |
22 | ## Dependencies
23 |
24 |
25 | - [PyTorch](https://pytorch.org) 2.1.2
26 | - transformers 4.33.0
27 | - datasets
28 | - tiktoken
29 | - wandb
30 |
31 | ## General Usage
32 |
33 | Below is an example code snippet for training a general model with NLL loss with SophiaG. Please refer to the next section for guidelines on hyperparameter tuning.
34 |
35 | ```python
36 | import torch
37 | import torch.nn.functional as F
38 | from sophia import SophiaG
39 |
40 | # init model loss function and input data
41 | model = Model()
42 | data_loader = ...
43 |
44 | # init the optimizer
45 | optimizer = SophiaG(model.parameters(), lr=2e-4, betas=(0.965, 0.99), rho=0.01, weight_decay=1e-1)
46 |
47 | total_bs = len(data_loader)
48 | bs = total_bs * block_size
49 | k = 10
50 | iter_num = -1
51 |
52 | # training loop
53 | for epoch in range(epochs):
54 | for X, Y in data_loader:
55 | # standard training code
56 | logits, loss = model(X, Y)
57 | loss.backward()
58 | optimizer.step(bs=bs)
59 | optimizer.zero_grad(set_to_none=True)
60 | iter_num += 1
61 |
62 | if iter_num % k != k - 1:
63 | continue
64 | else:
65 | # update hessian EMA
66 | logits, _ = model(X, None)
67 | samp_dist = torch.distributions.Categorical(logits=logits)
68 | y_sample = samp_dist.sample()
69 | loss_sampled = F.cross_entropy(logits.view(-1, logits.size(-1)), y_sample.view(-1), ignore_index=-1)
70 | loss_sampled.backward()
71 | optimizer.update_hessian()
72 | optimizer.zero_grad(set_to_none=True)
73 | model.zero_grad()
74 | ```
75 |
76 |
77 | ## Hyper-parameter Tuning
78 |
79 | ### Definition of learning rate
80 | - The update in the code is written as $\theta_{t+1} = \theta_t - lr*\textup{clip}(m_t / (\rho * h_t + \epsilon), 1)$, which is equivalent to the update in the paper up to a re-parameterization. (the $lr$ here corresponds to $\rho \cdot \eta_t$ in the paper). As a result, the learning rate of AdamW and Lion is not directly comparable. Empirically, Adam and Lion with learning rate ratio 5:1 has similar behaviour. The learning rate of SophiaG and Lion is directly comparable. Sophia allows to use much larger learning rate the Lion, and this is why Sophia is much faster.
81 |
82 | ### Tuning the hyperparameter $\rho$
83 | - Tune $\rho$ to make the proportion of the clipped coordinates stable and in a proper range. This is tracked as ```train/win_rate``` in the [GPT-2 training example](https://github.com/Liuhong99/Sophia/blob/2443b03529ecdccf65699a5e55e68d69ede39509/train_sophiag.py#L398C21-L398C65). ```train/win_rate``` should peak in the beginning and remain stable afterwards. ```train/win_rate``` should stay in the range of 0.1 - 0.5. Typically a large $\rho$ will lead to a large ```train/win_rate```. An example of typical ```win_rate``` behavior in T5 model is provided below.
84 |
85 | ### Tuning the learning rate and weight decay
86 | - Choose lr to be slightly smaller than the learning rate that you would use for AdamW or 3 - 5 times the learning rate that you would use for Lion.
87 |
88 |
89 |
90 |
91 | - If the loss blows up, slightly decrease the learning rate or increase $\rho$.
92 |
93 | - Always use about 2x larger weight decay than what you would use for AdamW.
94 |
95 | ### Hyperparameters for GPT-2 models
96 |
97 | - Choose lr to be about the same as the learning rate that you would use for AdamW or 5 - 10 times the learning rate that you would use for Lion.
98 | - Tune $\rho$ to make the proportion of the parameters where the update is not clipped stable and in a proper range. This is tracked as ```train/win_rate``` in the [GPT-2 training example](https://github.com/Liuhong99/Sophia/blob/2443b03529ecdccf65699a5e55e68d69ede39509/train_sophiag.py#L398C21-L398C65). ```train/win_rate``` should peak in the beginning and remain stable afterwards. ```train/win_rate``` should stay in the range of 0.1 - 0.5. Typically a large $\rho$ will lead to a large ```train/win_rate```.
99 | - Use slightly larger weight decay than AdamW, e.g. 0.2.
100 | - Except learning rate, all other hyperparameters are transferable across different model sizes.
101 | - See the table below for the hyperparameters for different model sizes.
102 |
103 | | Model Size | lr for Adam | lr for Lion | lr for Sophia | $\rho$ for Sophia | weight decay for Sophia |
104 | | -------- | ------- | ------- | ------- | ------- | ------- |
105 | | 125M | 6e-4 | 1e-4 | 6e-4 | 0.05 | 0.2 |
106 | | 355M | 3e-4 | 1e-4 | 7e-4 | 0.08 | 0.2 |
107 | | 770M | 2e-4 | 8e-5 | 3e-4 | 0.05 | 0.2 |
108 |
109 | - Please feel free to let us know what you find out during hyper-parameters tuning. We appreciate your valuable feedback and comments!
110 |
111 | ## Reproduce GPT-2 Results
112 |
113 | Prepare the [OpenWebText](https://huggingface.co/datasets/openwebtext) data following [nanoGPT](https://github.com/karpathy/nanoGPT/):
114 | ```
115 | $ python data/openwebtext/prepare.py
116 | ```
117 | Start pre-training GPT2 Small (125M):
118 |
119 | If you have a machine with 10 A5000 (24GB) GPUs,
120 | ```
121 | $ torchrun --standalone --nproc_per_node=10 \
122 | train_sophiag.py \
123 | config/train_gpt2_small_sophiag.py \
124 | --batch_size=8 \
125 | --gradient_accumulation_steps=6
126 | ```
127 | If you have a machine with 8 A100 (40GB) GPUs,
128 | ```
129 | $ torchrun --standalone --nproc_per_node=8 \
130 | train_sophiag.py \
131 | config/train_gpt2_small_sophiag.py \
132 | --batch_size=12 \
133 | --gradient_accumulation_steps=5
134 | ```
135 |
136 | To reproduce the AdamW baseline following [nanoGPT](https://github.com/karpathy/nanoGPT/):
137 | ```
138 | $ torchrun --standalone --nproc_per_node=10 \
139 | train_adam.py \
140 | config/train_gpt2_small_adam.py \
141 | --batch_size=8 \
142 | --gradient_accumulation_steps=6
143 | ```
144 |
145 | This will lead to results in the figure below:
146 |
147 |
148 |
149 |
150 | Start pre-training GPT2 Medium (355M):
151 |
152 | If you have a machine with 8 A100 (40GB) GPUs,
153 | ```
154 | $ torchrun --standalone --nproc_per_node=8 \
155 | train_sophiag.py \
156 | config/train_gpt2_medium_sophiag.py \
157 | --batch_size=6 \
158 | --gradient_accumulation_steps=10
159 | ```
160 |
161 | To reproduce the AdamW baseline:
162 | ```
163 | $ torchrun --standalone --nproc_per_node=8 \
164 | train_adam.py \
165 | config/train_gpt2_medium_adam.py \
166 | --batch_size=6 \
167 | --gradient_accumulation_steps=10
168 | ```
169 |
170 | Please adjust ```nproc_per_node```, ```batch_size```, and ```gradient_accumulation_steps``` accordingly if you use other hardware setup. Make sure their product equals 480.
171 |
172 |
173 | This will lead to results in the figure below:
174 |
175 |
176 |
177 |
178 | Start pre-training GPT2 1.5B:
179 |
180 | We use [the Pile](https://github.com/EleutherAI/the-pile) and GPT NeoX tokenizer. First set up TPU instances and environment following [levanter](https://github.com/stanford-crfm/levanter/blob/e183ec80ec5971b12d4a3fb08a160268de342670/docs/Getting-Started-TPU-VM.md). Then change GAMMA_SOPHIA_G to 200 in [optim.py](https://github.com/stanford-crfm/levanter/blob/e183ec80ec5971b12d4a3fb08a160268de342670/src/levanter/optim.py). The training script for 1.5B model is
181 | ```
182 | gcloud compute tpus tpu-vm ssh \
183 | --zone \
184 | --worker=all \
185 | --command 'WANDB_API_KEY= levanter/infra/launch.sh python levanter/examples/gpt2_example.py --config_path levanter/config/gpt2_1536_pile.yaml --trainer.beta1 0.965 --trainer.beta2 0.99 --trainer.min_lr_ratio 0.020 --trainer.weight_decay 0.15 --trainer.learning_rate 2.5e-4 --trainer.warmup_ratio 0.01'
186 |
187 | ```
188 |
189 | ## Acknowledgement
190 |
191 | The GPT-2 training code is based on [nanoGPT](https://github.com/karpathy/nanoGPT/), which is elegant and super efficient.
192 |
--------------------------------------------------------------------------------
/assets/1.5B_200k_new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuhong99/Sophia/a7e157229b71d58cf995d32854f1be15c265b350/assets/1.5B_200k_new.png
--------------------------------------------------------------------------------
/assets/medium_100k_plus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuhong99/Sophia/a7e157229b71d58cf995d32854f1be15c265b350/assets/medium_100k_plus.png
--------------------------------------------------------------------------------
/assets/small_100k_plus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuhong99/Sophia/a7e157229b71d58cf995d32854f1be15c265b350/assets/small_100k_plus.png
--------------------------------------------------------------------------------
/assets/t5_winrate.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuhong99/Sophia/a7e157229b71d58cf995d32854f1be15c265b350/assets/t5_winrate.png
--------------------------------------------------------------------------------
/config/train_gpt2_large_adam.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'sophia'
3 | wandb_run_name='gpt2-large-adam-100k'
4 |
5 | # these make the total batch size be ~0.5M
6 | # 6 batch size * 1024 block size * 10 gradaccum * 8 GPUs = 491,520
7 | batch_size = 4
8 | block_size = 1024
9 | gradient_accumulation_steps = 12
10 |
11 | n_layer = 36
12 | n_head = 20
13 | n_embd = 1280
14 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
15 | bias = False
16 | scale_attn_by_inverse_layer_idx = True
17 |
18 | # this makes total number of tokens be 300B
19 | max_iters = 100000
20 | lr_decay_iters = 100000
21 |
22 | # eval stuff
23 | eval_interval = 1000
24 | eval_iters = 200
25 | log_interval = 10
26 |
27 | # optimizer
28 | optimizer_name = 'adamw'
29 | learning_rate = 2e-4 # max learning rate
30 | weight_decay = 1e-1
31 | beta1 = 0.9
32 | beta2 = 0.95
33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
34 | # learning rate decay settings
35 | decay_lr = True # whether to decay the learning rate
36 | warmup_iters = 2000 # how many steps to warm up for
37 | min_lr = 1e-5
38 |
39 | compile = True
40 |
41 | out_dir = 'out_large_adam_100k'
42 |
--------------------------------------------------------------------------------
/config/train_gpt2_large_sophiag.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'sophia'
3 | wandb_run_name='gpt2-large-sophiag-100k'
4 |
5 | # these make the total batch size be ~0.5M
6 | # 6 batch size * 1024 block size * 10 gradaccum * 8 GPUs = 491,520
7 | batch_size = 4
8 | block_size = 1024
9 | gradient_accumulation_steps = 12
10 |
11 | n_layer = 36
12 | n_head = 20
13 | n_embd = 1280
14 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
15 | bias = False
16 | scale_attn_by_inverse_layer_idx = True
17 |
18 | # this makes total number of tokens be 300B
19 | max_iters = 100000
20 | lr_decay_iters = 100000
21 |
22 | # eval stuff
23 | eval_interval = 1000
24 | eval_iters = 200
25 | log_interval = 10
26 |
27 | # optimizer
28 | optimizer_name = 'sophiag'
29 | learning_rate = 3e-4 # max learning rate
30 | weight_decay = 2e-1
31 | beta1 = 0.965
32 | beta2 = 0.99
33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
34 | # learning rate decay settings
35 | decay_lr = True # whether to decay the learning rate
36 | warmup_iters = 2000 # how many steps to warm up for
37 | min_lr = 1e-5
38 | rho = 0.05
39 | interval = 10
40 |
41 | compile = True
42 |
43 | out_dir = 'out_large_sophiag_100k'
44 |
--------------------------------------------------------------------------------
/config/train_gpt2_medium_adam.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'sophia'
3 | wandb_run_name='gpt2-medium-adam-100k'
4 |
5 | # these make the total batch size be ~0.5M
6 | # 6 batch size * 1024 block size * 10 gradaccum * 8 GPUs = 491,520
7 | batch_size = 6
8 | block_size = 1024
9 | gradient_accumulation_steps = 8
10 |
11 | n_layer = 24
12 | n_head = 16
13 | n_embd = 1024
14 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
15 | bias = False
16 | scale_attn_by_inverse_layer_idx = True
17 |
18 | # this makes total number of tokens be 300B
19 | max_iters = 100000
20 | lr_decay_iters = 100000
21 |
22 | # eval stuff
23 | eval_interval = 1000
24 | eval_iters = 200
25 | log_interval = 10
26 |
27 | # optimizer
28 | optimizer_name = 'adamw'
29 | learning_rate = 3e-4 # max learning rate
30 | weight_decay = 1e-1
31 | beta1 = 0.9
32 | beta2 = 0.95
33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
34 | # learning rate decay settings
35 | decay_lr = True # whether to decay the learning rate
36 | warmup_iters = 2000 # how many steps to warm up for
37 | min_lr = 6e-5
38 |
39 | compile = True
40 |
41 | out_dir = 'out_medium_adam_100k'
42 |
--------------------------------------------------------------------------------
/config/train_gpt2_medium_sophiag.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'sophia'
3 | wandb_run_name='gpt2-medium-sophiag-100k'
4 |
5 | # these make the total batch size be ~0.5M
6 | # 6 batch size * 1024 block size * 10 gradaccum * 8 GPUs = 491,520
7 | batch_size = 10
8 | block_size = 1024
9 | gradient_accumulation_steps = 6
10 |
11 | n_layer = 24
12 | n_head = 16
13 | n_embd = 1024
14 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
15 | bias = False
16 | scale_attn_by_inverse_layer_idx = True
17 |
18 | # this makes total number of tokens be 300B
19 | max_iters = 100000
20 | lr_decay_iters = 100000
21 |
22 | # eval stuff
23 | eval_interval = 1000
24 | eval_iters = 200
25 | log_interval = 10
26 |
27 | # optimizer
28 | optimizer_name = 'sophiag'
29 | learning_rate = 7e-4 # max learning rate
30 | weight_decay = 2e-1
31 | beta1 = 0.965
32 | beta2 = 0.99
33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
34 | # learning rate decay settings
35 | decay_lr = True # whether to decay the learning rate
36 | warmup_iters = 2000 # how many steps to warm up for
37 | min_lr = 1e-5
38 | rho = 0.08
39 | interval = 10
40 |
41 | compile = True
42 |
43 | out_dir = 'out_medium_sophiag_100k'
44 |
--------------------------------------------------------------------------------
/config/train_gpt2_small_adam.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'sophia'
3 | wandb_run_name='gpt2-small-adam-100k'
4 |
5 | # these make the total batch size be ~0.5M
6 | # 8 batch size * 1024 block size * 6 gradaccum * 10 GPUs = 491,520
7 | batch_size = 8
8 | block_size = 1024
9 | gradient_accumulation_steps = 6
10 |
11 | n_layer = 12
12 | n_head = 12
13 | n_embd = 768
14 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
15 | bias = False
16 |
17 | # this makes total number of tokens be 300B
18 | max_iters = 100000
19 | lr_decay_iters = 100000
20 |
21 | # eval stuff
22 | eval_interval = 1000
23 | eval_iters = 200
24 | log_interval = 10
25 |
26 | # optimizer
27 | optimizer_name = 'adamw'
28 | learning_rate = 6e-4 # max learning rate
29 | weight_decay = 1e-1
30 | beta1 = 0.9
31 | beta2 = 0.95
32 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
33 | # learning rate decay settings
34 | decay_lr = True # whether to decay the learning rate
35 | warmup_iters = 2000 # how many steps to warm up for
36 | min_lr = 3e-5
37 |
38 | compile = True
39 |
40 | out_dir = 'out_small_adam_100k'
41 |
--------------------------------------------------------------------------------
/config/train_gpt2_small_sophiag.py:
--------------------------------------------------------------------------------
1 | wandb_log = True
2 | wandb_project = 'sophia'
3 | wandb_run_name='gpt2-small-sophiag-100k'
4 |
5 | # these make the total batch size be ~0.5M
6 | # 8 batch size * 1024 block size * 6 gradaccum * 10 GPUs = 491,520
7 | batch_size = 8
8 | block_size = 1024
9 | gradient_accumulation_steps = 6
10 | total_bs = 480
11 |
12 | n_layer = 12
13 | n_head = 12
14 | n_embd = 768
15 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
16 | bias = False
17 |
18 | # this makes total number of tokens be 300B
19 | max_iters = 100000
20 | lr_decay_iters = 100000
21 |
22 | # eval stuff
23 | eval_interval = 1000
24 | eval_iters = 200
25 | log_interval = 10
26 |
27 | # optimizer
28 | optimizer_name = 'sophiag'
29 | learning_rate = 6e-4 # max learning rate
30 | weight_decay = 2e-1
31 | beta1 = 0.965
32 | beta2 = 0.99
33 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
34 | # learning rate decay settings
35 | decay_lr = True # whether to decay the learning rate
36 | warmup_iters = 2000 # how many steps to warm up for
37 | min_lr = 1.5e-5
38 | rho = 0.05
39 | interval = 10
40 |
41 | compile = True
42 |
43 | out_dir = 'out_small_sophiag_100k'
44 |
--------------------------------------------------------------------------------
/configurator.py:
--------------------------------------------------------------------------------
1 | """
2 | Poor Man's Configurator. Probably a terrible idea. Example usage:
3 | $ python train.py config/override_file.py --batch_size=32
4 | this will first run config/override_file.py, then override batch_size to 32
5 |
6 | The code in this file will be run as follows from e.g. train.py:
7 | >>> exec(open('configurator.py').read())
8 |
9 | So it's not a Python module, it's just shuttling this code away from train.py
10 | The code in this script then overrides the globals()
11 |
12 | I know people are not going to love this, I just really dislike configuration
13 | complexity and having to prepend config. to every single variable. If someone
14 | comes up with a better simple Python solution I am all ears.
15 | """
16 |
17 | import sys
18 | from ast import literal_eval
19 |
20 | for arg in sys.argv[1:]:
21 | if '=' not in arg:
22 | # assume it's the name of a config file
23 | assert not arg.startswith('--')
24 | config_file = arg
25 | print(f"Overriding config with {config_file}:")
26 | with open(config_file) as f:
27 | print(f.read())
28 | exec(open(config_file).read())
29 | else:
30 | # assume it's a --key=value argument
31 | assert arg.startswith('--')
32 | key, val = arg.split('=')
33 | key = key[2:]
34 | if key in globals():
35 | try:
36 | # attempt to eval it it (e.g. if bool, number, or etc)
37 | attempt = literal_eval(val)
38 | except (SyntaxError, ValueError):
39 | # if that goes wrong, just use the string
40 | attempt = val
41 | # ensure the types match ok
42 | assert type(attempt) == type(globals()[key])
43 | # cross fingers
44 | print(f"Overriding: {key} = {attempt}")
45 | globals()[key] = attempt
46 | else:
47 | raise ValueError(f"Unknown config key: {key}")
48 |
--------------------------------------------------------------------------------
/data/openwebtext/prepare.py:
--------------------------------------------------------------------------------
1 | # saves the openwebtext dataset to a binary file for training. following was helpful:
2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
3 |
4 | import os
5 | from tqdm import tqdm
6 | import numpy as np
7 | import tiktoken
8 | from datasets import load_dataset # huggingface datasets
9 |
10 | # number of workers in .map() call
11 | # good number to use is ~order number of cpu cores // 2
12 | num_proc = 8
13 |
14 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
15 | dataset = load_dataset("openwebtext", cache_dir="/tiger/u/hliu99/nanoGPT/cache")
16 |
17 | # owt by default only contains the 'train' split, so create a test split
18 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
19 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val
20 |
21 | # this results in:
22 | # >>> split_dataset
23 | # DatasetDict({
24 | # train: Dataset({
25 | # features: ['text'],
26 | # num_rows: 8009762
27 | # })
28 | # val: Dataset({
29 | # features: ['text'],
30 | # num_rows: 4007
31 | # })
32 | # })
33 |
34 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
35 | enc = tiktoken.get_encoding("gpt2")
36 | def process(example):
37 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens
38 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
39 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
40 | out = {'ids': ids, 'len': len(ids)}
41 | return out
42 |
43 | # tokenize the dataset
44 | tokenized = split_dataset.map(
45 | process,
46 | remove_columns=['text'],
47 | desc="tokenizing the splits",
48 | num_proc=num_proc,
49 | )
50 | print('tokenization finished')
51 | # concatenate all the ids in each dataset into one large file we can use for training
52 | for split, dset in tokenized.items():
53 | arr_len = np.sum(dset['len'])
54 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
55 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
56 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
57 |
58 | print(f"writing {filename}...")
59 | idx = 0
60 | for example in tqdm(dset):
61 | arr[idx : idx + example['len']] = example['ids']
62 | idx += example['len']
63 | arr.flush()
64 |
65 | # train.bin is ~17GB, val.bin ~8.5MB
66 | # train has ~9B tokens (9,035,582,198)
67 | # val has ~4M tokens (4,434,897)
68 |
69 | # to read the bin files later, e.g. with numpy:
70 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r')
71 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import inspect
3 | from dataclasses import dataclass
4 | from sophia import SophiaG
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import functional as F
9 |
10 | optimizer_dict = {'adamw': torch.optim.AdamW,
11 | 'sophiag': SophiaG
12 | }
13 |
14 | # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
15 | def new_gelu(x):
16 | """
17 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
18 | Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
19 | """
20 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
21 |
22 | class LayerNorm(nn.Module):
23 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
24 |
25 | def __init__(self, ndim, bias):
26 | super().__init__()
27 | self.weight = nn.Parameter(torch.ones(ndim))
28 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
29 |
30 | def forward(self, input):
31 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
32 |
33 | class CausalSelfAttention(nn.Module):
34 |
35 | def __init__(self, config, idx_layer):
36 | super().__init__()
37 | assert config.n_embd % config.n_head == 0
38 | # key, query, value projections for all heads, but in a batch
39 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
40 | # output projection
41 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
42 | # regularization
43 | self.attn_dropout = nn.Dropout(config.dropout)
44 | self.resid_dropout = nn.Dropout(config.dropout)
45 | self.n_head = config.n_head
46 | self.n_embd = config.n_embd
47 | self.dropout = config.dropout
48 | self.idx_layer = idx_layer
49 | self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
50 |
51 | # causal mask to ensure that attention is only applied to the left in the input sequence
52 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
53 | .view(1, 1, config.block_size, config.block_size))
54 |
55 | def forward(self, x):
56 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
57 |
58 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
59 | q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
60 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
61 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
62 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
63 |
64 | if self.scale_attn_by_inverse_layer_idx:
65 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)) / float(self.idx_layer + 1))
66 | else:
67 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
68 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
69 | att = F.softmax(att, dim=-1)
70 | att = self.attn_dropout(att)
71 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
72 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
73 |
74 | # output projection
75 | y = self.resid_dropout(self.c_proj(y))
76 | return y
77 |
78 | class MLP(nn.Module):
79 |
80 | def __init__(self, config):
81 | super().__init__()
82 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
83 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
84 | self.dropout = nn.Dropout(config.dropout)
85 |
86 | def forward(self, x):
87 | x = self.c_fc(x)
88 | x = new_gelu(x)
89 | x = self.c_proj(x)
90 | x = self.dropout(x)
91 | return x
92 |
93 | class Block(nn.Module):
94 |
95 | def __init__(self, config, idx_layer):
96 | super().__init__()
97 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
98 | self.attn = CausalSelfAttention(config, idx_layer)
99 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
100 | self.mlp = MLP(config)
101 |
102 | def forward(self, x):
103 | x = x + self.attn(self.ln_1(x))
104 | x = x + self.mlp(self.ln_2(x))
105 | return x
106 |
107 | @dataclass
108 | class GPTConfig:
109 | block_size: int = 1024
110 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
111 | n_layer: int = 12
112 | n_head: int = 12
113 | n_embd: int = 768
114 | dropout: float = 0.0
115 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
116 | scale_attn_by_inverse_layer_idx: bool = False
117 |
118 |
119 | class GPT(nn.Module):
120 |
121 | def __init__(self, config):
122 | super().__init__()
123 | assert config.vocab_size is not None
124 | assert config.block_size is not None
125 | self.config = config
126 |
127 | self.transformer = nn.ModuleDict(dict(
128 | wte = nn.Embedding(config.vocab_size, config.n_embd),
129 | wpe = nn.Embedding(config.block_size, config.n_embd),
130 | drop = nn.Dropout(config.dropout),
131 | h = nn.ModuleList([Block(config, idx_layer) for idx_layer in range(config.n_layer)]),
132 | ln_f = LayerNorm(config.n_embd, bias=config.bias),
133 | ))
134 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
135 | # with weight tying when using torch.compile() some warnings get generated:
136 | # "UserWarning: functional_call was passed multiple values for tied weights.
137 | # This behavior is deprecated and will be an error in future versions"
138 | # not 100% sure what this is, so far seems to be harmless. TODO investigate
139 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
140 |
141 | # init all weights
142 | self.apply(self._init_weights)
143 | # apply special scaled init to the residual projections, per GPT-2 paper
144 | for pn, p in self.named_parameters():
145 | if pn.endswith('c_proj.weight'):
146 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
147 |
148 | # report number of parameters
149 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
150 |
151 | def get_num_params(self, non_embedding=True):
152 | """
153 | Return the number of parameters in the model.
154 | For non-embedding count (default), the position embeddings get subtracted.
155 | The token embeddings would too, except due to the parameter sharing these
156 | params are actually used as weights in the final layer, so we include them.
157 | """
158 | n_params = sum(p.numel() for p in self.parameters())
159 | if non_embedding:
160 | n_params -= self.transformer.wpe.weight.numel()
161 | return n_params
162 |
163 | def _init_weights(self, module):
164 | if isinstance(module, nn.Linear):
165 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
166 | if module.bias is not None:
167 | torch.nn.init.zeros_(module.bias)
168 | elif isinstance(module, nn.Embedding):
169 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
170 |
171 | def forward(self, idx, targets=None):
172 | device = idx.device
173 | b, t = idx.size()
174 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
175 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
176 |
177 | # forward the GPT model itself
178 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
179 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
180 | x = self.transformer.drop(tok_emb + pos_emb)
181 | for block in self.transformer.h:
182 | x = block(x)
183 | x = self.transformer.ln_f(x)
184 |
185 | if targets is not None:
186 | # if we are given some desired targets also calculate the loss
187 | if not isinstance(targets, int):
188 | logits = self.lm_head(x)
189 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
190 | else:
191 | logits = self.lm_head(x)
192 | loss = None
193 | else:
194 | # inference-time mini-optimization: only forward the lm_head on the very last position
195 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
196 | loss = None
197 |
198 | return logits, loss
199 |
200 | def crop_block_size(self, block_size):
201 | # model surgery to decrease the block size if necessary
202 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
203 | # but want to use a smaller block size for some smaller, simpler model
204 | assert block_size <= self.config.block_size
205 | self.config.block_size = block_size
206 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
207 | for block in self.transformer.h:
208 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
209 |
210 | @classmethod
211 | def from_pretrained(cls, model_type, override_args=None):
212 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
213 | override_args = override_args or {} # default to empty dict
214 | # only dropout can be overridden see more notes below
215 | assert all(k == 'dropout' for k in override_args)
216 | from transformers import GPT2LMHeadModel
217 | print("loading weights from pretrained gpt: %s" % model_type)
218 |
219 | # n_layer, n_head and n_embd are determined from model_type
220 | config_args = {
221 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
222 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
223 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
224 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
225 | }[model_type]
226 | print("forcing vocab_size=50257, block_size=1024, bias=True")
227 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
228 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
229 | config_args['bias'] = True # always True for GPT model checkpoints
230 | # we can override the dropout rate, if desired
231 | if 'dropout' in override_args:
232 | print(f"overriding dropout rate to {override_args['dropout']}")
233 | config_args['dropout'] = override_args['dropout']
234 | # create a from-scratch initialized minGPT model
235 | config = GPTConfig(**config_args)
236 | model = GPT(config)
237 | sd = model.state_dict()
238 | sd_keys = sd.keys()
239 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
240 |
241 | # init a huggingface/transformers model
242 | model_hf = GPT2LMHeadModel.from_pretrained(model_type)
243 | sd_hf = model_hf.state_dict()
244 |
245 | # copy while ensuring all of the parameters are aligned and match in names and shapes
246 | sd_keys_hf = sd_hf.keys()
247 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
248 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
249 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
250 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
251 | # this means that we have to transpose these weights when we import them
252 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
253 | for k in sd_keys_hf:
254 | if any(k.endswith(w) for w in transposed):
255 | # special treatment for the Conv1D weights we need to transpose
256 | assert sd_hf[k].shape[::-1] == sd[k].shape
257 | with torch.no_grad():
258 | sd[k].copy_(sd_hf[k].t())
259 | else:
260 | # vanilla copy over the other parameters
261 | assert sd_hf[k].shape == sd[k].shape
262 | with torch.no_grad():
263 | sd[k].copy_(sd_hf[k])
264 |
265 | return model
266 |
267 | def configure_optimizers(self, optimizer_name, weight_decay, learning_rate, betas, rho, device_type):
268 | """
269 | This long function is unfortunately doing something very simple and is being very defensive:
270 | We are separating out all parameters of the model into two buckets: those that will experience
271 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
272 | We are then returning the PyTorch optimizer object.
273 | """
274 |
275 | # separate out all parameters to those that will and won't experience regularizing weight decay
276 | decay = set()
277 | no_decay = set()
278 | whitelist_weight_modules = (torch.nn.Linear, )
279 | blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding)
280 | for mn, m in self.named_modules():
281 | for pn, p in m.named_parameters():
282 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
283 | # random note: because named_modules and named_parameters are recursive
284 | # we will see the same tensors p many many times. but doing it this way
285 | # allows us to know which parent module any tensor p belongs to...
286 | if pn.endswith('bias'):
287 | # all biases will not be decayed
288 | no_decay.add(fpn)
289 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
290 | # weights of whitelist modules will be weight decayed
291 | decay.add(fpn)
292 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
293 | # weights of blacklist modules will NOT be weight decayed
294 | no_decay.add(fpn)
295 |
296 | # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
297 | # will appear in the no_decay and decay sets respectively after the above.
298 | # In addition, because named_parameters() doesn't return duplicates, it
299 | # will only return the first occurrence, key'd by 'transformer.wte.weight', below.
300 | # so let's manually remove 'lm_head.weight' from decay set. This will include
301 | # this tensor into optimization via transformer.wte.weight only, and not decayed.
302 | decay.remove('lm_head.weight')
303 |
304 | # validate that we considered every parameter
305 | param_dict = {pn: p for pn, p in self.named_parameters()}
306 | inter_params = decay & no_decay
307 | union_params = decay | no_decay
308 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
309 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
310 | % (str(param_dict.keys() - union_params), )
311 |
312 | # create the pytorch optimizer object
313 | optim_groups = [
314 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
315 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
316 | ]
317 |
318 | opt_func = optimizer_dict[optimizer_name]
319 | if optimizer_name == 'adamw':
320 | # new PyTorch nightly has a new 'fused' option for AdamW that is much faster
321 | use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters)
322 | print(f"using fused AdamW: {use_fused}")
323 | extra_args = dict(fused=True) if use_fused else dict()
324 | optimizer = opt_func(optim_groups, lr=learning_rate, betas=betas, **extra_args)
325 | elif optimizer_name == 'sophiag':
326 | optimizer = opt_func(optim_groups, lr=learning_rate, betas=betas, rho=rho)
327 | else:
328 | raise ValueError('Invalid optimizer.')
329 | return optimizer
330 |
331 | def estimate_mfu(self, fwdbwd_per_iter, dt):
332 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
333 | # first estimate the number of flops we do per iteration.
334 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
335 | N = self.get_num_params()
336 | cfg = self.config
337 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
338 | flops_per_token = 6*N + 12*L*H*Q*T
339 | flops_per_fwdbwd = flops_per_token * T
340 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
341 | # express our flops throughput as ratio of A100 bfloat16 peak flops
342 | flops_achieved = flops_per_iter * (1.0/dt) # per second
343 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
344 | mfu = flops_achieved / flops_promised
345 | return mfu
346 |
347 | @torch.no_grad()
348 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
349 | """
350 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
351 | the sequence max_new_tokens times, feeding the predictions back into the model each time.
352 | Most likely you'll want to make sure to be in model.eval() mode of operation for this.
353 | """
354 | for _ in range(max_new_tokens):
355 | # if the sequence context is growing too long we must crop it at block_size
356 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
357 | # forward the model to get the logits for the index in the sequence
358 | logits, _ = self(idx_cond)
359 | # pluck the logits at the final step and scale by desired temperature
360 | logits = logits[:, -1, :] / temperature
361 | # optionally crop the logits to only the top k options
362 | if top_k is not None:
363 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
364 | logits[logits < v[:, [-1]]] = -float('Inf')
365 | # apply softmax to convert logits to (normalized) probabilities
366 | probs = F.softmax(logits, dim=-1)
367 | # sample from the distribution
368 | idx_next = torch.multinomial(probs, num_samples=1)
369 | # append sampled index to the running sequence and continue
370 | idx = torch.cat((idx, idx_next), dim=1)
371 |
372 | return idx
373 |
--------------------------------------------------------------------------------
/sophia.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import Tensor
4 | from torch.optim.optimizer import Optimizer
5 | from typing import List, Optional
6 |
7 |
8 | class SophiaG(Optimizer):
9 | def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho = 0.04,
10 | weight_decay=1e-1, *, maximize: bool = False,
11 | capturable: bool = False):
12 | if not 0.0 <= lr:
13 | raise ValueError("Invalid learning rate: {}".format(lr))
14 | if not 0.0 <= betas[0] < 1.0:
15 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
16 | if not 0.0 <= betas[1] < 1.0:
17 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
18 | if not 0.0 <= rho:
19 | raise ValueError("Invalid rho parameter at index 1: {}".format(rho))
20 | if not 0.0 <= weight_decay:
21 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
22 | defaults = dict(lr=lr, betas=betas, rho=rho,
23 | weight_decay=weight_decay,
24 | maximize=maximize, capturable=capturable)
25 | super(SophiaG, self).__init__(params, defaults)
26 |
27 | def __setstate__(self, state):
28 | super().__setstate__(state)
29 | for group in self.param_groups:
30 | group.setdefault('maximize', False)
31 | group.setdefault('capturable', False)
32 | state_values = list(self.state.values())
33 | step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
34 | if not step_is_tensor:
35 | for s in state_values:
36 | s['step'] = torch.tensor(float(s['step']))
37 |
38 | @torch.no_grad()
39 | def update_hessian(self):
40 | for group in self.param_groups:
41 | beta1, beta2 = group['betas']
42 | for p in group['params']:
43 | if p.grad is None:
44 | continue
45 | state = self.state[p]
46 |
47 | if len(state) == 0:
48 | state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
49 | if self.defaults['capturable'] else torch.tensor(0.)
50 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
51 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
52 |
53 | if 'hessian' not in state.keys():
54 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
55 |
56 | state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)
57 |
58 |
59 | @torch.no_grad()
60 | def step(self, closure=None, bs=5120):
61 | loss = None
62 | if closure is not None:
63 | with torch.enable_grad():
64 | loss = closure()
65 |
66 | for group in self.param_groups:
67 | params_with_grad = []
68 | grads = []
69 | exp_avgs = []
70 | state_steps = []
71 | hessian = []
72 | beta1, beta2 = group['betas']
73 |
74 | for p in group['params']:
75 | if p.grad is None:
76 | continue
77 | params_with_grad.append(p)
78 |
79 | if p.grad.is_sparse:
80 | raise RuntimeError('Hero does not support sparse gradients')
81 | grads.append(p.grad)
82 | state = self.state[p]
83 | # State initialization
84 | if len(state) == 0:
85 | state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
86 | if self.defaults['capturable'] else torch.tensor(0.)
87 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
88 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
89 |
90 | if 'hessian' not in state.keys():
91 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
92 |
93 | exp_avgs.append(state['exp_avg'])
94 | state_steps.append(state['step'])
95 | hessian.append(state['hessian'])
96 |
97 | if self.defaults['capturable']:
98 | bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs
99 |
100 | sophiag(params_with_grad,
101 | grads,
102 | exp_avgs,
103 | hessian,
104 | state_steps,
105 | bs=bs,
106 | beta1=beta1,
107 | beta2=beta2,
108 | rho=group['rho'],
109 | lr=group['lr'],
110 | weight_decay=group['weight_decay'],
111 | maximize=group['maximize'],
112 | capturable=group['capturable'])
113 |
114 | return loss
115 |
116 | def sophiag(params: List[Tensor],
117 | grads: List[Tensor],
118 | exp_avgs: List[Tensor],
119 | hessian: List[Tensor],
120 | state_steps: List[Tensor],
121 | capturable: bool = False,
122 | *,
123 | bs: int,
124 | beta1: float,
125 | beta2: float,
126 | rho: float,
127 | lr: float,
128 | weight_decay: float,
129 | maximize: bool):
130 |
131 | if not all(isinstance(t, torch.Tensor) for t in state_steps):
132 | raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
133 |
134 |
135 | func = _single_tensor_sophiag
136 |
137 | func(params,
138 | grads,
139 | exp_avgs,
140 | hessian,
141 | state_steps,
142 | bs=bs,
143 | beta1=beta1,
144 | beta2=beta2,
145 | rho=rho,
146 | lr=lr,
147 | weight_decay=weight_decay,
148 | maximize=maximize,
149 | capturable=capturable)
150 |
151 | def _single_tensor_sophiag(params: List[Tensor],
152 | grads: List[Tensor],
153 | exp_avgs: List[Tensor],
154 | hessian: List[Tensor],
155 | state_steps: List[Tensor],
156 | *,
157 | bs: int,
158 | beta1: float,
159 | beta2: float,
160 | rho: float,
161 | lr: float,
162 | weight_decay: float,
163 | maximize: bool,
164 | capturable: bool):
165 |
166 | for i, param in enumerate(params):
167 | grad = grads[i] if not maximize else -grads[i]
168 | exp_avg = exp_avgs[i]
169 | hess = hessian[i]
170 | step_t = state_steps[i]
171 |
172 | if capturable:
173 | assert param.is_cuda and step_t.is_cuda and bs.is_cuda
174 |
175 | if torch.is_complex(param):
176 | grad = torch.view_as_real(grad)
177 | exp_avg = torch.view_as_real(exp_avg)
178 | hess = torch.view_as_real(hess)
179 | param = torch.view_as_real(param)
180 |
181 | # update step
182 | step_t += 1
183 |
184 | # Perform stepweight decay
185 | param.mul_(1 - lr * weight_decay)
186 |
187 | # Decay the first and second moment running average coefficient
188 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
189 |
190 | if capturable:
191 | step_size = lr
192 | step_size_neg = step_size.neg()
193 |
194 | ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1)
195 | param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)
196 | else:
197 | step_size_neg = - lr
198 |
199 | ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1)
200 | param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)
--------------------------------------------------------------------------------
/train_adam.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import math
4 | import pickle
5 | from contextlib import nullcontext
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | from torch.distributed import init_process_group, destroy_process_group
11 |
12 | from model import GPTConfig, GPT
13 |
14 | # -----------------------------------------------------------------------------
15 | # default config values designed to train a gpt2 (124M) on OpenWebText
16 | # I/O
17 | out_dir = 'out'
18 | eval_interval = 2000
19 | log_interval = 1
20 | eval_iters = 200
21 | eval_only = False # if True, script exits right after the first eval
22 | always_save_checkpoint = True # if True, always save a checkpoint after each eval
23 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
24 | # wandb logging
25 | wandb_log = False # disabled by default
26 | wandb_project = 'owt'
27 | wandb_run_name = 'gpt2' # 'run' + str(time.time())
28 | # data
29 | dataset = 'openwebtext'
30 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes
31 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
32 | block_size = 1024
33 | # model
34 | n_layer = 12
35 | n_head = 12
36 | n_embd = 768
37 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
38 | bias = False # do we use bias inside LayerNorm and Linear layers?
39 | # optimizer
40 | optimizer_name = 'adamw'
41 | learning_rate = 6e-4 # max learning rate
42 | max_iters = 600000 # total number of training iterations
43 | weight_decay = 1e-1
44 | beta1 = 0.9
45 | beta2 = 0.95
46 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
47 | rho = 0.1
48 | interval = 10
49 | variant = 4
50 | # learning rate decay settings
51 | decay_lr = True # whether to decay the learning rate
52 | warmup_iters = 2000 # how many steps to warm up for
53 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
54 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
55 | # DDP settings
56 | backend = 'nccl' # 'nccl', 'gloo', etc.
57 | # system
58 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
59 | dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
60 | compile = True # use PyTorch 2.0 to compile the model to be faster
61 | scale_attn_by_inverse_layer_idx = True
62 | # -----------------------------------------------------------------------------
63 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
64 | exec(open('configurator.py').read()) # overrides from command line or config file
65 | config = {k: globals()[k] for k in config_keys} # will be useful for logging
66 | # -----------------------------------------------------------------------------
67 |
68 | # various inits, derived attributes, I/O setup
69 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
70 | if ddp:
71 | init_process_group(backend=backend)
72 | ddp_rank = int(os.environ['RANK'])
73 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
74 | device = f'cuda:{ddp_local_rank}'
75 | torch.cuda.set_device(device)
76 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
77 | seed_offset = ddp_rank # each process gets a different seed
78 | else:
79 | # if not ddp, we are running on a single gpu, and one process
80 | master_process = True
81 | seed_offset = 0
82 | gradient_accumulation_steps *= 8 # simulate 8 gpus
83 |
84 | if master_process:
85 | os.makedirs(out_dir, exist_ok=True)
86 | torch.manual_seed(5000 + seed_offset)
87 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
88 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
89 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
90 | # note: float16 data type will automatically use a GradScaler
91 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
92 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype)
93 |
94 | # poor man's data loader
95 | data_dir = os.path.join('data', dataset)
96 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
97 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
98 | def get_batch(split):
99 | data = train_data if split == 'train' else val_data
100 | ix = torch.randint(len(data) - block_size, (batch_size,))
101 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
102 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
103 | if device_type == 'cuda':
104 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
105 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
106 | else:
107 | x, y = x.to(device), y.to(device)
108 | return x, y
109 |
110 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
111 | iter_num = 0
112 | best_val_loss = 1e9
113 |
114 | # attempt to derive vocab_size from the dataset
115 | meta_path = os.path.join(data_dir, 'meta.pkl')
116 | meta_vocab_size = None
117 | if os.path.exists(meta_path):
118 | with open(meta_path, 'rb') as f:
119 | meta = pickle.load(f)
120 | meta_vocab_size = meta['vocab_size']
121 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
122 |
123 | # model init
124 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
125 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line
126 | if init_from == 'scratch':
127 | # init a new model from scratch
128 | print("Initializing a new model from scratch")
129 | # determine the vocab size we'll use for from-scratch training
130 | if meta_vocab_size is None:
131 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
132 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
133 | gptconf = GPTConfig(**model_args)
134 | model = GPT(gptconf)
135 | elif init_from == 'resume':
136 | print(f"Resuming training from {out_dir}")
137 | # resume training from a checkpoint.
138 | ckpt_path = os.path.join(out_dir, 'ckpt.pt')
139 | checkpoint = torch.load(ckpt_path, map_location=device)
140 | checkpoint_model_args = checkpoint['model_args']
141 | # force these config attributes to be equal otherwise we can't even resume training
142 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
143 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
144 | model_args[k] = checkpoint_model_args[k]
145 | # create the model
146 | gptconf = GPTConfig(**model_args)
147 | model = GPT(gptconf)
148 | state_dict = checkpoint['model']
149 | # fix the keys of the state dictionary :(
150 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
151 | unwanted_prefix = '_orig_mod.'
152 | for k,v in list(state_dict.items()):
153 | if k.startswith(unwanted_prefix):
154 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
155 | model.load_state_dict(state_dict)
156 | iter_num = checkpoint['iter_num']
157 | best_val_loss = checkpoint['best_val_loss']
158 | elif init_from.startswith('gpt2'):
159 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
160 | # initialize from OpenAI GPT-2 weights
161 | override_args = dict(dropout=dropout)
162 | model = GPT.from_pretrained(init_from, override_args)
163 | # read off the created config params, so we can store them into checkpoint correctly
164 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
165 | model_args[k] = getattr(model.config, k)
166 | # crop down the model block size if desired, using model surgery
167 | if block_size < model.config.block_size:
168 | model.crop_block_size(block_size)
169 | model_args['block_size'] = block_size # so that the checkpoint will have the right value
170 | model.to(device)
171 |
172 | # initialize a GradScaler. If enabled=False scaler is a no-op
173 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
174 |
175 | # optimizer
176 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), rho, device_type)
177 | if init_from == 'resume':
178 | optimizer.load_state_dict(checkpoint['optimizer'])
179 | del state_dict
180 | del checkpoint
181 | # compile the model
182 | if compile:
183 | print("compiling the model... (takes a ~minute)")
184 | unoptimized_model = model
185 | model = torch.compile(model) # requires PyTorch 2.0
186 |
187 | # wrap model into DDP container
188 | if ddp:
189 | model = DDP(model, device_ids=[ddp_local_rank])
190 |
191 | # helps estimate an arbitrarily accurate loss over either split using many batches
192 | @torch.no_grad()
193 | def estimate_loss():
194 | out = {}
195 | model.eval()
196 | for split in ['train', 'val']:
197 | losses = torch.zeros(eval_iters)
198 | for k in range(eval_iters):
199 | X, Y = get_batch(split)
200 | with ctx:
201 | logits, loss = model(X, Y)
202 | losses[k] = loss.item()
203 | out[split] = losses.mean()
204 | model.train()
205 | return out
206 |
207 | # learning rate decay scheduler (cosine with warmup)
208 | def get_lr(it):
209 | # 1) linear warmup for warmup_iters steps
210 | if it < warmup_iters:
211 | return learning_rate * it / warmup_iters
212 | # 2) if it > lr_decay_iters, return min learning rate
213 | if it > lr_decay_iters:
214 | return min_lr
215 | # 3) in between, use cosine decay down to min learning rate
216 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
217 | assert 0 <= decay_ratio <= 1
218 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
219 | return min_lr + coeff * (learning_rate - min_lr)
220 |
221 | # logging
222 | if wandb_log and master_process:
223 | import wandb
224 | wandb.init(project=wandb_project, name=wandb_run_name, config=config)
225 |
226 | # training loop
227 | X, Y = get_batch('train') # fetch the very first batch
228 | t0 = time.time()
229 | local_iter_num = 0 # number of iterations in the lifetime of this process
230 | raw_model = model.module if ddp else model # unwrap DDP container if needed
231 | running_mfu = -1.0
232 | clip_time = 0
233 | while True:
234 |
235 | # determine and set the learning rate for this iteration
236 | lr = get_lr(iter_num) if decay_lr else learning_rate
237 | for param_group in optimizer.param_groups:
238 | param_group['lr'] = lr
239 |
240 | # evaluate the loss on train/val sets and write checkpoints
241 | if iter_num % eval_interval == 0 and master_process:
242 | losses = estimate_loss()
243 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
244 | if wandb_log:
245 | wandb.log({
246 | "iter": iter_num,
247 | "train/loss": losses['train'],
248 | "val/loss": losses['val'],
249 | "lr": lr,
250 | "mfu": running_mfu*100, # convert to percentage
251 | }, step=iter_num)
252 | if losses['val'] < best_val_loss or always_save_checkpoint:
253 | best_val_loss = losses['val']
254 | if iter_num > 0:
255 | checkpoint = {
256 | 'model': raw_model.state_dict(),
257 | 'optimizer': optimizer.state_dict(),
258 | 'model_args': model_args,
259 | 'iter_num': iter_num,
260 | 'best_val_loss': best_val_loss,
261 | 'config': config,
262 | }
263 | print(f"saving checkpoint to {out_dir}")
264 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
265 | if iter_num % (eval_interval * 5) == 0:
266 | checkpoint = {
267 | 'model': raw_model.state_dict(),
268 | 'optimizer': optimizer.state_dict(),
269 | 'model_args': model_args,
270 | 'iter_num': iter_num,
271 | 'best_val_loss': best_val_loss,
272 | 'config': config,
273 | }
274 | print(f"saving checkpoint to {out_dir}")
275 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt'))
276 | if iter_num == 0 and eval_only:
277 | break
278 |
279 | # forward backward update, with optional gradient accumulation to simulate larger batch size
280 | # and using the GradScaler if data type is float16
281 | for micro_step in range(gradient_accumulation_steps):
282 | if ddp:
283 | # in DDP training we only need to sync gradients at the last micro step.
284 | # the official way to do this is with model.no_sync() context manager, but
285 | # I really dislike that this bloats the code and forces us to repeat code
286 | # looking at the source of that context manager, it just toggles this variable
287 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
288 | with ctx:
289 | logits, loss = model(X, Y)
290 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
291 | X, Y = get_batch('train')
292 | # backward pass, with gradient scaling if training in fp16
293 | scaler.scale(loss).backward()
294 | # clip the gradient
295 | if grad_clip != 0.0:
296 | scaler.unscale_(optimizer)
297 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
298 | if total_norm.item() > grad_clip:
299 | clip_time += 1
300 | # step the optimizer and scaler if training in fp16
301 | scaler.step(optimizer)
302 | scaler.update()
303 | # flush the gradients as soon as we can, no need for this memory anymore
304 | optimizer.zero_grad(set_to_none=True)
305 |
306 | # timing and logging
307 | t1 = time.time()
308 | dt = t1 - t0
309 | t0 = t1
310 | if iter_num % log_interval == 0 and master_process:
311 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
312 | if local_iter_num >= 5: # let the training loop settle a bit
313 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
314 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
315 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
316 | params = []
317 | for (name, p) in model.named_parameters():
318 | params.append(p)
319 | total_param_norm = 0
320 | for p in params:
321 | param_norm = p.data.norm(2)
322 | total_param_norm += param_norm.item() ** 2
323 | total_param_norm = total_param_norm ** 0.5
324 | momentum_norm = 0
325 | LL = len(optimizer.state_dict()['state'])
326 | for jj in range(LL):
327 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2
328 | momentum_norm = torch.sqrt(momentum_norm).item()
329 | if wandb_log:
330 | wandb.log({
331 | "iter": iter_num,
332 | "train/loss": lossf,
333 | "lr": lr,
334 | "param_norm": total_param_norm,
335 | "momentum_norm" : momentum_norm,
336 | "train/clip_rate": clip_time / (iter_num + 1)
337 | }, step=iter_num)
338 | iter_num += 1
339 | local_iter_num += 1
340 |
341 | # termination conditions
342 | if iter_num > max_iters:
343 | break
344 |
345 | if ddp:
346 | destroy_process_group()
347 |
--------------------------------------------------------------------------------
/train_sophiag.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import math
4 | import pickle
5 | from contextlib import nullcontext
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.nn.parallel import DistributedDataParallel as DDP
11 | from torch.distributed import init_process_group, destroy_process_group
12 | from model import GPTConfig, GPT
13 | import torch.autograd as autograd
14 |
15 | # -----------------------------------------------------------------------------
16 | # default config values designed to train a gpt2 (124M) on OpenWebText
17 | # I/O
18 | out_dir = 'out'
19 | eval_interval = 2000
20 | log_interval = 1
21 | eval_iters = 200
22 | eval_only = False # if True, script exits right after the first eval
23 | always_save_checkpoint = True # if True, always save a checkpoint after each eval
24 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
25 | # wandb logging
26 | wandb_log = False # disabled by default
27 | wandb_project = 'owt'
28 | wandb_run_name = 'gpt2' # 'run' + str(time.time())
29 | # data
30 | dataset = 'openwebtext'
31 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes
32 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
33 | block_size = 1024
34 | total_bs = 480
35 | # model
36 | n_layer = 12
37 | n_head = 12
38 | n_embd = 768
39 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
40 | bias = False # do we use bias inside LayerNorm and Linear layers?
41 | # optimizer
42 | optimizer_name = 'sophiag'
43 | learning_rate = 3e-4 # max learning rate
44 | max_iters = 600000 # total number of training iterations
45 | weight_decay = 1e-1
46 | beta1 = 0.9
47 | beta2 = 0.95
48 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
49 | rho = 0.03
50 | interval = 10
51 | hess_interval = interval
52 | variant = 4
53 | # learning rate decay settings
54 | decay_lr = True # whether to decay the learning rate
55 | warmup_iters = 2000 # how many steps to warm up for
56 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
57 | min_lr = 1.5e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
58 | # DDP settings
59 | backend = 'nccl' # 'nccl', 'gloo', etc.
60 | # system
61 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
62 | dtype = 'bfloat16' # 'float32', 'bfloat16'
63 | compile = True # use PyTorch 2.0 to compile the model to be faster
64 | scale_attn_by_inverse_layer_idx = True
65 | # -----------------------------------------------------------------------------
66 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
67 | exec(open('configurator.py').read()) # overrides from command line or config file
68 | config = {k: globals()[k] for k in config_keys} # will be useful for logging
69 | # -----------------------------------------------------------------------------
70 |
71 | # various inits, derived attributes, I/O setup
72 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
73 | if ddp:
74 | init_process_group(backend=backend)
75 | ddp_rank = int(os.environ['RANK'])
76 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
77 | device = f'cuda:{ddp_local_rank}'
78 | torch.cuda.set_device(device)
79 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
80 | seed_offset = ddp_rank # each process gets a different seed
81 | else:
82 | # if not ddp, we are running on a single gpu, and one process
83 | ddp_rank = 0 #ddp_rank is used in get_batch function so this has to be here also when running locally
84 | master_process = True
85 | seed_offset = 0
86 | gradient_accumulation_steps *= 8 # simulate 8 gpus
87 |
88 | if master_process:
89 | os.makedirs(out_dir, exist_ok=True)
90 | torch.manual_seed(2099)
91 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
92 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
93 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
94 | # note: float16 data type will automatically use a GradScaler
95 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
96 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
97 |
98 | # poor man's data loader
99 | data_dir = os.path.join('data', dataset)
100 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
101 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
102 | def get_batch(split):
103 | data = train_data if split == 'train' else val_data
104 | ix_list = []
105 | for jj in range(10):
106 | ix_list.append(torch.randint(len(data) - block_size, (batch_size,)))
107 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix_list[ddp_rank]])
108 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix_list[ddp_rank]])
109 | if device_type == 'cuda':
110 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
111 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
112 | else:
113 | x, y = x.to(device), y.to(device)
114 | return x, y
115 |
116 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
117 | iter_num = 0
118 | best_val_loss = 1e9
119 |
120 | # attempt to derive vocab_size from the dataset
121 | meta_path = os.path.join(data_dir, 'meta.pkl')
122 | meta_vocab_size = None
123 | if os.path.exists(meta_path):
124 | with open(meta_path, 'rb') as f:
125 | meta = pickle.load(f)
126 | meta_vocab_size = meta['vocab_size']
127 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
128 |
129 | # model init
130 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
131 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line
132 | if init_from == 'scratch':
133 | # init a new model from scratch
134 | print("Initializing a new model from scratch")
135 | # determine the vocab size we'll use for from-scratch training
136 | if meta_vocab_size is None:
137 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
138 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
139 | gptconf = GPTConfig(**model_args)
140 | model = GPT(gptconf)
141 | elif init_from == 'resume':
142 | print(f"Resuming training from {out_dir}")
143 | # resume training from a checkpoint.
144 | ckpt_path = os.path.join(out_dir, 'ckpt.pt')
145 | checkpoint = torch.load(ckpt_path, map_location=device)
146 | checkpoint_model_args = checkpoint['model_args']
147 | # force these config attributes to be equal otherwise we can't even resume training
148 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
149 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
150 | model_args[k] = checkpoint_model_args[k]
151 | # create the model
152 | gptconf = GPTConfig(**model_args)
153 | model = GPT(gptconf)
154 | state_dict = checkpoint['model']
155 | # fix the keys of the state dictionary :(
156 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
157 | unwanted_prefix = '_orig_mod.'
158 | for k,v in list(state_dict.items()):
159 | if k.startswith(unwanted_prefix):
160 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
161 | model.load_state_dict(state_dict)
162 | iter_num = checkpoint['iter_num']
163 | best_val_loss = checkpoint['best_val_loss']
164 | elif init_from.startswith('gpt2'):
165 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
166 | # initialize from OpenAI GPT-2 weights
167 | override_args = dict(dropout=dropout)
168 | model = GPT.from_pretrained(init_from, override_args)
169 | # read off the created config params, so we can store them into checkpoint correctly
170 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
171 | model_args[k] = getattr(model.config, k)
172 | # crop down the model block size if desired, using model surgery
173 | if block_size < model.config.block_size:
174 | model.crop_block_size(block_size)
175 | model_args['block_size'] = block_size # so that the checkpoint will have the right value
176 | model.to(device)
177 |
178 |
179 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), rho, device_type)
180 | if init_from == 'resume':
181 | optimizer.load_state_dict(checkpoint['optimizer'])
182 | del state_dict
183 | del checkpoint
184 | # compile the model
185 | if compile:
186 | print("compiling the model... (takes a ~minute)")
187 | unoptimized_model = model
188 | model = torch.compile(model) # requires PyTorch 2.0
189 |
190 | # wrap model into DDP container
191 | if ddp:
192 | model = DDP(model, device_ids=[ddp_local_rank])
193 |
194 | # helps estimate an arbitrarily accurate loss over either split using many batches
195 | @torch.no_grad()
196 | def estimate_loss():
197 | out = {}
198 | model.eval()
199 | for split in ['train', 'val']:
200 | losses = torch.zeros(eval_iters)
201 | for k in range(eval_iters):
202 | X, Y = get_batch(split)
203 | with ctx:
204 | logits, loss = model(X, Y)
205 | losses[k] = loss.item()
206 | out[split] = losses.mean()
207 | model.train()
208 | return out
209 |
210 | # learning rate decay scheduler (cosine with warmup)
211 | def get_lr(it):
212 | # 1) linear warmup for warmup_iters steps
213 | if it < warmup_iters:
214 | return learning_rate * it / warmup_iters
215 | # 2) if it > lr_decay_iters, return min learning rate
216 | if it > lr_decay_iters:
217 | return min_lr
218 | # 3) in between, use cosine decay down to min learning rate
219 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
220 | assert 0 <= decay_ratio <= 1
221 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
222 | return min_lr + coeff * (learning_rate - min_lr)
223 |
224 | # logging
225 | if wandb_log and master_process:
226 | import wandb
227 | wandb.init(project=wandb_project, name=wandb_run_name, config=config)
228 |
229 | # training loop
230 | X, Y = get_batch('train') # fetch the very first batch
231 | t0 = time.time()
232 | local_iter_num = 0 # number of iterations in the lifetime of this process
233 | raw_model = model.module if ddp else model # unwrap DDP container if needed
234 | running_mfu = -1.0
235 | num_param = 1
236 | num_effective = 0
237 | momentum_norm = 0
238 | hessian_norm = 0
239 | hessian_norm2 = 0
240 | clip_time = 0
241 | while True:
242 |
243 | # determine and set the learning rate for this iteration
244 | lr = get_lr(iter_num) if decay_lr else learning_rate
245 | for param_group in optimizer.param_groups:
246 | param_group['lr'] = lr
247 |
248 | # evaluate the loss on train/val sets and write checkpoints
249 | if iter_num % eval_interval == 0 and master_process:
250 | losses = estimate_loss()
251 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
252 | if wandb_log:
253 | wandb.log({
254 | "iter": iter_num,
255 | "train/loss": losses['train'],
256 | "val/loss": losses['val'],
257 | "lr": lr,
258 | "mfu": running_mfu*100, # convert to percentage
259 | }, step=iter_num)
260 | if losses['val'] < best_val_loss or always_save_checkpoint:
261 | best_val_loss = losses['val']
262 | if iter_num > 0:
263 | checkpoint = {
264 | 'model': raw_model.state_dict(),
265 | 'optimizer': optimizer.state_dict(),
266 | 'model_args': model_args,
267 | 'iter_num': iter_num,
268 | 'best_val_loss': best_val_loss,
269 | 'config': config,
270 | }
271 | print(f"saving checkpoint to {out_dir}")
272 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
273 | if iter_num % (eval_interval * 5) == 0:
274 | checkpoint = {
275 | 'model': raw_model.state_dict(),
276 | 'optimizer': optimizer.state_dict(),
277 | 'model_args': model_args,
278 | 'iter_num': iter_num,
279 | 'best_val_loss': best_val_loss,
280 | 'config': config,
281 | }
282 | print(f"saving checkpoint to {out_dir}")
283 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt'))
284 | if iter_num == 0 and eval_only:
285 | break
286 |
287 | # forward backward update, with optional gradient accumulation to simulate larger batch size
288 | # and using the GradScaler if data type is float16
289 | if iter_num % hess_interval != hess_interval - 1:
290 | for micro_step in range(gradient_accumulation_steps):
291 | if ddp:
292 | # in DDP training we only need to sync gradients at the last micro step.
293 | # the official way to do this is with model.no_sync() context manager, but
294 | # I really dislike that this bloats the code and forces us to repeat code
295 | # looking at the source of that context manager, it just toggles this variable
296 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
297 | with ctx:
298 | logits, loss = model(X, Y)
299 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
300 | X, Y = get_batch('train')
301 | # backward pass, with gradient scaling if training in fp16
302 | (loss / gradient_accumulation_steps).backward()
303 | # clip the gradient
304 | if grad_clip != 0.0:
305 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
306 | if total_norm.item() > grad_clip:
307 | clip_time += 1
308 | # step the optimizer and scaler if training in fp16
309 | optimizer.step(bs=total_bs * block_size)
310 | # flush the gradients as soon as we can, no need for this memory anymore
311 | optimizer.zero_grad(set_to_none=True)
312 |
313 | # timing and logging
314 | t1 = time.time()
315 | dt = t1 - t0
316 | t0 = t1
317 | if iter_num % log_interval == 0 and master_process:
318 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
319 | if local_iter_num >= 5: # let the training loop settle a bit
320 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
321 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
322 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
323 | total_param_norm = 0
324 | momentum_norm = 0
325 | params = []
326 | for (name, p) in model.named_parameters():
327 | params.append(p)
328 | for p in params:
329 | param_norm = p.data.norm(2)
330 | total_param_norm += param_norm.item() ** 2
331 | total_param_norm = total_param_norm ** 0.5
332 | momentum_norm = 0
333 | LL = len(optimizer.state_dict()['state'])
334 | for jj in range(LL):
335 | momentum_norm += (optimizer.state_dict()['state'][jj]['exp_avg'].detach().norm(2)) ** 2
336 | momentum_norm = torch.sqrt(momentum_norm).item()
337 | if wandb_log:
338 | wandb.log({
339 | "iter": iter_num,
340 | "train/loss": lossf,
341 | "lr": lr,
342 | "param_norm": total_param_norm,
343 | "momentum_norm" : momentum_norm,
344 | "hessian_norm": hessian_norm,
345 | "hessian_norm2": hessian_norm2,
346 | "train/win_rate": num_effective / num_param,
347 | "train/clip_rate": clip_time / (iter_num + 1)
348 |
349 | }, step=iter_num)
350 | iter_num += 1
351 | local_iter_num += 1
352 | else:
353 | for micro_step in range(gradient_accumulation_steps):
354 | if ddp:
355 | # in DDP training we only need to sync gradients at the last micro step.
356 | # the official way to do this is with model.no_sync() context manager, but
357 | # I really dislike that this bloats the code and forces us to repeat code
358 | # looking at the source of that context manager, it just toggles this variable
359 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
360 | with ctx:
361 | logits, loss = model(X, Y)
362 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
363 | X, Y = get_batch('train')
364 | # backward pass, with gradient scaling if training in fp16
365 | (loss / gradient_accumulation_steps).backward()
366 | # clip the gradient
367 | if grad_clip != 0.0:
368 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
369 | if total_norm.item() > grad_clip:
370 | clip_time += 1
371 | # step the optimizer and scaler if training in fp16
372 | optimizer.step(bs=total_bs * block_size)
373 | # flush the gradients as soon as we can, no need for this memory anymore
374 | optimizer.zero_grad(set_to_none=True)
375 |
376 | # timing and logging
377 | t1 = time.time()
378 | dt = t1 - t0
379 | t0 = t1
380 | if iter_num % log_interval == 0 and master_process:
381 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
382 | if local_iter_num >= 5: # let the training loop settle a bit
383 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
384 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
385 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
386 |
387 | total_param_norm = 0
388 | momentum_norm = 0
389 |
390 | if wandb_log:
391 | wandb.log({
392 | "iter": iter_num,
393 | "train/loss": lossf,
394 | "lr": lr,
395 | "param_norm": total_param_norm,
396 | "momentum_norm" : momentum_norm,
397 | "hessian_norm": hessian_norm.item(),
398 | "hessian_norm2": hessian_norm2,
399 | "train/win_rate": num_effective / num_param,
400 | "train/clip_rate": clip_time / (iter_num + 1)
401 |
402 | }, step=iter_num)
403 | iter_num += 1
404 | local_iter_num += 1
405 |
406 | for micro_step in range(gradient_accumulation_steps):
407 | if ddp:
408 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
409 | with ctx:
410 | logits, _ = model(X, 0)
411 | X, Y = get_batch('train')
412 | samp_dist = torch.distributions.Categorical(logits=logits)
413 | y_sample = samp_dist.sample()
414 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y_sample.view(-1), ignore_index=-1)
415 | # backward pass, with gradient scaling if training in fp16
416 | (loss / gradient_accumulation_steps).backward()
417 | # clip the gradient
418 | if grad_clip != 0.0:
419 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
420 | # step the optimizer and scaler if training in fp16
421 | optimizer.update_hessian()
422 | # flush the gradients as soon as we can, no need for this memory anymore
423 | optimizer.zero_grad(set_to_none=True)
424 |
425 | num_param = 0
426 | num_effective = 0
427 | hessian_norm = 0
428 | hessian_norm2 = 0
429 |
430 | LL = len(optimizer.state_dict()['state'])
431 |
432 | for jj in range(LL):
433 | num_param += optimizer.state_dict()['state'][jj]['exp_avg'].numel()
434 | num_effective += torch.sum(torch.abs(optimizer.state_dict()['state'][jj]['exp_avg']) < rho * total_bs * block_size * optimizer.state_dict()['state'][jj]['hessian'])
435 | hessian_norm += optimizer.state_dict()['state'][jj]['hessian'].detach().norm(1).item()
436 | hessian_norm2 += optimizer.state_dict()['state'][jj]['hessian'].detach().norm(2).item() ** 2
437 | hessian_norm2 = hessian_norm2 ** 0.5
438 |
439 |
440 |
441 | t1 = time.time()
442 | dt = t1 - t0
443 | t0 = t1
444 | if master_process:
445 | # loss as float. note: this is a CPU-GPU sync point
446 | if local_iter_num >= 5: # let the training loop settle a bit
447 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
448 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
449 | print(f"iter {iter_num}: time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
450 |
451 |
452 | # termination conditions
453 | if iter_num > max_iters:
454 | break
455 |
456 | if ddp:
457 | destroy_process_group()
458 |
--------------------------------------------------------------------------------