├── assets
├── bird.png
├── equation.png
├── main_teaser.png
└── reg_feat_mean
│ ├── base_imagemean_bs10k_trainaug.pt
│ ├── giant_imagemean_bs10k_trainaug.pt
│ ├── large_imagemean_bs10k_trainaug.pt
│ └── small_imagemean_bs10k_trainaug.pt
├── monkey_patch
├── __init__.py
├── modify_vit.py
├── README.md
├── modify_phi2.py
├── modify_mistral.py
└── modify_llama.py
├── lib
├── __init__.py
├── hook.py
├── load_data.py
├── model_dict.py
├── eval_utils.py
├── plot_utils_vit.py
├── load_model.py
├── plot_utils_llm.py
└── plot_utils.py
├── gpt-2
├── config
│ ├── eval_gpt2_attn_bias.py
│ ├── eval_gpt2_default.py
│ ├── eval_gpt2_sink.py
│ ├── train_gpt2_attn_bias.py
│ ├── train_gpt2_default.py
│ └── train_gpt2_sink.py
├── configurator.py
├── README.md
├── plot_gpt2.py
├── test.py
├── analyze.py
├── train.py
├── model_default.py
├── model_sink.py
└── model_attn_bias.py
├── LICENSE
├── INSTALL.md
├── main_vit.py
├── README.md
└── main_llm.py
/assets/bird.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/bird.png
--------------------------------------------------------------------------------
/assets/equation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/equation.png
--------------------------------------------------------------------------------
/assets/main_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/main_teaser.png
--------------------------------------------------------------------------------
/monkey_patch/__init__.py:
--------------------------------------------------------------------------------
1 | from .modify_llama import *
2 | from .modify_mistral import *
3 | from .modify_vit import *
4 | from .modify_phi2 import *
--------------------------------------------------------------------------------
/assets/reg_feat_mean/base_imagemean_bs10k_trainaug.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/reg_feat_mean/base_imagemean_bs10k_trainaug.pt
--------------------------------------------------------------------------------
/assets/reg_feat_mean/giant_imagemean_bs10k_trainaug.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/reg_feat_mean/giant_imagemean_bs10k_trainaug.pt
--------------------------------------------------------------------------------
/assets/reg_feat_mean/large_imagemean_bs10k_trainaug.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/reg_feat_mean/large_imagemean_bs10k_trainaug.pt
--------------------------------------------------------------------------------
/assets/reg_feat_mean/small_imagemean_bs10k_trainaug.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/reg_feat_mean/small_imagemean_bs10k_trainaug.pt
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
1 | from .load_model import load_llm, load_vit, load_dinov2_linear_head
2 | from .eval_utils import test_imagenet, setup_dinov2_model_for_eval, fix_reg_mean, eval_ppl
3 | from .plot_utils_llm import plot_3d_feat, plot_layer_ax, plot_attn
4 | from .plot_utils_vit import plot_3d_feat_vit, plot_layer_ax_vit
5 | from .load_data import get_data
6 | from .hook import setup_intervene_hook
--------------------------------------------------------------------------------
/gpt-2/config/eval_gpt2_attn_bias.py:
--------------------------------------------------------------------------------
1 | # evaluate gpt2 model with a sink token
2 | # n_layer=12, n_head=12, n_embd=768
3 | batch_size = 8
4 | eval_iters = 500 # use more iterations to get good estimate
5 | eval_only = True
6 | wandb_log = False
7 | init_from = 'resume'
8 | ckpt_iter = 50000
9 | out_dir="../pretrained-models/attn_bias"
10 | data_dir = '/data/locus/project_data/project_data2/mingjies/nanoGPT/data'
11 | save_dir="results/gpt-2/attn_bias/"
12 | compile = False
13 | model_type = "gpt2_attn_bias"
--------------------------------------------------------------------------------
/gpt-2/config/eval_gpt2_default.py:
--------------------------------------------------------------------------------
1 | # evaluate gpt2 model with default architecture
2 | # n_layer=12, n_head=12, n_embd=768
3 | batch_size = 8
4 | eval_iters = 500 # use more iterations to get good estimate
5 | eval_only = True
6 | wandb_log = False
7 | init_from = 'resume'
8 | ckpt_iter = 50000
9 | out_dir="../pretrained-models/default"
10 | data_dir = '/data/locus/project_data/project_data2/mingjies/nanoGPT/data'
11 | save_dir="results/gpt-2/default/"
12 | compile = False
13 | model_type = "gpt2_default"
--------------------------------------------------------------------------------
/gpt-2/config/eval_gpt2_sink.py:
--------------------------------------------------------------------------------
1 | # evaluate gpt2 model with a sink token
2 | # n_layer=12, n_head=12, n_embd=768
3 | batch_size = 8
4 | eval_iters = 500 # use more iterations to get good estimate
5 | eval_only = True
6 | wandb_log = False
7 | init_from = 'resume'
8 | ckpt_iter = 50000
9 | out_dir="../pretrained-models/sink"
10 | data_dir = '/data/locus/project_data/project_data2/mingjies/nanoGPT/data'
11 | save_dir="results/gpt-2/sink/"
12 | compile = False
13 | model_type = "gpt2_sink"
14 | num_reg = 1
--------------------------------------------------------------------------------
/gpt-2/config/train_gpt2_attn_bias.py:
--------------------------------------------------------------------------------
1 | out_dir = 'results/'
2 | data_dir = '/data/locus/project_data/project_data2/mingjies/nanoGPT/data'
3 |
4 | wandb_log = False
5 | wandb_project = 'owt'
6 | wandb_run_name='gpt2-124M-default-run'
7 | compile=False
8 |
9 | # these make the total batch size be ~0.5M
10 | # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
11 | batch_size = 12
12 | block_size = 1024
13 | gradient_accumulation_steps = 5 * 8
14 |
15 | # this makes total number of tokens be 300B
16 | max_iters = 600000
17 | lr_decay_iters = 600000
18 |
19 | # eval stuff
20 | eval_interval = 1000
21 | eval_iters = 200
22 | log_interval = 10
23 |
24 | # weight decay
25 | weight_decay = 1e-1
26 |
27 | model_type = "gpt2_attn_bias"
--------------------------------------------------------------------------------
/gpt-2/config/train_gpt2_default.py:
--------------------------------------------------------------------------------
1 | out_dir = 'results/'
2 | data_dir = '/data/locus/project_data/project_data2/mingjies/nanoGPT/data'
3 |
4 | wandb_log = False
5 | wandb_project = 'owt'
6 | wandb_run_name='gpt2-124M-default-run'
7 | compile=False
8 |
9 | # these make the total batch size be ~0.5M
10 | # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
11 | batch_size = 12
12 | block_size = 1024
13 | gradient_accumulation_steps = 5 * 8
14 |
15 | # this makes total number of tokens be 300B
16 | max_iters = 600000
17 | lr_decay_iters = 600000
18 |
19 | # eval stuff
20 | eval_interval = 1000
21 | eval_iters = 200
22 | log_interval = 10
23 |
24 | # weight decay
25 | weight_decay = 1e-1
26 |
27 | model_type = "gpt2_attn_bias"
--------------------------------------------------------------------------------
/gpt-2/config/train_gpt2_sink.py:
--------------------------------------------------------------------------------
1 | out_dir = 'results/'
2 | data_dir = '/data/locus/project_data/project_data2/mingjies/nanoGPT/data'
3 |
4 | wandb_log = False
5 | wandb_project = 'owt'
6 | wandb_run_name='gpt2-124M-default-run'
7 | compile=False
8 |
9 | # these make the total batch size be ~0.5M
10 | # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
11 | batch_size = 12
12 | block_size = 1024
13 | gradient_accumulation_steps = 5 * 8
14 |
15 | # this makes total number of tokens be 300B
16 | max_iters = 600000
17 | lr_decay_iters = 600000
18 |
19 | # eval stuff
20 | eval_interval = 1000
21 | eval_iters = 200
22 | log_interval = 10
23 |
24 | # weight decay
25 | weight_decay = 1e-1
26 |
27 | model_type = "gpt2_sink"
28 | num_reg=1
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 CMU Locus Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | ## Dependencies
4 | Create an new conda virtual environment
5 | ```sh
6 | conda create -n massive-activations python=3.9 -y
7 | conda activate massive-activations
8 | ```
9 |
10 | Install [Pytorch](https://pytorch.org/)>=2.0.0, [torchvision](https://pytorch.org/vision/stable/index.html)>=0.15.0 following official instructions. For example:
11 | ```sh
12 | pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html
13 | ```
14 |
15 | Install additional dependencies:
16 | ```sh
17 | pip install timm==0.9.12 transformers==4.36.0 accelerate==0.23.0 datasets==2.14.5 matplotlib==3.8.0 seaborn sentencepiece protobuf
18 | ```
19 |
20 | ## Pretrained Models
21 |
22 | - **LLM Models**: To use pretrained LLM models, update the `CACHE_DIR_BASE` variable in the [model_dict.py](lib/model_dict.py) file to point to the directory containing the pretrained model weights.
23 |
24 | - **DINOv2-reg Models**: To use the DINOv2-reg model for linear classification, download the pretrained linear classification head from this [link](https://github.com/facebookresearch/dinov2?tab=readme-ov-file#pretrained-heads---image-classification). Set the `--linear_head_path` argument in the [main_vit.py](main_vit.py) script to the directory where you've stored the downloaded weights.
25 |
--------------------------------------------------------------------------------
/lib/hook.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class Intervened_Layer:
4 | def __init__(self, model_name, reset_type):
5 | self.model_name = model_name
6 | self.reset_type = reset_type
7 |
8 | def update(self, inp, out):
9 | out_feature = out[0]
10 |
11 | if self.reset_type == "set_mean":
12 | alpha = 1.0
13 | elif self.reset_type == "set_zero":
14 | alpha = 0.0
15 | else:
16 | raise ValueError(f"reset_type {self.reset_type} not supported")
17 |
18 | if self.model_name == "llama2_13b":
19 | out_feature[:, 0, 4743] = - 1223.5 * alpha
20 | out_feature[:, 0, 2100] = - 717.95 * alpha
21 | elif self.model_name == "llama2_7b":
22 | feat_abs = out_feature.abs()
23 | sort_res = torch.sort(feat_abs.flatten(), descending=True)
24 |
25 | top_indices = sort_res.indices[0]
26 | token_dim = top_indices.item() // feat_abs.shape[2]
27 |
28 | if token_dim != 0:
29 | out_feature[:, token_dim, 2533] = 2546.8 * alpha
30 | out_feature[:, token_dim, 1415] = - 1502.0 * alpha
31 |
32 | out_feature[:, 0, 2533] = 767.6 * alpha
33 | out_feature[:, 0, 1415] = - 458.55 * alpha
34 |
35 | return (out_feature, *out[1:])
36 |
37 | def setup_intervene_hook(layer, model_name, reset_type):
38 | update_layer = Intervened_Layer(model_name, reset_type)
39 |
40 | def add_batch():
41 | def modify_hook(_, inp, out):
42 | update_layer.update(inp, out)
43 | return modify_hook
44 |
45 | handle = layer.register_forward_hook(add_batch())
46 |
47 | return handle
--------------------------------------------------------------------------------
/monkey_patch/modify_vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import types
3 |
4 | def vit_custom_block_forward(self, x: torch.Tensor) -> torch.Tensor:
5 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
6 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
7 | self.feat = x.clone().detach().cpu().double() ## this is to get the output feature of each layer;
8 | return x
9 |
10 | def enable_vit_custom_block(layer, layer_id):
11 | layer.layer_id = layer_id
12 | layer.forward = types.MethodType(vit_custom_block_forward, layer)
13 |
14 |
15 | def vit_custom_attention_forward(self, x) -> torch.Tensor:
16 | B, N, C = x.shape
17 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
18 | q, k, v = qkv.unbind(0)
19 | q, k = self.q_norm(q), self.k_norm(k)
20 |
21 | q = q * self.scale
22 | attn = q @ k.transpose(-2, -1)
23 |
24 | # ###################################################
25 | self.attn_logits = attn.clone().detach().cpu().double()
26 | # ###################################################
27 |
28 | attn = attn.softmax(dim=-1)
29 |
30 | # ###################################################
31 | self.attn_probs = attn.clone().detach().cpu().double()
32 | # ###################################################
33 |
34 | attn = self.attn_drop(attn)
35 | x = attn @ v
36 |
37 | x = x.transpose(1, 2).reshape(B, N, C)
38 | x = self.proj(x)
39 | x = self.proj_drop(x)
40 | return x
41 |
42 | def enable_vit_custom_attention(layer, layer_id):
43 | modified_module = layer.attn
44 | modified_module.layer_id = layer_id
45 | modified_module.forward = types.MethodType(vit_custom_attention_forward, modified_module)
46 |
--------------------------------------------------------------------------------
/gpt-2/configurator.py:
--------------------------------------------------------------------------------
1 | """
2 | Poor Man's Configurator. Probably a terrible idea. Example usage:
3 | $ python train.py config/override_file.py --batch_size=32
4 | this will first run config/override_file.py, then override batch_size to 32
5 |
6 | The code in this file will be run as follows from e.g. train.py:
7 | >>> exec(open('configurator.py').read())
8 |
9 | So it's not a Python module, it's just shuttling this code away from train.py
10 | The code in this script then overrides the globals()
11 |
12 | I know people are not going to love this, I just really dislike configuration
13 | complexity and having to prepend config. to every single variable. If someone
14 | comes up with a better simple Python solution I am all ears.
15 | """
16 |
17 | import sys
18 | from ast import literal_eval
19 |
20 | for arg in sys.argv[1:]:
21 | if '=' not in arg:
22 | # assume it's the name of a config file
23 | assert not arg.startswith('--')
24 | config_file = arg
25 | print(f"Overriding config with {config_file}:")
26 | with open(config_file) as f:
27 | print(f.read())
28 | exec(open(config_file).read())
29 | else:
30 | # assume it's a --key=value argument
31 | assert arg.startswith('--')
32 | key, val = arg.split('=')
33 | key = key[2:]
34 | if key in globals():
35 | try:
36 | # attempt to eval it it (e.g. if bool, number, or etc)
37 | attempt = literal_eval(val)
38 | except (SyntaxError, ValueError):
39 | # if that goes wrong, just use the string
40 | attempt = val
41 | # ensure the types match ok
42 | assert type(attempt) == type(globals()[key])
43 | # cross fingers
44 | print(f"Overriding: {key} = {attempt}")
45 | globals()[key] = attempt
46 | else:
47 | raise ValueError(f"Unknown config key: {key}")
48 |
--------------------------------------------------------------------------------
/monkey_patch/README.md:
--------------------------------------------------------------------------------
1 | # Monkey Patch LLMs
2 |
3 | This directory provide the code where we use to get intermediate hidden states from LLMs from [Transformers](https://github.com/huggingface/transformers/tree/main/src/transformers/models) and ViTs from [timm](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py). For other LLM model families, you need to write a custom forward function, depending on the model definition.
4 |
5 | A sample command to replace the forward function of [LLaMADecoderLayer](https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/llama/modeling_llama.py#L755) with a custom forward function ``llama_custom_decoderlayer_forward``:
6 | ```py
7 | import types
8 |
9 | def enable_llama_custom_decoderlayer(layer, layer_id):
10 | """
11 | This function modifies a given LlamaDecoderLayer object by setting its layer_id and replacing its forward method with a custom implementation.
12 | It allows for customization of the layer's behavior during the forward pass, which is when the layer processes input data.
13 | """
14 |
15 | layer.layer_id = layer_id
16 | # This line assigns a unique identifier to the layer. The `layer_id` parameter is used to specify this identifier,
17 | # which can be useful for tracking, debugging, or applying specific operations to certain layers within a larger model.
18 |
19 | layer.forward = types.MethodType(
20 | llama_custom_decoderlayer_forward, layer
21 | )
22 | # This line replaces the layer's original `forward` method with a new one.
23 | # `types.MethodType` is used to bind a new method to an existing object. In this case, it binds the
24 | # `llama_custom_decoderlayer_forward` function to the `layer` object as its new `forward` method.
25 | # `llama_custom_decoderlayer_forward` should be a function defined elsewhere that takes the same arguments as the original
26 | # `forward` method of the layer and implements the desired custom behavior for processing input data.
27 | # This allows for dynamic modification of the layer's behavior without altering the original class definition.
28 | ```
--------------------------------------------------------------------------------
/lib/load_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import torch
4 | from datasets import load_dataset
5 |
6 | def set_seed(seed):
7 | np.random.seed(seed)
8 | torch.random.manual_seed(seed)
9 |
10 | def get_data(tokenizer, nsamples=50, seqlen=2048, device=None):
11 | valdata = load_dataset("togethercomputer/RedPajama-Data-1T-Sample")
12 |
13 | num_seq = len(valdata["train"])
14 | seq_indices = np.random.choice(num_seq, 500, replace=False).tolist()
15 | seq_list = []
16 | for seq_ind in seq_indices:
17 | seq_list.append(valdata["train"][seq_ind]['text'])
18 |
19 | testenc = tokenizer("\n\n".join(seq_list), return_tensors='pt', add_special_tokens=False).input_ids
20 |
21 | testseq_list = []
22 | for i in range(nsamples):
23 | test_seq = testenc[:, (i * seqlen):((i+1) * seqlen)].to(device)
24 | testseq_list.append(test_seq.reshape(1, seqlen))
25 |
26 | return testseq_list
27 |
28 |
29 | def get_wikitext(tokenizer, seqlen=2048, device=None):
30 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
31 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt', add_special_tokens=False)
32 | testenc = testenc.input_ids
33 |
34 | testseq_list = []
35 | nsamples = testenc.numel() // seqlen
36 |
37 | for i in range(nsamples):
38 | testenc_cur = testenc[:,(i * seqlen):((i+1) * seqlen)].to(device)
39 | testseq_list.append(testenc_cur.reshape(1, seqlen))
40 | return testseq_list
41 |
42 | def get_pg19(tokenizer, seqlen=2048, device=None):
43 | valdata = load_dataset(
44 | 'emozilla/pg19', split='validation'
45 | )
46 |
47 | testseq_list = []
48 | valenc = tokenizer(' '.join(valdata[:5]['text']), return_tensors='pt').input_ids
49 | for i in range(100):
50 | testseq_list.append(valenc[:, (i * seqlen):((i+1) * seqlen)].to(device))
51 | return testseq_list
52 |
53 | def get_c4(tokenizer, seqlen=2048, device=None):
54 | valdata = load_dataset("NeelNanda/c4-10k")
55 |
56 | testseq_list = []
57 | valenc = tokenizer(' '.join(valdata["train"][:5000]['text']), return_tensors='pt').input_ids
58 | for i in range(100):
59 | testseq_list.append(valenc[:, (i * seqlen):((i+1) * seqlen)].to(device))
60 | return testseq_list
61 |
62 | def get_test_data(dataset_name, tokenizer=None, seed=0, seqlen=2048, device=None):
63 | random.seed(seed)
64 | set_seed(seed)
65 | if dataset_name == "wikitext":
66 | return get_wikitext(tokenizer, seqlen=seqlen, device=device)
67 | elif dataset_name == "c4":
68 | return get_c4(tokenizer, seqlen=seqlen, device=device)
69 | elif dataset_name == "pg19":
70 | return get_pg19(tokenizer, seqlen=seqlen, device=device)
71 |
--------------------------------------------------------------------------------
/gpt-2/README.md:
--------------------------------------------------------------------------------
1 | # Training GPT-2 with Explicit Attention Biases
2 |
3 | We provide the code and pretrained checkpoints for the experiments in Section 5.2 on "Explicit attention biases". The code for training GPT-2 is based on the open-source [nanoGPT](https://github.com/karpathy/nanoGPT) repository.
4 |
5 | ---
6 |
7 |
9 |
10 | We propose to augment the self-attention mechanism with explicit attention biases, by inserting auxiliary key and value parameters.
11 |
12 | [model_attn_bias.py](model_attn_bias.py) contains the model definition of GPT-2 augmented with explicit attention biases.
13 |
14 | ## Setup
15 |
16 | - *data*: Follow [here](https://github.com/karpathy/nanoGPT?tab=readme-ov-file#reproducing-gpt-2) to setup the training and validation data from OpenWebText2.
17 |
18 | - *pretrained models*: Here we provide the model checkpoints for three GPT-2 models we trained, each with 50k iterations
19 |
20 | | model name | download path | validation perplexity |
21 | |:---:|:---:|:---:|
22 | | default | [model](https://drive.google.com/file/d/1_oiybR7wmJ5ibPZMM3sGjtJRqVpMp0H6/view?usp=drive_link) | 3.04 |
23 | | sink | [model](https://drive.google.com/file/d/1ZnhFxN-A7qc9Cghcp_jLpxQXDRE2BqSc/view?usp=drive_link) | 3.04 |
24 | | attn_bias | [model](https://drive.google.com/file/d/1jSpGpNGqJ9Ff_goqoFjSRA5EN7Qmdv1U/view?usp=drive_link) | 3.04 |
25 |
26 | **Note**: For the config files in [config](config), set `out_dir` to the directory of the downloaded pretrained models and `data_dir` to the directories of the prepared OpenWebText2 dataset.
27 |
28 | ## Evalutate
29 |
30 | Running the following commands will evaluate the three GPT-2 checkpoints.
31 | ```sh
32 | CUDA_VISIBLE_DEVICES=0 python test.py config/eval_gpt2_default.py ### gpt2 default architecture
33 | CUDA_VISIBLE_DEVICES=0 python test.py config/eval_gpt2_sink.py ### gpt2 sink token
34 | CUDA_VISIBLE_DEVICES=0 python test.py config/eval_gpt2_attn_bias.py ### gpt2 attention biases
35 | ```
36 |
37 | ## Training
38 | Running the following commands will train the three GPT-2 models from scratch: (can adjust the number of GPUs for training on multiple GPUs)
39 | ```sh
40 | CUDA_VISIBLE_DEVICES=0 python train.py config/train_gpt2_default.py ### gpt2 default architecture
41 | CUDA_VISIBLE_DEVICES=0 python train.py config/train_gpt2_sink.py ### gpt2 sink token
42 | CUDA_VISIBLE_DEVICES=0 python train.py config/train_gpt2_attn_bias.py ### gpt2 attention biases
43 | ```
44 |
45 | ## Analysis
46 | We provide the commands for visualizing the activaiton magnitudes of an intermediate feature and also layerwise largest activation magnitudes:
47 | ```sh
48 | CUDA_VISIBLE_DEVICES=0 python analyze.py config/eval_gpt2_default.py
49 | CUDA_VISIBLE_DEVICES=0 python analyze.py config/eval_gpt2_sink.py
50 | CUDA_VISIBLE_DEVICES=0 python analyze.py config/eval_gpt2_attn_bias.py
51 | ```
--------------------------------------------------------------------------------
/gpt-2/plot_gpt2.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 |
6 | def plot_3d_ax_gpt2(feat, model_name, inp_seq, savedir):
7 | fig = plt.figure(figsize=(6,5))
8 | fig.tight_layout() # Or equivalently, "plt.tight_layout()"
9 | plt.subplots_adjust(wspace=0.1)
10 |
11 | ax = fig.add_subplot(1,1,1, projection='3d')
12 |
13 | name_title = {
14 | "gpt2_default": "GPT-2 Default",
15 | "gpt2_sink": "GPT-2 with Sink Token",
16 | "gpt2_attn_bias": "GPT-2 with Attention Bias",
17 | }
18 | num_tokens = feat.shape[1]
19 | num_channels = feat.shape[2]
20 | xdata = np.array([np.linspace(0,num_tokens-1,num_tokens) for i in range(num_channels)])
21 | ydata = np.array([np.ones(num_tokens) * i for i in range(num_channels)])
22 | zdata = feat[0].numpy().T
23 | ax.plot_wireframe(xdata, ydata, zdata, rstride=0, color="blue", linewidth=1.5)
24 |
25 | ax.set_xticks(np.linspace(0,num_tokens-1,num_tokens), inp_seq,
26 | rotation=60, fontsize=16)
27 |
28 | ax.set(yticklabels=[])
29 | ax.set_zticks([0, 1000, 2000], [0, "1k", "2k"], fontsize=18)
30 |
31 | ax.set_title(name_title[model_name], fontsize=20, fontweight="bold", y=1.005)
32 | plt.setp(ax.get_xticklabels(), rotation=50, ha="right", va="center",rotation_mode="anchor")
33 | plt.setp(ax.get_yticklabels(), ha="left",rotation_mode="anchor")
34 | plt.setp(ax.get_zticklabels(), ha="left",rotation_mode="anchor")
35 |
36 | ax.tick_params(axis='x', which='major', pad=-5)
37 | ax.tick_params(axis='y', which='major', pad=-5)
38 | ax.tick_params(axis='z', which='major', pad=-5)
39 | plt.savefig(os.path.join(savedir, f"{model_name}_3d.png"), bbox_inches="tight", dpi=200)
40 |
41 | def plot_layer_ax_gpt2(mean, model_name, savedir=None):
42 | fig = plt.figure(figsize=(6,5))
43 | fig.tight_layout() # Or equivalently, "plt.tight_layout()"
44 | plt.subplots_adjust(wspace=0.1)
45 |
46 | ax = fig.add_subplot(1,1,1)
47 |
48 | name_title = {
49 | "gpt2_default": "GPT-2 Default",
50 | "gpt2_sink": "GPT-2 with Sink Token",
51 | "gpt2_attn_bias": "GPT-2 with Attention Bias",
52 | }
53 |
54 | colors = ["cornflowerblue", "mediumseagreen", "C4", "teal", "dimgrey"]
55 | x_axis = np.arange(mean.shape[-1])+1
56 | for i in range(3):
57 | ax.plot(x_axis, mean[i], label=f"Top {i+1}", color=colors[i],
58 | linestyle="-", marker="o", markerfacecolor='none', markersize=5)
59 |
60 | ax.set_title(name_title[model_name], fontsize=18, fontweight="bold")
61 | ax.set_ylabel("Magnitudes", fontsize=18)
62 |
63 | num_layers = mean.shape[1]
64 | xtick_label = [1, num_layers//4, num_layers//2, num_layers*3//4, num_layers]
65 | ax.set_xticks(xtick_label, xtick_label, fontsize=16)
66 |
67 | ax.set_xlabel('Layers', fontsize=18, labelpad=0.3)
68 | ax.tick_params(axis='x', which='major', pad=1.0)
69 | ax.tick_params(axis='y', which='major', pad=0.4)
70 | ax.grid(axis='x', color='0.75')
71 | plt.savefig(os.path.join(savedir, f"{model_name}_layerwise.png"), bbox_inches="tight", dpi=200)
--------------------------------------------------------------------------------
/lib/model_dict.py:
--------------------------------------------------------------------------------
1 | CACHE_DIR_BASE = "./model_weights"
2 |
3 | MODEL_DICT_LLMs = {
4 | ### llama2 model
5 | "llama2_7b": {
6 | "model_id": "meta-llama/Llama-2-7b-hf",
7 | "cache_dir": CACHE_DIR_BASE
8 | },
9 | "llama2_13b": {
10 | "model_id": "meta-llama/Llama-2-13b-hf",
11 | "cache_dir": CACHE_DIR_BASE
12 | },
13 | "llama2_70b": {
14 | "model_id": "meta-llama/Llama-2-70b-hf",
15 | "cache_dir": CACHE_DIR_BASE
16 | },
17 |
18 | ### llama2 chat model
19 | "llama2_7b_chat": {
20 | "model_id": "meta-llama/Llama-2-7b-chat-hf",
21 | "cache_dir": CACHE_DIR_BASE
22 | },
23 | "llama2_13b_chat": {
24 | "model_id": "meta-llama/Llama-2-13b-chat-hf",
25 | "cache_dir": CACHE_DIR_BASE
26 | },
27 | "llama2_70b_chat": {
28 | "model_id": "meta-llama/Llama-2-70b-chat-hf",
29 | "cache_dir": CACHE_DIR_BASE
30 | },
31 |
32 | ### mistral model
33 | "mistral_7b": {
34 | "model_id": "mistralai/Mistral-7B-v0.1",
35 | "cache_dir": CACHE_DIR_BASE,
36 | },
37 | "mistral_moe": {
38 | "model_id": "mistralai/Mixtral-8x7B-v0.1",
39 | "cache_dir": CACHE_DIR_BASE,
40 | },
41 | "mistral_7b_instruct":{
42 | "model_id": "mistralai/Mistral-7B-Instruct-v0.2",
43 | "cache_dir": CACHE_DIR_BASE,
44 | },
45 | "mistral_moe_instruct": {
46 | "model_id": "mistralai/Mixtral-8x7B-Instruct-v0.1",
47 | "cache_dir": CACHE_DIR_BASE,
48 | },
49 |
50 | ### phi-2
51 | "phi-2": {
52 | "model_id": "microsoft/phi-2",
53 | "cache_dir": CACHE_DIR_BASE,
54 | },
55 |
56 | ### falcon model
57 | "falcon_7b": {
58 | "model_id": "tiiuae/falcon-7b",
59 | "cache_dir": CACHE_DIR_BASE,
60 | },
61 | "falcon_40b": {
62 | "model_id": "tiiuae/falcon-40b",
63 | "cache_dir": CACHE_DIR_BASE,
64 | },
65 |
66 | ### mpt model
67 | "mpt_7b": {
68 | "model_id": "mosaicml/mpt-7b",
69 | "cache_dir": CACHE_DIR_BASE,
70 | },
71 | "mpt_30b": {
72 | "model_id": "mosaicml/mpt-30b",
73 | "cache_dir": CACHE_DIR_BASE,
74 | },
75 |
76 | ### opt model
77 | "opt_125m": {
78 | "model_id": "facebook/opt-125m",
79 | "cache_dir": CACHE_DIR_BASE,
80 | },
81 | "opt_350m": {
82 | "model_id": "facebook/opt-350m",
83 | "cache_dir": CACHE_DIR_BASE,
84 | },
85 | "opt_1.3b": {
86 | "model_id": "facebook/opt-1.3b",
87 | "cache_dir": CACHE_DIR_BASE,
88 | },
89 | "opt_2.7b": {
90 | "model_id": "facebook/opt-2.7b",
91 | "cache_dir": CACHE_DIR_BASE,
92 | },
93 | "opt_7b": {
94 | "model_id": "facebook/opt-6.7b",
95 | "cache_dir": CACHE_DIR_BASE,
96 | },
97 | "opt_13b": {
98 | "model_id": "facebook/opt-13b",
99 | "cache_dir": CACHE_DIR_BASE,
100 | },
101 | "opt_30b": {
102 | "model_id": "facebook/opt-30b",
103 | "cache_dir": CACHE_DIR_BASE,
104 | },
105 | "opt_66b": {
106 | "model_id": "facebook/opt-66b",
107 | "cache_dir": CACHE_DIR_BASE,
108 | },
109 |
110 | ### gpt2 model
111 | "gpt2": {
112 | "model_id": "gpt2",
113 | "cache_dir": CACHE_DIR_BASE
114 | },
115 | "gpt2_medium": {
116 | "model_id": "gpt2-medium",
117 | "cache_dir": CACHE_DIR_BASE
118 | },
119 | "gpt2_large": {
120 | "model_id": "gpt2-large",
121 | "cache_dir": CACHE_DIR_BASE
122 | },
123 | "gpt2_xl": {
124 | "model_id": "gpt2-xl",
125 | "cache_dir": CACHE_DIR_BASE
126 | },
127 | }
--------------------------------------------------------------------------------
/lib/eval_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import types
3 |
4 | import torch
5 | import torch.nn as nn
6 | from timm.utils import accuracy
7 |
8 | from .load_data import get_test_data
9 |
10 |
11 | def fix_reg_mean(args, model):
12 | """
13 | Modifies the forward pass of each layer in a ViT model to set register token features to pre-computed means.
14 | """
15 | def custom_layer_forward(self, x: torch.Tensor) -> torch.Tensor:
16 | for reg_id in range(4):
17 | reg_id = int(reg_id)
18 | cur_reg_feat = self.reg_feat[reg_id].clone()
19 | x[:,reg_id+1,:] = cur_reg_feat.reshape(1,-1).repeat(x.shape[0],1)
20 |
21 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
22 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
23 | return x
24 |
25 | # Load the pre-computed means of resgiter token features for the DINOv2-reg ViT-`args.model_size`
26 | imagenet10k_reg_feat_mean = torch.load(os.path.join(args.reg_feat_mean, f"{args.model_size}_imagemean_bs10k_trainaug.pt"))
27 | for layer_id in range(len(model.blocks)):
28 | layer = model.blocks[layer_id]
29 | layer.reg_feat = imagenet10k_reg_feat_mean[layer_id]
30 |
31 | # Replace the layer's forward function with the custom forward function defined above
32 | layer.forward = types.MethodType(
33 | custom_layer_forward, layer
34 | )
35 |
36 | def setup_dinov2_model_for_eval(model, linear_head_weights):
37 | """
38 | Configures a DINOv2 model for ImageNet evaluation by setting up a new linear head with given weights and a custom forward function.
39 |
40 | Args:
41 | model: The DINOv2 pretrained ViT models.
42 | linear_head_weights: A dictionary containing the weights and bias for the new linear head.
43 | """
44 | in_features, out_features = linear_head_weights["weight"].shape
45 | model.head = nn.Linear(in_features, out_features, bias=True)
46 | model.head.weight.data = linear_head_weights["weight"]
47 | model.head.bias.data = linear_head_weights["bias"]
48 | model.head.cuda()
49 |
50 | def custom_forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
51 | x = torch.cat([x[:,0], x[:, self.num_prefix_tokens:].mean(dim=1)], dim=-1)
52 | x = self.fc_norm(x)
53 | x = self.head_drop(x)
54 | return self.head(x)
55 |
56 | model.forward_head = types.MethodType(
57 | custom_forward_head, model
58 | )
59 |
60 | def test_imagenet(model, dataloader):
61 | acc = 0
62 | cnt = 0
63 | for batch_idx, batch in enumerate(dataloader):
64 | if batch_idx % 10 == 0 and batch_idx > 0:
65 | print(f"batch idx {batch_idx} acc {acc/cnt}")
66 | images = batch[0].cuda()
67 | target = batch[-1].cuda()
68 |
69 | with torch.no_grad():
70 | output = model(images)
71 | acc1, _ = accuracy(output, target, topk=(1, 5))
72 | acc += acc1 * images.shape[0]
73 | cnt += images.shape[0]
74 |
75 | return acc/cnt
76 |
77 | @torch.no_grad()
78 | def eval_ppl(dataset_name, model, tokenizer, seed):
79 | print(f"Evaluating on {dataset_name}")
80 | seqlen=4096
81 | testseq_list = get_test_data(
82 | dataset_name, seed=seed, tokenizer=tokenizer, seqlen=seqlen, device="cuda:0"
83 | )
84 |
85 | nlls = []
86 | with torch.no_grad():
87 | for test_seq in testseq_list:
88 | lm_logits = model(test_seq).logits
89 |
90 | shift_logits = lm_logits[:, :-1, :].contiguous() ## shape: [1, 2047, 50272]
91 | shift_labels = test_seq[:, 1:] ## shape: shape: [1, 2047]
92 |
93 | loss_fct = nn.CrossEntropyLoss()
94 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
95 | neg_log_likelihood = loss.float() * test_seq.numel()
96 | nlls.append(neg_log_likelihood)
97 |
98 | ppl = torch.exp(torch.stack(nlls).sum() / (len(testseq_list) * seqlen))
99 | return ppl.item()
--------------------------------------------------------------------------------
/main_vit.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from importlib.metadata import version
4 |
5 | import numpy as np
6 | import torch
7 | from PIL import Image
8 | from torchvision import datasets
9 |
10 | import lib
11 | import monkey_patch as mp
12 |
13 | print('torch', version('torch'))
14 | print('transformers', version('transformers'))
15 | print('accelerate', version('accelerate'))
16 | print('# of gpus: ', torch.cuda.device_count())
17 |
18 | if __name__ == "__main__":
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--model_family', type=str)
21 | parser.add_argument('--model_size', type=str)
22 | parser.add_argument("--layer_id", type=int, default=1)
23 | parser.add_argument('--exp1', action="store_true", help="plot 3d feature")
24 | parser.add_argument('--exp2', action="store_true", help="layerwise analysis")
25 | parser.add_argument('--exp3', action="store_true", help="test original and fix-reg-mean accuracy")
26 | parser.add_argument('--imagenet_dir', type=str, default="/home/mingjies/imagenet-data/val")
27 | parser.add_argument('--linear_head_path', type=str, default="/data/locus/project_data/project_data2/mingjies/dinov2")
28 | parser.add_argument('--reg_feat_mean', type=str, default="assets/reg_feat_mean/")
29 | parser.add_argument('--seed', type=int, default=0)
30 | parser.add_argument('--num_imgs_mean', type=int, default=10)
31 | parser.add_argument('--savedir', type=str)
32 | args = parser.parse_args()
33 |
34 | torch.manual_seed(args.seed)
35 |
36 | if not os.path.exists(args.savedir):
37 | os.makedirs(args.savedir)
38 |
39 | model, layers, val_transform = lib.load_vit(args)
40 |
41 | if args.exp1:
42 | layer_id = args.layer_id - 1
43 | layer = layers[layer_id]
44 | mp.enable_vit_custom_block(layer, layer_id)
45 |
46 | img_path = os.path.join("assets", f"bird.png")
47 | img = Image.open(img_path)
48 | img = val_transform(img).unsqueeze(0).cuda()
49 |
50 | with torch.no_grad():
51 | output = model(img)
52 |
53 | feat_abs = layers[layer_id].feat.abs()
54 |
55 | lib.plot_3d_feat_vit(feat_abs, layer_id, args.model_family, args.model_size, args.savedir)
56 | # torch.save(stats, os.path.join(args.savedir, f"stats.pt"))
57 |
58 | elif args.exp2:
59 | for layer_id in range(len(layers)):
60 | layer = layers[layer_id]
61 | mp.enable_vit_custom_block(layer, layer_id)
62 |
63 | dataset = datasets.ImageFolder(args.imagenet_dir, transform=val_transform)
64 |
65 | stats = []
66 | for img_idx in range(args.num_imgs_mean):
67 | print("img_idx", img_idx)
68 | images, target = dataset[img_idx]
69 | images = images.unsqueeze(0).cuda()
70 |
71 | with torch.no_grad():
72 | output = model(images)
73 |
74 | layer_stats_np = np.zeros((4, len(layers)))
75 | for layer_id in range(len(layers)):
76 | feat_abs = layers[layer_id].feat.abs()
77 | sort_res = torch.sort(feat_abs.flatten(), descending=True)
78 | layer_stats_np[:3, layer_id] = sort_res.values[:3]
79 | layer_stats_np[3, layer_id] = torch.median(feat_abs)
80 |
81 | stats.append(layer_stats_np)
82 |
83 | lib.plot_layer_ax_vit(np.mean(stats, axis=0), args.model_family, args.model_size, args.savedir)
84 |
85 | elif args.exp3:
86 | linear_head = lib.load_dinov2_linear_head(args)
87 | lib.setup_dinov2_model_for_eval(model, linear_head)
88 |
89 | dataset = datasets.ImageFolder(args.imagenet_dir, transform=val_transform)
90 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, num_workers=8, pin_memory=False)
91 |
92 | f = open(os.path.join(args.savedir, "eval.txt"), "w")
93 | top1_acc = lib.test_imagenet(model, dataloader)
94 | print(f"{args.model_family} ViT-{args.model_size} original accuracy: {top1_acc}", file=f, flush=True)
95 |
96 | lib.fix_reg_mean(args, model)
97 | top1_acc = lib.test_imagenet(model, dataloader)
98 | print(f"{args.model_family} ViT-{args.model_size} fix-reg-mean accuracy: {top1_acc}", file=f, flush=True)
--------------------------------------------------------------------------------
/lib/plot_utils_vit.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 |
7 | # Configuration settings for matplotlib
8 | matplotlib.rcParams['pgf.texsystem'] = 'pdflatex'
9 | matplotlib.rcParams.update({
10 | 'font.size': 18,
11 | 'axes.labelsize': 20,
12 | 'axes.titlesize': 24,
13 | 'figure.titlesize': 28
14 | })
15 | matplotlib.rcParams['text.usetex'] = False
16 |
17 | def plot_3d_feat_vit_sub(ax, feat, layer_id, model_name, model_size):
18 | model_title={"dinov2_reg": f"DINOv2-reg ViT-{model_size}", "mistral_7b": "Mistral-7B",
19 | "llama2_13b": "LLaMA-2-13B", "llama2_70b": "LLaMA-2-70B", "mistral_moe":"Mixtral-8x7B"}
20 |
21 | num_channels = feat.shape[2]
22 |
23 | inp_seq = ["CLS", "reg 1", "reg 2", "reg 3", "reg 4",
24 | "patch 1", "patch 2", "patch i", "patch n"]
25 |
26 | xbase_index = [0,1,2,3,4,5,7,9]
27 | num_tokens = len(xbase_index)
28 | xdata = np.array([xbase_index for i in range(num_channels)])
29 | ydata = np.array([np.ones(num_tokens) * i for i in range(num_channels)])
30 | zdata = feat[0,:num_tokens,:].abs().numpy().T
31 | ax.plot_wireframe(xdata, ydata, zdata, rstride=0, color="royalblue", linewidth=2.5)
32 |
33 | ax.set_title(model_title[model_name]+f", Layer {layer_id+1}", fontsize=20, fontweight="bold", y=1.015)
34 |
35 | ax.set_yticks([179, 999], [179, 999], fontsize=15, fontweight="heavy")
36 |
37 | xbase_index = [0,1,2,3,4,]
38 | inp_seq = ["CLS", "reg 1", "reg 2", "reg 3", "reg 4"]
39 | ax.set_xticks(xbase_index, inp_seq, rotation=60, fontsize=16)
40 | ax.tick_params(axis='x', which='major', pad=-4)
41 | plt.setp(ax.get_xticklabels(), rotation=50, ha="right", va="center", rotation_mode="anchor")
42 |
43 | ax.set_zticks([0, 500, 1000], ["0", "500", "1k"], fontsize=16)
44 | ax.get_xticklabels()[3].set_weight("heavy")
45 | plt.setp(ax.get_yticklabels(), ha="left", va="center",rotation_mode="anchor")
46 | plt.setp(ax.get_zticklabels(), ha="left", va="top", rotation_mode="anchor")
47 |
48 | ax.tick_params(axis='x', which='major', pad=-5)
49 | ax.tick_params(axis='y', which='major', pad=-3)
50 | ax.tick_params(axis='z', which='major', pad=-5)
51 |
52 | def plot_3d_feat_vit(feat, layer_id, model_name, model_size, savedir):
53 | fig = plt.figure(figsize=(8,6))
54 | fig.tight_layout() # Or equivalently, "plt.tight_layout()"
55 | plt.subplots_adjust(wspace=0.)
56 |
57 | ax = fig.add_subplot(1,1, 1, projection='3d')
58 | plot_3d_feat_vit_sub(ax, feat, layer_id, model_name, model_size)
59 | plt.savefig(os.path.join(savedir, f"{model_name}_{model_size}_layer_{layer_id+1}.png"), bbox_inches="tight", dpi=200)
60 |
61 |
62 | def plot_layer_ax_vit_sub(ax, mean, model_family, model_size, colors=["royalblue", "darkorange", "forestgreen", "black"]):
63 | model_title={"dinov2_reg": "DINOv2-reg",
64 | "dinov2": "DINOv2", "mae": "MAE", "open_clip": "Open CLIP", "openai_clip": "OpenAI CLIP",
65 | "vit_orig": "ViT", "samvit": "SAM-ViT"}
66 |
67 | x_axis = np.arange(mean.shape[-1])+1
68 | for i in range(3):
69 | ax.plot(x_axis, mean[i], label=f"Top {i+1}", color=colors[i],
70 | linestyle="-", marker="o", markerfacecolor='none', markersize=5)
71 |
72 | ax.plot(x_axis, mean[-1], label=f"median", color=colors[-1],
73 | linestyle="-", marker="v", markerfacecolor='none', markersize=5)
74 |
75 | ax.set_title(model_title[model_family]+f" ViT-{model_size}", fontsize=18, fontweight="bold")
76 | ax.set_ylabel("Magnitudes", fontsize=18)
77 |
78 | num_layers = mean.shape[1]
79 | xtick_label = [1, num_layers//4, num_layers//2, num_layers*3//4, num_layers]
80 | ax.set_xticks(xtick_label, xtick_label, fontsize=16)
81 |
82 | ax.set_xlabel('Layers', fontsize=18, labelpad=4.0)
83 | ax.tick_params(axis='x', which='major', pad=2.0)
84 | ax.tick_params(axis='y', which='major', pad=0.4)
85 | ax.grid(axis='x', color='0.75')
86 | ax.grid(axis='y', color='0.75')
87 |
88 | def plot_layer_ax_vit(mean, model_family, model_size, savedir):
89 | fig = plt.figure(figsize=(8,6))
90 | fig.tight_layout() # Or equivalently, "plt.tight_layout()"
91 | plt.subplots_adjust(wspace=0.)
92 |
93 | ax = fig.add_subplot(1,1, 1)
94 | plot_layer_ax_vit_sub(ax, mean, model_family, model_size)
95 | plt.savefig(os.path.join(savedir, f"{model_family}_{model_size}.png"), bbox_inches="tight", dpi=200)
--------------------------------------------------------------------------------
/lib/load_model.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import timm
5 | from transformers import AutoTokenizer, AutoModelForCausalLM
6 |
7 | from .model_dict import MODEL_DICT_LLMs
8 |
9 |
10 | def load_llm(args):
11 | print(f"loading model {args.model}")
12 | model_name, cache_dir = MODEL_DICT_LLMs[args.model]["model_id"], MODEL_DICT_LLMs[args.model]["cache_dir"]
13 |
14 | if "falcon" in args.model or "mpt" in args.model or "phi" in args.model:
15 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, cache_dir=cache_dir, low_cpu_mem_usage=True, device_map="auto", trust_remote_code=True, token=args.access_token)
16 | elif "mistral" in args.model or "pythia" in args.model:
17 | model = AutoModelForCausalLM.from_pretrained(model_name, revision=args.revision, torch_dtype=torch.float16, cache_dir=cache_dir, low_cpu_mem_usage=True, device_map="auto", token=args.access_token)
18 | else:
19 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, cache_dir=cache_dir, low_cpu_mem_usage=True, device_map="auto", token=args.access_token)
20 | model.eval()
21 |
22 | if "mpt" in args.model or "pythia" in args.model:
23 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, token=args.access_token)
24 | else:
25 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, token=args.access_token)
26 |
27 | device = torch.device("cuda:0")
28 | if "mpt_30b" in args.model:
29 | device = model.hf_device_map["transformer.wte"]
30 | elif "30b" in args.model or "65b" in args.model or "70b" in args.model or "40b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here.
31 | device = torch.device("cuda:"+str(model.hf_device_map["lm_head"]))
32 |
33 | if "llama2_13b" == args.model:
34 | # device = torch.device("cuda:"+str(model.hf_device_map["lm_head"]))
35 | device = torch.device("cuda:1")
36 |
37 | seq_len=4096
38 | if "llama" in args.model or "mistral" in args.model:
39 | layers = model.model.layers
40 | hidden_size = model.config.hidden_size
41 | elif "falcon" in args.model:
42 | layers = model.transformer.h
43 | hidden_size = model.config.hidden_size
44 | elif "mpt" in args.model:
45 | layers = model.transformer.blocks
46 | hidden_size = model.config.d_model
47 | seq_len=2048
48 | elif "opt" in args.model:
49 | layers = model.model.decoder.layers
50 | hidden_size = model.config.hidden_size
51 | seq_len = 2048
52 | elif "gpt2" in args.model:
53 | layers = model.transformer.h
54 | hidden_size = model.transformer.embed_dim
55 | seq_len = 1024
56 | elif "pythia" in args.model:
57 | layers = model.gpt_neox.layers
58 | hidden_size = model.gpt_neox.config.hidden_size
59 | seq_len = 2048
60 | elif "phi-2" in args.model:
61 | layers = model.model.layers
62 | hidden_size = model.config.hidden_size
63 |
64 | return model, tokenizer, device, layers, hidden_size, seq_len
65 |
66 | def load_vit(args):
67 | if args.model_family == "mae":
68 | patch_size=14 if args.model_size == "huge" else 16
69 | model = timm.create_model(f'vit_{args.model_size}_patch{patch_size}_224.mae', pretrained=True)
70 | elif args.model_family == "openai_clip":
71 | patch_size=14 if args.model_size == "large" else 16
72 | model = timm.create_model(f"vit_{args.model_size}_patch{patch_size}_clip_224.openai", pretrained=True)
73 | elif args.model_family == "dinov2":
74 | model = timm.create_model(f"vit_{args.model_size}_patch14_dinov2.lvd142m", pretrained=True, num_classes=1000)
75 | elif args.model_family == "dinov2_reg":
76 | model = timm.create_model(f"vit_{args.model_size}_patch14_reg4_dinov2.lvd142m", pretrained=True, num_classes=1000)
77 |
78 | model = model.cuda()
79 | model = model.eval()
80 |
81 | layers = model.blocks
82 |
83 | data_config = timm.data.resolve_model_data_config(model)
84 | val_transform = timm.data.create_transform(**data_config, is_training=False)
85 |
86 | return model, layers, val_transform
87 |
88 | def load_dinov2_linear_head(args):
89 | assert "dinov2" in args.model_family, "this function is only for dinov2 models"
90 | if args.model_family == "dinov2_reg":
91 | linear_head_path = os.path.join(args.linear_head_path, f"dinov2_vit{args.model_size[0]}14_reg4_linear_head.pth")
92 | elif args.model_family == "dinov2":
93 | linear_head_path = os.path.join(args.linear_head_path, f"dinov2_vit{args.model_size[0]}14_linear_head.pth")
94 |
95 | linear_head_weights = torch.load(linear_head_path)
96 | return linear_head_weights
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Massive Activations in Large Language Models
2 |
3 | Official PyTorch implementation of our paper:
4 |
5 | **Massive Activations in Large Language Models**
6 | [Mingjie Sun](https://eric-mingjie.github.io/), [Xinlei Chen](https://xinleic.xyz/), [J. Zico Kolter](https://zicokolter.com/), [Zhuang Liu](https://liuzhuang13.github.io/)
7 | Carnegie Mellon University, Meta AI Research and Bosch Center for AI
8 | [Paper](https://arxiv.org/abs/2402.17762) - [Project page](https://eric-mingjie.github.io/massive-activations/index.html)
9 |
10 | Most of the experiments in this paper were done on one A6000 GPU.
11 |
12 | ---
13 |
14 |
16 |
17 |
18 | This paper studies the existence of *massive activations* in Large Language Models (LLMs). These activations have significantly larger magnitudes than other activations while on the other hand are extremely few in quantity.
19 |
20 | ## This repository
21 |
22 | ### Setup
23 | Installation instructions can be found in [INSTALL.md](INSTALL.md).
24 |
25 | ### Outline
26 | The contents of this repository are as follows:
27 |
28 | * [lib](lib) contains the util function for loading models, plotting figures and evaluation.
29 | * [monkey_patch](monkey_patch) contains the code for monkey patching LLMs with custom forward function, with a goal of collecting internal activation and attention statistics.
30 | * [gpt-2](gpt-2) contains the code for training GPT-2 with explicit attention biases.
31 | * [main_llm.py](main_llm.py) contains the code for reproducing our experiments on LLMs.
32 | * [main_vit.py](main_vit.py) contains the code for reproducing our experiments on ViTs.
33 |
34 | ### Large Language Models (LLMs)
35 |
36 | * We provide an example command to visualize a hidden state feature on the residual stream:
37 | ```sh
38 | CUDA_VISIBLE_DEVICES=0 python main_llm.py --model llama2_7b \
39 | --exp1 --layer_id 2 \
40 | --savedir results/llm/3d_feat_vis/
41 | ```
42 | Running this command will visualize the output feature of layer 2 in LLaMA-2-7B, on the input prompt "*Summer is warm. Winter is cold.\n*". The resulting visualizations are saved in `results/llm/3d_feat_vis/`.
43 |
44 | For some LLMs, e.g., LLaMA2-7B, you need to set the argument `--access-token` in order to access the weights.
45 |
46 | * We provide an example command to visualize the layerwise top 3 largest activation magnitudes:
47 | ```sh
48 | CUDA_VISIBLE_DEVICES=0 python main_llm.py --model llama2_7b \
49 | --exp2 \
50 | --savedir results/llm/layerwise/
51 | ```
52 | Running this command will visualize the per layer top activation magnitudes. The resulting visualizations are saved in `results/llm/layerwise`.
53 |
54 | * We provide an example command to run the intervention analysis:
55 | ```sh
56 | CUDA_VISIBLE_DEVICES=0 python main_llm.py --model llama2_7b \
57 | --exp3 \
58 | --reset_type set_zero \
59 | --layer_id 2 \
60 | --savedir results/llm/intervention_analysis/
61 | ```
62 | Here the argument `--reset_type` can be either `set_zero` or `set_mean`. This command will zero the massive activations in the output feature of layer 2 in LLaMA-2-7B. The evaluation results are saved in `results/llm/intervention_analysis`.
63 |
64 | * We provide an example command for attention visualization:
65 | ```sh
66 | CUDA_VISIBLE_DEVICES=0 python main_llm.py --model llama2_7b \
67 | --exp4 \
68 | --layer_id 3 \
69 | --savedir results/llm/attn_vis/
70 | ```
71 | Running this command will visualize the attention logits (average over attention heads) in layer 3 of LLaMA-2-7B. The visualizations are saved in `results/llm/attn_vis/`.
72 |
73 | ### Vision Transformers (ViTs)
74 |
75 | * We provide an example command for visualizing the activation magnitudes of the output feature of an intermediate layer:
76 | ```sh
77 | CUDA_VISIBLE_DEVICES=0 python main_vit.py --model_family dinov2_reg --model_size giant \
78 | --exp1 \
79 | --layer_id 40 \
80 | --savedir results/vit/3d_feat_vis/
81 | ```
82 |
83 | * We provide an example command for visualizing the layer-wise largest activation magnitudes:
84 | ```sh
85 | CUDA_VISIBLE_DEVICES=0 python main_vit.py --model_family dinov2_reg --model_size giant \
86 | --exp2 \
87 | --savedir results/vit/layerwise/
88 | ```
89 |
90 | * For reproducing the results of `Fix-Reg-Mean` on [DINOv2-reg](https://arxiv.org/abs/2309.16588), run the following commands:
91 | ```sh
92 | for model_size in small base large giant
93 | do
94 | CUDA_VISIBLE_DEVICES=0 python main_vit.py \
95 | --model_family dinov2_reg --model_size ${model_size} --exp3 \
96 | --reg_feat_mean assets/reg_feat_mean \
97 | --imagenet_dir [Path to ImageNet validation set] \
98 | --savedir results/vit/exp4/dinov2_reg_${model_size}
99 | done
100 | ```
101 | The argument `--reg_feat_mean` corresponds to the directory containing the mean of the register features at all layers collected over 10k ImageNet training images with data augmentations.
102 |
103 | Results
104 | | DINOv2-reg | ViT-S | ViT-B | ViT-L | ViT-G |
105 | |------------------|----------|----------|--------|-------|
106 | | Original | 81.9 | 84.8 | 86.3 | 87.0 |
107 | | `Fix-Reg-Mean` | 81.7 | 85.0 | 86.2 | 87.0 |
108 |
109 | ## License
110 | This project is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information.
111 |
112 | ## Reference
113 | ```bibtex
114 | @article{sun2024massive,
115 | title={Massive Activations in Large Language Models},
116 | author={Sun, Mingjie and Chen, Xinlei and Kolter, J. Zico and Liu, Zhuang},
117 | year={2024},
118 | journal={arXiv preprint arXiv:2402.17762}
119 | }
120 | ```
--------------------------------------------------------------------------------
/main_llm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from importlib.metadata import version
4 |
5 | import numpy as np
6 | import torch
7 |
8 | import lib
9 | import monkey_patch as mp
10 |
11 | print('torch', version('torch'))
12 | print('transformers', version('transformers'))
13 | print('accelerate', version('accelerate'))
14 | print('# of gpus: ', torch.cuda.device_count())
15 |
16 | if __name__ == "__main__":
17 | parser = argparse.ArgumentParser()
18 |
19 | parser.add_argument('--model', type=str, help='LLaMA model')
20 | parser.add_argument("--dataset", type=str, default="wikitext")
21 | parser.add_argument('--seed',type=int, default=1, help='Seed for sampling the calibration data.')
22 | parser.add_argument("--revision", type=str, default="main")
23 |
24 | parser.add_argument("--exp1", action="store_true", help="plot 3d feature")
25 | parser.add_argument("--exp2", action="store_true", help="layerwise analysis")
26 | parser.add_argument("--exp3", action="store_true", help="intervention analysis")
27 | parser.add_argument("--exp4", action="store_true", help="attention visualization")
28 | parser.add_argument("--layer_id", type=int, default=1)
29 | parser.add_argument("--reset_type", type=str, default="set_mean")
30 | parser.add_argument("--access_token", type=str, default="type in your access token here")
31 | parser.add_argument("--savedir", type=str)
32 | args = parser.parse_args()
33 |
34 | np.random.seed(args.seed)
35 | torch.manual_seed(args.seed)
36 | if not os.path.exists(args.savedir):
37 | os.makedirs(args.savedir)
38 |
39 | model, tokenizer, device, layers, hidden_size, seq_len = lib.load_llm(args)
40 | print("use device ", device)
41 |
42 | if args.exp1: ### visualize the output feature of a layer in LLMs
43 | layer_id = args.layer_id - 1
44 | if "llama2" in args.model:
45 | mp.enable_llama_custom_decoderlayer(layers[layer_id], layer_id)
46 | elif "mistral" in args.model:
47 | mp.enable_mistral_custom_decoderlayer(layers[layer_id], layer_id)
48 | elif "phi-2" in args.model:
49 | mp.enable_phi2_custom_decoderlayer(layers[layer_id], layer_id)
50 | else:
51 | raise ValueError(f"model {args.model} not supported")
52 |
53 | stats = {}
54 | seq = "Summer is warm. Winter is cold."
55 | valenc = tokenizer(seq, return_tensors='pt', add_special_tokens=False).input_ids.to(device)
56 |
57 | with torch.no_grad():
58 | model(valenc)
59 |
60 | seq_decoded = []
61 | for i in range(valenc.shape[1]):
62 | seq_decoded.append(tokenizer.decode(valenc[0,i].item()))
63 |
64 | stats[f"seq"] = seq_decoded
65 | feat_abs = layers[layer_id].feat.abs()
66 |
67 | stats[f"{layer_id}"] = feat_abs
68 |
69 | lib.plot_3d_feat(stats, layer_id, args.model, args.savedir)
70 |
71 | elif args.exp2: ### visualize the layerwise top activation magnitudes
72 | for layer_id in range(len(layers)):
73 | layer = layers[layer_id]
74 | if "llama2" in args.model:
75 | mp.enable_llama_custom_decoderlayer(layer, layer_id)
76 | elif "mistral" in args.model:
77 | mp.enable_mistral_custom_decoderlayer(layer, layer_id)
78 | elif "phi-2" in args.model:
79 | mp.enable_phi2_custom_decoderlayer(layers[layer_id], layer_id)
80 | else:
81 | raise ValueError(f"model {args.model} not supported")
82 |
83 | testseq_list = lib.get_data(tokenizer, nsamples=10, seqlen=seq_len, device=device)
84 |
85 | stats = []
86 | for seqid, testseq in enumerate(testseq_list):
87 | print(f"processing seq {seqid}")
88 | with torch.no_grad():
89 | model(testseq)
90 |
91 | seq_np = np.zeros((4, len(layers)))
92 | for layer_id in range(len(layers)):
93 | feat_abs = layers[layer_id].feat.abs()
94 | sort_res = torch.sort(feat_abs.flatten(), descending=True)
95 | seq_np[:3, layer_id] = sort_res.values[:3]
96 | seq_np[3, layer_id] = torch.median(feat_abs)
97 |
98 | stats.append(seq_np)
99 |
100 | lib.plot_layer_ax(stats, args.model, args.savedir)
101 |
102 | elif args.exp3: ### intervention analysis
103 | layer = layers[args.layer_id-1]
104 | lib.setup_intervene_hook(layer, args.model, args.reset_type)
105 |
106 | f = open(os.path.join(args.savedir, f"{args.model}_{args.reset_type}.log"), "a")
107 |
108 | ds_list = ["wikitext", "c4", "pg19"]
109 | res = {}
110 | for ds_name in ds_list:
111 | ppl = lib.eval_ppl(ds_name, model, tokenizer, args.seed, device)
112 | res[ds_name] = ppl
113 | print(f"{ds_name} ppl: {ppl}", file=f, flush=True)
114 |
115 | elif args.exp4:
116 | layer_id = args.layer_id - 1
117 | if "llama2" in args.model:
118 | modified_attn_layer = mp.enable_llama_custom_attention(layers[layer_id], layer_id)
119 | elif "mistral" in args.model:
120 | modified_attn_layer = mp.enable_mistral_custom_attention(layers[layer_id], layer_id)
121 | elif "phi-2" in args.model:
122 | modified_attn_layer = mp.enable_phi2_custom_attention(layers[layer_id], layer_id)
123 | else:
124 | raise ValueError(f"model {args.model} not supported")
125 |
126 | seq = "The following are multiple choice questions (with answers) about machine learning.\n\n A 6-sided die is rolled 15 times and the results are: side 1 comes up 0 times;"
127 | valenc = tokenizer(seq, return_tensors='pt', add_special_tokens=False).input_ids.to(device)
128 |
129 | with torch.no_grad():
130 | model(valenc)
131 |
132 | attn_logit = layers[layer_id].self_attn.attn_logits.detach().cpu()
133 | lib.plot_attn(attn_logit, args.model, layer_id, args.savedir)
--------------------------------------------------------------------------------
/lib/plot_utils_llm.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import seaborn as sns
7 |
8 | matplotlib.rcParams['pgf.texsystem'] = 'pdflatex'
9 | matplotlib.rcParams.update({
10 | 'font.size': 18,
11 | 'axes.labelsize': 20,
12 | 'axes.titlesize': 24,
13 | 'figure.titlesize': 28
14 | })
15 | matplotlib.rcParams['text.usetex'] = False
16 |
17 | MODEL_TITLE_DICT={"llama2_7b": "LLaMA-2-7B", "mistral_7b": "Mistral-7B",
18 | "llama2_13b_chat": "LLaMA-2-13B-chat", "llama2_70b_chat": "LLaMA-2-70B-chat",
19 | "llama2_7b_chat": "LLaMA-2-7B-chat", "llama2_13b": "LLaMA-2-13B", "llama2_70b": "LLaMA-2-70B",
20 | "mistral_moe":"Mixtral-8x7B", "falcon_7b": "Falcon-7B", "falcon_40b": "Falcon-40B", "phi-2": "Phi-2",
21 | "opt_7b":"OPT-7B", "opt_13b": "OPT-13B", "opt_30b": "OPT-30B", "opt_66b": "OPT-66B",
22 | "mpt_7b": "MPT-7B", "mpt_30b": "MPT-30B", "pythia_7b": "Pythia-7B", "pythia_12b": "Pythia-12B",
23 | "gpt2": "GPT-2", "gpt2_large": "GPT-2-Large", "gpt2_xl": "GPT-2-XL", "gpt2_medium": "GPT-2-Medium",
24 | "mistral_moe_instruct": "Mixtral-8x7B-Instruct", "mistral_7b_instruct": "Mistral-7B-Instruct"}
25 |
26 |
27 | def plot_3d_feat_sub(ax, obj, seq_id, layer_id, model_name):
28 | num_tokens = len(obj[f"seq"])
29 | num_channels = obj[f"{layer_id}"].shape[2]
30 | inp_seq = obj[f"seq"]
31 | inp_seq = [x if x != "<0x0A>" else r"\n" for x in inp_seq]
32 | xdata = np.array([np.linspace(0,num_tokens-1,num_tokens) for i in range(num_channels)])
33 | ydata = np.array([np.ones(num_tokens) * i for i in range(num_channels)])
34 | zdata = obj[f"{layer_id}"][0].abs().numpy().T
35 | ax.plot_wireframe(xdata, ydata, zdata, rstride=0, color="royalblue", linewidth=2.5)
36 |
37 | ax.set_xticks(np.linspace(0,num_tokens-1,num_tokens), inp_seq,
38 | rotation=50, fontsize=16)
39 |
40 | ax.set_zticks([0, 1000, 2000], ["0", "1k", "2k"], fontsize=15)
41 | ax.set_yticks([1415, 2533], [1415, 2533], fontsize=15, fontweight="heavy")
42 | ax.get_xticklabels()[0].set_weight("heavy")
43 |
44 | if seq_id in [0, 1]:
45 | ax.get_xticklabels()[3].set_weight("heavy")
46 |
47 | ax.set_title(MODEL_TITLE_DICT[model_name], fontsize=18, fontweight="bold", y=1.015)
48 |
49 | plt.setp(ax.get_xticklabels(), rotation=50, ha="right", va="center",
50 | rotation_mode="anchor")
51 | plt.setp(ax.get_yticklabels(), ha="left",
52 | rotation_mode="anchor")
53 |
54 | ax.tick_params(axis='x', which='major', pad=-4)
55 | ax.tick_params(axis='y', which='major', pad=-5)
56 | ax.tick_params(axis='z', which='major', pad=-1)
57 | ax.set_zlim(0,2400)
58 |
59 | def plot_3d_feat(obj, layer_id, model_name, savedir):
60 | fig = plt.figure(figsize=(14,6))
61 | fig.tight_layout() # Or equivalently, "plt.tight_layout()"
62 | plt.subplots_adjust(wspace=0.13)
63 |
64 | # for i in range(3):
65 | ax = fig.add_subplot(1,1, 1, projection='3d')
66 | plot_3d_feat_sub(ax, obj, 0, layer_id, model_name)
67 | plt.savefig(os.path.join(savedir, f"{model_name}_layer_{layer_id+1}.png"), bbox_inches="tight", dpi=200)
68 |
69 |
70 | def plot_layer_ax_sub(ax, mean, model_name):
71 | colors = ["cornflowerblue", "mediumseagreen", "C4", "teal", "dimgrey"]
72 |
73 | x_axis = np.arange(mean.shape[-1])+1
74 | for i in range(3):
75 | ax.plot(x_axis, mean[i], label=f"Top {i+1}", color=colors[i],
76 | linestyle="-", marker="o", markerfacecolor='none', markersize=5)
77 |
78 | ax.plot(x_axis, mean[-1], label=f"Median", color=colors[-1],
79 | linestyle="-", marker="v", markerfacecolor='none', markersize=5)
80 |
81 | ax.set_title(MODEL_TITLE_DICT[model_name], fontsize=18, fontweight="bold")
82 |
83 | num_layers = mean.shape[1]
84 | xtick_label = [1, num_layers//4, num_layers//2, num_layers*3//4, num_layers]
85 | ax.set_xticks(xtick_label, xtick_label, fontsize=16)
86 |
87 | ax.set_xlabel('Layers', fontsize=18, labelpad=0.8)
88 | ax.set_ylabel("Magnitudes", fontsize=18)
89 | ax.tick_params(axis='x', which='major', pad=1.0)
90 | ax.tick_params(axis='y', which='major', pad=0.4)
91 | ax.grid(axis='x', color='0.75')
92 |
93 | def plot_layer_ax(obj, model_name, savedir):
94 | fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(7.5, 4.5))
95 | fig.tight_layout() # Or equivalently, "plt.tight_layout()"
96 | plt.subplots_adjust(wspace=0.13)
97 |
98 | mean = np.mean(obj,axis=0)
99 | plot_layer_ax_sub(axs, mean, model_name)
100 | leg = axs.legend(
101 | loc='center', bbox_to_anchor=(0.5, -0.25),
102 | ncol=4, fancybox=True, prop={'size': 14}
103 | )
104 | leg.get_frame().set_edgecolor('silver')
105 | leg.get_frame().set_linewidth(1.0)
106 |
107 | plt.savefig(os.path.join(savedir, f"{model_name}.png"), bbox_inches="tight", dpi=200)
108 |
109 |
110 | def plot_attn_sub(ax, corr, model_name, layer_id):
111 | mask = np.zeros_like(corr)
112 | mask[np.triu_indices_from(mask, k=1)] = True
113 | sns.heatmap(corr, mask=mask, square=True, ax=ax,
114 | cmap="YlGnBu",cbar_kws={"shrink": 1.0, "pad": 0.01, "aspect":50})
115 |
116 | ax.set_facecolor("whitesmoke")
117 | cax = ax.figure.axes[-1]
118 | cax.tick_params(labelsize=18)
119 |
120 | ax.tick_params(axis='x', which='major')
121 | ax.set(xticklabels=[])
122 | ax.set(yticklabels=[])
123 | ax.tick_params(left=False, bottom=False)
124 | ax.set_title(f"{MODEL_TITLE_DICT[model_name]}, Layer {layer_id+1}", fontsize=24, fontweight="bold")
125 |
126 |
127 | def plot_attn(attn_logits, model_name, layer_id, savedir):
128 | fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(8, 4.75))
129 | fig.tight_layout() # Or equivalently, "plt.tight_layout()"
130 | plt.subplots_adjust(wspace=0.15)
131 |
132 | corr = attn_logits.numpy()[0].mean(0)
133 | corr = corr.astype("float64")
134 |
135 | plot_attn_sub(axs, corr, model_name, layer_id)
136 | plt.savefig(os.path.join(savedir, f"{model_name}_layer{layer_id+1}.pdf"), bbox_inches="tight", dpi=200)
--------------------------------------------------------------------------------
/lib/plot_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import matplotlib.pyplot as plt
6 | import seaborn as sns
7 | import matplotlib
8 | import matplotlib.gridspec as gridspec
9 | from matplotlib.patheffects import withStroke
10 | from collections import defaultdict
11 | from mpl_toolkits import mplot3d
12 | from mpl_toolkits.mplot3d import axes3d
13 |
14 | matplotlib.rcParams['pgf.texsystem'] = 'pdflatex'
15 | matplotlib.rcParams.update({
16 | # 'font.family': 'Arial',
17 | 'font.size': 18,
18 | 'axes.labelsize': 20,
19 | 'axes.titlesize': 24,
20 | 'figure.titlesize': 28
21 | })
22 | matplotlib.rcParams['text.usetex'] = False
23 |
24 | MODEL_TITLE_DICT={"llama2_7b": "LLaMA-2-7B", "mistral_7b": "Mistral-7B",
25 | "llama2_13b_chat": "LLaMA-2-13B-chat", "llama2_70b_chat": "LLaMA-2-70B-chat",
26 | "llama2_7b_chat": "LLaMA-2-7B-chat", "llama2_13b": "LLaMA-2-13B", "llama2_70b": "LLaMA-2-70B",
27 | "mistral_moe":"Mixtral-8x7B", "falcon_7b": "Falcon-7B", "falcon_40b": "Falcon-40B", "phi-2": "Phi-2",
28 | "opt_7b":"OPT-7B", "opt_13b": "OPT-13B", "opt_30b": "OPT-30B", "opt_66b": "OPT-66B",
29 | "mpt_7b": "MPT-7B", "mpt_30b": "MPT-30B", "pythia_7b": "Pythia-7B", "pythia_12b": "Pythia-12B",
30 | "gpt2": "GPT-2", "gpt2_large": "GPT-2-Large", "gpt2_xl": "GPT-2-XL", "gpt2_medium": "GPT-2-Medium",
31 | "mistral_moe_instruct": "Mixtral-8x7B-Instruct", "mistral_7b_instruct": "Mistral-7B-Instruct"}
32 |
33 |
34 | def plot_3d_feat_sub(ax, obj, seq_id, layer_id, model_name):
35 | num_tokens = len(obj[f"seq"])
36 | num_channels = obj[f"{layer_id}"].shape[2]
37 | inp_seq = obj[f"seq"]
38 | inp_seq = [x if x != "<0x0A>" else r"\n" for x in inp_seq]
39 | xdata = np.array([np.linspace(0,num_tokens-1,num_tokens) for i in range(num_channels)])
40 | ydata = np.array([np.ones(num_tokens) * i for i in range(num_channels)])
41 | zdata = obj[f"{layer_id}"][0].abs().numpy().T
42 | ax.plot_wireframe(xdata, ydata, zdata, rstride=0, color="royalblue", linewidth=2.5)
43 |
44 | ax.set_xticks(np.linspace(0,num_tokens-1,num_tokens), inp_seq,
45 | rotation=50, fontsize=16)
46 |
47 | ax.set_zticks([0, 1000, 2000], ["0", "1k", "2k"], fontsize=15)
48 | ax.set_yticks([1415, 2533], [1415, 2533], fontsize=15, fontweight="heavy")
49 | ax.get_xticklabels()[0].set_weight("heavy")
50 |
51 | if seq_id in [0, 1]:
52 | ax.get_xticklabels()[3].set_weight("heavy")
53 |
54 | ax.set_title(MODEL_TITLE_DICT[model_name], fontsize=18, fontweight="bold", y=1.015)
55 |
56 | plt.setp(ax.get_xticklabels(), rotation=50, ha="right", va="center",
57 | rotation_mode="anchor")
58 | plt.setp(ax.get_yticklabels(), ha="left",
59 | rotation_mode="anchor")
60 |
61 | ax.tick_params(axis='x', which='major', pad=-4)
62 | ax.tick_params(axis='y', which='major', pad=-5)
63 | ax.tick_params(axis='z', which='major', pad=-1)
64 | ax.set_zlim(0,2400)
65 |
66 | def plot_3d_feat(obj, layer_id, model_name, savedir):
67 | fig = plt.figure(figsize=(14,6))
68 | fig.tight_layout() # Or equivalently, "plt.tight_layout()"
69 | plt.subplots_adjust(wspace=0.13)
70 |
71 | # for i in range(3):
72 | ax = fig.add_subplot(1,1, 1, projection='3d')
73 | plot_3d_feat_sub(ax, obj, 0, layer_id, model_name)
74 | plt.savefig(os.path.join(savedir, f"{model_name}_layer_{layer_id+1}.png"), bbox_inches="tight", dpi=200)
75 |
76 |
77 | def plot_layer_ax_sub(ax, mean, model_name):
78 | colors = ["cornflowerblue", "mediumseagreen", "C4", "teal", "dimgrey"]
79 |
80 | x_axis = np.arange(mean.shape[-1])+1
81 | for i in range(3):
82 | ax.plot(x_axis, mean[i], label=f"Top {i+1}", color=colors[i],
83 | linestyle="-", marker="o", markerfacecolor='none', markersize=5)
84 |
85 | ax.plot(x_axis, mean[-1], label=f"Median", color=colors[-1],
86 | linestyle="-", marker="v", markerfacecolor='none', markersize=5)
87 |
88 | ax.set_title(MODEL_TITLE_DICT[model_name], fontsize=18, fontweight="bold")
89 |
90 | num_layers = mean.shape[1]
91 | xtick_label = [1, num_layers//4, num_layers//2, num_layers*3//4, num_layers]
92 | ax.set_xticks(xtick_label, xtick_label, fontsize=16)
93 |
94 | ax.set_xlabel('Layers', fontsize=18, labelpad=0.8)
95 | ax.set_ylabel("Magnitudes", fontsize=18)
96 | ax.tick_params(axis='x', which='major', pad=1.0)
97 | ax.tick_params(axis='y', which='major', pad=0.4)
98 | ax.grid(axis='x', color='0.75')
99 |
100 | def plot_layer_ax(obj, model_name, savedir):
101 | fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(7.5, 4.5))
102 | fig.tight_layout() # Or equivalently, "plt.tight_layout()"
103 | plt.subplots_adjust(wspace=0.13)
104 |
105 | mean = np.mean(obj,axis=0)
106 | plot_layer_ax_sub(axs, mean, model_name)
107 | leg = axs.legend(
108 | loc='center', bbox_to_anchor=(0.5, -0.25),
109 | ncol=4, fancybox=True, prop={'size': 14}
110 | )
111 | leg.get_frame().set_edgecolor('silver')
112 | leg.get_frame().set_linewidth(1.0)
113 |
114 | plt.savefig(os.path.join(savedir, f"{model_name}.png"), bbox_inches="tight", dpi=200)
115 |
116 |
117 | def plot_attn_sub(ax, corr, model_name, layer_id):
118 | mask = np.zeros_like(corr)
119 | mask[np.triu_indices_from(mask, k=1)] = True
120 | sns.heatmap(corr, mask=mask, square=True, ax=ax,
121 | cmap="YlGnBu",cbar_kws={"shrink": 1.0, "pad": 0.01, "aspect":50})
122 |
123 | ax.set_facecolor("whitesmoke")
124 | cax = ax.figure.axes[-1]
125 | cax.tick_params(labelsize=18)
126 |
127 | ax.tick_params(axis='x', which='major')
128 | ax.set(xticklabels=[])
129 | ax.set(yticklabels=[])
130 | ax.tick_params(left=False, bottom=False)
131 | ax.set_title(f"{MODEL_TITLE_DICT[model_name]}, Layer {layer_id+1}", fontsize=24, fontweight="bold")
132 |
133 |
134 | def plot_attn(attn_logits, model_name, layer_id, savedir):
135 | fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(8, 4.75))
136 | fig.tight_layout() # Or equivalently, "plt.tight_layout()"
137 | plt.subplots_adjust(wspace=0.15)
138 |
139 | corr = attn_logits.numpy()[0].mean(0)
140 | corr = corr.astype("float64")
141 |
142 | plot_attn_sub(axs, corr, model_name, layer_id)
143 | plt.savefig(os.path.join(savedir, f"{model_name}_layer{layer_id+1}.pdf"), bbox_inches="tight", dpi=200)
--------------------------------------------------------------------------------
/monkey_patch/modify_phi2.py:
--------------------------------------------------------------------------------
1 | import math
2 | import types
3 | from typing import Optional, Tuple
4 |
5 | import torch
6 | from torch import nn
7 | import torch.utils.checkpoint
8 | from transformers.models.llama.modeling_llama import (
9 | apply_rotary_pos_emb,
10 | repeat_kv,
11 | )
12 |
13 |
14 | def phi2_custom_attention_forward(
15 | self,
16 | hidden_states,
17 | attention_mask = None,
18 | position_ids = None,
19 | past_key_value = None,
20 | output_attentions = False,
21 | use_cache = False,
22 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
23 | bsz, q_len, _ = hidden_states.size()
24 |
25 | query_states = self.q_proj(hidden_states)
26 | key_states = self.k_proj(hidden_states)
27 | value_states = self.v_proj(hidden_states)
28 |
29 | ##################################################################
30 | self.query_states = query_states.detach().cpu().clone()
31 | self.key_states = key_states.detach().cpu().clone()
32 | self.value_states = value_states.detach().cpu().clone()
33 | ##################################################################
34 |
35 | if self.qk_layernorm:
36 | query_states = self.q_layernorm(query_states)
37 | key_states = self.k_layernorm(key_states)
38 |
39 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
40 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
41 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
42 |
43 | kv_seq_len = key_states.shape[-2]
44 | if past_key_value is not None:
45 | if self.layer_idx is None:
46 | raise ValueError(
47 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
48 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
49 | "with a layer index."
50 | )
51 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
53 |
54 | # Partial rotary embedding
55 | query_rot, query_pass = (
56 | query_states[..., : self.rotary_emb.dim],
57 | query_states[..., self.rotary_emb.dim :],
58 | )
59 | key_rot, key_pass = (
60 | key_states[..., : self.rotary_emb.dim],
61 | key_states[..., self.rotary_emb.dim :],
62 | )
63 | # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
64 | query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
65 |
66 | # [batch_size, seq_length, num_heads, head_dim]
67 | query_states = torch.cat((query_rot, query_pass), dim=-1)
68 | key_states = torch.cat((key_rot, key_pass), dim=-1)
69 |
70 | if past_key_value is not None:
71 | cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
72 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
73 |
74 | key_states = repeat_kv(key_states, self.num_key_value_groups)
75 | value_states = repeat_kv(value_states, self.num_key_value_groups)
76 |
77 | # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
78 | attn_weights = torch.matmul(
79 | query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
80 | ) / math.sqrt(self.head_dim)
81 |
82 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
83 | raise ValueError(
84 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
85 | f" {attn_weights.size()}"
86 | )
87 |
88 | if attention_mask is not None:
89 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
90 | raise ValueError(
91 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
92 | )
93 | attn_weights = attn_weights + attention_mask
94 |
95 | # ###################################################
96 | self.attn_logits = attn_weights
97 | # ###################################################
98 |
99 | # upcast attention to fp32
100 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
101 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
102 |
103 | # ###################################################
104 | self.attn_probs = attn_weights
105 | # ###################################################
106 | attn_output = torch.matmul(attn_weights, value_states)
107 |
108 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
109 | raise ValueError(
110 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
111 | f" {attn_output.size()}"
112 | )
113 |
114 | attn_output = attn_output.transpose(1, 2).contiguous()
115 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
116 |
117 | attn_output = self.dense(attn_output)
118 |
119 | if not output_attentions:
120 | attn_weights = None
121 |
122 | return attn_output, attn_weights, past_key_value
123 |
124 | def enable_phi2_custom_attention(layer, layer_id):
125 | modified_module = layer.self_attn
126 | modified_module.layer_id = layer_id
127 | modified_module.forward = types.MethodType(phi2_custom_attention_forward, modified_module)
128 |
129 | return modified_module
130 |
131 | def phi2_custom_decoderlayer_forward(
132 | self,
133 | hidden_states: torch.Tensor,
134 | attention_mask: Optional[torch.Tensor] = None,
135 | position_ids: Optional[torch.LongTensor] = None,
136 | output_attentions: Optional[bool] = False,
137 | use_cache: Optional[bool] = False,
138 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
139 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
140 | residual = hidden_states
141 |
142 | hidden_states = self.input_layernorm(hidden_states)
143 |
144 | # Self Attention
145 | attn_outputs, self_attn_weights, present_key_value = self.self_attn(
146 | hidden_states=hidden_states,
147 | attention_mask=attention_mask,
148 | position_ids=position_ids,
149 | past_key_value=past_key_value,
150 | output_attentions=output_attentions,
151 | use_cache=use_cache,
152 | )
153 | attn_outputs = self.resid_dropout(attn_outputs)
154 |
155 | feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
156 | hidden_states = attn_outputs + feed_forward_hidden_states + residual
157 |
158 | self.feat = hidden_states.clone().detach().cpu().double()
159 | outputs = (hidden_states,)
160 |
161 | if output_attentions:
162 | outputs += (self_attn_weights,)
163 |
164 | if use_cache:
165 | outputs += (present_key_value,)
166 |
167 | return outputs
168 |
169 | def enable_phi2_custom_decoderlayer(layer, layer_id):
170 | layer.layer_id = layer_id
171 | layer.forward = types.MethodType(
172 | phi2_custom_decoderlayer_forward, layer
173 | )
174 |
--------------------------------------------------------------------------------
/monkey_patch/modify_mistral.py:
--------------------------------------------------------------------------------
1 | import math
2 | import types
3 | import warnings
4 | from typing import Optional, Tuple
5 |
6 | import torch
7 | from torch import nn
8 | import torch.utils.checkpoint
9 | from transformers.models.mistral.modeling_mistral import (
10 | apply_rotary_pos_emb,
11 | repeat_kv,
12 | )
13 |
14 |
15 | def mistral_custom_decoderlayer_forward(
16 | self,
17 | hidden_states: torch.Tensor,
18 | attention_mask: Optional[torch.Tensor] = None,
19 | position_ids: Optional[torch.LongTensor] = None,
20 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
21 | output_attentions: Optional[bool] = False,
22 | use_cache: Optional[bool] = False,
23 | **kwargs,
24 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
25 | if "padding_mask" in kwargs:
26 | warnings.warn(
27 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
28 | )
29 |
30 | residual = hidden_states
31 |
32 | hidden_states = self.input_layernorm(hidden_states)
33 |
34 | # Self Attention
35 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
36 | hidden_states=hidden_states,
37 | attention_mask=attention_mask,
38 | position_ids=position_ids,
39 | past_key_value=past_key_value,
40 | output_attentions=output_attentions,
41 | use_cache=use_cache,
42 | )
43 |
44 | if residual.device.index != hidden_states.device.index:
45 | residual = residual.to(hidden_states.device)
46 | hidden_states = residual + hidden_states
47 |
48 | # Fully Connected
49 | residual = hidden_states
50 | hidden_states = self.post_attention_layernorm(hidden_states)
51 | hidden_states = self.mlp(hidden_states)
52 | hidden_states = residual + hidden_states
53 |
54 | self.feat = hidden_states.clone().detach().cpu().double()
55 |
56 | outputs = (hidden_states,)
57 |
58 | if output_attentions:
59 | outputs += (self_attn_weights,)
60 |
61 | if use_cache:
62 | outputs += (present_key_value,)
63 |
64 | return outputs
65 |
66 | def enable_mistral_custom_decoderlayer(layer, layer_id):
67 | """
68 | replace the forward function of MistralDecoderLayer with a custom forward function `mistral_custom_decoderlayer_forward`
69 | """
70 | layer.layer_id = layer_id
71 | layer.forward = types.MethodType(
72 | mistral_custom_decoderlayer_forward, layer
73 | )
74 |
75 | def mistral_custom_attention_forward(
76 | self,
77 | hidden_states,
78 | attention_mask = None,
79 | position_ids = None,
80 | past_key_value = None,
81 | output_attentions: bool = False,
82 | use_cache: bool = False,
83 | **kwargs,
84 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
85 | if "padding_mask" in kwargs:
86 | warnings.warn(
87 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
88 | )
89 | bsz, q_len, _ = hidden_states.size()
90 |
91 | query_states = self.q_proj(hidden_states)
92 | key_states = self.k_proj(hidden_states)
93 | value_states = self.v_proj(hidden_states)
94 |
95 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
96 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
97 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
98 |
99 | # ###################################################
100 | # self.hidden_states = hidden_states.detach().cpu().clone()
101 | self.query_states = query_states.detach().cpu().clone()
102 | self.key_states = key_states.detach().cpu().clone()
103 | self.value_states = value_states.detach().cpu().clone()
104 | # ###################################################
105 |
106 | kv_seq_len = key_states.shape[-2]
107 | if past_key_value is not None:
108 | if self.layer_idx is None:
109 | raise ValueError(
110 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
111 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
112 | "with a layer index."
113 | )
114 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
115 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
116 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
117 |
118 | if past_key_value is not None:
119 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
120 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
121 |
122 | # repeat k/v heads if n_kv_heads < n_heads
123 | key_states = repeat_kv(key_states, self.num_key_value_groups)
124 | value_states = repeat_kv(value_states, self.num_key_value_groups)
125 |
126 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
127 |
128 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
129 | raise ValueError(
130 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
131 | f" {attn_weights.size()}"
132 | )
133 |
134 | if attention_mask is not None:
135 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
136 | raise ValueError(
137 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
138 | )
139 |
140 | attn_weights = attn_weights + attention_mask
141 |
142 | # ###################################################
143 | self.attn_logits = attn_weights
144 | # ###################################################
145 |
146 | # upcast attention to fp32
147 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
148 |
149 | # ###################################################
150 | self.attn_probs = attn_weights
151 | # ###################################################
152 |
153 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
154 | attn_output = torch.matmul(attn_weights, value_states)
155 |
156 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
157 | raise ValueError(
158 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
159 | f" {attn_output.size()}"
160 | )
161 |
162 | attn_output = attn_output.transpose(1, 2).contiguous()
163 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
164 |
165 | attn_output = self.o_proj(attn_output)
166 |
167 | if not output_attentions:
168 | attn_weights = None
169 |
170 | return attn_output, attn_weights, past_key_value
171 |
172 | def enable_mistral_custom_attention(layer, layer_id):
173 | modified_module = layer.self_attn
174 | modified_module.layer_id = layer_id
175 | modified_module.forward = types.MethodType(mistral_custom_attention_forward, modified_module)
176 |
177 | return modified_module
178 |
--------------------------------------------------------------------------------
/monkey_patch/modify_llama.py:
--------------------------------------------------------------------------------
1 | import math
2 | import types
3 | import warnings
4 | from typing import Optional, Tuple
5 |
6 | import torch
7 | from torch import nn
8 | import torch.nn.functional as F
9 | import torch.utils.checkpoint
10 | from transformers.models.llama.modeling_llama import (
11 | apply_rotary_pos_emb,
12 | repeat_kv,
13 | rotate_half,
14 | )
15 |
16 |
17 | def llama_custom_decoderlayer_forward(
18 | self,
19 | hidden_states: torch.Tensor,
20 | attention_mask: Optional[torch.Tensor] = None,
21 | position_ids: Optional[torch.LongTensor] = None,
22 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
23 | output_attentions: Optional[bool] = False,
24 | use_cache: Optional[bool] = False,
25 | **kwargs,
26 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
27 | residual = hidden_states
28 |
29 | hidden_states = self.input_layernorm(hidden_states)
30 |
31 | # Self Attention
32 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
33 | hidden_states=hidden_states,
34 | attention_mask=attention_mask,
35 | position_ids=position_ids,
36 | past_key_value=past_key_value,
37 | output_attentions=output_attentions,
38 | use_cache=use_cache,
39 | **kwargs,
40 | )
41 |
42 | if residual.device.index != hidden_states.device.index:
43 | residual = residual.to(hidden_states.device)
44 | hidden_states = residual + hidden_states
45 |
46 | # Fully Connected
47 | residual = hidden_states
48 | hidden_states = self.post_attention_layernorm(hidden_states)
49 | hidden_states = self.mlp(hidden_states)
50 | hidden_states = residual + hidden_states
51 |
52 | self.feat = hidden_states.clone().detach().cpu().double()
53 |
54 | outputs = (hidden_states,)
55 |
56 | if output_attentions:
57 | outputs += (self_attn_weights,)
58 |
59 | if use_cache:
60 | outputs += (present_key_value,)
61 |
62 | return outputs
63 |
64 | def enable_llama_custom_decoderlayer(layer, layer_id):
65 | """
66 | replace the forward function of LlamaDecoderLayer with a custom forward function `llama_custom_decoderlayer_forward`
67 | """
68 | layer.layer_id = layer_id
69 | layer.forward = types.MethodType(
70 | llama_custom_decoderlayer_forward, layer
71 | )
72 |
73 |
74 | def apply_rotary_pos_emb_single(q, k, cos, sin, position_ids):
75 | cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
76 | sin = sin[position_ids].unsqueeze(1)
77 | q_embed = (q * cos) + (rotate_half(q) * sin)
78 | k_embed = (k * cos) + (rotate_half(k) * sin)
79 | return q_embed, k_embed
80 |
81 | def llama_custom_attention_forward(
82 | self,
83 | hidden_states,
84 | attention_mask = None,
85 | position_ids = None,
86 | past_key_value = None,
87 | output_attentions = False,
88 | use_cache = False,
89 | **kwargs,
90 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
91 | if "padding_mask" in kwargs:
92 | warnings.warn(
93 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
94 | )
95 |
96 | bsz, q_len, _ = hidden_states.size()
97 |
98 | query_states = self.q_proj(hidden_states)
99 | key_states = self.k_proj(hidden_states)
100 | value_states = self.v_proj(hidden_states)
101 |
102 | ##################################################################
103 | self.query_states = query_states.detach().cpu().clone()
104 | self.key_states = key_states.detach().cpu().clone()
105 | self.value_states = value_states.detach().cpu().clone()
106 | # ###################################################
107 |
108 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
109 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
110 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
111 |
112 | # ##################################################################
113 | # self.query_states = query_states.detach().cpu().clone()
114 | # self.key_states = key_states.detach().cpu().clone()
115 | # self.value_states = value_states.detach().cpu().clone()
116 | # # ###################################################
117 |
118 | kv_seq_len = key_states.shape[-2]
119 | if past_key_value is not None:
120 | if self.layer_idx is None:
121 | raise ValueError(
122 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
123 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
124 | "with a layer index."
125 | )
126 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
127 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
128 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
129 |
130 | if past_key_value is not None:
131 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
132 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
133 |
134 | key_states = repeat_kv(key_states, self.num_key_value_groups)
135 | value_states = repeat_kv(value_states, self.num_key_value_groups)
136 |
137 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
138 |
139 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
140 | raise ValueError(
141 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
142 | f" {attn_weights.size()}"
143 | )
144 |
145 | if attention_mask is not None:
146 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
147 | raise ValueError(
148 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
149 | )
150 | attn_weights = attn_weights + attention_mask
151 |
152 | # ###################################################
153 | self.attn_logits = attn_weights
154 | # ###################################################
155 |
156 | # upcast attention to fp32
157 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
158 |
159 | # ###################################################
160 | self.attn_probs = attn_weights
161 | # ###################################################
162 |
163 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
164 | attn_output = torch.matmul(attn_weights, value_states)
165 |
166 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
167 | raise ValueError(
168 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
169 | f" {attn_output.size()}"
170 | )
171 |
172 | attn_output = attn_output.transpose(1, 2).contiguous()
173 |
174 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
175 |
176 | if self.config.pretraining_tp > 1:
177 | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
178 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
179 | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
180 | else:
181 | attn_output = self.o_proj(attn_output)
182 |
183 | if not output_attentions:
184 | attn_weights = None
185 |
186 | return attn_output, attn_weights, past_key_value
187 |
188 | def enable_llama_custom_attention(layer, layer_id):
189 | """
190 | replace the forward function of LlamaAttention with a custom forward function `llama_custom_attention_forward`
191 | """
192 | modified_module = layer.self_attn
193 | modified_module.layer_id = layer_id
194 | modified_module.forward = types.MethodType(llama_custom_attention_forward, modified_module)
195 |
196 | return modified_module
--------------------------------------------------------------------------------
/gpt-2/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import time
4 | from contextlib import nullcontext
5 |
6 | import numpy as np
7 | import torch
8 | from torch.distributed import destroy_process_group, init_process_group
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 |
11 | # -----------------------------------------------------------------------------
12 | # default config values designed to train a gpt2 (124M) on OpenWebText
13 | # I/O
14 | out_dir = 'logs_ckpt/'
15 | save_dir = "results/"
16 | data_dir = "data/"
17 | model_type="gpt2_default"
18 | num_reg = 0
19 |
20 | eval_interval = 2000
21 | log_interval = 1
22 | eval_iters = 200
23 | eval_only = False # if True, script exits right after the first eval
24 | always_save_checkpoint = True # if True, always save a checkpoint after each eval
25 | init_from = 'resume' # 'scratch' or 'resume' or 'gpt2*'
26 | ckpt_iter = 50000
27 | # wandb logging
28 | wandb_log = False # disabled by default
29 | wandb_project = 'owt'
30 | wandb_run_name = 'gpt2' # 'run' + str(time.time())
31 | # data
32 | dataset = 'openwebtext'
33 | gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
34 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
35 | block_size = 1024
36 | # model
37 | n_layer = 12
38 | n_head = 12
39 | n_embd = 768
40 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
41 | bias = False # do we use bias inside LayerNorm and Linear layers?
42 | # adamw optimizer
43 | optim_name="adam"
44 | learning_rate = 6e-4 # max learning rate
45 | max_iters = 600000 # total number of training iterations
46 | weight_decay = 1e-1
47 | beta1 = 0.9
48 | beta2 = 0.95
49 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
50 | # learning rate decay settings
51 | decay_lr = True # whether to decay the learning rate
52 | warmup_iters = 2000 # how many steps to warm up for
53 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
54 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
55 | # DDP settings
56 | backend = 'nccl' # 'nccl', 'gloo', etc.
57 | # system
58 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
59 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
60 | compile = True # use PyTorch 2.0 to compile the model to be faster
61 | # -----------------------------------------------------------------------------
62 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
63 | exec(open('configurator.py').read()) # overrides from command line or config file
64 | config = {k: globals()[k] for k in config_keys} # will be useful for logging
65 | # -----------------------------------------------------------------------------
66 |
67 | if model_type == "gpt2_default":
68 | from model_default import GPTConfig, GPT
69 | elif model_type == "gpt2_sink":
70 | from model_sink import GPTConfig, GPT
71 | elif model_type == "gpt2_attn_bias":
72 | from model_attn_bias import GPTConfig, GPT
73 | else:
74 | raise ValueError(f"model_type {model_type} not supported")
75 |
76 | # various inits, derived attributes, I/O setup
77 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
78 | if ddp:
79 | init_process_group(backend=backend)
80 | ddp_rank = int(os.environ['RANK'])
81 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
82 | ddp_world_size = int(os.environ['WORLD_SIZE'])
83 | device = f'cuda:{ddp_local_rank}'
84 | torch.cuda.set_device(device)
85 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
86 | seed_offset = ddp_rank # each process gets a different seed
87 | # world_size number of processes will be training simultaneously, so we can scale
88 | # down the desired gradient accumulation iterations per process proportionally
89 | assert gradient_accumulation_steps % ddp_world_size == 0
90 | gradient_accumulation_steps //= ddp_world_size
91 | else:
92 | # if not ddp, we are running on a single gpu, and one process
93 | master_process = True
94 | seed_offset = 0
95 | ddp_world_size = 1
96 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
97 | print(f"tokens per iteration will be: {tokens_per_iter:,}")
98 |
99 | if master_process:
100 | os.makedirs(out_dir, exist_ok=True)
101 | torch.manual_seed(1337 + seed_offset)
102 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
103 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
104 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
105 | # note: float16 data type will automatically use a GradScaler
106 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
107 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
108 |
109 | # poor man's data loader
110 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
111 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
112 | def get_batch(split):
113 | data = train_data if split == 'train' else val_data
114 | ix = torch.randint(len(data) - block_size, (batch_size,))
115 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
116 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
117 | if device_type == 'cuda':
118 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
119 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
120 | else:
121 | x, y = x.to(device), y.to(device)
122 | return x, y
123 |
124 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
125 | iter_num = 0
126 | best_val_loss = 1e9
127 |
128 | # attempt to derive vocab_size from the dataset
129 | meta_path = os.path.join(data_dir, 'meta.pkl')
130 | meta_vocab_size = None
131 | if os.path.exists(meta_path):
132 | with open(meta_path, 'rb') as f:
133 | meta = pickle.load(f)
134 | meta_vocab_size = meta['vocab_size']
135 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
136 |
137 | # model init
138 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
139 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
140 | if init_from == 'scratch':
141 | # init a new model from scratch
142 | print("Initializing a new model from scratch")
143 | # determine the vocab size we'll use for from-scratch training
144 | if meta_vocab_size is None:
145 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
146 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
147 | gptconf = GPTConfig(**model_args)
148 | model = GPT(gptconf)
149 | elif init_from == 'resume':
150 | print(f"Resuming training from {out_dir}")
151 | # resume training from a checkpoint.
152 | ckpt_path = os.path.join(out_dir, f'ckpt_{ckpt_iter}.pt')
153 | checkpoint = torch.load(ckpt_path, map_location=device)
154 | checkpoint_model_args = checkpoint['model_args']
155 | # force these config attributes to be equal otherwise we can't even resume training
156 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
157 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
158 | model_args[k] = checkpoint_model_args[k]
159 | # create the model
160 | gptconf = GPTConfig(**model_args)
161 | model = GPT(gptconf)
162 | state_dict = checkpoint['model']
163 | # fix the keys of the state dictionary :(
164 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
165 | unwanted_prefix = '_orig_mod.'
166 | for k,v in list(state_dict.items()):
167 | if k.startswith(unwanted_prefix):
168 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
169 | model.load_state_dict(state_dict)
170 | iter_num = checkpoint['iter_num']
171 | best_val_loss = checkpoint['best_val_loss']
172 | elif init_from.startswith('gpt2'):
173 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
174 | # initialize from OpenAI GPT-2 weights
175 | override_args = dict(dropout=dropout)
176 | model = GPT.from_pretrained(init_from, override_args)
177 | # read off the created config params, so we can store them into checkpoint correctly
178 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
179 | model_args[k] = getattr(model.config, k)
180 | # crop down the model block size if desired, using model surgery
181 | if block_size < model.config.block_size:
182 | model.crop_block_size(block_size)
183 | model_args['block_size'] = block_size # so that the checkpoint will have the right value
184 | model.to(device)
185 |
186 |
187 | # initialize a GradScaler. If enabled=False scaler is a no-op
188 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
189 |
190 | # optimizer
191 | optimizer = model.configure_optimizers(optim_name, weight_decay, learning_rate, (beta1, beta2), device_type)
192 | if init_from == 'resume':
193 | optimizer.load_state_dict(checkpoint['optimizer'])
194 | checkpoint = None # free up memory
195 |
196 | # compile the model
197 | if compile:
198 | print("compiling the model... (takes a ~minute)")
199 | unoptimized_model = model
200 | model = torch.compile(model) # requires PyTorch 2.0
201 |
202 | # wrap model into DDP container
203 | if ddp:
204 | model = DDP(model, device_ids=[ddp_local_rank])
205 |
206 | # helps estimate an arbitrarily accurate loss over either split using many batches
207 | @torch.no_grad()
208 | def estimate_loss():
209 | out = {}
210 | model.eval()
211 | for split in ['train', 'val']:
212 | losses = torch.zeros(eval_iters)
213 | for k in range(eval_iters):
214 | X, Y = get_batch(split)
215 | with ctx:
216 | logits, loss = model(X, Y)
217 | losses[k] = loss.item()
218 | out[split] = losses.mean()
219 | model.train()
220 | return out
221 |
222 | # training loop
223 | t0 = time.time()
224 | local_iter_num = 0 # number of iterations in the lifetime of this process
225 | raw_model = model.module if ddp else model # unwrap DDP container if needed
226 | running_mfu = -1.0
227 |
228 | losses = estimate_loss()
229 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
230 |
--------------------------------------------------------------------------------
/gpt-2/analyze.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from contextlib import nullcontext
4 | import types
5 |
6 | import numpy as np
7 | import torch
8 | from torch.nn.parallel import DistributedDataParallel as DDP
9 | from torch.distributed import init_process_group, destroy_process_group
10 |
11 | import tiktoken
12 | from plot_gpt2 import *
13 |
14 | # -----------------------------------------------------------------------------
15 | # default config values designed to train a gpt2 (124M) on OpenWebText
16 | # I/O
17 | out_dir = 'logs_ckpt/'
18 | save_dir = "results/"
19 | data_dir = "data/"
20 | model_type="gpt2_default"
21 | num_reg = 0
22 |
23 | eval_interval = 2000
24 | log_interval = 1
25 | eval_iters = 1
26 | eval_only = False # if True, script exits right after the first eval
27 | always_save_checkpoint = True # if True, always save a checkpoint after each eval
28 | init_from = 'resume' # 'scratch' or 'resume' or 'gpt2*'
29 | ckpt_iter = 50000
30 | # wandb logging
31 | wandb_log = False # disabled by default
32 | wandb_project = 'owt'
33 | wandb_run_name = 'gpt2' # 'run' + str(time.time())
34 | # data
35 | dataset = 'openwebtext'
36 | gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
37 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
38 | block_size = 1024
39 | # model
40 | n_layer = 12
41 | n_head = 12
42 | n_embd = 768
43 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
44 | bias = False # do we use bias inside LayerNorm and Linear layers?
45 | # adamw optimizer
46 | optim_name="adam"
47 | learning_rate = 6e-4 # max learning rate
48 | max_iters = 600000 # total number of training iterations
49 | weight_decay = 1e-1
50 | beta1 = 0.9
51 | beta2 = 0.95
52 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
53 | # learning rate decay settings
54 | decay_lr = True # whether to decay the learning rate
55 | warmup_iters = 2000 # how many steps to warm up for
56 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
57 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
58 | # DDP settings
59 | backend = 'nccl' # 'nccl', 'gloo', etc.
60 | # system
61 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
62 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
63 | compile = True # use PyTorch 2.0 to compile the model to be faster
64 | # -----------------------------------------------------------------------------
65 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
66 | exec(open('configurator.py').read()) # overrides from command line or config file
67 | config = {k: globals()[k] for k in config_keys} # will be useful for logging
68 | # -----------------------------------------------------------------------------
69 |
70 | if model_type == "gpt2_default":
71 | from model_default import GPTConfig, GPT
72 | elif model_type == "gpt2_sink":
73 | from model_sink import GPTConfig, GPT
74 | elif model_type == "gpt2_attn_bias":
75 | from model_attn_bias import GPTConfig, GPT
76 | else:
77 | raise ValueError(f"model_type {model_type} not supported")
78 |
79 |
80 | if not os.path.exists(save_dir):
81 | os.makedirs(save_dir)
82 |
83 | # various inits, derived attributes, I/O setup
84 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
85 | if ddp:
86 | init_process_group(backend=backend)
87 | ddp_rank = int(os.environ['RANK'])
88 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
89 | ddp_world_size = int(os.environ['WORLD_SIZE'])
90 | device = f'cuda:{ddp_local_rank}'
91 | torch.cuda.set_device(device)
92 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
93 | seed_offset = ddp_rank # each process gets a different seed
94 | # world_size number of processes will be training simultaneously, so we can scale
95 | # down the desired gradient accumulation iterations per process proportionally
96 | assert gradient_accumulation_steps % ddp_world_size == 0
97 | gradient_accumulation_steps //= ddp_world_size
98 | else:
99 | # if not ddp, we are running on a single gpu, and one process
100 | master_process = True
101 | seed_offset = 0
102 | ddp_world_size = 1
103 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
104 | print(f"tokens per iteration will be: {tokens_per_iter:,}")
105 |
106 | if master_process:
107 | os.makedirs(out_dir, exist_ok=True)
108 | torch.manual_seed(1337 + seed_offset)
109 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
110 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
111 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
112 | # note: float16 data type will automatically use a GradScaler
113 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
114 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
115 |
116 | # poor man's data loader
117 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
118 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
119 | def get_batch(split):
120 | data = train_data if split == 'train' else val_data
121 | ix = torch.randint(len(data) - block_size, (batch_size,))
122 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
123 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
124 | if device_type == 'cuda':
125 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
126 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
127 | else:
128 | x, y = x.to(device), y.to(device)
129 | return x, y
130 |
131 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
132 | iter_num = 0
133 | best_val_loss = 1e9
134 |
135 | # attempt to derive vocab_size from the dataset
136 | meta_path = os.path.join(data_dir, 'meta.pkl')
137 | meta_vocab_size = None
138 | if os.path.exists(meta_path):
139 | with open(meta_path, 'rb') as f:
140 | meta = pickle.load(f)
141 | meta_vocab_size = meta['vocab_size']
142 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
143 |
144 | # model init
145 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
146 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
147 | if init_from == 'scratch':
148 | # init a new model from scratch
149 | print("Initializing a new model from scratch")
150 | # determine the vocab size we'll use for from-scratch training
151 | if meta_vocab_size is None:
152 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
153 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
154 | gptconf = GPTConfig(**model_args)
155 | model = GPT(gptconf)
156 | elif init_from == 'resume':
157 | print(f"Resuming training from {out_dir}")
158 | # resume training from a checkpoint.
159 | ckpt_path = os.path.join(out_dir, f'ckpt_{ckpt_iter}.pt')
160 | checkpoint = torch.load(ckpt_path, map_location=device)
161 | checkpoint_model_args = checkpoint['model_args']
162 | # force these config attributes to be equal otherwise we can't even resume training
163 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
164 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
165 | model_args[k] = checkpoint_model_args[k]
166 | # create the model
167 | gptconf = GPTConfig(**model_args)
168 | model = GPT(gptconf)
169 | state_dict = checkpoint['model']
170 | # fix the keys of the state dictionary :(
171 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
172 | unwanted_prefix = '_orig_mod.'
173 | for k,v in list(state_dict.items()):
174 | if k.startswith(unwanted_prefix):
175 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
176 | model.load_state_dict(state_dict)
177 | iter_num = checkpoint['iter_num']
178 | best_val_loss = checkpoint['best_val_loss']
179 | elif init_from.startswith('gpt2'):
180 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
181 | # initialize from OpenAI GPT-2 weights
182 | override_args = dict(dropout=dropout)
183 | model = GPT.from_pretrained(init_from, override_args)
184 | # read off the created config params, so we can store them into checkpoint correctly
185 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
186 | model_args[k] = getattr(model.config, k)
187 | # crop down the model block size if desired, using model surgery
188 | if block_size < model.config.block_size:
189 | model.crop_block_size(block_size)
190 | model_args['block_size'] = block_size # so that the checkpoint will have the right value
191 | model.to(device)
192 |
193 |
194 | # initialize a GradScaler. If enabled=False scaler is a no-op
195 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
196 |
197 | # optimizer
198 | optimizer = model.configure_optimizers(optim_name, weight_decay, learning_rate, (beta1, beta2), device_type)
199 | if init_from == 'resume':
200 | optimizer.load_state_dict(checkpoint['optimizer'])
201 | checkpoint = None # free up memory
202 |
203 | # compile the model
204 | if compile:
205 | print("compiling the model... (takes a ~minute)")
206 | unoptimized_model = model
207 | model = torch.compile(model) # requires PyTorch 2.0
208 |
209 | # wrap model into DDP container
210 | if ddp:
211 | model = DDP(model, device_ids=[ddp_local_rank])
212 |
213 | def custom_block_forward(self, x):
214 | x = x + self.attn(self.ln_1(x))
215 | x = x + self.mlp(self.ln_2(x))
216 | self.feat = x.cpu().clone().detach()
217 | return x
218 |
219 |
220 |
221 | layers = model.transformer.h
222 | for layer_id in range(len(model.transformer.h)):
223 | layer = layers[layer_id]
224 | layer.forward = types.MethodType(custom_block_forward, layer)
225 |
226 |
227 | stats = []
228 | eval_iters=1
229 | batch_size=1
230 | for seqid in range(3):
231 | print("current seqid: ", seqid)
232 | X, Y = get_batch("val")
233 | with torch.no_grad():
234 | model(X, Y)
235 |
236 | seq_np = np.zeros((202, len(layers)))
237 | for layer_id in range(len(layers)):
238 | feat_abs = layers[layer_id].feat.abs()
239 | sort_res = torch.sort(feat_abs.flatten(), descending=True)
240 | seq_np[:200, layer_id] = sort_res.values[:200]
241 | seq_np[-2, layer_id] = torch.mean(feat_abs)
242 | seq_np[-1, layer_id] = torch.median(feat_abs)
243 |
244 | stats.append(seq_np)
245 |
246 | plot_layer_ax_gpt2(np.mean(np.array(stats), axis=0), model_type, save_dir)
247 |
248 | ##########################################################################
249 | enc = tiktoken.get_encoding("gpt2")
250 | encoding = enc.encode("Summer is warm. Winter is cold.")
251 |
252 | layer_id = 5
253 | layer = model.transformer.h[layer_id]
254 | layer.forward = types.MethodType(custom_block_forward, layer)
255 |
256 |
257 | input_seq = torch.Tensor([encoding]).type(torch.int64).to(device)
258 |
259 | with torch.no_grad():
260 | model(input_seq, None)
261 |
262 | stats = {}
263 | feat_abs = layer.feat.abs()
264 | stats[f"seq0_{layer_id}"] = feat_abs
265 |
266 | inp_seq = ["Summer", "is", "warm", ".", "Winter", "is", "cold", "."]
267 | plot_3d_ax_gpt2(feat_abs, model_type, inp_seq, save_dir)
--------------------------------------------------------------------------------
/gpt-2/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import math
4 | import pickle
5 | from contextlib import nullcontext
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | from torch.distributed import init_process_group, destroy_process_group
11 |
12 | # -----------------------------------------------------------------------------
13 | # default config values designed to train a gpt2 (124M) on OpenWebText
14 | # I/O
15 | out_dir = 'logs_ckpt/'
16 | save_dir = "results/"
17 | data_dir = "data/"
18 | model_type = "gpt2_default"
19 | num_reg = 0
20 |
21 | eval_interval = 10000
22 | log_interval = 1
23 | eval_iters = 400
24 | eval_only = False # if True, script exits right after the first eval
25 | always_save_checkpoint = True # if True, always save a checkpoint after each eval
26 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
27 | # wandb logging
28 | wandb_log = False # disabled by default
29 | wandb_project = 'owt'
30 | wandb_run_name = 'gpt2' # 'run' + str(time.time())
31 | # data
32 | dataset = 'openwebtext'
33 | gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
34 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
35 | block_size = 1024
36 | # model
37 | n_layer = 12
38 | n_head = 12
39 | n_embd = 768
40 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
41 | bias = False # do we use bias inside LayerNorm and Linear layers?
42 | # adamw optimizer
43 | optim_name="adam"
44 | learning_rate = 6e-4 # max learning rate
45 | max_iters = 600000 # total number of training iterations
46 | weight_decay = 1e-1
47 | beta1 = 0.9
48 | beta2 = 0.95
49 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
50 | # learning rate decay settings
51 | decay_lr = True # whether to decay the learning rate
52 | warmup_iters = 2000 # how many steps to warm up for
53 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
54 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
55 | # DDP settings
56 | backend = 'nccl' # 'nccl', 'gloo', etc.
57 | # system
58 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
59 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
60 | compile = True # use PyTorch 2.0 to compile the model to be faster
61 | # -----------------------------------------------------------------------------
62 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
63 | exec(open('configurator.py').read()) # overrides from command line or config file
64 | config = {k: globals()[k] for k in config_keys} # will be useful for logging
65 | # -----------------------------------------------------------------------------
66 |
67 | if model_type == "gpt2_default":
68 | from model_default import GPTConfig, GPT
69 | elif model_type == "gpt2_sink":
70 | from model_sink import GPTConfig, GPT
71 | elif model_type == "gpt2_attn_bias":
72 | from model_attn_bias import GPTConfig, GPT
73 | else:
74 | raise ValueError(f"model_type {model_type} not supported")
75 |
76 | # various inits, derived attributes, I/O setup
77 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
78 | if ddp:
79 | init_process_group(backend=backend)
80 | ddp_rank = int(os.environ['RANK'])
81 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
82 | ddp_world_size = int(os.environ['WORLD_SIZE'])
83 | device = f'cuda:{ddp_local_rank}'
84 | torch.cuda.set_device(device)
85 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
86 | seed_offset = ddp_rank # each process gets a different seed
87 | # world_size number of processes will be training simultaneously, so we can scale
88 | # down the desired gradient accumulation iterations per process proportionally
89 | assert gradient_accumulation_steps % ddp_world_size == 0
90 | gradient_accumulation_steps //= ddp_world_size
91 | else:
92 | # if not ddp, we are running on a single gpu, and one process
93 | master_process = True
94 | seed_offset = 0
95 | ddp_world_size = 1
96 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
97 | print(f"tokens per iteration will be: {tokens_per_iter:,}")
98 |
99 | if master_process:
100 | os.makedirs(out_dir, exist_ok=True)
101 | torch.manual_seed(1337 + seed_offset)
102 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
103 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
104 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
105 | # note: float16 data type will automatically use a GradScaler
106 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
107 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
108 |
109 | # poor man's data loader
110 | # data_dir = os.path.join('data', dataset)
111 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
112 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
113 | def get_batch(split):
114 | data = train_data if split == 'train' else val_data
115 | ix = torch.randint(len(data) - block_size, (batch_size,))
116 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
117 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
118 | if device_type == 'cuda':
119 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
120 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
121 | else:
122 | x, y = x.to(device), y.to(device)
123 | return x, y
124 |
125 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
126 | iter_num = 0
127 | best_val_loss = 1e9
128 |
129 | # attempt to derive vocab_size from the dataset
130 | meta_path = os.path.join(data_dir, 'meta.pkl')
131 | meta_vocab_size = None
132 | if os.path.exists(meta_path):
133 | with open(meta_path, 'rb') as f:
134 | meta = pickle.load(f)
135 | meta_vocab_size = meta['vocab_size']
136 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
137 |
138 | # model init
139 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
140 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
141 | if init_from == 'scratch':
142 | # init a new model from scratch
143 | print("Initializing a new model from scratch")
144 | # determine the vocab size we'll use for from-scratch training
145 | if meta_vocab_size is None:
146 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
147 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
148 | gptconf = GPTConfig(**model_args)
149 | model = GPT(gptconf)
150 | elif init_from == 'resume':
151 | # resume training from a checkpoint.
152 | # ckpt_path = os.path.join(out_dir, 'ckpt_275000.pt')
153 | ckpt_path = os.path.join(out_dir, 'ckpt_15000.pt')
154 | print(f"******************************************************")
155 | print(f"Resuming training from {ckpt_path}")
156 | print(f"******************************************************")
157 | checkpoint = torch.load(ckpt_path, map_location=device)
158 | checkpoint_model_args = checkpoint['model_args']
159 | # force these config attributes to be equal otherwise we can't even resume training
160 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
161 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
162 | model_args[k] = checkpoint_model_args[k]
163 | # create the model
164 | gptconf = GPTConfig(**model_args)
165 | model = GPT(gptconf)
166 | state_dict = checkpoint['model']
167 | # fix the keys of the state dictionary :(
168 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
169 | unwanted_prefix = '_orig_mod.'
170 | for k,v in list(state_dict.items()):
171 | if k.startswith(unwanted_prefix):
172 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
173 | model.load_state_dict(state_dict)
174 | iter_num = checkpoint['iter_num']
175 | best_val_loss = checkpoint['best_val_loss']
176 | elif init_from.startswith('gpt2'):
177 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
178 | # initialize from OpenAI GPT-2 weights
179 | override_args = dict(dropout=dropout)
180 | model = GPT.from_pretrained(init_from, override_args)
181 | # read off the created config params, so we can store them into checkpoint correctly
182 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
183 | model_args[k] = getattr(model.config, k)
184 | # crop down the model block size if desired, using model surgery
185 | if block_size < model.config.block_size:
186 | model.crop_block_size(block_size)
187 | model_args['block_size'] = block_size # so that the checkpoint will have the right value
188 | model.to(device)
189 |
190 | # initialize a GradScaler. If enabled=False scaler is a no-op
191 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
192 |
193 | # optimizer
194 | optimizer = model.configure_optimizers(optim_name, weight_decay, learning_rate, (beta1, beta2), device_type)
195 | if init_from == 'resume':
196 | optimizer.load_state_dict(checkpoint['optimizer'])
197 | checkpoint = None # free up memory
198 |
199 | # compile the model
200 | if compile:
201 | print("compiling the model... (takes a ~minute)")
202 | unoptimized_model = model
203 | model = torch.compile(model) # requires PyTorch 2.0
204 |
205 | # wrap model into DDP container
206 | if ddp:
207 | model = DDP(model, device_ids=[ddp_local_rank])
208 |
209 | # helps estimate an arbitrarily accurate loss over either split using many batches
210 | @torch.no_grad()
211 | def estimate_loss():
212 | out = {}
213 | model.eval()
214 | for split in ['train', 'val']:
215 | losses = torch.zeros(eval_iters)
216 | for k in range(eval_iters):
217 | X, Y = get_batch(split)
218 | with ctx:
219 | logits, loss = model(X, Y)
220 | losses[k] = loss.item()
221 | out[split] = losses.mean()
222 | model.train()
223 | return out
224 |
225 | # learning rate decay scheduler (cosine with warmup)
226 | def get_lr(it):
227 | # 1) linear warmup for warmup_iters steps
228 | if it < warmup_iters:
229 | return learning_rate * it / warmup_iters
230 | # 2) if it > lr_decay_iters, return min learning rate
231 | if it > lr_decay_iters:
232 | return min_lr
233 | # 3) in between, use cosine decay down to min learning rate
234 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
235 | assert 0 <= decay_ratio <= 1
236 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
237 | return min_lr + coeff * (learning_rate - min_lr)
238 |
239 | # logging
240 | if wandb_log and master_process:
241 | import wandb
242 | wandb.init(project=wandb_project, name=wandb_run_name, config=config)
243 |
244 | # training loop
245 | X, Y = get_batch('train') # fetch the very first batch
246 | t0 = time.time()
247 | local_iter_num = 0 # number of iterations in the lifetime of this process
248 | raw_model = model.module if ddp else model # unwrap DDP container if needed
249 | running_mfu = -1.0
250 | while True:
251 |
252 | # determine and set the learning rate for this iteration
253 | lr = get_lr(iter_num) if decay_lr else learning_rate
254 | for param_group in optimizer.param_groups:
255 | param_group['lr'] = lr
256 |
257 | # evaluate the loss on train/val sets and write checkpoints
258 | if iter_num % eval_interval == 0 and master_process:
259 | losses = estimate_loss()
260 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
261 | if wandb_log:
262 | wandb.log({
263 | "iter": iter_num,
264 | "train/loss": losses['train'],
265 | "val/loss": losses['val'],
266 | "lr": lr,
267 | "mfu": running_mfu*100, # convert to percentage
268 | })
269 | if losses['val'] < best_val_loss or always_save_checkpoint:
270 | best_val_loss = losses['val']
271 | if iter_num > 0:
272 | checkpoint = {
273 | 'model': raw_model.state_dict(),
274 | 'optimizer': optimizer.state_dict(),
275 | 'model_args': model_args,
276 | 'iter_num': iter_num,
277 | 'best_val_loss': best_val_loss,
278 | 'config': config,
279 | }
280 | print(f"saving checkpoint to {out_dir}")
281 | torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{iter_num}.pt'))
282 | if iter_num == 0 and eval_only:
283 | break
284 |
285 | # forward backward update, with optional gradient accumulation to simulate larger batch size
286 | # and using the GradScaler if data type is float16
287 | for micro_step in range(gradient_accumulation_steps):
288 | if ddp:
289 | # in DDP training we only need to sync gradients at the last micro step.
290 | # the official way to do this is with model.no_sync() context manager, but
291 | # I really dislike that this bloats the code and forces us to repeat code
292 | # looking at the source of that context manager, it just toggles this variable
293 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
294 | with ctx:
295 | logits, loss = model(X, Y)
296 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
297 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
298 | X, Y = get_batch('train')
299 | # backward pass, with gradient scaling if training in fp16
300 | scaler.scale(loss).backward()
301 | # clip the gradient
302 | if grad_clip != 0.0:
303 | scaler.unscale_(optimizer)
304 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
305 | # step the optimizer and scaler if training in fp16
306 | scaler.step(optimizer)
307 | scaler.update()
308 | # flush the gradients as soon as we can, no need for this memory anymore
309 | optimizer.zero_grad(set_to_none=True)
310 |
311 | # timing and logging
312 | t1 = time.time()
313 | dt = t1 - t0
314 | t0 = t1
315 | if iter_num % log_interval == 0 and master_process:
316 | # get loss as float. note: this is a CPU-GPU sync point
317 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
318 | lossf = loss.item() * gradient_accumulation_steps
319 | if local_iter_num >= 5: # let the training loop settle a bit
320 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
321 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
322 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
323 | iter_num += 1
324 | local_iter_num += 1
325 |
326 | # termination conditions
327 | if iter_num > max_iters:
328 | break
329 |
330 | if ddp:
331 | destroy_process_group()
332 |
--------------------------------------------------------------------------------
/gpt-2/model_default.py:
--------------------------------------------------------------------------------
1 | """
2 | Full definition of a GPT Language Model, all of it in this single file.
3 | References:
4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5 | https://github.com/openai/gpt-2/blob/master/src/model.py
6 | 2) huggingface/transformers PyTorch implementation:
7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8 | """
9 |
10 | import math
11 | import inspect
12 | from dataclasses import dataclass
13 |
14 | import torch
15 | import torch.nn as nn
16 | from torch.nn import functional as F
17 |
18 | class LayerNorm(nn.Module):
19 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
20 |
21 | def __init__(self, ndim, bias):
22 | super().__init__()
23 | self.weight = nn.Parameter(torch.ones(ndim))
24 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
25 |
26 | def forward(self, input):
27 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
28 |
29 | class CausalSelfAttention(nn.Module):
30 |
31 | def __init__(self, config):
32 | super().__init__()
33 | assert config.n_embd % config.n_head == 0
34 | # key, query, value projections for all heads, but in a batch
35 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
36 | # output projection
37 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
38 | # regularization
39 | self.attn_dropout = nn.Dropout(config.dropout)
40 | self.resid_dropout = nn.Dropout(config.dropout)
41 | self.n_head = config.n_head
42 | self.n_embd = config.n_embd
43 | self.dropout = config.dropout
44 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
45 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
46 | if not self.flash:
47 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
48 | # causal mask to ensure that attention is only applied to the left in the input sequence
49 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
50 | .view(1, 1, config.block_size, config.block_size))
51 |
52 | def forward(self, x):
53 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
54 |
55 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
56 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
57 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
58 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
59 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
60 |
61 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
62 | if self.flash:
63 | # efficient attention using Flash Attention CUDA kernels
64 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
65 | # y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=temp_mask, dropout_p=self.dropout if self.training else 0, is_causal=False)
66 | else:
67 | # manual implementation of attention
68 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
69 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
70 | att = F.softmax(att, dim=-1)
71 | att = self.attn_dropout(att)
72 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
73 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
74 |
75 | # output projection
76 | y = self.resid_dropout(self.c_proj(y))
77 | return y
78 |
79 | class MLP(nn.Module):
80 |
81 | def __init__(self, config):
82 | super().__init__()
83 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
84 | self.gelu = nn.GELU()
85 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
86 | self.dropout = nn.Dropout(config.dropout)
87 |
88 | def forward(self, x):
89 | x = self.c_fc(x)
90 | x = self.gelu(x)
91 | x = self.c_proj(x)
92 | x = self.dropout(x)
93 | return x
94 |
95 | class Block(nn.Module):
96 |
97 | def __init__(self, config):
98 | super().__init__()
99 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
100 | self.attn = CausalSelfAttention(config)
101 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
102 | self.mlp = MLP(config)
103 |
104 | def forward(self, x):
105 | x = x + self.attn(self.ln_1(x))
106 | x = x + self.mlp(self.ln_2(x))
107 | return x
108 |
109 | @dataclass
110 | class GPTConfig:
111 | block_size: int = 1024
112 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
113 | n_layer: int = 12
114 | n_head: int = 12
115 | n_embd: int = 768
116 | dropout: float = 0.0
117 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
118 |
119 | class GPT(nn.Module):
120 |
121 | def __init__(self, config):
122 | super().__init__()
123 | assert config.vocab_size is not None
124 | assert config.block_size is not None
125 | self.config = config
126 |
127 | self.transformer = nn.ModuleDict(dict(
128 | wte = nn.Embedding(config.vocab_size, config.n_embd),
129 | wpe = nn.Embedding(config.block_size, config.n_embd),
130 | drop = nn.Dropout(config.dropout),
131 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
132 | ln_f = LayerNorm(config.n_embd, bias=config.bias),
133 | ))
134 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
135 | # with weight tying when using torch.compile() some warnings get generated:
136 | # "UserWarning: functional_call was passed multiple values for tied weights.
137 | # This behavior is deprecated and will be an error in future versions"
138 | # not 100% sure what this is, so far seems to be harmless. TODO investigate
139 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
140 |
141 | # init all weights
142 | self.apply(self._init_weights)
143 | # apply special scaled init to the residual projections, per GPT-2 paper
144 | for pn, p in self.named_parameters():
145 | if pn.endswith('c_proj.weight'):
146 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
147 |
148 | # report number of parameters
149 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
150 |
151 | def get_num_params(self, non_embedding=True):
152 | """
153 | Return the number of parameters in the model.
154 | For non-embedding count (default), the position embeddings get subtracted.
155 | The token embeddings would too, except due to the parameter sharing these
156 | params are actually used as weights in the final layer, so we include them.
157 | """
158 | n_params = sum(p.numel() for p in self.parameters())
159 | if non_embedding:
160 | n_params -= self.transformer.wpe.weight.numel()
161 | return n_params
162 |
163 | def _init_weights(self, module):
164 | if isinstance(module, nn.Linear):
165 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
166 | if module.bias is not None:
167 | torch.nn.init.zeros_(module.bias)
168 | elif isinstance(module, nn.Embedding):
169 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
170 |
171 | def forward(self, idx, targets=None):
172 | device = idx.device
173 | b, t = idx.size()
174 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
175 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
176 |
177 | # forward the GPT model itself
178 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
179 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
180 | x = self.transformer.drop(tok_emb + pos_emb)
181 | for block in self.transformer.h:
182 | x = block(x)
183 | x = self.transformer.ln_f(x)
184 |
185 | if targets is not None:
186 | # if we are given some desired targets also calculate the loss
187 | logits = self.lm_head(x)
188 |
189 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
190 | else:
191 | # inference-time mini-optimization: only forward the lm_head on the very last position
192 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
193 | loss = None
194 |
195 | return logits, loss
196 |
197 | def crop_block_size(self, block_size):
198 | # model surgery to decrease the block size if necessary
199 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
200 | # but want to use a smaller block size for some smaller, simpler model
201 | assert block_size <= self.config.block_size
202 | self.config.block_size = block_size
203 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
204 | for block in self.transformer.h:
205 | if hasattr(block.attn, 'bias'):
206 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
207 |
208 | @classmethod
209 | def from_pretrained(cls, model_type, override_args=None):
210 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
211 | override_args = override_args or {} # default to empty dict
212 | # only dropout can be overridden see more notes below
213 | assert all(k == 'dropout' for k in override_args)
214 | from transformers import GPT2LMHeadModel
215 | print("loading weights from pretrained gpt: %s" % model_type)
216 |
217 | # n_layer, n_head and n_embd are determined from model_type
218 | config_args = {
219 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
220 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
221 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
222 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
223 | }[model_type]
224 | print("forcing vocab_size=50257, block_size=1024, bias=True")
225 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
226 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
227 | config_args['bias'] = True # always True for GPT model checkpoints
228 | # we can override the dropout rate, if desired
229 | if 'dropout' in override_args:
230 | print(f"overriding dropout rate to {override_args['dropout']}")
231 | config_args['dropout'] = override_args['dropout']
232 | # create a from-scratch initialized minGPT model
233 | config = GPTConfig(**config_args)
234 | model = GPT(config)
235 | sd = model.state_dict()
236 | sd_keys = sd.keys()
237 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
238 |
239 | # init a huggingface/transformers model
240 | model_hf = GPT2LMHeadModel.from_pretrained(model_type)
241 | sd_hf = model_hf.state_dict()
242 |
243 | # copy while ensuring all of the parameters are aligned and match in names and shapes
244 | sd_keys_hf = sd_hf.keys()
245 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
246 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
247 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
248 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
249 | # this means that we have to transpose these weights when we import them
250 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
251 | for k in sd_keys_hf:
252 | if any(k.endswith(w) for w in transposed):
253 | # special treatment for the Conv1D weights we need to transpose
254 | assert sd_hf[k].shape[::-1] == sd[k].shape
255 | with torch.no_grad():
256 | sd[k].copy_(sd_hf[k].t())
257 | else:
258 | # vanilla copy over the other parameters
259 | assert sd_hf[k].shape == sd[k].shape
260 | with torch.no_grad():
261 | sd[k].copy_(sd_hf[k])
262 |
263 | return model
264 |
265 | def configure_optimizers(self, optim_name, weight_decay, learning_rate, betas, device_type):
266 | # start with all of the candidate parameters
267 | param_dict = {pn: p for pn, p in self.named_parameters()}
268 | # filter out those that do not require grad
269 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
270 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
271 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
272 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
273 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
274 | optim_groups = [
275 | {'params': decay_params, 'weight_decay': weight_decay},
276 | {'params': nodecay_params, 'weight_decay': 0.0}
277 | ]
278 | num_decay_params = sum(p.numel() for p in decay_params)
279 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
280 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
281 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
282 | # Create AdamW optimizer and use the fused version if it is available
283 | if optim_name == "adam":
284 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
285 | use_fused = fused_available and device_type == 'cuda'
286 | extra_args = dict(fused=True) if use_fused else dict()
287 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
288 | print(f"using fused AdamW: {use_fused}")
289 | elif optim_name == "gdm":
290 | optimizer = torch.optim.SGD(optim_groups, lr=learning_rate, momentum=betas[0], dampening=0.)
291 |
292 | return optimizer
293 |
294 | def estimate_mfu(self, fwdbwd_per_iter, dt):
295 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
296 | # first estimate the number of flops we do per iteration.
297 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
298 | N = self.get_num_params()
299 | cfg = self.config
300 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
301 | flops_per_token = 6*N + 12*L*H*Q*T
302 | flops_per_fwdbwd = flops_per_token * T
303 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
304 | # express our flops throughput as ratio of A100 bfloat16 peak flops
305 | flops_achieved = flops_per_iter * (1.0/dt) # per second
306 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
307 | mfu = flops_achieved / flops_promised
308 | return mfu
309 |
310 | @torch.no_grad()
311 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
312 | """
313 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
314 | the sequence max_new_tokens times, feeding the predictions back into the model each time.
315 | Most likely you'll want to make sure to be in model.eval() mode of operation for this.
316 | """
317 | for _ in range(max_new_tokens):
318 | # if the sequence context is growing too long we must crop it at block_size
319 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
320 | # forward the model to get the logits for the index in the sequence
321 | logits, _ = self(idx_cond)
322 | # pluck the logits at the final step and scale by desired temperature
323 | logits = logits[:, -1, :] / temperature
324 | # optionally crop the logits to only the top k options
325 | if top_k is not None:
326 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
327 | logits[logits < v[:, [-1]]] = -float('Inf')
328 | # apply softmax to convert logits to (normalized) probabilities
329 | probs = F.softmax(logits, dim=-1)
330 | # sample from the distribution
331 | idx_next = torch.multinomial(probs, num_samples=1)
332 | # append sampled index to the running sequence and continue
333 | idx = torch.cat((idx, idx_next), dim=1)
334 |
335 | return idx
336 |
--------------------------------------------------------------------------------
/gpt-2/model_sink.py:
--------------------------------------------------------------------------------
1 | """
2 | Full definition of a GPT Language Model, all of it in this single file.
3 | References:
4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5 | https://github.com/openai/gpt-2/blob/master/src/model.py
6 | 2) huggingface/transformers PyTorch implementation:
7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8 | """
9 |
10 | import math
11 | import inspect
12 | from dataclasses import dataclass
13 |
14 | import torch
15 | import torch.nn as nn
16 | from torch.nn import functional as F
17 |
18 | class LayerNorm(nn.Module):
19 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
20 |
21 | def __init__(self, ndim, bias):
22 | super().__init__()
23 | self.weight = nn.Parameter(torch.ones(ndim))
24 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
25 |
26 | def forward(self, input):
27 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
28 |
29 | class CausalSelfAttention(nn.Module):
30 |
31 | def __init__(self, config):
32 | super().__init__()
33 | assert config.n_embd % config.n_head == 0
34 | # key, query, value projections for all heads, but in a batch
35 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
36 | # output projection
37 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
38 | # regularization
39 | self.attn_dropout = nn.Dropout(config.dropout)
40 | self.resid_dropout = nn.Dropout(config.dropout)
41 | self.n_head = config.n_head
42 | self.n_embd = config.n_embd
43 | self.dropout = config.dropout
44 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
45 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
46 | if not self.flash:
47 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
48 | # causal mask to ensure that attention is only applied to the left in the input sequence
49 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
50 | .view(1, 1, config.block_size, config.block_size))
51 |
52 | def forward(self, x):
53 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
54 |
55 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
56 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
57 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
58 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
59 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
60 |
61 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
62 | if self.flash:
63 | # efficient attention using Flash Attention CUDA kernels
64 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
65 | else:
66 | # manual implementation of attention
67 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
68 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
69 | att = F.softmax(att, dim=-1)
70 | att = self.attn_dropout(att)
71 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
72 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
73 |
74 | # output projection
75 | y = self.resid_dropout(self.c_proj(y))
76 | return y
77 |
78 | class MLP(nn.Module):
79 |
80 | def __init__(self, config):
81 | super().__init__()
82 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
83 | self.gelu = nn.GELU()
84 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
85 | self.dropout = nn.Dropout(config.dropout)
86 |
87 | def forward(self, x):
88 | x = self.c_fc(x)
89 | x = self.gelu(x)
90 | x = self.c_proj(x)
91 | x = self.dropout(x)
92 | return x
93 |
94 | class Block(nn.Module):
95 |
96 | def __init__(self, config):
97 | super().__init__()
98 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
99 | self.attn = CausalSelfAttention(config)
100 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
101 | self.mlp = MLP(config)
102 |
103 | def forward(self, x):
104 | x = x + self.attn(self.ln_1(x))
105 | x = x + self.mlp(self.ln_2(x))
106 | return x
107 |
108 | @dataclass
109 | class GPTConfig:
110 | block_size: int = 1024
111 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
112 | n_layer: int = 12
113 | n_head: int = 12
114 | n_embd: int = 768
115 | dropout: float = 0.0
116 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
117 |
118 | class GPT(nn.Module):
119 |
120 | def __init__(self, config):
121 | super().__init__()
122 | assert config.vocab_size is not None
123 | assert config.block_size is not None
124 | self.config = config
125 |
126 | self.transformer = nn.ModuleDict(dict(
127 | wte = nn.Embedding(config.vocab_size, config.n_embd),
128 | wpe = nn.Embedding(config.block_size, config.n_embd),
129 | drop = nn.Dropout(config.dropout),
130 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
131 | ln_f = LayerNorm(config.n_embd, bias=config.bias),
132 | ))
133 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
134 | # with weight tying when using torch.compile() some warnings get generated:
135 | # "UserWarning: functional_call was passed multiple values for tied weights.
136 | # This behavior is deprecated and will be an error in future versions"
137 | # not 100% sure what this is, so far seems to be harmless. TODO investigate
138 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
139 |
140 | # init all weights
141 | self.apply(self._init_weights)
142 | # apply special scaled init to the residual projections, per GPT-2 paper
143 | for pn, p in self.named_parameters():
144 | if pn.endswith('c_proj.weight'):
145 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
146 |
147 | # report number of parameters
148 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
149 |
150 | ### add register token
151 | self.reg_token = nn.Parameter(torch.zeros(1, 1, config.n_embd), requires_grad=True) ### add register token
152 | torch.nn.init.normal_(self.reg_token, mean=0.0, std=0.02)
153 |
154 | def get_num_params(self, non_embedding=True):
155 | """
156 | Return the number of parameters in the model.
157 | For non-embedding count (default), the position embeddings get subtracted.
158 | The token embeddings would too, except due to the parameter sharing these
159 | params are actually used as weights in the final layer, so we include them.
160 | """
161 | n_params = sum(p.numel() for p in self.parameters())
162 | if non_embedding:
163 | n_params -= self.transformer.wpe.weight.numel()
164 | return n_params
165 |
166 | def _init_weights(self, module):
167 | if isinstance(module, nn.Linear):
168 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
169 | if module.bias is not None:
170 | torch.nn.init.zeros_(module.bias)
171 | elif isinstance(module, nn.Embedding):
172 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
173 |
174 | def forward(self, idx, targets=None):
175 | device = idx.device
176 | b, t = idx.size()
177 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
178 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
179 |
180 | # forward the GPT model itself
181 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
182 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
183 |
184 | tok_emb = tok_emb + pos_emb
185 | #############################
186 | ## insert a register/sink token
187 | reg_tok_emb = self.reg_token.repeat((b, 1, 1))
188 | tok_emb = torch.cat([reg_tok_emb, tok_emb], dim=1)
189 | ###################################
190 | x = self.transformer.drop(tok_emb)
191 | for block in self.transformer.h:
192 | x = block(x)
193 | x = self.transformer.ln_f(x)
194 | x = x[:, 1:, :] ## remove the register token feature
195 |
196 | if targets is not None:
197 | # if we are given some desired targets also calculate the loss
198 | logits = self.lm_head(x)
199 |
200 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
201 | else:
202 | # inference-time mini-optimization: only forward the lm_head on the very last position
203 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
204 | loss = None
205 |
206 | return logits, loss
207 |
208 | def crop_block_size(self, block_size):
209 | # model surgery to decrease the block size if necessary
210 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
211 | # but want to use a smaller block size for some smaller, simpler model
212 | assert block_size <= self.config.block_size
213 | self.config.block_size = block_size
214 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
215 | for block in self.transformer.h:
216 | if hasattr(block.attn, 'bias'):
217 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
218 |
219 | @classmethod
220 | def from_pretrained(cls, model_type, override_args=None):
221 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
222 | override_args = override_args or {} # default to empty dict
223 | # only dropout can be overridden see more notes below
224 | assert all(k == 'dropout' for k in override_args)
225 | from transformers import GPT2LMHeadModel
226 | print("loading weights from pretrained gpt: %s" % model_type)
227 |
228 | # n_layer, n_head and n_embd are determined from model_type
229 | config_args = {
230 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
231 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
232 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
233 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
234 | }[model_type]
235 | print("forcing vocab_size=50257, block_size=1024, bias=True")
236 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
237 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
238 | config_args['bias'] = True # always True for GPT model checkpoints
239 | # we can override the dropout rate, if desired
240 | if 'dropout' in override_args:
241 | print(f"overriding dropout rate to {override_args['dropout']}")
242 | config_args['dropout'] = override_args['dropout']
243 | # create a from-scratch initialized minGPT model
244 | config = GPTConfig(**config_args)
245 | model = GPT(config)
246 | sd = model.state_dict()
247 | sd_keys = sd.keys()
248 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
249 |
250 | # init a huggingface/transformers model
251 | model_hf = GPT2LMHeadModel.from_pretrained(model_type)
252 | sd_hf = model_hf.state_dict()
253 |
254 | # copy while ensuring all of the parameters are aligned and match in names and shapes
255 | sd_keys_hf = sd_hf.keys()
256 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
257 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
258 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
259 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
260 | # this means that we have to transpose these weights when we import them
261 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
262 | for k in sd_keys_hf:
263 | if any(k.endswith(w) for w in transposed):
264 | # special treatment for the Conv1D weights we need to transpose
265 | assert sd_hf[k].shape[::-1] == sd[k].shape
266 | with torch.no_grad():
267 | sd[k].copy_(sd_hf[k].t())
268 | else:
269 | # vanilla copy over the other parameters
270 | assert sd_hf[k].shape == sd[k].shape
271 | with torch.no_grad():
272 | sd[k].copy_(sd_hf[k])
273 |
274 | return model
275 |
276 | def configure_optimizers(self, optim_name, weight_decay, learning_rate, betas, device_type):
277 | # start with all of the candidate parameters
278 | param_dict = {pn: p for pn, p in self.named_parameters()}
279 | # filter out those that do not require grad
280 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
281 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
282 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
283 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
284 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
285 | optim_groups = [
286 | {'params': decay_params, 'weight_decay': weight_decay},
287 | {'params': nodecay_params, 'weight_decay': 0.0}
288 | ]
289 | num_decay_params = sum(p.numel() for p in decay_params)
290 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
291 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
292 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
293 | # Create AdamW optimizer and use the fused version if it is available
294 | if optim_name == "adam":
295 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
296 | use_fused = fused_available and device_type == 'cuda'
297 | extra_args = dict(fused=True) if use_fused else dict()
298 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
299 | print(f"using fused AdamW: {use_fused}")
300 | elif optim_name == "gdm":
301 | optimizer = torch.optim.SGD(optim_groups, lr=learning_rate, momentum=betas[0], dampening=0.)
302 |
303 | return optimizer
304 |
305 | def estimate_mfu(self, fwdbwd_per_iter, dt):
306 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
307 | # first estimate the number of flops we do per iteration.
308 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
309 | N = self.get_num_params()
310 | cfg = self.config
311 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
312 | flops_per_token = 6*N + 12*L*H*Q*T
313 | flops_per_fwdbwd = flops_per_token * T
314 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
315 | # express our flops throughput as ratio of A100 bfloat16 peak flops
316 | flops_achieved = flops_per_iter * (1.0/dt) # per second
317 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
318 | mfu = flops_achieved / flops_promised
319 | return mfu
320 |
321 | @torch.no_grad()
322 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
323 | """
324 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
325 | the sequence max_new_tokens times, feeding the predictions back into the model each time.
326 | Most likely you'll want to make sure to be in model.eval() mode of operation for this.
327 | """
328 | for _ in range(max_new_tokens):
329 | # if the sequence context is growing too long we must crop it at block_size
330 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
331 | # forward the model to get the logits for the index in the sequence
332 | logits, _ = self(idx_cond)
333 | # pluck the logits at the final step and scale by desired temperature
334 | logits = logits[:, -1, :] / temperature
335 | # optionally crop the logits to only the top k options
336 | if top_k is not None:
337 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
338 | logits[logits < v[:, [-1]]] = -float('Inf')
339 | # apply softmax to convert logits to (normalized) probabilities
340 | probs = F.softmax(logits, dim=-1)
341 | # sample from the distribution
342 | idx_next = torch.multinomial(probs, num_samples=1)
343 | # append sampled index to the running sequence and continue
344 | idx = torch.cat((idx, idx_next), dim=1)
345 |
346 | return idx
347 |
--------------------------------------------------------------------------------
/gpt-2/model_attn_bias.py:
--------------------------------------------------------------------------------
1 | """
2 | Full definition of a GPT Language Model, all of it in this single file.
3 | References:
4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5 | https://github.com/openai/gpt-2/blob/master/src/model.py
6 | 2) huggingface/transformers PyTorch implementation:
7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8 | """
9 |
10 | import math
11 | import inspect
12 | from dataclasses import dataclass
13 |
14 | import torch
15 | import torch.nn as nn
16 | from torch.nn import functional as F
17 |
18 | class LayerNorm(nn.Module):
19 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
20 |
21 | def __init__(self, ndim, bias):
22 | super().__init__()
23 | self.weight = nn.Parameter(torch.ones(ndim))
24 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
25 |
26 | def forward(self, input):
27 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
28 |
29 | class CausalSelfAttention(nn.Module):
30 |
31 | def __init__(self, config):
32 | super().__init__()
33 | assert config.n_embd % config.n_head == 0
34 | # key, query, value projections for all heads, but in a batch
35 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
36 | # output projection
37 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
38 | # regularization
39 | self.attn_dropout = nn.Dropout(config.dropout)
40 | self.resid_dropout = nn.Dropout(config.dropout)
41 | self.n_head = config.n_head
42 | self.n_embd = config.n_embd
43 | self.dropout = config.dropout
44 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
45 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
46 | if not self.flash:
47 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
48 | # causal mask to ensure that attention is only applied to the left in the input sequence
49 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
50 | .view(1, 1, config.block_size, config.block_size))
51 |
52 | #############################################################################
53 | self.k_bias = nn.Parameter(torch.zeros(1, self.n_head, 1, self.n_embd // self.n_head), requires_grad=True)
54 | self.v_bias = nn.Parameter(torch.zeros(1, self.n_head, 1, self.n_embd // self.n_head), requires_grad=True)
55 |
56 | torch.nn.init.normal_(self.k_bias, mean=0.0, std=0.02)
57 | torch.nn.init.normal_(self.v_bias, mean=0.0, std=0.02)
58 | #############################################################################
59 |
60 |
61 | def forward(self, x):
62 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
63 |
64 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
65 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
66 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
67 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
69 |
70 | ###########################################################
71 | L, S = q.size(-2), k.size(-2)
72 | temp_mask = torch.ones(L, S, dtype=torch.bool, device=q.device).tril(diagonal=0)
73 | true_values = torch.ones(temp_mask.size(0), 1, dtype=torch.bool, device=q.device)
74 | temp_mask = torch.cat((true_values, temp_mask), dim=1)
75 |
76 | k_bias = self.k_bias.repeat(B, 1, 1, 1) # (B, num_heads, 1, dim // num_heads)
77 | v_bias = self.v_bias.repeat(B, 1, 1, 1)
78 |
79 | k = torch.cat((k_bias, k), dim=2)
80 | v = torch.cat((v_bias, v), dim=2)
81 | ###########################################################
82 |
83 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
84 | if self.flash:
85 | # efficient attention using Flash Attention CUDA kernels
86 | # y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
87 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=temp_mask, dropout_p=self.dropout if self.training else 0, is_causal=False)
88 | else:
89 | # manual implementation of attention
90 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
91 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
92 | att = F.softmax(att, dim=-1)
93 | att = self.attn_dropout(att)
94 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
95 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
96 |
97 | # output projection
98 | y = self.resid_dropout(self.c_proj(y))
99 | return y
100 |
101 | class MLP(nn.Module):
102 |
103 | def __init__(self, config):
104 | super().__init__()
105 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
106 | self.gelu = nn.GELU()
107 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
108 | self.dropout = nn.Dropout(config.dropout)
109 |
110 | def forward(self, x):
111 | x = self.c_fc(x)
112 | x = self.gelu(x)
113 | x = self.c_proj(x)
114 | x = self.dropout(x)
115 | return x
116 |
117 | class Block(nn.Module):
118 |
119 | def __init__(self, config):
120 | super().__init__()
121 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
122 | self.attn = CausalSelfAttention(config)
123 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
124 | self.mlp = MLP(config)
125 |
126 | def forward(self, x):
127 | x = x + self.attn(self.ln_1(x))
128 | x = x + self.mlp(self.ln_2(x))
129 | return x
130 |
131 | @dataclass
132 | class GPTConfig:
133 | block_size: int = 1024
134 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
135 | n_layer: int = 12
136 | n_head: int = 12
137 | n_embd: int = 768
138 | dropout: float = 0.0
139 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
140 |
141 | class GPT(nn.Module):
142 |
143 | def __init__(self, config):
144 | super().__init__()
145 | assert config.vocab_size is not None
146 | assert config.block_size is not None
147 | self.config = config
148 |
149 | self.transformer = nn.ModuleDict(dict(
150 | wte = nn.Embedding(config.vocab_size, config.n_embd),
151 | wpe = nn.Embedding(config.block_size, config.n_embd),
152 | drop = nn.Dropout(config.dropout),
153 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
154 | ln_f = LayerNorm(config.n_embd, bias=config.bias),
155 | ))
156 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
157 | # with weight tying when using torch.compile() some warnings get generated:
158 | # "UserWarning: functional_call was passed multiple values for tied weights.
159 | # This behavior is deprecated and will be an error in future versions"
160 | # not 100% sure what this is, so far seems to be harmless. TODO investigate
161 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
162 |
163 | # init all weights
164 | self.apply(self._init_weights)
165 | # apply special scaled init to the residual projections, per GPT-2 paper
166 | for pn, p in self.named_parameters():
167 | if pn.endswith('c_proj.weight'):
168 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
169 |
170 | # report number of parameters
171 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
172 |
173 | def get_num_params(self, non_embedding=True):
174 | """
175 | Return the number of parameters in the model.
176 | For non-embedding count (default), the position embeddings get subtracted.
177 | The token embeddings would too, except due to the parameter sharing these
178 | params are actually used as weights in the final layer, so we include them.
179 | """
180 | n_params = sum(p.numel() for p in self.parameters())
181 | if non_embedding:
182 | n_params -= self.transformer.wpe.weight.numel()
183 | return n_params
184 |
185 | def _init_weights(self, module):
186 | if isinstance(module, nn.Linear):
187 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
188 | if module.bias is not None:
189 | torch.nn.init.zeros_(module.bias)
190 | elif isinstance(module, nn.Embedding):
191 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
192 |
193 | def forward(self, idx, targets=None):
194 | device = idx.device
195 | b, t = idx.size()
196 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
197 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
198 |
199 | # forward the GPT model itself
200 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
201 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
202 | x = self.transformer.drop(tok_emb + pos_emb)
203 | for block in self.transformer.h:
204 | x = block(x)
205 | x = self.transformer.ln_f(x)
206 |
207 | if targets is not None:
208 | # if we are given some desired targets also calculate the loss
209 | logits = self.lm_head(x)
210 |
211 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
212 | else:
213 | # inference-time mini-optimization: only forward the lm_head on the very last position
214 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
215 | loss = None
216 |
217 | return logits, loss
218 |
219 | def crop_block_size(self, block_size):
220 | # model surgery to decrease the block size if necessary
221 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
222 | # but want to use a smaller block size for some smaller, simpler model
223 | assert block_size <= self.config.block_size
224 | self.config.block_size = block_size
225 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
226 | for block in self.transformer.h:
227 | if hasattr(block.attn, 'bias'):
228 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
229 |
230 | @classmethod
231 | def from_pretrained(cls, model_type, override_args=None):
232 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
233 | override_args = override_args or {} # default to empty dict
234 | # only dropout can be overridden see more notes below
235 | assert all(k == 'dropout' for k in override_args)
236 | from transformers import GPT2LMHeadModel
237 | print("loading weights from pretrained gpt: %s" % model_type)
238 |
239 | # n_layer, n_head and n_embd are determined from model_type
240 | config_args = {
241 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
242 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
243 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
244 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
245 | }[model_type]
246 | print("forcing vocab_size=50257, block_size=1024, bias=True")
247 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
248 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
249 | config_args['bias'] = True # always True for GPT model checkpoints
250 | # we can override the dropout rate, if desired
251 | if 'dropout' in override_args:
252 | print(f"overriding dropout rate to {override_args['dropout']}")
253 | config_args['dropout'] = override_args['dropout']
254 | # create a from-scratch initialized minGPT model
255 | config = GPTConfig(**config_args)
256 | model = GPT(config)
257 | sd = model.state_dict()
258 | sd_keys = sd.keys()
259 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
260 |
261 | # init a huggingface/transformers model
262 | model_hf = GPT2LMHeadModel.from_pretrained(model_type)
263 | sd_hf = model_hf.state_dict()
264 |
265 | # copy while ensuring all of the parameters are aligned and match in names and shapes
266 | sd_keys_hf = sd_hf.keys()
267 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
268 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
269 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
270 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
271 | # this means that we have to transpose these weights when we import them
272 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
273 | for k in sd_keys_hf:
274 | if any(k.endswith(w) for w in transposed):
275 | # special treatment for the Conv1D weights we need to transpose
276 | assert sd_hf[k].shape[::-1] == sd[k].shape
277 | with torch.no_grad():
278 | sd[k].copy_(sd_hf[k].t())
279 | else:
280 | # vanilla copy over the other parameters
281 | assert sd_hf[k].shape == sd[k].shape
282 | with torch.no_grad():
283 | sd[k].copy_(sd_hf[k])
284 |
285 | return model
286 |
287 | def configure_optimizers(self, optim_name, weight_decay, learning_rate, betas, device_type):
288 | # start with all of the candidate parameters
289 | param_dict = {pn: p for pn, p in self.named_parameters()}
290 | # filter out those that do not require grad
291 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
292 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
293 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
294 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
295 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
296 | optim_groups = [
297 | {'params': decay_params, 'weight_decay': weight_decay},
298 | {'params': nodecay_params, 'weight_decay': 0.0}
299 | ]
300 | num_decay_params = sum(p.numel() for p in decay_params)
301 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
302 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
303 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
304 | # Create AdamW optimizer and use the fused version if it is available
305 | if optim_name == "adam":
306 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
307 | use_fused = fused_available and device_type == 'cuda'
308 | extra_args = dict(fused=True) if use_fused else dict()
309 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
310 | print(f"using fused AdamW: {use_fused}")
311 | elif optim_name == "gdm":
312 | optimizer = torch.optim.SGD(optim_groups, lr=learning_rate, momentum=betas[0], dampening=0.)
313 |
314 | return optimizer
315 |
316 | def estimate_mfu(self, fwdbwd_per_iter, dt):
317 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
318 | # first estimate the number of flops we do per iteration.
319 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
320 | N = self.get_num_params()
321 | cfg = self.config
322 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
323 | flops_per_token = 6*N + 12*L*H*Q*T
324 | flops_per_fwdbwd = flops_per_token * T
325 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
326 | # express our flops throughput as ratio of A100 bfloat16 peak flops
327 | flops_achieved = flops_per_iter * (1.0/dt) # per second
328 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
329 | mfu = flops_achieved / flops_promised
330 | return mfu
331 |
332 | @torch.no_grad()
333 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
334 | """
335 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
336 | the sequence max_new_tokens times, feeding the predictions back into the model each time.
337 | Most likely you'll want to make sure to be in model.eval() mode of operation for this.
338 | """
339 | for _ in range(max_new_tokens):
340 | # if the sequence context is growing too long we must crop it at block_size
341 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
342 | # forward the model to get the logits for the index in the sequence
343 | logits, _ = self(idx_cond)
344 | # pluck the logits at the final step and scale by desired temperature
345 | logits = logits[:, -1, :] / temperature
346 | # optionally crop the logits to only the top k options
347 | if top_k is not None:
348 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
349 | logits[logits < v[:, [-1]]] = -float('Inf')
350 | # apply softmax to convert logits to (normalized) probabilities
351 | probs = F.softmax(logits, dim=-1)
352 | # sample from the distribution
353 | idx_next = torch.multinomial(probs, num_samples=1)
354 | # append sampled index to the running sequence and continue
355 | idx = torch.cat((idx, idx_next), dim=1)
356 |
357 | return idx
358 |
--------------------------------------------------------------------------------