├── .gitattributes ├── .gitignore ├── README.md ├── assets ├── french.png ├── german.png └── scandinavian.png ├── autoencoder ├── autoencoder.py ├── configurator.py ├── feature-browser │ ├── build_website.py │ ├── main_page.py │ └── subpages.py ├── prepare.py ├── resource_loader.py ├── train.py └── utils │ ├── __init__.py │ └── plotting_utils.py ├── reproduction.md ├── requirements.txt └── transformer ├── README.md ├── config ├── train_gpt2.py └── train_shakespeare_char.py ├── configurator.py ├── data ├── openwebtext │ ├── prepare.py │ └── readme.md ├── shakespeare │ ├── prepare.py │ └── readme.md └── shakespeare_char │ ├── prepare.py │ └── readme.md ├── hooked_model.py ├── model.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Override jupyter in Github language stats for more accurate estimate of repo code languages 2 | # reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code 3 | *.ipynb linguist-generated 4 | *.html linguist-generated 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | .ipynb_checkpoints/ 4 | .vscode 5 | __pycache__/ 6 | *.bin 7 | *.pkl 8 | *.pt 9 | *.pyc 10 | *.sh 11 | *.html 12 | *.arrow 13 | *.ipynb 14 | *.css 15 | *.err 16 | *.out 17 | *.png 18 | input.txt 19 | notes.md 20 | autoencoder.ipynb 21 | blog.md 22 | env/ 23 | slurm/ 24 | wandb/ 25 | notes+papers 26 | notes_exp.md 27 | autoencoder/out/env_3.12/ 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Towards Monosemanticity: Decomposing Language Models With Dictionary Learning 3 | 4 | This repository reproduces results of [Anthropic's Sparse Dictionary Learning paper](https://transformer-circuits.pub/2023/monosemantic-features/). The codebase is quite rough, but the results are excellent. See the [feature interface](https://shehper.github.io/feature-interface/) to browse through the features learned by the sparse autoencoder. There are improvements to be made (see the [TODOs](#todos) section below), and I will work on them intermittently as I juggle things in life :) 5 | 6 | I trained a 1-layer transformer model from scratch using [nanoGPT](https://github.com/karpathy/nanoGPT) with $d_{\text{model}} = 128$. Then, I trained a sparse autoencoder with $4096$ features on its MLP activations as in [Anthropic's paper](https://transformer-circuits.pub/2023/monosemantic-features/). 93% of the autoencoder neurons were alive, only 5% of which were of ultra-low density. There are several interesting features. For example, there is [a feature for French language](https://shehper.github.io/feature-interface/?page=2011), 7 | 8 |

9 | 10 |

