├── LICENSE ├── README.md ├── configuration_gpt2.py ├── configuration_ngpt.py ├── data └── openwebtext_llama │ └── prepare.py ├── images ├── 4k_arceasy.png ├── 4k_average.png ├── 4k_hellaswag.png ├── 4k_lambada.png ├── 4k_loss.png ├── 4k_winogrande.png ├── 4k_wsc273.png ├── arc_easy.png ├── average.png ├── hellaswag.png ├── lambada.png ├── loss.png ├── network_params.png ├── optimizer_params.png ├── wall_clock.png ├── winogrande.png └── wsc273.png ├── modeling_gpt2.py ├── modeling_ngpt.py ├── train_gpt2.py └── train_ngpt.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Joe Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Sponsored by Nous Research Inc., 2025. 2 | 3 | # nGPT 4 | This is an open-source reproduction of NVIDIA's [nGPT](https://arxiv.org/abs/2410.01131) (Normalized Transformer with Representation Learning on the Hypersphere) paper by Loshchilov et al., which claims to reduce "the number of training steps required to achieve the same accuracy by a factor of 4 to 20, depending on the sequence length," compared to a baseline transformer model. 5 | ## Project Overview 6 | This repository provides modeling and training code for a modified GPT-2 and nGPT model as well as our results. Both models were pre-trained on [OpenWebText](https://huggingface.co/datasets/Skylion007/openwebtext). We attempt to adhere to the paper's specifications as closely as possible. 7 | 8 | ### Dependencies 9 | 10 | - **Hugging Face [transformers](https://github.com/huggingface/transformers) library**: `modeling_ngpt.py` and `modeling_gpt2.py` extend the `PreTrainedModel` class that Hugging Face provides. 11 | - **[nanoGPT](https://github.com/karpathy/nanoGPT)**: the training and data generation code build off this repository (`train_ngpt.py`, `train_gpt2.py`, `data/openwebtext_llama/prepare.py`) 12 | - **[EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)**: used for hellaswag, arc easy, winogrande, wsc273, and lambada-openai evals 13 | 14 | ### Key Modifications 15 | #### Tokenization 16 | 17 | The LLaMA tokenizer (vocab size 32000) is used instead of the GPT tokenizer (vocab size 50257). See `data/openwebtext_llama/prepare.py`. 18 | #### GPT-2 19 | 20 | - SwiGLU activation function 21 | - Rotary Position Embeddings (RoPE) 22 | - No weight tying between the input token embedding layer and the final output logits layer 23 | 24 | #### nGPT 25 | See paper for detailed explanations, particularly Section 2.6. Briefly, all hidden vectors and weight vectors that lie along the embedding dimension are normalized to have unit norm and lie on the same unit hypersphere. 26 | 27 | ### Training 28 | 29 | 0.5B models with 1024 and 4096 context length were trained on OpenWebText. We use the same parameters as specified in the nGPT paper as shown below. We use an initial learning rate of 15e-4 for 1024 context length and 30e-4 for 4096 context length. Here is the [model card](https://huggingface.co/NousResearch/ngpt_0.5B_4k_200B) for a model trained on 200B tokens at 4k context length. 30 |
31 | 32 | 33 |
34 | 35 | ### Results 36 | 37 | By visual inspection of the following graphs, we observe roughly 1.5-2x and 4x speedup at ~200 billion tokens for 1k and 4k context length, respectively. Note that every data point on the following graphs represent **separate** training runs that each ran to completion (i.e. different learning rate schedules). 38 | 39 | #### Loss 40 |
41 | 42 | 43 |
44 | 45 | #### Performance on downstream tasks (1k context) 46 | 47 |
48 | 49 | 50 | 51 |
52 |
53 | 54 | 55 | 56 |
57 | 58 | #### Performance on downstream tasks (4k context) 59 | 60 |
61 | 62 | 63 | 64 |
65 |
66 | 67 | 68 | 69 |
70 | 71 | #### Wall-Clock Time for Training on 100B training tokens with 8xH100 72 | 73 | 74 | ### Analysis 75 | 76 | We observe a smaller speedup than the nGPT paper claims. This is because the model parameters in this experimental reproduction are stored in float32 while the original paper stores them in bfloat16. NVIDIA's open source reproduction claims the following: 77 | 78 | > In order to reflect our experimental setup of the paper where parameters of matrices are in bfloat16, we also set bfloat16 as the dtype of network parameters (all except embeddings). Apparently, the change from float32 to bfloat16 only moderately affects nGPT but greatly degrades performance of the baseline GPT. 79 | 80 | We observe the above results in our reproduction as well — our nGPT model closely matches the experimental results, but our GPT-2 model performs better with float32 model parameters. For 0.5B models at 1k context length, while the paper claims 4x speedup with bfloat16 parameters at ~400 billion tokens, our nGPT reproduction achieves roughly 1.5-2x speedup with float32 parameters at ~200 billion tokens. For 0.5B models at 4k context length, while the paper claims 10x speedup with bfloat16 parameters at ~400 billion tokens, nGPT achieves roughly 4x speedup with float32 parameters at ~200 billion tokens. Moreover, we observe greater speedups for longer training runs. 81 | -------------------------------------------------------------------------------- /configuration_gpt2.py: -------------------------------------------------------------------------------- 1 | """OpenAI GPT-2 configuration""" 2 | 3 | from collections import OrderedDict 4 | from typing import Any, List, Mapping, Optional 5 | 6 | from transformers import PreTrainedTokenizer, TensorType, is_torch_available 7 | from transformers import PretrainedConfig 8 | from transformers.utils import logging 9 | 10 | 11 | logger = logging.get_logger(__name__) 12 | 13 | 14 | class GPT2Config(PretrainedConfig): 15 | """ 16 | This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to 17 | instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a 18 | configuration with the defaults will yield a similar configuration to that of the GPT-2 19 | [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) architecture. 20 | 21 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 22 | documentation from [`PretrainedConfig`] for more information. 23 | 24 | 25 | Args: 26 | vocab_size (`int`, *optional*, defaults to 50257): 27 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the 28 | `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. 29 | n_positions (`int`, *optional*, defaults to 1024): 30 | The maximum sequence length that this model might ever be used with. Typically set this to something large 31 | just in case (e.g., 512 or 1024 or 2048). 32 | n_embd (`int`, *optional*, defaults to 768): 33 | Dimensionality of the embeddings and hidden states. 34 | n_layer (`int`, *optional*, defaults to 12): 35 | Number of hidden layers in the Transformer encoder. 36 | n_head (`int`, *optional*, defaults to 12): 37 | Number of attention heads for each attention layer in the Transformer encoder. 38 | n_inner (`int`, *optional*): 39 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd 40 | activation_function (`str`, *optional*, defaults to `"gelu_new"`): 41 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. 42 | resid_pdrop (`float`, *optional*, defaults to 0.1): 43 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 44 | embd_pdrop (`float`, *optional*, defaults to 0.1): 45 | The dropout ratio for the embeddings. 46 | attn_pdrop (`float`, *optional*, defaults to 0.1): 47 | The dropout ratio for the attention. 48 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): 49 | The epsilon to use in the layer normalization layers. 50 | initializer_range (`float`, *optional*, defaults to 0.02): 51 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 52 | summary_type (`string`, *optional*, defaults to `"cls_index"`): 53 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 54 | [`TFGPT2DoubleHeadsModel`]. 55 | 56 | Has to be one of the following options: 57 | 58 | - `"last"`: Take the last token hidden state (like XLNet). 59 | - `"first"`: Take the first token hidden state (like BERT). 60 | - `"mean"`: Take the mean of all tokens hidden states. 61 | - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). 62 | - `"attn"`: Not implemented now, use multi-head attention. 63 | summary_use_proj (`bool`, *optional*, defaults to `True`): 64 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 65 | [`TFGPT2DoubleHeadsModel`]. 66 | 67 | Whether or not to add a projection after the vector extraction. 68 | summary_activation (`str`, *optional*): 69 | Argument used when doing sequence summary. Used in for the multiple choice head in 70 | [`GPT2DoubleHeadsModel`]. 71 | 72 | Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. 73 | summary_proj_to_labels (`bool`, *optional*, defaults to `True`): 74 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 75 | [`TFGPT2DoubleHeadsModel`]. 76 | 77 | Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. 78 | summary_first_dropout (`float`, *optional*, defaults to 0.1): 79 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 80 | [`TFGPT2DoubleHeadsModel`]. 81 | 82 | The dropout ratio to be used after the projection and activation. 83 | scale_attn_weights (`bool`, *optional*, defaults to `True`): 84 | Scale attention weights by dividing by sqrt(hidden_size).. 85 | use_cache (`bool`, *optional*, defaults to `True`): 86 | Whether or not the model should return the last key/values attentions (not used by all models). 87 | bos_token_id (`int`, *optional*, defaults to 50256): 88 | Id of the beginning of sentence token in the vocabulary. 89 | eos_token_id (`int`, *optional*, defaults to 50256): 90 | Id of the end of sentence token in the vocabulary. 91 | scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): 92 | Whether to additionally scale attention weights by `1 / layer_idx + 1`. 93 | reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): 94 | Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention 95 | dot-product/softmax to float() when training with mixed precision. 96 | 97 | Example: 98 | 99 | ```python 100 | >>> from transformers import GPT2Config, GPT2Model 101 | 102 | >>> # Initializing a GPT2 configuration 103 | >>> configuration = GPT2Config() 104 | 105 | >>> # Initializing a model (with random weights) from the configuration 106 | >>> model = GPT2Model(configuration) 107 | 108 | >>> # Accessing the model configuration 109 | >>> configuration = model.config 110 | ```""" 111 | 112 | model_type = "gpt2" 113 | keys_to_ignore_at_inference = ["past_key_values"] 114 | attribute_map = { 115 | "hidden_size": "n_embd", 116 | "max_position_embeddings": "n_positions", 117 | "num_attention_heads": "n_head", 118 | "num_hidden_layers": "n_layer", 119 | } 120 | 121 | def __init__( 122 | self, 123 | vocab_size=50257, 124 | n_positions=1024, 125 | n_embd=768, 126 | n_layer=12, 127 | n_head=12, 128 | n_inner=None, 129 | activation_function="gelu_new", 130 | resid_pdrop=0.0, 131 | embd_pdrop=0.0, 132 | attn_pdrop=0.0, 133 | layer_norm_epsilon=1e-5, 134 | initializer_range=0.02, 135 | summary_type="cls_index", 136 | summary_use_proj=True, 137 | summary_activation=None, 138 | summary_proj_to_labels=True, 139 | summary_first_dropout=0.0, 140 | scale_attn_weights=True, 141 | use_cache=True, 142 | bos_token_id=50256, 143 | eos_token_id=50256, 144 | scale_attn_by_inverse_layer_idx=False, 145 | reorder_and_upcast_attn=False, 146 | **kwargs, 147 | ): 148 | self.vocab_size = vocab_size 149 | self.n_positions = n_positions 150 | self.n_embd = n_embd 151 | self.n_layer = n_layer 152 | self.n_head = n_head 153 | self.n_inner = n_inner 154 | self.activation_function = activation_function 155 | self.resid_pdrop = resid_pdrop 156 | self.embd_pdrop = embd_pdrop 157 | self.attn_pdrop = attn_pdrop 158 | self.layer_norm_epsilon = layer_norm_epsilon 159 | self.initializer_range = initializer_range 160 | self.summary_type = summary_type 161 | self.summary_use_proj = summary_use_proj 162 | self.summary_activation = summary_activation 163 | self.summary_first_dropout = summary_first_dropout 164 | self.summary_proj_to_labels = summary_proj_to_labels 165 | self.scale_attn_weights = scale_attn_weights 166 | self.use_cache = use_cache 167 | self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx 168 | self.reorder_and_upcast_attn = reorder_and_upcast_attn 169 | 170 | self.bos_token_id = bos_token_id 171 | self.eos_token_id = eos_token_id 172 | 173 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 174 | 175 | 176 | -------------------------------------------------------------------------------- /configuration_ngpt.py: -------------------------------------------------------------------------------- 1 | """Ngpt configuration""" 2 | 3 | from collections import OrderedDict 4 | from typing import Any, List, Mapping, Optional 5 | 6 | from transformers import PreTrainedTokenizer, TensorType, is_torch_available 7 | from transformers import PretrainedConfig 8 | from transformers import logging 9 | 10 | 11 | logger = logging.get_logger(__name__) 12 | 13 | 14 | class NgptConfig(PretrainedConfig): 15 | """ 16 | This is the configuration class to store the configuration of a [`NgptModel`] or a [`TFNgptModel`]. It is used to 17 | instantiate a nGPT model according to the specified arguments, defining the model architecture. Instantiating a 18 | configuration with the defaults will yield a similar configuration to that of the GPT-2 19 | [](https://huggingface.co/) architecture. 20 | 21 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 22 | documentation from [`PretrainedConfig`] for more information. 23 | 24 | 25 | Args: 26 | vocab_size (`int`, *optional*, defaults to 50257): 27 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the 28 | `inputs_ids` passed when calling [`NgptModel`] or [`TFNgptModel`]. 29 | n_positions (`int`, *optional*, defaults to 1024): 30 | The maximum sequence length that this model might ever be used with. Typically set this to something large 31 | just in case (e.g., 512 or 1024 or 2048). 32 | n_embd (`int`, *optional*, defaults to 768): 33 | Dimensionality of the embeddings and hidden states. 34 | n_layer (`int`, *optional*, defaults to 12): 35 | Number of hidden layers in the Transformer encoder. 36 | n_head (`int`, *optional*, defaults to 12): 37 | Number of attention heads for each attention layer in the Transformer encoder. 38 | n_inner (`int`, *optional*): 39 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd 40 | activation_function (`str`, *optional*, defaults to `"gelu_new"`): 41 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. 42 | resid_pdrop (`float`, *optional*, defaults to 0.1): 43 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 44 | embd_pdrop (`float`, *optional*, defaults to 0.1): 45 | The dropout ratio for the embeddings. 46 | attn_pdrop (`float`, *optional*, defaults to 0.1): 47 | The dropout ratio for the attention. 48 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): 49 | The epsilon to use in the layer normalization layers. 50 | initializer_range (`float`, *optional*, defaults to 0.02): 51 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 52 | summary_type (`string`, *optional*, defaults to `"cls_index"`): 53 | Argument used when doing sequence summary, used in the models [`NgptDoubleHeadsModel`] and 54 | [`TFNgptDoubleHeadsModel`]. 55 | 56 | Has to be one of the following options: 57 | 58 | - `"last"`: Take the last token hidden state (like XLNet). 59 | - `"first"`: Take the first token hidden state (like BERT). 60 | - `"mean"`: Take the mean of all tokens hidden states. 61 | - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). 62 | - `"attn"`: Not implemented now, use multi-head attention. 63 | summary_use_proj (`bool`, *optional*, defaults to `True`): 64 | Argument used when doing sequence summary, used in the models [`NgptDoubleHeadsModel`] and 65 | [`TFNgptDoubleHeadsModel`]. 66 | 67 | Whether or not to add a projection after the vector extraction. 68 | summary_activation (`str`, *optional*): 69 | Argument used when doing sequence summary. Used in for the multiple choice head in 70 | [`NgptDoubleHeadsModel`]. 71 | 72 | Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. 73 | summary_proj_to_labels (`bool`, *optional*, defaults to `True`): 74 | Argument used when doing sequence summary, used in the models [`NgptDoubleHeadsModel`] and 75 | [`TFNgptDoubleHeadsModel`]. 76 | 77 | Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. 78 | summary_first_dropout (`float`, *optional*, defaults to 0.1): 79 | Argument used when doing sequence summary, used in the models [`NgptDoubleHeadsModel`] and 80 | [`TFNgptDoubleHeadsModel`]. 81 | 82 | The dropout ratio to be used after the projection and activation. 83 | scale_attn_weights (`bool`, *optional*, defaults to `True`): 84 | Scale attention weights by dividing by sqrt(hidden_size).. 85 | use_cache (`bool`, *optional*, defaults to `True`): 86 | Whether or not the model should return the last key/values attentions (not used by all models). 87 | bos_token_id (`int`, *optional*, defaults to 50256): 88 | Id of the beginning of sentence token in the vocabulary. 89 | eos_token_id (`int`, *optional*, defaults to 50256): 90 | Id of the end of sentence token in the vocabulary. 91 | scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): 92 | Whether to additionally scale attention weights by `1 / layer_idx + 1`. 93 | reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): 94 | Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention 95 | dot-product/softmax to float() when training with mixed precision. 96 | 97 | Example: 98 | 99 | ```python 100 | >>> from transformers import NgptConfig, NgptModel 101 | 102 | >>> # Initializing a Ngpt configuration 103 | >>> configuration = NgptConfig() 104 | 105 | >>> # Initializing a model (with random weights) from the configuration 106 | >>> model = NgptModel(configuration) 107 | 108 | >>> # Accessing the model configuration 109 | >>> configuration = model.config 110 | ```""" 111 | 112 | model_type = "ngpt" 113 | keys_to_ignore_at_inference = ["past_key_values"] 114 | attribute_map = { 115 | "hidden_size": "n_embd", 116 | "max_position_embeddings": "n_positions", 117 | "num_attention_heads": "n_head", 118 | "num_hidden_layers": "n_layer", 119 | } 120 | 121 | def __init__( 122 | self, 123 | vocab_size=50257, 124 | n_positions=1024, 125 | n_embd=768, 126 | n_layer=12, 127 | n_head=12, 128 | n_inner=None, 129 | activation_function="silu", 130 | resid_pdrop=0.0, 131 | embd_pdrop=0.0, 132 | attn_pdrop=0.0, 133 | layer_norm_epsilon=1e-5, 134 | initializer_range=0.02, 135 | summary_type="cls_index", 136 | summary_use_proj=True, 137 | summary_activation=None, 138 | summary_proj_to_labels=True, 139 | summary_first_dropout=0.1, 140 | scale_attn_weights=True, 141 | use_cache=True, 142 | bos_token_id=50256, 143 | eos_token_id=50256, 144 | scale_attn_by_inverse_layer_idx=False, 145 | reorder_and_upcast_attn=False, 146 | **kwargs, 147 | ): 148 | self.vocab_size = vocab_size 149 | self.n_positions = n_positions 150 | self.n_embd = n_embd 151 | self.n_layer = n_layer 152 | self.n_head = n_head 153 | self.n_inner = n_inner 154 | self.activation_function = activation_function 155 | self.resid_pdrop = resid_pdrop 156 | self.embd_pdrop = embd_pdrop 157 | self.attn_pdrop = attn_pdrop 158 | self.layer_norm_epsilon = layer_norm_epsilon 159 | self.initializer_range = initializer_range 160 | self.summary_type = summary_type 161 | self.summary_use_proj = summary_use_proj 162 | self.summary_activation = summary_activation 163 | self.summary_first_dropout = summary_first_dropout 164 | self.summary_proj_to_labels = summary_proj_to_labels 165 | self.scale_attn_weights = scale_attn_weights 166 | self.use_cache = use_cache 167 | self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx 168 | self.reorder_and_upcast_attn = reorder_and_upcast_attn 169 | 170 | self.bos_token_id = bos_token_id 171 | self.eos_token_id = eos_token_id 172 | 173 | ''' 174 | ADD BASE SCALING FACTOR (section 2.2.2) 175 | ''' 176 | self.base_scale = 1.0 / (self.n_embd ** 0.5) 177 | self.initializer_range = self.base_scale 178 | 179 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 180 | 181 | -------------------------------------------------------------------------------- /data/openwebtext_llama/prepare.py: -------------------------------------------------------------------------------- 1 | # saves the openwebtext dataset to a binary file for training. following was helpful: 2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 3 | 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import tiktoken 8 | from datasets import load_dataset # huggingface datasets 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | # number of workers in .map() call 11 | # good number to use is ~order number of cpu cores // 2 12 | num_proc = 8 13 | 14 | # number of workers in load_dataset() call 15 | # best number might be different from num_proc above as it also depends on NW speed. 16 | # it is better than 1 usually though 17 | num_proc_load_dataset = num_proc 18 | 19 | #enc = tiktoken.get_encoding("gpt2") 20 | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") 21 | eot_token = 2 22 | 23 | if __name__ == '__main__': 24 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 25 | dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset) 26 | 27 | # owt by default only contains the 'train' split, so create a test split 28 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 29 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val 30 | 31 | # this results in: 32 | # >>> split_dataset 33 | # DatasetDict({ 34 | # train: Dataset({ 35 | # features: ['text'], 36 | # num_rows: 8009762 37 | # }) 38 | # val: Dataset({ 39 | # features: ['text'], 40 | # num_rows: 4007 41 | # }) 42 | # }) 43 | 44 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 45 | def process(example): 46 | ids = tokenizer.encode(example['text']) # encode_ordinary ignores any special tokens 47 | ids.append(eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 48 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 49 | out = {'ids': ids, 'len': len(ids)} 50 | return out 51 | 52 | # tokenize the dataset 53 | tokenized = split_dataset.map( 54 | process, 55 | remove_columns=['text'], 56 | desc="tokenizing the splits", 57 | num_proc=num_proc, 58 | ) 59 | 60 | # concatenate all the ids in each dataset into one large file we can use for training 61 | for split, dset in tokenized.items(): 62 | arr_len = np.sum(dset['len'], dtype=np.uint64) 63 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 64 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 65 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 66 | total_batches = 1024 67 | 68 | idx = 0 69 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): 70 | # Batch together samples for faster write 71 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') 72 | arr_batch = np.concatenate(batch['ids']) 73 | # Write into mmap 74 | arr[idx : idx + len(arr_batch)] = arr_batch 75 | idx += len(arr_batch) 76 | arr.flush() 77 | 78 | # train.bin is ~17GB, val.bin ~8.5MB 79 | # train has ~9B tokens (9,035,582,198) 80 | # val has ~4M tokens (4,434,897) 81 | 82 | # to read the bin files later, e.g. with numpy: 83 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 84 | -------------------------------------------------------------------------------- /images/4k_arceasy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_arceasy.png -------------------------------------------------------------------------------- /images/4k_average.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_average.png -------------------------------------------------------------------------------- /images/4k_hellaswag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_hellaswag.png -------------------------------------------------------------------------------- /images/4k_lambada.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_lambada.png -------------------------------------------------------------------------------- /images/4k_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_loss.png -------------------------------------------------------------------------------- /images/4k_winogrande.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_winogrande.png -------------------------------------------------------------------------------- /images/4k_wsc273.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_wsc273.png -------------------------------------------------------------------------------- /images/arc_easy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/arc_easy.png -------------------------------------------------------------------------------- /images/average.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/average.png -------------------------------------------------------------------------------- /images/hellaswag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/hellaswag.png -------------------------------------------------------------------------------- /images/lambada.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/lambada.png -------------------------------------------------------------------------------- /images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/loss.png -------------------------------------------------------------------------------- /images/network_params.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/network_params.png -------------------------------------------------------------------------------- /images/optimizer_params.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/optimizer_params.png -------------------------------------------------------------------------------- /images/wall_clock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/wall_clock.png -------------------------------------------------------------------------------- /images/winogrande.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/winogrande.png -------------------------------------------------------------------------------- /images/wsc273.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/wsc273.png -------------------------------------------------------------------------------- /modeling_gpt2.py: -------------------------------------------------------------------------------- 1 | """PyTorch OpenAI GPT-2 model modifications according to nGPT specifications""" 2 | 3 | import math 4 | import os 5 | import warnings 6 | from dataclasses import dataclass 7 | from typing import Callable, Optional, Tuple, Union 8 | import torch 9 | import torch.utils.checkpoint 10 | from torch import nn 11 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 12 | 13 | from transformers.activations import ACT2FN 14 | 15 | from transformers import GenerationMixin 16 | from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa 17 | from transformers.modeling_outputs import ( 18 | BaseModelOutputWithPastAndCrossAttentions, 19 | CausalLMOutputWithCrossAttentions, 20 | QuestionAnsweringModelOutput, 21 | SequenceClassifierOutputWithPast, 22 | TokenClassifierOutput, 23 | ) 24 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, SequenceSummary 25 | from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer 26 | from transformers.utils import ( 27 | ModelOutput, 28 | add_code_sample_docstrings, 29 | add_start_docstrings, 30 | add_start_docstrings_to_model_forward, 31 | logging, 32 | replace_return_docstrings, 33 | ) 34 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 35 | from configuration_gpt2 import GPT2Config 36 | from torchtune.modules import RotaryPositionalEmbeddings 37 | 38 | logger = logging.get_logger(__name__) 39 | 40 | _CHECKPOINT_FOR_DOC = "openai-community/gpt2" 41 | _CONFIG_FOR_DOC = "GPT2Config" 42 | 43 | 44 | def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): 45 | """Load tf checkpoints in a pytorch model""" 46 | try: 47 | import re 48 | 49 | import tensorflow as tf 50 | except ImportError: 51 | logger.error( 52 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 53 | "https://www.tensorflow.org/install/ for installation instructions." 54 | ) 55 | raise 56 | tf_path = os.path.abspath(gpt2_checkpoint_path) 57 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 58 | # Load weights from TF model 59 | init_vars = tf.train.list_variables(tf_path) 60 | names = [] 61 | arrays = [] 62 | for name, shape in init_vars: 63 | logger.info(f"Loading TF weight {name} with shape {shape}") 64 | array = tf.train.load_variable(tf_path, name) 65 | names.append(name) 66 | arrays.append(array.squeeze()) 67 | 68 | for name, array in zip(names, arrays): 69 | name = name[6:] # skip "model/" 70 | name = name.split("/") 71 | pointer = model 72 | for m_name in name: 73 | if re.fullmatch(r"[A-Za-z]+\d+", m_name): 74 | scope_names = re.split(r"(\d+)", m_name) 75 | else: 76 | scope_names = [m_name] 77 | if scope_names[0] == "w" or scope_names[0] == "g": 78 | pointer = getattr(pointer, "weight") 79 | elif scope_names[0] == "b": 80 | pointer = getattr(pointer, "bias") 81 | elif scope_names[0] == "wpe" or scope_names[0] == "wte": 82 | pointer = getattr(pointer, scope_names[0]) 83 | pointer = getattr(pointer, "weight") 84 | else: 85 | pointer = getattr(pointer, scope_names[0]) 86 | if len(scope_names) >= 2: 87 | num = int(scope_names[1]) 88 | pointer = pointer[num] 89 | try: 90 | if pointer.shape != array.shape: 91 | raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") 92 | except ValueError as e: 93 | e.args += (pointer.shape, array.shape) 94 | raise 95 | logger.info(f"Initialize PyTorch weight {name}") 96 | pointer.data = torch.from_numpy(array) 97 | return model 98 | 99 | 100 | def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs): 101 | attn_weights = torch.matmul(query, key.transpose(-1, -2)) 102 | 103 | if module.scale_attn_weights: 104 | attn_weights = attn_weights / torch.full( 105 | [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device 106 | ) 107 | 108 | # Layer-wise attention scaling 109 | if module.scale_attn_by_inverse_layer_idx: 110 | attn_weights = attn_weights / float(module.layer_idx + 1) 111 | 112 | if not module.is_cross_attention: 113 | # if only "normal" attention layer implements causal mask 114 | query_length, key_length = query.size(-2), key.size(-2) 115 | causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length] 116 | mask_value = torch.finfo(attn_weights.dtype).min 117 | # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. 118 | # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` 119 | mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) 120 | attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) 121 | 122 | if attention_mask is not None: 123 | # Apply the attention mask 124 | attn_weights = attn_weights + attention_mask 125 | 126 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 127 | 128 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise 129 | attn_weights = attn_weights.type(value.dtype) 130 | attn_weights = module.attn_dropout(attn_weights) 131 | 132 | # Mask heads if we want to 133 | if head_mask is not None: 134 | attn_weights = attn_weights * head_mask 135 | 136 | attn_output = torch.matmul(attn_weights, value) 137 | attn_output = attn_output.transpose(1, 2) 138 | 139 | return attn_output, attn_weights 140 | 141 | 142 | class GPT2Attention(nn.Module): 143 | def __init__(self, config, is_cross_attention=False, layer_idx=None): 144 | super().__init__() 145 | self.config = config 146 | max_positions = config.max_position_embeddings 147 | self.register_buffer( 148 | "bias", 149 | torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( 150 | 1, 1, max_positions, max_positions 151 | ), 152 | persistent=False, 153 | ) 154 | self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) 155 | 156 | self.embed_dim = config.hidden_size 157 | self.num_heads = config.num_attention_heads 158 | self.head_dim = self.embed_dim // self.num_heads 159 | self.split_size = self.embed_dim 160 | if self.head_dim * self.num_heads != self.embed_dim: 161 | raise ValueError( 162 | f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" 163 | f" {self.num_heads})." 164 | ) 165 | 166 | # INITIALIZE ROTARY POSITIONAL EMBEDDINGS HERE 167 | self.rpe = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=config.max_position_embeddings) 168 | 169 | self.scale_attn_weights = config.scale_attn_weights 170 | self.is_cross_attention = is_cross_attention 171 | 172 | # Layer-wise attention scaling, reordering, and upcasting 173 | self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx 174 | self.layer_idx = layer_idx 175 | self.reorder_and_upcast_attn = config.reorder_and_upcast_attn 176 | 177 | if self.is_cross_attention: 178 | self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) 179 | self.q_attn = Conv1D(self.embed_dim, self.embed_dim) 180 | else: 181 | self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) 182 | self.c_proj = Conv1D(self.embed_dim, self.embed_dim) 183 | 184 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 185 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 186 | self.is_causal = True 187 | 188 | self.pruned_heads = set() 189 | 190 | def prune_heads(self, heads): 191 | if len(heads) == 0: 192 | return 193 | heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) 194 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) 195 | 196 | # Prune conv1d layers 197 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 198 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 199 | 200 | # Update hyper params 201 | self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) 202 | self.num_heads = self.num_heads - len(heads) 203 | self.pruned_heads = self.pruned_heads.union(heads) 204 | 205 | def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): 206 | # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) 207 | bsz, num_heads, q_seq_len, dk = query.size() 208 | _, _, k_seq_len, _ = key.size() 209 | 210 | # Preallocate attn_weights for `baddbmm` 211 | attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) 212 | 213 | # Compute Scale Factor 214 | scale_factor = 1.0 215 | if self.scale_attn_weights: 216 | scale_factor /= float(value.size(-1)) ** 0.5 217 | 218 | if self.scale_attn_by_inverse_layer_idx: 219 | scale_factor /= float(self.layer_idx + 1) 220 | 221 | # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) 222 | with torch.amp.autocast(query.device.type, enabled=False): 223 | q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) 224 | attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) 225 | attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) 226 | 227 | if not self.is_cross_attention: 228 | # if only "normal" attention layer implements causal mask 229 | query_length, key_length = query.size(-2), key.size(-2) 230 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] 231 | mask_value = torch.finfo(attn_weights.dtype).min 232 | # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. 233 | # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` 234 | mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) 235 | attn_weights = torch.where(causal_mask, attn_weights, mask_value) 236 | 237 | if attention_mask is not None: 238 | # Apply the attention mask 239 | attn_weights = attn_weights + attention_mask 240 | 241 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 242 | 243 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise 244 | if attn_weights.dtype != torch.float32: 245 | raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") 246 | attn_weights = attn_weights.type(value.dtype) 247 | attn_weights = self.attn_dropout(attn_weights) 248 | 249 | # Mask heads if we want to 250 | if head_mask is not None: 251 | attn_weights = attn_weights * head_mask 252 | 253 | attn_output = torch.matmul(attn_weights, value) 254 | attn_output = attn_output.transpose(1, 2) 255 | 256 | return attn_output, attn_weights 257 | 258 | def forward( 259 | self, 260 | hidden_states: Optional[Tuple[torch.FloatTensor]], 261 | layer_past: Optional[Tuple[torch.Tensor]] = None, 262 | attention_mask: Optional[torch.FloatTensor] = None, 263 | head_mask: Optional[torch.FloatTensor] = None, 264 | encoder_hidden_states: Optional[torch.Tensor] = None, 265 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 266 | use_cache: Optional[bool] = False, 267 | output_attentions: Optional[bool] = False, 268 | **kwargs, 269 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: 270 | if encoder_hidden_states is not None: 271 | if not hasattr(self, "q_attn"): 272 | raise ValueError( 273 | "If class is used as cross attention, the weights `q_attn` have to be defined. " 274 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." 275 | ) 276 | 277 | query_states = self.q_attn(hidden_states) 278 | key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 279 | attention_mask = encoder_attention_mask 280 | else: 281 | query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) 282 | 283 | shape_q = (*query_states.shape[:-1], -1, self.head_dim) 284 | shape_kv = (*key_states.shape[:-1], -1, self.head_dim) 285 | 286 | 287 | query_states = query_states.view(shape_q).transpose(1, 2) 288 | key_states = key_states.view(shape_kv).transpose(1, 2) 289 | value_states = value_states.view(shape_kv).transpose(1, 2) 290 | 291 | # APPLY ROTARY POSITION EMBEDDING HERE 292 | query_states = query_states.transpose(1, 2) 293 | key_states = key_states.transpose(1, 2) 294 | 295 | query_states = self.rpe(query_states) 296 | key_states = self.rpe(key_states) 297 | 298 | query_states = query_states.transpose(1, 2) 299 | key_states = key_states.transpose(1, 2) 300 | 301 | if layer_past is not None: 302 | past_key, past_value = layer_past 303 | key_states = torch.cat((past_key, key_states), dim=-2) 304 | value_states = torch.cat((past_value, value_states), dim=-2) 305 | 306 | if use_cache is True: 307 | present = (key_states, value_states) 308 | else: 309 | present = None 310 | 311 | is_cross_attention = encoder_hidden_states is not None 312 | is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention 313 | 314 | using_eager = self.config._attn_implementation == "eager" 315 | attention_interface: Callable = eager_attention_forward 316 | #print(self.config._attn_implementation, query_states.shape) 317 | if self.config._attn_implementation != "eager": 318 | if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): 319 | using_eager = True 320 | logger.warning_once( 321 | "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 322 | 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' 323 | ) 324 | else: 325 | # Attention functions are consistent with previous equivalent attention classes, however they do not support some options 326 | # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but 327 | # not necessarily to eager (if mentionned options are provided). 328 | attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] 329 | 330 | if using_eager and self.reorder_and_upcast_attn: 331 | attn_output, attn_weights = self._upcast_and_reordered_attn( 332 | query_states, key_states, value_states, attention_mask, head_mask 333 | ) 334 | else: 335 | attn_output, attn_weights = attention_interface( 336 | self, 337 | query_states, 338 | key_states, 339 | value_states, 340 | attention_mask, 341 | head_mask=head_mask, 342 | dropout=self.attn_dropout.p if self.training else 0.0, 343 | is_causal=is_causal, 344 | **kwargs, 345 | ) 346 | 347 | attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() 348 | attn_output = self.c_proj(attn_output) 349 | attn_output = self.resid_dropout(attn_output) 350 | 351 | outputs = (attn_output, present) 352 | if output_attentions: 353 | outputs += (attn_weights,) 354 | 355 | return outputs # a, present, (attentions) 356 | 357 | 358 | class GPT2MLP(nn.Module): 359 | def __init__(self, intermediate_size, config): 360 | super().__init__() 361 | embed_dim = config.hidden_size 362 | #self.c_fc = Conv1D(intermediate_size, embed_dim) 363 | self.c_fc = Conv1D(2*intermediate_size, embed_dim) 364 | self.c_proj = Conv1D(embed_dim, intermediate_size) 365 | self.act = ACT2FN[config.activation_function] 366 | self.dropout = nn.Dropout(config.resid_pdrop) 367 | 368 | def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: 369 | hidden_states = self.c_fc(hidden_states) 370 | #IMPLEMENT SWIGLU 371 | u, v = torch.chunk(hidden_states, 2, dim=-1) 372 | #hidden_states = self.c_fc(hidden_states) 373 | hidden_states = u * self.act(v) 374 | # hidden_states = self.act(hidden_states) 375 | hidden_states = self.c_proj(hidden_states) 376 | hidden_states = self.dropout(hidden_states) 377 | return hidden_states 378 | 379 | 380 | class GPT2Block(nn.Module): 381 | def __init__(self, config, layer_idx=None): 382 | super().__init__() 383 | hidden_size = config.hidden_size 384 | #inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size 385 | inner_dim = config.n_inner if config.n_inner is not None else 8 * hidden_size // 3 386 | 387 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 388 | self.attn = GPT2Attention(config=config, layer_idx=layer_idx) 389 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 390 | 391 | if config.add_cross_attention: 392 | self.crossattention = GPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx) 393 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 394 | 395 | self.mlp = GPT2MLP(inner_dim, config) 396 | 397 | def forward( 398 | self, 399 | hidden_states: Optional[Tuple[torch.FloatTensor]], 400 | layer_past: Optional[Tuple[torch.Tensor]] = None, 401 | attention_mask: Optional[torch.FloatTensor] = None, 402 | head_mask: Optional[torch.FloatTensor] = None, 403 | encoder_hidden_states: Optional[torch.Tensor] = None, 404 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 405 | use_cache: Optional[bool] = False, 406 | output_attentions: Optional[bool] = False, 407 | ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: 408 | residual = hidden_states 409 | hidden_states = self.ln_1(hidden_states) 410 | attn_outputs = self.attn( 411 | hidden_states, 412 | layer_past=layer_past, 413 | attention_mask=attention_mask, 414 | head_mask=head_mask, 415 | use_cache=use_cache, 416 | output_attentions=output_attentions, 417 | ) 418 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 419 | outputs = attn_outputs[1:] 420 | # residual connection 421 | hidden_states = attn_output + residual 422 | 423 | if encoder_hidden_states is not None: 424 | # add one self-attention block for cross-attention 425 | if not hasattr(self, "crossattention"): 426 | raise ValueError( 427 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " 428 | "cross-attention layers by setting `config.add_cross_attention=True`" 429 | ) 430 | residual = hidden_states 431 | hidden_states = self.ln_cross_attn(hidden_states) 432 | cross_attn_outputs = self.crossattention( 433 | hidden_states, 434 | attention_mask=attention_mask, 435 | head_mask=head_mask, 436 | encoder_hidden_states=encoder_hidden_states, 437 | encoder_attention_mask=encoder_attention_mask, 438 | output_attentions=output_attentions, 439 | ) 440 | attn_output = cross_attn_outputs[0] 441 | # residual connection 442 | hidden_states = residual + attn_output 443 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights 444 | 445 | residual = hidden_states 446 | hidden_states = self.ln_2(hidden_states) 447 | feed_forward_hidden_states = self.mlp(hidden_states) 448 | # residual connection 449 | hidden_states = residual + feed_forward_hidden_states 450 | 451 | if use_cache: 452 | outputs = (hidden_states,) + outputs 453 | else: 454 | outputs = (hidden_states,) + outputs[1:] 455 | 456 | return outputs # hidden_states, present, (attentions, cross_attentions) 457 | 458 | 459 | class GPT2PreTrainedModel(PreTrainedModel): 460 | """ 461 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 462 | models. 463 | """ 464 | 465 | config_class = GPT2Config 466 | load_tf_weights = load_tf_weights_in_gpt2 467 | base_model_prefix = "transformer" 468 | is_parallelizable = True 469 | supports_gradient_checkpointing = True 470 | _no_split_modules = ["GPT2Block"] 471 | _skip_keys_device_placement = "past_key_values" 472 | _supports_flash_attn_2 = True 473 | _supports_sdpa = True 474 | 475 | def __init__(self, *inputs, **kwargs): 476 | super().__init__(*inputs, **kwargs) 477 | 478 | def _init_weights(self, module): 479 | """Initialize the weights.""" 480 | if isinstance(module, (nn.Linear, Conv1D)): 481 | # Slightly different from the TF version which uses truncated_normal for initialization 482 | # cf https://github.com/pytorch/pytorch/pull/5617 483 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 484 | if module.bias is not None: 485 | module.bias.data.zero_() 486 | elif isinstance(module, nn.Embedding): 487 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 488 | if module.padding_idx is not None: 489 | module.weight.data[module.padding_idx].zero_() 490 | elif isinstance(module, nn.LayerNorm): 491 | module.bias.data.zero_() 492 | module.weight.data.fill_(1.0) 493 | 494 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 495 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 496 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 497 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 498 | # 499 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 500 | for name, p in module.named_parameters(): 501 | if name == "c_proj.weight": 502 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 503 | p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) 504 | 505 | 506 | @dataclass 507 | class GPT2DoubleHeadsModelOutput(ModelOutput): 508 | """ 509 | Base class for outputs of models predicting if two sentences are consecutive or not. 510 | 511 | Args: 512 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 513 | Language modeling loss. 514 | mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): 515 | Multiple choice classification loss. 516 | logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): 517 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 518 | mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): 519 | Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). 520 | past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 521 | Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, 522 | sequence_length, embed_size_per_head)`). 523 | 524 | Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see 525 | `past_key_values` input) to speed up sequential decoding. 526 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 527 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of 528 | shape `(batch_size, sequence_length, hidden_size)`. 529 | 530 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 531 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 532 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 533 | sequence_length)`. 534 | 535 | GPT2Attentions weights after the attention softmax, used to compute the weighted average in the 536 | self-attention heads. 537 | """ 538 | 539 | loss: Optional[torch.FloatTensor] = None 540 | mc_loss: Optional[torch.FloatTensor] = None 541 | logits: torch.FloatTensor = None 542 | mc_logits: torch.FloatTensor = None 543 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 544 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 545 | attentions: Optional[Tuple[torch.FloatTensor]] = None 546 | 547 | 548 | GPT2_START_DOCSTRING = r""" 549 | 550 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 551 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 552 | etc.) 553 | 554 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 555 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 556 | and behavior. 557 | 558 | Parameters: 559 | config ([`GPT2Config`]): Model configuration class with all the parameters of the model. 560 | Initializing with a config file does not load the weights associated with the model, only the 561 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 562 | """ 563 | 564 | GPT2_INPUTS_DOCSTRING = r""" 565 | Args: 566 | input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): 567 | `input_ids_length` = `sequence_length` if `past_key_values` is `None` else 568 | `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input 569 | sequence tokens in the vocabulary. 570 | 571 | If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as 572 | `input_ids`. 573 | 574 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 575 | [`PreTrainedTokenizer.__call__`] for details. 576 | 577 | [What are input IDs?](../glossary#input-ids) 578 | past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): 579 | Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see 580 | `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have 581 | their past given to this model should not be passed as `input_ids` as they have already been computed. 582 | attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 583 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 584 | 585 | - 1 for tokens that are **not masked**, 586 | - 0 for tokens that are **masked**. 587 | 588 | If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for 589 | `past_key_values`. In other words, the `attention_mask` always has to have the length: 590 | `len(past_key_values) + len(input_ids)` 591 | 592 | [What are attention masks?](../glossary#attention-mask) 593 | token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): 594 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 595 | 1]`: 596 | 597 | - 0 corresponds to a *sentence A* token, 598 | - 1 corresponds to a *sentence B* token. 599 | 600 | [What are token type IDs?](../glossary#token-type-ids) 601 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 602 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 603 | config.max_position_embeddings - 1]`. 604 | 605 | [What are position IDs?](../glossary#position-ids) 606 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 607 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 608 | 609 | - 1 indicates the head is **not masked**, 610 | - 0 indicates the head is **masked**. 611 | 612 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 613 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 614 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 615 | model's internal embedding lookup matrix. 616 | 617 | If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see 618 | `past_key_values`). 619 | use_cache (`bool`, *optional*): 620 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 621 | `past_key_values`). 622 | output_attentions (`bool`, *optional*): 623 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 624 | tensors for more detail. 625 | output_hidden_states (`bool`, *optional*): 626 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 627 | more detail. 628 | return_dict (`bool`, *optional*): 629 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 630 | """ 631 | PARALLELIZE_DOCSTRING = r""" 632 | This is an experimental feature and is a subject to change at a moment's notice. 633 | 634 | Uses a device map to distribute attention modules of the model across several devices. If no device map is given, 635 | it will evenly distribute blocks across all devices. 636 | 637 | Args: 638 | device_map (`Dict[int, list]`, *optional*): 639 | A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always 640 | automatically mapped to the first device (for esoteric reasons). That means that the first device should 641 | have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the 642 | following number of attention modules: 643 | 644 | - openai-community/gpt2: 12 645 | - openai-community/gpt2-medium: 24 646 | - openai-community/gpt2-large: 36 647 | - openai-community/gpt2-xl: 48 648 | 649 | Example: 650 | 651 | ```python 652 | # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: 653 | model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") 654 | device_map = { 655 | 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], 656 | 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], 657 | 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], 658 | 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], 659 | } 660 | model.parallelize(device_map) 661 | ``` 662 | """ 663 | DEPARALLELIZE_DOCSTRING = r""" 664 | Moves the model to cpu from a model parallel state. 665 | 666 | Example: 667 | 668 | ```python 669 | # On a 4 GPU machine with openai-community/gpt2-large: 670 | model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") 671 | device_map = { 672 | 0: [0, 1, 2, 3, 4, 5, 6, 7], 673 | 1: [8, 9, 10, 11, 12, 13, 14, 15], 674 | 2: [16, 17, 18, 19, 20, 21, 22, 23], 675 | 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], 676 | } 677 | model.parallelize(device_map) # Splits the model across several devices 678 | model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() 679 | ``` 680 | """ 681 | 682 | 683 | @add_start_docstrings( 684 | "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", 685 | GPT2_START_DOCSTRING, 686 | ) 687 | class GPT2Model(GPT2PreTrainedModel): 688 | _supports_param_buffer_assignment = False 689 | 690 | def __init__(self, config): 691 | super().__init__(config) 692 | 693 | self.embed_dim = config.hidden_size 694 | 695 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim) 696 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) 697 | 698 | self.drop = nn.Dropout(config.embd_pdrop) 699 | self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) 700 | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) 701 | 702 | # Model parallel 703 | self.model_parallel = False 704 | self.device_map = None 705 | self.gradient_checkpointing = False 706 | self._attn_implementation = config._attn_implementation 707 | 708 | # Initialize weights and apply final processing 709 | self.post_init() 710 | 711 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 712 | def parallelize(self, device_map=None): 713 | # Check validity of device_map 714 | warnings.warn( 715 | "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" 716 | " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" 717 | " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," 718 | " ...}", 719 | FutureWarning, 720 | ) 721 | self.device_map = ( 722 | get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map 723 | ) 724 | assert_device_map(self.device_map, len(self.h)) 725 | self.model_parallel = True 726 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 727 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 728 | self.wte = self.wte.to(self.first_device) 729 | self.wpe = self.wpe.to(self.first_device) 730 | # Load onto devices 731 | for k, v in self.device_map.items(): 732 | for block in v: 733 | cuda_device = "cuda:" + str(k) 734 | self.h[block] = self.h[block].to(cuda_device) 735 | # ln_f to last 736 | self.ln_f = self.ln_f.to(self.last_device) 737 | 738 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 739 | def deparallelize(self): 740 | warnings.warn( 741 | "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", 742 | FutureWarning, 743 | ) 744 | self.model_parallel = False 745 | self.device_map = None 746 | self.first_device = "cpu" 747 | self.last_device = "cpu" 748 | self.wte = self.wte.to("cpu") 749 | self.wpe = self.wpe.to("cpu") 750 | for index in range(len(self.h)): 751 | self.h[index] = self.h[index].to("cpu") 752 | self.ln_f = self.ln_f.to("cpu") 753 | torch.cuda.empty_cache() 754 | 755 | def get_input_embeddings(self): 756 | return self.wte 757 | 758 | def set_input_embeddings(self, new_embeddings): 759 | self.wte = new_embeddings 760 | 761 | def _prune_heads(self, heads_to_prune): 762 | """ 763 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 764 | """ 765 | for layer, heads in heads_to_prune.items(): 766 | self.h[layer].attn.prune_heads(heads) 767 | 768 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 769 | @add_code_sample_docstrings( 770 | checkpoint=_CHECKPOINT_FOR_DOC, 771 | output_type=BaseModelOutputWithPastAndCrossAttentions, 772 | config_class=_CONFIG_FOR_DOC, 773 | ) 774 | def forward( 775 | self, 776 | input_ids: Optional[torch.LongTensor] = None, 777 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 778 | attention_mask: Optional[torch.FloatTensor] = None, 779 | token_type_ids: Optional[torch.LongTensor] = None, 780 | position_ids: Optional[torch.LongTensor] = None, 781 | head_mask: Optional[torch.FloatTensor] = None, 782 | inputs_embeds: Optional[torch.FloatTensor] = None, 783 | encoder_hidden_states: Optional[torch.Tensor] = None, 784 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 785 | use_cache: Optional[bool] = None, 786 | output_attentions: Optional[bool] = None, 787 | output_hidden_states: Optional[bool] = None, 788 | return_dict: Optional[bool] = None, 789 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 790 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 791 | output_hidden_states = ( 792 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 793 | ) 794 | use_cache = use_cache if use_cache is not None else self.config.use_cache 795 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 796 | 797 | if input_ids is not None and inputs_embeds is not None: 798 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 799 | elif input_ids is not None: 800 | self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) 801 | input_shape = input_ids.size() 802 | input_ids = input_ids.view(-1, input_shape[-1]) 803 | batch_size = input_ids.shape[0] 804 | elif inputs_embeds is not None: 805 | input_shape = inputs_embeds.size()[:-1] 806 | batch_size = inputs_embeds.shape[0] 807 | else: 808 | raise ValueError("You have to specify either input_ids or inputs_embeds") 809 | 810 | device = input_ids.device if input_ids is not None else inputs_embeds.device 811 | 812 | if token_type_ids is not None: 813 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 814 | 815 | if past_key_values is None: 816 | past_length = 0 817 | past_key_values = tuple([None] * len(self.h)) 818 | else: 819 | past_length = past_key_values[0][0].size(-2) 820 | if position_ids is None: 821 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 822 | position_ids = position_ids.unsqueeze(0) 823 | 824 | if inputs_embeds is None: 825 | inputs_embeds = self.wte(input_ids) 826 | position_embeds = self.wpe(position_ids) 827 | hidden_states = inputs_embeds + position_embeds 828 | 829 | # Attention mask. 830 | _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None 831 | attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None 832 | if self._attn_implementation == "flash_attention_2": 833 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 834 | elif _use_sdpa: 835 | attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( 836 | attention_mask=attention_mask, 837 | input_shape=(batch_size, input_shape[-1]), 838 | inputs_embeds=inputs_embeds, 839 | past_key_values_length=past_length, 840 | ) 841 | else: 842 | if attention_mask is not None: 843 | # We create a 3D attention mask from a 2D tensor mask. 844 | # Sizes are [batch_size, 1, 1, to_seq_length] 845 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 846 | # this attention mask is more simple than the triangular masking of causal attention 847 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 848 | attention_mask = attention_mask[:, None, None, :] 849 | 850 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 851 | # masked positions, this operation will create a tensor which is 0.0 for 852 | # positions we want to attend and the dtype's smallest value for masked positions. 853 | # Since we are adding it to the raw scores before the softmax, this is 854 | # effectively the same as removing these entirely. 855 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility 856 | attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min 857 | 858 | # If a 2D or 3D attention mask is provided for the cross-attention 859 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 860 | if self.config.add_cross_attention and encoder_hidden_states is not None: 861 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 862 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 863 | if encoder_attention_mask is None: 864 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 865 | if _use_sdpa: 866 | encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( 867 | mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1] 868 | ) 869 | elif not self._attn_implementation == "flash_attention_2": 870 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) 871 | else: 872 | encoder_attention_mask = None 873 | 874 | # Prepare head mask if needed 875 | # 1.0 in head_mask indicate we keep the head 876 | # attention_probs has shape bsz x n_heads x N x N 877 | # head_mask has shape n_layer x batch x n_heads x N x N 878 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 879 | 880 | if token_type_ids is not None: 881 | token_type_embeds = self.wte(token_type_ids) 882 | hidden_states = hidden_states + token_type_embeds 883 | 884 | hidden_states = self.drop(hidden_states) 885 | 886 | output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) 887 | 888 | if self.gradient_checkpointing and self.training: 889 | if use_cache: 890 | logger.warning_once( 891 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 892 | ) 893 | use_cache = False 894 | 895 | presents = () if use_cache else None 896 | all_self_attentions = () if output_attentions else None 897 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 898 | all_hidden_states = () if output_hidden_states else None 899 | for i in range(len(self.h)): 900 | block, layer_past = self.h[i], past_key_values[i] 901 | # Model parallel 902 | if self.model_parallel: 903 | torch.cuda.set_device(hidden_states.device) 904 | # Ensure layer_past is on same device as hidden_states (might not be correct) 905 | if layer_past is not None: 906 | layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) 907 | # Ensure that attention_mask is always on the same device as hidden_states 908 | if attention_mask is not None: 909 | attention_mask = attention_mask.to(hidden_states.device) 910 | if isinstance(head_mask, torch.Tensor): 911 | head_mask = head_mask.to(hidden_states.device) 912 | if output_hidden_states: 913 | all_hidden_states = all_hidden_states + (hidden_states,) 914 | 915 | if self.gradient_checkpointing and self.training: 916 | outputs = self._gradient_checkpointing_func( 917 | block.__call__, 918 | hidden_states, 919 | None, 920 | attention_mask, 921 | head_mask[i], 922 | encoder_hidden_states, 923 | encoder_attention_mask, 924 | use_cache, 925 | output_attentions, 926 | ) 927 | else: 928 | outputs = block( 929 | hidden_states, 930 | layer_past=layer_past, 931 | attention_mask=attention_mask, 932 | head_mask=head_mask[i], 933 | encoder_hidden_states=encoder_hidden_states, 934 | encoder_attention_mask=encoder_attention_mask, 935 | use_cache=use_cache, 936 | output_attentions=output_attentions, 937 | ) 938 | 939 | hidden_states = outputs[0] 940 | if use_cache is True: 941 | presents = presents + (outputs[1],) 942 | 943 | if output_attentions: 944 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 945 | if self.config.add_cross_attention: 946 | all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) 947 | 948 | # Model Parallel: If it's the last layer for that device, put things on the next device 949 | if self.model_parallel: 950 | for k, v in self.device_map.items(): 951 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 952 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 953 | 954 | hidden_states = self.ln_f(hidden_states) 955 | 956 | hidden_states = hidden_states.view(output_shape) 957 | # Add last hidden state 958 | if output_hidden_states: 959 | all_hidden_states = all_hidden_states + (hidden_states,) 960 | 961 | if not return_dict: 962 | return tuple( 963 | v 964 | for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] 965 | if v is not None 966 | ) 967 | 968 | return BaseModelOutputWithPastAndCrossAttentions( 969 | last_hidden_state=hidden_states, 970 | past_key_values=presents, 971 | hidden_states=all_hidden_states, 972 | attentions=all_self_attentions, 973 | cross_attentions=all_cross_attentions, 974 | ) 975 | 976 | 977 | @add_start_docstrings( 978 | """ 979 | The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input 980 | embeddings). 981 | """, 982 | GPT2_START_DOCSTRING, 983 | ) 984 | class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin): 985 | _tied_weights_keys = ["lm_head.weight"] 986 | 987 | def __init__(self, config): 988 | super().__init__(config) 989 | self.transformer = GPT2Model(config) 990 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # this line is now completely useless. lm_head will not be used anywhere 991 | self.lm_head_2 = nn.Linear(config.n_embd, config.vocab_size, bias=False) 992 | 993 | # Model parallel 994 | self.model_parallel = False 995 | self.device_map = None 996 | 997 | # Initialize weights and apply final processing 998 | self.post_init() 999 | 1000 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 1001 | def parallelize(self, device_map=None): 1002 | warnings.warn( 1003 | "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" 1004 | " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" 1005 | " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" 1006 | " 0, 'transformer.h.1': 1, ...}", 1007 | FutureWarning, 1008 | ) 1009 | self.device_map = ( 1010 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) 1011 | if device_map is None 1012 | else device_map 1013 | ) 1014 | assert_device_map(self.device_map, len(self.transformer.h)) 1015 | self.transformer.parallelize(self.device_map) 1016 | self.lm_head = self.lm_head.to(self.transformer.first_device) 1017 | self.model_parallel = True 1018 | 1019 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 1020 | def deparallelize(self): 1021 | warnings.warn( 1022 | "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", 1023 | FutureWarning, 1024 | ) 1025 | self.transformer.deparallelize() 1026 | self.transformer = self.transformer.to("cpu") 1027 | self.lm_head = self.lm_head.to("cpu") 1028 | self.model_parallel = False 1029 | torch.cuda.empty_cache() 1030 | 1031 | def get_output_embeddings(self): 1032 | return self.lm_head 1033 | 1034 | def set_output_embeddings(self, new_embeddings): 1035 | self.lm_head = new_embeddings 1036 | 1037 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 1038 | @add_code_sample_docstrings( 1039 | checkpoint=_CHECKPOINT_FOR_DOC, 1040 | output_type=CausalLMOutputWithCrossAttentions, 1041 | config_class=_CONFIG_FOR_DOC, 1042 | ) 1043 | def forward( 1044 | self, 1045 | input_ids: Optional[torch.LongTensor] = None, 1046 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 1047 | attention_mask: Optional[torch.FloatTensor] = None, 1048 | token_type_ids: Optional[torch.LongTensor] = None, 1049 | position_ids: Optional[torch.LongTensor] = None, 1050 | head_mask: Optional[torch.FloatTensor] = None, 1051 | inputs_embeds: Optional[torch.FloatTensor] = None, 1052 | encoder_hidden_states: Optional[torch.Tensor] = None, 1053 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 1054 | labels: Optional[torch.LongTensor] = None, 1055 | use_cache: Optional[bool] = None, 1056 | output_attentions: Optional[bool] = None, 1057 | output_hidden_states: Optional[bool] = None, 1058 | return_dict: Optional[bool] = None, 1059 | ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: 1060 | r""" 1061 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1062 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 1063 | `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` 1064 | are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` 1065 | """ 1066 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1067 | 1068 | transformer_outputs = self.transformer( 1069 | input_ids, 1070 | past_key_values=past_key_values, 1071 | attention_mask=attention_mask, 1072 | token_type_ids=token_type_ids, 1073 | position_ids=position_ids, 1074 | head_mask=head_mask, 1075 | inputs_embeds=inputs_embeds, 1076 | encoder_hidden_states=encoder_hidden_states, 1077 | encoder_attention_mask=encoder_attention_mask, 1078 | use_cache=use_cache, 1079 | output_attentions=output_attentions, 1080 | output_hidden_states=output_hidden_states, 1081 | return_dict=return_dict, 1082 | ) 1083 | hidden_states = transformer_outputs[0] 1084 | 1085 | # Set device for model parallelism 1086 | if self.model_parallel: 1087 | torch.cuda.set_device(self.transformer.first_device) 1088 | hidden_states = hidden_states.to(self.lm_head.weight.device) 1089 | 1090 | lm_logits = self.lm_head_2(hidden_states) 1091 | 1092 | loss = None 1093 | if labels is not None: 1094 | # move labels to correct device to enable model parallelism 1095 | labels = labels.to(lm_logits.device) 1096 | # Shift so that tokens < n predict n 1097 | shift_logits = lm_logits[..., :-1, :].contiguous() 1098 | shift_labels = labels[..., 1:].contiguous() 1099 | # Flatten the tokens 1100 | loss_fct = CrossEntropyLoss() 1101 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1102 | 1103 | if not return_dict: 1104 | output = (lm_logits,) + transformer_outputs[1:] 1105 | return ((loss,) + output) if loss is not None else output 1106 | 1107 | return CausalLMOutputWithCrossAttentions( 1108 | loss=loss, 1109 | logits=lm_logits, 1110 | past_key_values=transformer_outputs.past_key_values, 1111 | hidden_states=transformer_outputs.hidden_states, 1112 | attentions=transformer_outputs.attentions, 1113 | cross_attentions=transformer_outputs.cross_attentions, 1114 | ) 1115 | 1116 | @staticmethod 1117 | def _reorder_cache( 1118 | past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor 1119 | ) -> Tuple[Tuple[torch.Tensor]]: 1120 | """ 1121 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 1122 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 1123 | beam_idx at every generation step. 1124 | """ 1125 | return tuple( 1126 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) 1127 | for layer_past in past_key_values 1128 | ) 1129 | 1130 | 1131 | @add_start_docstrings( 1132 | """ 1133 | The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for 1134 | RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the 1135 | input embeddings, the classification head takes as input the input of a specified classification token index in the 1136 | input sequence). 1137 | """, 1138 | GPT2_START_DOCSTRING, 1139 | ) 1140 | class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin): 1141 | _tied_weights_keys = ["lm_head.weight"] 1142 | 1143 | def __init__(self, config): 1144 | super().__init__(config) 1145 | config.num_labels = 1 1146 | self.transformer = GPT2Model(config) 1147 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 1148 | self.multiple_choice_head = SequenceSummary(config) 1149 | 1150 | # Model parallel 1151 | self.model_parallel = False 1152 | self.device_map = None 1153 | 1154 | # Initialize weights and apply final processing 1155 | self.post_init() 1156 | 1157 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 1158 | def parallelize(self, device_map=None): 1159 | warnings.warn( 1160 | "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" 1161 | " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" 1162 | " own `device_map` but it needs to be a dictionary module_name to device, so for instance" 1163 | " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", 1164 | FutureWarning, 1165 | ) 1166 | self.device_map = ( 1167 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) 1168 | if device_map is None 1169 | else device_map 1170 | ) 1171 | assert_device_map(self.device_map, len(self.transformer.h)) 1172 | self.transformer.parallelize(self.device_map) 1173 | self.lm_head = self.lm_head.to(self.transformer.first_device) 1174 | self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) 1175 | self.model_parallel = True 1176 | 1177 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 1178 | def deparallelize(self): 1179 | warnings.warn( 1180 | "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", 1181 | FutureWarning, 1182 | ) 1183 | self.transformer.deparallelize() 1184 | self.transformer = self.transformer.to("cpu") 1185 | self.lm_head = self.lm_head.to("cpu") 1186 | self.multiple_choice_head = self.multiple_choice_head.to("cpu") 1187 | self.model_parallel = False 1188 | torch.cuda.empty_cache() 1189 | 1190 | def get_output_embeddings(self): 1191 | return self.lm_head 1192 | 1193 | def set_output_embeddings(self, new_embeddings): 1194 | self.lm_head = new_embeddings 1195 | 1196 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 1197 | @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) 1198 | def forward( 1199 | self, 1200 | input_ids: Optional[torch.LongTensor] = None, 1201 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 1202 | attention_mask: Optional[torch.FloatTensor] = None, 1203 | token_type_ids: Optional[torch.LongTensor] = None, 1204 | position_ids: Optional[torch.LongTensor] = None, 1205 | head_mask: Optional[torch.FloatTensor] = None, 1206 | inputs_embeds: Optional[torch.FloatTensor] = None, 1207 | mc_token_ids: Optional[torch.LongTensor] = None, 1208 | labels: Optional[torch.LongTensor] = None, 1209 | mc_labels: Optional[torch.LongTensor] = None, 1210 | use_cache: Optional[bool] = None, 1211 | output_attentions: Optional[bool] = None, 1212 | output_hidden_states: Optional[bool] = None, 1213 | return_dict: Optional[bool] = None, 1214 | **kwargs, 1215 | ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: 1216 | r""" 1217 | mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): 1218 | Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - 1219 | 1]`. 1220 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1221 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 1222 | `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to 1223 | `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` 1224 | mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): 1225 | Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` 1226 | where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) 1227 | 1228 | Return: 1229 | 1230 | Example: 1231 | 1232 | ```python 1233 | >>> import torch 1234 | >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel 1235 | 1236 | >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") 1237 | >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") 1238 | 1239 | >>> # Add a [CLS] to the vocabulary (we should train it also!) 1240 | >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) 1241 | >>> # Update the model embeddings with the new vocabulary size 1242 | >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) 1243 | 1244 | >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] 1245 | >>> encoded_choices = [tokenizer.encode(s) for s in choices] 1246 | >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] 1247 | 1248 | >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 1249 | >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 1250 | 1251 | >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) 1252 | >>> lm_logits = outputs.logits 1253 | >>> mc_logits = outputs.mc_logits 1254 | ```""" 1255 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1256 | 1257 | transformer_outputs = self.transformer( 1258 | input_ids, 1259 | past_key_values=past_key_values, 1260 | attention_mask=attention_mask, 1261 | token_type_ids=token_type_ids, 1262 | position_ids=position_ids, 1263 | head_mask=head_mask, 1264 | inputs_embeds=inputs_embeds, 1265 | use_cache=use_cache, 1266 | output_attentions=output_attentions, 1267 | output_hidden_states=output_hidden_states, 1268 | return_dict=return_dict, 1269 | ) 1270 | 1271 | hidden_states = transformer_outputs[0] 1272 | 1273 | # Set device for model parallelism 1274 | if self.model_parallel: 1275 | torch.cuda.set_device(self.transformer.first_device) 1276 | hidden_states = hidden_states.to(self.lm_head.weight.device) 1277 | 1278 | lm_logits = self.lm_head(hidden_states) 1279 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) 1280 | 1281 | mc_loss = None 1282 | if mc_labels is not None: 1283 | loss_fct = CrossEntropyLoss() 1284 | mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) 1285 | lm_loss = None 1286 | if labels is not None: 1287 | labels = labels.to(lm_logits.device) 1288 | shift_logits = lm_logits[..., :-1, :].contiguous() 1289 | shift_labels = labels[..., 1:].contiguous() 1290 | loss_fct = CrossEntropyLoss() 1291 | lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1292 | 1293 | if not return_dict: 1294 | output = (lm_logits, mc_logits) + transformer_outputs[1:] 1295 | if mc_loss is not None: 1296 | output = (mc_loss,) + output 1297 | return ((lm_loss,) + output) if lm_loss is not None else output 1298 | 1299 | return GPT2DoubleHeadsModelOutput( 1300 | loss=lm_loss, 1301 | mc_loss=mc_loss, 1302 | logits=lm_logits, 1303 | mc_logits=mc_logits, 1304 | past_key_values=transformer_outputs.past_key_values, 1305 | hidden_states=transformer_outputs.hidden_states, 1306 | attentions=transformer_outputs.attentions, 1307 | ) 1308 | 1309 | @staticmethod 1310 | def _reorder_cache( 1311 | past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor 1312 | ) -> Tuple[Tuple[torch.Tensor]]: 1313 | """ 1314 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 1315 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 1316 | beam_idx at every generation step. 1317 | """ 1318 | return tuple( 1319 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) 1320 | for layer_past in past_key_values 1321 | ) 1322 | 1323 | 1324 | @add_start_docstrings( 1325 | """ 1326 | The GPT2 Model transformer with a sequence classification head on top (linear layer). 1327 | 1328 | [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1329 | (e.g. GPT-1) do. 1330 | 1331 | Since it does classification on the last token, it requires to know the position of the last token. If a 1332 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1333 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1334 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1335 | each row of the batch). 1336 | """, 1337 | GPT2_START_DOCSTRING, 1338 | ) 1339 | class GPT2ForSequenceClassification(GPT2PreTrainedModel): 1340 | def __init__(self, config): 1341 | super().__init__(config) 1342 | self.num_labels = config.num_labels 1343 | self.transformer = GPT2Model(config) 1344 | self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) 1345 | 1346 | # Model parallel 1347 | self.model_parallel = False 1348 | self.device_map = None 1349 | 1350 | # Initialize weights and apply final processing 1351 | self.post_init() 1352 | 1353 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 1354 | @add_code_sample_docstrings( 1355 | checkpoint="microsoft/DialogRPT-updown", 1356 | output_type=SequenceClassifierOutputWithPast, 1357 | config_class=_CONFIG_FOR_DOC, 1358 | ) 1359 | def forward( 1360 | self, 1361 | input_ids: Optional[torch.LongTensor] = None, 1362 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 1363 | attention_mask: Optional[torch.FloatTensor] = None, 1364 | token_type_ids: Optional[torch.LongTensor] = None, 1365 | position_ids: Optional[torch.LongTensor] = None, 1366 | head_mask: Optional[torch.FloatTensor] = None, 1367 | inputs_embeds: Optional[torch.FloatTensor] = None, 1368 | labels: Optional[torch.LongTensor] = None, 1369 | use_cache: Optional[bool] = None, 1370 | output_attentions: Optional[bool] = None, 1371 | output_hidden_states: Optional[bool] = None, 1372 | return_dict: Optional[bool] = None, 1373 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1374 | r""" 1375 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1376 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1377 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1378 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1379 | """ 1380 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1381 | 1382 | transformer_outputs = self.transformer( 1383 | input_ids, 1384 | past_key_values=past_key_values, 1385 | attention_mask=attention_mask, 1386 | token_type_ids=token_type_ids, 1387 | position_ids=position_ids, 1388 | head_mask=head_mask, 1389 | inputs_embeds=inputs_embeds, 1390 | use_cache=use_cache, 1391 | output_attentions=output_attentions, 1392 | output_hidden_states=output_hidden_states, 1393 | return_dict=return_dict, 1394 | ) 1395 | hidden_states = transformer_outputs[0] 1396 | logits = self.score(hidden_states) 1397 | 1398 | if input_ids is not None: 1399 | batch_size, sequence_length = input_ids.shape[:2] 1400 | else: 1401 | batch_size, sequence_length = inputs_embeds.shape[:2] 1402 | 1403 | assert ( 1404 | self.config.pad_token_id is not None or batch_size == 1 1405 | ), "Cannot handle batch sizes > 1 if no padding token is defined." 1406 | if self.config.pad_token_id is None: 1407 | sequence_lengths = -1 1408 | else: 1409 | if input_ids is not None: 1410 | # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility 1411 | sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 1412 | sequence_lengths = sequence_lengths % input_ids.shape[-1] 1413 | sequence_lengths = sequence_lengths.to(logits.device) 1414 | else: 1415 | sequence_lengths = -1 1416 | logger.warning_once( 1417 | f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " 1418 | "unexpected if using padding tokens in conjunction with `inputs_embeds.`" 1419 | ) 1420 | 1421 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 1422 | 1423 | loss = None 1424 | if labels is not None: 1425 | if self.config.problem_type is None: 1426 | if self.num_labels == 1: 1427 | self.config.problem_type = "regression" 1428 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1429 | self.config.problem_type = "single_label_classification" 1430 | else: 1431 | self.config.problem_type = "multi_label_classification" 1432 | 1433 | if self.config.problem_type == "regression": 1434 | loss_fct = MSELoss() 1435 | if self.num_labels == 1: 1436 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1437 | else: 1438 | loss = loss_fct(pooled_logits, labels) 1439 | elif self.config.problem_type == "single_label_classification": 1440 | loss_fct = CrossEntropyLoss() 1441 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1442 | elif self.config.problem_type == "multi_label_classification": 1443 | loss_fct = BCEWithLogitsLoss() 1444 | loss = loss_fct(pooled_logits, labels) 1445 | if not return_dict: 1446 | output = (pooled_logits,) + transformer_outputs[1:] 1447 | return ((loss,) + output) if loss is not None else output 1448 | 1449 | return SequenceClassifierOutputWithPast( 1450 | loss=loss, 1451 | logits=pooled_logits, 1452 | past_key_values=transformer_outputs.past_key_values, 1453 | hidden_states=transformer_outputs.hidden_states, 1454 | attentions=transformer_outputs.attentions, 1455 | ) 1456 | 1457 | 1458 | @add_start_docstrings( 1459 | """ 1460 | GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for 1461 | Named-Entity-Recognition (NER) tasks. 1462 | """, 1463 | GPT2_START_DOCSTRING, 1464 | ) 1465 | class GPT2ForTokenClassification(GPT2PreTrainedModel): 1466 | def __init__(self, config): 1467 | super().__init__(config) 1468 | self.num_labels = config.num_labels 1469 | 1470 | self.transformer = GPT2Model(config) 1471 | if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: 1472 | classifier_dropout = config.classifier_dropout 1473 | elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: 1474 | classifier_dropout = config.hidden_dropout 1475 | else: 1476 | classifier_dropout = 0.1 1477 | self.dropout = nn.Dropout(classifier_dropout) 1478 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1479 | 1480 | # Model parallel 1481 | self.model_parallel = False 1482 | self.device_map = None 1483 | 1484 | # Initialize weights and apply final processing 1485 | self.post_init() 1486 | 1487 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 1488 | # fmt: off 1489 | @add_code_sample_docstrings( 1490 | checkpoint="brad1141/gpt2-finetuned-comp2", 1491 | output_type=TokenClassifierOutput, 1492 | config_class=_CONFIG_FOR_DOC, 1493 | expected_loss=0.25, 1494 | expected_output=[ 1495 | "Lead", 1496 | "Lead", 1497 | "Lead", 1498 | "Position", 1499 | "Lead", 1500 | "Lead", 1501 | "Lead", 1502 | "Lead", 1503 | "Lead", 1504 | "Lead", 1505 | "Lead", 1506 | "Lead", 1507 | ], 1508 | ) 1509 | # fmt: on 1510 | def forward( 1511 | self, 1512 | input_ids: Optional[torch.LongTensor] = None, 1513 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 1514 | attention_mask: Optional[torch.FloatTensor] = None, 1515 | token_type_ids: Optional[torch.LongTensor] = None, 1516 | position_ids: Optional[torch.LongTensor] = None, 1517 | head_mask: Optional[torch.FloatTensor] = None, 1518 | inputs_embeds: Optional[torch.FloatTensor] = None, 1519 | labels: Optional[torch.LongTensor] = None, 1520 | use_cache: Optional[bool] = None, 1521 | output_attentions: Optional[bool] = None, 1522 | output_hidden_states: Optional[bool] = None, 1523 | return_dict: Optional[bool] = None, 1524 | ) -> Union[Tuple, TokenClassifierOutput]: 1525 | r""" 1526 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1527 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1528 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1529 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1530 | """ 1531 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1532 | 1533 | transformer_outputs = self.transformer( 1534 | input_ids, 1535 | past_key_values=past_key_values, 1536 | attention_mask=attention_mask, 1537 | token_type_ids=token_type_ids, 1538 | position_ids=position_ids, 1539 | head_mask=head_mask, 1540 | inputs_embeds=inputs_embeds, 1541 | use_cache=use_cache, 1542 | output_attentions=output_attentions, 1543 | output_hidden_states=output_hidden_states, 1544 | return_dict=return_dict, 1545 | ) 1546 | 1547 | hidden_states = transformer_outputs[0] 1548 | hidden_states = self.dropout(hidden_states) 1549 | logits = self.classifier(hidden_states) 1550 | 1551 | loss = None 1552 | if labels is not None: 1553 | labels = labels.to(logits.device) 1554 | loss_fct = CrossEntropyLoss() 1555 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1556 | 1557 | if not return_dict: 1558 | output = (logits,) + transformer_outputs[2:] 1559 | return ((loss,) + output) if loss is not None else output 1560 | 1561 | return TokenClassifierOutput( 1562 | loss=loss, 1563 | logits=logits, 1564 | hidden_states=transformer_outputs.hidden_states, 1565 | attentions=transformer_outputs.attentions, 1566 | ) 1567 | 1568 | 1569 | @add_start_docstrings( 1570 | """ 1571 | The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like 1572 | SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). 1573 | """, 1574 | GPT2_START_DOCSTRING, 1575 | ) 1576 | class GPT2ForQuestionAnswering(GPT2PreTrainedModel): 1577 | def __init__(self, config): 1578 | super().__init__(config) 1579 | self.num_labels = config.num_labels 1580 | self.transformer = GPT2Model(config) 1581 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 1582 | 1583 | # Model parallel 1584 | self.model_parallel = False 1585 | self.device_map = None 1586 | 1587 | # Initialize weights and apply final processing 1588 | self.post_init() 1589 | 1590 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1591 | @add_code_sample_docstrings( 1592 | checkpoint=_CHECKPOINT_FOR_DOC, 1593 | output_type=QuestionAnsweringModelOutput, 1594 | config_class=_CONFIG_FOR_DOC, 1595 | real_checkpoint=_CHECKPOINT_FOR_DOC, 1596 | ) 1597 | def forward( 1598 | self, 1599 | input_ids: Optional[torch.LongTensor] = None, 1600 | attention_mask: Optional[torch.FloatTensor] = None, 1601 | token_type_ids: Optional[torch.LongTensor] = None, 1602 | position_ids: Optional[torch.LongTensor] = None, 1603 | head_mask: Optional[torch.FloatTensor] = None, 1604 | inputs_embeds: Optional[torch.FloatTensor] = None, 1605 | start_positions: Optional[torch.LongTensor] = None, 1606 | end_positions: Optional[torch.LongTensor] = None, 1607 | output_attentions: Optional[bool] = None, 1608 | output_hidden_states: Optional[bool] = None, 1609 | return_dict: Optional[bool] = None, 1610 | ) -> Union[Tuple, QuestionAnsweringModelOutput]: 1611 | r""" 1612 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1613 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1614 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1615 | are not taken into account for computing the loss. 1616 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1617 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1618 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1619 | are not taken into account for computing the loss. 1620 | """ 1621 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1622 | 1623 | outputs = self.transformer( 1624 | input_ids, 1625 | attention_mask=attention_mask, 1626 | token_type_ids=token_type_ids, 1627 | position_ids=position_ids, 1628 | head_mask=head_mask, 1629 | inputs_embeds=inputs_embeds, 1630 | output_attentions=output_attentions, 1631 | output_hidden_states=output_hidden_states, 1632 | return_dict=return_dict, 1633 | ) 1634 | 1635 | sequence_output = outputs[0] 1636 | 1637 | logits = self.qa_outputs(sequence_output) 1638 | start_logits, end_logits = logits.split(1, dim=-1) 1639 | start_logits = start_logits.squeeze(-1).contiguous() 1640 | end_logits = end_logits.squeeze(-1).contiguous() 1641 | 1642 | total_loss = None 1643 | if start_positions is not None and end_positions is not None: 1644 | # If we are on multi-GPU, split add a dimension 1645 | if len(start_positions.size()) > 1: 1646 | start_positions = start_positions.squeeze(-1).to(start_logits.device) 1647 | if len(end_positions.size()) > 1: 1648 | end_positions = end_positions.squeeze(-1).to(end_logits.device) 1649 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1650 | ignored_index = start_logits.size(1) 1651 | start_positions = start_positions.clamp(0, ignored_index) 1652 | end_positions = end_positions.clamp(0, ignored_index) 1653 | 1654 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1655 | start_loss = loss_fct(start_logits, start_positions) 1656 | end_loss = loss_fct(end_logits, end_positions) 1657 | total_loss = (start_loss + end_loss) / 2 1658 | 1659 | if not return_dict: 1660 | output = (start_logits, end_logits) + outputs[2:] 1661 | return ((total_loss,) + output) if total_loss is not None else output 1662 | 1663 | return QuestionAnsweringModelOutput( 1664 | loss=total_loss, 1665 | start_logits=start_logits, 1666 | end_logits=end_logits, 1667 | hidden_states=outputs.hidden_states, 1668 | attentions=outputs.attentions, 1669 | ) 1670 | 1671 | 1672 | __all__ = [ 1673 | "GPT2DoubleHeadsModel", 1674 | "GPT2ForQuestionAnswering", 1675 | "GPT2ForSequenceClassification", 1676 | "GPT2ForTokenClassification", 1677 | "GPT2LMHeadModel", 1678 | "GPT2Model", 1679 | "GPT2PreTrainedModel", 1680 | "load_tf_weights_in_gpt2", 1681 | ] 1682 | -------------------------------------------------------------------------------- /train_gpt2.py: -------------------------------------------------------------------------------- 1 | """ 2 | This training script can be run both on a single gpu in debug mode, 3 | and also in a larger training run with distributed data parallel (ddp). 4 | 5 | To run on a single GPU, example: 6 | $ python train.py --batch_size=32 --compile=False 7 | 8 | To run with DDP on 4 gpus on 1 node, example: 9 | $ torchrun --standalone --nproc_per_node=4 train.py 10 | """ 11 | 12 | import os 13 | import time 14 | import math 15 | import pickle 16 | from contextlib import nullcontext 17 | import inspect 18 | 19 | import numpy as np 20 | import torch 21 | from torch.nn.parallel import DistributedDataParallel as DDP 22 | from torch.distributed import init_process_group, destroy_process_group 23 | from modeling_gpt2 import GPT2Model, GPT2LMHeadModel, GPT2Config 24 | 25 | 26 | overall_name = "gpt2_llama_0.5B_4k_100k" 27 | # ----------------------------------------------------------------------------- 28 | # default config values designed to train a gpt2 (124M) on OpenWebText 29 | # I/O 30 | out_dir = 'out_' + overall_name 31 | eval_interval = 1000 32 | log_interval = 1 33 | eval_iters = 200 34 | eval_only = False # if True, script exits right after the first eval 35 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 36 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 37 | # wandb logging 38 | wandb_log = True # disabled by default 39 | wandb_project = 'owt' 40 | wandb_run_name = overall_name 41 | # data 42 | dataset = 'openwebtext_llama' 43 | total_batch_size = 524288 44 | global_batch = 512 45 | batch_size = 8 # if gradient_accumulation_steps > 1, this is the micro-batch size 46 | block_size = 1024 47 | #gradient_accumulation_steps = total_batch_size // (batch_size * block_size) # used to simulate larger batch sizes 48 | gradient_accumulation_steps = global_batch // batch_size 49 | 50 | # adamw optimizer 51 | learning_rate = 30e-4 # max learning rate 52 | max_iters = 100000 # total number of training iterations 53 | weight_decay = 0.1 54 | beta1 = 0.9 55 | beta2 = 0.95 56 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 57 | # learning rate decay settings 58 | decay_lr = True # whether to decay the learning rate 59 | warmup_iters = 2000 # how many steps to warm up for 60 | lr_decay_iters = 100000 # should be ~= max_iters per Chinchilla 61 | min_lr = 0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 62 | # DDP settings 63 | backend = 'nccl' # 'nccl', 'gloo', etc. 64 | # system 65 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 66 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 67 | compile = True # use PyTorch 2.0 to compile the model to be faster 68 | # ----------------------------------------------------------------------------- 69 | config = {"lr": learning_rate, "weight_decay": weight_decay, "decay_lr": decay_lr, "warmup_iters": warmup_iters, "min_lr": min_lr} 70 | # various inits, derived attributes, I/O setup 71 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 72 | if ddp: 73 | init_process_group(backend=backend) 74 | ddp_rank = int(os.environ['RANK']) 75 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 76 | ddp_world_size = int(os.environ['WORLD_SIZE']) 77 | device = f'cuda:{ddp_local_rank}' 78 | torch.cuda.set_device(device) 79 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 80 | seed_offset = ddp_rank # each process gets a different seed 81 | # world_size number of processes will be training simultaneously, so we can scale 82 | # down the desired gradient accumulation iterations per process proportionally 83 | assert gradient_accumulation_steps % ddp_world_size == 0 84 | gradient_accumulation_steps //= ddp_world_size 85 | else: 86 | # if not ddp, we are running on a single gpu, and one process 87 | master_process = True 88 | seed_offset = 0 89 | ddp_world_size = 1 90 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size 91 | if master_process: 92 | print(f"tokens per iteration will be: {tokens_per_iter:,}") 93 | 94 | if master_process: 95 | os.makedirs(out_dir, exist_ok=True) 96 | torch.manual_seed(1337 + seed_offset) 97 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 98 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 99 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 100 | # note: float16 data type will automatically use a GradScaler 101 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 102 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 103 | 104 | # poor man's data loader 105 | data_dir = os.path.join('data', dataset) 106 | def get_batch(split): 107 | # We recreate np.memmap every batch to avoid a memory leak, as per 108 | # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 109 | if split == 'train': 110 | data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 111 | else: 112 | data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 113 | ix = torch.randint(len(data) - block_size, (batch_size,)) 114 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 115 | y = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 116 | if device_type == 'cuda': 117 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 118 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 119 | else: 120 | x, y = x.to(device), y.to(device) 121 | return x, y 122 | 123 | def configure_optimizers(model, weight_decay, learning_rate, device_type): 124 | param_dict = {pn: p for pn, p in model.named_parameters()} 125 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 126 | 127 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 128 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 129 | optim_groups = [ 130 | {'params': decay_params, 'weight_decay': weight_decay}, 131 | {'params': nodecay_params, 'weight_decay': 0.0} 132 | ] 133 | num_decay_params = sum(p.numel() for p in decay_params) 134 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 135 | if master_process: 136 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 137 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 138 | # Create AdamW optimizer and use the fused version if it is available 139 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 140 | use_fused = fused_available and device_type == "cuda" 141 | if master_process: 142 | print(f"using fused AdamW: {use_fused}") 143 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused) 144 | return optimizer 145 | 146 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 147 | iter_num = 0 148 | best_val_loss = 1e9 149 | 150 | # model init 151 | '''model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 152 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line''' 153 | if init_from == 'scratch': 154 | # init a new model from scratch 155 | if master_process: 156 | print("Initializing a new model from scratch") 157 | # determine the vocab size we'll use for from-scratch training 158 | 159 | #model = GPT2LMHeadModel(GPT2Config(vocab_size=50304, n_positions=block_size, n_embd=1024, n_layer=24, n_head=16, n_inner=4096)) 160 | model = GPT2LMHeadModel(GPT2Config(vocab_size=32000, n_positions=block_size, n_embd=1024, n_layer=24, n_head=16, n_inner=4096, bos_token_id=1, eos_token_id=2)) 161 | elif init_from == 'resume': 162 | print(f"Resuming training from {out_dir}") 163 | # resume training from a checkpoint. 164 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 165 | checkpoint = torch.load(ckpt_path, map_location=device) 166 | checkpoint_model_args = checkpoint['model_args'] 167 | # force these config attributes to be equal otherwise we can't even resume training 168 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 169 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 170 | model_args[k] = checkpoint_model_args[k] 171 | # create the model 172 | gptconf = GPTConfig(**model_args) 173 | model = GPT(gptconf) 174 | state_dict = checkpoint['model'] 175 | # fix the keys of the state dictionary :( 176 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 177 | unwanted_prefix = '_orig_mod.' 178 | for k,v in list(state_dict.items()): 179 | if k.startswith(unwanted_prefix): 180 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 181 | model.load_state_dict(state_dict) 182 | iter_num = checkpoint['iter_num'] 183 | best_val_loss = checkpoint['best_val_loss'] 184 | 185 | model.to(device) 186 | 187 | # initialize a GradScaler. If enabled=False scaler is a no-op 188 | scaler = torch.GradScaler("cuda", enabled=(dtype == 'float16')) 189 | 190 | # optimizer 191 | optimizer = configure_optimizers(model, weight_decay, learning_rate, device_type) 192 | 193 | if init_from == 'resume': 194 | optimizer.load_state_dict(checkpoint['optimizer']) 195 | checkpoint = None # free up memory 196 | 197 | # compile the model 198 | if compile: 199 | if master_process: 200 | print("compiling the model... (takes a ~minute)") 201 | unoptimized_model = model 202 | model = torch.compile(model) # requires PyTorch 2.0 203 | 204 | # wrap model into DDP container 205 | if ddp: 206 | model = DDP(model, device_ids=[ddp_local_rank]) 207 | 208 | # helps estimate an arbitrarily accurate loss over either split using many batches 209 | @torch.no_grad() 210 | def estimate_loss(): 211 | out = {} 212 | model.eval() 213 | for split in ['train', 'val']: 214 | losses = torch.zeros(eval_iters) 215 | for k in range(eval_iters): 216 | X, Y = get_batch(split) 217 | with ctx: 218 | ret = model(input_ids=X, labels=Y) 219 | logits,loss = ret["logits"], ret["loss"] 220 | losses[k] = loss.item() 221 | out[split] = losses.mean() 222 | model.train() 223 | return out 224 | 225 | # learning rate decay scheduler (cosine with warmup) 226 | def get_lr(it): 227 | # 1) linear warmup for warmup_iters steps 228 | if it < warmup_iters: 229 | return learning_rate * (it + 1) / (warmup_iters + 1) 230 | # 2) if it > lr_decay_iters, return min learning rate 231 | if it > lr_decay_iters: 232 | return min_lr 233 | # 3) in between, use cosine decay down to min learning rate 234 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 235 | assert 0 <= decay_ratio <= 1 236 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 237 | return min_lr + coeff * (learning_rate - min_lr) 238 | 239 | # logging 240 | if wandb_log and master_process: 241 | import wandb 242 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 243 | 244 | # training loop 245 | X, Y = get_batch('train') # fetch the very first batch 246 | t0 = time.time() 247 | local_iter_num = 0 # number of iterations in the lifetime of this process 248 | raw_model = model.module if ddp else model # unwrap DDP container if needed 249 | running_mfu = -1.0 250 | while True: 251 | 252 | # determine and set the learning rate for this iteration 253 | lr = get_lr(iter_num) if decay_lr else learning_rate 254 | for param_group in optimizer.param_groups: 255 | param_group['lr'] = lr 256 | 257 | # evaluate the loss on train/val sets and write checkpoints 258 | if iter_num % eval_interval == 0 and master_process: 259 | losses = estimate_loss() 260 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 261 | if wandb_log: 262 | wandb.log({ 263 | "iter": iter_num, 264 | "train/loss": losses['train'], 265 | "val/loss": losses['val'], 266 | "lr": lr, 267 | }) 268 | if losses['val'] < best_val_loss or always_save_checkpoint: 269 | best_val_loss = losses['val'] 270 | if iter_num > 0: 271 | checkpoint = { 272 | 'model': raw_model.state_dict(), 273 | 'optimizer': optimizer.state_dict(), 274 | 'iter_num': iter_num, 275 | 'best_val_loss': best_val_loss, 276 | 'config': config, 277 | } 278 | print(f"saving checkpoint to {out_dir}") 279 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 280 | if iter_num == 0 and eval_only: 281 | break 282 | 283 | # forward backward update, with optional gradient accumulation to simulate larger batch size 284 | # and using the GradScaler if data type is float16 285 | for micro_step in range(gradient_accumulation_steps): 286 | if ddp: 287 | # in DDP training we only need to sync gradients at the last micro step. 288 | # the official way to do this is with model.no_sync() context manager, but 289 | # I really dislike that this bloats the code and forces us to repeat code 290 | # looking at the source of that context manager, it just toggles this variable 291 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 292 | with ctx: 293 | ret = model(input_ids=X, labels=Y) 294 | logits,loss = ret["logits"], ret["loss"] 295 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation 296 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 297 | X, Y = get_batch('train') 298 | # backward pass, with gradient scaling if training in fp16 299 | scaler.scale(loss).backward() 300 | # clip the gradient 301 | if grad_clip != 0.0: 302 | scaler.unscale_(optimizer) 303 | norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 304 | # step the optimizer and scaler if training in fp16 305 | scaler.step(optimizer) 306 | scaler.update() 307 | # flush the gradients as soon as we can, no need for this memory anymore 308 | optimizer.zero_grad(set_to_none=True) 309 | torch.cuda.synchronize() 310 | # timing and logging 311 | t1 = time.time() 312 | dt = t1 - t0 313 | t0 = t1 314 | if iter_num % log_interval == 0 and master_process: 315 | # get loss as float. note: this is a CPU-GPU sync point 316 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum) 317 | lossf = loss.item() * gradient_accumulation_steps 318 | tokens_per_sec = tokens_per_iter / dt 319 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, lr {lr}, norm: {norm: .4f}, tok/sec: {tokens_per_sec}") 320 | iter_num += 1 321 | local_iter_num += 1 322 | 323 | # termination conditions 324 | if iter_num > max_iters: 325 | break 326 | 327 | if ddp: 328 | destroy_process_group() 329 | -------------------------------------------------------------------------------- /train_ngpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | This training script can be run both on a single gpu in debug mode, 3 | and also in a larger training run with distributed data parallel (ddp). 4 | 5 | To run on a single GPU, example: 6 | $ python train.py --batch_size=32 --compile=False 7 | 8 | To run with DDP on 4 gpus on 1 node, example: 9 | $ torchrun --standalone --nproc_per_node=4 train.py 10 | 11 | """ 12 | 13 | import os 14 | import time 15 | import math 16 | import pickle 17 | from contextlib import nullcontext 18 | import inspect 19 | 20 | import numpy as np 21 | import torch 22 | from torch.nn.parallel import DistributedDataParallel as DDP 23 | from torch.distributed import init_process_group, destroy_process_group 24 | from modeling_ngpt import NgptModel, NgptLMHeadModel, NgptConfig 25 | 26 | # ----------------------------------------------------------------------------- 27 | # default config values designed to train a Ngpt (124M) on OpenWebText 28 | # I/O 29 | overall_name = "ngpt_llama_0.5B_1k_100k" 30 | 31 | out_dir = 'out' + overall_name 32 | eval_interval = 2000 33 | log_interval = 1 34 | eval_iters = 200 35 | eval_only = False # if True, script exits right after the first eval 36 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 37 | init_from = 'scratch' # 'scratch' or 'resume' or 'Ngpt*' 38 | # wandb logging 39 | wandb_log = True # disabled by default 40 | wandb_project = 'owt' 41 | wandb_run_name = overall_name # 'run' + str(time.time()) 42 | # data 43 | dataset = 'openwebtext_llama' 44 | global_batch = 512 45 | batch_size = 4 # if gradient_accumulation_steps > 1, this is the micro-batch size 46 | block_size = 1024 47 | #gradient_accumulation_steps = total_batch_size // (batch_size * block_size) # used to simulate larger batch sizes 48 | gradient_accumulation_steps = global_batch // batch_size 49 | 50 | # adamw optimizer 51 | learning_rate = 30e-4 # max learning rate 52 | max_iters = 100000 # total number of training iterations 53 | weight_decay = 0.0 54 | beta1 = 0.9 55 | beta2 = 0.95 56 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 57 | # learning rate decay settings 58 | decay_lr = True # whether to decay the learning rate 59 | warmup_iters = 0 # how many steps to warm up for 60 | lr_decay_iters = 100000 # should be ~= max_iters per Chinchilla 61 | min_lr = 0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 62 | # DDP settings 63 | backend = 'nccl' # 'nccl', 'gloo', etc. 64 | # system 65 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 66 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 67 | compile = True # use PyTorch 2.0 to compile the model to be faster 68 | # ----------------------------------------------------------------------------- 69 | config = {"lr": learning_rate, "weight_decay": weight_decay, "decay_lr": decay_lr, "warmup_iters": warmup_iters, "min_lr": min_lr} 70 | # various inits, derived attributes, I/O setup 71 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 72 | if ddp: 73 | init_process_group(backend=backend) 74 | ddp_rank = int(os.environ['RANK']) 75 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 76 | ddp_world_size = int(os.environ['WORLD_SIZE']) 77 | device = f'cuda:{ddp_local_rank}' 78 | torch.cuda.set_device(device) 79 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 80 | seed_offset = ddp_rank # each process gets a different seed 81 | # world_size number of processes will be training simultaneously, so we can scale 82 | # down the desired gradient accumulation iterations per process proportionally 83 | assert gradient_accumulation_steps % ddp_world_size == 0 84 | gradient_accumulation_steps //= ddp_world_size 85 | else: 86 | # if not ddp, we are running on a single gpu, and one process 87 | master_process = True 88 | seed_offset = 0 89 | ddp_world_size = 1 90 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size 91 | if master_process: 92 | print(f"tokens per iteration will be: {tokens_per_iter:,}") 93 | 94 | if master_process: 95 | os.makedirs(out_dir, exist_ok=True) 96 | torch.manual_seed(1337 + seed_offset) 97 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 98 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 99 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 100 | # note: float16 data type will automatically use a GradScaler 101 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 102 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 103 | 104 | # poor man's data loader 105 | data_dir = os.path.join('data', dataset) 106 | def get_batch(split): 107 | # We recreate np.memmap every batch to avoid a memory leak, as per 108 | # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 109 | if split == 'train': 110 | data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 111 | else: 112 | data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 113 | ix = torch.randint(len(data) - block_size, (batch_size,)) 114 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 115 | y = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 116 | if device_type == 'cuda': 117 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 118 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 119 | else: 120 | x, y = x.to(device), y.to(device) 121 | return x, y 122 | 123 | def configure_optimizers(model, weight_decay, learning_rate, device_type): 124 | param_dict = {pn: p for pn, p in model.named_parameters()} 125 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 126 | 127 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 128 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 129 | optim_groups = [ 130 | {'params': decay_params, 'weight_decay': weight_decay}, 131 | {'params': nodecay_params, 'weight_decay': 0.0} 132 | ] 133 | num_decay_params = sum(p.numel() for p in decay_params) 134 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 135 | if master_process: 136 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 137 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 138 | # Create AdamW optimizer and use the fused version if it is available 139 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 140 | use_fused = fused_available and device_type == "cuda" 141 | if master_process: 142 | print(f"using fused AdamW: {use_fused}") 143 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused) 144 | return optimizer 145 | 146 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 147 | iter_num = 0 148 | best_val_loss = 1e9 149 | 150 | # model init 151 | '''model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 152 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line''' 153 | 154 | #model_config = NgptConfig(vocab_size=50304, n_embd=1024, n_layer=24, n_head=16, n_inner=4096) 155 | model_config = NgptConfig(vocab_size=32000, n_positions=block_size, n_embd=1024, n_layer=24, n_head=16, n_inner=4096, bos_token_id = 1, eos_token_id = 2) 156 | if init_from == 'scratch': 157 | # init a new model from scratch 158 | if master_process: 159 | print("Initializing a new model from scratch") 160 | # determine the vocab size we'll use for from-scratch training 161 | 162 | model = NgptLMHeadModel(model_config) 163 | elif init_from == 'resume': 164 | print(f"Resuming training from {out_dir}") 165 | # resume training from a checkpoint. 166 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 167 | checkpoint = torch.load(ckpt_path, map_location=device) 168 | checkpoint_model_args = checkpoint['model_args'] 169 | # force these config attributes to be equal otherwise we can't even resume training 170 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 171 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 172 | model_args[k] = checkpoint_model_args[k] 173 | # create the model 174 | gptconf = GPTConfig(**model_args) 175 | model = GPT(gptconf) 176 | state_dict = checkpoint['model'] 177 | # fix the keys of the state dictionary :( 178 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 179 | unwanted_prefix = '_orig_mod.' 180 | for k,v in list(state_dict.items()): 181 | if k.startswith(unwanted_prefix): 182 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 183 | model.load_state_dict(state_dict) 184 | iter_num = checkpoint['iter_num'] 185 | best_val_loss = checkpoint['best_val_loss'] 186 | 187 | model.to(device) 188 | 189 | # initialize a GradScaler. If enabled=False scaler is a no-op 190 | scaler = torch.GradScaler("cuda", enabled=(dtype == 'float16')) 191 | 192 | # optimizer 193 | optimizer = configure_optimizers(model, weight_decay, learning_rate, device_type) 194 | 195 | if init_from == 'resume': 196 | optimizer.load_state_dict(checkpoint['optimizer']) 197 | checkpoint = None # free up memory 198 | 199 | # compile the model 200 | if compile: 201 | if master_process: 202 | print("compiling the model... (takes a ~minute)") 203 | unoptimized_model = model 204 | model = torch.compile(model) # requires PyTorch 2.0 205 | 206 | # wrap model into DDP container 207 | if ddp: 208 | model = DDP(model, device_ids=[ddp_local_rank]) 209 | 210 | # helps estimate an arbitrarily accurate loss over either split using many batches 211 | @torch.no_grad() 212 | def estimate_loss(): 213 | out = {} 214 | model.eval() 215 | for split in ['train', 'val']: 216 | losses = torch.zeros(eval_iters) 217 | for k in range(eval_iters): 218 | X, Y = get_batch(split) 219 | with ctx: 220 | ret = model(input_ids=X, labels=Y) 221 | logits,loss = ret["logits"], ret["loss"] 222 | losses[k] = loss.item() 223 | out[split] = losses.mean() 224 | model.train() 225 | return out 226 | 227 | # learning rate decay scheduler (cosine with warmup) 228 | def get_lr(it): 229 | # 1) linear warmup for warmup_iters steps 230 | if it < warmup_iters: 231 | return learning_rate * (it + 1) / (warmup_iters + 1) 232 | # 2) if it > lr_decay_iters, return min learning rate 233 | if it > lr_decay_iters: 234 | return min_lr 235 | # 3) in between, use cosine decay down to min learning rate 236 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 237 | assert 0 <= decay_ratio <= 1 238 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 239 | return min_lr + coeff * (learning_rate - min_lr) 240 | 241 | # logging 242 | if wandb_log and master_process: 243 | import wandb 244 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 245 | 246 | # training loop 247 | X, Y = get_batch('train') # fetch the very first batch 248 | t0 = time.time() 249 | local_iter_num = 0 # number of iterations in the lifetime of this process 250 | raw_model = model.module if ddp else model # unwrap DDP container if needed 251 | running_mfu = -1.0 252 | 253 | def justnorm(x, idim): 254 | dtype = x.dtype 255 | x = x.float() 256 | res = (x / x.norm(p=2, dim=idim, keepdim=True)).to(dtype=dtype) 257 | return res 258 | 259 | def normalize_matrices(): 260 | #print(raw_model.test.weight.data.shape) 261 | #print(raw_model.transformer.wte.weight.data.shape) 262 | #print(raw_model.lm_head.weight.data.shape) 263 | raw_model.transformer.wte.weight.data.copy_(justnorm(raw_model.transformer.wte.weight.data, 1)) 264 | raw_model.lm_head.weight.data.copy_(justnorm(raw_model.lm_head.weight.data, 1)) 265 | 266 | for layer_idx in range(model_config.num_hidden_layers): 267 | block = raw_model.transformer.h[layer_idx] 268 | 269 | #print(block.attn.c_attn.weight.data.shape) 270 | block.attn.c_attn.weight.data.copy_(justnorm(block.attn.c_attn.weight.data, 0)) #n_embd, 3*n_embd 271 | #print(block.attn.c_proj.weight.data.shape) 272 | block.attn.c_proj.weight.data.copy_(justnorm(block.attn.c_proj.weight.data, 1)) #n_proj, n_embd 273 | 274 | #print(block.mlp.c_fc.weight.data.shape) 275 | #print(block.mlp.c_proj.weight.data.shape) 276 | block.mlp.c_fc.weight.data.copy_(justnorm(block.mlp.c_fc.weight.data, 0)) 277 | block.mlp.c_proj.weight.data.copy_(justnorm(block.mlp.c_proj.weight.data, 1)) 278 | 279 | normalize_matrices() 280 | 281 | while True: 282 | 283 | # determine and set the learning rate for this iteration 284 | lr = get_lr(iter_num) if decay_lr else learning_rate 285 | for param_group in optimizer.param_groups: 286 | param_group['lr'] = lr 287 | 288 | # evaluate the loss on train/val sets and write checkpoints 289 | if iter_num % eval_interval == 0 and master_process: 290 | losses = estimate_loss() 291 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 292 | if wandb_log: 293 | wandb.log({ 294 | "iter": iter_num, 295 | "train/loss": losses['train'], 296 | "val/loss": losses['val'], 297 | "lr": lr, 298 | }) 299 | if losses['val'] < best_val_loss or always_save_checkpoint: 300 | best_val_loss = losses['val'] 301 | if iter_num > 0: 302 | checkpoint = { 303 | 'model': raw_model.state_dict(), 304 | 'optimizer': optimizer.state_dict(), 305 | 'iter_num': iter_num, 306 | 'best_val_loss': best_val_loss, 307 | 'config': config, 308 | } 309 | print(f"saving checkpoint to {out_dir}") 310 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 311 | if iter_num == 0 and eval_only: 312 | break 313 | 314 | # forward backward update, with optional gradient accumulation to simulate larger batch size 315 | # and using the GradScaler if data type is float16 316 | for micro_step in range(gradient_accumulation_steps): 317 | if ddp: 318 | # in DDP training we only need to sync gradients at the last micro step. 319 | # the official way to do this is with model.no_sync() context manager, but 320 | # I really dislike that this bloats the code and forces us to repeat code 321 | # looking at the source of that context manager, it just toggles this variable 322 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 323 | with ctx: 324 | ret = model(input_ids=X, labels=Y) 325 | logits,loss = ret["logits"], ret["loss"] 326 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation 327 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 328 | X, Y = get_batch('train') 329 | # backward pass, with gradient scaling if training in fp16 330 | scaler.scale(loss).backward() 331 | # clip the gradient 332 | if grad_clip != 0.0: 333 | scaler.unscale_(optimizer) 334 | norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 335 | # step the optimizer and scaler if training in fp16 336 | scaler.step(optimizer) 337 | scaler.update() 338 | # flush the gradients as soon as we can, no need for this memory anymore 339 | optimizer.zero_grad(set_to_none=True) 340 | torch.cuda.synchronize() 341 | 342 | normalize_matrices() 343 | # timing and logging 344 | t1 = time.time() 345 | dt = t1 - t0 346 | t0 = t1 347 | if iter_num % log_interval == 0 and master_process: 348 | # get loss as float. note: this is a CPU-GPU sync point 349 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum) 350 | lossf = loss.item() * gradient_accumulation_steps 351 | tokens_per_sec = tokens_per_iter / dt 352 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, lr {lr}, norm: {norm: .4f}, tok/sec: {tokens_per_sec}") 353 | iter_num += 1 354 | local_iter_num += 1 355 | 356 | # termination conditions 357 | if iter_num > max_iters: 358 | break 359 | 360 | if ddp: 361 | destroy_process_group() 362 | --------------------------------------------------------------------------------