├── hyperparams.png ├── LICENSE ├── LICENSES.md ├── configurator.py ├── launcher.sh ├── README.md ├── model.py └── train.py /hyperparams.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/ngpt/HEAD/hyperparams.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright(c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | MIT License 3 | [https://opensource.org/license/mit](https://opensource.org/license/mit) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a 6 | copy of this software and associated documentation files (the "Software"), 7 | to deal in the Software without restriction, including without limitation 8 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 9 | and/or sell copies of the Software, and to permit persons to whom the 10 | Software is furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 | THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 21 | DEALINGS IN THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSES.md: -------------------------------------------------------------------------------- 1 | Copyright(c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | 3 | MIT License 4 | [https://opensource.org/license/mit](https://opensource.org/license/mit) 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a 7 | copy of this software and associated documentation files (the "Software"), 8 | to deal in the Software without restriction, including without limitation 9 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | and/or sell copies of the Software, and to permit persons to whom the 11 | Software is furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | DEALINGS IN THE SOFTWARE. 23 | 24 | # Third-Party Licenses 25 | 26 | This project uses third-party open-source software (OSS). 27 | Below is a list of these components and their respective licenses. 28 | 29 | ================================================== 30 | 31 | ## [nanoGPT] 32 | [https://github.com/karpathy/nanoGPT](https://github.com/karpathy/nanoGPT) 33 | 34 | MIT License 35 | [https://github.com/karpathy/nanoGPT/blob/master/LICENSE](https://github.com/karpathy/nanoGPT/blob/master/LICENSE) 36 | 37 | Copyright (c) 2022 Andrej Karpathy 38 | 39 | Permission is hereby granted, free of charge, to any person obtaining a copy 40 | of this software and associated documentation files (the "Software"), to deal 41 | in the Software without restriction, including without limitation the rights 42 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 43 | copies of the Software, and to permit persons to whom the Software is 44 | furnished to do so, subject to the following conditions: 45 | 46 | The above copyright notice and this permission notice shall be included in all 47 | copies or substantial portions of the Software. 48 | 49 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 50 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 51 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 52 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 53 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 54 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 55 | SOFTWARE. 56 | 57 | -------------------------------------------------------------------------------- /configurator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # MIT license 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a 5 | # copy of this software and associated documentation files (the "Software"), 6 | # to deal in the Software without restriction, including without limitation 7 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | # and/or sell copies of the Software, and to permit persons to whom the 9 | # Software is furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in 12 | # all copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | # DEALINGS IN THE SOFTWARE. 21 | 22 | 23 | # The text below is the original header from the nanoGPT library 24 | """ 25 | Poor Man's Configurator. Probably a terrible idea. Example usage: 26 | $ python train.py config/override_file.py --batch_size=32 27 | this will first run config/override_file.py, then override batch_size to 32 28 | 29 | The code in this file will be run as follows from e.g. train.py: 30 | >>> exec(open('configurator.py').read()) 31 | 32 | So it's not a Python module, it's just shuttling this code away from train.py 33 | The code in this script then overrides the globals() 34 | 35 | I know people are not going to love this, I just really dislike configuration 36 | complexity and having to prepend config. to every single variable. If someone 37 | comes up with a better simple Python solution I am all ears. 38 | """ 39 | 40 | import sys 41 | from ast import literal_eval 42 | 43 | for arg in sys.argv[1:]: 44 | if '=' not in arg: 45 | # assume it's the name of a config file 46 | assert not arg.startswith('--') 47 | config_file = arg 48 | print(f"Overriding config with {config_file}:") 49 | with open(config_file) as f: 50 | print(f.read()) 51 | exec(open(config_file).read()) 52 | else: 53 | # assume it's a --key=value argument 54 | assert arg.startswith('--') 55 | key, val = arg.split('=') 56 | key = key[2:] 57 | if key in globals(): 58 | try: 59 | # attempt to eval it it (e.g. if bool, number, or etc) 60 | attempt = literal_eval(val) 61 | except (SyntaxError, ValueError): 62 | # if that goes wrong, just use the string 63 | attempt = val 64 | # ensure the types match ok 65 | assert type(attempt) == type(globals()[key]) 66 | # cross fingers 67 | print(f"Overriding: {key} = {attempt}") 68 | globals()[key] = attempt 69 | else: 70 | raise ValueError(f"Unknown config key: {key}") 71 | -------------------------------------------------------------------------------- /launcher.sh: -------------------------------------------------------------------------------- 1 | # Copyright(c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # MIT License 3 | # [https://opensource.org/license/mit](https://opensource.org/license/mit) 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a 6 | # copy of this software and associated documentation files (the "Software"), 7 | # to deal in the Software without restriction, including without limitation 8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 9 | # and/or sell copies of the Software, and to permit persons to whom the 10 | # Software is furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in 13 | # all copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 21 | # DEALINGS IN THE SOFTWARE. 22 | 23 | 24 | 25 | problem_name="GPT_1kctx_10k_lr30e-4" 26 | mycommand="" 27 | runtype="scratch" # the first run from scratch 28 | #runtype="resume" # uncomment to resume from the last checkpoint when needed 29 | 30 | # Notes: 31 | # block_size = sequence/context length 32 | # total batch size = gradient_accumulation_steps * batch_size 33 | # if there is a limit on job duration, you will need to implement scratch/resume logic (also check time_limit_seconds and max_iters_per_launch) 34 | # you can adjust max_iters_per_launch to stop training after the specified number of local (within the job) training steps. 35 | # the settings for gradient_accumulation_steps and batch_size are configured for running on 8 nodes (64 GPUs) in parallel. 36 | 37 | if [ "$problem_name" = "GPT_1kctx_10k_lr30e-4" ]; then 38 | mycommand=" --init_from='$runtype' --use_nGPT=0 --learning_rate=30e-4 --weight_decay=0.1 --warmup_iters=2000 --n_layer=24 --n_head=16 --n_embd=1024 --block_size=1024 --compile=False --batch_size=8 --gradient_accumulation_steps=64 --eval_iters=1000 --max_iters=10000 --lr_decay_iters=10000 --time_limit_seconds=103700 --min_lr=0.0 --eval_interval=2000 --max_iters_per_launch=14000" 39 | fi 40 | 41 | if [ "$problem_name" = "nGPT_1kctx_10k_lr30e-4" ]; then 42 | mycommand=" --init_from='$runtype' --use_nGPT=1 --learning_rate=30e-4 --weight_decay=0.0 --warmup_iters=0 --n_layer=24 --n_head=16 --n_embd=1024 --block_size=1024 --compile=False --batch_size=8 --gradient_accumulation_steps=64 --eval_iters=1000 --max_iters=10000 --lr_decay_iters=10000 --time_limit_seconds=103700 --min_lr=0.0 --eval_interval=2000 --max_iters_per_launch=14000" 43 | fi 44 | 45 | if [ "$problem_name" = "GPT_4kctx_10k_lr30e-4" ]; then 46 | mycommand=" --init_from='$runtype' --use_nGPT=0 --learning_rate=30e-4 --weight_decay=0.1 --warmup_iters=2000 --n_layer=24 --n_head=16 --n_embd=1024 --block_size=4096 --compile=False --batch_size=2 --gradient_accumulation_steps=256 --eval_iters=1000 --max_iters=10000 --lr_decay_iters=10000 --time_limit_seconds=103700 --min_lr=0.0 --eval_interval=2000 --max_iters_per_launch=18000" 47 | fi 48 | 49 | if [ "$problem_name" = "nGPT_4kctx_10k_lr30e-4" ]; then 50 | mycommand=" --init_from='$runtype' --use_nGPT=1 --learning_rate=30e-4 --weight_decay=0.0 --warmup_iters=0 --n_layer=24 --n_head=16 --n_embd=1024 --block_size=4096 --compile=False --batch_size=2 --gradient_accumulation_steps=256 --eval_iters=1000 --max_iters=10000 --lr_decay_iters=10000 --time_limit_seconds=103700 --min_lr=0.0 --eval_interval=2000 --max_iters_per_launch=18000" 51 | fi 52 | 53 | if [ "$mycommand" != "" ]; then 54 | torchrun --nnodes 1 --nproc_per_node 8 --rdzv_endpoint=localhost:29501 train.py $mycommand 55 | fi 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Copyright(c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | 3 | MIT License 4 | [https://opensource.org/license/mit](https://opensource.org/license/mit) 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a 7 | copy of this software and associated documentation files (the "Software"), 8 | to deal in the Software without restriction, including without limitation 9 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | and/or sell copies of the Software, and to permit persons to whom the 11 | Software is furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | DEALINGS IN THE SOFTWARE. 23 | 24 | 25 | Please take a moment to read this text in full. It may save you time. 26 | 27 | # **nGPT: Normalized Transformer with Representation Learning on the Hypersphere** 28 | 29 | **Authors**: Ilya Loshchilov, Cheng-Ping Hsieh, Simeng Sun, and Boris Ginsburg 30 | **Paper**: [arXiv:2410.01131](https://arxiv.org/abs/2410.01131) 31 | 32 | --- 33 | 34 | This repository provides code for **nGPT**, which builds on **nanoGPT** by Andrej Karpathy. Familiarity with nanoGPT's codebase is **required**, as it resolves common issues relevant here. With this foundation, you’ll find using and understanding nGPT much easier. 35 | 36 | ## **Project Overview** 37 | 38 | The main difference in this codebase lies in the Transformer models: 39 | 40 | 1. **Modifications**: 41 | - `model.py` includes both the **original** and **normalized Transformer** models. 42 | - `train.py` contains the **normalization procedure** for training. 43 | - The architecture follows the paper's specifications. The vocabulary size is different and this changes the scale of loss values. 44 | 45 | 2. **Dependencies**: 46 | - **nanoGPT**: To generate the data folder with OpenWebText, see the [nanoGPT repository](https://github.com/karpathy/nanoGPT). 47 | - **FlashAttention**: FlashAttention from [Dao-AILab](https://github.com/Dao-AILab/flash-attention) (BSD 3-Clause License) is used, though PyTorch’s default attention can be substituted if preferred. 48 | 49 | ## **Getting Started** 50 | 51 | ### **Running the Code** 52 | 53 | To start the training process with defined hyperparameters, execute `launcher.sh`. 54 | 55 | ### **Experiment Replication** 56 | 57 | The paper demonstrates a speedup of nGPT w.r.t. the baseline GPT in three settings: 1k, 4k and 8k context length. For 1k context length, **the expected speedup is of a factor of 4 at about 200k iteration steps** of training (i.e., the validation loss if on par with the one of the baseline GPT after 800k steps). **For shorter training runs, the expected speedup is smaller**. For 4k context length, the expected speedup is of a factor of 10. Similarly, the longer the training the greater the speedup. 58 | 59 | When reproducing these experiments using nanoGPT codebase, we introduced a few modifications to reflect the experimental setup used in the paper. First, we use RoPE for positional embedding instead of the absolution positional embedding. Second, we use SwiGLU instead of GELU which was the default in nanoGPT. Third, the original nanoGPT code can perform operations using lower-precision arithmetic (e.g., with bfloat16), however, the storage of parameters is happening in float32. In order to reflect our experimental setup of the paper where parameters of matrices are in bfloat16, we also set bfloat16 as the dtype of network parameters (all except embeddings). Apparatently, the change from float32 to bfloat16 only moderately affects nGPT but greatly degrades performance of the baseline GPT. In fact, the speedups that we observe in this reimplementation are greater than the ones of our internal experiments. This suggests that the treatment of precision-critical operations in this code is rather suboptimal and this affects the baseline GPT much more. Thus, **the demonstrated here speedup factors are greater than they normally should be and reported in the paper**. One possible explanation could be that nGPT is less sensitive to low-precision arithmetic/storage than the baseline GPT. As a result, one could see a smaller speedup when using float32 precision and a greater speedup when using bfloat16 and even lower precision. However, it also could be that something is wrong in this implementation. We would love to iterate over this code to make it a better approximation of our internal experiments and overall to attain a reproducible experimental setup. 60 | 61 | --- 62 | 63 | ![Hyperparameters](./hyperparams.png) 64 | 65 | --- 66 | 67 | ### **Repository Goals** 68 | 69 | Like nanoGPT, it might be beneficial to keep this code stable (with fixing only bugs) so that it can serve as a consistent reference implementation. 70 | 71 | This implementation is not optimized for memory or compute performance: the main goal is to **_illustrate_** how nGPT works, not to achieve a production-ready code. The paper suggests that nGPT can be simplified in various ways, sometimes without any loss in performance. 72 | 73 | **Special Thanks**: Many thanks to Andrej Karpathy for creating the nanoGPT library, which serves as a foundational component of this project. 74 | 75 | --- 76 | 77 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # MIT license 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a 5 | # copy of this software and associated documentation files (the "Software"), 6 | # to deal in the Software without restriction, including without limitation 7 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | # and/or sell copies of the Software, and to permit persons to whom the 9 | # Software is furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in 12 | # all copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | # DEALINGS IN THE SOFTWARE. 21 | 22 | 23 | # The text below is the original header from the nanoGPT library 24 | """ 25 | Full definition of a GPT Language Model, all of it in this single file. 26 | References: 27 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 28 | https://github.com/openai/gpt-2/blob/master/src/model.py 29 | 2) huggingface/transformers PyTorch implementation: 30 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 31 | """ 32 | 33 | import math 34 | import inspect 35 | from dataclasses import dataclass 36 | 37 | import torch 38 | import torch.nn as nn 39 | from torch.nn import functional as F 40 | import numpy as np 41 | from flash_attn import flash_attn_qkvpacked_func, flash_attn_func 42 | 43 | 44 | def apply_rotary_position_embeddings(sinusoidal_pos, q, k): 45 | # Split the sinusoidal_pos into sin and cos parts 46 | sin, cos = sinusoidal_pos.chunk(2, dim=-1) 47 | # Apply the rotary embeddings to the query and key 48 | q_rot = torch.stack((-q[..., 1::2], q[..., ::2]), dim=-1) 49 | k_rot = torch.stack((-k[..., 1::2], k[..., ::2]), dim=-1) 50 | q_rot = torch.reshape(q_rot, q.shape[:-1] + (q.shape[-1]//2, 2)) * torch.stack((cos, sin), dim=-1) 51 | k_rot = torch.reshape(k_rot, k.shape[:-1] + (k.shape[-1]//2, 2)) * torch.stack((cos, sin), dim=-1) 52 | q_rot = torch.reshape(q_rot, q.shape) 53 | k_rot = torch.reshape(k_rot, k.shape) 54 | return q_rot, k_rot 55 | 56 | def get_sinusoidal_embeddings( n_positions, dim): 57 | """Generate sinusoidal positional embeddings.""" 58 | position = torch.arange(n_positions, dtype=torch.float).unsqueeze(1) 59 | div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) 60 | sinusoidal_emb = torch.zeros((n_positions, dim)) 61 | sinusoidal_emb[:, 0::2] = torch.sin(position * div_term) 62 | sinusoidal_emb[:, 1::2] = torch.cos(position * div_term) 63 | return sinusoidal_emb 64 | 65 | 66 | class Block(nn.Module): 67 | 68 | def __init__(self, config, iblock): 69 | super().__init__() 70 | self.config = config 71 | 72 | self.key = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch.bfloat16) 73 | self.query = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch.bfloat16) 74 | self.value = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch.bfloat16) 75 | self.att_c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch.bfloat16) 76 | 77 | self.c_fc = nn.Linear(config.n_embd, 2 * 4 * config.n_embd, bias=config.bias, dtype=torch.bfloat16) 78 | self.silu = nn.SiLU() 79 | self.mlp_c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias, dtype=torch.bfloat16) 80 | 81 | if (config.use_nGPT == 0): 82 | self.rmsnorm_att = RMSNorm(config.n_embd) 83 | self.rmsnorm_mlp = RMSNorm(config.n_embd) 84 | 85 | if (config.use_nGPT == 1): 86 | self.attn_alpha_init_value = 0.05 87 | self.attn_alpha_init_scaling = config.base_scale 88 | self.attn_alpha = torch.nn.Parameter(self.attn_alpha_init_scaling*torch.ones(self.config.n_embd, dtype=torch.float32)) 89 | 90 | self.mlp_alpha_init_value = 0.05 91 | self.mlp_alpha_init_scaling = config.base_scale 92 | self.mlp_alpha = torch.nn.Parameter(self.mlp_alpha_init_scaling*torch.ones(self.config.n_embd, dtype=torch.float32)) 93 | 94 | self.sqk_init_value = 1.0 95 | self.sqk_init_scaling = config.base_scale 96 | self.sqk = torch.nn.Parameter(self.sqk_init_scaling*torch.ones(self.config.n_embd, dtype=torch.float32)) 97 | 98 | self.suv_init_value = 1.0 99 | self.suv_init_scaling = 1.0 100 | self.suv = torch.nn.Parameter(self.suv_init_scaling*torch.ones(2 * 4 * config.n_embd, dtype=torch.float32)) 101 | 102 | 103 | def justnorm(self, x): 104 | #return F.normalize(x, p=2, dim=-1) 105 | res = x / x.norm(p=2, dim=-1, keepdim=True) 106 | return res 107 | 108 | def forward(self, h): 109 | B, T, C = h.size() 110 | 111 | hin = h 112 | if (self.config.use_nGPT == 0): 113 | hin = self.rmsnorm_att(h) 114 | 115 | q = self.query(hin) 116 | k = self.key(hin) 117 | v = self.value(hin) 118 | 119 | q = q.view(B, T, self.config.n_head, self.config.n_embd // self.config.n_head) 120 | k = k.view(B, T, self.config.n_head, self.config.n_embd // self.config.n_head) 121 | v = v.view(B, T, self.config.n_head, self.config.n_embd // self.config.n_head) 122 | 123 | sinusoidal_pos = get_sinusoidal_embeddings(T, self.config.n_embd // self.config.n_head).to(device=q.device) 124 | q, k = apply_rotary_position_embeddings(sinusoidal_pos, q.transpose(1, 2), k.transpose(1, 2)) 125 | q = q.transpose(2, 1) 126 | k = k.transpose(2, 1) 127 | 128 | if (self.config.use_nGPT == 1): 129 | sqk = (self.sqk * (self.sqk_init_value/self.sqk_init_scaling)).view(1, 1, self.config.n_head, self.config.n_embd // self.config.n_head) 130 | q = sqk * self.justnorm(q) 131 | k = sqk * self.justnorm(k) 132 | 133 | sqrt_head_dim = (self.config.n_embd / self.config.n_head) ** 0.5 134 | if (self.config.use_nGPT == 0): softmax_scale = 1.0 / sqrt_head_dim 135 | if (self.config.use_nGPT == 1): softmax_scale = sqrt_head_dim 136 | y = flash_attn_func(q.to(dtype=torch.bfloat16), k.to(dtype=torch.bfloat16), v.to(dtype=torch.bfloat16), dropout_p=0.0, softmax_scale=softmax_scale, causal=True, window_size=(-1, -1), alibi_slopes=None, deterministic=True) 137 | y = y.to(dtype=q.dtype) 138 | y = y.contiguous().view(B, T, self.config.n_embd) 139 | 140 | h_att = self.att_c_proj(y) 141 | 142 | if (self.config.use_nGPT == 0): 143 | h = h + h_att 144 | if (self.config.use_nGPT == 1): 145 | lr = self.attn_alpha * (self.attn_alpha_init_value / self.attn_alpha_init_scaling) 146 | lr = torch.abs(lr) 147 | 148 | A_norm = self.justnorm(h) # normally, normalization is not needed 149 | B_norm = self.justnorm(h_att) 150 | 151 | #res = (1.0 - lr) * A_norm + lr * B_norm 152 | res = A_norm + lr * (B_norm - A_norm) 153 | h = self.justnorm(res) 154 | 155 | hin = h 156 | if (self.config.use_nGPT == 0): 157 | hin = self.rmsnorm_mlp(h) 158 | uv = self.c_fc(hin) 159 | if (self.config.use_nGPT == 1): 160 | suv = (self.suv * ((self.suv_init_value/self.suv_init_scaling) * (self.config.n_embd ** 0.5))) 161 | uv = suv * uv 162 | u, v = torch.chunk(uv, 2, dim=-1) 163 | x_mlp = u * self.silu(v) 164 | h_mlp = self.mlp_c_proj(x_mlp) 165 | 166 | if (self.config.use_nGPT == 0): 167 | h = h + h_mlp 168 | if (self.config.use_nGPT == 1): 169 | lr = self.mlp_alpha * (self.mlp_alpha_init_value / self.mlp_alpha_init_scaling) 170 | lr = torch.abs(lr) 171 | 172 | A_norm = self.justnorm(h) # normally, normalization is not needed 173 | B_norm = self.justnorm(h_mlp) 174 | 175 | #res = (1.0 - lr) * A_norm + lr * B_norm 176 | res = A_norm + lr * (B_norm - A_norm) 177 | h = self.justnorm(res) 178 | 179 | return h 180 | 181 | @dataclass 182 | class GPTConfig: 183 | block_size: int = 1024 184 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 185 | n_layer: int = 12 186 | n_head: int = 12 187 | n_embd: int = 1024 188 | base_scale: float = 1.0 / (1024.0 ** 0.5) # 1 / sqrt(n_embd) 189 | use_nGPT: int = 0 190 | dropout: float = 0.0 191 | bias: bool = False 192 | 193 | class RMSNorm(torch.nn.Module): 194 | def __init__(self, embdim: int, eps: float = 1e-6) -> None: 195 | super().__init__() 196 | self.weight = torch.nn.Parameter(torch.ones(embdim)) 197 | self.eps = eps 198 | 199 | def forward(self, x: torch.Tensor) -> torch.Tensor: 200 | dtype = x.dtype 201 | x = x.float() 202 | norm = torch.mean(x * x, dim=-1, keepdim=True) 203 | xnorm = x * torch.rsqrt(norm + self.eps) 204 | xnorm = xnorm.to(dtype=dtype) 205 | return xnorm * self.weight 206 | 207 | 208 | class GPT(nn.Module): 209 | 210 | def __init__(self, config): 211 | super().__init__() 212 | assert config.vocab_size is not None 213 | assert config.block_size is not None 214 | self.config = config 215 | 216 | self.transformer = nn.ModuleDict(dict( 217 | wte = nn.Embedding(config.vocab_size, config.n_embd), 218 | drop = nn.Dropout(config.dropout), 219 | h = nn.ModuleList([Block(config, il) for il in range(config.n_layer)]) 220 | )) 221 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 222 | # with weight tying when using torch.compile() some warnings get generated: 223 | # "UserWarning: functional_call was passed multiple values for tied weights. 224 | # This behavior is deprecated and will be an error in future versions" 225 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 226 | # *we don't use it becuase in the nGPT paper there was no weight tying of weights* 227 | # self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 228 | 229 | # init all weights 230 | self.apply(self._init_weights) 231 | # apply special scaled init to the residual projections, per GPT-2 paper 232 | for pn, p in self.named_parameters(): 233 | if pn.endswith('c_proj.weight'): 234 | torch.nn.init.normal_(p, mean=0.0, std=config.base_scale/math.sqrt(2 * config.n_layer)) 235 | # report number of parameters 236 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 237 | 238 | if (config.use_nGPT == 1): 239 | self.sz_init_value = 1.00 240 | self.sz_init_scaling = config.base_scale 241 | self.sz = torch.nn.Parameter(self.sz_init_scaling*torch.ones(config.vocab_size, dtype=torch.float32)) 242 | 243 | if (config.use_nGPT == 0): 244 | self.rmsnorm_f = RMSNorm(config.n_embd) 245 | 246 | 247 | def get_num_params(self, non_embedding=True): 248 | """ 249 | Return the number of parameters in the model. 250 | For non-embedding count (default), the position embeddings get subtracted. 251 | The token embeddings would too, except due to the parameter sharing these 252 | params are actually used as weights in the final layer, so we include them. 253 | """ 254 | n_params = sum(p.numel() for p in self.parameters()) 255 | #if non_embedding: 256 | # n_params -= self.transformer.wpe.weight.numel() 257 | return n_params 258 | 259 | def _init_weights(self, module): 260 | if isinstance(module, nn.Linear): 261 | torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.base_scale) 262 | if module.bias is not None: 263 | torch.nn.init.zeros_(module.bias) 264 | elif isinstance(module, nn.Embedding): 265 | torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.base_scale) 266 | 267 | def forward(self, idx, targets=None): 268 | device = idx.device 269 | b, t = idx.size() 270 | #assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 271 | 272 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 273 | 274 | x = tok_emb 275 | for block in self.transformer.h: 276 | x = block(x) 277 | 278 | if (self.config.use_nGPT == 0): 279 | x = self.rmsnorm_f(x) 280 | 281 | if targets is not None: 282 | logits = self.lm_head(x) 283 | if (self.config.use_nGPT == 1): 284 | sz = self.sz * (self.sz_init_value/self.sz_init_scaling) 285 | logits = sz * logits 286 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 287 | else: 288 | # inference-time mini-optimization: only forward the lm_head on the very last position 289 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 290 | if (self.config.use_nGPT == 1): 291 | sz = self.sz * (self.sz_init_value/self.sz_init_scaling) 292 | logits = sz * logits 293 | loss = None 294 | 295 | return logits, loss 296 | 297 | 298 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 299 | # start with all of the candidate parameters 300 | param_dict = {pn: p for pn, p in self.named_parameters()} 301 | # filter out those that do not require grad 302 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 303 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 304 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 305 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 306 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 307 | optim_groups = [ 308 | {'params': decay_params, 'weight_decay': weight_decay}, 309 | {'params': nodecay_params, 'weight_decay': 0.0} 310 | ] 311 | num_decay_params = sum(p.numel() for p in decay_params) 312 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 313 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 314 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 315 | # Create AdamW optimizer and use the fused version if it is available 316 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 317 | use_fused = False#fused_available and device_type == 'cuda' 318 | extra_args = dict(fused=True) if use_fused else dict() 319 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 320 | print(f"using fused AdamW: {use_fused}") 321 | return optimizer 322 | 323 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright(c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # MIT License 3 | # [https://opensource.org/license/mit](https://opensource.org/license/mit) 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a 6 | # copy of this software and associated documentation files (the "Software"), 7 | # to deal in the Software without restriction, including without limitation 8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 9 | # and/or sell copies of the Software, and to permit persons to whom the 10 | # Software is furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in 13 | # all copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 21 | # DEALINGS IN THE SOFTWARE. 22 | 23 | 24 | # The text below is the original header from the nanoGPT library 25 | """ 26 | This training script can be run both on a single gpu in debug mode, 27 | and also in a larger training run with distributed data parallel (ddp). 28 | 29 | To run on a single GPU, example: 30 | $ python train.py --batch_size=32 --compile=False 31 | 32 | To run with DDP on 4 gpus on 1 node, example: 33 | $ torchrun --standalone --nproc_per_node=4 train.py 34 | 35 | To run with DDP on 4 gpus across 2 nodes, example: 36 | - Run on the first (master) node with example IP 123.456.123.456: 37 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 38 | - Run on the worker node: 39 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 40 | (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1) 41 | """ 42 | 43 | import os 44 | import time 45 | import math 46 | import pickle 47 | import sys 48 | from contextlib import nullcontext 49 | from typing import Callable, Optional, Tuple, List 50 | 51 | import numpy as np 52 | import torch 53 | from torch.nn.parallel import DistributedDataParallel as DDP 54 | import torch.distributed as dist 55 | from model import GPTConfig, GPT 56 | from torch.nn import functional as F 57 | from datetime import timedelta 58 | 59 | # ----------------------------------------------------------------------------- 60 | # I/O 61 | 62 | eval_interval = 1000 63 | log_interval = 10 64 | eval_iters = 200 65 | eval_only = False # if True, script exits right after the first eval 66 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 67 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 68 | # wandb logging 69 | wandb_log = False # disabled by default 70 | wandb_project = 'owt' 71 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 72 | # data 73 | dataset = 'openwebtext' 74 | gradient_accumulation_steps = 64 # used to simulate larger batch sizes 75 | batch_size = 8 # if gradient_accumulation_steps > 1, this is the micro-batch size 76 | # model 77 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 78 | bias = False # do we use bias inside LayerNorm and Linear layers? 79 | # adamw optimizer 80 | max_iters = 600000 # total number of training iterations 81 | beta1 = 0.9 82 | beta2 = 0.95 83 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 84 | # learning rate decay settings 85 | decay_lr = True # whether to decay the learning rate 86 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 87 | # DDP settings 88 | backend = 'nccl' # 'nccl', 'gloo', etc. 89 | # system 90 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 91 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 92 | compile = False # use PyTorch 2.0 to compile the model to be faster 93 | # 94 | time_limit_seconds = 1000000000 # stop after x seconds 95 | max_iters_per_launch = 1000000000 # stop after x steps of the current 96 | 97 | use_nGPT = 1 98 | learning_rate = 15e-4 99 | 100 | # model size and seqlen 101 | if (1): 102 | n_layer = 12 103 | n_head = 16 104 | n_embd = 1024 105 | block_size = 1024 # = context/sequence length 106 | 107 | if (use_nGPT == 0): 108 | min_lr = 0.0 109 | weight_decay = 0.1 110 | warmup_iters = 2000 111 | if (use_nGPT == 1): 112 | min_lr = 0.0 113 | weight_decay = 0.0 114 | warmup_iters = 0 115 | 116 | tlaunch = time.time() 117 | print("Current Directory:", os.getcwd()) 118 | # the input configurations will overwrite all configs given above! 119 | # ----------------------------------------------------------------------------- 120 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 121 | exec(open('configurator.py').read()) # overrides from command line or config file 122 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 123 | # ----------------------------------------------------------------------------- 124 | 125 | if (use_nGPT == 0): 126 | base_scale = 0.02 # can be interpreted as init_std 127 | if (use_nGPT == 1): 128 | base_scale = 1.0 / n_embd ** 0.5 129 | 130 | # various inits, derived attributes, I/O setup 131 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 132 | if ddp: 133 | #init_process_group(backend=backend) 134 | dist.init_process_group(backend=backend, 135 | timeout=timedelta(milliseconds=20*60000) # Setting a 20-minute timeout 136 | ) 137 | ddp_rank = int(os.environ['RANK']) 138 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 139 | ddp_world_size = int(os.environ['WORLD_SIZE']) 140 | device = f'cuda:{ddp_local_rank}' 141 | torch.cuda.set_device(device) 142 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 143 | seed_offset = ddp_rank # each process gets a different seed 144 | # world_size number of processes will be training simultaneously, so we can scale 145 | # down the desired gradient accumulation iterations per process proportionally 146 | assert gradient_accumulation_steps % ddp_world_size == 0 147 | gradient_accumulation_steps //= ddp_world_size 148 | dist.barrier() 149 | else: 150 | # if not ddp, we are running on a single gpu, and one process 151 | master_process = True 152 | seed_offset = 0 153 | ddp_world_size = 1 154 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size 155 | print(f"tokens per iteration will be: {tokens_per_iter:,}") 156 | 157 | 158 | out_dir='./' 159 | if master_process: 160 | if not os.path.exists(out_dir): 161 | os.makedirs(out_dir) 162 | 163 | 164 | local_seed = seed_offset 165 | np.random.seed(local_seed) 166 | torch.manual_seed(local_seed) 167 | torch.cuda.manual_seed(local_seed) 168 | 169 | 170 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 171 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 172 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 173 | # note: float16 data type will automatically use a GradScaler 174 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 175 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 176 | 177 | # poor man's data loader 178 | tdataloading_begin = time.time() 179 | if os.path.exists('./../../data'): 180 | data_dir = os.path.join('./../../data', dataset) 181 | else: 182 | data_dir = os.path.join('data', dataset) 183 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 184 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 185 | 186 | def get_batch(split): 187 | # We recreate np.memmap every batch to avoid a memory leak, as per 188 | # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 189 | if split == 'train': 190 | data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 191 | else: 192 | data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 193 | ix = torch.randint(len(data) - block_size, (batch_size,)) 194 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 195 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 196 | if device_type == 'cuda': 197 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 198 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 199 | else: 200 | x, y = x.to(device), y.to(device) 201 | return x, y 202 | 203 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 204 | iter_num = 0 205 | 206 | 207 | # attempt to derive vocab_size from the dataset 208 | meta_path = os.path.join(data_dir, 'meta.pkl') 209 | meta_vocab_size = None 210 | if os.path.exists(meta_path): 211 | with open(meta_path, 'rb') as f: 212 | meta = pickle.load(f) 213 | meta_vocab_size = meta['vocab_size'] 214 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 215 | print("Data loading time: %f sec" % (time.time()-tdataloading_begin)) 216 | 217 | 218 | # model init 219 | tmodelinit_begin = time.time() 220 | model_args = dict(use_nGPT=use_nGPT, n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, base_scale=base_scale, 221 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line 222 | if init_from == 'scratch': 223 | # init a new model from scratch 224 | print("Initializing a new model from scratch") 225 | # determine the vocab size we'll use for from-scratch training 226 | if meta_vocab_size is None: 227 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 228 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 229 | gptconf = GPTConfig(**model_args) 230 | model = GPT(gptconf) 231 | elif init_from == 'resume': 232 | print(f"Resuming training from {out_dir}") 233 | # resume training from a checkpoint. 234 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 235 | checkpoint = torch.load(ckpt_path, map_location=device) 236 | checkpoint_model_args = checkpoint['model_args'] 237 | # force these config attributes to be equal otherwise we can't even resume training 238 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 239 | for k in ['use_nGPT', 'base_scale', 'n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 240 | model_args[k] = checkpoint_model_args[k] 241 | # create the model 242 | gptconf = GPTConfig(**model_args) 243 | model = GPT(gptconf) 244 | state_dict = checkpoint['model'] 245 | # fix the keys of the state dictionary :( 246 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 247 | unwanted_prefix = '_orig_mod.' 248 | for k,v in list(state_dict.items()): 249 | if k.startswith(unwanted_prefix): 250 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 251 | model.load_state_dict(state_dict) 252 | iter_num = checkpoint['iter_num'] 253 | # crop down the model block size if desired, using model surgery 254 | if block_size < model.config.block_size: 255 | model.crop_block_size(block_size) 256 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 257 | model.to(device) 258 | print("Model initialization/loading time: %f sec" % (time.time()-tmodelinit_begin)) 259 | 260 | # optimizer 261 | optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) 262 | if init_from == 'resume': 263 | optimizer.load_state_dict(checkpoint['optimizer']) 264 | checkpoint = None # free up memory 265 | 266 | # compile the model 267 | if compile: 268 | print("compiling the model... (takes a ~minute)") 269 | unoptimized_model = model 270 | model = torch.compile(model) # requires PyTorch 2.0 271 | 272 | # wrap model into DDP container 273 | if ddp: 274 | model = DDP(model, device_ids=[ddp_local_rank]) 275 | 276 | # helps estimate an arbitrarily accurate loss over either split using many batches 277 | @torch.no_grad() 278 | def estimate_loss(): 279 | out = {} 280 | model.eval() 281 | for split in ['train', 'val']: 282 | losses = torch.zeros(eval_iters) 283 | for k in range(eval_iters): 284 | X, Y = get_batch(split) 285 | with ctx: 286 | logits, loss = model(X, Y) 287 | losses[k] = loss.item() 288 | out[split] = losses.mean() 289 | model.train() 290 | return out 291 | 292 | # learning rate decay scheduler (cosine with warmup) 293 | def get_lr(it): 294 | # 1) linear warmup for warmup_iters steps 295 | if it < warmup_iters: 296 | return learning_rate * it / warmup_iters 297 | # 2) if it > lr_decay_iters, return min learning rate 298 | if it > lr_decay_iters: 299 | return min_lr 300 | # 3) in between, use cosine decay down to min learning rate 301 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 302 | assert 0 <= decay_ratio <= 1 303 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 304 | return min_lr + coeff * (learning_rate - min_lr) 305 | 306 | # logging 307 | if wandb_log and master_process: 308 | import wandb 309 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 310 | 311 | # training loop 312 | #X, Y = get_batch('train') # fetch the very first batch 313 | t0 = time.time() 314 | local_iter_num = 0 # number of iterations in the lifetime of this process 315 | raw_model = model.module if ddp else model # unwrap DDP container if needed 316 | 317 | 318 | if master_process: 319 | print("learning_rate: %f" % (learning_rate)) 320 | print("min_lr: %f" % (min_lr)) 321 | print("max_iters: %f" % (max_iters)) 322 | print("lr_decay_iters: %f" % (lr_decay_iters)) 323 | print("warmup_iters: %f" % (warmup_iters)) 324 | print("batch_size: %f" % (batch_size)) 325 | print("gradient_accumulation_steps: %f" % (gradient_accumulation_steps)) 326 | print("block_size: %f" % (block_size)) 327 | print("weight_decay: %f" % (weight_decay)) 328 | 329 | def get_hparams_str(model): 330 | if (use_nGPT == 0): 331 | return "" 332 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 333 | transformer = model.module.transformer 334 | config = model.module.config 335 | module = model.module 336 | else: 337 | transformer = model.transformer 338 | config = model.config 339 | module = model 340 | 341 | resstr = "%.5f " % torch.mean( module.sz * (module.sz_init_value/module.sz_init_scaling) ) 342 | 343 | for layer_idx in range(0, config.n_layer): 344 | block = transformer["h"][layer_idx] 345 | sqk = block.sqk * (block.sqk_init_value/block.sqk_init_scaling) 346 | attn_alpha = block.attn_alpha * (block.attn_alpha_init_value / block.attn_alpha_init_scaling) 347 | mlp_alpha = block.mlp_alpha * (block.mlp_alpha_init_value / block.mlp_alpha_init_scaling) 348 | suv = block.suv * (block.suv_init_value/block.suv_init_scaling) 349 | 350 | resstr = resstr + "%.5f " % torch.mean( sqk ) 351 | resstr = resstr + "%.5f " % torch.mean( attn_alpha ) 352 | resstr = resstr + "%.5f " % torch.mean( mlp_alpha ) 353 | resstr = resstr + "%.5f " % torch.mean( suv ) 354 | 355 | return resstr 356 | 357 | stat_fname = out_dir + "/stat" 358 | if master_process: 359 | if init_from == 'scratch': 360 | file = open(stat_fname, "w") 361 | resstr = f"{0:.6e} {0:.4e} {0.0:.4e} {0.0:.4e} {0.0:.4e} {0.0:.4e} {0.0:.4e} {0.0:.4e} {0.0:.4e} {0.0:.4e} {0.0:.4e} {0:.4e} {0.0:.4e}" 362 | resstr = resstr + get_hparams_str(model) + "\n" 363 | file.write(resstr) 364 | arguments = sys.argv 365 | fname_arg = out_dir + "/args" 366 | with open(fname_arg, 'w') as file_arg: 367 | for arg in arguments: 368 | file_arg.write(arg + '\n') 369 | 370 | if init_from == 'resume': 371 | file = open(stat_fname, "a") 372 | 373 | 374 | time_spent = time.time() - tlaunch 375 | print(f"Time spent: {time_spent} seconds") 376 | starting_iter_num = iter_num 377 | print("starting_iter_num: %d" % iter_num) 378 | 379 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 380 | transformer = model.module.transformer 381 | config = model.module.config 382 | module = model.module 383 | else: 384 | transformer = model.transformer 385 | config = model.config 386 | module = model 387 | 388 | def justnorm(x, idim=-1): 389 | dtype = x.dtype 390 | x = x.float() 391 | res = (x / x.norm(p=2, dim=idim, keepdim=True)).to(dtype=dtype) 392 | return res 393 | 394 | def normalize_matrices(): 395 | transformer.wte.weight.data.copy_(justnorm(transformer.wte.weight.data, 1)) # V, n_embd 396 | module.lm_head.weight.data.copy_(justnorm(module.lm_head.weight.data, 1)) # V, n_embd 397 | 398 | 399 | for layer_idx in range(0, config.n_layer): 400 | block = transformer["h"][layer_idx] 401 | 402 | block.query.weight.data.copy_(justnorm(block.query.weight.data, 1)) # n_proj, n_embd 403 | block.key.weight.data.copy_(justnorm(block.key.weight.data, 1)) # n_proj, n_embd 404 | block.value.weight.data.copy_(justnorm(block.value.weight.data, 1)) # n_proj, n_embd 405 | block.att_c_proj.weight.data.copy_(justnorm(block.att_c_proj.weight.data, 0)) # n_embd, n_proj 406 | 407 | block.c_fc.weight.data.copy_(justnorm(block.c_fc.weight.data, 1)) # n_proj, n_embd 408 | block.mlp_c_proj.weight.data.copy_(justnorm(block.mlp_c_proj.weight.data, 0)) # n_embd, n_proj 409 | 410 | if (use_nGPT == 1): 411 | normalize_matrices() 412 | 413 | while True: 414 | #sys.stdout.flush() 415 | if (local_iter_num > max_iters_per_launch): 416 | break 417 | if (1): 418 | local_seed = 100*iter_num + seed_offset # local_seed should never exceed 2.147e+9 because of np.random.seed, 100 here should be > nworkers 419 | np.random.seed(local_seed) 420 | torch.manual_seed(local_seed) 421 | torch.cuda.manual_seed(local_seed) 422 | #if (iter_num % 10 == 0): # uncomment to make sure different seeds are used 423 | # print("iter: %d seed: %d" % (iter_num, local_seed)) 424 | 425 | # determine and set the learning rate for this iteration 426 | lr = get_lr(iter_num) if decay_lr else learning_rate 427 | for param_group in optimizer.param_groups: 428 | param_group['lr'] = lr 429 | 430 | # evaluate the loss on train/val sets and write checkpoints 431 | if iter_num % eval_interval == 0 and master_process: 432 | rng_state_pytorch = torch.get_rng_state() 433 | rng_state_bytes = rng_state_pytorch.numpy().tobytes() 434 | losses = estimate_loss() 435 | print(f"step {iter_num}: train loss {losses['train']:.6f}, val loss {losses['val']:.6f}") 436 | 437 | if wandb_log: 438 | wandb.log({ 439 | "iter": iter_num, 440 | "train/loss": losses['train'], 441 | "val/loss": losses['val'], 442 | "lr": lr 443 | }) 444 | 445 | if always_save_checkpoint: 446 | if iter_num > starting_iter_num: 447 | tcheckpointsaving_begin = time.time() 448 | checkpoint = { 449 | 'model': raw_model.state_dict(), 450 | 'optimizer': optimizer.state_dict(), 451 | 'model_args': model_args, 452 | 'iter_num': iter_num, 453 | 'config': config, 454 | 'rng_state_pytorch_bytes': rng_state_bytes, 455 | 'rng_state_numpy': np.random.get_state() 456 | } 457 | print(f"saving checkpoint to {out_dir}") 458 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 459 | print("Checkpoint saving time: %f sec" % (time.time()-tcheckpointsaving_begin)) 460 | 461 | if iter_num == 0 and eval_only: 462 | break 463 | 464 | # forward backward update, with optional gradient accumulation to simulate larger batch size 465 | # and using the GradScaler if data type is float16 466 | X, Y = get_batch('train') 467 | for micro_step in range(gradient_accumulation_steps): 468 | if ddp: 469 | # in DDP training we only need to sync gradients at the last micro step. 470 | # the official way to do this is with model.no_sync() context manager, but 471 | # I really dislike that this bloats the code and forces us to repeat code 472 | # looking at the source of that context manager, it just toggles this variable 473 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 474 | with ctx: 475 | logits, loss = model(X, Y) 476 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation 477 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 478 | X, Y = get_batch('train') 479 | # backward pass, with gradient scaling if training in fp16 480 | #.scale(loss).backward() 481 | loss.backward() 482 | 483 | if grad_clip != 0.0: 484 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 485 | optimizer.step() 486 | # flush the gradients as soon as we can, no need for this memory anymore 487 | optimizer.zero_grad(set_to_none=True) 488 | 489 | # timing and logging 490 | t1 = time.time() 491 | dt = t1 - t0 492 | t0 = t1 493 | if iter_num % log_interval == 0 and master_process: 494 | # get loss as float. note: this is a CPU-GPU sync point 495 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum) 496 | lossf = loss.item() * gradient_accumulation_steps 497 | print(f"iter {iter_num}: loss {lossf:.6f}, time {dt*1000:.2f}ms") 498 | 499 | if (use_nGPT == 1): 500 | normalize_matrices() 501 | 502 | if (iter_num % 100 == 0) and master_process: 503 | print("lr=%f" % lr) 504 | 505 | if master_process: 506 | resstr = f"{iter_num:.6e} {lr:.4e} {losses['train']:.4e} {losses['val']:.4e} {0.0:.4e} {0.0:.4e} {0.0:.4e} {0.0:.4e} {0.0:.4e} {0.0:.4e} {0.0:.4e} {0:.4e} {0.0:.4e} " 507 | resstr = resstr + get_hparams_str(model) + "\n" 508 | 509 | file.write(resstr) 510 | file.flush() 511 | 512 | if iter_num >= max_iters: 513 | finished_fname = out_dir + "/finished" 514 | finished_file = open(finished_fname, "w") 515 | finished_file.write("1") 516 | finished_file.close() 517 | 518 | if (time.time() - tlaunch > time_limit_seconds): 519 | break 520 | 521 | iter_num += 1 522 | local_iter_num += 1 523 | if iter_num > max_iters: 524 | break 525 | time_spent = time.time() - tlaunch 526 | print(f"Time spent: {time_spent} seconds") 527 | if ddp: 528 | dist.barrier() 529 | dist.destroy_process_group() 530 | 531 | --------------------------------------------------------------------------------