├── LICENSE
├── README.md
├── configuration_gpt2.py
├── configuration_ngpt.py
├── data
└── openwebtext_llama
│ └── prepare.py
├── images
├── 4k_arceasy.png
├── 4k_average.png
├── 4k_hellaswag.png
├── 4k_lambada.png
├── 4k_loss.png
├── 4k_winogrande.png
├── 4k_wsc273.png
├── arc_easy.png
├── average.png
├── hellaswag.png
├── lambada.png
├── loss.png
├── network_params.png
├── optimizer_params.png
├── wall_clock.png
├── winogrande.png
└── wsc273.png
├── modeling_gpt2.py
├── modeling_ngpt.py
├── train_gpt2.py
└── train_ngpt.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Joe Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Sponsored by Nous Research Inc., 2025.
2 |
3 | # nGPT
4 | This is an open-source reproduction of NVIDIA's [nGPT](https://arxiv.org/abs/2410.01131) (Normalized Transformer with Representation Learning on the Hypersphere) paper by Loshchilov et al., which claims to reduce "the number of training steps required to achieve the same accuracy by a factor of 4 to 20, depending on the sequence length," compared to a baseline transformer model.
5 | ## Project Overview
6 | This repository provides modeling and training code for a modified GPT-2 and nGPT model as well as our results. Both models were pre-trained on [OpenWebText](https://huggingface.co/datasets/Skylion007/openwebtext). We attempt to adhere to the paper's specifications as closely as possible.
7 |
8 | ### Dependencies
9 |
10 | - **Hugging Face [transformers](https://github.com/huggingface/transformers) library**: `modeling_ngpt.py` and `modeling_gpt2.py` extend the `PreTrainedModel` class that Hugging Face provides.
11 | - **[nanoGPT](https://github.com/karpathy/nanoGPT)**: the training and data generation code build off this repository (`train_ngpt.py`, `train_gpt2.py`, `data/openwebtext_llama/prepare.py`)
12 | - **[EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)**: used for hellaswag, arc easy, winogrande, wsc273, and lambada-openai evals
13 |
14 | ### Key Modifications
15 | #### Tokenization
16 |
17 | The LLaMA tokenizer (vocab size 32000) is used instead of the GPT tokenizer (vocab size 50257). See `data/openwebtext_llama/prepare.py`.
18 | #### GPT-2
19 |
20 | - SwiGLU activation function
21 | - Rotary Position Embeddings (RoPE)
22 | - No weight tying between the input token embedding layer and the final output logits layer
23 |
24 | #### nGPT
25 | See paper for detailed explanations, particularly Section 2.6. Briefly, all hidden vectors and weight vectors that lie along the embedding dimension are normalized to have unit norm and lie on the same unit hypersphere.
26 |
27 | ### Training
28 |
29 | 0.5B models with 1024 and 4096 context length were trained on OpenWebText. We use the same parameters as specified in the nGPT paper as shown below. We use an initial learning rate of 15e-4 for 1024 context length and 30e-4 for 4096 context length. Here is the [model card](https://huggingface.co/NousResearch/ngpt_0.5B_4k_200B) for a model trained on 200B tokens at 4k context length.
30 |
31 |

32 |

33 |
34 |
35 | ### Results
36 |
37 | By visual inspection of the following graphs, we observe roughly 1.5-2x and 4x speedup at ~200 billion tokens for 1k and 4k context length, respectively. Note that every data point on the following graphs represent **separate** training runs that each ran to completion (i.e. different learning rate schedules).
38 |
39 | #### Loss
40 |
41 |

42 |

43 |
44 |
45 | #### Performance on downstream tasks (1k context)
46 |
47 |
52 |
57 |
58 | #### Performance on downstream tasks (4k context)
59 |
60 |
65 |
70 |
71 | #### Wall-Clock Time for Training on 100B training tokens with 8xH100
72 |
73 |
74 | ### Analysis
75 |
76 | We observe a smaller speedup than the nGPT paper claims. This is because the model parameters in this experimental reproduction are stored in float32 while the original paper stores them in bfloat16. NVIDIA's open source reproduction claims the following:
77 |
78 | > In order to reflect our experimental setup of the paper where parameters of matrices are in bfloat16, we also set bfloat16 as the dtype of network parameters (all except embeddings). Apparently, the change from float32 to bfloat16 only moderately affects nGPT but greatly degrades performance of the baseline GPT.
79 |
80 | We observe the above results in our reproduction as well — our nGPT model closely matches the experimental results, but our GPT-2 model performs better with float32 model parameters. For 0.5B models at 1k context length, while the paper claims 4x speedup with bfloat16 parameters at ~400 billion tokens, our nGPT reproduction achieves roughly 1.5-2x speedup with float32 parameters at ~200 billion tokens. For 0.5B models at 4k context length, while the paper claims 10x speedup with bfloat16 parameters at ~400 billion tokens, nGPT achieves roughly 4x speedup with float32 parameters at ~200 billion tokens. Moreover, we observe greater speedups for longer training runs.
81 |
--------------------------------------------------------------------------------
/configuration_gpt2.py:
--------------------------------------------------------------------------------
1 | """OpenAI GPT-2 configuration"""
2 |
3 | from collections import OrderedDict
4 | from typing import Any, List, Mapping, Optional
5 |
6 | from transformers import PreTrainedTokenizer, TensorType, is_torch_available
7 | from transformers import PretrainedConfig
8 | from transformers.utils import logging
9 |
10 |
11 | logger = logging.get_logger(__name__)
12 |
13 |
14 | class GPT2Config(PretrainedConfig):
15 | """
16 | This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to
17 | instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a
18 | configuration with the defaults will yield a similar configuration to that of the GPT-2
19 | [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) architecture.
20 |
21 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22 | documentation from [`PretrainedConfig`] for more information.
23 |
24 |
25 | Args:
26 | vocab_size (`int`, *optional*, defaults to 50257):
27 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
28 | `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
29 | n_positions (`int`, *optional*, defaults to 1024):
30 | The maximum sequence length that this model might ever be used with. Typically set this to something large
31 | just in case (e.g., 512 or 1024 or 2048).
32 | n_embd (`int`, *optional*, defaults to 768):
33 | Dimensionality of the embeddings and hidden states.
34 | n_layer (`int`, *optional*, defaults to 12):
35 | Number of hidden layers in the Transformer encoder.
36 | n_head (`int`, *optional*, defaults to 12):
37 | Number of attention heads for each attention layer in the Transformer encoder.
38 | n_inner (`int`, *optional*):
39 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
40 | activation_function (`str`, *optional*, defaults to `"gelu_new"`):
41 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
42 | resid_pdrop (`float`, *optional*, defaults to 0.1):
43 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
44 | embd_pdrop (`float`, *optional*, defaults to 0.1):
45 | The dropout ratio for the embeddings.
46 | attn_pdrop (`float`, *optional*, defaults to 0.1):
47 | The dropout ratio for the attention.
48 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
49 | The epsilon to use in the layer normalization layers.
50 | initializer_range (`float`, *optional*, defaults to 0.02):
51 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
52 | summary_type (`string`, *optional*, defaults to `"cls_index"`):
53 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
54 | [`TFGPT2DoubleHeadsModel`].
55 |
56 | Has to be one of the following options:
57 |
58 | - `"last"`: Take the last token hidden state (like XLNet).
59 | - `"first"`: Take the first token hidden state (like BERT).
60 | - `"mean"`: Take the mean of all tokens hidden states.
61 | - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
62 | - `"attn"`: Not implemented now, use multi-head attention.
63 | summary_use_proj (`bool`, *optional*, defaults to `True`):
64 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
65 | [`TFGPT2DoubleHeadsModel`].
66 |
67 | Whether or not to add a projection after the vector extraction.
68 | summary_activation (`str`, *optional*):
69 | Argument used when doing sequence summary. Used in for the multiple choice head in
70 | [`GPT2DoubleHeadsModel`].
71 |
72 | Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
73 | summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
74 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
75 | [`TFGPT2DoubleHeadsModel`].
76 |
77 | Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
78 | summary_first_dropout (`float`, *optional*, defaults to 0.1):
79 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
80 | [`TFGPT2DoubleHeadsModel`].
81 |
82 | The dropout ratio to be used after the projection and activation.
83 | scale_attn_weights (`bool`, *optional*, defaults to `True`):
84 | Scale attention weights by dividing by sqrt(hidden_size)..
85 | use_cache (`bool`, *optional*, defaults to `True`):
86 | Whether or not the model should return the last key/values attentions (not used by all models).
87 | bos_token_id (`int`, *optional*, defaults to 50256):
88 | Id of the beginning of sentence token in the vocabulary.
89 | eos_token_id (`int`, *optional*, defaults to 50256):
90 | Id of the end of sentence token in the vocabulary.
91 | scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
92 | Whether to additionally scale attention weights by `1 / layer_idx + 1`.
93 | reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
94 | Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
95 | dot-product/softmax to float() when training with mixed precision.
96 |
97 | Example:
98 |
99 | ```python
100 | >>> from transformers import GPT2Config, GPT2Model
101 |
102 | >>> # Initializing a GPT2 configuration
103 | >>> configuration = GPT2Config()
104 |
105 | >>> # Initializing a model (with random weights) from the configuration
106 | >>> model = GPT2Model(configuration)
107 |
108 | >>> # Accessing the model configuration
109 | >>> configuration = model.config
110 | ```"""
111 |
112 | model_type = "gpt2"
113 | keys_to_ignore_at_inference = ["past_key_values"]
114 | attribute_map = {
115 | "hidden_size": "n_embd",
116 | "max_position_embeddings": "n_positions",
117 | "num_attention_heads": "n_head",
118 | "num_hidden_layers": "n_layer",
119 | }
120 |
121 | def __init__(
122 | self,
123 | vocab_size=50257,
124 | n_positions=1024,
125 | n_embd=768,
126 | n_layer=12,
127 | n_head=12,
128 | n_inner=None,
129 | activation_function="gelu_new",
130 | resid_pdrop=0.0,
131 | embd_pdrop=0.0,
132 | attn_pdrop=0.0,
133 | layer_norm_epsilon=1e-5,
134 | initializer_range=0.02,
135 | summary_type="cls_index",
136 | summary_use_proj=True,
137 | summary_activation=None,
138 | summary_proj_to_labels=True,
139 | summary_first_dropout=0.0,
140 | scale_attn_weights=True,
141 | use_cache=True,
142 | bos_token_id=50256,
143 | eos_token_id=50256,
144 | scale_attn_by_inverse_layer_idx=False,
145 | reorder_and_upcast_attn=False,
146 | **kwargs,
147 | ):
148 | self.vocab_size = vocab_size
149 | self.n_positions = n_positions
150 | self.n_embd = n_embd
151 | self.n_layer = n_layer
152 | self.n_head = n_head
153 | self.n_inner = n_inner
154 | self.activation_function = activation_function
155 | self.resid_pdrop = resid_pdrop
156 | self.embd_pdrop = embd_pdrop
157 | self.attn_pdrop = attn_pdrop
158 | self.layer_norm_epsilon = layer_norm_epsilon
159 | self.initializer_range = initializer_range
160 | self.summary_type = summary_type
161 | self.summary_use_proj = summary_use_proj
162 | self.summary_activation = summary_activation
163 | self.summary_first_dropout = summary_first_dropout
164 | self.summary_proj_to_labels = summary_proj_to_labels
165 | self.scale_attn_weights = scale_attn_weights
166 | self.use_cache = use_cache
167 | self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
168 | self.reorder_and_upcast_attn = reorder_and_upcast_attn
169 |
170 | self.bos_token_id = bos_token_id
171 | self.eos_token_id = eos_token_id
172 |
173 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
174 |
175 |
176 |
--------------------------------------------------------------------------------
/configuration_ngpt.py:
--------------------------------------------------------------------------------
1 | """Ngpt configuration"""
2 |
3 | from collections import OrderedDict
4 | from typing import Any, List, Mapping, Optional
5 |
6 | from transformers import PreTrainedTokenizer, TensorType, is_torch_available
7 | from transformers import PretrainedConfig
8 | from transformers import logging
9 |
10 |
11 | logger = logging.get_logger(__name__)
12 |
13 |
14 | class NgptConfig(PretrainedConfig):
15 | """
16 | This is the configuration class to store the configuration of a [`NgptModel`] or a [`TFNgptModel`]. It is used to
17 | instantiate a nGPT model according to the specified arguments, defining the model architecture. Instantiating a
18 | configuration with the defaults will yield a similar configuration to that of the GPT-2
19 | [](https://huggingface.co/) architecture.
20 |
21 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22 | documentation from [`PretrainedConfig`] for more information.
23 |
24 |
25 | Args:
26 | vocab_size (`int`, *optional*, defaults to 50257):
27 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
28 | `inputs_ids` passed when calling [`NgptModel`] or [`TFNgptModel`].
29 | n_positions (`int`, *optional*, defaults to 1024):
30 | The maximum sequence length that this model might ever be used with. Typically set this to something large
31 | just in case (e.g., 512 or 1024 or 2048).
32 | n_embd (`int`, *optional*, defaults to 768):
33 | Dimensionality of the embeddings and hidden states.
34 | n_layer (`int`, *optional*, defaults to 12):
35 | Number of hidden layers in the Transformer encoder.
36 | n_head (`int`, *optional*, defaults to 12):
37 | Number of attention heads for each attention layer in the Transformer encoder.
38 | n_inner (`int`, *optional*):
39 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
40 | activation_function (`str`, *optional*, defaults to `"gelu_new"`):
41 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
42 | resid_pdrop (`float`, *optional*, defaults to 0.1):
43 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
44 | embd_pdrop (`float`, *optional*, defaults to 0.1):
45 | The dropout ratio for the embeddings.
46 | attn_pdrop (`float`, *optional*, defaults to 0.1):
47 | The dropout ratio for the attention.
48 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
49 | The epsilon to use in the layer normalization layers.
50 | initializer_range (`float`, *optional*, defaults to 0.02):
51 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
52 | summary_type (`string`, *optional*, defaults to `"cls_index"`):
53 | Argument used when doing sequence summary, used in the models [`NgptDoubleHeadsModel`] and
54 | [`TFNgptDoubleHeadsModel`].
55 |
56 | Has to be one of the following options:
57 |
58 | - `"last"`: Take the last token hidden state (like XLNet).
59 | - `"first"`: Take the first token hidden state (like BERT).
60 | - `"mean"`: Take the mean of all tokens hidden states.
61 | - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
62 | - `"attn"`: Not implemented now, use multi-head attention.
63 | summary_use_proj (`bool`, *optional*, defaults to `True`):
64 | Argument used when doing sequence summary, used in the models [`NgptDoubleHeadsModel`] and
65 | [`TFNgptDoubleHeadsModel`].
66 |
67 | Whether or not to add a projection after the vector extraction.
68 | summary_activation (`str`, *optional*):
69 | Argument used when doing sequence summary. Used in for the multiple choice head in
70 | [`NgptDoubleHeadsModel`].
71 |
72 | Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
73 | summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
74 | Argument used when doing sequence summary, used in the models [`NgptDoubleHeadsModel`] and
75 | [`TFNgptDoubleHeadsModel`].
76 |
77 | Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
78 | summary_first_dropout (`float`, *optional*, defaults to 0.1):
79 | Argument used when doing sequence summary, used in the models [`NgptDoubleHeadsModel`] and
80 | [`TFNgptDoubleHeadsModel`].
81 |
82 | The dropout ratio to be used after the projection and activation.
83 | scale_attn_weights (`bool`, *optional*, defaults to `True`):
84 | Scale attention weights by dividing by sqrt(hidden_size)..
85 | use_cache (`bool`, *optional*, defaults to `True`):
86 | Whether or not the model should return the last key/values attentions (not used by all models).
87 | bos_token_id (`int`, *optional*, defaults to 50256):
88 | Id of the beginning of sentence token in the vocabulary.
89 | eos_token_id (`int`, *optional*, defaults to 50256):
90 | Id of the end of sentence token in the vocabulary.
91 | scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
92 | Whether to additionally scale attention weights by `1 / layer_idx + 1`.
93 | reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
94 | Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
95 | dot-product/softmax to float() when training with mixed precision.
96 |
97 | Example:
98 |
99 | ```python
100 | >>> from transformers import NgptConfig, NgptModel
101 |
102 | >>> # Initializing a Ngpt configuration
103 | >>> configuration = NgptConfig()
104 |
105 | >>> # Initializing a model (with random weights) from the configuration
106 | >>> model = NgptModel(configuration)
107 |
108 | >>> # Accessing the model configuration
109 | >>> configuration = model.config
110 | ```"""
111 |
112 | model_type = "ngpt"
113 | keys_to_ignore_at_inference = ["past_key_values"]
114 | attribute_map = {
115 | "hidden_size": "n_embd",
116 | "max_position_embeddings": "n_positions",
117 | "num_attention_heads": "n_head",
118 | "num_hidden_layers": "n_layer",
119 | }
120 |
121 | def __init__(
122 | self,
123 | vocab_size=50257,
124 | n_positions=1024,
125 | n_embd=768,
126 | n_layer=12,
127 | n_head=12,
128 | n_inner=None,
129 | activation_function="silu",
130 | resid_pdrop=0.0,
131 | embd_pdrop=0.0,
132 | attn_pdrop=0.0,
133 | layer_norm_epsilon=1e-5,
134 | initializer_range=0.02,
135 | summary_type="cls_index",
136 | summary_use_proj=True,
137 | summary_activation=None,
138 | summary_proj_to_labels=True,
139 | summary_first_dropout=0.1,
140 | scale_attn_weights=True,
141 | use_cache=True,
142 | bos_token_id=50256,
143 | eos_token_id=50256,
144 | scale_attn_by_inverse_layer_idx=False,
145 | reorder_and_upcast_attn=False,
146 | **kwargs,
147 | ):
148 | self.vocab_size = vocab_size
149 | self.n_positions = n_positions
150 | self.n_embd = n_embd
151 | self.n_layer = n_layer
152 | self.n_head = n_head
153 | self.n_inner = n_inner
154 | self.activation_function = activation_function
155 | self.resid_pdrop = resid_pdrop
156 | self.embd_pdrop = embd_pdrop
157 | self.attn_pdrop = attn_pdrop
158 | self.layer_norm_epsilon = layer_norm_epsilon
159 | self.initializer_range = initializer_range
160 | self.summary_type = summary_type
161 | self.summary_use_proj = summary_use_proj
162 | self.summary_activation = summary_activation
163 | self.summary_first_dropout = summary_first_dropout
164 | self.summary_proj_to_labels = summary_proj_to_labels
165 | self.scale_attn_weights = scale_attn_weights
166 | self.use_cache = use_cache
167 | self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
168 | self.reorder_and_upcast_attn = reorder_and_upcast_attn
169 |
170 | self.bos_token_id = bos_token_id
171 | self.eos_token_id = eos_token_id
172 |
173 | '''
174 | ADD BASE SCALING FACTOR (section 2.2.2)
175 | '''
176 | self.base_scale = 1.0 / (self.n_embd ** 0.5)
177 | self.initializer_range = self.base_scale
178 |
179 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
180 |
181 |
--------------------------------------------------------------------------------
/data/openwebtext_llama/prepare.py:
--------------------------------------------------------------------------------
1 | # saves the openwebtext dataset to a binary file for training. following was helpful:
2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
3 |
4 | import os
5 | from tqdm import tqdm
6 | import numpy as np
7 | import tiktoken
8 | from datasets import load_dataset # huggingface datasets
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | # number of workers in .map() call
11 | # good number to use is ~order number of cpu cores // 2
12 | num_proc = 8
13 |
14 | # number of workers in load_dataset() call
15 | # best number might be different from num_proc above as it also depends on NW speed.
16 | # it is better than 1 usually though
17 | num_proc_load_dataset = num_proc
18 |
19 | #enc = tiktoken.get_encoding("gpt2")
20 | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
21 | eot_token = 2
22 |
23 | if __name__ == '__main__':
24 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
25 | dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset)
26 |
27 | # owt by default only contains the 'train' split, so create a test split
28 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
29 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val
30 |
31 | # this results in:
32 | # >>> split_dataset
33 | # DatasetDict({
34 | # train: Dataset({
35 | # features: ['text'],
36 | # num_rows: 8009762
37 | # })
38 | # val: Dataset({
39 | # features: ['text'],
40 | # num_rows: 4007
41 | # })
42 | # })
43 |
44 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
45 | def process(example):
46 | ids = tokenizer.encode(example['text']) # encode_ordinary ignores any special tokens
47 | ids.append(eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
48 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
49 | out = {'ids': ids, 'len': len(ids)}
50 | return out
51 |
52 | # tokenize the dataset
53 | tokenized = split_dataset.map(
54 | process,
55 | remove_columns=['text'],
56 | desc="tokenizing the splits",
57 | num_proc=num_proc,
58 | )
59 |
60 | # concatenate all the ids in each dataset into one large file we can use for training
61 | for split, dset in tokenized.items():
62 | arr_len = np.sum(dset['len'], dtype=np.uint64)
63 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
64 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
65 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
66 | total_batches = 1024
67 |
68 | idx = 0
69 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
70 | # Batch together samples for faster write
71 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
72 | arr_batch = np.concatenate(batch['ids'])
73 | # Write into mmap
74 | arr[idx : idx + len(arr_batch)] = arr_batch
75 | idx += len(arr_batch)
76 | arr.flush()
77 |
78 | # train.bin is ~17GB, val.bin ~8.5MB
79 | # train has ~9B tokens (9,035,582,198)
80 | # val has ~4M tokens (4,434,897)
81 |
82 | # to read the bin files later, e.g. with numpy:
83 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r')
84 |
--------------------------------------------------------------------------------
/images/4k_arceasy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_arceasy.png
--------------------------------------------------------------------------------
/images/4k_average.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_average.png
--------------------------------------------------------------------------------
/images/4k_hellaswag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_hellaswag.png
--------------------------------------------------------------------------------
/images/4k_lambada.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_lambada.png
--------------------------------------------------------------------------------
/images/4k_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_loss.png
--------------------------------------------------------------------------------
/images/4k_winogrande.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_winogrande.png
--------------------------------------------------------------------------------
/images/4k_wsc273.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/4k_wsc273.png
--------------------------------------------------------------------------------
/images/arc_easy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/arc_easy.png
--------------------------------------------------------------------------------
/images/average.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/average.png
--------------------------------------------------------------------------------
/images/hellaswag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/hellaswag.png
--------------------------------------------------------------------------------
/images/lambada.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/lambada.png
--------------------------------------------------------------------------------
/images/loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/loss.png
--------------------------------------------------------------------------------
/images/network_params.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/network_params.png
--------------------------------------------------------------------------------
/images/optimizer_params.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/optimizer_params.png
--------------------------------------------------------------------------------
/images/wall_clock.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/wall_clock.png
--------------------------------------------------------------------------------
/images/winogrande.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/winogrande.png
--------------------------------------------------------------------------------
/images/wsc273.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoeLi12345/nGPT/dc660f70c3fcfe386d560a31754e42ee309b1f00/images/wsc273.png
--------------------------------------------------------------------------------
/modeling_gpt2.py:
--------------------------------------------------------------------------------
1 | """PyTorch OpenAI GPT-2 model modifications according to nGPT specifications"""
2 |
3 | import math
4 | import os
5 | import warnings
6 | from dataclasses import dataclass
7 | from typing import Callable, Optional, Tuple, Union
8 | import torch
9 | import torch.utils.checkpoint
10 | from torch import nn
11 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
12 |
13 | from transformers.activations import ACT2FN
14 |
15 | from transformers import GenerationMixin
16 | from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
17 | from transformers.modeling_outputs import (
18 | BaseModelOutputWithPastAndCrossAttentions,
19 | CausalLMOutputWithCrossAttentions,
20 | QuestionAnsweringModelOutput,
21 | SequenceClassifierOutputWithPast,
22 | TokenClassifierOutput,
23 | )
24 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, SequenceSummary
25 | from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
26 | from transformers.utils import (
27 | ModelOutput,
28 | add_code_sample_docstrings,
29 | add_start_docstrings,
30 | add_start_docstrings_to_model_forward,
31 | logging,
32 | replace_return_docstrings,
33 | )
34 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
35 | from configuration_gpt2 import GPT2Config
36 | from torchtune.modules import RotaryPositionalEmbeddings
37 |
38 | logger = logging.get_logger(__name__)
39 |
40 | _CHECKPOINT_FOR_DOC = "openai-community/gpt2"
41 | _CONFIG_FOR_DOC = "GPT2Config"
42 |
43 |
44 | def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
45 | """Load tf checkpoints in a pytorch model"""
46 | try:
47 | import re
48 |
49 | import tensorflow as tf
50 | except ImportError:
51 | logger.error(
52 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
53 | "https://www.tensorflow.org/install/ for installation instructions."
54 | )
55 | raise
56 | tf_path = os.path.abspath(gpt2_checkpoint_path)
57 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
58 | # Load weights from TF model
59 | init_vars = tf.train.list_variables(tf_path)
60 | names = []
61 | arrays = []
62 | for name, shape in init_vars:
63 | logger.info(f"Loading TF weight {name} with shape {shape}")
64 | array = tf.train.load_variable(tf_path, name)
65 | names.append(name)
66 | arrays.append(array.squeeze())
67 |
68 | for name, array in zip(names, arrays):
69 | name = name[6:] # skip "model/"
70 | name = name.split("/")
71 | pointer = model
72 | for m_name in name:
73 | if re.fullmatch(r"[A-Za-z]+\d+", m_name):
74 | scope_names = re.split(r"(\d+)", m_name)
75 | else:
76 | scope_names = [m_name]
77 | if scope_names[0] == "w" or scope_names[0] == "g":
78 | pointer = getattr(pointer, "weight")
79 | elif scope_names[0] == "b":
80 | pointer = getattr(pointer, "bias")
81 | elif scope_names[0] == "wpe" or scope_names[0] == "wte":
82 | pointer = getattr(pointer, scope_names[0])
83 | pointer = getattr(pointer, "weight")
84 | else:
85 | pointer = getattr(pointer, scope_names[0])
86 | if len(scope_names) >= 2:
87 | num = int(scope_names[1])
88 | pointer = pointer[num]
89 | try:
90 | if pointer.shape != array.shape:
91 | raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
92 | except ValueError as e:
93 | e.args += (pointer.shape, array.shape)
94 | raise
95 | logger.info(f"Initialize PyTorch weight {name}")
96 | pointer.data = torch.from_numpy(array)
97 | return model
98 |
99 |
100 | def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
101 | attn_weights = torch.matmul(query, key.transpose(-1, -2))
102 |
103 | if module.scale_attn_weights:
104 | attn_weights = attn_weights / torch.full(
105 | [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
106 | )
107 |
108 | # Layer-wise attention scaling
109 | if module.scale_attn_by_inverse_layer_idx:
110 | attn_weights = attn_weights / float(module.layer_idx + 1)
111 |
112 | if not module.is_cross_attention:
113 | # if only "normal" attention layer implements causal mask
114 | query_length, key_length = query.size(-2), key.size(-2)
115 | causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
116 | mask_value = torch.finfo(attn_weights.dtype).min
117 | # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
118 | # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
119 | mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
120 | attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
121 |
122 | if attention_mask is not None:
123 | # Apply the attention mask
124 | attn_weights = attn_weights + attention_mask
125 |
126 | attn_weights = nn.functional.softmax(attn_weights, dim=-1)
127 |
128 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
129 | attn_weights = attn_weights.type(value.dtype)
130 | attn_weights = module.attn_dropout(attn_weights)
131 |
132 | # Mask heads if we want to
133 | if head_mask is not None:
134 | attn_weights = attn_weights * head_mask
135 |
136 | attn_output = torch.matmul(attn_weights, value)
137 | attn_output = attn_output.transpose(1, 2)
138 |
139 | return attn_output, attn_weights
140 |
141 |
142 | class GPT2Attention(nn.Module):
143 | def __init__(self, config, is_cross_attention=False, layer_idx=None):
144 | super().__init__()
145 | self.config = config
146 | max_positions = config.max_position_embeddings
147 | self.register_buffer(
148 | "bias",
149 | torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
150 | 1, 1, max_positions, max_positions
151 | ),
152 | persistent=False,
153 | )
154 | self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
155 |
156 | self.embed_dim = config.hidden_size
157 | self.num_heads = config.num_attention_heads
158 | self.head_dim = self.embed_dim // self.num_heads
159 | self.split_size = self.embed_dim
160 | if self.head_dim * self.num_heads != self.embed_dim:
161 | raise ValueError(
162 | f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
163 | f" {self.num_heads})."
164 | )
165 |
166 | # INITIALIZE ROTARY POSITIONAL EMBEDDINGS HERE
167 | self.rpe = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=config.max_position_embeddings)
168 |
169 | self.scale_attn_weights = config.scale_attn_weights
170 | self.is_cross_attention = is_cross_attention
171 |
172 | # Layer-wise attention scaling, reordering, and upcasting
173 | self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
174 | self.layer_idx = layer_idx
175 | self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
176 |
177 | if self.is_cross_attention:
178 | self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
179 | self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
180 | else:
181 | self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
182 | self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
183 |
184 | self.attn_dropout = nn.Dropout(config.attn_pdrop)
185 | self.resid_dropout = nn.Dropout(config.resid_pdrop)
186 | self.is_causal = True
187 |
188 | self.pruned_heads = set()
189 |
190 | def prune_heads(self, heads):
191 | if len(heads) == 0:
192 | return
193 | heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
194 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
195 |
196 | # Prune conv1d layers
197 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
198 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
199 |
200 | # Update hyper params
201 | self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
202 | self.num_heads = self.num_heads - len(heads)
203 | self.pruned_heads = self.pruned_heads.union(heads)
204 |
205 | def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
206 | # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
207 | bsz, num_heads, q_seq_len, dk = query.size()
208 | _, _, k_seq_len, _ = key.size()
209 |
210 | # Preallocate attn_weights for `baddbmm`
211 | attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
212 |
213 | # Compute Scale Factor
214 | scale_factor = 1.0
215 | if self.scale_attn_weights:
216 | scale_factor /= float(value.size(-1)) ** 0.5
217 |
218 | if self.scale_attn_by_inverse_layer_idx:
219 | scale_factor /= float(self.layer_idx + 1)
220 |
221 | # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
222 | with torch.amp.autocast(query.device.type, enabled=False):
223 | q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
224 | attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
225 | attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
226 |
227 | if not self.is_cross_attention:
228 | # if only "normal" attention layer implements causal mask
229 | query_length, key_length = query.size(-2), key.size(-2)
230 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
231 | mask_value = torch.finfo(attn_weights.dtype).min
232 | # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
233 | # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
234 | mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
235 | attn_weights = torch.where(causal_mask, attn_weights, mask_value)
236 |
237 | if attention_mask is not None:
238 | # Apply the attention mask
239 | attn_weights = attn_weights + attention_mask
240 |
241 | attn_weights = nn.functional.softmax(attn_weights, dim=-1)
242 |
243 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
244 | if attn_weights.dtype != torch.float32:
245 | raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
246 | attn_weights = attn_weights.type(value.dtype)
247 | attn_weights = self.attn_dropout(attn_weights)
248 |
249 | # Mask heads if we want to
250 | if head_mask is not None:
251 | attn_weights = attn_weights * head_mask
252 |
253 | attn_output = torch.matmul(attn_weights, value)
254 | attn_output = attn_output.transpose(1, 2)
255 |
256 | return attn_output, attn_weights
257 |
258 | def forward(
259 | self,
260 | hidden_states: Optional[Tuple[torch.FloatTensor]],
261 | layer_past: Optional[Tuple[torch.Tensor]] = None,
262 | attention_mask: Optional[torch.FloatTensor] = None,
263 | head_mask: Optional[torch.FloatTensor] = None,
264 | encoder_hidden_states: Optional[torch.Tensor] = None,
265 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
266 | use_cache: Optional[bool] = False,
267 | output_attentions: Optional[bool] = False,
268 | **kwargs,
269 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
270 | if encoder_hidden_states is not None:
271 | if not hasattr(self, "q_attn"):
272 | raise ValueError(
273 | "If class is used as cross attention, the weights `q_attn` have to be defined. "
274 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
275 | )
276 |
277 | query_states = self.q_attn(hidden_states)
278 | key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
279 | attention_mask = encoder_attention_mask
280 | else:
281 | query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
282 |
283 | shape_q = (*query_states.shape[:-1], -1, self.head_dim)
284 | shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
285 |
286 |
287 | query_states = query_states.view(shape_q).transpose(1, 2)
288 | key_states = key_states.view(shape_kv).transpose(1, 2)
289 | value_states = value_states.view(shape_kv).transpose(1, 2)
290 |
291 | # APPLY ROTARY POSITION EMBEDDING HERE
292 | query_states = query_states.transpose(1, 2)
293 | key_states = key_states.transpose(1, 2)
294 |
295 | query_states = self.rpe(query_states)
296 | key_states = self.rpe(key_states)
297 |
298 | query_states = query_states.transpose(1, 2)
299 | key_states = key_states.transpose(1, 2)
300 |
301 | if layer_past is not None:
302 | past_key, past_value = layer_past
303 | key_states = torch.cat((past_key, key_states), dim=-2)
304 | value_states = torch.cat((past_value, value_states), dim=-2)
305 |
306 | if use_cache is True:
307 | present = (key_states, value_states)
308 | else:
309 | present = None
310 |
311 | is_cross_attention = encoder_hidden_states is not None
312 | is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
313 |
314 | using_eager = self.config._attn_implementation == "eager"
315 | attention_interface: Callable = eager_attention_forward
316 | #print(self.config._attn_implementation, query_states.shape)
317 | if self.config._attn_implementation != "eager":
318 | if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
319 | using_eager = True
320 | logger.warning_once(
321 | "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
322 | 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
323 | )
324 | else:
325 | # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
326 | # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
327 | # not necessarily to eager (if mentionned options are provided).
328 | attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
329 |
330 | if using_eager and self.reorder_and_upcast_attn:
331 | attn_output, attn_weights = self._upcast_and_reordered_attn(
332 | query_states, key_states, value_states, attention_mask, head_mask
333 | )
334 | else:
335 | attn_output, attn_weights = attention_interface(
336 | self,
337 | query_states,
338 | key_states,
339 | value_states,
340 | attention_mask,
341 | head_mask=head_mask,
342 | dropout=self.attn_dropout.p if self.training else 0.0,
343 | is_causal=is_causal,
344 | **kwargs,
345 | )
346 |
347 | attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
348 | attn_output = self.c_proj(attn_output)
349 | attn_output = self.resid_dropout(attn_output)
350 |
351 | outputs = (attn_output, present)
352 | if output_attentions:
353 | outputs += (attn_weights,)
354 |
355 | return outputs # a, present, (attentions)
356 |
357 |
358 | class GPT2MLP(nn.Module):
359 | def __init__(self, intermediate_size, config):
360 | super().__init__()
361 | embed_dim = config.hidden_size
362 | #self.c_fc = Conv1D(intermediate_size, embed_dim)
363 | self.c_fc = Conv1D(2*intermediate_size, embed_dim)
364 | self.c_proj = Conv1D(embed_dim, intermediate_size)
365 | self.act = ACT2FN[config.activation_function]
366 | self.dropout = nn.Dropout(config.resid_pdrop)
367 |
368 | def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
369 | hidden_states = self.c_fc(hidden_states)
370 | #IMPLEMENT SWIGLU
371 | u, v = torch.chunk(hidden_states, 2, dim=-1)
372 | #hidden_states = self.c_fc(hidden_states)
373 | hidden_states = u * self.act(v)
374 | # hidden_states = self.act(hidden_states)
375 | hidden_states = self.c_proj(hidden_states)
376 | hidden_states = self.dropout(hidden_states)
377 | return hidden_states
378 |
379 |
380 | class GPT2Block(nn.Module):
381 | def __init__(self, config, layer_idx=None):
382 | super().__init__()
383 | hidden_size = config.hidden_size
384 | #inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
385 | inner_dim = config.n_inner if config.n_inner is not None else 8 * hidden_size // 3
386 |
387 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
388 | self.attn = GPT2Attention(config=config, layer_idx=layer_idx)
389 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
390 |
391 | if config.add_cross_attention:
392 | self.crossattention = GPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx)
393 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
394 |
395 | self.mlp = GPT2MLP(inner_dim, config)
396 |
397 | def forward(
398 | self,
399 | hidden_states: Optional[Tuple[torch.FloatTensor]],
400 | layer_past: Optional[Tuple[torch.Tensor]] = None,
401 | attention_mask: Optional[torch.FloatTensor] = None,
402 | head_mask: Optional[torch.FloatTensor] = None,
403 | encoder_hidden_states: Optional[torch.Tensor] = None,
404 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
405 | use_cache: Optional[bool] = False,
406 | output_attentions: Optional[bool] = False,
407 | ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
408 | residual = hidden_states
409 | hidden_states = self.ln_1(hidden_states)
410 | attn_outputs = self.attn(
411 | hidden_states,
412 | layer_past=layer_past,
413 | attention_mask=attention_mask,
414 | head_mask=head_mask,
415 | use_cache=use_cache,
416 | output_attentions=output_attentions,
417 | )
418 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
419 | outputs = attn_outputs[1:]
420 | # residual connection
421 | hidden_states = attn_output + residual
422 |
423 | if encoder_hidden_states is not None:
424 | # add one self-attention block for cross-attention
425 | if not hasattr(self, "crossattention"):
426 | raise ValueError(
427 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
428 | "cross-attention layers by setting `config.add_cross_attention=True`"
429 | )
430 | residual = hidden_states
431 | hidden_states = self.ln_cross_attn(hidden_states)
432 | cross_attn_outputs = self.crossattention(
433 | hidden_states,
434 | attention_mask=attention_mask,
435 | head_mask=head_mask,
436 | encoder_hidden_states=encoder_hidden_states,
437 | encoder_attention_mask=encoder_attention_mask,
438 | output_attentions=output_attentions,
439 | )
440 | attn_output = cross_attn_outputs[0]
441 | # residual connection
442 | hidden_states = residual + attn_output
443 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
444 |
445 | residual = hidden_states
446 | hidden_states = self.ln_2(hidden_states)
447 | feed_forward_hidden_states = self.mlp(hidden_states)
448 | # residual connection
449 | hidden_states = residual + feed_forward_hidden_states
450 |
451 | if use_cache:
452 | outputs = (hidden_states,) + outputs
453 | else:
454 | outputs = (hidden_states,) + outputs[1:]
455 |
456 | return outputs # hidden_states, present, (attentions, cross_attentions)
457 |
458 |
459 | class GPT2PreTrainedModel(PreTrainedModel):
460 | """
461 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
462 | models.
463 | """
464 |
465 | config_class = GPT2Config
466 | load_tf_weights = load_tf_weights_in_gpt2
467 | base_model_prefix = "transformer"
468 | is_parallelizable = True
469 | supports_gradient_checkpointing = True
470 | _no_split_modules = ["GPT2Block"]
471 | _skip_keys_device_placement = "past_key_values"
472 | _supports_flash_attn_2 = True
473 | _supports_sdpa = True
474 |
475 | def __init__(self, *inputs, **kwargs):
476 | super().__init__(*inputs, **kwargs)
477 |
478 | def _init_weights(self, module):
479 | """Initialize the weights."""
480 | if isinstance(module, (nn.Linear, Conv1D)):
481 | # Slightly different from the TF version which uses truncated_normal for initialization
482 | # cf https://github.com/pytorch/pytorch/pull/5617
483 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
484 | if module.bias is not None:
485 | module.bias.data.zero_()
486 | elif isinstance(module, nn.Embedding):
487 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
488 | if module.padding_idx is not None:
489 | module.weight.data[module.padding_idx].zero_()
490 | elif isinstance(module, nn.LayerNorm):
491 | module.bias.data.zero_()
492 | module.weight.data.fill_(1.0)
493 |
494 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
495 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
496 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
497 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/
498 | #
499 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
500 | for name, p in module.named_parameters():
501 | if name == "c_proj.weight":
502 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
503 | p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
504 |
505 |
506 | @dataclass
507 | class GPT2DoubleHeadsModelOutput(ModelOutput):
508 | """
509 | Base class for outputs of models predicting if two sentences are consecutive or not.
510 |
511 | Args:
512 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
513 | Language modeling loss.
514 | mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
515 | Multiple choice classification loss.
516 | logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
517 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
518 | mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
519 | Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
520 | past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
521 | Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
522 | sequence_length, embed_size_per_head)`).
523 |
524 | Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
525 | `past_key_values` input) to speed up sequential decoding.
526 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
527 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
528 | shape `(batch_size, sequence_length, hidden_size)`.
529 |
530 | Hidden-states of the model at the output of each layer plus the initial embedding outputs.
531 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
532 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
533 | sequence_length)`.
534 |
535 | GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
536 | self-attention heads.
537 | """
538 |
539 | loss: Optional[torch.FloatTensor] = None
540 | mc_loss: Optional[torch.FloatTensor] = None
541 | logits: torch.FloatTensor = None
542 | mc_logits: torch.FloatTensor = None
543 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
544 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
545 | attentions: Optional[Tuple[torch.FloatTensor]] = None
546 |
547 |
548 | GPT2_START_DOCSTRING = r"""
549 |
550 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
551 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
552 | etc.)
553 |
554 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
555 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
556 | and behavior.
557 |
558 | Parameters:
559 | config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
560 | Initializing with a config file does not load the weights associated with the model, only the
561 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
562 | """
563 |
564 | GPT2_INPUTS_DOCSTRING = r"""
565 | Args:
566 | input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
567 | `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
568 | `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
569 | sequence tokens in the vocabulary.
570 |
571 | If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
572 | `input_ids`.
573 |
574 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
575 | [`PreTrainedTokenizer.__call__`] for details.
576 |
577 | [What are input IDs?](../glossary#input-ids)
578 | past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
579 | Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
580 | `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
581 | their past given to this model should not be passed as `input_ids` as they have already been computed.
582 | attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
583 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
584 |
585 | - 1 for tokens that are **not masked**,
586 | - 0 for tokens that are **masked**.
587 |
588 | If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
589 | `past_key_values`. In other words, the `attention_mask` always has to have the length:
590 | `len(past_key_values) + len(input_ids)`
591 |
592 | [What are attention masks?](../glossary#attention-mask)
593 | token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
594 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
595 | 1]`:
596 |
597 | - 0 corresponds to a *sentence A* token,
598 | - 1 corresponds to a *sentence B* token.
599 |
600 | [What are token type IDs?](../glossary#token-type-ids)
601 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
602 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
603 | config.max_position_embeddings - 1]`.
604 |
605 | [What are position IDs?](../glossary#position-ids)
606 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
607 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
608 |
609 | - 1 indicates the head is **not masked**,
610 | - 0 indicates the head is **masked**.
611 |
612 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
613 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
614 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
615 | model's internal embedding lookup matrix.
616 |
617 | If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
618 | `past_key_values`).
619 | use_cache (`bool`, *optional*):
620 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
621 | `past_key_values`).
622 | output_attentions (`bool`, *optional*):
623 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
624 | tensors for more detail.
625 | output_hidden_states (`bool`, *optional*):
626 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
627 | more detail.
628 | return_dict (`bool`, *optional*):
629 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
630 | """
631 | PARALLELIZE_DOCSTRING = r"""
632 | This is an experimental feature and is a subject to change at a moment's notice.
633 |
634 | Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
635 | it will evenly distribute blocks across all devices.
636 |
637 | Args:
638 | device_map (`Dict[int, list]`, *optional*):
639 | A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
640 | automatically mapped to the first device (for esoteric reasons). That means that the first device should
641 | have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
642 | following number of attention modules:
643 |
644 | - openai-community/gpt2: 12
645 | - openai-community/gpt2-medium: 24
646 | - openai-community/gpt2-large: 36
647 | - openai-community/gpt2-xl: 48
648 |
649 | Example:
650 |
651 | ```python
652 | # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
653 | model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl")
654 | device_map = {
655 | 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
656 | 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
657 | 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
658 | 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
659 | }
660 | model.parallelize(device_map)
661 | ```
662 | """
663 | DEPARALLELIZE_DOCSTRING = r"""
664 | Moves the model to cpu from a model parallel state.
665 |
666 | Example:
667 |
668 | ```python
669 | # On a 4 GPU machine with openai-community/gpt2-large:
670 | model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large")
671 | device_map = {
672 | 0: [0, 1, 2, 3, 4, 5, 6, 7],
673 | 1: [8, 9, 10, 11, 12, 13, 14, 15],
674 | 2: [16, 17, 18, 19, 20, 21, 22, 23],
675 | 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
676 | }
677 | model.parallelize(device_map) # Splits the model across several devices
678 | model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
679 | ```
680 | """
681 |
682 |
683 | @add_start_docstrings(
684 | "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
685 | GPT2_START_DOCSTRING,
686 | )
687 | class GPT2Model(GPT2PreTrainedModel):
688 | _supports_param_buffer_assignment = False
689 |
690 | def __init__(self, config):
691 | super().__init__(config)
692 |
693 | self.embed_dim = config.hidden_size
694 |
695 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
696 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
697 |
698 | self.drop = nn.Dropout(config.embd_pdrop)
699 | self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
700 | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
701 |
702 | # Model parallel
703 | self.model_parallel = False
704 | self.device_map = None
705 | self.gradient_checkpointing = False
706 | self._attn_implementation = config._attn_implementation
707 |
708 | # Initialize weights and apply final processing
709 | self.post_init()
710 |
711 | @add_start_docstrings(PARALLELIZE_DOCSTRING)
712 | def parallelize(self, device_map=None):
713 | # Check validity of device_map
714 | warnings.warn(
715 | "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
716 | " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
717 | " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
718 | " ...}",
719 | FutureWarning,
720 | )
721 | self.device_map = (
722 | get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
723 | )
724 | assert_device_map(self.device_map, len(self.h))
725 | self.model_parallel = True
726 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
727 | self.last_device = "cuda:" + str(max(self.device_map.keys()))
728 | self.wte = self.wte.to(self.first_device)
729 | self.wpe = self.wpe.to(self.first_device)
730 | # Load onto devices
731 | for k, v in self.device_map.items():
732 | for block in v:
733 | cuda_device = "cuda:" + str(k)
734 | self.h[block] = self.h[block].to(cuda_device)
735 | # ln_f to last
736 | self.ln_f = self.ln_f.to(self.last_device)
737 |
738 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
739 | def deparallelize(self):
740 | warnings.warn(
741 | "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
742 | FutureWarning,
743 | )
744 | self.model_parallel = False
745 | self.device_map = None
746 | self.first_device = "cpu"
747 | self.last_device = "cpu"
748 | self.wte = self.wte.to("cpu")
749 | self.wpe = self.wpe.to("cpu")
750 | for index in range(len(self.h)):
751 | self.h[index] = self.h[index].to("cpu")
752 | self.ln_f = self.ln_f.to("cpu")
753 | torch.cuda.empty_cache()
754 |
755 | def get_input_embeddings(self):
756 | return self.wte
757 |
758 | def set_input_embeddings(self, new_embeddings):
759 | self.wte = new_embeddings
760 |
761 | def _prune_heads(self, heads_to_prune):
762 | """
763 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
764 | """
765 | for layer, heads in heads_to_prune.items():
766 | self.h[layer].attn.prune_heads(heads)
767 |
768 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
769 | @add_code_sample_docstrings(
770 | checkpoint=_CHECKPOINT_FOR_DOC,
771 | output_type=BaseModelOutputWithPastAndCrossAttentions,
772 | config_class=_CONFIG_FOR_DOC,
773 | )
774 | def forward(
775 | self,
776 | input_ids: Optional[torch.LongTensor] = None,
777 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
778 | attention_mask: Optional[torch.FloatTensor] = None,
779 | token_type_ids: Optional[torch.LongTensor] = None,
780 | position_ids: Optional[torch.LongTensor] = None,
781 | head_mask: Optional[torch.FloatTensor] = None,
782 | inputs_embeds: Optional[torch.FloatTensor] = None,
783 | encoder_hidden_states: Optional[torch.Tensor] = None,
784 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
785 | use_cache: Optional[bool] = None,
786 | output_attentions: Optional[bool] = None,
787 | output_hidden_states: Optional[bool] = None,
788 | return_dict: Optional[bool] = None,
789 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
790 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
791 | output_hidden_states = (
792 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
793 | )
794 | use_cache = use_cache if use_cache is not None else self.config.use_cache
795 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
796 |
797 | if input_ids is not None and inputs_embeds is not None:
798 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
799 | elif input_ids is not None:
800 | self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
801 | input_shape = input_ids.size()
802 | input_ids = input_ids.view(-1, input_shape[-1])
803 | batch_size = input_ids.shape[0]
804 | elif inputs_embeds is not None:
805 | input_shape = inputs_embeds.size()[:-1]
806 | batch_size = inputs_embeds.shape[0]
807 | else:
808 | raise ValueError("You have to specify either input_ids or inputs_embeds")
809 |
810 | device = input_ids.device if input_ids is not None else inputs_embeds.device
811 |
812 | if token_type_ids is not None:
813 | token_type_ids = token_type_ids.view(-1, input_shape[-1])
814 |
815 | if past_key_values is None:
816 | past_length = 0
817 | past_key_values = tuple([None] * len(self.h))
818 | else:
819 | past_length = past_key_values[0][0].size(-2)
820 | if position_ids is None:
821 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
822 | position_ids = position_ids.unsqueeze(0)
823 |
824 | if inputs_embeds is None:
825 | inputs_embeds = self.wte(input_ids)
826 | position_embeds = self.wpe(position_ids)
827 | hidden_states = inputs_embeds + position_embeds
828 |
829 | # Attention mask.
830 | _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
831 | attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
832 | if self._attn_implementation == "flash_attention_2":
833 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
834 | elif _use_sdpa:
835 | attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
836 | attention_mask=attention_mask,
837 | input_shape=(batch_size, input_shape[-1]),
838 | inputs_embeds=inputs_embeds,
839 | past_key_values_length=past_length,
840 | )
841 | else:
842 | if attention_mask is not None:
843 | # We create a 3D attention mask from a 2D tensor mask.
844 | # Sizes are [batch_size, 1, 1, to_seq_length]
845 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
846 | # this attention mask is more simple than the triangular masking of causal attention
847 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
848 | attention_mask = attention_mask[:, None, None, :]
849 |
850 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
851 | # masked positions, this operation will create a tensor which is 0.0 for
852 | # positions we want to attend and the dtype's smallest value for masked positions.
853 | # Since we are adding it to the raw scores before the softmax, this is
854 | # effectively the same as removing these entirely.
855 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
856 | attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
857 |
858 | # If a 2D or 3D attention mask is provided for the cross-attention
859 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
860 | if self.config.add_cross_attention and encoder_hidden_states is not None:
861 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
862 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
863 | if encoder_attention_mask is None:
864 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
865 | if _use_sdpa:
866 | encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
867 | mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
868 | )
869 | elif not self._attn_implementation == "flash_attention_2":
870 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
871 | else:
872 | encoder_attention_mask = None
873 |
874 | # Prepare head mask if needed
875 | # 1.0 in head_mask indicate we keep the head
876 | # attention_probs has shape bsz x n_heads x N x N
877 | # head_mask has shape n_layer x batch x n_heads x N x N
878 | head_mask = self.get_head_mask(head_mask, self.config.n_layer)
879 |
880 | if token_type_ids is not None:
881 | token_type_embeds = self.wte(token_type_ids)
882 | hidden_states = hidden_states + token_type_embeds
883 |
884 | hidden_states = self.drop(hidden_states)
885 |
886 | output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
887 |
888 | if self.gradient_checkpointing and self.training:
889 | if use_cache:
890 | logger.warning_once(
891 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
892 | )
893 | use_cache = False
894 |
895 | presents = () if use_cache else None
896 | all_self_attentions = () if output_attentions else None
897 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
898 | all_hidden_states = () if output_hidden_states else None
899 | for i in range(len(self.h)):
900 | block, layer_past = self.h[i], past_key_values[i]
901 | # Model parallel
902 | if self.model_parallel:
903 | torch.cuda.set_device(hidden_states.device)
904 | # Ensure layer_past is on same device as hidden_states (might not be correct)
905 | if layer_past is not None:
906 | layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
907 | # Ensure that attention_mask is always on the same device as hidden_states
908 | if attention_mask is not None:
909 | attention_mask = attention_mask.to(hidden_states.device)
910 | if isinstance(head_mask, torch.Tensor):
911 | head_mask = head_mask.to(hidden_states.device)
912 | if output_hidden_states:
913 | all_hidden_states = all_hidden_states + (hidden_states,)
914 |
915 | if self.gradient_checkpointing and self.training:
916 | outputs = self._gradient_checkpointing_func(
917 | block.__call__,
918 | hidden_states,
919 | None,
920 | attention_mask,
921 | head_mask[i],
922 | encoder_hidden_states,
923 | encoder_attention_mask,
924 | use_cache,
925 | output_attentions,
926 | )
927 | else:
928 | outputs = block(
929 | hidden_states,
930 | layer_past=layer_past,
931 | attention_mask=attention_mask,
932 | head_mask=head_mask[i],
933 | encoder_hidden_states=encoder_hidden_states,
934 | encoder_attention_mask=encoder_attention_mask,
935 | use_cache=use_cache,
936 | output_attentions=output_attentions,
937 | )
938 |
939 | hidden_states = outputs[0]
940 | if use_cache is True:
941 | presents = presents + (outputs[1],)
942 |
943 | if output_attentions:
944 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
945 | if self.config.add_cross_attention:
946 | all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
947 |
948 | # Model Parallel: If it's the last layer for that device, put things on the next device
949 | if self.model_parallel:
950 | for k, v in self.device_map.items():
951 | if i == v[-1] and "cuda:" + str(k) != self.last_device:
952 | hidden_states = hidden_states.to("cuda:" + str(k + 1))
953 |
954 | hidden_states = self.ln_f(hidden_states)
955 |
956 | hidden_states = hidden_states.view(output_shape)
957 | # Add last hidden state
958 | if output_hidden_states:
959 | all_hidden_states = all_hidden_states + (hidden_states,)
960 |
961 | if not return_dict:
962 | return tuple(
963 | v
964 | for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
965 | if v is not None
966 | )
967 |
968 | return BaseModelOutputWithPastAndCrossAttentions(
969 | last_hidden_state=hidden_states,
970 | past_key_values=presents,
971 | hidden_states=all_hidden_states,
972 | attentions=all_self_attentions,
973 | cross_attentions=all_cross_attentions,
974 | )
975 |
976 |
977 | @add_start_docstrings(
978 | """
979 | The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
980 | embeddings).
981 | """,
982 | GPT2_START_DOCSTRING,
983 | )
984 | class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
985 | _tied_weights_keys = ["lm_head.weight"]
986 |
987 | def __init__(self, config):
988 | super().__init__(config)
989 | self.transformer = GPT2Model(config)
990 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # this line is now completely useless. lm_head will not be used anywhere
991 | self.lm_head_2 = nn.Linear(config.n_embd, config.vocab_size, bias=False)
992 |
993 | # Model parallel
994 | self.model_parallel = False
995 | self.device_map = None
996 |
997 | # Initialize weights and apply final processing
998 | self.post_init()
999 |
1000 | @add_start_docstrings(PARALLELIZE_DOCSTRING)
1001 | def parallelize(self, device_map=None):
1002 | warnings.warn(
1003 | "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1004 | " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1005 | " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
1006 | " 0, 'transformer.h.1': 1, ...}",
1007 | FutureWarning,
1008 | )
1009 | self.device_map = (
1010 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1011 | if device_map is None
1012 | else device_map
1013 | )
1014 | assert_device_map(self.device_map, len(self.transformer.h))
1015 | self.transformer.parallelize(self.device_map)
1016 | self.lm_head = self.lm_head.to(self.transformer.first_device)
1017 | self.model_parallel = True
1018 |
1019 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1020 | def deparallelize(self):
1021 | warnings.warn(
1022 | "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1023 | FutureWarning,
1024 | )
1025 | self.transformer.deparallelize()
1026 | self.transformer = self.transformer.to("cpu")
1027 | self.lm_head = self.lm_head.to("cpu")
1028 | self.model_parallel = False
1029 | torch.cuda.empty_cache()
1030 |
1031 | def get_output_embeddings(self):
1032 | return self.lm_head
1033 |
1034 | def set_output_embeddings(self, new_embeddings):
1035 | self.lm_head = new_embeddings
1036 |
1037 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1038 | @add_code_sample_docstrings(
1039 | checkpoint=_CHECKPOINT_FOR_DOC,
1040 | output_type=CausalLMOutputWithCrossAttentions,
1041 | config_class=_CONFIG_FOR_DOC,
1042 | )
1043 | def forward(
1044 | self,
1045 | input_ids: Optional[torch.LongTensor] = None,
1046 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1047 | attention_mask: Optional[torch.FloatTensor] = None,
1048 | token_type_ids: Optional[torch.LongTensor] = None,
1049 | position_ids: Optional[torch.LongTensor] = None,
1050 | head_mask: Optional[torch.FloatTensor] = None,
1051 | inputs_embeds: Optional[torch.FloatTensor] = None,
1052 | encoder_hidden_states: Optional[torch.Tensor] = None,
1053 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
1054 | labels: Optional[torch.LongTensor] = None,
1055 | use_cache: Optional[bool] = None,
1056 | output_attentions: Optional[bool] = None,
1057 | output_hidden_states: Optional[bool] = None,
1058 | return_dict: Optional[bool] = None,
1059 | ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1060 | r"""
1061 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1062 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1063 | `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1064 | are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1065 | """
1066 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1067 |
1068 | transformer_outputs = self.transformer(
1069 | input_ids,
1070 | past_key_values=past_key_values,
1071 | attention_mask=attention_mask,
1072 | token_type_ids=token_type_ids,
1073 | position_ids=position_ids,
1074 | head_mask=head_mask,
1075 | inputs_embeds=inputs_embeds,
1076 | encoder_hidden_states=encoder_hidden_states,
1077 | encoder_attention_mask=encoder_attention_mask,
1078 | use_cache=use_cache,
1079 | output_attentions=output_attentions,
1080 | output_hidden_states=output_hidden_states,
1081 | return_dict=return_dict,
1082 | )
1083 | hidden_states = transformer_outputs[0]
1084 |
1085 | # Set device for model parallelism
1086 | if self.model_parallel:
1087 | torch.cuda.set_device(self.transformer.first_device)
1088 | hidden_states = hidden_states.to(self.lm_head.weight.device)
1089 |
1090 | lm_logits = self.lm_head_2(hidden_states)
1091 |
1092 | loss = None
1093 | if labels is not None:
1094 | # move labels to correct device to enable model parallelism
1095 | labels = labels.to(lm_logits.device)
1096 | # Shift so that tokens < n predict n
1097 | shift_logits = lm_logits[..., :-1, :].contiguous()
1098 | shift_labels = labels[..., 1:].contiguous()
1099 | # Flatten the tokens
1100 | loss_fct = CrossEntropyLoss()
1101 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1102 |
1103 | if not return_dict:
1104 | output = (lm_logits,) + transformer_outputs[1:]
1105 | return ((loss,) + output) if loss is not None else output
1106 |
1107 | return CausalLMOutputWithCrossAttentions(
1108 | loss=loss,
1109 | logits=lm_logits,
1110 | past_key_values=transformer_outputs.past_key_values,
1111 | hidden_states=transformer_outputs.hidden_states,
1112 | attentions=transformer_outputs.attentions,
1113 | cross_attentions=transformer_outputs.cross_attentions,
1114 | )
1115 |
1116 | @staticmethod
1117 | def _reorder_cache(
1118 | past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1119 | ) -> Tuple[Tuple[torch.Tensor]]:
1120 | """
1121 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1122 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1123 | beam_idx at every generation step.
1124 | """
1125 | return tuple(
1126 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1127 | for layer_past in past_key_values
1128 | )
1129 |
1130 |
1131 | @add_start_docstrings(
1132 | """
1133 | The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
1134 | RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
1135 | input embeddings, the classification head takes as input the input of a specified classification token index in the
1136 | input sequence).
1137 | """,
1138 | GPT2_START_DOCSTRING,
1139 | )
1140 | class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
1141 | _tied_weights_keys = ["lm_head.weight"]
1142 |
1143 | def __init__(self, config):
1144 | super().__init__(config)
1145 | config.num_labels = 1
1146 | self.transformer = GPT2Model(config)
1147 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1148 | self.multiple_choice_head = SequenceSummary(config)
1149 |
1150 | # Model parallel
1151 | self.model_parallel = False
1152 | self.device_map = None
1153 |
1154 | # Initialize weights and apply final processing
1155 | self.post_init()
1156 |
1157 | @add_start_docstrings(PARALLELIZE_DOCSTRING)
1158 | def parallelize(self, device_map=None):
1159 | warnings.warn(
1160 | "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
1161 | " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
1162 | " own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1163 | " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
1164 | FutureWarning,
1165 | )
1166 | self.device_map = (
1167 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1168 | if device_map is None
1169 | else device_map
1170 | )
1171 | assert_device_map(self.device_map, len(self.transformer.h))
1172 | self.transformer.parallelize(self.device_map)
1173 | self.lm_head = self.lm_head.to(self.transformer.first_device)
1174 | self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device)
1175 | self.model_parallel = True
1176 |
1177 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1178 | def deparallelize(self):
1179 | warnings.warn(
1180 | "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1181 | FutureWarning,
1182 | )
1183 | self.transformer.deparallelize()
1184 | self.transformer = self.transformer.to("cpu")
1185 | self.lm_head = self.lm_head.to("cpu")
1186 | self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1187 | self.model_parallel = False
1188 | torch.cuda.empty_cache()
1189 |
1190 | def get_output_embeddings(self):
1191 | return self.lm_head
1192 |
1193 | def set_output_embeddings(self, new_embeddings):
1194 | self.lm_head = new_embeddings
1195 |
1196 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1197 | @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
1198 | def forward(
1199 | self,
1200 | input_ids: Optional[torch.LongTensor] = None,
1201 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1202 | attention_mask: Optional[torch.FloatTensor] = None,
1203 | token_type_ids: Optional[torch.LongTensor] = None,
1204 | position_ids: Optional[torch.LongTensor] = None,
1205 | head_mask: Optional[torch.FloatTensor] = None,
1206 | inputs_embeds: Optional[torch.FloatTensor] = None,
1207 | mc_token_ids: Optional[torch.LongTensor] = None,
1208 | labels: Optional[torch.LongTensor] = None,
1209 | mc_labels: Optional[torch.LongTensor] = None,
1210 | use_cache: Optional[bool] = None,
1211 | output_attentions: Optional[bool] = None,
1212 | output_hidden_states: Optional[bool] = None,
1213 | return_dict: Optional[bool] = None,
1214 | **kwargs,
1215 | ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
1216 | r"""
1217 | mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
1218 | Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
1219 | 1]`.
1220 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1221 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1222 | `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
1223 | `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
1224 | mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
1225 | Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1226 | where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
1227 |
1228 | Return:
1229 |
1230 | Example:
1231 |
1232 | ```python
1233 | >>> import torch
1234 | >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel
1235 |
1236 | >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
1237 | >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
1238 |
1239 | >>> # Add a [CLS] to the vocabulary (we should train it also!)
1240 | >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
1241 | >>> # Update the model embeddings with the new vocabulary size
1242 | >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
1243 |
1244 | >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
1245 | >>> encoded_choices = [tokenizer.encode(s) for s in choices]
1246 | >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
1247 |
1248 | >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
1249 | >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
1250 |
1251 | >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
1252 | >>> lm_logits = outputs.logits
1253 | >>> mc_logits = outputs.mc_logits
1254 | ```"""
1255 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1256 |
1257 | transformer_outputs = self.transformer(
1258 | input_ids,
1259 | past_key_values=past_key_values,
1260 | attention_mask=attention_mask,
1261 | token_type_ids=token_type_ids,
1262 | position_ids=position_ids,
1263 | head_mask=head_mask,
1264 | inputs_embeds=inputs_embeds,
1265 | use_cache=use_cache,
1266 | output_attentions=output_attentions,
1267 | output_hidden_states=output_hidden_states,
1268 | return_dict=return_dict,
1269 | )
1270 |
1271 | hidden_states = transformer_outputs[0]
1272 |
1273 | # Set device for model parallelism
1274 | if self.model_parallel:
1275 | torch.cuda.set_device(self.transformer.first_device)
1276 | hidden_states = hidden_states.to(self.lm_head.weight.device)
1277 |
1278 | lm_logits = self.lm_head(hidden_states)
1279 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1280 |
1281 | mc_loss = None
1282 | if mc_labels is not None:
1283 | loss_fct = CrossEntropyLoss()
1284 | mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
1285 | lm_loss = None
1286 | if labels is not None:
1287 | labels = labels.to(lm_logits.device)
1288 | shift_logits = lm_logits[..., :-1, :].contiguous()
1289 | shift_labels = labels[..., 1:].contiguous()
1290 | loss_fct = CrossEntropyLoss()
1291 | lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1292 |
1293 | if not return_dict:
1294 | output = (lm_logits, mc_logits) + transformer_outputs[1:]
1295 | if mc_loss is not None:
1296 | output = (mc_loss,) + output
1297 | return ((lm_loss,) + output) if lm_loss is not None else output
1298 |
1299 | return GPT2DoubleHeadsModelOutput(
1300 | loss=lm_loss,
1301 | mc_loss=mc_loss,
1302 | logits=lm_logits,
1303 | mc_logits=mc_logits,
1304 | past_key_values=transformer_outputs.past_key_values,
1305 | hidden_states=transformer_outputs.hidden_states,
1306 | attentions=transformer_outputs.attentions,
1307 | )
1308 |
1309 | @staticmethod
1310 | def _reorder_cache(
1311 | past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1312 | ) -> Tuple[Tuple[torch.Tensor]]:
1313 | """
1314 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1315 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1316 | beam_idx at every generation step.
1317 | """
1318 | return tuple(
1319 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1320 | for layer_past in past_key_values
1321 | )
1322 |
1323 |
1324 | @add_start_docstrings(
1325 | """
1326 | The GPT2 Model transformer with a sequence classification head on top (linear layer).
1327 |
1328 | [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1329 | (e.g. GPT-1) do.
1330 |
1331 | Since it does classification on the last token, it requires to know the position of the last token. If a
1332 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1333 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1334 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1335 | each row of the batch).
1336 | """,
1337 | GPT2_START_DOCSTRING,
1338 | )
1339 | class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1340 | def __init__(self, config):
1341 | super().__init__(config)
1342 | self.num_labels = config.num_labels
1343 | self.transformer = GPT2Model(config)
1344 | self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1345 |
1346 | # Model parallel
1347 | self.model_parallel = False
1348 | self.device_map = None
1349 |
1350 | # Initialize weights and apply final processing
1351 | self.post_init()
1352 |
1353 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1354 | @add_code_sample_docstrings(
1355 | checkpoint="microsoft/DialogRPT-updown",
1356 | output_type=SequenceClassifierOutputWithPast,
1357 | config_class=_CONFIG_FOR_DOC,
1358 | )
1359 | def forward(
1360 | self,
1361 | input_ids: Optional[torch.LongTensor] = None,
1362 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1363 | attention_mask: Optional[torch.FloatTensor] = None,
1364 | token_type_ids: Optional[torch.LongTensor] = None,
1365 | position_ids: Optional[torch.LongTensor] = None,
1366 | head_mask: Optional[torch.FloatTensor] = None,
1367 | inputs_embeds: Optional[torch.FloatTensor] = None,
1368 | labels: Optional[torch.LongTensor] = None,
1369 | use_cache: Optional[bool] = None,
1370 | output_attentions: Optional[bool] = None,
1371 | output_hidden_states: Optional[bool] = None,
1372 | return_dict: Optional[bool] = None,
1373 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1374 | r"""
1375 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1376 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1377 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1378 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1379 | """
1380 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1381 |
1382 | transformer_outputs = self.transformer(
1383 | input_ids,
1384 | past_key_values=past_key_values,
1385 | attention_mask=attention_mask,
1386 | token_type_ids=token_type_ids,
1387 | position_ids=position_ids,
1388 | head_mask=head_mask,
1389 | inputs_embeds=inputs_embeds,
1390 | use_cache=use_cache,
1391 | output_attentions=output_attentions,
1392 | output_hidden_states=output_hidden_states,
1393 | return_dict=return_dict,
1394 | )
1395 | hidden_states = transformer_outputs[0]
1396 | logits = self.score(hidden_states)
1397 |
1398 | if input_ids is not None:
1399 | batch_size, sequence_length = input_ids.shape[:2]
1400 | else:
1401 | batch_size, sequence_length = inputs_embeds.shape[:2]
1402 |
1403 | assert (
1404 | self.config.pad_token_id is not None or batch_size == 1
1405 | ), "Cannot handle batch sizes > 1 if no padding token is defined."
1406 | if self.config.pad_token_id is None:
1407 | sequence_lengths = -1
1408 | else:
1409 | if input_ids is not None:
1410 | # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1411 | sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1412 | sequence_lengths = sequence_lengths % input_ids.shape[-1]
1413 | sequence_lengths = sequence_lengths.to(logits.device)
1414 | else:
1415 | sequence_lengths = -1
1416 | logger.warning_once(
1417 | f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1418 | "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1419 | )
1420 |
1421 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1422 |
1423 | loss = None
1424 | if labels is not None:
1425 | if self.config.problem_type is None:
1426 | if self.num_labels == 1:
1427 | self.config.problem_type = "regression"
1428 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1429 | self.config.problem_type = "single_label_classification"
1430 | else:
1431 | self.config.problem_type = "multi_label_classification"
1432 |
1433 | if self.config.problem_type == "regression":
1434 | loss_fct = MSELoss()
1435 | if self.num_labels == 1:
1436 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1437 | else:
1438 | loss = loss_fct(pooled_logits, labels)
1439 | elif self.config.problem_type == "single_label_classification":
1440 | loss_fct = CrossEntropyLoss()
1441 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1442 | elif self.config.problem_type == "multi_label_classification":
1443 | loss_fct = BCEWithLogitsLoss()
1444 | loss = loss_fct(pooled_logits, labels)
1445 | if not return_dict:
1446 | output = (pooled_logits,) + transformer_outputs[1:]
1447 | return ((loss,) + output) if loss is not None else output
1448 |
1449 | return SequenceClassifierOutputWithPast(
1450 | loss=loss,
1451 | logits=pooled_logits,
1452 | past_key_values=transformer_outputs.past_key_values,
1453 | hidden_states=transformer_outputs.hidden_states,
1454 | attentions=transformer_outputs.attentions,
1455 | )
1456 |
1457 |
1458 | @add_start_docstrings(
1459 | """
1460 | GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1461 | Named-Entity-Recognition (NER) tasks.
1462 | """,
1463 | GPT2_START_DOCSTRING,
1464 | )
1465 | class GPT2ForTokenClassification(GPT2PreTrainedModel):
1466 | def __init__(self, config):
1467 | super().__init__(config)
1468 | self.num_labels = config.num_labels
1469 |
1470 | self.transformer = GPT2Model(config)
1471 | if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1472 | classifier_dropout = config.classifier_dropout
1473 | elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1474 | classifier_dropout = config.hidden_dropout
1475 | else:
1476 | classifier_dropout = 0.1
1477 | self.dropout = nn.Dropout(classifier_dropout)
1478 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1479 |
1480 | # Model parallel
1481 | self.model_parallel = False
1482 | self.device_map = None
1483 |
1484 | # Initialize weights and apply final processing
1485 | self.post_init()
1486 |
1487 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1488 | # fmt: off
1489 | @add_code_sample_docstrings(
1490 | checkpoint="brad1141/gpt2-finetuned-comp2",
1491 | output_type=TokenClassifierOutput,
1492 | config_class=_CONFIG_FOR_DOC,
1493 | expected_loss=0.25,
1494 | expected_output=[
1495 | "Lead",
1496 | "Lead",
1497 | "Lead",
1498 | "Position",
1499 | "Lead",
1500 | "Lead",
1501 | "Lead",
1502 | "Lead",
1503 | "Lead",
1504 | "Lead",
1505 | "Lead",
1506 | "Lead",
1507 | ],
1508 | )
1509 | # fmt: on
1510 | def forward(
1511 | self,
1512 | input_ids: Optional[torch.LongTensor] = None,
1513 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1514 | attention_mask: Optional[torch.FloatTensor] = None,
1515 | token_type_ids: Optional[torch.LongTensor] = None,
1516 | position_ids: Optional[torch.LongTensor] = None,
1517 | head_mask: Optional[torch.FloatTensor] = None,
1518 | inputs_embeds: Optional[torch.FloatTensor] = None,
1519 | labels: Optional[torch.LongTensor] = None,
1520 | use_cache: Optional[bool] = None,
1521 | output_attentions: Optional[bool] = None,
1522 | output_hidden_states: Optional[bool] = None,
1523 | return_dict: Optional[bool] = None,
1524 | ) -> Union[Tuple, TokenClassifierOutput]:
1525 | r"""
1526 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1527 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1528 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1529 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1530 | """
1531 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1532 |
1533 | transformer_outputs = self.transformer(
1534 | input_ids,
1535 | past_key_values=past_key_values,
1536 | attention_mask=attention_mask,
1537 | token_type_ids=token_type_ids,
1538 | position_ids=position_ids,
1539 | head_mask=head_mask,
1540 | inputs_embeds=inputs_embeds,
1541 | use_cache=use_cache,
1542 | output_attentions=output_attentions,
1543 | output_hidden_states=output_hidden_states,
1544 | return_dict=return_dict,
1545 | )
1546 |
1547 | hidden_states = transformer_outputs[0]
1548 | hidden_states = self.dropout(hidden_states)
1549 | logits = self.classifier(hidden_states)
1550 |
1551 | loss = None
1552 | if labels is not None:
1553 | labels = labels.to(logits.device)
1554 | loss_fct = CrossEntropyLoss()
1555 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1556 |
1557 | if not return_dict:
1558 | output = (logits,) + transformer_outputs[2:]
1559 | return ((loss,) + output) if loss is not None else output
1560 |
1561 | return TokenClassifierOutput(
1562 | loss=loss,
1563 | logits=logits,
1564 | hidden_states=transformer_outputs.hidden_states,
1565 | attentions=transformer_outputs.attentions,
1566 | )
1567 |
1568 |
1569 | @add_start_docstrings(
1570 | """
1571 | The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like
1572 | SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1573 | """,
1574 | GPT2_START_DOCSTRING,
1575 | )
1576 | class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
1577 | def __init__(self, config):
1578 | super().__init__(config)
1579 | self.num_labels = config.num_labels
1580 | self.transformer = GPT2Model(config)
1581 | self.qa_outputs = nn.Linear(config.hidden_size, 2)
1582 |
1583 | # Model parallel
1584 | self.model_parallel = False
1585 | self.device_map = None
1586 |
1587 | # Initialize weights and apply final processing
1588 | self.post_init()
1589 |
1590 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1591 | @add_code_sample_docstrings(
1592 | checkpoint=_CHECKPOINT_FOR_DOC,
1593 | output_type=QuestionAnsweringModelOutput,
1594 | config_class=_CONFIG_FOR_DOC,
1595 | real_checkpoint=_CHECKPOINT_FOR_DOC,
1596 | )
1597 | def forward(
1598 | self,
1599 | input_ids: Optional[torch.LongTensor] = None,
1600 | attention_mask: Optional[torch.FloatTensor] = None,
1601 | token_type_ids: Optional[torch.LongTensor] = None,
1602 | position_ids: Optional[torch.LongTensor] = None,
1603 | head_mask: Optional[torch.FloatTensor] = None,
1604 | inputs_embeds: Optional[torch.FloatTensor] = None,
1605 | start_positions: Optional[torch.LongTensor] = None,
1606 | end_positions: Optional[torch.LongTensor] = None,
1607 | output_attentions: Optional[bool] = None,
1608 | output_hidden_states: Optional[bool] = None,
1609 | return_dict: Optional[bool] = None,
1610 | ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1611 | r"""
1612 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1613 | Labels for position (index) of the start of the labelled span for computing the token classification loss.
1614 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1615 | are not taken into account for computing the loss.
1616 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1617 | Labels for position (index) of the end of the labelled span for computing the token classification loss.
1618 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1619 | are not taken into account for computing the loss.
1620 | """
1621 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1622 |
1623 | outputs = self.transformer(
1624 | input_ids,
1625 | attention_mask=attention_mask,
1626 | token_type_ids=token_type_ids,
1627 | position_ids=position_ids,
1628 | head_mask=head_mask,
1629 | inputs_embeds=inputs_embeds,
1630 | output_attentions=output_attentions,
1631 | output_hidden_states=output_hidden_states,
1632 | return_dict=return_dict,
1633 | )
1634 |
1635 | sequence_output = outputs[0]
1636 |
1637 | logits = self.qa_outputs(sequence_output)
1638 | start_logits, end_logits = logits.split(1, dim=-1)
1639 | start_logits = start_logits.squeeze(-1).contiguous()
1640 | end_logits = end_logits.squeeze(-1).contiguous()
1641 |
1642 | total_loss = None
1643 | if start_positions is not None and end_positions is not None:
1644 | # If we are on multi-GPU, split add a dimension
1645 | if len(start_positions.size()) > 1:
1646 | start_positions = start_positions.squeeze(-1).to(start_logits.device)
1647 | if len(end_positions.size()) > 1:
1648 | end_positions = end_positions.squeeze(-1).to(end_logits.device)
1649 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
1650 | ignored_index = start_logits.size(1)
1651 | start_positions = start_positions.clamp(0, ignored_index)
1652 | end_positions = end_positions.clamp(0, ignored_index)
1653 |
1654 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1655 | start_loss = loss_fct(start_logits, start_positions)
1656 | end_loss = loss_fct(end_logits, end_positions)
1657 | total_loss = (start_loss + end_loss) / 2
1658 |
1659 | if not return_dict:
1660 | output = (start_logits, end_logits) + outputs[2:]
1661 | return ((total_loss,) + output) if total_loss is not None else output
1662 |
1663 | return QuestionAnsweringModelOutput(
1664 | loss=total_loss,
1665 | start_logits=start_logits,
1666 | end_logits=end_logits,
1667 | hidden_states=outputs.hidden_states,
1668 | attentions=outputs.attentions,
1669 | )
1670 |
1671 |
1672 | __all__ = [
1673 | "GPT2DoubleHeadsModel",
1674 | "GPT2ForQuestionAnswering",
1675 | "GPT2ForSequenceClassification",
1676 | "GPT2ForTokenClassification",
1677 | "GPT2LMHeadModel",
1678 | "GPT2Model",
1679 | "GPT2PreTrainedModel",
1680 | "load_tf_weights_in_gpt2",
1681 | ]
1682 |
--------------------------------------------------------------------------------
/train_gpt2.py:
--------------------------------------------------------------------------------
1 | """
2 | This training script can be run both on a single gpu in debug mode,
3 | and also in a larger training run with distributed data parallel (ddp).
4 |
5 | To run on a single GPU, example:
6 | $ python train.py --batch_size=32 --compile=False
7 |
8 | To run with DDP on 4 gpus on 1 node, example:
9 | $ torchrun --standalone --nproc_per_node=4 train.py
10 | """
11 |
12 | import os
13 | import time
14 | import math
15 | import pickle
16 | from contextlib import nullcontext
17 | import inspect
18 |
19 | import numpy as np
20 | import torch
21 | from torch.nn.parallel import DistributedDataParallel as DDP
22 | from torch.distributed import init_process_group, destroy_process_group
23 | from modeling_gpt2 import GPT2Model, GPT2LMHeadModel, GPT2Config
24 |
25 |
26 | overall_name = "gpt2_llama_0.5B_4k_100k"
27 | # -----------------------------------------------------------------------------
28 | # default config values designed to train a gpt2 (124M) on OpenWebText
29 | # I/O
30 | out_dir = 'out_' + overall_name
31 | eval_interval = 1000
32 | log_interval = 1
33 | eval_iters = 200
34 | eval_only = False # if True, script exits right after the first eval
35 | always_save_checkpoint = True # if True, always save a checkpoint after each eval
36 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
37 | # wandb logging
38 | wandb_log = True # disabled by default
39 | wandb_project = 'owt'
40 | wandb_run_name = overall_name
41 | # data
42 | dataset = 'openwebtext_llama'
43 | total_batch_size = 524288
44 | global_batch = 512
45 | batch_size = 8 # if gradient_accumulation_steps > 1, this is the micro-batch size
46 | block_size = 1024
47 | #gradient_accumulation_steps = total_batch_size // (batch_size * block_size) # used to simulate larger batch sizes
48 | gradient_accumulation_steps = global_batch // batch_size
49 |
50 | # adamw optimizer
51 | learning_rate = 30e-4 # max learning rate
52 | max_iters = 100000 # total number of training iterations
53 | weight_decay = 0.1
54 | beta1 = 0.9
55 | beta2 = 0.95
56 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
57 | # learning rate decay settings
58 | decay_lr = True # whether to decay the learning rate
59 | warmup_iters = 2000 # how many steps to warm up for
60 | lr_decay_iters = 100000 # should be ~= max_iters per Chinchilla
61 | min_lr = 0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
62 | # DDP settings
63 | backend = 'nccl' # 'nccl', 'gloo', etc.
64 | # system
65 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
66 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
67 | compile = True # use PyTorch 2.0 to compile the model to be faster
68 | # -----------------------------------------------------------------------------
69 | config = {"lr": learning_rate, "weight_decay": weight_decay, "decay_lr": decay_lr, "warmup_iters": warmup_iters, "min_lr": min_lr}
70 | # various inits, derived attributes, I/O setup
71 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
72 | if ddp:
73 | init_process_group(backend=backend)
74 | ddp_rank = int(os.environ['RANK'])
75 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
76 | ddp_world_size = int(os.environ['WORLD_SIZE'])
77 | device = f'cuda:{ddp_local_rank}'
78 | torch.cuda.set_device(device)
79 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
80 | seed_offset = ddp_rank # each process gets a different seed
81 | # world_size number of processes will be training simultaneously, so we can scale
82 | # down the desired gradient accumulation iterations per process proportionally
83 | assert gradient_accumulation_steps % ddp_world_size == 0
84 | gradient_accumulation_steps //= ddp_world_size
85 | else:
86 | # if not ddp, we are running on a single gpu, and one process
87 | master_process = True
88 | seed_offset = 0
89 | ddp_world_size = 1
90 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
91 | if master_process:
92 | print(f"tokens per iteration will be: {tokens_per_iter:,}")
93 |
94 | if master_process:
95 | os.makedirs(out_dir, exist_ok=True)
96 | torch.manual_seed(1337 + seed_offset)
97 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
98 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
99 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
100 | # note: float16 data type will automatically use a GradScaler
101 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
102 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
103 |
104 | # poor man's data loader
105 | data_dir = os.path.join('data', dataset)
106 | def get_batch(split):
107 | # We recreate np.memmap every batch to avoid a memory leak, as per
108 | # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
109 | if split == 'train':
110 | data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
111 | else:
112 | data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
113 | ix = torch.randint(len(data) - block_size, (batch_size,))
114 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
115 | y = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
116 | if device_type == 'cuda':
117 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
118 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
119 | else:
120 | x, y = x.to(device), y.to(device)
121 | return x, y
122 |
123 | def configure_optimizers(model, weight_decay, learning_rate, device_type):
124 | param_dict = {pn: p for pn, p in model.named_parameters()}
125 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
126 |
127 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
128 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
129 | optim_groups = [
130 | {'params': decay_params, 'weight_decay': weight_decay},
131 | {'params': nodecay_params, 'weight_decay': 0.0}
132 | ]
133 | num_decay_params = sum(p.numel() for p in decay_params)
134 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
135 | if master_process:
136 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
137 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
138 | # Create AdamW optimizer and use the fused version if it is available
139 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
140 | use_fused = fused_available and device_type == "cuda"
141 | if master_process:
142 | print(f"using fused AdamW: {use_fused}")
143 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
144 | return optimizer
145 |
146 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
147 | iter_num = 0
148 | best_val_loss = 1e9
149 |
150 | # model init
151 | '''model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
152 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line'''
153 | if init_from == 'scratch':
154 | # init a new model from scratch
155 | if master_process:
156 | print("Initializing a new model from scratch")
157 | # determine the vocab size we'll use for from-scratch training
158 |
159 | #model = GPT2LMHeadModel(GPT2Config(vocab_size=50304, n_positions=block_size, n_embd=1024, n_layer=24, n_head=16, n_inner=4096))
160 | model = GPT2LMHeadModel(GPT2Config(vocab_size=32000, n_positions=block_size, n_embd=1024, n_layer=24, n_head=16, n_inner=4096, bos_token_id=1, eos_token_id=2))
161 | elif init_from == 'resume':
162 | print(f"Resuming training from {out_dir}")
163 | # resume training from a checkpoint.
164 | ckpt_path = os.path.join(out_dir, 'ckpt.pt')
165 | checkpoint = torch.load(ckpt_path, map_location=device)
166 | checkpoint_model_args = checkpoint['model_args']
167 | # force these config attributes to be equal otherwise we can't even resume training
168 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
169 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
170 | model_args[k] = checkpoint_model_args[k]
171 | # create the model
172 | gptconf = GPTConfig(**model_args)
173 | model = GPT(gptconf)
174 | state_dict = checkpoint['model']
175 | # fix the keys of the state dictionary :(
176 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
177 | unwanted_prefix = '_orig_mod.'
178 | for k,v in list(state_dict.items()):
179 | if k.startswith(unwanted_prefix):
180 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
181 | model.load_state_dict(state_dict)
182 | iter_num = checkpoint['iter_num']
183 | best_val_loss = checkpoint['best_val_loss']
184 |
185 | model.to(device)
186 |
187 | # initialize a GradScaler. If enabled=False scaler is a no-op
188 | scaler = torch.GradScaler("cuda", enabled=(dtype == 'float16'))
189 |
190 | # optimizer
191 | optimizer = configure_optimizers(model, weight_decay, learning_rate, device_type)
192 |
193 | if init_from == 'resume':
194 | optimizer.load_state_dict(checkpoint['optimizer'])
195 | checkpoint = None # free up memory
196 |
197 | # compile the model
198 | if compile:
199 | if master_process:
200 | print("compiling the model... (takes a ~minute)")
201 | unoptimized_model = model
202 | model = torch.compile(model) # requires PyTorch 2.0
203 |
204 | # wrap model into DDP container
205 | if ddp:
206 | model = DDP(model, device_ids=[ddp_local_rank])
207 |
208 | # helps estimate an arbitrarily accurate loss over either split using many batches
209 | @torch.no_grad()
210 | def estimate_loss():
211 | out = {}
212 | model.eval()
213 | for split in ['train', 'val']:
214 | losses = torch.zeros(eval_iters)
215 | for k in range(eval_iters):
216 | X, Y = get_batch(split)
217 | with ctx:
218 | ret = model(input_ids=X, labels=Y)
219 | logits,loss = ret["logits"], ret["loss"]
220 | losses[k] = loss.item()
221 | out[split] = losses.mean()
222 | model.train()
223 | return out
224 |
225 | # learning rate decay scheduler (cosine with warmup)
226 | def get_lr(it):
227 | # 1) linear warmup for warmup_iters steps
228 | if it < warmup_iters:
229 | return learning_rate * (it + 1) / (warmup_iters + 1)
230 | # 2) if it > lr_decay_iters, return min learning rate
231 | if it > lr_decay_iters:
232 | return min_lr
233 | # 3) in between, use cosine decay down to min learning rate
234 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
235 | assert 0 <= decay_ratio <= 1
236 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
237 | return min_lr + coeff * (learning_rate - min_lr)
238 |
239 | # logging
240 | if wandb_log and master_process:
241 | import wandb
242 | wandb.init(project=wandb_project, name=wandb_run_name, config=config)
243 |
244 | # training loop
245 | X, Y = get_batch('train') # fetch the very first batch
246 | t0 = time.time()
247 | local_iter_num = 0 # number of iterations in the lifetime of this process
248 | raw_model = model.module if ddp else model # unwrap DDP container if needed
249 | running_mfu = -1.0
250 | while True:
251 |
252 | # determine and set the learning rate for this iteration
253 | lr = get_lr(iter_num) if decay_lr else learning_rate
254 | for param_group in optimizer.param_groups:
255 | param_group['lr'] = lr
256 |
257 | # evaluate the loss on train/val sets and write checkpoints
258 | if iter_num % eval_interval == 0 and master_process:
259 | losses = estimate_loss()
260 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
261 | if wandb_log:
262 | wandb.log({
263 | "iter": iter_num,
264 | "train/loss": losses['train'],
265 | "val/loss": losses['val'],
266 | "lr": lr,
267 | })
268 | if losses['val'] < best_val_loss or always_save_checkpoint:
269 | best_val_loss = losses['val']
270 | if iter_num > 0:
271 | checkpoint = {
272 | 'model': raw_model.state_dict(),
273 | 'optimizer': optimizer.state_dict(),
274 | 'iter_num': iter_num,
275 | 'best_val_loss': best_val_loss,
276 | 'config': config,
277 | }
278 | print(f"saving checkpoint to {out_dir}")
279 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
280 | if iter_num == 0 and eval_only:
281 | break
282 |
283 | # forward backward update, with optional gradient accumulation to simulate larger batch size
284 | # and using the GradScaler if data type is float16
285 | for micro_step in range(gradient_accumulation_steps):
286 | if ddp:
287 | # in DDP training we only need to sync gradients at the last micro step.
288 | # the official way to do this is with model.no_sync() context manager, but
289 | # I really dislike that this bloats the code and forces us to repeat code
290 | # looking at the source of that context manager, it just toggles this variable
291 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
292 | with ctx:
293 | ret = model(input_ids=X, labels=Y)
294 | logits,loss = ret["logits"], ret["loss"]
295 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
296 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
297 | X, Y = get_batch('train')
298 | # backward pass, with gradient scaling if training in fp16
299 | scaler.scale(loss).backward()
300 | # clip the gradient
301 | if grad_clip != 0.0:
302 | scaler.unscale_(optimizer)
303 | norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
304 | # step the optimizer and scaler if training in fp16
305 | scaler.step(optimizer)
306 | scaler.update()
307 | # flush the gradients as soon as we can, no need for this memory anymore
308 | optimizer.zero_grad(set_to_none=True)
309 | torch.cuda.synchronize()
310 | # timing and logging
311 | t1 = time.time()
312 | dt = t1 - t0
313 | t0 = t1
314 | if iter_num % log_interval == 0 and master_process:
315 | # get loss as float. note: this is a CPU-GPU sync point
316 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
317 | lossf = loss.item() * gradient_accumulation_steps
318 | tokens_per_sec = tokens_per_iter / dt
319 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, lr {lr}, norm: {norm: .4f}, tok/sec: {tokens_per_sec}")
320 | iter_num += 1
321 | local_iter_num += 1
322 |
323 | # termination conditions
324 | if iter_num > max_iters:
325 | break
326 |
327 | if ddp:
328 | destroy_process_group()
329 |
--------------------------------------------------------------------------------
/train_ngpt.py:
--------------------------------------------------------------------------------
1 | """
2 | This training script can be run both on a single gpu in debug mode,
3 | and also in a larger training run with distributed data parallel (ddp).
4 |
5 | To run on a single GPU, example:
6 | $ python train.py --batch_size=32 --compile=False
7 |
8 | To run with DDP on 4 gpus on 1 node, example:
9 | $ torchrun --standalone --nproc_per_node=4 train.py
10 |
11 | """
12 |
13 | import os
14 | import time
15 | import math
16 | import pickle
17 | from contextlib import nullcontext
18 | import inspect
19 |
20 | import numpy as np
21 | import torch
22 | from torch.nn.parallel import DistributedDataParallel as DDP
23 | from torch.distributed import init_process_group, destroy_process_group
24 | from modeling_ngpt import NgptModel, NgptLMHeadModel, NgptConfig
25 |
26 | # -----------------------------------------------------------------------------
27 | # default config values designed to train a Ngpt (124M) on OpenWebText
28 | # I/O
29 | overall_name = "ngpt_llama_0.5B_1k_100k"
30 |
31 | out_dir = 'out' + overall_name
32 | eval_interval = 2000
33 | log_interval = 1
34 | eval_iters = 200
35 | eval_only = False # if True, script exits right after the first eval
36 | always_save_checkpoint = True # if True, always save a checkpoint after each eval
37 | init_from = 'scratch' # 'scratch' or 'resume' or 'Ngpt*'
38 | # wandb logging
39 | wandb_log = True # disabled by default
40 | wandb_project = 'owt'
41 | wandb_run_name = overall_name # 'run' + str(time.time())
42 | # data
43 | dataset = 'openwebtext_llama'
44 | global_batch = 512
45 | batch_size = 4 # if gradient_accumulation_steps > 1, this is the micro-batch size
46 | block_size = 1024
47 | #gradient_accumulation_steps = total_batch_size // (batch_size * block_size) # used to simulate larger batch sizes
48 | gradient_accumulation_steps = global_batch // batch_size
49 |
50 | # adamw optimizer
51 | learning_rate = 30e-4 # max learning rate
52 | max_iters = 100000 # total number of training iterations
53 | weight_decay = 0.0
54 | beta1 = 0.9
55 | beta2 = 0.95
56 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
57 | # learning rate decay settings
58 | decay_lr = True # whether to decay the learning rate
59 | warmup_iters = 0 # how many steps to warm up for
60 | lr_decay_iters = 100000 # should be ~= max_iters per Chinchilla
61 | min_lr = 0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
62 | # DDP settings
63 | backend = 'nccl' # 'nccl', 'gloo', etc.
64 | # system
65 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
66 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
67 | compile = True # use PyTorch 2.0 to compile the model to be faster
68 | # -----------------------------------------------------------------------------
69 | config = {"lr": learning_rate, "weight_decay": weight_decay, "decay_lr": decay_lr, "warmup_iters": warmup_iters, "min_lr": min_lr}
70 | # various inits, derived attributes, I/O setup
71 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
72 | if ddp:
73 | init_process_group(backend=backend)
74 | ddp_rank = int(os.environ['RANK'])
75 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
76 | ddp_world_size = int(os.environ['WORLD_SIZE'])
77 | device = f'cuda:{ddp_local_rank}'
78 | torch.cuda.set_device(device)
79 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
80 | seed_offset = ddp_rank # each process gets a different seed
81 | # world_size number of processes will be training simultaneously, so we can scale
82 | # down the desired gradient accumulation iterations per process proportionally
83 | assert gradient_accumulation_steps % ddp_world_size == 0
84 | gradient_accumulation_steps //= ddp_world_size
85 | else:
86 | # if not ddp, we are running on a single gpu, and one process
87 | master_process = True
88 | seed_offset = 0
89 | ddp_world_size = 1
90 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
91 | if master_process:
92 | print(f"tokens per iteration will be: {tokens_per_iter:,}")
93 |
94 | if master_process:
95 | os.makedirs(out_dir, exist_ok=True)
96 | torch.manual_seed(1337 + seed_offset)
97 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
98 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
99 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
100 | # note: float16 data type will automatically use a GradScaler
101 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
102 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
103 |
104 | # poor man's data loader
105 | data_dir = os.path.join('data', dataset)
106 | def get_batch(split):
107 | # We recreate np.memmap every batch to avoid a memory leak, as per
108 | # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
109 | if split == 'train':
110 | data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
111 | else:
112 | data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
113 | ix = torch.randint(len(data) - block_size, (batch_size,))
114 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
115 | y = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
116 | if device_type == 'cuda':
117 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
118 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
119 | else:
120 | x, y = x.to(device), y.to(device)
121 | return x, y
122 |
123 | def configure_optimizers(model, weight_decay, learning_rate, device_type):
124 | param_dict = {pn: p for pn, p in model.named_parameters()}
125 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
126 |
127 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
128 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
129 | optim_groups = [
130 | {'params': decay_params, 'weight_decay': weight_decay},
131 | {'params': nodecay_params, 'weight_decay': 0.0}
132 | ]
133 | num_decay_params = sum(p.numel() for p in decay_params)
134 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
135 | if master_process:
136 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
137 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
138 | # Create AdamW optimizer and use the fused version if it is available
139 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
140 | use_fused = fused_available and device_type == "cuda"
141 | if master_process:
142 | print(f"using fused AdamW: {use_fused}")
143 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
144 | return optimizer
145 |
146 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
147 | iter_num = 0
148 | best_val_loss = 1e9
149 |
150 | # model init
151 | '''model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
152 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line'''
153 |
154 | #model_config = NgptConfig(vocab_size=50304, n_embd=1024, n_layer=24, n_head=16, n_inner=4096)
155 | model_config = NgptConfig(vocab_size=32000, n_positions=block_size, n_embd=1024, n_layer=24, n_head=16, n_inner=4096, bos_token_id = 1, eos_token_id = 2)
156 | if init_from == 'scratch':
157 | # init a new model from scratch
158 | if master_process:
159 | print("Initializing a new model from scratch")
160 | # determine the vocab size we'll use for from-scratch training
161 |
162 | model = NgptLMHeadModel(model_config)
163 | elif init_from == 'resume':
164 | print(f"Resuming training from {out_dir}")
165 | # resume training from a checkpoint.
166 | ckpt_path = os.path.join(out_dir, 'ckpt.pt')
167 | checkpoint = torch.load(ckpt_path, map_location=device)
168 | checkpoint_model_args = checkpoint['model_args']
169 | # force these config attributes to be equal otherwise we can't even resume training
170 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
171 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
172 | model_args[k] = checkpoint_model_args[k]
173 | # create the model
174 | gptconf = GPTConfig(**model_args)
175 | model = GPT(gptconf)
176 | state_dict = checkpoint['model']
177 | # fix the keys of the state dictionary :(
178 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
179 | unwanted_prefix = '_orig_mod.'
180 | for k,v in list(state_dict.items()):
181 | if k.startswith(unwanted_prefix):
182 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
183 | model.load_state_dict(state_dict)
184 | iter_num = checkpoint['iter_num']
185 | best_val_loss = checkpoint['best_val_loss']
186 |
187 | model.to(device)
188 |
189 | # initialize a GradScaler. If enabled=False scaler is a no-op
190 | scaler = torch.GradScaler("cuda", enabled=(dtype == 'float16'))
191 |
192 | # optimizer
193 | optimizer = configure_optimizers(model, weight_decay, learning_rate, device_type)
194 |
195 | if init_from == 'resume':
196 | optimizer.load_state_dict(checkpoint['optimizer'])
197 | checkpoint = None # free up memory
198 |
199 | # compile the model
200 | if compile:
201 | if master_process:
202 | print("compiling the model... (takes a ~minute)")
203 | unoptimized_model = model
204 | model = torch.compile(model) # requires PyTorch 2.0
205 |
206 | # wrap model into DDP container
207 | if ddp:
208 | model = DDP(model, device_ids=[ddp_local_rank])
209 |
210 | # helps estimate an arbitrarily accurate loss over either split using many batches
211 | @torch.no_grad()
212 | def estimate_loss():
213 | out = {}
214 | model.eval()
215 | for split in ['train', 'val']:
216 | losses = torch.zeros(eval_iters)
217 | for k in range(eval_iters):
218 | X, Y = get_batch(split)
219 | with ctx:
220 | ret = model(input_ids=X, labels=Y)
221 | logits,loss = ret["logits"], ret["loss"]
222 | losses[k] = loss.item()
223 | out[split] = losses.mean()
224 | model.train()
225 | return out
226 |
227 | # learning rate decay scheduler (cosine with warmup)
228 | def get_lr(it):
229 | # 1) linear warmup for warmup_iters steps
230 | if it < warmup_iters:
231 | return learning_rate * (it + 1) / (warmup_iters + 1)
232 | # 2) if it > lr_decay_iters, return min learning rate
233 | if it > lr_decay_iters:
234 | return min_lr
235 | # 3) in between, use cosine decay down to min learning rate
236 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
237 | assert 0 <= decay_ratio <= 1
238 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
239 | return min_lr + coeff * (learning_rate - min_lr)
240 |
241 | # logging
242 | if wandb_log and master_process:
243 | import wandb
244 | wandb.init(project=wandb_project, name=wandb_run_name, config=config)
245 |
246 | # training loop
247 | X, Y = get_batch('train') # fetch the very first batch
248 | t0 = time.time()
249 | local_iter_num = 0 # number of iterations in the lifetime of this process
250 | raw_model = model.module if ddp else model # unwrap DDP container if needed
251 | running_mfu = -1.0
252 |
253 | def justnorm(x, idim):
254 | dtype = x.dtype
255 | x = x.float()
256 | res = (x / x.norm(p=2, dim=idim, keepdim=True)).to(dtype=dtype)
257 | return res
258 |
259 | def normalize_matrices():
260 | #print(raw_model.test.weight.data.shape)
261 | #print(raw_model.transformer.wte.weight.data.shape)
262 | #print(raw_model.lm_head.weight.data.shape)
263 | raw_model.transformer.wte.weight.data.copy_(justnorm(raw_model.transformer.wte.weight.data, 1))
264 | raw_model.lm_head.weight.data.copy_(justnorm(raw_model.lm_head.weight.data, 1))
265 |
266 | for layer_idx in range(model_config.num_hidden_layers):
267 | block = raw_model.transformer.h[layer_idx]
268 |
269 | #print(block.attn.c_attn.weight.data.shape)
270 | block.attn.c_attn.weight.data.copy_(justnorm(block.attn.c_attn.weight.data, 0)) #n_embd, 3*n_embd
271 | #print(block.attn.c_proj.weight.data.shape)
272 | block.attn.c_proj.weight.data.copy_(justnorm(block.attn.c_proj.weight.data, 1)) #n_proj, n_embd
273 |
274 | #print(block.mlp.c_fc.weight.data.shape)
275 | #print(block.mlp.c_proj.weight.data.shape)
276 | block.mlp.c_fc.weight.data.copy_(justnorm(block.mlp.c_fc.weight.data, 0))
277 | block.mlp.c_proj.weight.data.copy_(justnorm(block.mlp.c_proj.weight.data, 1))
278 |
279 | normalize_matrices()
280 |
281 | while True:
282 |
283 | # determine and set the learning rate for this iteration
284 | lr = get_lr(iter_num) if decay_lr else learning_rate
285 | for param_group in optimizer.param_groups:
286 | param_group['lr'] = lr
287 |
288 | # evaluate the loss on train/val sets and write checkpoints
289 | if iter_num % eval_interval == 0 and master_process:
290 | losses = estimate_loss()
291 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
292 | if wandb_log:
293 | wandb.log({
294 | "iter": iter_num,
295 | "train/loss": losses['train'],
296 | "val/loss": losses['val'],
297 | "lr": lr,
298 | })
299 | if losses['val'] < best_val_loss or always_save_checkpoint:
300 | best_val_loss = losses['val']
301 | if iter_num > 0:
302 | checkpoint = {
303 | 'model': raw_model.state_dict(),
304 | 'optimizer': optimizer.state_dict(),
305 | 'iter_num': iter_num,
306 | 'best_val_loss': best_val_loss,
307 | 'config': config,
308 | }
309 | print(f"saving checkpoint to {out_dir}")
310 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
311 | if iter_num == 0 and eval_only:
312 | break
313 |
314 | # forward backward update, with optional gradient accumulation to simulate larger batch size
315 | # and using the GradScaler if data type is float16
316 | for micro_step in range(gradient_accumulation_steps):
317 | if ddp:
318 | # in DDP training we only need to sync gradients at the last micro step.
319 | # the official way to do this is with model.no_sync() context manager, but
320 | # I really dislike that this bloats the code and forces us to repeat code
321 | # looking at the source of that context manager, it just toggles this variable
322 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
323 | with ctx:
324 | ret = model(input_ids=X, labels=Y)
325 | logits,loss = ret["logits"], ret["loss"]
326 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
327 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
328 | X, Y = get_batch('train')
329 | # backward pass, with gradient scaling if training in fp16
330 | scaler.scale(loss).backward()
331 | # clip the gradient
332 | if grad_clip != 0.0:
333 | scaler.unscale_(optimizer)
334 | norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
335 | # step the optimizer and scaler if training in fp16
336 | scaler.step(optimizer)
337 | scaler.update()
338 | # flush the gradients as soon as we can, no need for this memory anymore
339 | optimizer.zero_grad(set_to_none=True)
340 | torch.cuda.synchronize()
341 |
342 | normalize_matrices()
343 | # timing and logging
344 | t1 = time.time()
345 | dt = t1 - t0
346 | t0 = t1
347 | if iter_num % log_interval == 0 and master_process:
348 | # get loss as float. note: this is a CPU-GPU sync point
349 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
350 | lossf = loss.item() * gradient_accumulation_steps
351 | tokens_per_sec = tokens_per_iter / dt
352 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, lr {lr}, norm: {norm: .4f}, tok/sec: {tokens_per_sec}")
353 | iter_num += 1
354 | local_iter_num += 1
355 |
356 | # termination conditions
357 | if iter_num > max_iters:
358 | break
359 |
360 | if ddp:
361 | destroy_process_group()
362 |
--------------------------------------------------------------------------------