11 | 12 | a feature each for German, Japanese, and many other languages, as well many other interesting features: 13 | 14 | - [A feature for German](https://shehper.github.io/feature-interface/?page=156) 15 | - [A feature for Scandinavian languages](https://shehper.github.io/feature-interface/?page=1634) 16 | - [A feature for Japanese](https://shehper.github.io/feature-interface/?page=1989) 17 | - [A feature for Hebrew](https://shehper.github.io/feature-interface/?page=2026) 18 | - [A feature for Cyrilic vowels](https://shehper.github.io/feature-interface/?page=3987) 19 | - [A feature for token "at" in words like "Croatian", "Scat", "Hayat", etc](https://shehper.github.io/feature-interface/?page=1662) 20 | - [A single token feature for "much"](https://shehper.github.io/feature-interface/?page=2760) 21 | - [A feature for sports leagues: NHL, NBA, etc](https://shehper.github.io/feature-interface/?page=379) 22 | - [A feature for Gregorian calendar dates](https://shehper.github.io/feature-interface/?page=344) 23 | - [A feature for "when"](https://shehper.github.io/feature-interface/?page=2022): 24 | - this feature particularly stands out because of the size of the mode around large activation values. 25 | - [A feature for "&"](https://shehper.github.io/feature-interface/?page=1916) 26 | - [A feature for ")"](https://shehper.github.io/feature-interface/?page=1917) 27 | - [A feature for "v" in URLs like "com/watch?v=SiN8](https://shehper.github.io/feature-interface/?page=27) 28 | - [A feature for programming code](https://shehper.github.io/feature-interface/?page=45) 29 | - [A feature for Donald Trump](https://shehper.github.io/feature-interface/?page=292) 30 | - [A feature for LaTeX](https://shehper.github.io/feature-interface/?page=538) 31 | 32 | 34 | 35 | 36 | 37 | ### Training Details 38 | 39 | I used the "OpenWebText" dataset to train the transformer model, to generate the MLP activations dataset for the autoencoder, and to generate the feature interface visualizations. The transformer model had $d_{\text{model}}= 128$, $d_{\text{MLP}} = 512$, and $n_{\text{head}}= 4$. I trained this model for $2 \times 10^5$ iterations to roughly match the number of epochs with [Anthropic's training procedure](https://transformer-circuits.pub/2023/monosemantic-features#appendix-transformer). 40 | 41 | I collected the dataset of 4B MLP activations by performing forward pass on 20M prompts (each of length 1024), keeping 200 activation vectors from each prompt. Next, I trained the autoencoder for approximately $5 \times 10^5$ training steps at batch size 8192 and learning rate $3 \times 10^{-4}$. I performed neuron resampling 4 times during training at training steps $2.5 \times i \times 10^4$ for $i=1, 2, 3, 4$. See a complete log of the training run on the [W&B page](https://wandb.ai/shehper/sparse-autoencoder-openwebtext-public/runs/vjbcwjsf?nw=nwusershehper). The L1-coefficient for this training run is $10^{-3}$. I selected the L1-coefficient and the learning rate by performing a grid search. 42 | 43 | For the most part, I followed the training procedure described in the [appendix](https://transformer-circuits.pub/2023/monosemantic-features#appendix-autoencoder) of Anthropic's original paper. I did not follow the improvements they suggested in their [January](https://transformer-circuits.pub/2024/jan-update/index.html) and [February](https://transformer-circuits.pub/2024/feb-update/index.html) updates. 44 | 45 | ### TODOs 46 | - Incorporate the effects of feature ablations in the feature interface. 47 | - Implement an interface to see "Feature Activations on Example Texts" as done by Anthropic [here](https://transformer-circuits.pub/2023/monosemantic-features/vis/a1-math.html). 48 | - Modify the code so that one can train a sparse autoencoder on activations of any MLP / attention layer. 49 | 50 | ### Related Work 51 | There are several other very interesting works on the web exploring sparse dictionary learning. Here is a small subset of them. 52 | 53 | - [Sparse Autoencoders Find Highly Interpretable Features in Language Models by Cunningham, et al.](https://arxiv.org/abs/2309.08600) 54 | - [Sparse Autoencoders Work on Attention Layer Outputs by Kissane, et al.](https://www.lesswrong.com/posts/DtdzGwFh9dCfsekZZ/sparse-autoencoders-work-on-attention-layer-outputs) 55 | - [Joseph Bloom's SAE codebase](https://github.com/jbloomAus/mats_sae_training) along with a blogpost on [trained SAEs for all residual stream layers of GPT-2 small](https://www.alignmentforum.org/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream) 56 | - [Neel Nanda's SAE codebase](https://github.com/neelnanda-io/1L-Sparse-Autoencoder) along with a [blogpost](https://www.lesswrong.com/posts/fKuugaxt2XLTkASkk/open-source-replication-and-commentary-on-anthropic-s) 57 | - [Callum McDougall's exercises on SAEs](https://github.com/callummcdougall/sae-exercises-mats/tree/main) 58 | - [SAE library by AI Safey Foundation](https://github.com/ai-safety-foundation/sparse_autoencoder) 59 | 60 | -------------------------------------------------------------------------------- /assets/french.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shehper/sparse-dictionary-learning/aa667b1863fa1cc14fb30819fac561f3517c4cea/assets/french.png -------------------------------------------------------------------------------- /assets/german.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shehper/sparse-dictionary-learning/aa667b1863fa1cc14fb30819fac561f3517c4cea/assets/german.png -------------------------------------------------------------------------------- /assets/scandinavian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shehper/sparse-dictionary-learning/aa667b1863fa1cc14fb30819fac561f3517c4cea/assets/scandinavian.png -------------------------------------------------------------------------------- /autoencoder/autoencoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines an AutoEncoder class, which also contains an implementation of neuron resampling. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class AutoEncoder(nn.Module): 10 | def __init__(self, n_inputs: int, n_latents: int, lam: float = 0.003, resampling_interval: int = 25000): 11 | """ 12 | n_input: Number of inputs 13 | n_latents: Number of neurons in the hidden layer 14 | lam: L1-coefficient for Sparse Autoencoder 15 | resampling_interval: Number of training steps after which dead neurons will be resampled 16 | """ 17 | super().__init__() 18 | self.n_inputs, self.n_latents = n_inputs, n_latents 19 | self.encoder = nn.Linear(n_inputs, n_latents) 20 | self.relu = nn.ReLU() 21 | self.decoder = nn.Linear(n_latents, n_inputs) 22 | self.lam = lam 23 | self.resampling_interval = resampling_interval 24 | self.dead_neurons = None 25 | self.normalize_decoder_columns() 26 | 27 | def forward(self, x): 28 | latents = self.encode(x) 29 | reconstructed = self.decode(latents) 30 | loss = self.calculate_loss(x, latents, reconstructed) 31 | 32 | if self.training: 33 | return {'loss': loss, 'latents': latents} 34 | else: 35 | return { 36 | 'loss': loss, 37 | 'latents': latents, 38 | 'reconst_acts': reconstructed, 39 | 'mse_loss': self.mse_loss(reconstructed, x), 40 | 'l1_loss': self.l1_loss(latents) 41 | } 42 | 43 | def encode(self, x): 44 | bias_corrected_input = x - self.decoder.bias 45 | return self.relu(self.encoder(bias_corrected_input)) 46 | 47 | def decode(self, encoded): 48 | return self.decoder(encoded) 49 | 50 | def calculate_loss(self, x, encoded, reconstructed): 51 | mse_loss = self.mse_loss(reconstructed, x) 52 | l1_loss = self.l1_loss(encoded) 53 | return mse_loss + self.lam * l1_loss 54 | 55 | def mse_loss(self, reconstructed, original): 56 | return F.mse_loss(reconstructed, original) 57 | 58 | def l1_loss(self, encoded): 59 | return F.l1_loss(encoded, torch.zeros_like(encoded), reduction='sum') / encoded.shape[0] 60 | 61 | @torch.no_grad() 62 | def get_feature_activations(self, inputs, start_idx, end_idx): 63 | """ 64 | Computes the activations of a subset of features in the hidden layer. 65 | 66 | :param inputs: Input tensor of shape (..., n) where n = d_MLP. It includes batch dimensions. 67 | :param start_idx: Starting index (inclusive) of the feature subset. 68 | :param end_idx: Ending index (exclusive) of the feature subset. 69 | 70 | Returns the activations for the specified feature range, reducing computation by 71 | only processing the necessary part of the network's weights and biases. 72 | """ 73 | adjusted_inputs = inputs - self.decoder.bias # Adjust input to account for decoder bias 74 | weight_subset = self.encoder.weight[start_idx:end_idx, :].t() # Transpose the subset of weights 75 | bias_subset = self.encoder.bias[start_idx:end_idx] 76 | 77 | activations = self.relu(adjusted_inputs @ weight_subset + bias_subset) 78 | 79 | return activations 80 | 81 | @torch.no_grad() 82 | def normalize_decoder_columns(self): 83 | """ 84 | Normalize the decoder's weight vectors to have unit norm along the feature dimension. 85 | This normalization can help in maintaining the stability of the network's weights. 86 | """ 87 | self.decoder.weight.data = F.normalize(self.decoder.weight.data, dim=0) 88 | 89 | def remove_parallel_component_of_decoder_grad(self): 90 | """ 91 | Remove the component of the gradient parallel to the decoder's weight vectors. 92 | """ 93 | unit_weights = F.normalize(self.decoder.weight, dim=0) # \hat{b} 94 | proj = (self.decoder.weight.grad * unit_weights).sum(dim=0) * unit_weights 95 | self.decoder.weight.grad = self.decoder.weight.grad - proj 96 | 97 | @staticmethod 98 | def is_dead_neuron_investigation_step(step, resampling_interval, num_resamples): 99 | """ 100 | Determine if the current step is the start of a phase for investigating dead neurons. 101 | According to Anthropic's specified policy, it occurs at odd multiples of half the resampling interval. 102 | """ 103 | return (step > 0) and step % (resampling_interval // 2) == 0 and (step // (resampling_interval // 2)) % 2 != 0 and step < resampling_interval * num_resamples 104 | 105 | @staticmethod 106 | def is_within_neuron_investigation_phase(step, resampling_interval, num_resamples): 107 | """ 108 | Check if the current step is within a phase where active neurons are investigated. 109 | This phase occurs in intervals defined in the specified range, starting at odd multiples of half the resampling interval. 110 | """ 111 | return any(milestone - resampling_interval // 2 <= step < milestone 112 | for milestone in range(resampling_interval, resampling_interval * (num_resamples + 1), resampling_interval)) 113 | 114 | @torch.no_grad() 115 | def initiate_dead_neurons(self): 116 | self.dead_neurons = set(range(self.n_latents)) 117 | 118 | @torch.no_grad() 119 | def update_dead_neurons(self, latents): 120 | """ 121 | Update the set of dead neurons based on the current feature activations. 122 | If a neuron is active (has non-zero activation), it is removed from the dead neuron set. 123 | """ 124 | active_neurons = torch.nonzero(torch.count_nonzero(latents, dim=0), as_tuple=False).view(-1) 125 | self.dead_neurons.difference_update(active_neurons.tolist()) 126 | 127 | @torch.no_grad() 128 | def resample_dead_neurons(self, data, optimizer, batch_size=8192): 129 | """ 130 | Resample the dead neurons by resetting their weights and biases based on the characteristics 131 | of active neurons. Proceeds only if there are dead neurons to resample. 132 | """ 133 | if not self.dead_neurons: 134 | return 135 | 136 | device = self._get_device() 137 | dead_neurons_t, alive_neurons = self._get_neuron_indices() 138 | average_enc_norm = self._compute_average_norm_of_alive_neurons(alive_neurons) 139 | probs = self._compute_loss_probabilities(data, batch_size, device) 140 | selected_examples = self._select_examples_based_on_probabilities(data, probs) 141 | 142 | self._resample_neurons(selected_examples, dead_neurons_t, average_enc_norm, device) 143 | self._update_optimizer_parameters(optimizer, dead_neurons_t) 144 | 145 | print('Dead neurons resampled successfully!') 146 | self.dead_neurons = None 147 | 148 | def _get_device(self): 149 | return next(self.parameters()).device 150 | 151 | def _get_neuron_indices(self): 152 | dead_neurons_t = torch.tensor(list(self.dead_neurons), device=self._get_device()) 153 | alive_neurons = torch.tensor([i for i in range(self.n_latents) if i not in self.dead_neurons], device=self._get_device()) 154 | return dead_neurons_t, alive_neurons 155 | 156 | def _compute_average_norm_of_alive_neurons(self, alive_neurons): 157 | return torch.linalg.vector_norm(self.encoder.weight[alive_neurons], dim=1).mean() 158 | 159 | def _compute_loss_probabilities(self, data, batch_size, device): 160 | num_batches = (len(data) + batch_size - 1) // batch_size 161 | probs = torch.zeros(len(data), device=device) 162 | for i in range(num_batches): 163 | batch_slice = slice(i * batch_size, (i + 1) * batch_size) 164 | x_batch = data[batch_slice].to(device) 165 | probs[batch_slice] = self._compute_batch_loss_squared(x_batch) 166 | return probs.cpu() 167 | 168 | def _compute_batch_loss_squared(self, x_batch): 169 | latents = self.encode(x_batch) 170 | reconst_acts = self.decode(latents) 171 | mselosses = F.mse_loss(reconst_acts, x_batch, reduction='none').sum(dim=1) 172 | l1losses = F.l1_loss(latents, torch.zeros_like(latents), reduction='none').sum(dim=1) 173 | return (mselosses + self.lam * l1losses).square() 174 | 175 | def _select_examples_based_on_probabilities(self, data, probs): 176 | selection_indices = torch.multinomial(probs, num_samples=len(self.dead_neurons)) 177 | return data[selection_indices].to(dtype=torch.float32) 178 | 179 | def _resample_neurons(self, examples, dead_neurons_t, average_enc_norm, device): 180 | examples_unit_norm = F.normalize(examples, dim=1).to(device) 181 | self.decoder.weight[:, dead_neurons_t] = examples_unit_norm.T 182 | 183 | # Renormalize examples to have a certain norm and reset encoder weights and biases 184 | adjusted_examples = examples_unit_norm * average_enc_norm * 0.2 185 | self.encoder.weight[dead_neurons_t] = adjusted_examples 186 | self.encoder.bias[dead_neurons_t] = 0 187 | 188 | def _update_optimizer_parameters(self, optimizer, dead_neurons_t): 189 | for i, param in enumerate(optimizer.param_groups[0]['params']): 190 | param_state = optimizer.state[param] 191 | if i in [0, 1]: # Encoder weights and biases 192 | param_state['exp_avg'][dead_neurons_t] = 0 193 | param_state['exp_avg_sq'][dead_neurons_t] = 0 194 | elif i == 2: # Decoder weights 195 | param_state['exp_avg'][:, dead_neurons_t] = 0 196 | param_state['exp_avg_sq'][:, dead_neurons_t] = 0 197 | -------------------------------------------------------------------------------- /autoencoder/configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /autoencoder/feature-browser/build_website.py: -------------------------------------------------------------------------------- 1 | """ 2 | Make a feature browser for a trained autoencoder model. 3 | In this file, it is useful to keep track of shapes of each tensor. 4 | Each tensor is followed by a comment describing its shape. 5 | I use the following glossary: 6 | S: num_sampled_tokens 7 | R: window_radius 8 | W: window_length 9 | L: number of autoencoder latents 10 | H: Number of features being processed in a phase 11 | N: total_sampled_tokens = (num_contexts * num_sampled_tokens) 12 | T: block_size (same as nanoGPT) 13 | B: gpt_batch_size (same as nanoGPT) 14 | C: n_embd a.k.a d_model (same as nanoGPT) 15 | V: vocab_size 16 | SI: samples_per_interval 17 | 18 | Run on a Macbook as 19 | python build_website.py --device=cpu --dataset=shakespeare_char --gpt_ckpt_dir=out_sc_1_2_32 --sae_ckpt_dir=1712254759.95 20 | """ 21 | 22 | from dataclasses import dataclass 23 | import torch 24 | from tensordict import TensorDict 25 | import os 26 | import sys 27 | from math import ceil 28 | from main_page import create_main_html_page 29 | from subpages import write_alive_feature_page, write_dead_feature_page, write_ultralow_density_feature_page 30 | 31 | sys.path.insert(1, '../') 32 | from resource_loader import ResourceLoader 33 | from utils.plotting_utils import make_activations_histogram, make_logits_histogram 34 | 35 | # hyperparameters 36 | # data and model 37 | dataset = 'openwebtext' 38 | gpt_ckpt_dir = 'out' 39 | sae_ckpt_dir = 0.0 # subdirectory containing the specific model to consider 40 | # feature page hyperparameter 41 | num_contexts = 10000 42 | num_sampled_tokens = 10 # number of tokens in each context on which feature activations will be computed 43 | window_radius = 4 # number of tokens to print on either side of a sampled token.. # V / R 44 | num_top_activations = 10 # number of top activations for each feature 45 | num_intervals = 12 # number of intervals to divide activations in; = 12 in Anthropic's work 46 | samples_per_interval = 5 # number of examples to sample from each interval of activations 47 | # evaluation hyperparameters 48 | gpt_batch_size = 156 49 | num_phases = 52 # due to memory constraints, it's useful to process features in phases. 50 | # system 51 | device = 'cuda' # change it to cpu 52 | # reproducibility 53 | seed = 1442 54 | 55 | @dataclass 56 | class FeatureBrowserConfig: 57 | # dataset and model 58 | dataset: str = "openwebtext" 59 | gpt_ckpt_dir: str = "out" 60 | sae_ckpt_dir: str = "out" 61 | # feature browser hyperparameters 62 | num_contexts: int = int(1e6) 63 | num_sampled_tokens: int = 10 64 | window_radius: int = 4 65 | num_top_activations: int = 10 66 | num_intervals: int = 12 67 | samples_per_interval: int = 5 68 | # processing hyperparameters 69 | seed: int = 0 70 | device: str = "cpu" 71 | gpt_batch_size: int = 156 72 | num_phases: int = 52 73 | 74 | class FeatureBrowser(ResourceLoader): 75 | def __init__(self, config): 76 | super().__init__( 77 | dataset=config.dataset, 78 | gpt_ckpt_dir=config.gpt_ckpt_dir, 79 | device=config.device, 80 | mode="eval", 81 | sae_ckpt_dir=str(config.sae_ckpt_dir), 82 | ) 83 | 84 | # retrieve feature browser hyperparameters from config 85 | self.num_contexts = config.num_contexts 86 | self.num_sampled_tokens = config.num_sampled_tokens 87 | self.window_radius = config.window_radius 88 | self.num_top_activations = num_top_activations 89 | self.num_intervals = num_intervals 90 | self.samples_per_interval = samples_per_interval 91 | self.window_length = 2 * self.window_radius + 1 92 | self.total_sampled_tokens = self.num_contexts * self.num_sampled_tokens # Define total sampled tokens 93 | 94 | self.gpt_batch_size = config.gpt_batch_size 95 | self.n_features = self.autoencoder.n_latents 96 | self.num_phases = config.num_phases 97 | self.num_features_per_phase = ceil(self.n_features / self.num_phases) 98 | self.num_batches = ceil(self.num_contexts / self.gpt_batch_size) 99 | 100 | self.X, self.Y = self.get_text_batch(num_contexts=self.num_contexts) # sample text data for analysis # (B, T) 101 | self.encode, self.decode = self.load_tokenizer() 102 | self.html_out = os.path.join(os.path.dirname(os.path.abspath('.')), 'out', config.dataset, str(config.sae_ckpt_dir)) 103 | self.seed = config.seed 104 | 105 | # create subdirectories to store logit histograms and feature activation histograms 106 | os.makedirs(os.path.join(self.html_out, 'logits_histograms'), exist_ok=True) 107 | os.makedirs(os.path.join(self.html_out, 'activations_histograms'), exist_ok=True) 108 | 109 | # TODO: why are my logits of the order of 10, while Anthropic's are <1. Do they rescale them? 110 | # Or is it because of linear approximation to LayerNorm? 111 | 112 | # self.attributed logits is a tensor of shape (n_features, vocab_size) containing logits for each feature 113 | self.attributed_logits = self.compute_logits() 114 | self.top_logits, self.bottom_logits = self.compute_top_and_bottom_logits() 115 | 116 | print(f"Will process features in {self.num_phases} phases. Each phase will have forward pass in {self.num_batches} batches") 117 | 118 | # TODO: actually, a lot more information may be initialized here. 119 | # e.g. context window data, perhaps current phase, which feature is being processed, top activations data for that feature, etc. 120 | # That way, I will not have to pass so many variables to each method. 121 | # can also have top_acts_data here and sampled_intervals data 122 | 123 | def build(self): 124 | """ 125 | Logic: Process features in `num_phases` phases. 126 | In each phase, compute feature activations and feature ablations for (MLP activations of) text data `self.X`. 127 | Sample context windows from this data. 128 | Next, use `get_top_activations` to get tokens (along with windows) with top activations for each feature. 129 | Note that sampled tokens are the same in all phases, thanks to the use of fn_seed in `_sample_context_windows`. 130 | """ 131 | self.write_main_page() 132 | 133 | 134 | for phase in range(self.num_phases): 135 | feature_start_idx = phase * self.num_features_per_phase 136 | feature_end_idx = min((phase + 1) * self.num_features_per_phase, self.n_features) 137 | print(f'working on features # {feature_start_idx} - {feature_end_idx} in phase {phase + 1}/{self.num_phases}') 138 | context_window_data = self.compute_context_window_data(feature_start_idx, feature_end_idx) 139 | top_acts_data = self.compute_top_activations(context_window_data) 140 | for h in range(0, feature_end_idx-feature_start_idx): 141 | # make and save histogram of logits for this feature 142 | feature_id = phase * self.num_features_per_phase + h 143 | make_logits_histogram(logits=self.attributed_logits[feature_id, :], 144 | feature_id=feature_id, 145 | dirpath=self.html_out) 146 | # write the page for this feature 147 | self.write_feature_page(phase, h, context_window_data, top_acts_data) 148 | 149 | # if phase == 0: 150 | # print(f'stored new feature browser pages in {self.html_out}') 151 | # break 152 | 153 | def compute_context_window_data(self, feature_start_idx, feature_end_idx): 154 | """Compute data of tokens and feature activations for all context windows. 155 | This should probably also include feature ablations.""" 156 | context_window_data = self._initialize_context_window_data(feature_start_idx, feature_end_idx) 157 | 158 | for iter in range(self.num_batches): 159 | if iter % 20 == 0: 160 | print(f"computing feature activations for batches {iter+1}-{min(iter+20, self.num_batches)}/{self.num_batches}") 161 | batch_start_idx = iter * self.gpt_batch_size 162 | batch_end_idx = (iter + 1) * self.gpt_batch_size 163 | x, feature_activations, logits_difference_storage = self._compute_batch_feature_activations(batch_start_idx, 164 | batch_end_idx, 165 | feature_start_idx, 166 | feature_end_idx) 167 | # x: (B, T), # feature_activations: (B, T, H), # logits_difference_storage: (B, T, H) 168 | x_context_windows, feature_acts_context_windows, logits_difference_context_window = self._sample_context_windows( x, 169 | feature_activations, 170 | logits_difference_storage, 171 | fn_seed=self.seed+iter) 172 | # context_window_tokens: (B * S, W), context_window_feature_acts: (B * S, W, H), 173 | # logits_difference_context_window: (B * S, W, H) 174 | idx_start = batch_start_idx * self.num_sampled_tokens 175 | idx_end = batch_end_idx * self.num_sampled_tokens 176 | context_window_data["tokens"][idx_start:idx_end] = x_context_windows 177 | context_window_data["feature_acts"][idx_start:idx_end] = feature_acts_context_windows 178 | context_window_data["logits_diff"][idx_start:idx_end] = logits_difference_context_window 179 | 180 | return context_window_data 181 | 182 | def compute_top_activations(self, data): 183 | """Computes top activations of given context window data. 184 | `data` is a TensorDict with keys `tokens` and `feature_acts` of shapes (B*S, W) and (B * S, W, H) respectively.""" 185 | 186 | num_features = data["feature_acts"].shape[-1] # Label this as H. 187 | 188 | # Find the indices of the top activations at the center of the window 189 | _, top_indices = torch.topk(data["feature_acts"][:, self.window_radius, :], 190 | k=self.num_top_activations, dim=0) # (k, H) 191 | 192 | # Prepare the tokens corresponding to the top activations 193 | tokens_with_top_acts = data["tokens"][top_indices].transpose(dim0=1, dim1=2) # (k, W, H) 194 | 195 | # Extract and stack the top feature activations for each feature across all windows 196 | top_feature_activations = torch.stack( 197 | [data["feature_acts"][top_indices[:, i], :, i] for i in range(num_features)], 198 | dim=-1 199 | ) # (k, W, H) 200 | 201 | logits_diff_for_top_acts = torch.stack( 202 | [data["logits_diff"][top_indices[:, i], :, i] for i in range(num_features)], 203 | dim=-1 204 | ) # (k, W, H) 205 | 206 | # Bundle the top tokens and feature activations into a structured data format 207 | top_activations_data = TensorDict({ 208 | "tokens": tokens_with_top_acts, 209 | "feature_acts": top_feature_activations, 210 | "logits_diff": logits_diff_for_top_acts, 211 | }, batch_size=[self.num_top_activations, self.window_length, num_features]) # (k, W, H) 212 | 213 | return top_activations_data 214 | 215 | @torch.no_grad() 216 | def compute_logits(self,): 217 | """ 218 | Computes logits for each feature through path expansion approach. 219 | Returns a torch tensor of shape (num_features, vocab_size) 220 | By default, it uses full LayerNorm instead of its linear approximation. # TODO: understand if that's okay 221 | # also, this function is specific to SAEs trained on the activations of last MLP layer for now. TODO: generalize this 222 | By default, logits for each feature are shifted so that the median value is 0. 223 | """ 224 | mlp_out = self.transformer.transformer.h[-1].mlp.c_proj(self.autoencoder.decoder.weight.detach().t()) # (L, C) 225 | ln_out = self.transformer.transformer.ln_f(mlp_out) # (L, C) 226 | logits = self.transformer.lm_head(ln_out) # (L, V) 227 | attributed_logits = (logits - logits.median(dim=1, keepdim=True).values) # (L, V) 228 | return attributed_logits 229 | 230 | @torch.no_grad() 231 | def compute_top_and_bottom_logits(self,): 232 | """ 233 | Computes top and bottom logits for each feature. 234 | Returns (top_logits, bottom_logits). Each is of type `torch.return_types.topk`. 235 | """ 236 | # GPT-2 tokenizer has vocab size 50257. nanoGPT sets vocab size = 50304 for higher training speed. 237 | # See https://twitter.com/karpathy/status/1621578354024677377?lang=en 238 | # Decoder will give an error if a token with id > 50256 is given, and bottom_logits may pick one of these tokens. 239 | # Therefore, set max token id to 50256 by hand. 240 | attributed_logits = self.attributed_logits[:, :50257] 241 | top_logits = torch.topk(attributed_logits, largest=True, sorted=True, k=self.num_top_activations, dim=1) # (L, k) 242 | bottom_logits = torch.topk(attributed_logits, largest=False, sorted=True, k=self.num_top_activations, dim=1) # (L, k) 243 | return top_logits, bottom_logits 244 | 245 | def write_feature_page(self, phase, h, data, top_acts_data): 246 | """"Writes features pages for dead / alive neurons; also makes a histogram. 247 | For alive features, it calls sample_and_write.""" 248 | 249 | curr_feature_acts_MW = data["feature_acts"][:, :, h] 250 | mid_token_feature_acts_M = curr_feature_acts_MW[:, self.window_radius] 251 | num_nonzero_acts = torch.count_nonzero(mid_token_feature_acts_M) 252 | 253 | feature_id = phase * self.num_features_per_phase + h 254 | if num_nonzero_acts == 0: 255 | write_dead_feature_page(feature_id=feature_id, dirpath=self.html_out) 256 | return 257 | 258 | act_density = torch.count_nonzero(curr_feature_acts_MW) / (self.total_sampled_tokens * self.window_length) * 100 259 | non_zero_acts = curr_feature_acts_MW[curr_feature_acts_MW != 0] 260 | make_activations_histogram(activations=non_zero_acts, 261 | density=act_density, 262 | feature_id=feature_id, 263 | dirpath=self.html_out) 264 | 265 | 266 | # TODO: for ultralow density and other alive neurons, I will need to give feature ablation data 267 | if num_nonzero_acts < self.num_intervals * self.samples_per_interval: 268 | write_ultralow_density_feature_page(feature_id=feature_id, 269 | decode=self.decode, 270 | top_acts_data=top_acts_data[:num_nonzero_acts, :, h], 271 | dirpath=self.html_out) 272 | return 273 | 274 | self.sample_and_write(data, feature_id, num_nonzero_acts, mid_token_feature_acts_M, curr_feature_acts_MW, top_acts_data, h) 275 | 276 | def sample_and_write(self, data, feature_id, num_nonzero_acts, mid_token_feature_acts_M, curr_feature_acts_MW, top_acts_data, h): 277 | _, sorted_indices = torch.sort(mid_token_feature_acts_M, descending=True) # (N*S,) 278 | sampled_indices = torch.stack([ 279 | j * num_nonzero_acts // self.num_intervals + 280 | torch.randperm(num_nonzero_acts // self.num_intervals)[:self.samples_per_interval].sort()[0] 281 | for j in range(self.num_intervals) 282 | ], dim=0) 283 | original_indices = sorted_indices[sampled_indices] # TODO: explain sampled_indices and original_indices 284 | sampled_acts_data = TensorDict({ 285 | "tokens": data["tokens"][original_indices], 286 | "feature_acts": curr_feature_acts_MW[original_indices], 287 | "logits_diff": curr_feature_acts_MW[original_indices], 288 | }, batch_size=[self.num_intervals, self.samples_per_interval, self.window_length]) # (I, SI, W) 289 | 290 | write_alive_feature_page(feature_id=feature_id, 291 | decode=self.decode, 292 | top_logits=self.top_logits, 293 | bottom_logits=self.bottom_logits, 294 | top_acts_data=top_acts_data[:, :, h], 295 | sampled_acts_data=sampled_acts_data, 296 | dirpath=self.html_out) 297 | 298 | def _sample_context_windows(self, *args, fn_seed=0): 299 | """ 300 | Select windows of tokens around randomly sampled tokens from input tensors. 301 | 302 | Given tensors each of shape (B, T, ...), this function returns tensors containing 303 | windows around randomly selected tokens. The shape of the output is (B * S, W, ...), 304 | where S is the number of tokens in each context to evaluate, and W is the window size 305 | (including the token itself and tokens on either side). By default, S = self.num_sampled_tokens, 306 | W = self.window_length. 307 | 308 | Parameters: 309 | - args: Variable number of tensor arguments, each of shape (B, T, ...) 310 | - fn_seed (int, optional): Seed for random number generator, default is 0 311 | """ 312 | if not args or not all(isinstance(tensor, torch.Tensor) and tensor.ndim >= 2 for tensor in args): 313 | raise ValueError("All inputs must be torch tensors with at least 2 dimensions.") 314 | 315 | # Ensure all tensors have the same shape in the first two dimensions 316 | B, T = args[0].shape[:2] 317 | if not all(tensor.shape[:2] == (B, T) for tensor in args): 318 | raise ValueError("All tensors must have the same shape along the first two dimensions.") 319 | 320 | torch.manual_seed(fn_seed) 321 | num_sampled_tokens=self.num_sampled_tokens 322 | token_idx = torch.stack([self.window_radius + torch.randperm(T - 2 * self.window_radius)[:num_sampled_tokens] 323 | for _ in range(B)], dim=0) # (B, S) # use of torch.randperm for sampling without replacement 324 | window_idx = token_idx.unsqueeze(-1) + torch.arange(-self.window_radius, self.window_radius + 1) # (B, S, W) 325 | batch_idx = torch.arange(B).view(-1, 1, 1).expand_as(window_idx) # (B, S, W) 326 | 327 | result_tensors = [] 328 | for tensor in args: 329 | if tensor.ndim == 3: 330 | L = tensor.shape[2] 331 | sliced_tensor = tensor[batch_idx, window_idx, :] # (B, S, W, L) 332 | sliced_tensor = sliced_tensor.view(-1, self.window_length, L) # (B *S , W, L) 333 | elif tensor.ndim == 2: 334 | sliced_tensor = tensor[batch_idx, window_idx] # (B, S, W) 335 | sliced_tensor = sliced_tensor.view(-1, self.window_length) # (B*S, W) 336 | else: 337 | raise ValueError("Tensor dimensions not supported. Only 2D and 3D tensors are allowed.") 338 | result_tensors.append(sliced_tensor) 339 | 340 | return result_tensors 341 | 342 | def _initialize_context_window_data(self, feature_start_idx, feature_end_idx): 343 | num_features_in_phase = feature_end_idx - feature_start_idx 344 | context_window_data = TensorDict({ 345 | "tokens": torch.zeros(self.total_sampled_tokens, self.window_length, dtype=torch.int32), 346 | "feature_acts": torch.zeros(self.total_sampled_tokens, self.window_length, num_features_in_phase), 347 | "logits_diff": torch.zeros(self.total_sampled_tokens, self.window_length, num_features_in_phase), 348 | }, batch_size=[self.total_sampled_tokens, self.window_length]) # (N * S, W) 349 | return context_window_data 350 | 351 | @torch.no_grad() 352 | def _compute_batch_feature_activations(self, batch_start_idx, batch_end_idx, feature_start_idx, feature_end_idx): 353 | """Computes feature activations for given batch of input text. 354 | """ 355 | x = self.X[batch_start_idx:batch_end_idx].to(self.device) 356 | y = self.Y[batch_start_idx:batch_end_idx].to(self.device) 357 | B, T = x.shape 358 | H = feature_end_idx - feature_start_idx # number of features in this phase 359 | original_logits, _ = self.transformer(x, y) # original_logits.shape = (B, T, V) 360 | mlp_acts = self.transformer.mlp_activation_hooks[0] # (B, T, 4C) 361 | self.transformer.clear_mlp_activation_hooks() 362 | feature_activations = self.autoencoder.get_feature_activations(inputs=mlp_acts, 363 | start_idx=feature_start_idx, 364 | end_idx=feature_end_idx) # (B, T, H) 365 | 366 | dictionary_vectors = self.autoencoder.decoder.weight[:, feature_start_idx:feature_end_idx] # (4C, H) 367 | feature_projections = feature_activations.unsqueeze(2) * dictionary_vectors # (B, T, 4C, H) 368 | feature_ablations = mlp_acts.unsqueeze(-1) - feature_projections # (B, T, 4C, H) 369 | 370 | # TODO: do I need to center the median at 0 before computing differences? 371 | # Otherwise, the probability of sampling the token can probably not be compared through the logit weight alone. 372 | logits_difference_storage = torch.zeros(B, T, H, device=self.device) # (B, T, H) 373 | for h in range(H): 374 | feat_ablation_logits, _ = self.transformer(x, y, mode="replace", replacement_tensor=feature_ablations[:, :, :, h]) # (B, T, V) 375 | 376 | logits_difference = original_logits - feat_ablation_logits # (B, T, V) 377 | 378 | # restrict logits difference to those for the next token 379 | logits_difference_storage[:, :, h] = logits_difference[torch.arange(B)[:, None], torch.arange(T), y] # (B, T) 380 | 381 | return x, feature_activations, logits_difference_storage 382 | 383 | def write_main_page(self): 384 | create_main_html_page(n_features=self.n_features, dirpath=self.html_out) 385 | 386 | if __name__ == "__main__": 387 | 388 | # ----------------------------------------------------------------------------- 389 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 390 | configurator = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'configurator.py') 391 | exec(open(configurator).read()) # overrides from command line or config file 392 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 393 | # ----------------------------------------------------------------------------- 394 | 395 | torch.manual_seed(seed) 396 | config = FeatureBrowserConfig(**config) 397 | feature_browser = FeatureBrowser(config) 398 | # Run the processing 399 | feature_browser.build() 400 | 401 | 402 | # TODO: tooltip css function should be imported separately and written explicitly I think, for clarity 403 | # TODO: methods that need to be revisited: write_feature_page, sample_and_write. 404 | # TODO: it would be nice if the final output does not depend on num_phases. Set seed for each feature separately? 405 | # TODO: we don't really need autoencoder data in eval mode. Change this in resourceloader. -------------------------------------------------------------------------------- /autoencoder/feature-browser/main_page.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def write_main_page(n_features): 4 | 5 | main = f""" 6 | 7 | 8 | 9 | Feature Visualization 10 | 33 | 34 | 35 |

Feature Browser

36 |

Slide to select a neuron number (0 to {n_features-1}) or enter it below:

37 | 38 | 39 | 40 | 0 41 | 42 | 43 | 44 | 45 | 46 | 47 |
48 | 49 |
50 | 51 | 121 | 122 | 123 | 124 | """ 125 | return main 126 | 127 | def write_tooltip_css_file(): 128 | tooltip_css = f"""/* Style for the tooltip */ 129 | .tooltip {{ 130 | position: relative; 131 | display: inline-block; 132 | cursor: pointer; 133 | margin-right: -4px; 134 | }} 135 | 136 | /* Style for the tooltip trigger text */ 137 | .tooltip > span {{ 138 | background-color: #FFCC99; /* Light orange background color */ 139 | color: #333; /* Dark text color for contrast */ 140 | padding: 0px; /* Add padding to make the background more prominent */ 141 | border-radius: 0px; /* Optional: Adds rounded corners to the background */ 142 | }} 143 | 144 | /* Style for the tooltip content */ 145 | .tooltip .tooltiptext {{ 146 | visibility: hidden; 147 | width: 280px; /* Increased width */ 148 | background-color: #333; 149 | color: #fff; 150 | text-align: center; 151 | border-radius: 6px; 152 | padding: 5px 10px; /* Add horizontal padding if needed */ 153 | position: absolute; 154 | z-index: 1; 155 | bottom: 125%; 156 | left: 50%; 157 | margin-left: -140px; /* Adjusted to half of the new width to keep it centered */ 158 | opacity: 0; 159 | transition: opacity 0.3s; 160 | white-space: pre-wrap; 161 | overflow: hidden; /* Ensures the content does not spill outside the tooltip */ 162 | }} 163 | 164 | /* Show the tooltip content when hovering over the tooltip */ 165 | .tooltip:hover .tooltiptext {{ 166 | visibility: visible; 167 | opacity: 1; 168 | }} 169 | 170 | /* Style for the tooltip trigger text with the default color */ 171 | .tooltip > span.default-color {{ 172 | background-color: #FFCC99; /* Light orange background color */ 173 | color: #333; /* Dark text color for contrast */ 174 | padding: 0px; 175 | border-radius: 0px; 176 | }} 177 | 178 | /* Style for the tooltip trigger text with white color */ 179 | .tooltip > span.white-color {{ 180 | background-color: #FFFFFF; /* White background color */ 181 | color: #333; /* Dark text color for contrast */ 182 | padding: 0px; 183 | border-radius: 0px; 184 | }} 185 | """ 186 | 187 | return tooltip_css 188 | 189 | def create_main_html_page(n_features, dirpath=None): 190 | # create a directory to store feature information 191 | os.makedirs(os.path.join(dirpath, 'feature_pages'), exist_ok=True) 192 | # create a directory to store histograms of feature activations 193 | os.makedirs(os.path.join(dirpath, 'histograms'), exist_ok=True) 194 | # write a helper css file tooltip.css in autoencoder_subdir 195 | with open(os.path.join(dirpath, f'tooltip.css'), 'w') as file: 196 | file.write(write_tooltip_css_file()) 197 | # write the main page for html 198 | with open(os.path.join(dirpath, 'main.html'), 'w') as file: 199 | file.write(write_main_page(n_features)) 200 | print(f'created main page for HTML interface in {dirpath}') -------------------------------------------------------------------------------- /autoencoder/feature-browser/subpages.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Import a dictionary of feature activations from autoencoder_dir/autoencoder_subdir/feature_infos.pkl and write HTML pages 4 | A couple of sample runs: 5 | python write_html.py --k=5 --num_intervals=3 --interval_exs=2 --dataset=shakespeare_char --autoencoder_subdir=1704914564.90-autoencoder-shakespeare_char 6 | python write_html.py --k=5 --num_intervals=6 --interval_exs=3 --autoencoder_subdir=1705203324.45-autoencoder-openwebtext 7 | """ 8 | 9 | import os 10 | import torch 11 | from tensordict import TensorDict 12 | 13 | def write_feature_page_header(): 14 | header = f""" 15 | 16 | 17 | 18 | 19 | 20 | 21 | 77 | 78 | 79 |

80 | """ 81 | return header 82 | 83 | def write_dead_feature_page(feature_id, dirpath=None): 84 | html_content = [] 85 | # add page_header to list of texts 86 | html_content.append(write_feature_page_header()) 87 | # add dead neuron text 88 | html_content.append(""" 89 |

Dead Neuron.

90 |
91 | 92 | """) 93 | # write the HTML file 94 | with open(os.path.join(dirpath, 'feature_pages', f'{feature_id}.html'), 'w') as file: 95 | file.write("".join(html_content)) 96 | 97 | def write_activation_example(decode, tokens, activations, logits_diff): 98 | assert isinstance(tokens, torch.Tensor) and isinstance(activations, torch.Tensor), "expect inputs to be torch tensors" 99 | assert tokens.ndim == 1 and activations.ndim == 1, "expect tokens and acts to be 1d tensors" 100 | html_content = [] 101 | W = tokens.shape[0] 102 | mid_token_index = (W-1)//2 103 | 104 | start_bold_text = lambda j, mid_token_index: "" if j == mid_token_index else "" 105 | end_bold_text = lambda j, mid_token_index: "" if j == mid_token_index else "" 106 | 107 | for j in range(W): 108 | char = decode([tokens[j].item()]) 109 | activation = activations[j].item() 110 | logit_diff = logits_diff[j].item() 111 | # TODO: make sure my writing of logit_diff is not off by one token in position 112 | 113 | # for html rendering, replace newline character with its HTML counterpart 114 | char = char.replace('\n', '') 115 | 116 | text_color = "default-color" if activation > 0 else "white-color" 117 | underline_color = "underline-red" if logit_diff > 0 else "underline-blue" if logit_diff < 0 else "" 118 | single_token_text = f""" 119 |
120 | {start_bold_text(j, mid_token_index)} {char.replace(' ', ' ')} {end_bold_text(j, mid_token_index)} 121 | 122 | Token: {char} 123 | Activation: {activation:.4f} 124 | Feature Ablation: {logit_diff:.4f} 125 | 126 |
""" 127 | html_content.append(single_token_text) 128 | html_content.append("""
""") 129 | return "".join(html_content) 130 | 131 | 132 | def write_activations_section(decode, examples_data): 133 | assert examples_data.ndim == 2, "input must be two dimensional, shape: (X, W) or (k, W)" 134 | n, W = examples_data.shape 135 | 136 | html_content = [] 137 | html_content.append(f""" 138 |

Max Activation = {examples_data[0]['feature_acts'][(W-1)//2]:.4f}

139 | """) 140 | 141 | for i in range(n): 142 | html_content.append(write_activation_example(decode, 143 | tokens=examples_data["tokens"][i], 144 | activations=examples_data["feature_acts"][i], 145 | logits_diff=examples_data["logits_diff"][i])) 146 | return "".join(html_content) 147 | 148 | 149 | def include_feature_density_histogram(feature_id, dirpath=None): 150 | if os.path.exists(os.path.join(dirpath, 'activations_histograms', f'{feature_id}.png')): 151 | feature_density_histogram = f""" 152 |
153 | Feature Activations Histogram 154 |
""" 155 | return feature_density_histogram 156 | else: 157 | return "" 158 | 159 | def include_logits_histogram(feature_id, dirpath=None): 160 | if os.path.exists(os.path.join(dirpath, 'logits_histograms', f'{feature_id}.png')): 161 | logits_histogram = f""" 162 |
163 | Logits Histogram 164 |
165 | 166 | """ 167 | return logits_histogram 168 | else: 169 | return "" 170 | 171 | def include_top_and_bottom_logits(top_logits, bottom_logits, decode, feature_id): 172 | # TODO: replace 10 with num_top_activations 173 | logits_text = ["""
"""] 174 | logits_text.append("""
""") 175 | logits_text.append("""

Negative Logits

""") 176 | for i in range(10): 177 | token = decode([bottom_logits.indices[feature_id, i].tolist()]) 178 | token_html = token.replace('\n', '') 179 | logits_line = f"""
180 | 181 | {token_html} 182 | 183 | {bottom_logits.values[feature_id, i]:.4f} 184 |
""" 185 | logits_text.append(logits_line) 186 | logits_text.append("""
""") 187 | 188 | logits_text.append("""
""") 189 | logits_text.append("""

Positive Logits

""") 190 | for i in range(10): 191 | token = decode([top_logits.indices[feature_id, i].tolist()]) 192 | token_html = token.replace('\n', '') 193 | logits_line = f"""
194 | 195 | {token_html} 196 | 197 | {top_logits.values[feature_id, i]:.4f} 198 |
""" 199 | logits_text.append(logits_line) 200 | logits_text.append("""
""") 201 | logits_text.append("""
""") 202 | return "".join(logits_text) 203 | 204 | # TODO: merge write_alive_feature_page and write_ultralow_density_feature_page into one single function 205 | 206 | def write_alive_feature_page(feature_id, decode, top_logits, bottom_logits, top_acts_data, sampled_acts_data, dirpath=None): 207 | 208 | print(f'writing feature page for feature # {feature_id}') 209 | 210 | assert isinstance(top_acts_data, TensorDict), "expect top activation data to be presented in a TensorDict" 211 | assert top_acts_data.ndim == 2, "expect top activation data TensorDict to be 2-dimensional, shape: (k, W)" 212 | 213 | assert isinstance(sampled_acts_data, TensorDict), "expect samples activation data to be presented in a TensorDict" 214 | assert sampled_acts_data.ndim == 3, "expect sampled activation data TensorDict to be 3-dimensional, shape: (I, X, W)" 215 | 216 | assert 'tokens' in top_acts_data.keys() and 'feature_acts' in top_acts_data.keys() and \ 217 | 'tokens' in sampled_acts_data.keys() and 'feature_acts' in sampled_acts_data.keys(), \ 218 | "expect input TensorDicts to have tokens and features_acts keys" 219 | 220 | html_content = [] 221 | 222 | # add page_header to the HTML page 223 | html_content.append(write_feature_page_header()) 224 | html_content.append("""
225 |
""") 226 | 227 | # add histogram of feature activations, top and bottom logits and logits histogram 228 | html_content.append(include_feature_density_histogram(feature_id, dirpath=dirpath)) 229 | html_content.append(include_top_and_bottom_logits(top_logits, bottom_logits, decode, feature_id)) 230 | html_content.append(include_logits_histogram(feature_id, dirpath=dirpath)) 231 | 232 | # add feature #, and the information that it is an ultralow density neuron 233 | html_content.append(f"""
234 |

Neuron # {feature_id}

""") 235 | 236 | # include a section on top activations 237 | html_content.append(""" 238 |

Top Activations

239 | """) 240 | html_content.append(write_activations_section(decode, top_acts_data)) 241 | 242 | # include a section on sampled activations 243 | I = sampled_acts_data.shape[0] # number of intervals 244 | for i in range(I): 245 | if i < I - 1: 246 | html_content.append(f"

Subsample Interval {i}

") 247 | else: 248 | html_content.append(f"

Bottom Activations

") 249 | html_content.append(write_activations_section(decode, sampled_acts_data[i])) 250 | 251 | # include the end of the HTML page 252 | html_content.append(" ") 253 | with open(os.path.join(dirpath, 'feature_pages', f'{feature_id}.html'), 'w') as file: 254 | file.write("".join(html_content)) 255 | 256 | def write_ultralow_density_feature_page(feature_id, decode, top_acts_data, dirpath=None): 257 | 258 | print(f'writing feature page for feature # {feature_id}') 259 | 260 | assert isinstance(top_acts_data, TensorDict), "expect top activation data to be presented in a TensorDict" 261 | assert top_acts_data.ndim == 2, "expect top activation data TensorDict to be 2-dimensional, shape: (n, W)" 262 | 263 | assert 'tokens' in top_acts_data.keys() and 'feature_acts' in top_acts_data.keys() and \ 264 | "expect input TensorDict to have tokens and features_acts keys" 265 | 266 | html_content = [] 267 | 268 | # add page_header to the HTML page 269 | html_content.append(write_feature_page_header()) 270 | 271 | # add histogram of feature activations 272 | if os.path.exists(os.path.join(dirpath, 'activations_histograms', f'{feature_id}.png')): 273 | html_content.append(f"""
274 |
275 | Feature Activations Histogram 276 |
""") 277 | 278 | # add feature #, and the information that it is an ultralow density neuron 279 | html_content.append(f"""
280 |

Neuron # {feature_id}

281 |

Ultralow Density Neuron

""") 282 | 283 | 284 | # include a section on top activations 285 | html_content.append(""" 286 |

Top Activations

287 | """) 288 | html_content.append(write_activations_section(decode, top_acts_data)) 289 | 290 | # include the end of the HTML page 291 | html_content.append(" ") 292 | with open(os.path.join(dirpath, 'feature_pages', f'{feature_id}.html'), 'w') as file: 293 | file.write("".join(html_content)) -------------------------------------------------------------------------------- /autoencoder/prepare.py: -------------------------------------------------------------------------------- 1 | """" 2 | Prepares training dataset for our autoencoder. 3 | Run on Macbook as 4 | python -u prepare.py --num_contexts=5000 --num_sampled_tokens=16 --dataset=shakespeare_char --gpt_ckpt_dir=out_sc_1_2_32 5 | """ 6 | import os 7 | import torch 8 | import time 9 | import psutil 10 | from resource_loader import ResourceLoader 11 | 12 | # Default parameters, can be overridden by command line arguments or a configuration file 13 | # dataset and model 14 | dataset = 'openwebtext' 15 | gpt_ckpt_dir = 'out' # Model checkpoint directory 16 | # autoencoder data size 17 | num_contexts = int(2e6) # Number of context windows 18 | num_sampled_tokens = 200 # Tokens per context window 19 | # system 20 | device = 'cpu' 21 | num_partitions = 20 # Number of output files 22 | # reproducibility 23 | seed = 0 24 | 25 | # ----------------------------------------------------------------------------- 26 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 27 | exec(open('configurator.py').read()) # overrides from command line or config file 28 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 29 | # ----------------------------------------------------------------------------- 30 | 31 | torch.manual_seed(seed) 32 | 33 | # Load resources and model 34 | resource_loader = ResourceLoader(dataset=dataset, gpt_ckpt_dir=gpt_ckpt_dir, mode="prepare") 35 | gpt = resource_loader.transformer 36 | 37 | # Get model configurations 38 | block_size = gpt.config.block_size 39 | n_ffwd = 4 * gpt.config.n_embd 40 | 41 | # Prepare storage for activations 42 | data_storage = torch.zeros(num_contexts * num_sampled_tokens, n_ffwd, dtype=torch.float32) 43 | shuffled_indices = torch.randperm(num_contexts * num_sampled_tokens) 44 | 45 | def compute_activations(): 46 | start_time = time.time() 47 | gpt_batch_size = 500 48 | n_batches = num_contexts // gpt_batch_size 49 | 50 | for batch in range(n_batches): 51 | # Load batch and compute activations 52 | x, _ = resource_loader.get_text_batch(gpt_batch_size) 53 | _, _ = gpt(x) # Forward pass 54 | activations = gpt.mlp_activation_hooks[0] # Retrieve activations 55 | 56 | # Clean up to save memory 57 | gpt.clear_mlp_activation_hooks() 58 | 59 | # Process and store activations 60 | token_locs = torch.stack([torch.randperm(block_size)[:num_sampled_tokens] for _ in range(gpt_batch_size)]) 61 | data = torch.gather(activations, 1, token_locs.unsqueeze(2).expand(-1, -1, activations.size(2))).view(-1, n_ffwd) 62 | data_storage[shuffled_indices[batch * gpt_batch_size * num_sampled_tokens: (batch + 1) * gpt_batch_size * num_sampled_tokens]] = data 63 | 64 | print(f"Batch {batch}/{n_batches} processed in {(time.time() - start_time) / (batch + 1):.2f} seconds; " 65 | f"Memory: {psutil.virtual_memory().available / (1024 ** 3):.2f} GB available, {psutil.virtual_memory().percent}% used.") 66 | 67 | def save_activations(): 68 | sae_data_dir = os.path.join(os.path.abspath('.'), 'data', dataset, str(n_ffwd)) 69 | os.makedirs(sae_data_dir, exist_ok=True) 70 | examples_per_file = num_contexts * num_sampled_tokens // num_partitions 71 | 72 | for i in range(num_partitions): 73 | file_path = f'{sae_data_dir}/{seed * num_partitions + i}.pt' 74 | if os.path.exists(file_path): 75 | print(f"Warning: File {file_path} already exists and will be overwritten.") 76 | 77 | # Save data to file, cloning to reduce memory usage 78 | torch.save(data_storage[i * examples_per_file: (i + 1) * examples_per_file].clone(), file_path) 79 | print(f'Saved {file_path}') 80 | 81 | if __name__ == '__main__': 82 | compute_activations() 83 | save_activations() 84 | -------------------------------------------------------------------------------- /autoencoder/resource_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import sys 5 | import pickle 6 | import tiktoken 7 | 8 | # Extend the Python path to include the transformer subdirectory for GPT class import 9 | base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.insert(0, os.path.join(base_dir, 'transformer')) 11 | from model import GPTConfig 12 | from hooked_model import HookedGPT 13 | 14 | class ResourceLoader: 15 | """ 16 | Manages resources for training, evaluation, and preparation. 17 | This includes loading datasets, model weights, and handling batches of data. 18 | """ 19 | 20 | def __init__(self, dataset, gpt_ckpt_dir, device='cpu', mode="train", sae_ckpt_dir=""): 21 | assert mode in ["train", "eval", "prepare"], "Invalid mode; must be 'train', 'eval', or 'prepare'." 22 | 23 | self.dataset = dataset # Name of the dataset (e.g., openwebtext, shakespeare_char) 24 | self.gpt_ckpt_dir = gpt_ckpt_dir # Directory containing GPT model weights 25 | self.device = device # Device on which the models will be loaded 26 | self.mode = mode 27 | 28 | # Set the path to the repository as the base directory 29 | current_file_dir = os.path.dirname(os.path.abspath(__file__)) 30 | self.base_dir = os.path.dirname(current_file_dir) 31 | 32 | # Load the text data and transformer model 33 | self.text_data = self.load_text_data() 34 | self.transformer = self.load_transformer_model() 35 | self.n_ffwd = self.transformer.config.n_embd * 4 36 | 37 | if mode == "train": 38 | self.autoencoder_data_dir = os.path.join(self.base_dir, 'autoencoder', 'data', self.dataset, str(self.n_ffwd)) 39 | self.autoencoder_data = self.load_next_autoencoder_partition(partition_id=0) 40 | self.autoencoder_data_info = self.init_autoencoder_data_info() 41 | 42 | if mode == "eval": 43 | assert sae_ckpt_dir, "A path to autoencoder checkpoint must be given" 44 | self.sae_ckpt_dir = sae_ckpt_dir 45 | self.autoencoder = self.load_autoencoder_model() 46 | self.autoencoder.eval() # note that if we load an autoencoder to resume training, we must not do this 47 | 48 | def load_text_data(self): 49 | """Loads the text data from the specified dataset.""" 50 | text_data_path = os.path.join(self.base_dir, 'transformer', 'data', self.dataset, 'train.bin') 51 | if not os.path.exists(text_data_path): 52 | # if train.bin is too large to be loaded to RAM (e.g. on your laptop), can experiment with val.bin. 53 | print(f"train.bin not found; attempting to find val.bin") 54 | text_data_path = os.path.join(self.base_dir, 'transformer', 'data', self.dataset, 'val.bin') 55 | return np.memmap(text_data_path, dtype=np.uint16, mode='r') 56 | 57 | def load_transformer_model(self): 58 | """Loads the GPT model with pre-trained weights.""" 59 | ckpt_path = os.path.join(self.base_dir, 'transformer', self.gpt_ckpt_dir, 'ckpt.pt') 60 | checkpoint = torch.load(ckpt_path, map_location=self.device) 61 | gpt_conf = GPTConfig(**checkpoint['model_args']) 62 | transformer = HookedGPT(gpt_conf) 63 | state_dict = checkpoint['model'] 64 | 65 | # Remove unwanted prefix from state_dict keys 66 | unwanted_prefix = '_orig_mod.' 67 | for key in list(state_dict.keys()): 68 | if key.startswith(unwanted_prefix): 69 | state_dict[key[len(unwanted_prefix):]] = state_dict.pop(key) 70 | 71 | transformer.load_state_dict(state_dict) 72 | transformer.eval() 73 | transformer.to(self.device) 74 | return transformer 75 | 76 | def get_number_of_autoencoder_data_files(self): 77 | """Returns the number of files in the autoencoder data directory.""" 78 | try: 79 | num_partitions = len(next(os.walk(self.autoencoder_data_dir))[2]) 80 | except StopIteration: 81 | raise ValueError("Autoencoder data directory is empty") 82 | return num_partitions 83 | 84 | def init_autoencoder_data_info(self): 85 | """Initializes and returns information about the autoencoder data.""" 86 | num_partitions = self.get_number_of_autoencoder_data_files() 87 | return { 88 | 'num_partitions': num_partitions, 89 | 'current_partition_id': 0, 90 | 'offset': 0, 91 | 'examples_per_partition': self.autoencoder_data.shape[0], 92 | 'total_examples': num_partitions * self.autoencoder_data.shape[0] 93 | } 94 | 95 | def load_autoencoder_model(self): 96 | """Loads the AutoEncoder model with pre-trained weights""" 97 | autoencoder_path = os.path.join(self.base_dir, "autoencoder", "out", self.dataset, self.sae_ckpt_dir) 98 | autoencoder_ckpt = torch.load(os.path.join(autoencoder_path, 'ckpt.pt'), map_location=self.device) 99 | state_dict = autoencoder_ckpt['autoencoder'] 100 | n_features, n_ffwd = state_dict['encoder.weight'].shape # H, F 101 | l1_coeff = autoencoder_ckpt['config']['l1_coeff'] 102 | from autoencoder import AutoEncoder 103 | autoencoder = AutoEncoder(n_ffwd, n_features, lam=l1_coeff).to(self.device) 104 | autoencoder.load_state_dict(state_dict) 105 | return autoencoder 106 | 107 | def get_text_batch(self, num_contexts): 108 | """Generates and returns a batch of text data for training or evaluation.""" 109 | block_size = self.transformer.config.block_size 110 | ix = torch.randint(len(self.text_data) - block_size, (num_contexts,)) 111 | X = torch.stack([torch.from_numpy(self.text_data[i:i+block_size].astype(np.int64)) for i in ix]) 112 | Y = torch.stack([torch.from_numpy(self.text_data[i+1:i+1+block_size].astype(np.int64)) for i in ix]) 113 | return X.to(device=self.device), Y.to(device=self.device) 114 | 115 | def get_autoencoder_data_batch(self, step, batch_size=8192): 116 | """ 117 | Retrieves a batch of autoencoder data based on the step and batch size. 118 | It loads the next data partition if the batch exceeds the current partition. 119 | """ 120 | info = self.autoencoder_data_info 121 | batch_start = step * batch_size - info["current_partition_id"] * info["examples_per_partition"] - info["offset"] 122 | batch_end = batch_start + batch_size 123 | 124 | if batch_end > info["examples_per_partition"]: 125 | # When batch exceeds current partition, load data from the next partition 126 | if info["current_partition_id"] < info["num_partitions"] - 1: 127 | remaining = info["examples_per_partition"] - batch_start 128 | batch = self.autoencoder_data[batch_start:] 129 | info["current_partition_id"] += 1 130 | self.load_next_autoencoder_partition(info["current_partition_id"]) 131 | batch = torch.cat([batch, self.autoencoder_data[:batch_size - remaining]]) 132 | info["offset"] = batch_size - remaining 133 | else: 134 | raise IndexError("Autoencoder data batch request exceeds available partitions.") 135 | else: 136 | batch = self.autoencoder_data[batch_start:batch_end] 137 | 138 | assert len(batch) == batch_size, f"Batch length mismatch at step {step}" 139 | return batch.to(self.device) 140 | 141 | def load_next_autoencoder_partition(self, partition_id): 142 | """ 143 | Loads the specified partition of the autoencoder data. 144 | """ 145 | file_path = os.path.join(self.autoencoder_data_dir, f'{partition_id}.pt') 146 | self.autoencoder_data = torch.load(file_path) 147 | return self.autoencoder_data 148 | 149 | def select_resampling_data(self, size=819200): 150 | """ 151 | Selects a subset of autoencoder data for resampling, distributed evenly across partitions. 152 | """ 153 | info = self.autoencoder_data_info 154 | num_samples_per_partition = size // info["num_partitions"] 155 | resampling_data = torch.zeros(size, self.n_ffwd) 156 | 157 | for partition_id in range(info["num_partitions"]): 158 | partition_data = torch.load(os.path.join(self.autoencoder_data_dir, f'{partition_id}.pt')) 159 | sample_indices = torch.randint(info["examples_per_partition"], (num_samples_per_partition,)) 160 | start_index = partition_id * num_samples_per_partition 161 | resampling_data[start_index:start_index + num_samples_per_partition] = partition_data[sample_indices] 162 | 163 | return resampling_data 164 | 165 | def load_tokenizer(self): 166 | load_meta = False 167 | meta_path = os.path.join(self.base_dir, 'transformer', 'data', self.dataset, 'meta.pkl') 168 | load_meta = os.path.exists(meta_path) 169 | if load_meta: 170 | print(f"Loading meta from {meta_path}...") 171 | with open(meta_path, 'rb') as f: 172 | meta = pickle.load(f) 173 | # TODO want to make this more general to arbitrary encoder/decoder schemes 174 | stoi, itos = meta['stoi'], meta['itos'] 175 | encode = lambda s: [stoi[c] for c in s] 176 | decode = lambda l: ''.join([itos[i] for i in l]) 177 | else: 178 | # ok let's assume gpt-2 encodings by default 179 | print("No meta.pkl found, assuming GPT-2 encodings...") 180 | enc = tiktoken.get_encoding("gpt2") 181 | encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) 182 | decode = lambda l: enc.decode(l) 183 | return encode, decode -------------------------------------------------------------------------------- /autoencoder/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a Sparse AutoEncoder model 3 | 4 | Run on a macbook on a Shakespeare dataset as 5 | python train.py --dataset=shakespeare_char --gpt_ckpt_dir=out_sc_1_2_32 --eval_iters=1 --eval_batch_size=16 --batch_size=128 --device=cpu --eval_interval=100 --n_features=1024 --resampling_interval=150 --wandb_log=True 6 | """ 7 | import os 8 | import torch 9 | import numpy as np 10 | import time 11 | from autoencoder import AutoEncoder 12 | from resource_loader import ResourceLoader 13 | from utils.plotting_utils import make_density_histogram 14 | 15 | ## hyperparameters 16 | # dataset and model 17 | dataset = 'openwebtext' 18 | gpt_ckpt_dir = 'out' 19 | # training 20 | n_features = 4096 21 | batch_size = 8192 # batch size for autoencoder training 22 | l1_coeff = 3e-3 23 | learning_rate = 3e-4 24 | resampling_interval = 25000 # number of training steps after which neuron resampling will be performed 25 | num_resamples = 4 # number of times resampling is to be performed; it is done 4 times in Anthropic's paper 26 | resampling_data_size = 819200 27 | # evaluation 28 | eval_batch_size = 16 # batch size (number of GPT contexts) for evaluation 29 | eval_iters = 200 # number of iterations in the evaluation loop 30 | eval_interval = 1000 # number of training steps after which the autoencoder is evaluated 31 | # I/O 32 | save_checkpoint = True # whether to save model, optimizer, etc or not 33 | save_interval = 10000 # number of training steps after which a checkpoint will be saved 34 | out_dir = 'out' # directory containing trained autoencoder model weights 35 | # wandb logging 36 | wandb_log = True 37 | # system 38 | device = 'cuda' 39 | # reproducibility 40 | seed = 1442 41 | 42 | # ----------------------------------------------------------------------------- 43 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 44 | exec(open('configurator.py').read()) # overrides from command line or config file 45 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 46 | # ----------------------------------------------------------------------------- 47 | 48 | torch.manual_seed(seed) 49 | # initiating ResourceLoader in training mode loads Transformer checkpoint, text data, and autoencoder data 50 | resourceloader = ResourceLoader( 51 | dataset=dataset, 52 | gpt_ckpt_dir=gpt_ckpt_dir, 53 | device=device, 54 | mode="train", 55 | ) 56 | 57 | gpt = resourceloader.transformer # TODO: either it should be called transformer or gpt 58 | autoencoder = AutoEncoder(n_inputs = 4 * resourceloader.transformer.config.n_embd, 59 | n_latents = n_features, 60 | lam = l1_coeff, 61 | resampling_interval = resampling_interval).to(device) 62 | optimizer = torch.optim.Adam(autoencoder.parameters(), lr=learning_rate) 63 | 64 | ## prepare for logging and saving checkpoints 65 | run_name = f'{time.time():.2f}' 66 | if wandb_log: 67 | import wandb 68 | wandb.init(project=f'sparse-autoencoder-{dataset}', name=run_name, config=config) 69 | if save_checkpoint: 70 | ckpt_path = os.path.join(out_dir, dataset, run_name) 71 | os.makedirs(ckpt_path, exist_ok=True) 72 | 73 | ############## TRAINING LOOP ############### 74 | start_time = time.time() 75 | num_steps = resourceloader.autoencoder_data_info["total_examples"] // batch_size 76 | 77 | for step in range(num_steps): 78 | 79 | batch = resourceloader.get_autoencoder_data_batch(step, batch_size=batch_size) 80 | optimizer.zero_grad(set_to_none=True) 81 | autoencoder_output = autoencoder(batch) # f has shape (batch_size, n_features) 82 | autoencoder_output['loss'].backward() 83 | 84 | # remove component of gradient parallel to weight 85 | autoencoder.remove_parallel_component_of_decoder_grad() 86 | optimizer.step() 87 | 88 | # periodically update the norm of dictionary vectors to ensure they stay close to 1. 89 | if step % 1000 == 0: 90 | autoencoder.normalize_decoder_columns() 91 | 92 | ## ------------ perform neuron resampling ----------- ###### 93 | # check if we should start investigating dead/alive neurons at this step 94 | # This is done at an odd multiple of resampling_interval // 2 in Anthropic's paper. 95 | if autoencoder.is_dead_neuron_investigation_step(step, resampling_interval, num_resamples): 96 | print(f'initiating investigation of dead neurons at step = {step}') 97 | autoencoder.initiate_dead_neurons() 98 | 99 | # check if we should look for dead neurons at this step 100 | # This is done between an odd and an even multiple of resampling_interval // 2. 101 | if autoencoder.is_within_neuron_investigation_phase(step, resampling_interval, num_resamples): 102 | autoencoder.update_dead_neurons(autoencoder_output['latents']) 103 | 104 | # perform neuron resampling if step is a multiple of resampling interval 105 | if (step+1) % resampling_interval == 0 and step < num_resamples * resampling_interval: 106 | num_dead_neurons = len(autoencoder.dead_neurons) 107 | print(f'{num_dead_neurons} neurons to be resampled at step = {step}') 108 | if num_dead_neurons > 0: 109 | autoencoder.resample_dead_neurons(data=resourceloader.select_resampling_data(size=resampling_data_size), 110 | optimizer=optimizer, 111 | batch_size=batch_size) 112 | 113 | ### ------------ log info ----------- ###### 114 | if (step % eval_interval == 0) or step == num_steps - 1: 115 | print(f'Entering evaluation mode at step = {step}') 116 | autoencoder.eval() 117 | 118 | log_dict = {'losses/reconstructed_nll': 0, # log-likelihood loss using reconstructed MLP activations 119 | 'losses/l0_norm': 0, # L0-norm; average number of non-zero components of a feature activation vector 120 | 'losses/reconstruction_loss': 0, # |xhat - x|^2 <-- L2-norm between MLP activations & their reconstruction 121 | 'losses/l1_norm': 0, # L1-norm of feature activations 122 | 'losses/autoencoder_loss': 0, # reconstruction_loss + L1-coeff * l1_loss 123 | 'losses/nll_score': 0, # ratio of (nll_loss - ablated_loss) to (nll_loss - reconstructed_nll) 124 | } 125 | 126 | # initiate a tensor containing the number of tokens on which each feature activates 127 | feat_acts_count = torch.zeros(n_features, dtype=torch.float32) 128 | 129 | # get batches of text data and evaluate the autoencoder on MLP activations 130 | for iter in range(eval_iters): 131 | if iter % 20 == 0: 132 | print(f'Performing evaluation at iterations # ({iter} - {min(iter+19, eval_iters)})/{eval_iters}') 133 | x, y = resourceloader.get_text_batch(num_contexts=eval_batch_size) 134 | 135 | _, nll_loss = gpt(x, y) 136 | mlp_acts = gpt.mlp_activation_hooks[0] 137 | gpt.clear_mlp_activation_hooks() # free up memory 138 | _, ablated_loss = gpt(x, y, mode="replace") 139 | 140 | with torch.no_grad(): 141 | autoencoder_output = autoencoder(mlp_acts) 142 | _, reconstructed_nll = gpt(x, y, mode="replace", replacement_tensor=autoencoder_output['reconst_acts']) 143 | 144 | # for each feature, calculate the TOTAL number of tokens on which it is active; shape: 145 | feat_acts = autoencoder_output['latents'].to('cpu') # (eval_batch_size, block_size, n_features) 146 | torch.add(feat_acts_count, feat_acts.count_nonzero(dim=[0, 1]), out=feat_acts_count) # (n_features, ) 147 | 148 | # calculat the AVERAGE number of non-zero entries in each feature vector and log all losses 149 | log_dict['losses/l0_norm'] += feat_acts.count_nonzero(dim=-1).float().mean().item() 150 | log_dict['losses/reconstructed_nll'] += reconstructed_nll.item() 151 | log_dict['losses/autoencoder_loss'] += autoencoder_output['loss'].item() 152 | log_dict['losses/reconstruction_loss'] += autoencoder_output['mse_loss'].item() 153 | log_dict['losses/l1_norm'] += autoencoder_output['l1_loss'].item() 154 | log_dict['losses/nll_score'] += (nll_loss - reconstructed_nll).item()/(nll_loss - ablated_loss).item() 155 | 156 | # compute feature densities and plot feature density histogram 157 | log_feat_acts_density = np.log10(feat_acts_count[feat_acts_count != 0]/(eval_iters * eval_batch_size * gpt.config.block_size)) # (n_features,) 158 | feat_density_historgram = make_density_histogram(log_feat_acts_density) 159 | 160 | # take mean of all loss values by dividing by the number of evaluation batches; also log more metrics 161 | log_dict = {key: val/eval_iters for key, val in log_dict.items()} 162 | log_dict.update( 163 | {'training_step': step, 164 | 'training_examples': step * batch_size, 165 | 'debug/mean_dictionary_vector_length': torch.linalg.vector_norm(autoencoder.decoder.weight, dim=0).mean(), 166 | 'feature_density/min_log_feat_density': log_feat_acts_density.min().item() if len(log_feat_acts_density) > 0 else -100, 167 | 'feature_density/num_neurons_with_feature_density_above_1e-3': (log_feat_acts_density > -3).sum(), 168 | 'feature_density/num_neurons_with_feature_density_below_1e-3': (log_feat_acts_density < -3).sum(), 169 | 'feature_density/num_neurons_with_feature_density_below_1e-4': (log_feat_acts_density < -4).sum(), 170 | 'feature_density/num_neurons_with_feature_density_below_1e-5': (log_feat_acts_density < -5).sum(), 171 | 'feature_density/num_alive_neurons': len(log_feat_acts_density), 172 | }) 173 | if wandb_log: 174 | log_dict.update({'feature_density/feature_density_histograms': wandb.Image(feat_density_historgram)}) 175 | wandb.log(log_dict) 176 | 177 | autoencoder.train() 178 | print(f'Exiting evaluation mode at step = {step}') 179 | 180 | ### ------------ save a checkpoint ----------- ###### 181 | if save_checkpoint and step > 0 and (step % save_interval == 0 or step == num_steps - 1): 182 | checkpoint = { 183 | 'autoencoder': autoencoder.state_dict(), 184 | 'optimizer': optimizer.state_dict(), 185 | 'log_dict': log_dict, 186 | 'config': config, 187 | 'feature_activation_counts': feat_acts_count, # may be used later to identify alive vs dead neurons 188 | } 189 | print(f"saving checkpoint to {ckpt_path} at training step = {step}") 190 | torch.save(checkpoint, os.path.join(ckpt_path, 'ckpt.pt')) 191 | 192 | if wandb_log: 193 | wandb.finish() -------------------------------------------------------------------------------- /autoencoder/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .plotting_utils import make_density_histogram, make_activations_histogram, make_logits_histogram -------------------------------------------------------------------------------- /autoencoder/utils/plotting_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Three different histogram functions. The difference lies in whether to save the histogram image on disk or not, 3 | color scheme and axes labels. 4 | These can perhaps be combined into one function, but leaving it as it is for now. 5 | """ 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | from io import BytesIO 9 | import torch 10 | import os 11 | 12 | def make_density_histogram(data, bins='auto'): 13 | """Makes a histogram image from the provided data and returns it. 14 | We use it in train.py to plot feature density histograms and log them with W&B.""" 15 | fig, ax = plt.subplots() 16 | ax.hist(data, bins=bins) 17 | ax.set_title('Histogram') 18 | plt.tight_layout() 19 | 20 | buf = BytesIO() # create a BytesIO buffer 21 | fig.savefig(buf, format='png') # save the plot to the buffer in PNG format 22 | buf.seek(0) # rewind the buffer to the beginning 23 | image = Image.open(buf) # open the image from the buffer 24 | 25 | plt.close(fig) # close the figure to free memory 26 | return image 27 | 28 | def make_activations_histogram(activations, density, feature_id, dirpath=None): 29 | """makes a histogram of activations and saves it on the disk 30 | we later include the histogram in the feature browser""" 31 | if isinstance(activations, torch.Tensor): 32 | activations = activations.cpu().numpy() 33 | plt.hist(activations, bins='auto') # You can adjust the number of bins as needed 34 | plt.title(f'Activations (Density = {density:.4f}%)') 35 | plt.xlabel('Activation') 36 | plt.ylabel('Frequency') 37 | 38 | # Save the histogram as an image 39 | image_path = os.path.join(dirpath, 'activations_histograms', f'{feature_id}.png') 40 | plt.savefig(image_path) 41 | plt.close() 42 | 43 | def make_logits_histogram(logits, feature_id, dirpath=None): 44 | """ 45 | Makes a histogram of logits for a given feature and saves it as a PNG file 46 | Input: 47 | logits: a torch tensor of shape (vocab_size,) 48 | feature_id: int 49 | dirpath: histogram is saved as dirpath/logits_histograms/feature_id.png 50 | """ 51 | plt.hist(logits.cpu().numpy(), bins='auto') # You can adjust the number of bins as needed 52 | 53 | image_path = os.path.join(dirpath, 'logits_histograms', f'{feature_id}.png') 54 | plt.savefig(image_path) 55 | plt.close() -------------------------------------------------------------------------------- /reproduction.md: -------------------------------------------------------------------------------- 1 | ## reproducing results 2 | 3 | **step 0: make a virtual environment and install required packages** 4 | 5 | Clone the repository and change the directory. 6 | ``` 7 | https://github.com/shehper/monosemantic.git && cd monosemantic 8 | ``` 9 | 10 | Make a new virtual environment, and activate it. 11 | ``` 12 | python -m venv ./env 13 | source ./env/bin/activate 14 | ``` 15 | 16 | Install packages from requirements.txt. 17 | ``` 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | I used Python 3.9 for this project. If you have an older version of OpenSSL on your machine, you will notice that downloading and tokenizing dataset in Step 1 will return a compatibility error between the versions of urllib3 and OpenSSL. In this case, you may upgrade OpenSSL or downgrade sentry-sdk and urllib3 to older versions as follows. 22 | 23 | ``` 24 | pip install sentry-sdk==1.29.2 # try only if prepare.py in Step 1 returns ImportError for urllib3 25 | pip install urllib3==1.26.15 # try only if prepare.py in Step 1 returns ImportError for urllib3 26 | ``` 27 | 28 | **step 1: train a one-layer transformer model** 29 | 30 | I used [nanoGPT](https://github.com/karpathy/nanoGPT) to train a one-layer transformer. The required code is in the 'transformer' subfolder of this repository. 31 | 32 | In order to train this transformer model, first move to the 'transformer' subdirectory. 33 | ``` 34 | cd transformer 35 | ``` 36 | 37 | Next, download and tokenize the OpenWebText dataset as follows. (If it gives any import errors, please look at the possible solution provided in Step 0.) 38 | 39 | ``` 40 | python data/openwebtext/prepare.py 41 | ``` 42 | 43 | This will result in two files in the data/openwebtext/ folder, named train.bin (containing ~9B tokens) and val.bin (containing ~4M tokens). Now, train a 1-layer transformer model with embedding dimension 128: 44 | ``` 45 | python train.py config/train_gpt2.py --wandb_project=monosemantic --n_layer=1 --n_embd=128 --n_head=4 --max_iters=200000 --lr_decay_iters=200000 46 | ``` 47 | 48 | This run saves the model checkpoints in the subfolder transformer/out. I trained the model for 200000 iterations in order to match the number of training epochs with Anthropic's paper. This run took around 3 days on an A100 GPU and achieved a validation loss of 4.609. 49 | 50 | If you have a node with more than one GPU available, you may alternatively train the model as follows for faster training. Here num_gpus is the number of GPUs on the node. 51 | 52 | ``` 53 | torchrun --standalone --nproc_per_node=num_gpus train.py config/train_gpt2.py --wandb_project=monosemantic --n_layer=1 --n_embd=128 --n_head=4 --max_iters=200000 --lr_decay_iters=200000 54 | ``` 55 | 56 | **step 2: generate training data for autoencoder** 57 | 58 | Now move to the autoencoder subdirectory. 59 | ``` 60 | cd ../autoencoder 61 | ``` 62 | 63 | First, generate the training data for the autoencoder. 64 | ``` 65 | python generate_mlp_data.py 66 | ``` 67 | By default, this computes MLP activations for 4 million contexts, and samples and randomly shuffles the outputs for 200 tokens per context. The dataset is saved in n_files=20 files in 'sae_data' subfolder of autoencoder. You may choose different values for these variables using --total_contexts, --tokens_per_context and --n_files command line arguments. 68 | 69 | I used a node with 1TB RAM for this step as the dataset takes about 770GB space. I saved it in 20 files in order to be able to train the autoencoder model on a node with less CPU RAM (as low as 64GB) in Step 3. 70 | 71 | By default, MLP activations were saved in float16 data type, but you may change that by passing '--convert_to_f16=False' flag in the command line input. 72 | 73 | **step 2a: choose a subset of data for neuron resampling** 74 | 75 | Anthropic used a random subset of 819200 activation vectors to resample neurons four times during training. As the node that I used for training (in Step 3) did not have high enough RAM so that I could load the entire training data of the autoencoder and select 819200 examples at the time of resampling, I used a high-RAM (> 1TB) node to pre-select 4*819200 examples and saved it in a separate file 'data_for_resampling_neurons.pt'. 76 | 77 | This may be done as follows. 78 | ``` 79 | python select_resampling_data.py 80 | ``` 81 | 82 | If you have high-RAM available on your GPU node, you may skip this step and sample the subset randomly at the time of neuron resampling. 83 | 84 | **step 3: train a sparse autoencoder model** 85 | 86 | Next, you may train the sparse autoencoder model as follows. 87 | ``` 88 | python train.py --l1_coeff=3e-7 89 | ``` 90 | 91 | I tried a few different values of the L1-coefficient and learning rate and noticed that the best trade-off between feature activation sparsity (=L0-norm) and reconstructed NLL score occured around l1-coeff=3e-7 and learning_rate=3e-4. This L1 coefficient is much smaller than the values of L1-coefficient used in Anthropic's paper. I do not know why this is the case. 92 | 93 | ## analysis of features 94 | During training, I logged various metrics including feature density histograms. They are available on this [Weights & Biases project](https://wandb.ai/shehper/sparse-autoencoder-openwebtext-public). The spikes in various loss curves appear at the training step of neuron resampling, as one would expect. 95 | 96 | It is mentioned in the Anthropic paper that they performed manual inspection of features during training. I did not perform this manual inspection *during* training but I did perform it after training finished to compare different models. 97 | 98 | For this step, I used top_activations.py as 99 | ``` 100 | python top_activations.py --autoencoder_subdir=/subdirectory/of/out_autoencoder/containing_model_ckpt --eval_contexts=20000 --length_context_on_each_side=10 --k=10 --publish_html=True 101 | ``` 102 | 103 | where /subdirectory/of/out_autoencoder/containing_model_ckpt is the name of the subdirectory of 'out_autencoder' folder containing the model checkpoint. This evaluates the model on 20000 contexts from the OpenWebText dataset. The output is saved as a dictionary of k=10 top activations for each autoencoder neuron. If we pass publish_html=True, it also saves the top 10 activations and the associated tokens and contexts for each neuron in the form of an HTML file in the same subdirectory. 104 | 105 | For example, please see the HTML files [high_density_neurons.html](https://shehper.github.io/monosemantic/autoencoder/out_autoencoder/1704783101.13-autoencoder-openwebtext/high_density_neurons.html) and [ultra_low_density_neurons.html](https://shehper.github.io/monosemantic/autoencoder/out_autoencoder/1704783101.13-autoencoder-openwebtext/ultra_low_density_neurons.html) for the model with l1_coeff=3e-7, learning_rate=3e-4, and loss curves as in the afore-mentioned [Weights & Biases page](https://wandb.ai/shehper/sparse-autoencoder-openwebtext-public/runs/rajo0rsx?workspace=user-). -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.9.4 2 | aiosignal==1.3.1 3 | appdirs==1.4.4 4 | appnope==0.1.3 5 | asttokens==2.4.1 6 | async-timeout==4.0.3 7 | attrs==23.2.0 8 | certifi==2023.11.17 9 | charset-normalizer==3.3.2 10 | click==8.1.7 11 | comm==0.2.0 12 | contourpy==1.2.0 13 | cycler==0.12.1 14 | datasets==2.16.1 15 | debugpy==1.8.0 16 | decorator==5.1.1 17 | dill==0.3.7 18 | docker-pycreds==0.4.0 19 | exceptiongroup==1.2.0 20 | executing==2.0.1 21 | filelock==3.13.1 22 | fonttools==4.47.0 23 | frozenlist==1.4.1 24 | fsspec==2023.10.0 25 | gitdb==4.0.11 26 | GitPython==3.1.41 27 | huggingface-hub==0.20.2 28 | idna==3.7 29 | importlib-metadata==7.0.0 30 | importlib-resources==6.1.1 31 | ipykernel==6.27.1 32 | ipython==8.18.1 33 | jedi==0.19.1 34 | Jinja2==3.1.3 35 | jupyter_client==8.6.0 36 | jupyter_core==5.5.0 37 | kiwisolver==1.4.5 38 | MarkupSafe==2.1.3 39 | matplotlib==3.8.2 40 | matplotlib-inline==0.1.6 41 | mpmath==1.3.0 42 | multidict==6.0.4 43 | multiprocess==0.70.15 44 | nest-asyncio==1.5.8 45 | networkx==3.2.1 46 | numpy==1.26.2 47 | packaging==23.2 48 | pandas==2.1.4 49 | parso==0.8.3 50 | pexpect==4.9.0 51 | Pillow==10.3.0 52 | platformdirs==4.1.0 53 | prompt-toolkit==3.0.43 54 | protobuf==4.25.1 55 | psutil==5.9.6 56 | ptyprocess==0.7.0 57 | pure-eval==0.2.2 58 | pyarrow==14.0.2 59 | pyarrow-hotfix==0.6 60 | Pygments==2.17.2 61 | pyparsing==3.1.1 62 | python-dateutil==2.8.2 63 | pytz==2023.3.post1 64 | PyYAML==6.0.1 65 | pyzmq==25.1.2 66 | regex==2023.10.3 67 | requests==2.31.0 68 | sentry-sdk==1.39.1 69 | setproctitle==1.3.3 70 | six==1.16.0 71 | smmap==5.0.1 72 | stack-data==0.6.3 73 | sympy==1.12 74 | tiktoken==0.5.2 75 | torch==2.1.2 76 | tornado==6.4 77 | tqdm==4.66.1 78 | traitlets==5.14.0 79 | typing_extensions==4.9.0 80 | tzdata==2023.4 81 | urllib3==2.1.0 82 | wandb==0.16.1 83 | wcwidth==0.2.12 84 | xxhash==3.4.1 85 | yarl==1.9.4 86 | zipp==3.17.0 87 | -------------------------------------------------------------------------------- /transformer/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformer/config/train_gpt2.py: -------------------------------------------------------------------------------- 1 | # config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB 2 | # launch as the following (e.g. in a screen session) and wait ~5 days: 3 | # $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py 4 | 5 | wandb_log = True 6 | wandb_project = 'owt' 7 | wandb_run_name='gpt2-124M' 8 | 9 | # these make the total batch size be ~0.5M 10 | # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520 11 | batch_size = 12 12 | block_size = 1024 13 | gradient_accumulation_steps = 5 * 8 14 | 15 | # this makes total number of tokens be 300B 16 | max_iters = 600000 17 | lr_decay_iters = 600000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # weight decay 25 | weight_decay = 1e-1 26 | -------------------------------------------------------------------------------- /transformer/config/train_shakespeare_char.py: -------------------------------------------------------------------------------- 1 | # train a miniature character-level shakespeare model 2 | # good for debugging and playing on macbooks and such 3 | 4 | out_dir = 'out-shakespeare-char' 5 | eval_interval = 250 # keep frequent because we'll overfit 6 | eval_iters = 200 7 | log_interval = 10 # don't print too too often 8 | 9 | # we expect to overfit on this small dataset, so only save when val improves 10 | always_save_checkpoint = False 11 | 12 | wandb_log = False # override via command line if you like 13 | wandb_project = 'shakespeare-char' 14 | wandb_run_name = 'mini-gpt' 15 | 16 | dataset = 'shakespeare_char' 17 | gradient_accumulation_steps = 1 18 | batch_size = 64 19 | block_size = 256 # context of up to 256 previous characters 20 | 21 | # baby GPT model :) 22 | n_layer = 6 23 | n_head = 6 24 | n_embd = 384 25 | dropout = 0.2 26 | 27 | learning_rate = 1e-3 # with baby networks can afford to go a bit higher 28 | max_iters = 5000 29 | lr_decay_iters = 5000 # make equal to max_iters usually 30 | min_lr = 1e-4 # learning_rate / 10 usually 31 | beta2 = 0.99 # make a bit bigger because number of tokens per iter is small 32 | 33 | warmup_iters = 100 # not super necessary potentially 34 | 35 | # on macbook also add 36 | # device = 'cpu' # run on cpu only 37 | # compile = False # do not torch compile the model 38 | -------------------------------------------------------------------------------- /transformer/configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /transformer/data/openwebtext/prepare.py: -------------------------------------------------------------------------------- 1 | # saves the openwebtext dataset to a binary file for training. following was helpful: 2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 3 | 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import tiktoken 8 | from datasets import load_dataset # huggingface datasets 9 | 10 | # number of workers in .map() call 11 | # good number to use is ~order number of cpu cores // 2 12 | num_proc = 8 13 | 14 | # number of workers in load_dataset() call 15 | # best number might be different from num_proc above as it also depends on NW speed. 16 | # it is better than 1 usually though 17 | num_proc_load_dataset = num_proc 18 | 19 | if __name__ == '__main__': 20 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 21 | dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset) 22 | 23 | # owt by default only contains the 'train' split, so create a test split 24 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 25 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val 26 | 27 | # this results in: 28 | # >>> split_dataset 29 | # DatasetDict({ 30 | # train: Dataset({ 31 | # features: ['text'], 32 | # num_rows: 8009762 33 | # }) 34 | # val: Dataset({ 35 | # features: ['text'], 36 | # num_rows: 4007 37 | # }) 38 | # }) 39 | 40 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 41 | enc = tiktoken.get_encoding("gpt2") 42 | def process(example): 43 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 44 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 45 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 46 | out = {'ids': ids, 'len': len(ids)} 47 | return out 48 | 49 | # tokenize the dataset 50 | tokenized = split_dataset.map( 51 | process, 52 | remove_columns=['text'], 53 | desc="tokenizing the splits", 54 | num_proc=num_proc, 55 | ) 56 | 57 | # concatenate all the ids in each dataset into one large file we can use for training 58 | for split, dset in tokenized.items(): 59 | arr_len = np.sum(dset['len'], dtype=np.uint64) 60 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 61 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 62 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 63 | total_batches = 1024 64 | 65 | idx = 0 66 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): 67 | # Batch together samples for faster write 68 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') 69 | arr_batch = np.concatenate(batch['ids']) 70 | # Write into mmap 71 | arr[idx : idx + len(arr_batch)] = arr_batch 72 | idx += len(arr_batch) 73 | arr.flush() 74 | 75 | # train.bin is ~17GB, val.bin ~8.5MB 76 | # train has ~9B tokens (9,035,582,198) 77 | # val has ~4M tokens (4,434,897) 78 | 79 | # to read the bin files later, e.g. with numpy: 80 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 81 | -------------------------------------------------------------------------------- /transformer/data/openwebtext/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## openwebtext dataset 3 | 4 | after running `prepare.py` (preprocess) we get: 5 | 6 | - train.bin is ~17GB, val.bin ~8.5MB 7 | - train has ~9B tokens (9,035,582,198) 8 | - val has ~4M tokens (4,434,897) 9 | 10 | this came from 8,013,769 documents in total. 11 | 12 | references: 13 | 14 | - OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) 15 | - [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset 16 | -------------------------------------------------------------------------------- /transformer/data/shakespeare/prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import tiktoken 4 | import numpy as np 5 | 6 | # download the tiny shakespeare dataset 7 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') 8 | if not os.path.exists(input_file_path): 9 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 10 | with open(input_file_path, 'w') as f: 11 | f.write(requests.get(data_url).text) 12 | 13 | with open(input_file_path, 'r') as f: 14 | data = f.read() 15 | n = len(data) 16 | train_data = data[:int(n*0.9)] 17 | val_data = data[int(n*0.9):] 18 | 19 | # encode with tiktoken gpt2 bpe 20 | enc = tiktoken.get_encoding("gpt2") 21 | train_ids = enc.encode_ordinary(train_data) 22 | val_ids = enc.encode_ordinary(val_data) 23 | print(f"train has {len(train_ids):,} tokens") 24 | print(f"val has {len(val_ids):,} tokens") 25 | 26 | # export to bin files 27 | train_ids = np.array(train_ids, dtype=np.uint16) 28 | val_ids = np.array(val_ids, dtype=np.uint16) 29 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) 30 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) 31 | 32 | # train.bin has 301,966 tokens 33 | # val.bin has 36,059 tokens 34 | -------------------------------------------------------------------------------- /transformer/data/shakespeare/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # tiny shakespeare 3 | 4 | Tiny shakespeare, of the good old char-rnn fame :) 5 | 6 | After running `prepare.py`: 7 | 8 | - train.bin has 301,966 tokens 9 | - val.bin has 36,059 tokens 10 | -------------------------------------------------------------------------------- /transformer/data/shakespeare_char/prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prepare the Shakespeare dataset for character-level language modeling. 3 | So instead of encoding with GPT-2 BPE tokens, we just map characters to ints. 4 | Will save train.bin, val.bin containing the ids, and meta.pkl containing the 5 | encoder and decoder and some other related info. 6 | """ 7 | import os 8 | import pickle 9 | import requests 10 | import numpy as np 11 | 12 | # download the tiny shakespeare dataset 13 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') 14 | if not os.path.exists(input_file_path): 15 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 16 | with open(input_file_path, 'w') as f: 17 | f.write(requests.get(data_url).text) 18 | 19 | with open(input_file_path, 'r') as f: 20 | data = f.read() 21 | print(f"length of dataset in characters: {len(data):,}") 22 | 23 | # get all the unique characters that occur in this text 24 | chars = sorted(list(set(data))) 25 | vocab_size = len(chars) 26 | print("all the unique characters:", ''.join(chars)) 27 | print(f"vocab size: {vocab_size:,}") 28 | 29 | # create a mapping from characters to integers 30 | stoi = { ch:i for i,ch in enumerate(chars) } 31 | itos = { i:ch for i,ch in enumerate(chars) } 32 | def encode(s): 33 | return [stoi[c] for c in s] # encoder: take a string, output a list of integers 34 | def decode(l): 35 | return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string 36 | 37 | # create the train and test splits 38 | n = len(data) 39 | train_data = data[:int(n*0.9)] 40 | val_data = data[int(n*0.9):] 41 | 42 | # encode both to integers 43 | train_ids = encode(train_data) 44 | val_ids = encode(val_data) 45 | print(f"train has {len(train_ids):,} tokens") 46 | print(f"val has {len(val_ids):,} tokens") 47 | 48 | # export to bin files 49 | train_ids = np.array(train_ids, dtype=np.uint16) 50 | val_ids = np.array(val_ids, dtype=np.uint16) 51 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) 52 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) 53 | 54 | # save the meta information as well, to help us encode/decode later 55 | meta = { 56 | 'vocab_size': vocab_size, 57 | 'itos': itos, 58 | 'stoi': stoi, 59 | } 60 | with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f: 61 | pickle.dump(meta, f) 62 | 63 | # length of dataset in characters: 1115394 64 | # all the unique characters: 65 | # !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz 66 | # vocab size: 65 67 | # train has 1003854 tokens 68 | # val has 111540 tokens 69 | -------------------------------------------------------------------------------- /transformer/data/shakespeare_char/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # tiny shakespeare, character-level 3 | 4 | Tiny shakespeare, of the good old char-rnn fame :) Treated on character-level. 5 | 6 | After running `prepare.py`: 7 | 8 | - train.bin has 1,003,854 tokens 9 | - val.bin has 111,540 tokens 10 | -------------------------------------------------------------------------------- /transformer/hooked_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from model import GPT 4 | 5 | class HookedGPT(GPT): 6 | def __init__(self, config): 7 | super().__init__(config) 8 | self.mlp_activation_hooks = [] 9 | self.hook_handle = None 10 | 11 | def hook_fn(self, module, input, output, mode='store', replacement_tensor=None): 12 | if mode == 'store': 13 | self.mlp_activation_hooks.append(output.clone().detach()) 14 | elif mode == 'replace': 15 | if replacement_tensor is None: 16 | replacement_tensor = torch.zeros_like(output) 17 | return replacement_tensor 18 | 19 | def forward(self, idx, targets=None, mode='store', replacement_tensor=None): 20 | device = idx.device 21 | b, t = idx.size() 22 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 23 | pos = torch.arange(0, t, dtype=torch.long, device=device) 24 | 25 | tok_emb = self.transformer.wte(idx) 26 | pos_emb = self.transformer.wpe(pos) 27 | x = self.transformer.drop(tok_emb + pos_emb) 28 | 29 | if mode == 'store' and self.mlp_activation_hooks: 30 | self.clear_mlp_activation_hooks() 31 | 32 | # Register the hook on the MLP GELU activation of the last transformer block 33 | self.hook_handle = self.transformer.h[-1].mlp.gelu.register_forward_hook( 34 | lambda module, input, output: 35 | self.hook_fn(module, input, output, mode=mode, replacement_tensor=replacement_tensor) 36 | ) 37 | 38 | for block in self.transformer.h: 39 | x = block(x) 40 | x = self.transformer.ln_f(x) 41 | 42 | if targets is not None: 43 | logits = self.lm_head(x) 44 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 45 | else: 46 | logits = self.lm_head(x[:, [-1], :]) 47 | loss = None 48 | 49 | # Remove the hook after the forward pass 50 | self.hook_handle.remove() 51 | self.hook_handle = None 52 | 53 | return logits, loss 54 | 55 | def clear_mlp_activation_hooks(self): 56 | self.mlp_activation_hooks.clear() -------------------------------------------------------------------------------- /transformer/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full definition of a GPT Language Model, all of it in this single file. 3 | References: 4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 5 | https://github.com/openai/gpt-2/blob/master/src/model.py 6 | 2) huggingface/transformers PyTorch implementation: 7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 8 | """ 9 | 10 | import math 11 | import inspect 12 | from dataclasses import dataclass 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | 18 | class LayerNorm(nn.Module): 19 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 20 | 21 | def __init__(self, ndim, bias): 22 | super().__init__() 23 | self.weight = nn.Parameter(torch.ones(ndim)) 24 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 25 | 26 | def forward(self, input): 27 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 28 | 29 | class CausalSelfAttention(nn.Module): 30 | 31 | def __init__(self, config): 32 | super().__init__() 33 | assert config.n_embd % config.n_head == 0 34 | # key, query, value projections for all heads, but in a batch 35 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 36 | # output projection 37 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 38 | # regularization 39 | self.attn_dropout = nn.Dropout(config.dropout) 40 | self.resid_dropout = nn.Dropout(config.dropout) 41 | self.n_head = config.n_head 42 | self.n_embd = config.n_embd 43 | self.dropout = config.dropout 44 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 45 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 46 | if not self.flash: 47 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 48 | # causal mask to ensure that attention is only applied to the left in the input sequence 49 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 50 | .view(1, 1, config.block_size, config.block_size)) 51 | 52 | def forward(self, x): 53 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 54 | 55 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 56 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 57 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 58 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 59 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 60 | 61 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 62 | if self.flash: 63 | # efficient attention using Flash Attention CUDA kernels 64 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) 65 | else: 66 | # manual implementation of attention 67 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 68 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 69 | att = F.softmax(att, dim=-1) 70 | att = self.attn_dropout(att) 71 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 72 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 73 | 74 | # output projection 75 | y = self.resid_dropout(self.c_proj(y)) 76 | return y 77 | 78 | class MLP(nn.Module): 79 | 80 | def __init__(self, config): 81 | super().__init__() 82 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 83 | self.gelu = nn.GELU() 84 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 85 | self.dropout = nn.Dropout(config.dropout) 86 | 87 | def forward(self, x): 88 | x = self.c_fc(x) 89 | x = self.gelu(x) 90 | x = self.c_proj(x) 91 | x = self.dropout(x) 92 | return x 93 | 94 | class Block(nn.Module): 95 | 96 | def __init__(self, config): 97 | super().__init__() 98 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 99 | self.attn = CausalSelfAttention(config) 100 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 101 | self.mlp = MLP(config) 102 | 103 | def forward(self, x): 104 | x = x + self.attn(self.ln_1(x)) 105 | x = x + self.mlp(self.ln_2(x)) 106 | return x 107 | 108 | @dataclass 109 | class GPTConfig: 110 | block_size: int = 1024 111 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 112 | n_layer: int = 12 113 | n_head: int = 12 114 | n_embd: int = 768 115 | dropout: float = 0.0 116 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 117 | 118 | class GPT(nn.Module): 119 | 120 | def __init__(self, config): 121 | super().__init__() 122 | assert config.vocab_size is not None 123 | assert config.block_size is not None 124 | self.config = config 125 | 126 | self.transformer = nn.ModuleDict(dict( 127 | wte = nn.Embedding(config.vocab_size, config.n_embd), 128 | wpe = nn.Embedding(config.block_size, config.n_embd), 129 | drop = nn.Dropout(config.dropout), 130 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 131 | ln_f = LayerNorm(config.n_embd, bias=config.bias), 132 | )) 133 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 134 | # with weight tying when using torch.compile() some warnings get generated: 135 | # "UserWarning: functional_call was passed multiple values for tied weights. 136 | # This behavior is deprecated and will be an error in future versions" 137 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 138 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 139 | 140 | # init all weights 141 | self.apply(self._init_weights) 142 | # apply special scaled init to the residual projections, per GPT-2 paper 143 | for pn, p in self.named_parameters(): 144 | if pn.endswith('c_proj.weight'): 145 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 146 | 147 | # report number of parameters 148 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 149 | 150 | def get_num_params(self, non_embedding=True): 151 | """ 152 | Return the number of parameters in the model. 153 | For non-embedding count (default), the position embeddings get subtracted. 154 | The token embeddings would too, except due to the parameter sharing these 155 | params are actually used as weights in the final layer, so we include them. 156 | """ 157 | n_params = sum(p.numel() for p in self.parameters()) 158 | if non_embedding: 159 | n_params -= self.transformer.wpe.weight.numel() 160 | return n_params 161 | 162 | def _init_weights(self, module): 163 | if isinstance(module, nn.Linear): 164 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 165 | if module.bias is not None: 166 | torch.nn.init.zeros_(module.bias) 167 | elif isinstance(module, nn.Embedding): 168 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 169 | 170 | def forward(self, idx, targets=None): 171 | device = idx.device 172 | b, t = idx.size() 173 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 174 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) 175 | 176 | # forward the GPT model itself 177 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 178 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) 179 | x = self.transformer.drop(tok_emb + pos_emb) 180 | for block in self.transformer.h: 181 | x = block(x) 182 | x = self.transformer.ln_f(x) 183 | 184 | if targets is not None: 185 | # if we are given some desired targets also calculate the loss 186 | logits = self.lm_head(x) 187 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 188 | else: 189 | # inference-time mini-optimization: only forward the lm_head on the very last position 190 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 191 | loss = None 192 | 193 | return logits, loss 194 | 195 | def crop_block_size(self, block_size): 196 | # model surgery to decrease the block size if necessary 197 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 198 | # but want to use a smaller block size for some smaller, simpler model 199 | assert block_size <= self.config.block_size 200 | self.config.block_size = block_size 201 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) 202 | for block in self.transformer.h: 203 | if hasattr(block.attn, 'bias'): 204 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] 205 | 206 | @classmethod 207 | def from_pretrained(cls, model_type, override_args=None): 208 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 209 | override_args = override_args or {} # default to empty dict 210 | # only dropout can be overridden see more notes below 211 | assert all(k == 'dropout' for k in override_args) 212 | from transformers import GPT2LMHeadModel 213 | print("loading weights from pretrained gpt: %s" % model_type) 214 | 215 | # n_layer, n_head and n_embd are determined from model_type 216 | config_args = { 217 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 218 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 219 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 220 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 221 | }[model_type] 222 | print("forcing vocab_size=50257, block_size=1024, bias=True") 223 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints 224 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints 225 | config_args['bias'] = True # always True for GPT model checkpoints 226 | # we can override the dropout rate, if desired 227 | if 'dropout' in override_args: 228 | print(f"overriding dropout rate to {override_args['dropout']}") 229 | config_args['dropout'] = override_args['dropout'] 230 | # create a from-scratch initialized minGPT model 231 | config = GPTConfig(**config_args) 232 | model = GPT(config) 233 | sd = model.state_dict() 234 | sd_keys = sd.keys() 235 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param 236 | 237 | # init a huggingface/transformers model 238 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 239 | sd_hf = model_hf.state_dict() 240 | 241 | # copy while ensuring all of the parameters are aligned and match in names and shapes 242 | sd_keys_hf = sd_hf.keys() 243 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer 244 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) 245 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 246 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 247 | # this means that we have to transpose these weights when we import them 248 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 249 | for k in sd_keys_hf: 250 | if any(k.endswith(w) for w in transposed): 251 | # special treatment for the Conv1D weights we need to transpose 252 | assert sd_hf[k].shape[::-1] == sd[k].shape 253 | with torch.no_grad(): 254 | sd[k].copy_(sd_hf[k].t()) 255 | else: 256 | # vanilla copy over the other parameters 257 | assert sd_hf[k].shape == sd[k].shape 258 | with torch.no_grad(): 259 | sd[k].copy_(sd_hf[k]) 260 | 261 | return model 262 | 263 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 264 | # start with all of the candidate parameters 265 | param_dict = {pn: p for pn, p in self.named_parameters()} 266 | # filter out those that do not require grad 267 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 268 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 269 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 270 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 271 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 272 | optim_groups = [ 273 | {'params': decay_params, 'weight_decay': weight_decay}, 274 | {'params': nodecay_params, 'weight_decay': 0.0} 275 | ] 276 | num_decay_params = sum(p.numel() for p in decay_params) 277 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 278 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 279 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 280 | # Create AdamW optimizer and use the fused version if it is available 281 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 282 | use_fused = fused_available and device_type == 'cuda' 283 | extra_args = dict(fused=True) if use_fused else dict() 284 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 285 | print(f"using fused AdamW: {use_fused}") 286 | 287 | return optimizer 288 | 289 | def estimate_mfu(self, fwdbwd_per_iter, dt): 290 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 291 | # first estimate the number of flops we do per iteration. 292 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 293 | N = self.get_num_params() 294 | cfg = self.config 295 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size 296 | flops_per_token = 6*N + 12*L*H*Q*T 297 | flops_per_fwdbwd = flops_per_token * T 298 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 299 | # express our flops throughput as ratio of A100 bfloat16 peak flops 300 | flops_achieved = flops_per_iter * (1.0/dt) # per second 301 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 302 | mfu = flops_achieved / flops_promised 303 | return mfu 304 | 305 | @torch.no_grad() 306 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 307 | """ 308 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 309 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 310 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 311 | """ 312 | for _ in range(max_new_tokens): 313 | # if the sequence context is growing too long we must crop it at block_size 314 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 315 | # forward the model to get the logits for the index in the sequence 316 | logits, _ = self(idx_cond) 317 | # pluck the logits at the final step and scale by desired temperature 318 | logits = logits[:, -1, :] / temperature 319 | # optionally crop the logits to only the top k options 320 | if top_k is not None: 321 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 322 | logits[logits < v[:, [-1]]] = -float('Inf') 323 | # apply softmax to convert logits to (normalized) probabilities 324 | probs = F.softmax(logits, dim=-1) 325 | # sample from the distribution 326 | idx_next = torch.multinomial(probs, num_samples=1) 327 | # append sampled index to the running sequence and continue 328 | idx = torch.cat((idx, idx_next), dim=1) 329 | 330 | return idx -------------------------------------------------------------------------------- /transformer/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This training script can be run both on a single gpu in debug mode, 3 | and also in a larger training run with distributed data parallel (ddp). 4 | 5 | To run on a single GPU, example: 6 | $ python train.py --batch_size=32 --compile=False 7 | 8 | To run with DDP on 4 gpus on 1 node, example: 9 | $ torchrun --standalone --nproc_per_node=4 train.py 10 | 11 | To run with DDP on 4 gpus across 2 nodes, example: 12 | - Run on the first (master) node with example IP 123.456.123.456: 13 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 14 | - Run on the worker node: 15 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 16 | (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1) 17 | """ 18 | 19 | import os 20 | import time 21 | import math 22 | import pickle 23 | from contextlib import nullcontext 24 | 25 | import numpy as np 26 | import torch 27 | from torch.nn.parallel import DistributedDataParallel as DDP 28 | from torch.distributed import init_process_group, destroy_process_group 29 | 30 | from model import GPTConfig, GPT 31 | 32 | 33 | import torch._dynamo 34 | torch._dynamo.config.suppress_errors = True 35 | 36 | # ----------------------------------------------------------------------------- 37 | # default config values designed to train a gpt2 (124M) on OpenWebText 38 | # I/O 39 | out_dir = 'out' 40 | eval_interval = 2000 41 | log_interval = 1 42 | eval_iters = 200 43 | eval_only = False # if True, script exits right after the first eval 44 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 45 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 46 | # wandb logging 47 | wandb_log = False # disabled by default 48 | wandb_project = 'owt' 49 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 50 | # data 51 | dataset = 'openwebtext' 52 | gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes 53 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 54 | block_size = 1024 55 | # model 56 | n_layer = 12 57 | n_head = 12 58 | n_embd = 768 59 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 60 | bias = False # do we use bias inside LayerNorm and Linear layers? 61 | # adamw optimizer 62 | learning_rate = 6e-4 # max learning rate 63 | max_iters = 600000 # total number of training iterations 64 | weight_decay = 1e-1 65 | beta1 = 0.9 66 | beta2 = 0.95 67 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 68 | # learning rate decay settings 69 | decay_lr = True # whether to decay the learning rate 70 | warmup_iters = 2000 # how many steps to warm up for 71 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 72 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 73 | # DDP settings 74 | backend = 'nccl' # 'nccl', 'gloo', etc. 75 | # system 76 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 77 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 78 | compile = True # use PyTorch 2.0 to compile the model to be faster 79 | # ----------------------------------------------------------------------------- 80 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 81 | exec(open('configurator.py').read()) # overrides from command line or config file 82 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 83 | # ----------------------------------------------------------------------------- 84 | 85 | # various inits, derived attributes, I/O setup 86 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 87 | if ddp: 88 | init_process_group(backend=backend) 89 | ddp_rank = int(os.environ['RANK']) 90 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 91 | ddp_world_size = int(os.environ['WORLD_SIZE']) 92 | device = f'cuda:{ddp_local_rank}' 93 | torch.cuda.set_device(device) 94 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 95 | seed_offset = ddp_rank # each process gets a different seed 96 | # world_size number of processes will be training simultaneously, so we can scale 97 | # down the desired gradient accumulation iterations per process proportionally 98 | assert gradient_accumulation_steps % ddp_world_size == 0 99 | gradient_accumulation_steps //= ddp_world_size 100 | else: 101 | # if not ddp, we are running on a single gpu, and one process 102 | master_process = True 103 | seed_offset = 0 104 | ddp_world_size = 1 105 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size 106 | print(f"tokens per iteration will be: {tokens_per_iter:,}") 107 | 108 | if master_process: 109 | os.makedirs(out_dir, exist_ok=True) 110 | torch.manual_seed(1337 + seed_offset) 111 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 112 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 113 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 114 | # note: float16 data type will automatically use a GradScaler 115 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 116 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 117 | 118 | # poor man's data loader 119 | data_dir = os.path.join('data', dataset) 120 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 121 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 122 | def get_batch(split): 123 | data = train_data if split == 'train' else val_data 124 | ix = torch.randint(len(data) - block_size, (batch_size,)) 125 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 126 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 127 | if device_type == 'cuda': 128 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 129 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 130 | else: 131 | x, y = x.to(device), y.to(device) 132 | return x, y 133 | 134 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 135 | iter_num = 0 136 | best_val_loss = 1e9 137 | 138 | # attempt to derive vocab_size from the dataset 139 | meta_path = os.path.join(data_dir, 'meta.pkl') 140 | meta_vocab_size = None 141 | if os.path.exists(meta_path): 142 | with open(meta_path, 'rb') as f: 143 | meta = pickle.load(f) 144 | meta_vocab_size = meta['vocab_size'] 145 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 146 | 147 | # model init 148 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 149 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line 150 | if init_from == 'scratch': 151 | # init a new model from scratch 152 | print("Initializing a new model from scratch") 153 | # determine the vocab size we'll use for from-scratch training 154 | if meta_vocab_size is None: 155 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 156 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 157 | gptconf = GPTConfig(**model_args) 158 | model = GPT(gptconf) 159 | elif init_from == 'resume': 160 | print(f"Resuming training from {out_dir}") 161 | # resume training from a checkpoint. 162 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 163 | checkpoint = torch.load(ckpt_path, map_location=device) 164 | checkpoint_model_args = checkpoint['model_args'] 165 | # force these config attributes to be equal otherwise we can't even resume training 166 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 167 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 168 | model_args[k] = checkpoint_model_args[k] 169 | # create the model 170 | gptconf = GPTConfig(**model_args) 171 | model = GPT(gptconf) 172 | state_dict = checkpoint['model'] 173 | # fix the keys of the state dictionary :( 174 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 175 | unwanted_prefix = '_orig_mod.' 176 | for k,v in list(state_dict.items()): 177 | if k.startswith(unwanted_prefix): 178 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 179 | model.load_state_dict(state_dict) 180 | iter_num = checkpoint['iter_num'] 181 | best_val_loss = checkpoint['best_val_loss'] 182 | elif init_from.startswith('gpt2'): 183 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 184 | # initialize from OpenAI GPT-2 weights 185 | override_args = dict(dropout=dropout) 186 | model = GPT.from_pretrained(init_from, override_args) 187 | # read off the created config params, so we can store them into checkpoint correctly 188 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 189 | model_args[k] = getattr(model.config, k) 190 | # crop down the model block size if desired, using model surgery 191 | if block_size < model.config.block_size: 192 | model.crop_block_size(block_size) 193 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 194 | model.to(device) 195 | 196 | # initialize a GradScaler. If enabled=False scaler is a no-op 197 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 198 | 199 | # optimizer 200 | optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) 201 | if init_from == 'resume': 202 | optimizer.load_state_dict(checkpoint['optimizer']) 203 | checkpoint = None # free up memory 204 | 205 | # compile the model 206 | if compile: 207 | print("compiling the model... (takes a ~minute)") 208 | unoptimized_model = model 209 | model = torch.compile(model) # requires PyTorch 2.0 210 | 211 | # wrap model into DDP container 212 | if ddp: 213 | model = DDP(model, device_ids=[ddp_local_rank]) 214 | 215 | # helps estimate an arbitrarily accurate loss over either split using many batches 216 | @torch.no_grad() 217 | def estimate_loss(): 218 | out = {} 219 | model.eval() 220 | for split in ['train', 'val']: 221 | losses = torch.zeros(eval_iters) 222 | for k in range(eval_iters): 223 | X, Y = get_batch(split) 224 | with ctx: 225 | logits, loss = model(X, Y) 226 | losses[k] = loss.item() 227 | out[split] = losses.mean() 228 | model.train() 229 | return out 230 | 231 | # learning rate decay scheduler (cosine with warmup) 232 | def get_lr(it): 233 | # 1) linear warmup for warmup_iters steps 234 | if it < warmup_iters: 235 | return learning_rate * it / warmup_iters 236 | # 2) if it > lr_decay_iters, return min learning rate 237 | if it > lr_decay_iters: 238 | return min_lr 239 | # 3) in between, use cosine decay down to min learning rate 240 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 241 | assert 0 <= decay_ratio <= 1 242 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 243 | return min_lr + coeff * (learning_rate - min_lr) 244 | 245 | # logging 246 | if wandb_log and master_process: 247 | import wandb 248 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 249 | 250 | # training loop 251 | X, Y = get_batch('train') # fetch the very first batch 252 | t0 = time.time() 253 | local_iter_num = 0 # number of iterations in the lifetime of this process 254 | raw_model = model.module if ddp else model # unwrap DDP container if needed 255 | running_mfu = -1.0 256 | while True: 257 | 258 | # determine and set the learning rate for this iteration 259 | lr = get_lr(iter_num) if decay_lr else learning_rate 260 | for param_group in optimizer.param_groups: 261 | param_group['lr'] = lr 262 | 263 | # evaluate the loss on train/val sets and write checkpoints 264 | if iter_num % eval_interval == 0 and master_process: 265 | losses = estimate_loss() 266 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 267 | if wandb_log: 268 | wandb.log({ 269 | "iter": iter_num, 270 | "train/loss": losses['train'], 271 | "val/loss": losses['val'], 272 | "lr": lr, 273 | "mfu": running_mfu*100, # convert to percentage 274 | }) 275 | if losses['val'] < best_val_loss or always_save_checkpoint: 276 | best_val_loss = losses['val'] 277 | if iter_num > 0: 278 | checkpoint = { 279 | 'model': raw_model.state_dict(), 280 | 'optimizer': optimizer.state_dict(), 281 | 'model_args': model_args, 282 | 'iter_num': iter_num, 283 | 'best_val_loss': best_val_loss, 284 | 'config': config, 285 | } 286 | print(f"saving checkpoint to {out_dir}") 287 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 288 | if iter_num == 0 and eval_only: 289 | break 290 | 291 | # forward backward update, with optional gradient accumulation to simulate larger batch size 292 | # and using the GradScaler if data type is float16 293 | for micro_step in range(gradient_accumulation_steps): 294 | if ddp: 295 | # in DDP training we only need to sync gradients at the last micro step. 296 | # the official way to do this is with model.no_sync() context manager, but 297 | # I really dislike that this bloats the code and forces us to repeat code 298 | # looking at the source of that context manager, it just toggles this variable 299 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 300 | with ctx: 301 | logits, loss = model(X, Y) 302 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation 303 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 304 | X, Y = get_batch('train') 305 | # backward pass, with gradient scaling if training in fp16 306 | scaler.scale(loss).backward() 307 | # clip the gradient 308 | if grad_clip != 0.0: 309 | scaler.unscale_(optimizer) 310 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 311 | # step the optimizer and scaler if training in fp16 312 | scaler.step(optimizer) 313 | scaler.update() 314 | # flush the gradients as soon as we can, no need for this memory anymore 315 | optimizer.zero_grad(set_to_none=True) 316 | 317 | # timing and logging 318 | t1 = time.time() 319 | dt = t1 - t0 320 | t0 = t1 321 | if iter_num % log_interval == 0 and master_process: 322 | # get loss as float. note: this is a CPU-GPU sync point 323 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum) 324 | lossf = loss.item() * gradient_accumulation_steps 325 | if local_iter_num >= 5: # let the training loop settle a bit 326 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 327 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 328 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 329 | iter_num += 1 330 | local_iter_num += 1 331 | 332 | # termination conditions 333 | if iter_num > max_iters: 334 | break 335 | 336 | if ddp: 337 | destroy_process_group() 338 | --------------------------------------------------------------------------------