├── assets ├── bird.png ├── equation.png ├── main_teaser.png └── reg_feat_mean │ ├── base_imagemean_bs10k_trainaug.pt │ ├── giant_imagemean_bs10k_trainaug.pt │ ├── large_imagemean_bs10k_trainaug.pt │ └── small_imagemean_bs10k_trainaug.pt ├── monkey_patch ├── __init__.py ├── modify_vit.py ├── README.md ├── modify_phi2.py ├── modify_mistral.py └── modify_llama.py ├── lib ├── __init__.py ├── hook.py ├── load_data.py ├── model_dict.py ├── eval_utils.py ├── plot_utils_vit.py ├── load_model.py ├── plot_utils_llm.py └── plot_utils.py ├── gpt-2 ├── config │ ├── eval_gpt2_attn_bias.py │ ├── eval_gpt2_default.py │ ├── eval_gpt2_sink.py │ ├── train_gpt2_attn_bias.py │ ├── train_gpt2_default.py │ └── train_gpt2_sink.py ├── configurator.py ├── README.md ├── plot_gpt2.py ├── test.py ├── analyze.py ├── train.py ├── model_default.py ├── model_sink.py └── model_attn_bias.py ├── LICENSE ├── INSTALL.md ├── main_vit.py ├── README.md └── main_llm.py /assets/bird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/bird.png -------------------------------------------------------------------------------- /assets/equation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/equation.png -------------------------------------------------------------------------------- /assets/main_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/main_teaser.png -------------------------------------------------------------------------------- /monkey_patch/__init__.py: -------------------------------------------------------------------------------- 1 | from .modify_llama import * 2 | from .modify_mistral import * 3 | from .modify_vit import * 4 | from .modify_phi2 import * -------------------------------------------------------------------------------- /assets/reg_feat_mean/base_imagemean_bs10k_trainaug.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/reg_feat_mean/base_imagemean_bs10k_trainaug.pt -------------------------------------------------------------------------------- /assets/reg_feat_mean/giant_imagemean_bs10k_trainaug.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/reg_feat_mean/giant_imagemean_bs10k_trainaug.pt -------------------------------------------------------------------------------- /assets/reg_feat_mean/large_imagemean_bs10k_trainaug.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/reg_feat_mean/large_imagemean_bs10k_trainaug.pt -------------------------------------------------------------------------------- /assets/reg_feat_mean/small_imagemean_bs10k_trainaug.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/massive-activations/HEAD/assets/reg_feat_mean/small_imagemean_bs10k_trainaug.pt -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_model import load_llm, load_vit, load_dinov2_linear_head 2 | from .eval_utils import test_imagenet, setup_dinov2_model_for_eval, fix_reg_mean, eval_ppl 3 | from .plot_utils_llm import plot_3d_feat, plot_layer_ax, plot_attn 4 | from .plot_utils_vit import plot_3d_feat_vit, plot_layer_ax_vit 5 | from .load_data import get_data 6 | from .hook import setup_intervene_hook -------------------------------------------------------------------------------- /gpt-2/config/eval_gpt2_attn_bias.py: -------------------------------------------------------------------------------- 1 | # evaluate gpt2 model with a sink token 2 | # n_layer=12, n_head=12, n_embd=768 3 | batch_size = 8 4 | eval_iters = 500 # use more iterations to get good estimate 5 | eval_only = True 6 | wandb_log = False 7 | init_from = 'resume' 8 | ckpt_iter = 50000 9 | out_dir="../pretrained-models/attn_bias" 10 | data_dir = '/data/locus/project_data/project_data2/mingjies/nanoGPT/data' 11 | save_dir="results/gpt-2/attn_bias/" 12 | compile = False 13 | model_type = "gpt2_attn_bias" -------------------------------------------------------------------------------- /gpt-2/config/eval_gpt2_default.py: -------------------------------------------------------------------------------- 1 | # evaluate gpt2 model with default architecture 2 | # n_layer=12, n_head=12, n_embd=768 3 | batch_size = 8 4 | eval_iters = 500 # use more iterations to get good estimate 5 | eval_only = True 6 | wandb_log = False 7 | init_from = 'resume' 8 | ckpt_iter = 50000 9 | out_dir="../pretrained-models/default" 10 | data_dir = '/data/locus/project_data/project_data2/mingjies/nanoGPT/data' 11 | save_dir="results/gpt-2/default/" 12 | compile = False 13 | model_type = "gpt2_default" -------------------------------------------------------------------------------- /gpt-2/config/eval_gpt2_sink.py: -------------------------------------------------------------------------------- 1 | # evaluate gpt2 model with a sink token 2 | # n_layer=12, n_head=12, n_embd=768 3 | batch_size = 8 4 | eval_iters = 500 # use more iterations to get good estimate 5 | eval_only = True 6 | wandb_log = False 7 | init_from = 'resume' 8 | ckpt_iter = 50000 9 | out_dir="../pretrained-models/sink" 10 | data_dir = '/data/locus/project_data/project_data2/mingjies/nanoGPT/data' 11 | save_dir="results/gpt-2/sink/" 12 | compile = False 13 | model_type = "gpt2_sink" 14 | num_reg = 1 -------------------------------------------------------------------------------- /gpt-2/config/train_gpt2_attn_bias.py: -------------------------------------------------------------------------------- 1 | out_dir = 'results/' 2 | data_dir = '/data/locus/project_data/project_data2/mingjies/nanoGPT/data' 3 | 4 | wandb_log = False 5 | wandb_project = 'owt' 6 | wandb_run_name='gpt2-124M-default-run' 7 | compile=False 8 | 9 | # these make the total batch size be ~0.5M 10 | # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520 11 | batch_size = 12 12 | block_size = 1024 13 | gradient_accumulation_steps = 5 * 8 14 | 15 | # this makes total number of tokens be 300B 16 | max_iters = 600000 17 | lr_decay_iters = 600000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # weight decay 25 | weight_decay = 1e-1 26 | 27 | model_type = "gpt2_attn_bias" -------------------------------------------------------------------------------- /gpt-2/config/train_gpt2_default.py: -------------------------------------------------------------------------------- 1 | out_dir = 'results/' 2 | data_dir = '/data/locus/project_data/project_data2/mingjies/nanoGPT/data' 3 | 4 | wandb_log = False 5 | wandb_project = 'owt' 6 | wandb_run_name='gpt2-124M-default-run' 7 | compile=False 8 | 9 | # these make the total batch size be ~0.5M 10 | # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520 11 | batch_size = 12 12 | block_size = 1024 13 | gradient_accumulation_steps = 5 * 8 14 | 15 | # this makes total number of tokens be 300B 16 | max_iters = 600000 17 | lr_decay_iters = 600000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # weight decay 25 | weight_decay = 1e-1 26 | 27 | model_type = "gpt2_attn_bias" -------------------------------------------------------------------------------- /gpt-2/config/train_gpt2_sink.py: -------------------------------------------------------------------------------- 1 | out_dir = 'results/' 2 | data_dir = '/data/locus/project_data/project_data2/mingjies/nanoGPT/data' 3 | 4 | wandb_log = False 5 | wandb_project = 'owt' 6 | wandb_run_name='gpt2-124M-default-run' 7 | compile=False 8 | 9 | # these make the total batch size be ~0.5M 10 | # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520 11 | batch_size = 12 12 | block_size = 1024 13 | gradient_accumulation_steps = 5 * 8 14 | 15 | # this makes total number of tokens be 300B 16 | max_iters = 600000 17 | lr_decay_iters = 600000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # weight decay 25 | weight_decay = 1e-1 26 | 27 | model_type = "gpt2_sink" 28 | num_reg=1 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 CMU Locus Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Dependencies 4 | Create an new conda virtual environment 5 | ```sh 6 | conda create -n massive-activations python=3.9 -y 7 | conda activate massive-activations 8 | ``` 9 | 10 | Install [Pytorch](https://pytorch.org/)>=2.0.0, [torchvision](https://pytorch.org/vision/stable/index.html)>=0.15.0 following official instructions. For example: 11 | ```sh 12 | pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html 13 | ``` 14 | 15 | Install additional dependencies: 16 | ```sh 17 | pip install timm==0.9.12 transformers==4.36.0 accelerate==0.23.0 datasets==2.14.5 matplotlib==3.8.0 seaborn sentencepiece protobuf 18 | ``` 19 | 20 | ## Pretrained Models 21 | 22 | - **LLM Models**: To use pretrained LLM models, update the `CACHE_DIR_BASE` variable in the [model_dict.py](lib/model_dict.py) file to point to the directory containing the pretrained model weights. 23 | 24 | - **DINOv2-reg Models**: To use the DINOv2-reg model for linear classification, download the pretrained linear classification head from this [link](https://github.com/facebookresearch/dinov2?tab=readme-ov-file#pretrained-heads---image-classification). Set the `--linear_head_path` argument in the [main_vit.py](main_vit.py) script to the directory where you've stored the downloaded weights. 25 | -------------------------------------------------------------------------------- /lib/hook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Intervened_Layer: 4 | def __init__(self, model_name, reset_type): 5 | self.model_name = model_name 6 | self.reset_type = reset_type 7 | 8 | def update(self, inp, out): 9 | out_feature = out[0] 10 | 11 | if self.reset_type == "set_mean": 12 | alpha = 1.0 13 | elif self.reset_type == "set_zero": 14 | alpha = 0.0 15 | else: 16 | raise ValueError(f"reset_type {self.reset_type} not supported") 17 | 18 | if self.model_name == "llama2_13b": 19 | out_feature[:, 0, 4743] = - 1223.5 * alpha 20 | out_feature[:, 0, 2100] = - 717.95 * alpha 21 | elif self.model_name == "llama2_7b": 22 | feat_abs = out_feature.abs() 23 | sort_res = torch.sort(feat_abs.flatten(), descending=True) 24 | 25 | top_indices = sort_res.indices[0] 26 | token_dim = top_indices.item() // feat_abs.shape[2] 27 | 28 | if token_dim != 0: 29 | out_feature[:, token_dim, 2533] = 2546.8 * alpha 30 | out_feature[:, token_dim, 1415] = - 1502.0 * alpha 31 | 32 | out_feature[:, 0, 2533] = 767.6 * alpha 33 | out_feature[:, 0, 1415] = - 458.55 * alpha 34 | 35 | return (out_feature, *out[1:]) 36 | 37 | def setup_intervene_hook(layer, model_name, reset_type): 38 | update_layer = Intervened_Layer(model_name, reset_type) 39 | 40 | def add_batch(): 41 | def modify_hook(_, inp, out): 42 | update_layer.update(inp, out) 43 | return modify_hook 44 | 45 | handle = layer.register_forward_hook(add_batch()) 46 | 47 | return handle -------------------------------------------------------------------------------- /monkey_patch/modify_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import types 3 | 4 | def vit_custom_block_forward(self, x: torch.Tensor) -> torch.Tensor: 5 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 6 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 7 | self.feat = x.clone().detach().cpu().double() ## this is to get the output feature of each layer; 8 | return x 9 | 10 | def enable_vit_custom_block(layer, layer_id): 11 | layer.layer_id = layer_id 12 | layer.forward = types.MethodType(vit_custom_block_forward, layer) 13 | 14 | 15 | def vit_custom_attention_forward(self, x) -> torch.Tensor: 16 | B, N, C = x.shape 17 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 18 | q, k, v = qkv.unbind(0) 19 | q, k = self.q_norm(q), self.k_norm(k) 20 | 21 | q = q * self.scale 22 | attn = q @ k.transpose(-2, -1) 23 | 24 | # ################################################### 25 | self.attn_logits = attn.clone().detach().cpu().double() 26 | # ################################################### 27 | 28 | attn = attn.softmax(dim=-1) 29 | 30 | # ################################################### 31 | self.attn_probs = attn.clone().detach().cpu().double() 32 | # ################################################### 33 | 34 | attn = self.attn_drop(attn) 35 | x = attn @ v 36 | 37 | x = x.transpose(1, 2).reshape(B, N, C) 38 | x = self.proj(x) 39 | x = self.proj_drop(x) 40 | return x 41 | 42 | def enable_vit_custom_attention(layer, layer_id): 43 | modified_module = layer.attn 44 | modified_module.layer_id = layer_id 45 | modified_module.forward = types.MethodType(vit_custom_attention_forward, modified_module) 46 | -------------------------------------------------------------------------------- /gpt-2/configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /monkey_patch/README.md: -------------------------------------------------------------------------------- 1 | # Monkey Patch LLMs 2 | 3 | This directory provide the code where we use to get intermediate hidden states from LLMs from [Transformers](https://github.com/huggingface/transformers/tree/main/src/transformers/models) and ViTs from [timm](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py). For other LLM model families, you need to write a custom forward function, depending on the model definition. 4 | 5 | A sample command to replace the forward function of [LLaMADecoderLayer](https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/llama/modeling_llama.py#L755) with a custom forward function ``llama_custom_decoderlayer_forward``: 6 | ```py 7 | import types 8 | 9 | def enable_llama_custom_decoderlayer(layer, layer_id): 10 | """ 11 | This function modifies a given LlamaDecoderLayer object by setting its layer_id and replacing its forward method with a custom implementation. 12 | It allows for customization of the layer's behavior during the forward pass, which is when the layer processes input data. 13 | """ 14 | 15 | layer.layer_id = layer_id 16 | # This line assigns a unique identifier to the layer. The `layer_id` parameter is used to specify this identifier, 17 | # which can be useful for tracking, debugging, or applying specific operations to certain layers within a larger model. 18 | 19 | layer.forward = types.MethodType( 20 | llama_custom_decoderlayer_forward, layer 21 | ) 22 | # This line replaces the layer's original `forward` method with a new one. 23 | # `types.MethodType` is used to bind a new method to an existing object. In this case, it binds the 24 | # `llama_custom_decoderlayer_forward` function to the `layer` object as its new `forward` method. 25 | # `llama_custom_decoderlayer_forward` should be a function defined elsewhere that takes the same arguments as the original 26 | # `forward` method of the layer and implements the desired custom behavior for processing input data. 27 | # This allows for dynamic modification of the layer's behavior without altering the original class definition. 28 | ``` -------------------------------------------------------------------------------- /lib/load_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | from datasets import load_dataset 5 | 6 | def set_seed(seed): 7 | np.random.seed(seed) 8 | torch.random.manual_seed(seed) 9 | 10 | def get_data(tokenizer, nsamples=50, seqlen=2048, device=None): 11 | valdata = load_dataset("togethercomputer/RedPajama-Data-1T-Sample") 12 | 13 | num_seq = len(valdata["train"]) 14 | seq_indices = np.random.choice(num_seq, 500, replace=False).tolist() 15 | seq_list = [] 16 | for seq_ind in seq_indices: 17 | seq_list.append(valdata["train"][seq_ind]['text']) 18 | 19 | testenc = tokenizer("\n\n".join(seq_list), return_tensors='pt', add_special_tokens=False).input_ids 20 | 21 | testseq_list = [] 22 | for i in range(nsamples): 23 | test_seq = testenc[:, (i * seqlen):((i+1) * seqlen)].to(device) 24 | testseq_list.append(test_seq.reshape(1, seqlen)) 25 | 26 | return testseq_list 27 | 28 | 29 | def get_wikitext(tokenizer, seqlen=2048, device=None): 30 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 31 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt', add_special_tokens=False) 32 | testenc = testenc.input_ids 33 | 34 | testseq_list = [] 35 | nsamples = testenc.numel() // seqlen 36 | 37 | for i in range(nsamples): 38 | testenc_cur = testenc[:,(i * seqlen):((i+1) * seqlen)].to(device) 39 | testseq_list.append(testenc_cur.reshape(1, seqlen)) 40 | return testseq_list 41 | 42 | def get_pg19(tokenizer, seqlen=2048, device=None): 43 | valdata = load_dataset( 44 | 'emozilla/pg19', split='validation' 45 | ) 46 | 47 | testseq_list = [] 48 | valenc = tokenizer(' '.join(valdata[:5]['text']), return_tensors='pt').input_ids 49 | for i in range(100): 50 | testseq_list.append(valenc[:, (i * seqlen):((i+1) * seqlen)].to(device)) 51 | return testseq_list 52 | 53 | def get_c4(tokenizer, seqlen=2048, device=None): 54 | valdata = load_dataset("NeelNanda/c4-10k") 55 | 56 | testseq_list = [] 57 | valenc = tokenizer(' '.join(valdata["train"][:5000]['text']), return_tensors='pt').input_ids 58 | for i in range(100): 59 | testseq_list.append(valenc[:, (i * seqlen):((i+1) * seqlen)].to(device)) 60 | return testseq_list 61 | 62 | def get_test_data(dataset_name, tokenizer=None, seed=0, seqlen=2048, device=None): 63 | random.seed(seed) 64 | set_seed(seed) 65 | if dataset_name == "wikitext": 66 | return get_wikitext(tokenizer, seqlen=seqlen, device=device) 67 | elif dataset_name == "c4": 68 | return get_c4(tokenizer, seqlen=seqlen, device=device) 69 | elif dataset_name == "pg19": 70 | return get_pg19(tokenizer, seqlen=seqlen, device=device) 71 | -------------------------------------------------------------------------------- /gpt-2/README.md: -------------------------------------------------------------------------------- 1 | # Training GPT-2 with Explicit Attention Biases 2 | 3 | We provide the code and pretrained checkpoints for the experiments in Section 5.2 on "Explicit attention biases". The code for training GPT-2 is based on the open-source [nanoGPT](https://github.com/karpathy/nanoGPT) repository. 4 | 5 | --- 6 |

7 | 9 |

10 | We propose to augment the self-attention mechanism with explicit attention biases, by inserting auxiliary key and value parameters. 11 | 12 | [model_attn_bias.py](model_attn_bias.py) contains the model definition of GPT-2 augmented with explicit attention biases. 13 | 14 | ## Setup 15 | 16 | - *data*: Follow [here](https://github.com/karpathy/nanoGPT?tab=readme-ov-file#reproducing-gpt-2) to setup the training and validation data from OpenWebText2. 17 | 18 | - *pretrained models*: Here we provide the model checkpoints for three GPT-2 models we trained, each with 50k iterations 19 | 20 | | model name | download path | validation perplexity | 21 | |:---:|:---:|:---:| 22 | | default | [model](https://drive.google.com/file/d/1_oiybR7wmJ5ibPZMM3sGjtJRqVpMp0H6/view?usp=drive_link) | 3.04 | 23 | | sink | [model](https://drive.google.com/file/d/1ZnhFxN-A7qc9Cghcp_jLpxQXDRE2BqSc/view?usp=drive_link) | 3.04 | 24 | | attn_bias | [model](https://drive.google.com/file/d/1jSpGpNGqJ9Ff_goqoFjSRA5EN7Qmdv1U/view?usp=drive_link) | 3.04 | 25 | 26 | **Note**: For the config files in [config](config), set `out_dir` to the directory of the downloaded pretrained models and `data_dir` to the directories of the prepared OpenWebText2 dataset. 27 | 28 | ## Evalutate 29 | 30 | Running the following commands will evaluate the three GPT-2 checkpoints. 31 | ```sh 32 | CUDA_VISIBLE_DEVICES=0 python test.py config/eval_gpt2_default.py ### gpt2 default architecture 33 | CUDA_VISIBLE_DEVICES=0 python test.py config/eval_gpt2_sink.py ### gpt2 sink token 34 | CUDA_VISIBLE_DEVICES=0 python test.py config/eval_gpt2_attn_bias.py ### gpt2 attention biases 35 | ``` 36 | 37 | ## Training 38 | Running the following commands will train the three GPT-2 models from scratch: (can adjust the number of GPUs for training on multiple GPUs) 39 | ```sh 40 | CUDA_VISIBLE_DEVICES=0 python train.py config/train_gpt2_default.py ### gpt2 default architecture 41 | CUDA_VISIBLE_DEVICES=0 python train.py config/train_gpt2_sink.py ### gpt2 sink token 42 | CUDA_VISIBLE_DEVICES=0 python train.py config/train_gpt2_attn_bias.py ### gpt2 attention biases 43 | ``` 44 | 45 | ## Analysis 46 | We provide the commands for visualizing the activaiton magnitudes of an intermediate feature and also layerwise largest activation magnitudes: 47 | ```sh 48 | CUDA_VISIBLE_DEVICES=0 python analyze.py config/eval_gpt2_default.py 49 | CUDA_VISIBLE_DEVICES=0 python analyze.py config/eval_gpt2_sink.py 50 | CUDA_VISIBLE_DEVICES=0 python analyze.py config/eval_gpt2_attn_bias.py 51 | ``` -------------------------------------------------------------------------------- /gpt-2/plot_gpt2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | def plot_3d_ax_gpt2(feat, model_name, inp_seq, savedir): 7 | fig = plt.figure(figsize=(6,5)) 8 | fig.tight_layout() # Or equivalently, "plt.tight_layout()" 9 | plt.subplots_adjust(wspace=0.1) 10 | 11 | ax = fig.add_subplot(1,1,1, projection='3d') 12 | 13 | name_title = { 14 | "gpt2_default": "GPT-2 Default", 15 | "gpt2_sink": "GPT-2 with Sink Token", 16 | "gpt2_attn_bias": "GPT-2 with Attention Bias", 17 | } 18 | num_tokens = feat.shape[1] 19 | num_channels = feat.shape[2] 20 | xdata = np.array([np.linspace(0,num_tokens-1,num_tokens) for i in range(num_channels)]) 21 | ydata = np.array([np.ones(num_tokens) * i for i in range(num_channels)]) 22 | zdata = feat[0].numpy().T 23 | ax.plot_wireframe(xdata, ydata, zdata, rstride=0, color="blue", linewidth=1.5) 24 | 25 | ax.set_xticks(np.linspace(0,num_tokens-1,num_tokens), inp_seq, 26 | rotation=60, fontsize=16) 27 | 28 | ax.set(yticklabels=[]) 29 | ax.set_zticks([0, 1000, 2000], [0, "1k", "2k"], fontsize=18) 30 | 31 | ax.set_title(name_title[model_name], fontsize=20, fontweight="bold", y=1.005) 32 | plt.setp(ax.get_xticklabels(), rotation=50, ha="right", va="center",rotation_mode="anchor") 33 | plt.setp(ax.get_yticklabels(), ha="left",rotation_mode="anchor") 34 | plt.setp(ax.get_zticklabels(), ha="left",rotation_mode="anchor") 35 | 36 | ax.tick_params(axis='x', which='major', pad=-5) 37 | ax.tick_params(axis='y', which='major', pad=-5) 38 | ax.tick_params(axis='z', which='major', pad=-5) 39 | plt.savefig(os.path.join(savedir, f"{model_name}_3d.png"), bbox_inches="tight", dpi=200) 40 | 41 | def plot_layer_ax_gpt2(mean, model_name, savedir=None): 42 | fig = plt.figure(figsize=(6,5)) 43 | fig.tight_layout() # Or equivalently, "plt.tight_layout()" 44 | plt.subplots_adjust(wspace=0.1) 45 | 46 | ax = fig.add_subplot(1,1,1) 47 | 48 | name_title = { 49 | "gpt2_default": "GPT-2 Default", 50 | "gpt2_sink": "GPT-2 with Sink Token", 51 | "gpt2_attn_bias": "GPT-2 with Attention Bias", 52 | } 53 | 54 | colors = ["cornflowerblue", "mediumseagreen", "C4", "teal", "dimgrey"] 55 | x_axis = np.arange(mean.shape[-1])+1 56 | for i in range(3): 57 | ax.plot(x_axis, mean[i], label=f"Top {i+1}", color=colors[i], 58 | linestyle="-", marker="o", markerfacecolor='none', markersize=5) 59 | 60 | ax.set_title(name_title[model_name], fontsize=18, fontweight="bold") 61 | ax.set_ylabel("Magnitudes", fontsize=18) 62 | 63 | num_layers = mean.shape[1] 64 | xtick_label = [1, num_layers//4, num_layers//2, num_layers*3//4, num_layers] 65 | ax.set_xticks(xtick_label, xtick_label, fontsize=16) 66 | 67 | ax.set_xlabel('Layers', fontsize=18, labelpad=0.3) 68 | ax.tick_params(axis='x', which='major', pad=1.0) 69 | ax.tick_params(axis='y', which='major', pad=0.4) 70 | ax.grid(axis='x', color='0.75') 71 | plt.savefig(os.path.join(savedir, f"{model_name}_layerwise.png"), bbox_inches="tight", dpi=200) -------------------------------------------------------------------------------- /lib/model_dict.py: -------------------------------------------------------------------------------- 1 | CACHE_DIR_BASE = "./model_weights" 2 | 3 | MODEL_DICT_LLMs = { 4 | ### llama2 model 5 | "llama2_7b": { 6 | "model_id": "meta-llama/Llama-2-7b-hf", 7 | "cache_dir": CACHE_DIR_BASE 8 | }, 9 | "llama2_13b": { 10 | "model_id": "meta-llama/Llama-2-13b-hf", 11 | "cache_dir": CACHE_DIR_BASE 12 | }, 13 | "llama2_70b": { 14 | "model_id": "meta-llama/Llama-2-70b-hf", 15 | "cache_dir": CACHE_DIR_BASE 16 | }, 17 | 18 | ### llama2 chat model 19 | "llama2_7b_chat": { 20 | "model_id": "meta-llama/Llama-2-7b-chat-hf", 21 | "cache_dir": CACHE_DIR_BASE 22 | }, 23 | "llama2_13b_chat": { 24 | "model_id": "meta-llama/Llama-2-13b-chat-hf", 25 | "cache_dir": CACHE_DIR_BASE 26 | }, 27 | "llama2_70b_chat": { 28 | "model_id": "meta-llama/Llama-2-70b-chat-hf", 29 | "cache_dir": CACHE_DIR_BASE 30 | }, 31 | 32 | ### mistral model 33 | "mistral_7b": { 34 | "model_id": "mistralai/Mistral-7B-v0.1", 35 | "cache_dir": CACHE_DIR_BASE, 36 | }, 37 | "mistral_moe": { 38 | "model_id": "mistralai/Mixtral-8x7B-v0.1", 39 | "cache_dir": CACHE_DIR_BASE, 40 | }, 41 | "mistral_7b_instruct":{ 42 | "model_id": "mistralai/Mistral-7B-Instruct-v0.2", 43 | "cache_dir": CACHE_DIR_BASE, 44 | }, 45 | "mistral_moe_instruct": { 46 | "model_id": "mistralai/Mixtral-8x7B-Instruct-v0.1", 47 | "cache_dir": CACHE_DIR_BASE, 48 | }, 49 | 50 | ### phi-2 51 | "phi-2": { 52 | "model_id": "microsoft/phi-2", 53 | "cache_dir": CACHE_DIR_BASE, 54 | }, 55 | 56 | ### falcon model 57 | "falcon_7b": { 58 | "model_id": "tiiuae/falcon-7b", 59 | "cache_dir": CACHE_DIR_BASE, 60 | }, 61 | "falcon_40b": { 62 | "model_id": "tiiuae/falcon-40b", 63 | "cache_dir": CACHE_DIR_BASE, 64 | }, 65 | 66 | ### mpt model 67 | "mpt_7b": { 68 | "model_id": "mosaicml/mpt-7b", 69 | "cache_dir": CACHE_DIR_BASE, 70 | }, 71 | "mpt_30b": { 72 | "model_id": "mosaicml/mpt-30b", 73 | "cache_dir": CACHE_DIR_BASE, 74 | }, 75 | 76 | ### opt model 77 | "opt_125m": { 78 | "model_id": "facebook/opt-125m", 79 | "cache_dir": CACHE_DIR_BASE, 80 | }, 81 | "opt_350m": { 82 | "model_id": "facebook/opt-350m", 83 | "cache_dir": CACHE_DIR_BASE, 84 | }, 85 | "opt_1.3b": { 86 | "model_id": "facebook/opt-1.3b", 87 | "cache_dir": CACHE_DIR_BASE, 88 | }, 89 | "opt_2.7b": { 90 | "model_id": "facebook/opt-2.7b", 91 | "cache_dir": CACHE_DIR_BASE, 92 | }, 93 | "opt_7b": { 94 | "model_id": "facebook/opt-6.7b", 95 | "cache_dir": CACHE_DIR_BASE, 96 | }, 97 | "opt_13b": { 98 | "model_id": "facebook/opt-13b", 99 | "cache_dir": CACHE_DIR_BASE, 100 | }, 101 | "opt_30b": { 102 | "model_id": "facebook/opt-30b", 103 | "cache_dir": CACHE_DIR_BASE, 104 | }, 105 | "opt_66b": { 106 | "model_id": "facebook/opt-66b", 107 | "cache_dir": CACHE_DIR_BASE, 108 | }, 109 | 110 | ### gpt2 model 111 | "gpt2": { 112 | "model_id": "gpt2", 113 | "cache_dir": CACHE_DIR_BASE 114 | }, 115 | "gpt2_medium": { 116 | "model_id": "gpt2-medium", 117 | "cache_dir": CACHE_DIR_BASE 118 | }, 119 | "gpt2_large": { 120 | "model_id": "gpt2-large", 121 | "cache_dir": CACHE_DIR_BASE 122 | }, 123 | "gpt2_xl": { 124 | "model_id": "gpt2-xl", 125 | "cache_dir": CACHE_DIR_BASE 126 | }, 127 | } -------------------------------------------------------------------------------- /lib/eval_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | 4 | import torch 5 | import torch.nn as nn 6 | from timm.utils import accuracy 7 | 8 | from .load_data import get_test_data 9 | 10 | 11 | def fix_reg_mean(args, model): 12 | """ 13 | Modifies the forward pass of each layer in a ViT model to set register token features to pre-computed means. 14 | """ 15 | def custom_layer_forward(self, x: torch.Tensor) -> torch.Tensor: 16 | for reg_id in range(4): 17 | reg_id = int(reg_id) 18 | cur_reg_feat = self.reg_feat[reg_id].clone() 19 | x[:,reg_id+1,:] = cur_reg_feat.reshape(1,-1).repeat(x.shape[0],1) 20 | 21 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 22 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 23 | return x 24 | 25 | # Load the pre-computed means of resgiter token features for the DINOv2-reg ViT-`args.model_size` 26 | imagenet10k_reg_feat_mean = torch.load(os.path.join(args.reg_feat_mean, f"{args.model_size}_imagemean_bs10k_trainaug.pt")) 27 | for layer_id in range(len(model.blocks)): 28 | layer = model.blocks[layer_id] 29 | layer.reg_feat = imagenet10k_reg_feat_mean[layer_id] 30 | 31 | # Replace the layer's forward function with the custom forward function defined above 32 | layer.forward = types.MethodType( 33 | custom_layer_forward, layer 34 | ) 35 | 36 | def setup_dinov2_model_for_eval(model, linear_head_weights): 37 | """ 38 | Configures a DINOv2 model for ImageNet evaluation by setting up a new linear head with given weights and a custom forward function. 39 | 40 | Args: 41 | model: The DINOv2 pretrained ViT models. 42 | linear_head_weights: A dictionary containing the weights and bias for the new linear head. 43 | """ 44 | in_features, out_features = linear_head_weights["weight"].shape 45 | model.head = nn.Linear(in_features, out_features, bias=True) 46 | model.head.weight.data = linear_head_weights["weight"] 47 | model.head.bias.data = linear_head_weights["bias"] 48 | model.head.cuda() 49 | 50 | def custom_forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: 51 | x = torch.cat([x[:,0], x[:, self.num_prefix_tokens:].mean(dim=1)], dim=-1) 52 | x = self.fc_norm(x) 53 | x = self.head_drop(x) 54 | return self.head(x) 55 | 56 | model.forward_head = types.MethodType( 57 | custom_forward_head, model 58 | ) 59 | 60 | def test_imagenet(model, dataloader): 61 | acc = 0 62 | cnt = 0 63 | for batch_idx, batch in enumerate(dataloader): 64 | if batch_idx % 10 == 0 and batch_idx > 0: 65 | print(f"batch idx {batch_idx} acc {acc/cnt}") 66 | images = batch[0].cuda() 67 | target = batch[-1].cuda() 68 | 69 | with torch.no_grad(): 70 | output = model(images) 71 | acc1, _ = accuracy(output, target, topk=(1, 5)) 72 | acc += acc1 * images.shape[0] 73 | cnt += images.shape[0] 74 | 75 | return acc/cnt 76 | 77 | @torch.no_grad() 78 | def eval_ppl(dataset_name, model, tokenizer, seed): 79 | print(f"Evaluating on {dataset_name}") 80 | seqlen=4096 81 | testseq_list = get_test_data( 82 | dataset_name, seed=seed, tokenizer=tokenizer, seqlen=seqlen, device="cuda:0" 83 | ) 84 | 85 | nlls = [] 86 | with torch.no_grad(): 87 | for test_seq in testseq_list: 88 | lm_logits = model(test_seq).logits 89 | 90 | shift_logits = lm_logits[:, :-1, :].contiguous() ## shape: [1, 2047, 50272] 91 | shift_labels = test_seq[:, 1:] ## shape: shape: [1, 2047] 92 | 93 | loss_fct = nn.CrossEntropyLoss() 94 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) 95 | neg_log_likelihood = loss.float() * test_seq.numel() 96 | nlls.append(neg_log_likelihood) 97 | 98 | ppl = torch.exp(torch.stack(nlls).sum() / (len(testseq_list) * seqlen)) 99 | return ppl.item() -------------------------------------------------------------------------------- /main_vit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from importlib.metadata import version 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from torchvision import datasets 9 | 10 | import lib 11 | import monkey_patch as mp 12 | 13 | print('torch', version('torch')) 14 | print('transformers', version('transformers')) 15 | print('accelerate', version('accelerate')) 16 | print('# of gpus: ', torch.cuda.device_count()) 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--model_family', type=str) 21 | parser.add_argument('--model_size', type=str) 22 | parser.add_argument("--layer_id", type=int, default=1) 23 | parser.add_argument('--exp1', action="store_true", help="plot 3d feature") 24 | parser.add_argument('--exp2', action="store_true", help="layerwise analysis") 25 | parser.add_argument('--exp3', action="store_true", help="test original and fix-reg-mean accuracy") 26 | parser.add_argument('--imagenet_dir', type=str, default="/home/mingjies/imagenet-data/val") 27 | parser.add_argument('--linear_head_path', type=str, default="/data/locus/project_data/project_data2/mingjies/dinov2") 28 | parser.add_argument('--reg_feat_mean', type=str, default="assets/reg_feat_mean/") 29 | parser.add_argument('--seed', type=int, default=0) 30 | parser.add_argument('--num_imgs_mean', type=int, default=10) 31 | parser.add_argument('--savedir', type=str) 32 | args = parser.parse_args() 33 | 34 | torch.manual_seed(args.seed) 35 | 36 | if not os.path.exists(args.savedir): 37 | os.makedirs(args.savedir) 38 | 39 | model, layers, val_transform = lib.load_vit(args) 40 | 41 | if args.exp1: 42 | layer_id = args.layer_id - 1 43 | layer = layers[layer_id] 44 | mp.enable_vit_custom_block(layer, layer_id) 45 | 46 | img_path = os.path.join("assets", f"bird.png") 47 | img = Image.open(img_path) 48 | img = val_transform(img).unsqueeze(0).cuda() 49 | 50 | with torch.no_grad(): 51 | output = model(img) 52 | 53 | feat_abs = layers[layer_id].feat.abs() 54 | 55 | lib.plot_3d_feat_vit(feat_abs, layer_id, args.model_family, args.model_size, args.savedir) 56 | # torch.save(stats, os.path.join(args.savedir, f"stats.pt")) 57 | 58 | elif args.exp2: 59 | for layer_id in range(len(layers)): 60 | layer = layers[layer_id] 61 | mp.enable_vit_custom_block(layer, layer_id) 62 | 63 | dataset = datasets.ImageFolder(args.imagenet_dir, transform=val_transform) 64 | 65 | stats = [] 66 | for img_idx in range(args.num_imgs_mean): 67 | print("img_idx", img_idx) 68 | images, target = dataset[img_idx] 69 | images = images.unsqueeze(0).cuda() 70 | 71 | with torch.no_grad(): 72 | output = model(images) 73 | 74 | layer_stats_np = np.zeros((4, len(layers))) 75 | for layer_id in range(len(layers)): 76 | feat_abs = layers[layer_id].feat.abs() 77 | sort_res = torch.sort(feat_abs.flatten(), descending=True) 78 | layer_stats_np[:3, layer_id] = sort_res.values[:3] 79 | layer_stats_np[3, layer_id] = torch.median(feat_abs) 80 | 81 | stats.append(layer_stats_np) 82 | 83 | lib.plot_layer_ax_vit(np.mean(stats, axis=0), args.model_family, args.model_size, args.savedir) 84 | 85 | elif args.exp3: 86 | linear_head = lib.load_dinov2_linear_head(args) 87 | lib.setup_dinov2_model_for_eval(model, linear_head) 88 | 89 | dataset = datasets.ImageFolder(args.imagenet_dir, transform=val_transform) 90 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, num_workers=8, pin_memory=False) 91 | 92 | f = open(os.path.join(args.savedir, "eval.txt"), "w") 93 | top1_acc = lib.test_imagenet(model, dataloader) 94 | print(f"{args.model_family} ViT-{args.model_size} original accuracy: {top1_acc}", file=f, flush=True) 95 | 96 | lib.fix_reg_mean(args, model) 97 | top1_acc = lib.test_imagenet(model, dataloader) 98 | print(f"{args.model_family} ViT-{args.model_size} fix-reg-mean accuracy: {top1_acc}", file=f, flush=True) -------------------------------------------------------------------------------- /lib/plot_utils_vit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | # Configuration settings for matplotlib 8 | matplotlib.rcParams['pgf.texsystem'] = 'pdflatex' 9 | matplotlib.rcParams.update({ 10 | 'font.size': 18, 11 | 'axes.labelsize': 20, 12 | 'axes.titlesize': 24, 13 | 'figure.titlesize': 28 14 | }) 15 | matplotlib.rcParams['text.usetex'] = False 16 | 17 | def plot_3d_feat_vit_sub(ax, feat, layer_id, model_name, model_size): 18 | model_title={"dinov2_reg": f"DINOv2-reg ViT-{model_size}", "mistral_7b": "Mistral-7B", 19 | "llama2_13b": "LLaMA-2-13B", "llama2_70b": "LLaMA-2-70B", "mistral_moe":"Mixtral-8x7B"} 20 | 21 | num_channels = feat.shape[2] 22 | 23 | inp_seq = ["CLS", "reg 1", "reg 2", "reg 3", "reg 4", 24 | "patch 1", "patch 2", "patch i", "patch n"] 25 | 26 | xbase_index = [0,1,2,3,4,5,7,9] 27 | num_tokens = len(xbase_index) 28 | xdata = np.array([xbase_index for i in range(num_channels)]) 29 | ydata = np.array([np.ones(num_tokens) * i for i in range(num_channels)]) 30 | zdata = feat[0,:num_tokens,:].abs().numpy().T 31 | ax.plot_wireframe(xdata, ydata, zdata, rstride=0, color="royalblue", linewidth=2.5) 32 | 33 | ax.set_title(model_title[model_name]+f", Layer {layer_id+1}", fontsize=20, fontweight="bold", y=1.015) 34 | 35 | ax.set_yticks([179, 999], [179, 999], fontsize=15, fontweight="heavy") 36 | 37 | xbase_index = [0,1,2,3,4,] 38 | inp_seq = ["CLS", "reg 1", "reg 2", "reg 3", "reg 4"] 39 | ax.set_xticks(xbase_index, inp_seq, rotation=60, fontsize=16) 40 | ax.tick_params(axis='x', which='major', pad=-4) 41 | plt.setp(ax.get_xticklabels(), rotation=50, ha="right", va="center", rotation_mode="anchor") 42 | 43 | ax.set_zticks([0, 500, 1000], ["0", "500", "1k"], fontsize=16) 44 | ax.get_xticklabels()[3].set_weight("heavy") 45 | plt.setp(ax.get_yticklabels(), ha="left", va="center",rotation_mode="anchor") 46 | plt.setp(ax.get_zticklabels(), ha="left", va="top", rotation_mode="anchor") 47 | 48 | ax.tick_params(axis='x', which='major', pad=-5) 49 | ax.tick_params(axis='y', which='major', pad=-3) 50 | ax.tick_params(axis='z', which='major', pad=-5) 51 | 52 | def plot_3d_feat_vit(feat, layer_id, model_name, model_size, savedir): 53 | fig = plt.figure(figsize=(8,6)) 54 | fig.tight_layout() # Or equivalently, "plt.tight_layout()" 55 | plt.subplots_adjust(wspace=0.) 56 | 57 | ax = fig.add_subplot(1,1, 1, projection='3d') 58 | plot_3d_feat_vit_sub(ax, feat, layer_id, model_name, model_size) 59 | plt.savefig(os.path.join(savedir, f"{model_name}_{model_size}_layer_{layer_id+1}.png"), bbox_inches="tight", dpi=200) 60 | 61 | 62 | def plot_layer_ax_vit_sub(ax, mean, model_family, model_size, colors=["royalblue", "darkorange", "forestgreen", "black"]): 63 | model_title={"dinov2_reg": "DINOv2-reg", 64 | "dinov2": "DINOv2", "mae": "MAE", "open_clip": "Open CLIP", "openai_clip": "OpenAI CLIP", 65 | "vit_orig": "ViT", "samvit": "SAM-ViT"} 66 | 67 | x_axis = np.arange(mean.shape[-1])+1 68 | for i in range(3): 69 | ax.plot(x_axis, mean[i], label=f"Top {i+1}", color=colors[i], 70 | linestyle="-", marker="o", markerfacecolor='none', markersize=5) 71 | 72 | ax.plot(x_axis, mean[-1], label=f"median", color=colors[-1], 73 | linestyle="-", marker="v", markerfacecolor='none', markersize=5) 74 | 75 | ax.set_title(model_title[model_family]+f" ViT-{model_size}", fontsize=18, fontweight="bold") 76 | ax.set_ylabel("Magnitudes", fontsize=18) 77 | 78 | num_layers = mean.shape[1] 79 | xtick_label = [1, num_layers//4, num_layers//2, num_layers*3//4, num_layers] 80 | ax.set_xticks(xtick_label, xtick_label, fontsize=16) 81 | 82 | ax.set_xlabel('Layers', fontsize=18, labelpad=4.0) 83 | ax.tick_params(axis='x', which='major', pad=2.0) 84 | ax.tick_params(axis='y', which='major', pad=0.4) 85 | ax.grid(axis='x', color='0.75') 86 | ax.grid(axis='y', color='0.75') 87 | 88 | def plot_layer_ax_vit(mean, model_family, model_size, savedir): 89 | fig = plt.figure(figsize=(8,6)) 90 | fig.tight_layout() # Or equivalently, "plt.tight_layout()" 91 | plt.subplots_adjust(wspace=0.) 92 | 93 | ax = fig.add_subplot(1,1, 1) 94 | plot_layer_ax_vit_sub(ax, mean, model_family, model_size) 95 | plt.savefig(os.path.join(savedir, f"{model_family}_{model_size}.png"), bbox_inches="tight", dpi=200) -------------------------------------------------------------------------------- /lib/load_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import timm 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | 7 | from .model_dict import MODEL_DICT_LLMs 8 | 9 | 10 | def load_llm(args): 11 | print(f"loading model {args.model}") 12 | model_name, cache_dir = MODEL_DICT_LLMs[args.model]["model_id"], MODEL_DICT_LLMs[args.model]["cache_dir"] 13 | 14 | if "falcon" in args.model or "mpt" in args.model or "phi" in args.model: 15 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, cache_dir=cache_dir, low_cpu_mem_usage=True, device_map="auto", trust_remote_code=True, token=args.access_token) 16 | elif "mistral" in args.model or "pythia" in args.model: 17 | model = AutoModelForCausalLM.from_pretrained(model_name, revision=args.revision, torch_dtype=torch.float16, cache_dir=cache_dir, low_cpu_mem_usage=True, device_map="auto", token=args.access_token) 18 | else: 19 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, cache_dir=cache_dir, low_cpu_mem_usage=True, device_map="auto", token=args.access_token) 20 | model.eval() 21 | 22 | if "mpt" in args.model or "pythia" in args.model: 23 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, token=args.access_token) 24 | else: 25 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, token=args.access_token) 26 | 27 | device = torch.device("cuda:0") 28 | if "mpt_30b" in args.model: 29 | device = model.hf_device_map["transformer.wte"] 30 | elif "30b" in args.model or "65b" in args.model or "70b" in args.model or "40b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here. 31 | device = torch.device("cuda:"+str(model.hf_device_map["lm_head"])) 32 | 33 | if "llama2_13b" == args.model: 34 | # device = torch.device("cuda:"+str(model.hf_device_map["lm_head"])) 35 | device = torch.device("cuda:1") 36 | 37 | seq_len=4096 38 | if "llama" in args.model or "mistral" in args.model: 39 | layers = model.model.layers 40 | hidden_size = model.config.hidden_size 41 | elif "falcon" in args.model: 42 | layers = model.transformer.h 43 | hidden_size = model.config.hidden_size 44 | elif "mpt" in args.model: 45 | layers = model.transformer.blocks 46 | hidden_size = model.config.d_model 47 | seq_len=2048 48 | elif "opt" in args.model: 49 | layers = model.model.decoder.layers 50 | hidden_size = model.config.hidden_size 51 | seq_len = 2048 52 | elif "gpt2" in args.model: 53 | layers = model.transformer.h 54 | hidden_size = model.transformer.embed_dim 55 | seq_len = 1024 56 | elif "pythia" in args.model: 57 | layers = model.gpt_neox.layers 58 | hidden_size = model.gpt_neox.config.hidden_size 59 | seq_len = 2048 60 | elif "phi-2" in args.model: 61 | layers = model.model.layers 62 | hidden_size = model.config.hidden_size 63 | 64 | return model, tokenizer, device, layers, hidden_size, seq_len 65 | 66 | def load_vit(args): 67 | if args.model_family == "mae": 68 | patch_size=14 if args.model_size == "huge" else 16 69 | model = timm.create_model(f'vit_{args.model_size}_patch{patch_size}_224.mae', pretrained=True) 70 | elif args.model_family == "openai_clip": 71 | patch_size=14 if args.model_size == "large" else 16 72 | model = timm.create_model(f"vit_{args.model_size}_patch{patch_size}_clip_224.openai", pretrained=True) 73 | elif args.model_family == "dinov2": 74 | model = timm.create_model(f"vit_{args.model_size}_patch14_dinov2.lvd142m", pretrained=True, num_classes=1000) 75 | elif args.model_family == "dinov2_reg": 76 | model = timm.create_model(f"vit_{args.model_size}_patch14_reg4_dinov2.lvd142m", pretrained=True, num_classes=1000) 77 | 78 | model = model.cuda() 79 | model = model.eval() 80 | 81 | layers = model.blocks 82 | 83 | data_config = timm.data.resolve_model_data_config(model) 84 | val_transform = timm.data.create_transform(**data_config, is_training=False) 85 | 86 | return model, layers, val_transform 87 | 88 | def load_dinov2_linear_head(args): 89 | assert "dinov2" in args.model_family, "this function is only for dinov2 models" 90 | if args.model_family == "dinov2_reg": 91 | linear_head_path = os.path.join(args.linear_head_path, f"dinov2_vit{args.model_size[0]}14_reg4_linear_head.pth") 92 | elif args.model_family == "dinov2": 93 | linear_head_path = os.path.join(args.linear_head_path, f"dinov2_vit{args.model_size[0]}14_linear_head.pth") 94 | 95 | linear_head_weights = torch.load(linear_head_path) 96 | return linear_head_weights -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Massive Activations in Large Language Models 2 | 3 | Official PyTorch implementation of our paper: 4 | 5 | **Massive Activations in Large Language Models**
6 | [Mingjie Sun](https://eric-mingjie.github.io/), [Xinlei Chen](https://xinleic.xyz/), [J. Zico Kolter](https://zicokolter.com/), [Zhuang Liu](https://liuzhuang13.github.io/)
7 | Carnegie Mellon University, Meta AI Research and Bosch Center for AI
8 | [Paper](https://arxiv.org/abs/2402.17762) - [Project page](https://eric-mingjie.github.io/massive-activations/index.html) 9 | 10 | Most of the experiments in this paper were done on one A6000 GPU. 11 | 12 | --- 13 |

14 | 16 |

17 | 18 | This paper studies the existence of *massive activations* in Large Language Models (LLMs). These activations have significantly larger magnitudes than other activations while on the other hand are extremely few in quantity. 19 | 20 | ## This repository 21 | 22 | ### Setup 23 | Installation instructions can be found in [INSTALL.md](INSTALL.md). 24 | 25 | ### Outline 26 | The contents of this repository are as follows: 27 | 28 | * [lib](lib) contains the util function for loading models, plotting figures and evaluation. 29 | * [monkey_patch](monkey_patch) contains the code for monkey patching LLMs with custom forward function, with a goal of collecting internal activation and attention statistics. 30 | * [gpt-2](gpt-2) contains the code for training GPT-2 with explicit attention biases. 31 | * [main_llm.py](main_llm.py) contains the code for reproducing our experiments on LLMs. 32 | * [main_vit.py](main_vit.py) contains the code for reproducing our experiments on ViTs. 33 | 34 | ### Large Language Models (LLMs) 35 | 36 | * We provide an example command to visualize a hidden state feature on the residual stream: 37 | ```sh 38 | CUDA_VISIBLE_DEVICES=0 python main_llm.py --model llama2_7b \ 39 | --exp1 --layer_id 2 \ 40 | --savedir results/llm/3d_feat_vis/ 41 | ``` 42 | Running this command will visualize the output feature of layer 2 in LLaMA-2-7B, on the input prompt "*Summer is warm. Winter is cold.\n*". The resulting visualizations are saved in `results/llm/3d_feat_vis/`. 43 | 44 | For some LLMs, e.g., LLaMA2-7B, you need to set the argument `--access-token` in order to access the weights. 45 | 46 | * We provide an example command to visualize the layerwise top 3 largest activation magnitudes: 47 | ```sh 48 | CUDA_VISIBLE_DEVICES=0 python main_llm.py --model llama2_7b \ 49 | --exp2 \ 50 | --savedir results/llm/layerwise/ 51 | ``` 52 | Running this command will visualize the per layer top activation magnitudes. The resulting visualizations are saved in `results/llm/layerwise`. 53 | 54 | * We provide an example command to run the intervention analysis: 55 | ```sh 56 | CUDA_VISIBLE_DEVICES=0 python main_llm.py --model llama2_7b \ 57 | --exp3 \ 58 | --reset_type set_zero \ 59 | --layer_id 2 \ 60 | --savedir results/llm/intervention_analysis/ 61 | ``` 62 | Here the argument `--reset_type` can be either `set_zero` or `set_mean`. This command will zero the massive activations in the output feature of layer 2 in LLaMA-2-7B. The evaluation results are saved in `results/llm/intervention_analysis`. 63 | 64 | * We provide an example command for attention visualization: 65 | ```sh 66 | CUDA_VISIBLE_DEVICES=0 python main_llm.py --model llama2_7b \ 67 | --exp4 \ 68 | --layer_id 3 \ 69 | --savedir results/llm/attn_vis/ 70 | ``` 71 | Running this command will visualize the attention logits (average over attention heads) in layer 3 of LLaMA-2-7B. The visualizations are saved in `results/llm/attn_vis/`. 72 | 73 | ### Vision Transformers (ViTs) 74 | 75 | * We provide an example command for visualizing the activation magnitudes of the output feature of an intermediate layer: 76 | ```sh 77 | CUDA_VISIBLE_DEVICES=0 python main_vit.py --model_family dinov2_reg --model_size giant \ 78 | --exp1 \ 79 | --layer_id 40 \ 80 | --savedir results/vit/3d_feat_vis/ 81 | ``` 82 | 83 | * We provide an example command for visualizing the layer-wise largest activation magnitudes: 84 | ```sh 85 | CUDA_VISIBLE_DEVICES=0 python main_vit.py --model_family dinov2_reg --model_size giant \ 86 | --exp2 \ 87 | --savedir results/vit/layerwise/ 88 | ``` 89 | 90 | * For reproducing the results of `Fix-Reg-Mean` on [DINOv2-reg](https://arxiv.org/abs/2309.16588), run the following commands: 91 | ```sh 92 | for model_size in small base large giant 93 | do 94 | CUDA_VISIBLE_DEVICES=0 python main_vit.py \ 95 | --model_family dinov2_reg --model_size ${model_size} --exp3 \ 96 | --reg_feat_mean assets/reg_feat_mean \ 97 | --imagenet_dir [Path to ImageNet validation set] \ 98 | --savedir results/vit/exp4/dinov2_reg_${model_size} 99 | done 100 | ``` 101 | The argument `--reg_feat_mean` corresponds to the directory containing the mean of the register features at all layers collected over 10k ImageNet training images with data augmentations. 102 | 103 | Results 104 | | DINOv2-reg | ViT-S | ViT-B | ViT-L | ViT-G | 105 | |------------------|----------|----------|--------|-------| 106 | | Original | 81.9 | 84.8 | 86.3 | 87.0 | 107 | | `Fix-Reg-Mean` | 81.7 | 85.0 | 86.2 | 87.0 | 108 | 109 | ## License 110 | This project is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information. 111 | 112 | ## Reference 113 | ```bibtex 114 | @article{sun2024massive, 115 | title={Massive Activations in Large Language Models}, 116 | author={Sun, Mingjie and Chen, Xinlei and Kolter, J. Zico and Liu, Zhuang}, 117 | year={2024}, 118 | journal={arXiv preprint arXiv:2402.17762} 119 | } 120 | ``` -------------------------------------------------------------------------------- /main_llm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from importlib.metadata import version 4 | 5 | import numpy as np 6 | import torch 7 | 8 | import lib 9 | import monkey_patch as mp 10 | 11 | print('torch', version('torch')) 12 | print('transformers', version('transformers')) 13 | print('accelerate', version('accelerate')) 14 | print('# of gpus: ', torch.cuda.device_count()) 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument('--model', type=str, help='LLaMA model') 20 | parser.add_argument("--dataset", type=str, default="wikitext") 21 | parser.add_argument('--seed',type=int, default=1, help='Seed for sampling the calibration data.') 22 | parser.add_argument("--revision", type=str, default="main") 23 | 24 | parser.add_argument("--exp1", action="store_true", help="plot 3d feature") 25 | parser.add_argument("--exp2", action="store_true", help="layerwise analysis") 26 | parser.add_argument("--exp3", action="store_true", help="intervention analysis") 27 | parser.add_argument("--exp4", action="store_true", help="attention visualization") 28 | parser.add_argument("--layer_id", type=int, default=1) 29 | parser.add_argument("--reset_type", type=str, default="set_mean") 30 | parser.add_argument("--access_token", type=str, default="type in your access token here") 31 | parser.add_argument("--savedir", type=str) 32 | args = parser.parse_args() 33 | 34 | np.random.seed(args.seed) 35 | torch.manual_seed(args.seed) 36 | if not os.path.exists(args.savedir): 37 | os.makedirs(args.savedir) 38 | 39 | model, tokenizer, device, layers, hidden_size, seq_len = lib.load_llm(args) 40 | print("use device ", device) 41 | 42 | if args.exp1: ### visualize the output feature of a layer in LLMs 43 | layer_id = args.layer_id - 1 44 | if "llama2" in args.model: 45 | mp.enable_llama_custom_decoderlayer(layers[layer_id], layer_id) 46 | elif "mistral" in args.model: 47 | mp.enable_mistral_custom_decoderlayer(layers[layer_id], layer_id) 48 | elif "phi-2" in args.model: 49 | mp.enable_phi2_custom_decoderlayer(layers[layer_id], layer_id) 50 | else: 51 | raise ValueError(f"model {args.model} not supported") 52 | 53 | stats = {} 54 | seq = "Summer is warm. Winter is cold." 55 | valenc = tokenizer(seq, return_tensors='pt', add_special_tokens=False).input_ids.to(device) 56 | 57 | with torch.no_grad(): 58 | model(valenc) 59 | 60 | seq_decoded = [] 61 | for i in range(valenc.shape[1]): 62 | seq_decoded.append(tokenizer.decode(valenc[0,i].item())) 63 | 64 | stats[f"seq"] = seq_decoded 65 | feat_abs = layers[layer_id].feat.abs() 66 | 67 | stats[f"{layer_id}"] = feat_abs 68 | 69 | lib.plot_3d_feat(stats, layer_id, args.model, args.savedir) 70 | 71 | elif args.exp2: ### visualize the layerwise top activation magnitudes 72 | for layer_id in range(len(layers)): 73 | layer = layers[layer_id] 74 | if "llama2" in args.model: 75 | mp.enable_llama_custom_decoderlayer(layer, layer_id) 76 | elif "mistral" in args.model: 77 | mp.enable_mistral_custom_decoderlayer(layer, layer_id) 78 | elif "phi-2" in args.model: 79 | mp.enable_phi2_custom_decoderlayer(layers[layer_id], layer_id) 80 | else: 81 | raise ValueError(f"model {args.model} not supported") 82 | 83 | testseq_list = lib.get_data(tokenizer, nsamples=10, seqlen=seq_len, device=device) 84 | 85 | stats = [] 86 | for seqid, testseq in enumerate(testseq_list): 87 | print(f"processing seq {seqid}") 88 | with torch.no_grad(): 89 | model(testseq) 90 | 91 | seq_np = np.zeros((4, len(layers))) 92 | for layer_id in range(len(layers)): 93 | feat_abs = layers[layer_id].feat.abs() 94 | sort_res = torch.sort(feat_abs.flatten(), descending=True) 95 | seq_np[:3, layer_id] = sort_res.values[:3] 96 | seq_np[3, layer_id] = torch.median(feat_abs) 97 | 98 | stats.append(seq_np) 99 | 100 | lib.plot_layer_ax(stats, args.model, args.savedir) 101 | 102 | elif args.exp3: ### intervention analysis 103 | layer = layers[args.layer_id-1] 104 | lib.setup_intervene_hook(layer, args.model, args.reset_type) 105 | 106 | f = open(os.path.join(args.savedir, f"{args.model}_{args.reset_type}.log"), "a") 107 | 108 | ds_list = ["wikitext", "c4", "pg19"] 109 | res = {} 110 | for ds_name in ds_list: 111 | ppl = lib.eval_ppl(ds_name, model, tokenizer, args.seed, device) 112 | res[ds_name] = ppl 113 | print(f"{ds_name} ppl: {ppl}", file=f, flush=True) 114 | 115 | elif args.exp4: 116 | layer_id = args.layer_id - 1 117 | if "llama2" in args.model: 118 | modified_attn_layer = mp.enable_llama_custom_attention(layers[layer_id], layer_id) 119 | elif "mistral" in args.model: 120 | modified_attn_layer = mp.enable_mistral_custom_attention(layers[layer_id], layer_id) 121 | elif "phi-2" in args.model: 122 | modified_attn_layer = mp.enable_phi2_custom_attention(layers[layer_id], layer_id) 123 | else: 124 | raise ValueError(f"model {args.model} not supported") 125 | 126 | seq = "The following are multiple choice questions (with answers) about machine learning.\n\n A 6-sided die is rolled 15 times and the results are: side 1 comes up 0 times;" 127 | valenc = tokenizer(seq, return_tensors='pt', add_special_tokens=False).input_ids.to(device) 128 | 129 | with torch.no_grad(): 130 | model(valenc) 131 | 132 | attn_logit = layers[layer_id].self_attn.attn_logits.detach().cpu() 133 | lib.plot_attn(attn_logit, args.model, layer_id, args.savedir) -------------------------------------------------------------------------------- /lib/plot_utils_llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import seaborn as sns 7 | 8 | matplotlib.rcParams['pgf.texsystem'] = 'pdflatex' 9 | matplotlib.rcParams.update({ 10 | 'font.size': 18, 11 | 'axes.labelsize': 20, 12 | 'axes.titlesize': 24, 13 | 'figure.titlesize': 28 14 | }) 15 | matplotlib.rcParams['text.usetex'] = False 16 | 17 | MODEL_TITLE_DICT={"llama2_7b": "LLaMA-2-7B", "mistral_7b": "Mistral-7B", 18 | "llama2_13b_chat": "LLaMA-2-13B-chat", "llama2_70b_chat": "LLaMA-2-70B-chat", 19 | "llama2_7b_chat": "LLaMA-2-7B-chat", "llama2_13b": "LLaMA-2-13B", "llama2_70b": "LLaMA-2-70B", 20 | "mistral_moe":"Mixtral-8x7B", "falcon_7b": "Falcon-7B", "falcon_40b": "Falcon-40B", "phi-2": "Phi-2", 21 | "opt_7b":"OPT-7B", "opt_13b": "OPT-13B", "opt_30b": "OPT-30B", "opt_66b": "OPT-66B", 22 | "mpt_7b": "MPT-7B", "mpt_30b": "MPT-30B", "pythia_7b": "Pythia-7B", "pythia_12b": "Pythia-12B", 23 | "gpt2": "GPT-2", "gpt2_large": "GPT-2-Large", "gpt2_xl": "GPT-2-XL", "gpt2_medium": "GPT-2-Medium", 24 | "mistral_moe_instruct": "Mixtral-8x7B-Instruct", "mistral_7b_instruct": "Mistral-7B-Instruct"} 25 | 26 | 27 | def plot_3d_feat_sub(ax, obj, seq_id, layer_id, model_name): 28 | num_tokens = len(obj[f"seq"]) 29 | num_channels = obj[f"{layer_id}"].shape[2] 30 | inp_seq = obj[f"seq"] 31 | inp_seq = [x if x != "<0x0A>" else r"\n" for x in inp_seq] 32 | xdata = np.array([np.linspace(0,num_tokens-1,num_tokens) for i in range(num_channels)]) 33 | ydata = np.array([np.ones(num_tokens) * i for i in range(num_channels)]) 34 | zdata = obj[f"{layer_id}"][0].abs().numpy().T 35 | ax.plot_wireframe(xdata, ydata, zdata, rstride=0, color="royalblue", linewidth=2.5) 36 | 37 | ax.set_xticks(np.linspace(0,num_tokens-1,num_tokens), inp_seq, 38 | rotation=50, fontsize=16) 39 | 40 | ax.set_zticks([0, 1000, 2000], ["0", "1k", "2k"], fontsize=15) 41 | ax.set_yticks([1415, 2533], [1415, 2533], fontsize=15, fontweight="heavy") 42 | ax.get_xticklabels()[0].set_weight("heavy") 43 | 44 | if seq_id in [0, 1]: 45 | ax.get_xticklabels()[3].set_weight("heavy") 46 | 47 | ax.set_title(MODEL_TITLE_DICT[model_name], fontsize=18, fontweight="bold", y=1.015) 48 | 49 | plt.setp(ax.get_xticklabels(), rotation=50, ha="right", va="center", 50 | rotation_mode="anchor") 51 | plt.setp(ax.get_yticklabels(), ha="left", 52 | rotation_mode="anchor") 53 | 54 | ax.tick_params(axis='x', which='major', pad=-4) 55 | ax.tick_params(axis='y', which='major', pad=-5) 56 | ax.tick_params(axis='z', which='major', pad=-1) 57 | ax.set_zlim(0,2400) 58 | 59 | def plot_3d_feat(obj, layer_id, model_name, savedir): 60 | fig = plt.figure(figsize=(14,6)) 61 | fig.tight_layout() # Or equivalently, "plt.tight_layout()" 62 | plt.subplots_adjust(wspace=0.13) 63 | 64 | # for i in range(3): 65 | ax = fig.add_subplot(1,1, 1, projection='3d') 66 | plot_3d_feat_sub(ax, obj, 0, layer_id, model_name) 67 | plt.savefig(os.path.join(savedir, f"{model_name}_layer_{layer_id+1}.png"), bbox_inches="tight", dpi=200) 68 | 69 | 70 | def plot_layer_ax_sub(ax, mean, model_name): 71 | colors = ["cornflowerblue", "mediumseagreen", "C4", "teal", "dimgrey"] 72 | 73 | x_axis = np.arange(mean.shape[-1])+1 74 | for i in range(3): 75 | ax.plot(x_axis, mean[i], label=f"Top {i+1}", color=colors[i], 76 | linestyle="-", marker="o", markerfacecolor='none', markersize=5) 77 | 78 | ax.plot(x_axis, mean[-1], label=f"Median", color=colors[-1], 79 | linestyle="-", marker="v", markerfacecolor='none', markersize=5) 80 | 81 | ax.set_title(MODEL_TITLE_DICT[model_name], fontsize=18, fontweight="bold") 82 | 83 | num_layers = mean.shape[1] 84 | xtick_label = [1, num_layers//4, num_layers//2, num_layers*3//4, num_layers] 85 | ax.set_xticks(xtick_label, xtick_label, fontsize=16) 86 | 87 | ax.set_xlabel('Layers', fontsize=18, labelpad=0.8) 88 | ax.set_ylabel("Magnitudes", fontsize=18) 89 | ax.tick_params(axis='x', which='major', pad=1.0) 90 | ax.tick_params(axis='y', which='major', pad=0.4) 91 | ax.grid(axis='x', color='0.75') 92 | 93 | def plot_layer_ax(obj, model_name, savedir): 94 | fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(7.5, 4.5)) 95 | fig.tight_layout() # Or equivalently, "plt.tight_layout()" 96 | plt.subplots_adjust(wspace=0.13) 97 | 98 | mean = np.mean(obj,axis=0) 99 | plot_layer_ax_sub(axs, mean, model_name) 100 | leg = axs.legend( 101 | loc='center', bbox_to_anchor=(0.5, -0.25), 102 | ncol=4, fancybox=True, prop={'size': 14} 103 | ) 104 | leg.get_frame().set_edgecolor('silver') 105 | leg.get_frame().set_linewidth(1.0) 106 | 107 | plt.savefig(os.path.join(savedir, f"{model_name}.png"), bbox_inches="tight", dpi=200) 108 | 109 | 110 | def plot_attn_sub(ax, corr, model_name, layer_id): 111 | mask = np.zeros_like(corr) 112 | mask[np.triu_indices_from(mask, k=1)] = True 113 | sns.heatmap(corr, mask=mask, square=True, ax=ax, 114 | cmap="YlGnBu",cbar_kws={"shrink": 1.0, "pad": 0.01, "aspect":50}) 115 | 116 | ax.set_facecolor("whitesmoke") 117 | cax = ax.figure.axes[-1] 118 | cax.tick_params(labelsize=18) 119 | 120 | ax.tick_params(axis='x', which='major') 121 | ax.set(xticklabels=[]) 122 | ax.set(yticklabels=[]) 123 | ax.tick_params(left=False, bottom=False) 124 | ax.set_title(f"{MODEL_TITLE_DICT[model_name]}, Layer {layer_id+1}", fontsize=24, fontweight="bold") 125 | 126 | 127 | def plot_attn(attn_logits, model_name, layer_id, savedir): 128 | fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(8, 4.75)) 129 | fig.tight_layout() # Or equivalently, "plt.tight_layout()" 130 | plt.subplots_adjust(wspace=0.15) 131 | 132 | corr = attn_logits.numpy()[0].mean(0) 133 | corr = corr.astype("float64") 134 | 135 | plot_attn_sub(axs, corr, model_name, layer_id) 136 | plt.savefig(os.path.join(savedir, f"{model_name}_layer{layer_id+1}.pdf"), bbox_inches="tight", dpi=200) -------------------------------------------------------------------------------- /lib/plot_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | import matplotlib 8 | import matplotlib.gridspec as gridspec 9 | from matplotlib.patheffects import withStroke 10 | from collections import defaultdict 11 | from mpl_toolkits import mplot3d 12 | from mpl_toolkits.mplot3d import axes3d 13 | 14 | matplotlib.rcParams['pgf.texsystem'] = 'pdflatex' 15 | matplotlib.rcParams.update({ 16 | # 'font.family': 'Arial', 17 | 'font.size': 18, 18 | 'axes.labelsize': 20, 19 | 'axes.titlesize': 24, 20 | 'figure.titlesize': 28 21 | }) 22 | matplotlib.rcParams['text.usetex'] = False 23 | 24 | MODEL_TITLE_DICT={"llama2_7b": "LLaMA-2-7B", "mistral_7b": "Mistral-7B", 25 | "llama2_13b_chat": "LLaMA-2-13B-chat", "llama2_70b_chat": "LLaMA-2-70B-chat", 26 | "llama2_7b_chat": "LLaMA-2-7B-chat", "llama2_13b": "LLaMA-2-13B", "llama2_70b": "LLaMA-2-70B", 27 | "mistral_moe":"Mixtral-8x7B", "falcon_7b": "Falcon-7B", "falcon_40b": "Falcon-40B", "phi-2": "Phi-2", 28 | "opt_7b":"OPT-7B", "opt_13b": "OPT-13B", "opt_30b": "OPT-30B", "opt_66b": "OPT-66B", 29 | "mpt_7b": "MPT-7B", "mpt_30b": "MPT-30B", "pythia_7b": "Pythia-7B", "pythia_12b": "Pythia-12B", 30 | "gpt2": "GPT-2", "gpt2_large": "GPT-2-Large", "gpt2_xl": "GPT-2-XL", "gpt2_medium": "GPT-2-Medium", 31 | "mistral_moe_instruct": "Mixtral-8x7B-Instruct", "mistral_7b_instruct": "Mistral-7B-Instruct"} 32 | 33 | 34 | def plot_3d_feat_sub(ax, obj, seq_id, layer_id, model_name): 35 | num_tokens = len(obj[f"seq"]) 36 | num_channels = obj[f"{layer_id}"].shape[2] 37 | inp_seq = obj[f"seq"] 38 | inp_seq = [x if x != "<0x0A>" else r"\n" for x in inp_seq] 39 | xdata = np.array([np.linspace(0,num_tokens-1,num_tokens) for i in range(num_channels)]) 40 | ydata = np.array([np.ones(num_tokens) * i for i in range(num_channels)]) 41 | zdata = obj[f"{layer_id}"][0].abs().numpy().T 42 | ax.plot_wireframe(xdata, ydata, zdata, rstride=0, color="royalblue", linewidth=2.5) 43 | 44 | ax.set_xticks(np.linspace(0,num_tokens-1,num_tokens), inp_seq, 45 | rotation=50, fontsize=16) 46 | 47 | ax.set_zticks([0, 1000, 2000], ["0", "1k", "2k"], fontsize=15) 48 | ax.set_yticks([1415, 2533], [1415, 2533], fontsize=15, fontweight="heavy") 49 | ax.get_xticklabels()[0].set_weight("heavy") 50 | 51 | if seq_id in [0, 1]: 52 | ax.get_xticklabels()[3].set_weight("heavy") 53 | 54 | ax.set_title(MODEL_TITLE_DICT[model_name], fontsize=18, fontweight="bold", y=1.015) 55 | 56 | plt.setp(ax.get_xticklabels(), rotation=50, ha="right", va="center", 57 | rotation_mode="anchor") 58 | plt.setp(ax.get_yticklabels(), ha="left", 59 | rotation_mode="anchor") 60 | 61 | ax.tick_params(axis='x', which='major', pad=-4) 62 | ax.tick_params(axis='y', which='major', pad=-5) 63 | ax.tick_params(axis='z', which='major', pad=-1) 64 | ax.set_zlim(0,2400) 65 | 66 | def plot_3d_feat(obj, layer_id, model_name, savedir): 67 | fig = plt.figure(figsize=(14,6)) 68 | fig.tight_layout() # Or equivalently, "plt.tight_layout()" 69 | plt.subplots_adjust(wspace=0.13) 70 | 71 | # for i in range(3): 72 | ax = fig.add_subplot(1,1, 1, projection='3d') 73 | plot_3d_feat_sub(ax, obj, 0, layer_id, model_name) 74 | plt.savefig(os.path.join(savedir, f"{model_name}_layer_{layer_id+1}.png"), bbox_inches="tight", dpi=200) 75 | 76 | 77 | def plot_layer_ax_sub(ax, mean, model_name): 78 | colors = ["cornflowerblue", "mediumseagreen", "C4", "teal", "dimgrey"] 79 | 80 | x_axis = np.arange(mean.shape[-1])+1 81 | for i in range(3): 82 | ax.plot(x_axis, mean[i], label=f"Top {i+1}", color=colors[i], 83 | linestyle="-", marker="o", markerfacecolor='none', markersize=5) 84 | 85 | ax.plot(x_axis, mean[-1], label=f"Median", color=colors[-1], 86 | linestyle="-", marker="v", markerfacecolor='none', markersize=5) 87 | 88 | ax.set_title(MODEL_TITLE_DICT[model_name], fontsize=18, fontweight="bold") 89 | 90 | num_layers = mean.shape[1] 91 | xtick_label = [1, num_layers//4, num_layers//2, num_layers*3//4, num_layers] 92 | ax.set_xticks(xtick_label, xtick_label, fontsize=16) 93 | 94 | ax.set_xlabel('Layers', fontsize=18, labelpad=0.8) 95 | ax.set_ylabel("Magnitudes", fontsize=18) 96 | ax.tick_params(axis='x', which='major', pad=1.0) 97 | ax.tick_params(axis='y', which='major', pad=0.4) 98 | ax.grid(axis='x', color='0.75') 99 | 100 | def plot_layer_ax(obj, model_name, savedir): 101 | fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(7.5, 4.5)) 102 | fig.tight_layout() # Or equivalently, "plt.tight_layout()" 103 | plt.subplots_adjust(wspace=0.13) 104 | 105 | mean = np.mean(obj,axis=0) 106 | plot_layer_ax_sub(axs, mean, model_name) 107 | leg = axs.legend( 108 | loc='center', bbox_to_anchor=(0.5, -0.25), 109 | ncol=4, fancybox=True, prop={'size': 14} 110 | ) 111 | leg.get_frame().set_edgecolor('silver') 112 | leg.get_frame().set_linewidth(1.0) 113 | 114 | plt.savefig(os.path.join(savedir, f"{model_name}.png"), bbox_inches="tight", dpi=200) 115 | 116 | 117 | def plot_attn_sub(ax, corr, model_name, layer_id): 118 | mask = np.zeros_like(corr) 119 | mask[np.triu_indices_from(mask, k=1)] = True 120 | sns.heatmap(corr, mask=mask, square=True, ax=ax, 121 | cmap="YlGnBu",cbar_kws={"shrink": 1.0, "pad": 0.01, "aspect":50}) 122 | 123 | ax.set_facecolor("whitesmoke") 124 | cax = ax.figure.axes[-1] 125 | cax.tick_params(labelsize=18) 126 | 127 | ax.tick_params(axis='x', which='major') 128 | ax.set(xticklabels=[]) 129 | ax.set(yticklabels=[]) 130 | ax.tick_params(left=False, bottom=False) 131 | ax.set_title(f"{MODEL_TITLE_DICT[model_name]}, Layer {layer_id+1}", fontsize=24, fontweight="bold") 132 | 133 | 134 | def plot_attn(attn_logits, model_name, layer_id, savedir): 135 | fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(8, 4.75)) 136 | fig.tight_layout() # Or equivalently, "plt.tight_layout()" 137 | plt.subplots_adjust(wspace=0.15) 138 | 139 | corr = attn_logits.numpy()[0].mean(0) 140 | corr = corr.astype("float64") 141 | 142 | plot_attn_sub(axs, corr, model_name, layer_id) 143 | plt.savefig(os.path.join(savedir, f"{model_name}_layer{layer_id+1}.pdf"), bbox_inches="tight", dpi=200) -------------------------------------------------------------------------------- /monkey_patch/modify_phi2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import types 3 | from typing import Optional, Tuple 4 | 5 | import torch 6 | from torch import nn 7 | import torch.utils.checkpoint 8 | from transformers.models.llama.modeling_llama import ( 9 | apply_rotary_pos_emb, 10 | repeat_kv, 11 | ) 12 | 13 | 14 | def phi2_custom_attention_forward( 15 | self, 16 | hidden_states, 17 | attention_mask = None, 18 | position_ids = None, 19 | past_key_value = None, 20 | output_attentions = False, 21 | use_cache = False, 22 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 23 | bsz, q_len, _ = hidden_states.size() 24 | 25 | query_states = self.q_proj(hidden_states) 26 | key_states = self.k_proj(hidden_states) 27 | value_states = self.v_proj(hidden_states) 28 | 29 | ################################################################## 30 | self.query_states = query_states.detach().cpu().clone() 31 | self.key_states = key_states.detach().cpu().clone() 32 | self.value_states = value_states.detach().cpu().clone() 33 | ################################################################## 34 | 35 | if self.qk_layernorm: 36 | query_states = self.q_layernorm(query_states) 37 | key_states = self.k_layernorm(key_states) 38 | 39 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 40 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 41 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 42 | 43 | kv_seq_len = key_states.shape[-2] 44 | if past_key_value is not None: 45 | if self.layer_idx is None: 46 | raise ValueError( 47 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 48 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 49 | "with a layer index." 50 | ) 51 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | 54 | # Partial rotary embedding 55 | query_rot, query_pass = ( 56 | query_states[..., : self.rotary_emb.dim], 57 | query_states[..., self.rotary_emb.dim :], 58 | ) 59 | key_rot, key_pass = ( 60 | key_states[..., : self.rotary_emb.dim], 61 | key_states[..., self.rotary_emb.dim :], 62 | ) 63 | # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] 64 | query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) 65 | 66 | # [batch_size, seq_length, num_heads, head_dim] 67 | query_states = torch.cat((query_rot, query_pass), dim=-1) 68 | key_states = torch.cat((key_rot, key_pass), dim=-1) 69 | 70 | if past_key_value is not None: 71 | cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} 72 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 73 | 74 | key_states = repeat_kv(key_states, self.num_key_value_groups) 75 | value_states = repeat_kv(value_states, self.num_key_value_groups) 76 | 77 | # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow 78 | attn_weights = torch.matmul( 79 | query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) 80 | ) / math.sqrt(self.head_dim) 81 | 82 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 83 | raise ValueError( 84 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 85 | f" {attn_weights.size()}" 86 | ) 87 | 88 | if attention_mask is not None: 89 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 90 | raise ValueError( 91 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 92 | ) 93 | attn_weights = attn_weights + attention_mask 94 | 95 | # ################################################### 96 | self.attn_logits = attn_weights 97 | # ################################################### 98 | 99 | # upcast attention to fp32 100 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) 101 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) 102 | 103 | # ################################################### 104 | self.attn_probs = attn_weights 105 | # ################################################### 106 | attn_output = torch.matmul(attn_weights, value_states) 107 | 108 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 109 | raise ValueError( 110 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 111 | f" {attn_output.size()}" 112 | ) 113 | 114 | attn_output = attn_output.transpose(1, 2).contiguous() 115 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 116 | 117 | attn_output = self.dense(attn_output) 118 | 119 | if not output_attentions: 120 | attn_weights = None 121 | 122 | return attn_output, attn_weights, past_key_value 123 | 124 | def enable_phi2_custom_attention(layer, layer_id): 125 | modified_module = layer.self_attn 126 | modified_module.layer_id = layer_id 127 | modified_module.forward = types.MethodType(phi2_custom_attention_forward, modified_module) 128 | 129 | return modified_module 130 | 131 | def phi2_custom_decoderlayer_forward( 132 | self, 133 | hidden_states: torch.Tensor, 134 | attention_mask: Optional[torch.Tensor] = None, 135 | position_ids: Optional[torch.LongTensor] = None, 136 | output_attentions: Optional[bool] = False, 137 | use_cache: Optional[bool] = False, 138 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 139 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 140 | residual = hidden_states 141 | 142 | hidden_states = self.input_layernorm(hidden_states) 143 | 144 | # Self Attention 145 | attn_outputs, self_attn_weights, present_key_value = self.self_attn( 146 | hidden_states=hidden_states, 147 | attention_mask=attention_mask, 148 | position_ids=position_ids, 149 | past_key_value=past_key_value, 150 | output_attentions=output_attentions, 151 | use_cache=use_cache, 152 | ) 153 | attn_outputs = self.resid_dropout(attn_outputs) 154 | 155 | feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) 156 | hidden_states = attn_outputs + feed_forward_hidden_states + residual 157 | 158 | self.feat = hidden_states.clone().detach().cpu().double() 159 | outputs = (hidden_states,) 160 | 161 | if output_attentions: 162 | outputs += (self_attn_weights,) 163 | 164 | if use_cache: 165 | outputs += (present_key_value,) 166 | 167 | return outputs 168 | 169 | def enable_phi2_custom_decoderlayer(layer, layer_id): 170 | layer.layer_id = layer_id 171 | layer.forward = types.MethodType( 172 | phi2_custom_decoderlayer_forward, layer 173 | ) 174 | -------------------------------------------------------------------------------- /monkey_patch/modify_mistral.py: -------------------------------------------------------------------------------- 1 | import math 2 | import types 3 | import warnings 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | from torch import nn 8 | import torch.utils.checkpoint 9 | from transformers.models.mistral.modeling_mistral import ( 10 | apply_rotary_pos_emb, 11 | repeat_kv, 12 | ) 13 | 14 | 15 | def mistral_custom_decoderlayer_forward( 16 | self, 17 | hidden_states: torch.Tensor, 18 | attention_mask: Optional[torch.Tensor] = None, 19 | position_ids: Optional[torch.LongTensor] = None, 20 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 21 | output_attentions: Optional[bool] = False, 22 | use_cache: Optional[bool] = False, 23 | **kwargs, 24 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 25 | if "padding_mask" in kwargs: 26 | warnings.warn( 27 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 28 | ) 29 | 30 | residual = hidden_states 31 | 32 | hidden_states = self.input_layernorm(hidden_states) 33 | 34 | # Self Attention 35 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 36 | hidden_states=hidden_states, 37 | attention_mask=attention_mask, 38 | position_ids=position_ids, 39 | past_key_value=past_key_value, 40 | output_attentions=output_attentions, 41 | use_cache=use_cache, 42 | ) 43 | 44 | if residual.device.index != hidden_states.device.index: 45 | residual = residual.to(hidden_states.device) 46 | hidden_states = residual + hidden_states 47 | 48 | # Fully Connected 49 | residual = hidden_states 50 | hidden_states = self.post_attention_layernorm(hidden_states) 51 | hidden_states = self.mlp(hidden_states) 52 | hidden_states = residual + hidden_states 53 | 54 | self.feat = hidden_states.clone().detach().cpu().double() 55 | 56 | outputs = (hidden_states,) 57 | 58 | if output_attentions: 59 | outputs += (self_attn_weights,) 60 | 61 | if use_cache: 62 | outputs += (present_key_value,) 63 | 64 | return outputs 65 | 66 | def enable_mistral_custom_decoderlayer(layer, layer_id): 67 | """ 68 | replace the forward function of MistralDecoderLayer with a custom forward function `mistral_custom_decoderlayer_forward` 69 | """ 70 | layer.layer_id = layer_id 71 | layer.forward = types.MethodType( 72 | mistral_custom_decoderlayer_forward, layer 73 | ) 74 | 75 | def mistral_custom_attention_forward( 76 | self, 77 | hidden_states, 78 | attention_mask = None, 79 | position_ids = None, 80 | past_key_value = None, 81 | output_attentions: bool = False, 82 | use_cache: bool = False, 83 | **kwargs, 84 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 85 | if "padding_mask" in kwargs: 86 | warnings.warn( 87 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 88 | ) 89 | bsz, q_len, _ = hidden_states.size() 90 | 91 | query_states = self.q_proj(hidden_states) 92 | key_states = self.k_proj(hidden_states) 93 | value_states = self.v_proj(hidden_states) 94 | 95 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 96 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 97 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 98 | 99 | # ################################################### 100 | # self.hidden_states = hidden_states.detach().cpu().clone() 101 | self.query_states = query_states.detach().cpu().clone() 102 | self.key_states = key_states.detach().cpu().clone() 103 | self.value_states = value_states.detach().cpu().clone() 104 | # ################################################### 105 | 106 | kv_seq_len = key_states.shape[-2] 107 | if past_key_value is not None: 108 | if self.layer_idx is None: 109 | raise ValueError( 110 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 111 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 112 | "with a layer index." 113 | ) 114 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 115 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 116 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 117 | 118 | if past_key_value is not None: 119 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 120 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 121 | 122 | # repeat k/v heads if n_kv_heads < n_heads 123 | key_states = repeat_kv(key_states, self.num_key_value_groups) 124 | value_states = repeat_kv(value_states, self.num_key_value_groups) 125 | 126 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 127 | 128 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 129 | raise ValueError( 130 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 131 | f" {attn_weights.size()}" 132 | ) 133 | 134 | if attention_mask is not None: 135 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 136 | raise ValueError( 137 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 138 | ) 139 | 140 | attn_weights = attn_weights + attention_mask 141 | 142 | # ################################################### 143 | self.attn_logits = attn_weights 144 | # ################################################### 145 | 146 | # upcast attention to fp32 147 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 148 | 149 | # ################################################### 150 | self.attn_probs = attn_weights 151 | # ################################################### 152 | 153 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) 154 | attn_output = torch.matmul(attn_weights, value_states) 155 | 156 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 157 | raise ValueError( 158 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 159 | f" {attn_output.size()}" 160 | ) 161 | 162 | attn_output = attn_output.transpose(1, 2).contiguous() 163 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 164 | 165 | attn_output = self.o_proj(attn_output) 166 | 167 | if not output_attentions: 168 | attn_weights = None 169 | 170 | return attn_output, attn_weights, past_key_value 171 | 172 | def enable_mistral_custom_attention(layer, layer_id): 173 | modified_module = layer.self_attn 174 | modified_module.layer_id = layer_id 175 | modified_module.forward = types.MethodType(mistral_custom_attention_forward, modified_module) 176 | 177 | return modified_module 178 | -------------------------------------------------------------------------------- /monkey_patch/modify_llama.py: -------------------------------------------------------------------------------- 1 | import math 2 | import types 3 | import warnings 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | import torch.utils.checkpoint 10 | from transformers.models.llama.modeling_llama import ( 11 | apply_rotary_pos_emb, 12 | repeat_kv, 13 | rotate_half, 14 | ) 15 | 16 | 17 | def llama_custom_decoderlayer_forward( 18 | self, 19 | hidden_states: torch.Tensor, 20 | attention_mask: Optional[torch.Tensor] = None, 21 | position_ids: Optional[torch.LongTensor] = None, 22 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 23 | output_attentions: Optional[bool] = False, 24 | use_cache: Optional[bool] = False, 25 | **kwargs, 26 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 27 | residual = hidden_states 28 | 29 | hidden_states = self.input_layernorm(hidden_states) 30 | 31 | # Self Attention 32 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 33 | hidden_states=hidden_states, 34 | attention_mask=attention_mask, 35 | position_ids=position_ids, 36 | past_key_value=past_key_value, 37 | output_attentions=output_attentions, 38 | use_cache=use_cache, 39 | **kwargs, 40 | ) 41 | 42 | if residual.device.index != hidden_states.device.index: 43 | residual = residual.to(hidden_states.device) 44 | hidden_states = residual + hidden_states 45 | 46 | # Fully Connected 47 | residual = hidden_states 48 | hidden_states = self.post_attention_layernorm(hidden_states) 49 | hidden_states = self.mlp(hidden_states) 50 | hidden_states = residual + hidden_states 51 | 52 | self.feat = hidden_states.clone().detach().cpu().double() 53 | 54 | outputs = (hidden_states,) 55 | 56 | if output_attentions: 57 | outputs += (self_attn_weights,) 58 | 59 | if use_cache: 60 | outputs += (present_key_value,) 61 | 62 | return outputs 63 | 64 | def enable_llama_custom_decoderlayer(layer, layer_id): 65 | """ 66 | replace the forward function of LlamaDecoderLayer with a custom forward function `llama_custom_decoderlayer_forward` 67 | """ 68 | layer.layer_id = layer_id 69 | layer.forward = types.MethodType( 70 | llama_custom_decoderlayer_forward, layer 71 | ) 72 | 73 | 74 | def apply_rotary_pos_emb_single(q, k, cos, sin, position_ids): 75 | cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] 76 | sin = sin[position_ids].unsqueeze(1) 77 | q_embed = (q * cos) + (rotate_half(q) * sin) 78 | k_embed = (k * cos) + (rotate_half(k) * sin) 79 | return q_embed, k_embed 80 | 81 | def llama_custom_attention_forward( 82 | self, 83 | hidden_states, 84 | attention_mask = None, 85 | position_ids = None, 86 | past_key_value = None, 87 | output_attentions = False, 88 | use_cache = False, 89 | **kwargs, 90 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 91 | if "padding_mask" in kwargs: 92 | warnings.warn( 93 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 94 | ) 95 | 96 | bsz, q_len, _ = hidden_states.size() 97 | 98 | query_states = self.q_proj(hidden_states) 99 | key_states = self.k_proj(hidden_states) 100 | value_states = self.v_proj(hidden_states) 101 | 102 | ################################################################## 103 | self.query_states = query_states.detach().cpu().clone() 104 | self.key_states = key_states.detach().cpu().clone() 105 | self.value_states = value_states.detach().cpu().clone() 106 | # ################################################### 107 | 108 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 109 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 110 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 111 | 112 | # ################################################################## 113 | # self.query_states = query_states.detach().cpu().clone() 114 | # self.key_states = key_states.detach().cpu().clone() 115 | # self.value_states = value_states.detach().cpu().clone() 116 | # # ################################################### 117 | 118 | kv_seq_len = key_states.shape[-2] 119 | if past_key_value is not None: 120 | if self.layer_idx is None: 121 | raise ValueError( 122 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 123 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 124 | "with a layer index." 125 | ) 126 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 127 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 128 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 129 | 130 | if past_key_value is not None: 131 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 132 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 133 | 134 | key_states = repeat_kv(key_states, self.num_key_value_groups) 135 | value_states = repeat_kv(value_states, self.num_key_value_groups) 136 | 137 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 138 | 139 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 140 | raise ValueError( 141 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 142 | f" {attn_weights.size()}" 143 | ) 144 | 145 | if attention_mask is not None: 146 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 147 | raise ValueError( 148 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 149 | ) 150 | attn_weights = attn_weights + attention_mask 151 | 152 | # ################################################### 153 | self.attn_logits = attn_weights 154 | # ################################################### 155 | 156 | # upcast attention to fp32 157 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 158 | 159 | # ################################################### 160 | self.attn_probs = attn_weights 161 | # ################################################### 162 | 163 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) 164 | attn_output = torch.matmul(attn_weights, value_states) 165 | 166 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 167 | raise ValueError( 168 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 169 | f" {attn_output.size()}" 170 | ) 171 | 172 | attn_output = attn_output.transpose(1, 2).contiguous() 173 | 174 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 175 | 176 | if self.config.pretraining_tp > 1: 177 | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) 178 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) 179 | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) 180 | else: 181 | attn_output = self.o_proj(attn_output) 182 | 183 | if not output_attentions: 184 | attn_weights = None 185 | 186 | return attn_output, attn_weights, past_key_value 187 | 188 | def enable_llama_custom_attention(layer, layer_id): 189 | """ 190 | replace the forward function of LlamaAttention with a custom forward function `llama_custom_attention_forward` 191 | """ 192 | modified_module = layer.self_attn 193 | modified_module.layer_id = layer_id 194 | modified_module.forward = types.MethodType(llama_custom_attention_forward, modified_module) 195 | 196 | return modified_module -------------------------------------------------------------------------------- /gpt-2/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import time 4 | from contextlib import nullcontext 5 | 6 | import numpy as np 7 | import torch 8 | from torch.distributed import destroy_process_group, init_process_group 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | 11 | # ----------------------------------------------------------------------------- 12 | # default config values designed to train a gpt2 (124M) on OpenWebText 13 | # I/O 14 | out_dir = 'logs_ckpt/' 15 | save_dir = "results/" 16 | data_dir = "data/" 17 | model_type="gpt2_default" 18 | num_reg = 0 19 | 20 | eval_interval = 2000 21 | log_interval = 1 22 | eval_iters = 200 23 | eval_only = False # if True, script exits right after the first eval 24 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 25 | init_from = 'resume' # 'scratch' or 'resume' or 'gpt2*' 26 | ckpt_iter = 50000 27 | # wandb logging 28 | wandb_log = False # disabled by default 29 | wandb_project = 'owt' 30 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 31 | # data 32 | dataset = 'openwebtext' 33 | gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes 34 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 35 | block_size = 1024 36 | # model 37 | n_layer = 12 38 | n_head = 12 39 | n_embd = 768 40 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 41 | bias = False # do we use bias inside LayerNorm and Linear layers? 42 | # adamw optimizer 43 | optim_name="adam" 44 | learning_rate = 6e-4 # max learning rate 45 | max_iters = 600000 # total number of training iterations 46 | weight_decay = 1e-1 47 | beta1 = 0.9 48 | beta2 = 0.95 49 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 50 | # learning rate decay settings 51 | decay_lr = True # whether to decay the learning rate 52 | warmup_iters = 2000 # how many steps to warm up for 53 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 54 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 55 | # DDP settings 56 | backend = 'nccl' # 'nccl', 'gloo', etc. 57 | # system 58 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 59 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 60 | compile = True # use PyTorch 2.0 to compile the model to be faster 61 | # ----------------------------------------------------------------------------- 62 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 63 | exec(open('configurator.py').read()) # overrides from command line or config file 64 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 65 | # ----------------------------------------------------------------------------- 66 | 67 | if model_type == "gpt2_default": 68 | from model_default import GPTConfig, GPT 69 | elif model_type == "gpt2_sink": 70 | from model_sink import GPTConfig, GPT 71 | elif model_type == "gpt2_attn_bias": 72 | from model_attn_bias import GPTConfig, GPT 73 | else: 74 | raise ValueError(f"model_type {model_type} not supported") 75 | 76 | # various inits, derived attributes, I/O setup 77 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 78 | if ddp: 79 | init_process_group(backend=backend) 80 | ddp_rank = int(os.environ['RANK']) 81 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 82 | ddp_world_size = int(os.environ['WORLD_SIZE']) 83 | device = f'cuda:{ddp_local_rank}' 84 | torch.cuda.set_device(device) 85 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 86 | seed_offset = ddp_rank # each process gets a different seed 87 | # world_size number of processes will be training simultaneously, so we can scale 88 | # down the desired gradient accumulation iterations per process proportionally 89 | assert gradient_accumulation_steps % ddp_world_size == 0 90 | gradient_accumulation_steps //= ddp_world_size 91 | else: 92 | # if not ddp, we are running on a single gpu, and one process 93 | master_process = True 94 | seed_offset = 0 95 | ddp_world_size = 1 96 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size 97 | print(f"tokens per iteration will be: {tokens_per_iter:,}") 98 | 99 | if master_process: 100 | os.makedirs(out_dir, exist_ok=True) 101 | torch.manual_seed(1337 + seed_offset) 102 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 103 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 104 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 105 | # note: float16 data type will automatically use a GradScaler 106 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 107 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 108 | 109 | # poor man's data loader 110 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 111 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 112 | def get_batch(split): 113 | data = train_data if split == 'train' else val_data 114 | ix = torch.randint(len(data) - block_size, (batch_size,)) 115 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 116 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 117 | if device_type == 'cuda': 118 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 119 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 120 | else: 121 | x, y = x.to(device), y.to(device) 122 | return x, y 123 | 124 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 125 | iter_num = 0 126 | best_val_loss = 1e9 127 | 128 | # attempt to derive vocab_size from the dataset 129 | meta_path = os.path.join(data_dir, 'meta.pkl') 130 | meta_vocab_size = None 131 | if os.path.exists(meta_path): 132 | with open(meta_path, 'rb') as f: 133 | meta = pickle.load(f) 134 | meta_vocab_size = meta['vocab_size'] 135 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 136 | 137 | # model init 138 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 139 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line 140 | if init_from == 'scratch': 141 | # init a new model from scratch 142 | print("Initializing a new model from scratch") 143 | # determine the vocab size we'll use for from-scratch training 144 | if meta_vocab_size is None: 145 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 146 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 147 | gptconf = GPTConfig(**model_args) 148 | model = GPT(gptconf) 149 | elif init_from == 'resume': 150 | print(f"Resuming training from {out_dir}") 151 | # resume training from a checkpoint. 152 | ckpt_path = os.path.join(out_dir, f'ckpt_{ckpt_iter}.pt') 153 | checkpoint = torch.load(ckpt_path, map_location=device) 154 | checkpoint_model_args = checkpoint['model_args'] 155 | # force these config attributes to be equal otherwise we can't even resume training 156 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 157 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 158 | model_args[k] = checkpoint_model_args[k] 159 | # create the model 160 | gptconf = GPTConfig(**model_args) 161 | model = GPT(gptconf) 162 | state_dict = checkpoint['model'] 163 | # fix the keys of the state dictionary :( 164 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 165 | unwanted_prefix = '_orig_mod.' 166 | for k,v in list(state_dict.items()): 167 | if k.startswith(unwanted_prefix): 168 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 169 | model.load_state_dict(state_dict) 170 | iter_num = checkpoint['iter_num'] 171 | best_val_loss = checkpoint['best_val_loss'] 172 | elif init_from.startswith('gpt2'): 173 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 174 | # initialize from OpenAI GPT-2 weights 175 | override_args = dict(dropout=dropout) 176 | model = GPT.from_pretrained(init_from, override_args) 177 | # read off the created config params, so we can store them into checkpoint correctly 178 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 179 | model_args[k] = getattr(model.config, k) 180 | # crop down the model block size if desired, using model surgery 181 | if block_size < model.config.block_size: 182 | model.crop_block_size(block_size) 183 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 184 | model.to(device) 185 | 186 | 187 | # initialize a GradScaler. If enabled=False scaler is a no-op 188 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 189 | 190 | # optimizer 191 | optimizer = model.configure_optimizers(optim_name, weight_decay, learning_rate, (beta1, beta2), device_type) 192 | if init_from == 'resume': 193 | optimizer.load_state_dict(checkpoint['optimizer']) 194 | checkpoint = None # free up memory 195 | 196 | # compile the model 197 | if compile: 198 | print("compiling the model... (takes a ~minute)") 199 | unoptimized_model = model 200 | model = torch.compile(model) # requires PyTorch 2.0 201 | 202 | # wrap model into DDP container 203 | if ddp: 204 | model = DDP(model, device_ids=[ddp_local_rank]) 205 | 206 | # helps estimate an arbitrarily accurate loss over either split using many batches 207 | @torch.no_grad() 208 | def estimate_loss(): 209 | out = {} 210 | model.eval() 211 | for split in ['train', 'val']: 212 | losses = torch.zeros(eval_iters) 213 | for k in range(eval_iters): 214 | X, Y = get_batch(split) 215 | with ctx: 216 | logits, loss = model(X, Y) 217 | losses[k] = loss.item() 218 | out[split] = losses.mean() 219 | model.train() 220 | return out 221 | 222 | # training loop 223 | t0 = time.time() 224 | local_iter_num = 0 # number of iterations in the lifetime of this process 225 | raw_model = model.module if ddp else model # unwrap DDP container if needed 226 | running_mfu = -1.0 227 | 228 | losses = estimate_loss() 229 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 230 | -------------------------------------------------------------------------------- /gpt-2/analyze.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from contextlib import nullcontext 4 | import types 5 | 6 | import numpy as np 7 | import torch 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from torch.distributed import init_process_group, destroy_process_group 10 | 11 | import tiktoken 12 | from plot_gpt2 import * 13 | 14 | # ----------------------------------------------------------------------------- 15 | # default config values designed to train a gpt2 (124M) on OpenWebText 16 | # I/O 17 | out_dir = 'logs_ckpt/' 18 | save_dir = "results/" 19 | data_dir = "data/" 20 | model_type="gpt2_default" 21 | num_reg = 0 22 | 23 | eval_interval = 2000 24 | log_interval = 1 25 | eval_iters = 1 26 | eval_only = False # if True, script exits right after the first eval 27 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 28 | init_from = 'resume' # 'scratch' or 'resume' or 'gpt2*' 29 | ckpt_iter = 50000 30 | # wandb logging 31 | wandb_log = False # disabled by default 32 | wandb_project = 'owt' 33 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 34 | # data 35 | dataset = 'openwebtext' 36 | gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes 37 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 38 | block_size = 1024 39 | # model 40 | n_layer = 12 41 | n_head = 12 42 | n_embd = 768 43 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 44 | bias = False # do we use bias inside LayerNorm and Linear layers? 45 | # adamw optimizer 46 | optim_name="adam" 47 | learning_rate = 6e-4 # max learning rate 48 | max_iters = 600000 # total number of training iterations 49 | weight_decay = 1e-1 50 | beta1 = 0.9 51 | beta2 = 0.95 52 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 53 | # learning rate decay settings 54 | decay_lr = True # whether to decay the learning rate 55 | warmup_iters = 2000 # how many steps to warm up for 56 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 57 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 58 | # DDP settings 59 | backend = 'nccl' # 'nccl', 'gloo', etc. 60 | # system 61 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 62 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 63 | compile = True # use PyTorch 2.0 to compile the model to be faster 64 | # ----------------------------------------------------------------------------- 65 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 66 | exec(open('configurator.py').read()) # overrides from command line or config file 67 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 68 | # ----------------------------------------------------------------------------- 69 | 70 | if model_type == "gpt2_default": 71 | from model_default import GPTConfig, GPT 72 | elif model_type == "gpt2_sink": 73 | from model_sink import GPTConfig, GPT 74 | elif model_type == "gpt2_attn_bias": 75 | from model_attn_bias import GPTConfig, GPT 76 | else: 77 | raise ValueError(f"model_type {model_type} not supported") 78 | 79 | 80 | if not os.path.exists(save_dir): 81 | os.makedirs(save_dir) 82 | 83 | # various inits, derived attributes, I/O setup 84 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 85 | if ddp: 86 | init_process_group(backend=backend) 87 | ddp_rank = int(os.environ['RANK']) 88 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 89 | ddp_world_size = int(os.environ['WORLD_SIZE']) 90 | device = f'cuda:{ddp_local_rank}' 91 | torch.cuda.set_device(device) 92 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 93 | seed_offset = ddp_rank # each process gets a different seed 94 | # world_size number of processes will be training simultaneously, so we can scale 95 | # down the desired gradient accumulation iterations per process proportionally 96 | assert gradient_accumulation_steps % ddp_world_size == 0 97 | gradient_accumulation_steps //= ddp_world_size 98 | else: 99 | # if not ddp, we are running on a single gpu, and one process 100 | master_process = True 101 | seed_offset = 0 102 | ddp_world_size = 1 103 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size 104 | print(f"tokens per iteration will be: {tokens_per_iter:,}") 105 | 106 | if master_process: 107 | os.makedirs(out_dir, exist_ok=True) 108 | torch.manual_seed(1337 + seed_offset) 109 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 110 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 111 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 112 | # note: float16 data type will automatically use a GradScaler 113 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 114 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 115 | 116 | # poor man's data loader 117 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 118 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 119 | def get_batch(split): 120 | data = train_data if split == 'train' else val_data 121 | ix = torch.randint(len(data) - block_size, (batch_size,)) 122 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 123 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 124 | if device_type == 'cuda': 125 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 126 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 127 | else: 128 | x, y = x.to(device), y.to(device) 129 | return x, y 130 | 131 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 132 | iter_num = 0 133 | best_val_loss = 1e9 134 | 135 | # attempt to derive vocab_size from the dataset 136 | meta_path = os.path.join(data_dir, 'meta.pkl') 137 | meta_vocab_size = None 138 | if os.path.exists(meta_path): 139 | with open(meta_path, 'rb') as f: 140 | meta = pickle.load(f) 141 | meta_vocab_size = meta['vocab_size'] 142 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 143 | 144 | # model init 145 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 146 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line 147 | if init_from == 'scratch': 148 | # init a new model from scratch 149 | print("Initializing a new model from scratch") 150 | # determine the vocab size we'll use for from-scratch training 151 | if meta_vocab_size is None: 152 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 153 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 154 | gptconf = GPTConfig(**model_args) 155 | model = GPT(gptconf) 156 | elif init_from == 'resume': 157 | print(f"Resuming training from {out_dir}") 158 | # resume training from a checkpoint. 159 | ckpt_path = os.path.join(out_dir, f'ckpt_{ckpt_iter}.pt') 160 | checkpoint = torch.load(ckpt_path, map_location=device) 161 | checkpoint_model_args = checkpoint['model_args'] 162 | # force these config attributes to be equal otherwise we can't even resume training 163 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 164 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 165 | model_args[k] = checkpoint_model_args[k] 166 | # create the model 167 | gptconf = GPTConfig(**model_args) 168 | model = GPT(gptconf) 169 | state_dict = checkpoint['model'] 170 | # fix the keys of the state dictionary :( 171 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 172 | unwanted_prefix = '_orig_mod.' 173 | for k,v in list(state_dict.items()): 174 | if k.startswith(unwanted_prefix): 175 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 176 | model.load_state_dict(state_dict) 177 | iter_num = checkpoint['iter_num'] 178 | best_val_loss = checkpoint['best_val_loss'] 179 | elif init_from.startswith('gpt2'): 180 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 181 | # initialize from OpenAI GPT-2 weights 182 | override_args = dict(dropout=dropout) 183 | model = GPT.from_pretrained(init_from, override_args) 184 | # read off the created config params, so we can store them into checkpoint correctly 185 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 186 | model_args[k] = getattr(model.config, k) 187 | # crop down the model block size if desired, using model surgery 188 | if block_size < model.config.block_size: 189 | model.crop_block_size(block_size) 190 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 191 | model.to(device) 192 | 193 | 194 | # initialize a GradScaler. If enabled=False scaler is a no-op 195 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 196 | 197 | # optimizer 198 | optimizer = model.configure_optimizers(optim_name, weight_decay, learning_rate, (beta1, beta2), device_type) 199 | if init_from == 'resume': 200 | optimizer.load_state_dict(checkpoint['optimizer']) 201 | checkpoint = None # free up memory 202 | 203 | # compile the model 204 | if compile: 205 | print("compiling the model... (takes a ~minute)") 206 | unoptimized_model = model 207 | model = torch.compile(model) # requires PyTorch 2.0 208 | 209 | # wrap model into DDP container 210 | if ddp: 211 | model = DDP(model, device_ids=[ddp_local_rank]) 212 | 213 | def custom_block_forward(self, x): 214 | x = x + self.attn(self.ln_1(x)) 215 | x = x + self.mlp(self.ln_2(x)) 216 | self.feat = x.cpu().clone().detach() 217 | return x 218 | 219 | 220 | 221 | layers = model.transformer.h 222 | for layer_id in range(len(model.transformer.h)): 223 | layer = layers[layer_id] 224 | layer.forward = types.MethodType(custom_block_forward, layer) 225 | 226 | 227 | stats = [] 228 | eval_iters=1 229 | batch_size=1 230 | for seqid in range(3): 231 | print("current seqid: ", seqid) 232 | X, Y = get_batch("val") 233 | with torch.no_grad(): 234 | model(X, Y) 235 | 236 | seq_np = np.zeros((202, len(layers))) 237 | for layer_id in range(len(layers)): 238 | feat_abs = layers[layer_id].feat.abs() 239 | sort_res = torch.sort(feat_abs.flatten(), descending=True) 240 | seq_np[:200, layer_id] = sort_res.values[:200] 241 | seq_np[-2, layer_id] = torch.mean(feat_abs) 242 | seq_np[-1, layer_id] = torch.median(feat_abs) 243 | 244 | stats.append(seq_np) 245 | 246 | plot_layer_ax_gpt2(np.mean(np.array(stats), axis=0), model_type, save_dir) 247 | 248 | ########################################################################## 249 | enc = tiktoken.get_encoding("gpt2") 250 | encoding = enc.encode("Summer is warm. Winter is cold.") 251 | 252 | layer_id = 5 253 | layer = model.transformer.h[layer_id] 254 | layer.forward = types.MethodType(custom_block_forward, layer) 255 | 256 | 257 | input_seq = torch.Tensor([encoding]).type(torch.int64).to(device) 258 | 259 | with torch.no_grad(): 260 | model(input_seq, None) 261 | 262 | stats = {} 263 | feat_abs = layer.feat.abs() 264 | stats[f"seq0_{layer_id}"] = feat_abs 265 | 266 | inp_seq = ["Summer", "is", "warm", ".", "Winter", "is", "cold", "."] 267 | plot_3d_ax_gpt2(feat_abs, model_type, inp_seq, save_dir) -------------------------------------------------------------------------------- /gpt-2/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import pickle 5 | from contextlib import nullcontext 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from torch.distributed import init_process_group, destroy_process_group 11 | 12 | # ----------------------------------------------------------------------------- 13 | # default config values designed to train a gpt2 (124M) on OpenWebText 14 | # I/O 15 | out_dir = 'logs_ckpt/' 16 | save_dir = "results/" 17 | data_dir = "data/" 18 | model_type = "gpt2_default" 19 | num_reg = 0 20 | 21 | eval_interval = 10000 22 | log_interval = 1 23 | eval_iters = 400 24 | eval_only = False # if True, script exits right after the first eval 25 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 26 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 27 | # wandb logging 28 | wandb_log = False # disabled by default 29 | wandb_project = 'owt' 30 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 31 | # data 32 | dataset = 'openwebtext' 33 | gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes 34 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 35 | block_size = 1024 36 | # model 37 | n_layer = 12 38 | n_head = 12 39 | n_embd = 768 40 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 41 | bias = False # do we use bias inside LayerNorm and Linear layers? 42 | # adamw optimizer 43 | optim_name="adam" 44 | learning_rate = 6e-4 # max learning rate 45 | max_iters = 600000 # total number of training iterations 46 | weight_decay = 1e-1 47 | beta1 = 0.9 48 | beta2 = 0.95 49 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 50 | # learning rate decay settings 51 | decay_lr = True # whether to decay the learning rate 52 | warmup_iters = 2000 # how many steps to warm up for 53 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 54 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 55 | # DDP settings 56 | backend = 'nccl' # 'nccl', 'gloo', etc. 57 | # system 58 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 59 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 60 | compile = True # use PyTorch 2.0 to compile the model to be faster 61 | # ----------------------------------------------------------------------------- 62 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 63 | exec(open('configurator.py').read()) # overrides from command line or config file 64 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 65 | # ----------------------------------------------------------------------------- 66 | 67 | if model_type == "gpt2_default": 68 | from model_default import GPTConfig, GPT 69 | elif model_type == "gpt2_sink": 70 | from model_sink import GPTConfig, GPT 71 | elif model_type == "gpt2_attn_bias": 72 | from model_attn_bias import GPTConfig, GPT 73 | else: 74 | raise ValueError(f"model_type {model_type} not supported") 75 | 76 | # various inits, derived attributes, I/O setup 77 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 78 | if ddp: 79 | init_process_group(backend=backend) 80 | ddp_rank = int(os.environ['RANK']) 81 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 82 | ddp_world_size = int(os.environ['WORLD_SIZE']) 83 | device = f'cuda:{ddp_local_rank}' 84 | torch.cuda.set_device(device) 85 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 86 | seed_offset = ddp_rank # each process gets a different seed 87 | # world_size number of processes will be training simultaneously, so we can scale 88 | # down the desired gradient accumulation iterations per process proportionally 89 | assert gradient_accumulation_steps % ddp_world_size == 0 90 | gradient_accumulation_steps //= ddp_world_size 91 | else: 92 | # if not ddp, we are running on a single gpu, and one process 93 | master_process = True 94 | seed_offset = 0 95 | ddp_world_size = 1 96 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size 97 | print(f"tokens per iteration will be: {tokens_per_iter:,}") 98 | 99 | if master_process: 100 | os.makedirs(out_dir, exist_ok=True) 101 | torch.manual_seed(1337 + seed_offset) 102 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 103 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 104 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 105 | # note: float16 data type will automatically use a GradScaler 106 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 107 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 108 | 109 | # poor man's data loader 110 | # data_dir = os.path.join('data', dataset) 111 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 112 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 113 | def get_batch(split): 114 | data = train_data if split == 'train' else val_data 115 | ix = torch.randint(len(data) - block_size, (batch_size,)) 116 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 117 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 118 | if device_type == 'cuda': 119 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 120 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 121 | else: 122 | x, y = x.to(device), y.to(device) 123 | return x, y 124 | 125 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 126 | iter_num = 0 127 | best_val_loss = 1e9 128 | 129 | # attempt to derive vocab_size from the dataset 130 | meta_path = os.path.join(data_dir, 'meta.pkl') 131 | meta_vocab_size = None 132 | if os.path.exists(meta_path): 133 | with open(meta_path, 'rb') as f: 134 | meta = pickle.load(f) 135 | meta_vocab_size = meta['vocab_size'] 136 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 137 | 138 | # model init 139 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 140 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line 141 | if init_from == 'scratch': 142 | # init a new model from scratch 143 | print("Initializing a new model from scratch") 144 | # determine the vocab size we'll use for from-scratch training 145 | if meta_vocab_size is None: 146 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 147 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 148 | gptconf = GPTConfig(**model_args) 149 | model = GPT(gptconf) 150 | elif init_from == 'resume': 151 | # resume training from a checkpoint. 152 | # ckpt_path = os.path.join(out_dir, 'ckpt_275000.pt') 153 | ckpt_path = os.path.join(out_dir, 'ckpt_15000.pt') 154 | print(f"******************************************************") 155 | print(f"Resuming training from {ckpt_path}") 156 | print(f"******************************************************") 157 | checkpoint = torch.load(ckpt_path, map_location=device) 158 | checkpoint_model_args = checkpoint['model_args'] 159 | # force these config attributes to be equal otherwise we can't even resume training 160 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 161 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 162 | model_args[k] = checkpoint_model_args[k] 163 | # create the model 164 | gptconf = GPTConfig(**model_args) 165 | model = GPT(gptconf) 166 | state_dict = checkpoint['model'] 167 | # fix the keys of the state dictionary :( 168 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 169 | unwanted_prefix = '_orig_mod.' 170 | for k,v in list(state_dict.items()): 171 | if k.startswith(unwanted_prefix): 172 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 173 | model.load_state_dict(state_dict) 174 | iter_num = checkpoint['iter_num'] 175 | best_val_loss = checkpoint['best_val_loss'] 176 | elif init_from.startswith('gpt2'): 177 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 178 | # initialize from OpenAI GPT-2 weights 179 | override_args = dict(dropout=dropout) 180 | model = GPT.from_pretrained(init_from, override_args) 181 | # read off the created config params, so we can store them into checkpoint correctly 182 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 183 | model_args[k] = getattr(model.config, k) 184 | # crop down the model block size if desired, using model surgery 185 | if block_size < model.config.block_size: 186 | model.crop_block_size(block_size) 187 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 188 | model.to(device) 189 | 190 | # initialize a GradScaler. If enabled=False scaler is a no-op 191 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 192 | 193 | # optimizer 194 | optimizer = model.configure_optimizers(optim_name, weight_decay, learning_rate, (beta1, beta2), device_type) 195 | if init_from == 'resume': 196 | optimizer.load_state_dict(checkpoint['optimizer']) 197 | checkpoint = None # free up memory 198 | 199 | # compile the model 200 | if compile: 201 | print("compiling the model... (takes a ~minute)") 202 | unoptimized_model = model 203 | model = torch.compile(model) # requires PyTorch 2.0 204 | 205 | # wrap model into DDP container 206 | if ddp: 207 | model = DDP(model, device_ids=[ddp_local_rank]) 208 | 209 | # helps estimate an arbitrarily accurate loss over either split using many batches 210 | @torch.no_grad() 211 | def estimate_loss(): 212 | out = {} 213 | model.eval() 214 | for split in ['train', 'val']: 215 | losses = torch.zeros(eval_iters) 216 | for k in range(eval_iters): 217 | X, Y = get_batch(split) 218 | with ctx: 219 | logits, loss = model(X, Y) 220 | losses[k] = loss.item() 221 | out[split] = losses.mean() 222 | model.train() 223 | return out 224 | 225 | # learning rate decay scheduler (cosine with warmup) 226 | def get_lr(it): 227 | # 1) linear warmup for warmup_iters steps 228 | if it < warmup_iters: 229 | return learning_rate * it / warmup_iters 230 | # 2) if it > lr_decay_iters, return min learning rate 231 | if it > lr_decay_iters: 232 | return min_lr 233 | # 3) in between, use cosine decay down to min learning rate 234 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 235 | assert 0 <= decay_ratio <= 1 236 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 237 | return min_lr + coeff * (learning_rate - min_lr) 238 | 239 | # logging 240 | if wandb_log and master_process: 241 | import wandb 242 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 243 | 244 | # training loop 245 | X, Y = get_batch('train') # fetch the very first batch 246 | t0 = time.time() 247 | local_iter_num = 0 # number of iterations in the lifetime of this process 248 | raw_model = model.module if ddp else model # unwrap DDP container if needed 249 | running_mfu = -1.0 250 | while True: 251 | 252 | # determine and set the learning rate for this iteration 253 | lr = get_lr(iter_num) if decay_lr else learning_rate 254 | for param_group in optimizer.param_groups: 255 | param_group['lr'] = lr 256 | 257 | # evaluate the loss on train/val sets and write checkpoints 258 | if iter_num % eval_interval == 0 and master_process: 259 | losses = estimate_loss() 260 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 261 | if wandb_log: 262 | wandb.log({ 263 | "iter": iter_num, 264 | "train/loss": losses['train'], 265 | "val/loss": losses['val'], 266 | "lr": lr, 267 | "mfu": running_mfu*100, # convert to percentage 268 | }) 269 | if losses['val'] < best_val_loss or always_save_checkpoint: 270 | best_val_loss = losses['val'] 271 | if iter_num > 0: 272 | checkpoint = { 273 | 'model': raw_model.state_dict(), 274 | 'optimizer': optimizer.state_dict(), 275 | 'model_args': model_args, 276 | 'iter_num': iter_num, 277 | 'best_val_loss': best_val_loss, 278 | 'config': config, 279 | } 280 | print(f"saving checkpoint to {out_dir}") 281 | torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{iter_num}.pt')) 282 | if iter_num == 0 and eval_only: 283 | break 284 | 285 | # forward backward update, with optional gradient accumulation to simulate larger batch size 286 | # and using the GradScaler if data type is float16 287 | for micro_step in range(gradient_accumulation_steps): 288 | if ddp: 289 | # in DDP training we only need to sync gradients at the last micro step. 290 | # the official way to do this is with model.no_sync() context manager, but 291 | # I really dislike that this bloats the code and forces us to repeat code 292 | # looking at the source of that context manager, it just toggles this variable 293 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 294 | with ctx: 295 | logits, loss = model(X, Y) 296 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation 297 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 298 | X, Y = get_batch('train') 299 | # backward pass, with gradient scaling if training in fp16 300 | scaler.scale(loss).backward() 301 | # clip the gradient 302 | if grad_clip != 0.0: 303 | scaler.unscale_(optimizer) 304 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 305 | # step the optimizer and scaler if training in fp16 306 | scaler.step(optimizer) 307 | scaler.update() 308 | # flush the gradients as soon as we can, no need for this memory anymore 309 | optimizer.zero_grad(set_to_none=True) 310 | 311 | # timing and logging 312 | t1 = time.time() 313 | dt = t1 - t0 314 | t0 = t1 315 | if iter_num % log_interval == 0 and master_process: 316 | # get loss as float. note: this is a CPU-GPU sync point 317 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum) 318 | lossf = loss.item() * gradient_accumulation_steps 319 | if local_iter_num >= 5: # let the training loop settle a bit 320 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 321 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 322 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 323 | iter_num += 1 324 | local_iter_num += 1 325 | 326 | # termination conditions 327 | if iter_num > max_iters: 328 | break 329 | 330 | if ddp: 331 | destroy_process_group() 332 | -------------------------------------------------------------------------------- /gpt-2/model_default.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full definition of a GPT Language Model, all of it in this single file. 3 | References: 4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 5 | https://github.com/openai/gpt-2/blob/master/src/model.py 6 | 2) huggingface/transformers PyTorch implementation: 7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 8 | """ 9 | 10 | import math 11 | import inspect 12 | from dataclasses import dataclass 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | 18 | class LayerNorm(nn.Module): 19 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 20 | 21 | def __init__(self, ndim, bias): 22 | super().__init__() 23 | self.weight = nn.Parameter(torch.ones(ndim)) 24 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 25 | 26 | def forward(self, input): 27 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 28 | 29 | class CausalSelfAttention(nn.Module): 30 | 31 | def __init__(self, config): 32 | super().__init__() 33 | assert config.n_embd % config.n_head == 0 34 | # key, query, value projections for all heads, but in a batch 35 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 36 | # output projection 37 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 38 | # regularization 39 | self.attn_dropout = nn.Dropout(config.dropout) 40 | self.resid_dropout = nn.Dropout(config.dropout) 41 | self.n_head = config.n_head 42 | self.n_embd = config.n_embd 43 | self.dropout = config.dropout 44 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 45 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 46 | if not self.flash: 47 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 48 | # causal mask to ensure that attention is only applied to the left in the input sequence 49 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 50 | .view(1, 1, config.block_size, config.block_size)) 51 | 52 | def forward(self, x): 53 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 54 | 55 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 56 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 57 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 58 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 59 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 60 | 61 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 62 | if self.flash: 63 | # efficient attention using Flash Attention CUDA kernels 64 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) 65 | # y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=temp_mask, dropout_p=self.dropout if self.training else 0, is_causal=False) 66 | else: 67 | # manual implementation of attention 68 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 69 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 70 | att = F.softmax(att, dim=-1) 71 | att = self.attn_dropout(att) 72 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 73 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 74 | 75 | # output projection 76 | y = self.resid_dropout(self.c_proj(y)) 77 | return y 78 | 79 | class MLP(nn.Module): 80 | 81 | def __init__(self, config): 82 | super().__init__() 83 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 84 | self.gelu = nn.GELU() 85 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 86 | self.dropout = nn.Dropout(config.dropout) 87 | 88 | def forward(self, x): 89 | x = self.c_fc(x) 90 | x = self.gelu(x) 91 | x = self.c_proj(x) 92 | x = self.dropout(x) 93 | return x 94 | 95 | class Block(nn.Module): 96 | 97 | def __init__(self, config): 98 | super().__init__() 99 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 100 | self.attn = CausalSelfAttention(config) 101 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 102 | self.mlp = MLP(config) 103 | 104 | def forward(self, x): 105 | x = x + self.attn(self.ln_1(x)) 106 | x = x + self.mlp(self.ln_2(x)) 107 | return x 108 | 109 | @dataclass 110 | class GPTConfig: 111 | block_size: int = 1024 112 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 113 | n_layer: int = 12 114 | n_head: int = 12 115 | n_embd: int = 768 116 | dropout: float = 0.0 117 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 118 | 119 | class GPT(nn.Module): 120 | 121 | def __init__(self, config): 122 | super().__init__() 123 | assert config.vocab_size is not None 124 | assert config.block_size is not None 125 | self.config = config 126 | 127 | self.transformer = nn.ModuleDict(dict( 128 | wte = nn.Embedding(config.vocab_size, config.n_embd), 129 | wpe = nn.Embedding(config.block_size, config.n_embd), 130 | drop = nn.Dropout(config.dropout), 131 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 132 | ln_f = LayerNorm(config.n_embd, bias=config.bias), 133 | )) 134 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 135 | # with weight tying when using torch.compile() some warnings get generated: 136 | # "UserWarning: functional_call was passed multiple values for tied weights. 137 | # This behavior is deprecated and will be an error in future versions" 138 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 139 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 140 | 141 | # init all weights 142 | self.apply(self._init_weights) 143 | # apply special scaled init to the residual projections, per GPT-2 paper 144 | for pn, p in self.named_parameters(): 145 | if pn.endswith('c_proj.weight'): 146 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 147 | 148 | # report number of parameters 149 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 150 | 151 | def get_num_params(self, non_embedding=True): 152 | """ 153 | Return the number of parameters in the model. 154 | For non-embedding count (default), the position embeddings get subtracted. 155 | The token embeddings would too, except due to the parameter sharing these 156 | params are actually used as weights in the final layer, so we include them. 157 | """ 158 | n_params = sum(p.numel() for p in self.parameters()) 159 | if non_embedding: 160 | n_params -= self.transformer.wpe.weight.numel() 161 | return n_params 162 | 163 | def _init_weights(self, module): 164 | if isinstance(module, nn.Linear): 165 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 166 | if module.bias is not None: 167 | torch.nn.init.zeros_(module.bias) 168 | elif isinstance(module, nn.Embedding): 169 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 170 | 171 | def forward(self, idx, targets=None): 172 | device = idx.device 173 | b, t = idx.size() 174 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 175 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) 176 | 177 | # forward the GPT model itself 178 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 179 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) 180 | x = self.transformer.drop(tok_emb + pos_emb) 181 | for block in self.transformer.h: 182 | x = block(x) 183 | x = self.transformer.ln_f(x) 184 | 185 | if targets is not None: 186 | # if we are given some desired targets also calculate the loss 187 | logits = self.lm_head(x) 188 | 189 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 190 | else: 191 | # inference-time mini-optimization: only forward the lm_head on the very last position 192 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 193 | loss = None 194 | 195 | return logits, loss 196 | 197 | def crop_block_size(self, block_size): 198 | # model surgery to decrease the block size if necessary 199 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 200 | # but want to use a smaller block size for some smaller, simpler model 201 | assert block_size <= self.config.block_size 202 | self.config.block_size = block_size 203 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) 204 | for block in self.transformer.h: 205 | if hasattr(block.attn, 'bias'): 206 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] 207 | 208 | @classmethod 209 | def from_pretrained(cls, model_type, override_args=None): 210 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 211 | override_args = override_args or {} # default to empty dict 212 | # only dropout can be overridden see more notes below 213 | assert all(k == 'dropout' for k in override_args) 214 | from transformers import GPT2LMHeadModel 215 | print("loading weights from pretrained gpt: %s" % model_type) 216 | 217 | # n_layer, n_head and n_embd are determined from model_type 218 | config_args = { 219 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 220 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 221 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 222 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 223 | }[model_type] 224 | print("forcing vocab_size=50257, block_size=1024, bias=True") 225 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints 226 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints 227 | config_args['bias'] = True # always True for GPT model checkpoints 228 | # we can override the dropout rate, if desired 229 | if 'dropout' in override_args: 230 | print(f"overriding dropout rate to {override_args['dropout']}") 231 | config_args['dropout'] = override_args['dropout'] 232 | # create a from-scratch initialized minGPT model 233 | config = GPTConfig(**config_args) 234 | model = GPT(config) 235 | sd = model.state_dict() 236 | sd_keys = sd.keys() 237 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param 238 | 239 | # init a huggingface/transformers model 240 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 241 | sd_hf = model_hf.state_dict() 242 | 243 | # copy while ensuring all of the parameters are aligned and match in names and shapes 244 | sd_keys_hf = sd_hf.keys() 245 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer 246 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) 247 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 248 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 249 | # this means that we have to transpose these weights when we import them 250 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 251 | for k in sd_keys_hf: 252 | if any(k.endswith(w) for w in transposed): 253 | # special treatment for the Conv1D weights we need to transpose 254 | assert sd_hf[k].shape[::-1] == sd[k].shape 255 | with torch.no_grad(): 256 | sd[k].copy_(sd_hf[k].t()) 257 | else: 258 | # vanilla copy over the other parameters 259 | assert sd_hf[k].shape == sd[k].shape 260 | with torch.no_grad(): 261 | sd[k].copy_(sd_hf[k]) 262 | 263 | return model 264 | 265 | def configure_optimizers(self, optim_name, weight_decay, learning_rate, betas, device_type): 266 | # start with all of the candidate parameters 267 | param_dict = {pn: p for pn, p in self.named_parameters()} 268 | # filter out those that do not require grad 269 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 270 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 271 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 272 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 273 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 274 | optim_groups = [ 275 | {'params': decay_params, 'weight_decay': weight_decay}, 276 | {'params': nodecay_params, 'weight_decay': 0.0} 277 | ] 278 | num_decay_params = sum(p.numel() for p in decay_params) 279 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 280 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 281 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 282 | # Create AdamW optimizer and use the fused version if it is available 283 | if optim_name == "adam": 284 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 285 | use_fused = fused_available and device_type == 'cuda' 286 | extra_args = dict(fused=True) if use_fused else dict() 287 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 288 | print(f"using fused AdamW: {use_fused}") 289 | elif optim_name == "gdm": 290 | optimizer = torch.optim.SGD(optim_groups, lr=learning_rate, momentum=betas[0], dampening=0.) 291 | 292 | return optimizer 293 | 294 | def estimate_mfu(self, fwdbwd_per_iter, dt): 295 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 296 | # first estimate the number of flops we do per iteration. 297 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 298 | N = self.get_num_params() 299 | cfg = self.config 300 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size 301 | flops_per_token = 6*N + 12*L*H*Q*T 302 | flops_per_fwdbwd = flops_per_token * T 303 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 304 | # express our flops throughput as ratio of A100 bfloat16 peak flops 305 | flops_achieved = flops_per_iter * (1.0/dt) # per second 306 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 307 | mfu = flops_achieved / flops_promised 308 | return mfu 309 | 310 | @torch.no_grad() 311 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 312 | """ 313 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 314 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 315 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 316 | """ 317 | for _ in range(max_new_tokens): 318 | # if the sequence context is growing too long we must crop it at block_size 319 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 320 | # forward the model to get the logits for the index in the sequence 321 | logits, _ = self(idx_cond) 322 | # pluck the logits at the final step and scale by desired temperature 323 | logits = logits[:, -1, :] / temperature 324 | # optionally crop the logits to only the top k options 325 | if top_k is not None: 326 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 327 | logits[logits < v[:, [-1]]] = -float('Inf') 328 | # apply softmax to convert logits to (normalized) probabilities 329 | probs = F.softmax(logits, dim=-1) 330 | # sample from the distribution 331 | idx_next = torch.multinomial(probs, num_samples=1) 332 | # append sampled index to the running sequence and continue 333 | idx = torch.cat((idx, idx_next), dim=1) 334 | 335 | return idx 336 | -------------------------------------------------------------------------------- /gpt-2/model_sink.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full definition of a GPT Language Model, all of it in this single file. 3 | References: 4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 5 | https://github.com/openai/gpt-2/blob/master/src/model.py 6 | 2) huggingface/transformers PyTorch implementation: 7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 8 | """ 9 | 10 | import math 11 | import inspect 12 | from dataclasses import dataclass 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | 18 | class LayerNorm(nn.Module): 19 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 20 | 21 | def __init__(self, ndim, bias): 22 | super().__init__() 23 | self.weight = nn.Parameter(torch.ones(ndim)) 24 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 25 | 26 | def forward(self, input): 27 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 28 | 29 | class CausalSelfAttention(nn.Module): 30 | 31 | def __init__(self, config): 32 | super().__init__() 33 | assert config.n_embd % config.n_head == 0 34 | # key, query, value projections for all heads, but in a batch 35 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 36 | # output projection 37 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 38 | # regularization 39 | self.attn_dropout = nn.Dropout(config.dropout) 40 | self.resid_dropout = nn.Dropout(config.dropout) 41 | self.n_head = config.n_head 42 | self.n_embd = config.n_embd 43 | self.dropout = config.dropout 44 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 45 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 46 | if not self.flash: 47 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 48 | # causal mask to ensure that attention is only applied to the left in the input sequence 49 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 50 | .view(1, 1, config.block_size, config.block_size)) 51 | 52 | def forward(self, x): 53 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 54 | 55 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 56 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 57 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 58 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 59 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 60 | 61 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 62 | if self.flash: 63 | # efficient attention using Flash Attention CUDA kernels 64 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) 65 | else: 66 | # manual implementation of attention 67 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 68 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 69 | att = F.softmax(att, dim=-1) 70 | att = self.attn_dropout(att) 71 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 72 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 73 | 74 | # output projection 75 | y = self.resid_dropout(self.c_proj(y)) 76 | return y 77 | 78 | class MLP(nn.Module): 79 | 80 | def __init__(self, config): 81 | super().__init__() 82 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 83 | self.gelu = nn.GELU() 84 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 85 | self.dropout = nn.Dropout(config.dropout) 86 | 87 | def forward(self, x): 88 | x = self.c_fc(x) 89 | x = self.gelu(x) 90 | x = self.c_proj(x) 91 | x = self.dropout(x) 92 | return x 93 | 94 | class Block(nn.Module): 95 | 96 | def __init__(self, config): 97 | super().__init__() 98 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 99 | self.attn = CausalSelfAttention(config) 100 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 101 | self.mlp = MLP(config) 102 | 103 | def forward(self, x): 104 | x = x + self.attn(self.ln_1(x)) 105 | x = x + self.mlp(self.ln_2(x)) 106 | return x 107 | 108 | @dataclass 109 | class GPTConfig: 110 | block_size: int = 1024 111 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 112 | n_layer: int = 12 113 | n_head: int = 12 114 | n_embd: int = 768 115 | dropout: float = 0.0 116 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 117 | 118 | class GPT(nn.Module): 119 | 120 | def __init__(self, config): 121 | super().__init__() 122 | assert config.vocab_size is not None 123 | assert config.block_size is not None 124 | self.config = config 125 | 126 | self.transformer = nn.ModuleDict(dict( 127 | wte = nn.Embedding(config.vocab_size, config.n_embd), 128 | wpe = nn.Embedding(config.block_size, config.n_embd), 129 | drop = nn.Dropout(config.dropout), 130 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 131 | ln_f = LayerNorm(config.n_embd, bias=config.bias), 132 | )) 133 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 134 | # with weight tying when using torch.compile() some warnings get generated: 135 | # "UserWarning: functional_call was passed multiple values for tied weights. 136 | # This behavior is deprecated and will be an error in future versions" 137 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 138 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 139 | 140 | # init all weights 141 | self.apply(self._init_weights) 142 | # apply special scaled init to the residual projections, per GPT-2 paper 143 | for pn, p in self.named_parameters(): 144 | if pn.endswith('c_proj.weight'): 145 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 146 | 147 | # report number of parameters 148 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 149 | 150 | ### add register token 151 | self.reg_token = nn.Parameter(torch.zeros(1, 1, config.n_embd), requires_grad=True) ### add register token 152 | torch.nn.init.normal_(self.reg_token, mean=0.0, std=0.02) 153 | 154 | def get_num_params(self, non_embedding=True): 155 | """ 156 | Return the number of parameters in the model. 157 | For non-embedding count (default), the position embeddings get subtracted. 158 | The token embeddings would too, except due to the parameter sharing these 159 | params are actually used as weights in the final layer, so we include them. 160 | """ 161 | n_params = sum(p.numel() for p in self.parameters()) 162 | if non_embedding: 163 | n_params -= self.transformer.wpe.weight.numel() 164 | return n_params 165 | 166 | def _init_weights(self, module): 167 | if isinstance(module, nn.Linear): 168 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 169 | if module.bias is not None: 170 | torch.nn.init.zeros_(module.bias) 171 | elif isinstance(module, nn.Embedding): 172 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 173 | 174 | def forward(self, idx, targets=None): 175 | device = idx.device 176 | b, t = idx.size() 177 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 178 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) 179 | 180 | # forward the GPT model itself 181 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 182 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) 183 | 184 | tok_emb = tok_emb + pos_emb 185 | ############################# 186 | ## insert a register/sink token 187 | reg_tok_emb = self.reg_token.repeat((b, 1, 1)) 188 | tok_emb = torch.cat([reg_tok_emb, tok_emb], dim=1) 189 | ################################### 190 | x = self.transformer.drop(tok_emb) 191 | for block in self.transformer.h: 192 | x = block(x) 193 | x = self.transformer.ln_f(x) 194 | x = x[:, 1:, :] ## remove the register token feature 195 | 196 | if targets is not None: 197 | # if we are given some desired targets also calculate the loss 198 | logits = self.lm_head(x) 199 | 200 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 201 | else: 202 | # inference-time mini-optimization: only forward the lm_head on the very last position 203 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 204 | loss = None 205 | 206 | return logits, loss 207 | 208 | def crop_block_size(self, block_size): 209 | # model surgery to decrease the block size if necessary 210 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 211 | # but want to use a smaller block size for some smaller, simpler model 212 | assert block_size <= self.config.block_size 213 | self.config.block_size = block_size 214 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) 215 | for block in self.transformer.h: 216 | if hasattr(block.attn, 'bias'): 217 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] 218 | 219 | @classmethod 220 | def from_pretrained(cls, model_type, override_args=None): 221 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 222 | override_args = override_args or {} # default to empty dict 223 | # only dropout can be overridden see more notes below 224 | assert all(k == 'dropout' for k in override_args) 225 | from transformers import GPT2LMHeadModel 226 | print("loading weights from pretrained gpt: %s" % model_type) 227 | 228 | # n_layer, n_head and n_embd are determined from model_type 229 | config_args = { 230 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 231 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 232 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 233 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 234 | }[model_type] 235 | print("forcing vocab_size=50257, block_size=1024, bias=True") 236 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints 237 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints 238 | config_args['bias'] = True # always True for GPT model checkpoints 239 | # we can override the dropout rate, if desired 240 | if 'dropout' in override_args: 241 | print(f"overriding dropout rate to {override_args['dropout']}") 242 | config_args['dropout'] = override_args['dropout'] 243 | # create a from-scratch initialized minGPT model 244 | config = GPTConfig(**config_args) 245 | model = GPT(config) 246 | sd = model.state_dict() 247 | sd_keys = sd.keys() 248 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param 249 | 250 | # init a huggingface/transformers model 251 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 252 | sd_hf = model_hf.state_dict() 253 | 254 | # copy while ensuring all of the parameters are aligned and match in names and shapes 255 | sd_keys_hf = sd_hf.keys() 256 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer 257 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) 258 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 259 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 260 | # this means that we have to transpose these weights when we import them 261 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 262 | for k in sd_keys_hf: 263 | if any(k.endswith(w) for w in transposed): 264 | # special treatment for the Conv1D weights we need to transpose 265 | assert sd_hf[k].shape[::-1] == sd[k].shape 266 | with torch.no_grad(): 267 | sd[k].copy_(sd_hf[k].t()) 268 | else: 269 | # vanilla copy over the other parameters 270 | assert sd_hf[k].shape == sd[k].shape 271 | with torch.no_grad(): 272 | sd[k].copy_(sd_hf[k]) 273 | 274 | return model 275 | 276 | def configure_optimizers(self, optim_name, weight_decay, learning_rate, betas, device_type): 277 | # start with all of the candidate parameters 278 | param_dict = {pn: p for pn, p in self.named_parameters()} 279 | # filter out those that do not require grad 280 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 281 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 282 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 283 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 284 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 285 | optim_groups = [ 286 | {'params': decay_params, 'weight_decay': weight_decay}, 287 | {'params': nodecay_params, 'weight_decay': 0.0} 288 | ] 289 | num_decay_params = sum(p.numel() for p in decay_params) 290 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 291 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 292 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 293 | # Create AdamW optimizer and use the fused version if it is available 294 | if optim_name == "adam": 295 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 296 | use_fused = fused_available and device_type == 'cuda' 297 | extra_args = dict(fused=True) if use_fused else dict() 298 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 299 | print(f"using fused AdamW: {use_fused}") 300 | elif optim_name == "gdm": 301 | optimizer = torch.optim.SGD(optim_groups, lr=learning_rate, momentum=betas[0], dampening=0.) 302 | 303 | return optimizer 304 | 305 | def estimate_mfu(self, fwdbwd_per_iter, dt): 306 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 307 | # first estimate the number of flops we do per iteration. 308 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 309 | N = self.get_num_params() 310 | cfg = self.config 311 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size 312 | flops_per_token = 6*N + 12*L*H*Q*T 313 | flops_per_fwdbwd = flops_per_token * T 314 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 315 | # express our flops throughput as ratio of A100 bfloat16 peak flops 316 | flops_achieved = flops_per_iter * (1.0/dt) # per second 317 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 318 | mfu = flops_achieved / flops_promised 319 | return mfu 320 | 321 | @torch.no_grad() 322 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 323 | """ 324 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 325 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 326 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 327 | """ 328 | for _ in range(max_new_tokens): 329 | # if the sequence context is growing too long we must crop it at block_size 330 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 331 | # forward the model to get the logits for the index in the sequence 332 | logits, _ = self(idx_cond) 333 | # pluck the logits at the final step and scale by desired temperature 334 | logits = logits[:, -1, :] / temperature 335 | # optionally crop the logits to only the top k options 336 | if top_k is not None: 337 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 338 | logits[logits < v[:, [-1]]] = -float('Inf') 339 | # apply softmax to convert logits to (normalized) probabilities 340 | probs = F.softmax(logits, dim=-1) 341 | # sample from the distribution 342 | idx_next = torch.multinomial(probs, num_samples=1) 343 | # append sampled index to the running sequence and continue 344 | idx = torch.cat((idx, idx_next), dim=1) 345 | 346 | return idx 347 | -------------------------------------------------------------------------------- /gpt-2/model_attn_bias.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full definition of a GPT Language Model, all of it in this single file. 3 | References: 4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 5 | https://github.com/openai/gpt-2/blob/master/src/model.py 6 | 2) huggingface/transformers PyTorch implementation: 7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 8 | """ 9 | 10 | import math 11 | import inspect 12 | from dataclasses import dataclass 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | 18 | class LayerNorm(nn.Module): 19 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 20 | 21 | def __init__(self, ndim, bias): 22 | super().__init__() 23 | self.weight = nn.Parameter(torch.ones(ndim)) 24 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 25 | 26 | def forward(self, input): 27 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 28 | 29 | class CausalSelfAttention(nn.Module): 30 | 31 | def __init__(self, config): 32 | super().__init__() 33 | assert config.n_embd % config.n_head == 0 34 | # key, query, value projections for all heads, but in a batch 35 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 36 | # output projection 37 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 38 | # regularization 39 | self.attn_dropout = nn.Dropout(config.dropout) 40 | self.resid_dropout = nn.Dropout(config.dropout) 41 | self.n_head = config.n_head 42 | self.n_embd = config.n_embd 43 | self.dropout = config.dropout 44 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 45 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 46 | if not self.flash: 47 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 48 | # causal mask to ensure that attention is only applied to the left in the input sequence 49 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 50 | .view(1, 1, config.block_size, config.block_size)) 51 | 52 | ############################################################################# 53 | self.k_bias = nn.Parameter(torch.zeros(1, self.n_head, 1, self.n_embd // self.n_head), requires_grad=True) 54 | self.v_bias = nn.Parameter(torch.zeros(1, self.n_head, 1, self.n_embd // self.n_head), requires_grad=True) 55 | 56 | torch.nn.init.normal_(self.k_bias, mean=0.0, std=0.02) 57 | torch.nn.init.normal_(self.v_bias, mean=0.0, std=0.02) 58 | ############################################################################# 59 | 60 | 61 | def forward(self, x): 62 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 63 | 64 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 65 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 66 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 67 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 68 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 69 | 70 | ########################################################### 71 | L, S = q.size(-2), k.size(-2) 72 | temp_mask = torch.ones(L, S, dtype=torch.bool, device=q.device).tril(diagonal=0) 73 | true_values = torch.ones(temp_mask.size(0), 1, dtype=torch.bool, device=q.device) 74 | temp_mask = torch.cat((true_values, temp_mask), dim=1) 75 | 76 | k_bias = self.k_bias.repeat(B, 1, 1, 1) # (B, num_heads, 1, dim // num_heads) 77 | v_bias = self.v_bias.repeat(B, 1, 1, 1) 78 | 79 | k = torch.cat((k_bias, k), dim=2) 80 | v = torch.cat((v_bias, v), dim=2) 81 | ########################################################### 82 | 83 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 84 | if self.flash: 85 | # efficient attention using Flash Attention CUDA kernels 86 | # y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) 87 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=temp_mask, dropout_p=self.dropout if self.training else 0, is_causal=False) 88 | else: 89 | # manual implementation of attention 90 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 91 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 92 | att = F.softmax(att, dim=-1) 93 | att = self.attn_dropout(att) 94 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 95 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 96 | 97 | # output projection 98 | y = self.resid_dropout(self.c_proj(y)) 99 | return y 100 | 101 | class MLP(nn.Module): 102 | 103 | def __init__(self, config): 104 | super().__init__() 105 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 106 | self.gelu = nn.GELU() 107 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 108 | self.dropout = nn.Dropout(config.dropout) 109 | 110 | def forward(self, x): 111 | x = self.c_fc(x) 112 | x = self.gelu(x) 113 | x = self.c_proj(x) 114 | x = self.dropout(x) 115 | return x 116 | 117 | class Block(nn.Module): 118 | 119 | def __init__(self, config): 120 | super().__init__() 121 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 122 | self.attn = CausalSelfAttention(config) 123 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 124 | self.mlp = MLP(config) 125 | 126 | def forward(self, x): 127 | x = x + self.attn(self.ln_1(x)) 128 | x = x + self.mlp(self.ln_2(x)) 129 | return x 130 | 131 | @dataclass 132 | class GPTConfig: 133 | block_size: int = 1024 134 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 135 | n_layer: int = 12 136 | n_head: int = 12 137 | n_embd: int = 768 138 | dropout: float = 0.0 139 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 140 | 141 | class GPT(nn.Module): 142 | 143 | def __init__(self, config): 144 | super().__init__() 145 | assert config.vocab_size is not None 146 | assert config.block_size is not None 147 | self.config = config 148 | 149 | self.transformer = nn.ModuleDict(dict( 150 | wte = nn.Embedding(config.vocab_size, config.n_embd), 151 | wpe = nn.Embedding(config.block_size, config.n_embd), 152 | drop = nn.Dropout(config.dropout), 153 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 154 | ln_f = LayerNorm(config.n_embd, bias=config.bias), 155 | )) 156 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 157 | # with weight tying when using torch.compile() some warnings get generated: 158 | # "UserWarning: functional_call was passed multiple values for tied weights. 159 | # This behavior is deprecated and will be an error in future versions" 160 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 161 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 162 | 163 | # init all weights 164 | self.apply(self._init_weights) 165 | # apply special scaled init to the residual projections, per GPT-2 paper 166 | for pn, p in self.named_parameters(): 167 | if pn.endswith('c_proj.weight'): 168 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 169 | 170 | # report number of parameters 171 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 172 | 173 | def get_num_params(self, non_embedding=True): 174 | """ 175 | Return the number of parameters in the model. 176 | For non-embedding count (default), the position embeddings get subtracted. 177 | The token embeddings would too, except due to the parameter sharing these 178 | params are actually used as weights in the final layer, so we include them. 179 | """ 180 | n_params = sum(p.numel() for p in self.parameters()) 181 | if non_embedding: 182 | n_params -= self.transformer.wpe.weight.numel() 183 | return n_params 184 | 185 | def _init_weights(self, module): 186 | if isinstance(module, nn.Linear): 187 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 188 | if module.bias is not None: 189 | torch.nn.init.zeros_(module.bias) 190 | elif isinstance(module, nn.Embedding): 191 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 192 | 193 | def forward(self, idx, targets=None): 194 | device = idx.device 195 | b, t = idx.size() 196 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 197 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) 198 | 199 | # forward the GPT model itself 200 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 201 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) 202 | x = self.transformer.drop(tok_emb + pos_emb) 203 | for block in self.transformer.h: 204 | x = block(x) 205 | x = self.transformer.ln_f(x) 206 | 207 | if targets is not None: 208 | # if we are given some desired targets also calculate the loss 209 | logits = self.lm_head(x) 210 | 211 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 212 | else: 213 | # inference-time mini-optimization: only forward the lm_head on the very last position 214 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 215 | loss = None 216 | 217 | return logits, loss 218 | 219 | def crop_block_size(self, block_size): 220 | # model surgery to decrease the block size if necessary 221 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 222 | # but want to use a smaller block size for some smaller, simpler model 223 | assert block_size <= self.config.block_size 224 | self.config.block_size = block_size 225 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) 226 | for block in self.transformer.h: 227 | if hasattr(block.attn, 'bias'): 228 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] 229 | 230 | @classmethod 231 | def from_pretrained(cls, model_type, override_args=None): 232 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 233 | override_args = override_args or {} # default to empty dict 234 | # only dropout can be overridden see more notes below 235 | assert all(k == 'dropout' for k in override_args) 236 | from transformers import GPT2LMHeadModel 237 | print("loading weights from pretrained gpt: %s" % model_type) 238 | 239 | # n_layer, n_head and n_embd are determined from model_type 240 | config_args = { 241 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 242 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 243 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 244 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 245 | }[model_type] 246 | print("forcing vocab_size=50257, block_size=1024, bias=True") 247 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints 248 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints 249 | config_args['bias'] = True # always True for GPT model checkpoints 250 | # we can override the dropout rate, if desired 251 | if 'dropout' in override_args: 252 | print(f"overriding dropout rate to {override_args['dropout']}") 253 | config_args['dropout'] = override_args['dropout'] 254 | # create a from-scratch initialized minGPT model 255 | config = GPTConfig(**config_args) 256 | model = GPT(config) 257 | sd = model.state_dict() 258 | sd_keys = sd.keys() 259 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param 260 | 261 | # init a huggingface/transformers model 262 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 263 | sd_hf = model_hf.state_dict() 264 | 265 | # copy while ensuring all of the parameters are aligned and match in names and shapes 266 | sd_keys_hf = sd_hf.keys() 267 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer 268 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) 269 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 270 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 271 | # this means that we have to transpose these weights when we import them 272 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 273 | for k in sd_keys_hf: 274 | if any(k.endswith(w) for w in transposed): 275 | # special treatment for the Conv1D weights we need to transpose 276 | assert sd_hf[k].shape[::-1] == sd[k].shape 277 | with torch.no_grad(): 278 | sd[k].copy_(sd_hf[k].t()) 279 | else: 280 | # vanilla copy over the other parameters 281 | assert sd_hf[k].shape == sd[k].shape 282 | with torch.no_grad(): 283 | sd[k].copy_(sd_hf[k]) 284 | 285 | return model 286 | 287 | def configure_optimizers(self, optim_name, weight_decay, learning_rate, betas, device_type): 288 | # start with all of the candidate parameters 289 | param_dict = {pn: p for pn, p in self.named_parameters()} 290 | # filter out those that do not require grad 291 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 292 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 293 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 294 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 295 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 296 | optim_groups = [ 297 | {'params': decay_params, 'weight_decay': weight_decay}, 298 | {'params': nodecay_params, 'weight_decay': 0.0} 299 | ] 300 | num_decay_params = sum(p.numel() for p in decay_params) 301 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 302 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 303 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 304 | # Create AdamW optimizer and use the fused version if it is available 305 | if optim_name == "adam": 306 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 307 | use_fused = fused_available and device_type == 'cuda' 308 | extra_args = dict(fused=True) if use_fused else dict() 309 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 310 | print(f"using fused AdamW: {use_fused}") 311 | elif optim_name == "gdm": 312 | optimizer = torch.optim.SGD(optim_groups, lr=learning_rate, momentum=betas[0], dampening=0.) 313 | 314 | return optimizer 315 | 316 | def estimate_mfu(self, fwdbwd_per_iter, dt): 317 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 318 | # first estimate the number of flops we do per iteration. 319 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 320 | N = self.get_num_params() 321 | cfg = self.config 322 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size 323 | flops_per_token = 6*N + 12*L*H*Q*T 324 | flops_per_fwdbwd = flops_per_token * T 325 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 326 | # express our flops throughput as ratio of A100 bfloat16 peak flops 327 | flops_achieved = flops_per_iter * (1.0/dt) # per second 328 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 329 | mfu = flops_achieved / flops_promised 330 | return mfu 331 | 332 | @torch.no_grad() 333 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 334 | """ 335 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 336 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 337 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 338 | """ 339 | for _ in range(max_new_tokens): 340 | # if the sequence context is growing too long we must crop it at block_size 341 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 342 | # forward the model to get the logits for the index in the sequence 343 | logits, _ = self(idx_cond) 344 | # pluck the logits at the final step and scale by desired temperature 345 | logits = logits[:, -1, :] / temperature 346 | # optionally crop the logits to only the top k options 347 | if top_k is not None: 348 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 349 | logits[logits < v[:, [-1]]] = -float('Inf') 350 | # apply softmax to convert logits to (normalized) probabilities 351 | probs = F.softmax(logits, dim=-1) 352 | # sample from the distribution 353 | idx_next = torch.multinomial(probs, num_samples=1) 354 | # append sampled index to the running sequence and continue 355 | idx = torch.cat((idx, idx_next), dim=1) 356 | 357 | return idx 358 | --------------------------------------------------------------------------------