├── README.md ├── assets └── teaser.svg ├── datasets ├── __init__.py └── activations.py ├── dictionary_learning ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── buffer.py ├── config.py ├── dictionary.py ├── evaluation.py ├── grad_pursuit.py ├── interp.py ├── pretrained_dictionary_downloader.sh ├── requirements.txt ├── tests │ └── test_end_to_end.py ├── trainers │ ├── __init__.py │ ├── batch_top_k.py │ ├── gated_anneal.py │ ├── gdm.py │ ├── jumprelu.py │ ├── matroyshka_batch_top_k.py │ ├── p_anneal.py │ ├── standard.py │ ├── top_k.py │ └── trainer.py ├── training.py └── utils.py ├── encode_images.py ├── find_hai_indices.py ├── imagenet_subset.py ├── images └── white.png ├── inat_depth.py ├── metric.py ├── models ├── clip.py ├── dino.py ├── llava.py └── siglip.py ├── requirements.txt ├── sae_train.py ├── save_activations.py ├── scripts ├── matryoshka_hierarchy.sh ├── mllm_steering.sh └── monosemanticity_score.sh ├── similarity_baseline.py ├── steering_qualitative.py ├── steering_score.py ├── uniqueness.py ├── utils.py └── visualize_neurons.py /README.md: -------------------------------------------------------------------------------- 1 | ## 2 |

Sparse Autoencoders Learn Monosemantic Features in Vision-Language Models

3 | 4 |
5 | Mateusz Pach, 6 | Shyamgopal Karthik, 7 | Quentin Bouniot, 8 | Serge Belongie, 9 | Zeynep Akata 10 |
11 |
12 | 13 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2504.02821) 14 |
15 | 16 |

Abstract

17 | 18 |

19 | Sparse Autoencoders (SAEs) have recently been shown to enhance interpretability and steerability in Large Language Models (LLMs). In this work, we extend the application of SAEs to Vision-Language Models (VLMs), such as CLIP, and introduce a comprehensive framework for evaluating monosemanticity in vision representations. Our experimental results reveal that SAEs trained on VLMs significantly enhance the monosemanticity of individual neurons while also exhibiting hierarchical representations that align well with expert-defined structures (e.g., iNaturalist taxonomy). Most notably, we demonstrate that applying SAEs to intervene on a CLIP vision encoder, directly steer output from multimodal LLMs (e.g., LLaVA) without any modifications to the underlying model. These findings emphasize the practicality and efficacy of SAEs as an unsupervised approach for enhancing both the interpretability and control of VLMs. 20 |

21 |
22 |
23 | Teaser 24 |
25 | 26 | --- 27 | ### Setup 28 | Install required PIP packages. 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | Download following datasets: 33 | * ImageNet (https://pytorch.org/vision/main/generated/torchvision.datasets.ImageNet.html) 34 | * INaturalist 2021 (https://github.com/visipedia/inat_comp/tree/master/2021) 35 | 36 | Export paths to dataset directories. The directories should contain `train/` and `val/` subdirectories. 37 | ```bash 38 | export IMAGENET_PATH="" 39 | export INAT_PATH="" 40 | ``` 41 | Code was run using Python version 3.11.10. 42 | ### Running Experiments 43 | The commands required to reproduce the results are organized into scripts located in the `scripts/` directory: 44 | * `monosemanticity_score.sh` computes the Monosemanticity Score (MS) for specified SAEs, layers, models, and image encoders. 45 | * `matryoshka_hierarchy.sh` analyzes the hierarchical structure that emerges in Matryoshka SAEs. 46 | * `mllm_steering.sh` enables experimentation with steering LLaVA using an SAE built on top of the vision encoder. 47 | 48 | We use the implementation of sparse autoencoders available at https://github.com/saprmarks/dictionary_learning. 49 | ### Citation 50 | ```bibtex 51 | @article{pach2025sparse, 52 | title={Sparse Autoencoders Learn Monosemantic Features in Vision-Language Models}, 53 | author={Mateusz Pach and Shyamgopal Karthik and Quentin Bouniot and Serge Belongie and Zeynep Akata}, 54 | journal={arXiv preprint arXiv:2504.02821}, 55 | year={2025} 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # from .cc3m import get_cc3m -------------------------------------------------------------------------------- /datasets/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import os 4 | import bisect 5 | 6 | class ChunkedActivationsDataset(Dataset): 7 | def __init__(self, directory, transform=None, device="cpu"): 8 | """ 9 | Args: 10 | directory (str): Path to the directory containing .pth files. 11 | transform (callable, optional): Optional transform to be applied on a sample. 12 | device (str): Device to store the activations ('cpu' or 'cuda'). 13 | """ 14 | self.directory = directory 15 | self.files = sorted( 16 | (f for f in os.listdir(directory) if f.endswith('.pth') or f.endswith('.pt')), 17 | key=lambda x: int(x.split('_part')[-1].split('.pt')[0]) 18 | ) 19 | self.transform = transform 20 | self.device = device 21 | self.file_offsets = [] 22 | self.cumulative_lengths = [] 23 | 24 | # Cache to store the last accessed file 25 | self._cached_file = None 26 | self._cached_tensor = None 27 | 28 | # Compute the offsets and cumulative lengths for efficient indexing 29 | cumulative = 0 30 | for file in self.files: 31 | tensor = torch.load(os.path.join(directory, file)) # Load file temporarily 32 | length = tensor.size(0) # Number of samples in the file (K) 33 | self.file_offsets.append((file, length)) 34 | cumulative += length 35 | self.cumulative_lengths.append(cumulative) 36 | # break 37 | 38 | def __len__(self): 39 | return self.cumulative_lengths[-1] if self.cumulative_lengths else 0 40 | 41 | def _load_file(self, file): 42 | """ 43 | Loads a file and caches it for future access. 44 | Moves the tensor to the specified device if not already cached. 45 | """ 46 | if self._cached_file != file: 47 | file_path = os.path.join(self.directory, file) 48 | tensor = torch.load(file_path) 49 | self._cached_tensor = tensor.to(self.device) # Move to the specified device 50 | self._cached_file = file 51 | 52 | def __getitem__(self, idx): 53 | """ 54 | Args: 55 | idx (int): Index of the sample to retrieve. 56 | Returns: 57 | torch.Tensor: Neuron activation vector on the specified device. 58 | """ 59 | # Find the correct file corresponding to the index 60 | for file_idx, (file, length) in enumerate(self.file_offsets): 61 | if idx < self.cumulative_lengths[file_idx]: 62 | # Calculate relative index within the file 63 | relative_idx = idx - (self.cumulative_lengths[file_idx - 1] if file_idx > 0 else 0) 64 | 65 | # Load the file into cache if it's not already cached 66 | self._load_file(file) 67 | 68 | # Retrieve the sample 69 | sample = self._cached_tensor[relative_idx] 70 | if self.transform: 71 | sample = self.transform(sample) 72 | return sample 73 | 74 | raise IndexError(f"Index {idx} out of range for dataset with length {len(self)}") 75 | 76 | 77 | class ActivationsDataset(Dataset): 78 | # def __init__(self, directory, transform=None, device="cpu", take_every=1): 79 | # """ 80 | # Args: 81 | # directory (str): Path to the directory containing .pth files. 82 | # transform (callable, optional): Optional transform to be applied on a sample. 83 | # device (str): Device to store the activations ('cpu' or 'cuda'). 84 | # """ 85 | # self.directory = directory 86 | # self.files = sorted( 87 | # (f for f in os.listdir(directory) if f.endswith('.pth') or f.endswith('.pt')), 88 | # key=lambda x: int(x.split('_part')[-1].split('.pt')[0]) 89 | # ) 90 | # self.transform = transform 91 | # self.device = device 92 | # 93 | # # Load all tensors into memory 94 | # self.cached_tensors = [] 95 | # self.cumulative_lengths = [] 96 | # cumulative = 0 97 | # 98 | # for file in self.files: 99 | # tensor = torch.load(os.path.join(directory, file)).to(self.device) 100 | # self.cached_tensors.append(tensor) 101 | # cumulative += tensor.size(0) # Number of samples in the file (K) 102 | # self.cumulative_lengths.append(cumulative) 103 | def __init__(self, directory, transform=None, device="cpu", take_every=1): 104 | """ 105 | Args: 106 | directory (str): Path to the directory containing .pth files. 107 | transform (callable, optional): Optional transform to be applied on a sample. 108 | device (str): Device to store the activations ('cpu' or 'cuda'). 109 | take_every (int): Load every N-th row from the concatenated tensors to reduce memory usage. 110 | """ 111 | self.directory = directory 112 | self.files = sorted( 113 | (f for f in os.listdir(directory) if (f.endswith('.pth') or f.endswith('.pt')) and not f.startswith('all')), 114 | key=lambda x: int(x.split('_part')[-1].split('.pt')[0]) 115 | ) 116 | self.transform = transform 117 | self.device = device 118 | self.take_every = take_every 119 | 120 | # Load all tensors into memory with skipping rows as per take_every 121 | self.cached_tensors = [] 122 | self.cumulative_lengths = [] 123 | cumulative = 0 124 | 125 | # Track the global row index 126 | global_row_index = 0 127 | 128 | for file in self.files: 129 | tensor = torch.load(os.path.join(directory, file)).to(self.device) 130 | num_rows = tensor.size(0) 131 | 132 | # Calculate global indices for the rows in this tensor 133 | global_indices = list(range(global_row_index, global_row_index + num_rows)) 134 | selected_indices = [ 135 | idx - global_row_index for idx in global_indices if idx % self.take_every == 0 136 | ] 137 | 138 | # Extract rows using the selected indices 139 | if selected_indices: 140 | tensor = tensor[selected_indices] 141 | self.cached_tensors.append(tensor) 142 | cumulative += len(selected_indices) 143 | self.cumulative_lengths.append(cumulative) 144 | 145 | # Update the global_row_index 146 | global_row_index += num_rows 147 | 148 | def __len__(self): 149 | return self.cumulative_lengths[-1] if self.cumulative_lengths else 0 150 | 151 | def __getitem__(self, idx): 152 | """ 153 | Args: 154 | idx (int): Index of the sample to retrieve. 155 | Returns: 156 | torch.Tensor: Neuron activation vector on the specified device. 157 | """ 158 | # # Find the correct tensor corresponding to the index 159 | # for tensor_idx, tensor in enumerate(self.cached_tensors): 160 | # # return torch.randn_like(tensor[0]).to(self.device) 161 | # if idx < self.cumulative_lengths[tensor_idx]: 162 | # # Calculate relative index within the tensor 163 | # relative_idx = idx - (self.cumulative_lengths[tensor_idx - 1] if tensor_idx > 0 else 0) 164 | # 165 | # # Retrieve the sample 166 | # sample = tensor[relative_idx] 167 | # if self.transform: 168 | # sample = self.transform(sample) 169 | # return sample 170 | 171 | tensor_idx = bisect.bisect_right(self.cumulative_lengths, idx) 172 | relative_idx = idx - (self.cumulative_lengths[tensor_idx - 1] if tensor_idx > 0 else 0) 173 | sample = self.cached_tensors[tensor_idx][relative_idx] 174 | if self.transform: 175 | sample = self.transform(sample) 176 | return sample -------------------------------------------------------------------------------- /dictionary_learning/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | dictionaries 162 | wandb 163 | experiment* 164 | run_experiment.sh 165 | nohup.out 166 | *.zip -------------------------------------------------------------------------------- /dictionary_learning/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 saprmarks 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dictionary_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder 2 | from .buffer import ActivationBuffer -------------------------------------------------------------------------------- /dictionary_learning/config.py: -------------------------------------------------------------------------------- 1 | # debugging flag for use in other scripts 2 | DEBUG = False -------------------------------------------------------------------------------- /dictionary_learning/evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for evaluating dictionaries on a model and dataset. 3 | """ 4 | 5 | import torch as t 6 | from collections import defaultdict 7 | 8 | from .buffer import ActivationBuffer, NNsightActivationBuffer 9 | from nnsight import LanguageModel 10 | from .config import DEBUG 11 | 12 | 13 | def loss_recovered( 14 | text, # a batch of text 15 | model: LanguageModel, # an nnsight LanguageModel 16 | submodule, # submodules of model 17 | dictionary, # dictionaries for submodules 18 | max_len=None, # max context length for loss recovered 19 | normalize_batch=False, # normalize batch before passing through dictionary 20 | io="out", # can be 'in', 'out', or 'in_and_out' 21 | tracer_args = {'use_cache': False, 'output_attentions': False}, # minimize cache during model trace. 22 | ): 23 | """ 24 | How much of the model's loss is recovered by replacing the component output 25 | with the reconstruction by the autoencoder? 26 | """ 27 | 28 | if max_len is None: 29 | invoker_args = {} 30 | else: 31 | invoker_args = {"truncation": True, "max_length": max_len } 32 | 33 | with model.trace("_"): 34 | temp_output = submodule.output.save() 35 | 36 | output_is_tuple = False 37 | # Note: isinstance() won't work here as torch.Size is a subclass of tuple, 38 | # so isinstance(temp_output.shape, tuple) would return True even for torch.Size. 39 | if type(temp_output.shape) == tuple: 40 | output_is_tuple = True 41 | 42 | # unmodified logits 43 | with model.trace(text, invoker_args=invoker_args): 44 | logits_original = model.output.save() 45 | logits_original = logits_original.value 46 | 47 | # logits when replacing component activations with reconstruction by autoencoder 48 | with model.trace(text, **tracer_args, invoker_args=invoker_args): 49 | if io == 'in': 50 | x = submodule.input 51 | if normalize_batch: 52 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 53 | x = x * scale 54 | elif io == 'out': 55 | x = submodule.output 56 | if output_is_tuple: x = x[0] 57 | if normalize_batch: 58 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 59 | x = x * scale 60 | elif io == 'in_and_out': 61 | x = submodule.input 62 | if normalize_batch: 63 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 64 | x = x * scale 65 | else: 66 | raise ValueError(f"Invalid value for io: {io}") 67 | x = x.save() 68 | 69 | # If we incorrectly handle output_is_tuple, such as with some mlp submodules, we will get an error here. 70 | assert len(x.shape) == 3, f"Expected x to have shape (B, L, D), got {x.shape}, output_is_tuple: {output_is_tuple}" 71 | 72 | x_hat = dictionary(x).to(model.dtype) 73 | 74 | # intervene with `x_hat` 75 | with model.trace(text, **tracer_args, invoker_args=invoker_args): 76 | if io == 'in': 77 | x = submodule.input 78 | if normalize_batch: 79 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 80 | x_hat = x_hat / scale 81 | submodule.input[:] = x_hat 82 | elif io == 'out': 83 | x = submodule.output 84 | if output_is_tuple: x = x[0] 85 | if normalize_batch: 86 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 87 | x_hat = x_hat / scale 88 | if output_is_tuple: 89 | submodule.output[0][:] = x_hat 90 | else: 91 | submodule.output[:] = x_hat 92 | elif io == 'in_and_out': 93 | x = submodule.input 94 | if normalize_batch: 95 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 96 | x_hat = x_hat / scale 97 | if output_is_tuple: 98 | submodule.output[0][:] = x_hat 99 | else: 100 | submodule.output[:] = x_hat 101 | else: 102 | raise ValueError(f"Invalid value for io: {io}") 103 | 104 | logits_reconstructed = model.output.save() 105 | logits_reconstructed = logits_reconstructed.value 106 | 107 | # logits when replacing component activations with zeros 108 | with model.trace(text, **tracer_args, invoker_args=invoker_args): 109 | if io == 'in': 110 | x = submodule.input 111 | submodule.input[:] = t.zeros_like(x) 112 | elif io in ['out', 'in_and_out']: 113 | x = submodule.output 114 | if output_is_tuple: 115 | submodule.output[0][:] = t.zeros_like(x[0]) 116 | else: 117 | submodule.output[:] = t.zeros_like(x) 118 | else: 119 | raise ValueError(f"Invalid value for io: {io}") 120 | 121 | input = model.inputs.save() 122 | logits_zero = model.output.save() 123 | 124 | logits_zero = logits_zero.value 125 | 126 | # get everything into the right format 127 | try: 128 | logits_original = logits_original.logits 129 | logits_reconstructed = logits_reconstructed.logits 130 | logits_zero = logits_zero.logits 131 | except: 132 | pass 133 | 134 | if isinstance(text, t.Tensor): 135 | tokens = text 136 | else: 137 | try: 138 | tokens = input[1]['input_ids'] 139 | except: 140 | tokens = input[1]['input'] 141 | 142 | # compute losses 143 | losses = [] 144 | if hasattr(model, 'tokenizer') and model.tokenizer is not None: 145 | loss_kwargs = {'ignore_index': model.tokenizer.pad_token_id} 146 | else: 147 | loss_kwargs = {} 148 | for logits in [logits_original, logits_reconstructed, logits_zero]: 149 | loss = t.nn.CrossEntropyLoss(**loss_kwargs)( 150 | logits[:, :-1, :].reshape(-1, logits.shape[-1]), tokens[:, 1:].reshape(-1) 151 | ) 152 | losses.append(loss) 153 | 154 | return tuple(losses) 155 | 156 | @t.no_grad() 157 | def evaluate( 158 | dictionary, # a dictionary 159 | activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered 160 | max_len=128, # max context length for loss recovered 161 | batch_size=128, # batch size for loss recovered 162 | io="out", # can be 'in', 'out', or 'in_and_out' 163 | normalize_batch=False, # normalize batch before passing through dictionary 164 | tracer_args={'use_cache': False, 'output_attentions': False}, # minimize cache during model trace. 165 | device="cpu", 166 | n_batches: int = 1, 167 | ): 168 | assert n_batches > 0 169 | out = defaultdict(float) 170 | active_features = t.zeros(dictionary.dict_size, dtype=t.float32, device=device) 171 | 172 | for _ in range(n_batches): 173 | try: 174 | x = next(activations).to(device) 175 | if normalize_batch: 176 | x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim ** 0.5) 177 | except StopIteration: 178 | raise StopIteration( 179 | "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data." 180 | ) 181 | x_hat, f = dictionary(x, output_features=True) 182 | l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() 183 | l1_loss = f.norm(p=1, dim=-1).mean() 184 | l0 = (f != 0).float().sum(dim=-1).mean() 185 | 186 | features_BF = t.flatten(f, start_dim=0, end_dim=-2).to(dtype=t.float32) # If f is shape (B, L, D), flatten to (B*L, D) 187 | assert features_BF.shape[-1] == dictionary.dict_size 188 | assert len(features_BF.shape) == 2 189 | 190 | active_features += features_BF.sum(dim=0) 191 | 192 | # cosine similarity between x and x_hat 193 | x_normed = x / t.linalg.norm(x, dim=-1, keepdim=True) 194 | x_hat_normed = x_hat / t.linalg.norm(x_hat, dim=-1, keepdim=True) 195 | cossim = (x_normed * x_hat_normed).sum(dim=-1).mean() 196 | 197 | # l2 ratio 198 | l2_ratio = (t.linalg.norm(x_hat, dim=-1) / t.linalg.norm(x, dim=-1)).mean() 199 | 200 | #compute variance explained 201 | total_variance = t.var(x, dim=0).sum() 202 | residual_variance = t.var(x - x_hat, dim=0).sum() 203 | frac_variance_explained = (1 - residual_variance / total_variance) 204 | 205 | # Equation 10 from https://arxiv.org/abs/2404.16014 206 | x_hat_norm_squared = t.linalg.norm(x_hat, dim=-1, ord=2)**2 207 | x_dot_x_hat = (x * x_hat).sum(dim=-1) 208 | relative_reconstruction_bias = x_hat_norm_squared.mean() / x_dot_x_hat.mean() 209 | 210 | out["l2_loss"] += l2_loss.item() 211 | out["l1_loss"] += l1_loss.item() 212 | out["l0"] += l0.item() 213 | out["frac_variance_explained"] += frac_variance_explained.item() 214 | out["cossim"] += cossim.item() 215 | out["l2_ratio"] += l2_ratio.item() 216 | out['relative_reconstruction_bias'] += relative_reconstruction_bias.item() 217 | 218 | if not isinstance(activations, (ActivationBuffer, NNsightActivationBuffer)): 219 | continue 220 | 221 | # compute loss recovered 222 | loss_original, loss_reconstructed, loss_zero = loss_recovered( 223 | activations.text_batch(batch_size=batch_size), 224 | activations.model, 225 | activations.submodule, 226 | dictionary, 227 | max_len=max_len, 228 | normalize_batch=normalize_batch, 229 | io=io, 230 | tracer_args=tracer_args 231 | ) 232 | frac_recovered = (loss_reconstructed - loss_zero) / (loss_original - loss_zero) 233 | 234 | out["loss_original"] += loss_original.item() 235 | out["loss_reconstructed"] += loss_reconstructed.item() 236 | out["loss_zero"] += loss_zero.item() 237 | out["frac_recovered"] += frac_recovered.item() 238 | 239 | out = {key: value / n_batches for key, value in out.items()} 240 | frac_alive = (active_features != 0).float().sum() / dictionary.dict_size 241 | out["frac_alive"] = frac_alive.item() 242 | 243 | return out -------------------------------------------------------------------------------- /dictionary_learning/grad_pursuit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements batched gradient pursuit algorithm here: 3 | https://www.lesswrong.com/posts/C5KAZQib3bzzpeyrg/full-post-progress-update-1-from-the-gdm-mech-interp-team#Inference_Time_Optimisation:~:text=two%20seem%20promising.-,Details%20of%20Sparse%20Approximation%20Algorithms%20(for%20accelerators),-This%20section%20gets 4 | """ 5 | 6 | import torch as t 7 | 8 | 9 | def _grad_pursuit_update_step(signal, weights, dictionary, batch_arange, selected_features): 10 | """ 11 | signal: b x d, weights: b x n, dictionary: d x n, batch_arange: b, selected_features: b x n 12 | """ 13 | residual = signal - t.einsum('bn,dn -> bd', weights, dictionary) 14 | # choose the element with largest inner product with residual, as in matched pursuit. 15 | inner_products = t.einsum('dn,bd -> bn', dictionary, residual) 16 | idxs = t.argmax(inner_products, dim=1) 17 | # add the new feature to the active set. 18 | selected_features[batch_arange, idxs] = 1 19 | 20 | # the gradient for the weights is the inner product, restricted to the chosen features 21 | grad = selected_features * inner_products 22 | # the next two steps compute the optimal step size 23 | c = t.einsum('bn,dn -> bd', grad, dictionary) 24 | step_size = t.einsum('bd,bd -> b', c, residual) / t.einsum('bd,bd -> b ', c, c) 25 | weights = weights + t.einsum('b,bn -> bn', step_size, grad) 26 | weights = t.clip(weights, min=0) # clip the weights to be positive 27 | return weights, selected_features 28 | 29 | def grad_pursuit(signal, dictionary, target_l0 : int = 20, device : str = 'cpu'): 30 | """ 31 | Inputs: signal: b x d, dictionary: d x n, target_l0: int, device: str 32 | Outputs: weights: b x n 33 | """ 34 | assert len(signal.shape) == 2 # makes sure this a batch of signals 35 | with t.no_grad(): 36 | batch_arange = t.arange(signal.shape[0]).to(device) 37 | weights = t.zeros((signal.shape[0], dictionary.shape[1])).to(device) 38 | selected_features = t.zeros((signal.shape[0], dictionary.shape[1])).to(device) 39 | for _ in range(target_l0): 40 | weights, selected_features = _grad_pursuit_update_step( 41 | signal, weights, dictionary, batch_arange, selected_features) 42 | return weights -------------------------------------------------------------------------------- /dictionary_learning/interp.py: -------------------------------------------------------------------------------- 1 | import random 2 | from circuitsvis.activations import text_neuron_activations 3 | from einops import rearrange 4 | import torch as t 5 | from collections import namedtuple 6 | import umap 7 | import pandas as pd 8 | import plotly.express as px 9 | 10 | 11 | def feature_effect( 12 | model, 13 | submodule, 14 | dictionary, 15 | feature, 16 | inputs, 17 | max_length=128, 18 | add_residual=True, # whether to compensate for dictionary reconstruction error by adding residual 19 | k=10, 20 | largest=True, 21 | ): 22 | """ 23 | Effect of ablating the feature on top k predictions for next token. 24 | """ 25 | tracer_kwargs = { 26 | "scan": False, 27 | "validate": False, 28 | "invoker_args": dict(max_length=max_length), 29 | } 30 | # clean run 31 | with t.no_grad(), model.trace(inputs, **tracer_kwargs): 32 | if dictionary is None: 33 | pass 34 | elif not add_residual: # run hidden state through autoencoder 35 | if type(submodule.output.shape) == tuple: 36 | submodule.output[0][:] = dictionary(submodule.output[0]) 37 | else: 38 | submodule.output = dictionary(submodule.output) 39 | clean_output = model.output.save() 40 | try: 41 | clean_logits = clean_output.value.logits[:, -1, :] 42 | except: 43 | clean_logits = clean_output.value[:, -1, :] 44 | clean_logprobs = t.nn.functional.log_softmax(clean_logits, dim=-1) 45 | 46 | # ablated run 47 | with t.no_grad(), model.trace(inputs, **tracer_kwargs): 48 | if dictionary is None: 49 | if type(submodule.output.shape) == tuple: 50 | submodule.output[0][:, -1, feature] = 0 51 | else: 52 | submodule.output[:, -1, feature] = 0 53 | else: 54 | x = submodule.output 55 | if type(x.shape) == tuple: 56 | x = x[0] 57 | x_hat, f = dictionary(x, output_features=True) 58 | residual = x - x_hat 59 | 60 | f[:, -1, feature] = 0 61 | if add_residual: 62 | x_hat = dictionary.decode(f) + residual 63 | else: 64 | x_hat = dictionary.decode(f) 65 | 66 | if type(submodule.output.shape) == tuple: 67 | submodule.output[0][:] = x_hat 68 | else: 69 | submodule.output = x_hat 70 | ablated_output = model.output.save() 71 | try: 72 | ablated_logits = ablated_output.value.logits[:, -1, :] 73 | except: 74 | ablated_logits = ablated_output.value[:, -1, :] 75 | ablated_logprobs = t.nn.functional.log_softmax(ablated_logits, dim=-1) 76 | 77 | diff = clean_logprobs - ablated_logprobs 78 | top_probs, top_tokens = t.topk(diff.mean(dim=0), k=k, largest=largest) 79 | return top_tokens, top_probs 80 | 81 | 82 | def examine_dimension( 83 | model, submodule, buffer, dictionary=None, max_length=128, n_inputs=512, dim_idx=None, k=30 84 | ): 85 | 86 | tracer_kwargs = { 87 | "scan": False, 88 | "validate": False, 89 | "invoker_args": dict(max_length=max_length), 90 | } 91 | 92 | def _list_decode(x): 93 | if isinstance(x, int): 94 | return model.tokenizer.decode(x) 95 | else: 96 | return [_list_decode(y) for y in x] 97 | 98 | if dim_idx is None: 99 | dim_idx = random.randint(0, activations.shape[-1] - 1) 100 | 101 | inputs = buffer.tokenized_batch(batch_size=n_inputs) 102 | 103 | with t.no_grad(), model.trace(inputs, **tracer_kwargs): 104 | tokens = model.inputs[1][ 105 | "input_ids" 106 | ].save() # if you're getting errors, check here; might only work for pythia models 107 | activations = submodule.output 108 | if type(activations.shape) == tuple: 109 | activations = activations[0] 110 | if dictionary is not None: 111 | activations = dictionary.encode(activations) 112 | activations = activations[:, :, dim_idx].save() 113 | activations = activations.value 114 | 115 | # get top k tokens by mean activation 116 | tokens = tokens.value 117 | token_mean_acts = {} 118 | for ctx in tokens: 119 | for tok in ctx: 120 | if tok.item() in token_mean_acts: 121 | continue 122 | idxs = (tokens == tok).nonzero(as_tuple=True) 123 | token_mean_acts[tok.item()] = activations[idxs].mean().item() 124 | top_tokens = sorted(token_mean_acts.items(), key=lambda x: x[1], reverse=True)[:k] 125 | top_tokens = [(model.tokenizer.decode(tok), act) for tok, act in top_tokens] 126 | 127 | flattened_acts = rearrange(activations, "b n -> (b n)") 128 | topk_indices = t.argsort(flattened_acts, dim=0, descending=True)[:k] 129 | batch_indices = topk_indices // activations.shape[1] 130 | token_indices = topk_indices % activations.shape[1] 131 | tokens = [ 132 | tokens[batch_idx, : token_idx + 1].tolist() 133 | for batch_idx, token_idx in zip(batch_indices, token_indices) 134 | ] 135 | activations = [ 136 | activations[batch_idx, : token_id + 1, None, None] 137 | for batch_idx, token_id in zip(batch_indices, token_indices) 138 | ] 139 | decoded_tokens = _list_decode(tokens) 140 | top_contexts = text_neuron_activations(decoded_tokens, activations) 141 | 142 | top_affected = feature_effect( 143 | model, submodule, dictionary, dim_idx, tokens, max_length=max_length, k=k 144 | ) 145 | top_affected = [(model.tokenizer.decode(tok), prob.item()) for tok, prob in zip(*top_affected)] 146 | 147 | return namedtuple("featureProfile", ["top_contexts", "top_tokens", "top_affected"])( 148 | top_contexts, top_tokens, top_affected 149 | ) 150 | 151 | 152 | def feature_umap( 153 | dictionary, 154 | weight="decoder", # 'encoder' or 'decoder' 155 | # UMAP parameters 156 | n_neighbors=15, 157 | metric="cosine", 158 | min_dist=0.05, 159 | n_components=2, # dimension of the UMAP embedding 160 | feat_idxs=None, # if not none, indicate the feature with a red dot 161 | ): 162 | """ 163 | Fit a UMAP embedding of the dictionary features and return a plotly plot of the result.""" 164 | if weight == "encoder": 165 | df = pd.DataFrame(dictionary.encoder.weight.cpu().detach().numpy()) 166 | else: 167 | df = pd.DataFrame(dictionary.decoder.weight.T.cpu().detach().numpy()) 168 | reducer = umap.UMAP( 169 | n_neighbors=n_neighbors, 170 | metric=metric, 171 | min_dist=min_dist, 172 | n_components=n_components, 173 | ) 174 | embedding = reducer.fit_transform(df) 175 | if feat_idxs is None: 176 | colors = None 177 | if isinstance(feat_idxs, int): 178 | feat_idxs = [feat_idxs] 179 | else: 180 | colors = ["blue" if i not in feat_idxs else "red" for i in range(embedding.shape[0])] 181 | if n_components == 2: 182 | return px.scatter(x=embedding[:, 0], y=embedding[:, 1], hover_name=df.index, color=colors) 183 | if n_components == 3: 184 | return px.scatter_3d( 185 | x=embedding[:, 0], 186 | y=embedding[:, 1], 187 | z=embedding[:, 2], 188 | hover_name=df.index, 189 | color=colors, 190 | ) 191 | raise ValueError("n_components must be 2 or 3") 192 | -------------------------------------------------------------------------------- /dictionary_learning/pretrained_dictionary_downloader.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget https://huggingface.co/saprmarks/pythia-70m-deduped-saes/resolve/main/dictionaries_pythia-70m-deduped_10.zip 4 | unzip dictionaries_pythia-70m-deduped_10.zip -------------------------------------------------------------------------------- /dictionary_learning/requirements.txt: -------------------------------------------------------------------------------- 1 | circuitsvis>=1.43.2 2 | datasets>=2.18.0 3 | einops>=0.7.0 4 | matplotlib>=3.8.3 5 | nnsight>=0.3.0 6 | pandas>=2.2.1 7 | plotly>=5.18.0 8 | torch>=2.1.2 9 | tqdm>=4.66.1 10 | umap-learn>=0.5.6 11 | zstandard>=0.22.0 12 | wandb>=0.12.0 13 | pytest>=6.2.4 -------------------------------------------------------------------------------- /dictionary_learning/tests/test_end_to_end.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | from nnsight import LanguageModel 3 | import os 4 | import json 5 | import random 6 | 7 | from dictionary_learning.training import trainSAE 8 | from dictionary_learning.trainers.standard import StandardTrainer 9 | from dictionary_learning.trainers.top_k import TopKTrainer, AutoEncoderTopK 10 | from dictionary_learning.utils import hf_dataset_to_generator, get_nested_folders, load_dictionary 11 | from dictionary_learning.buffer import ActivationBuffer 12 | from dictionary_learning.dictionary import ( 13 | AutoEncoder, 14 | GatedAutoEncoder, 15 | AutoEncoderNew, 16 | JumpReluAutoEncoder, 17 | ) 18 | from dictionary_learning.evaluation import evaluate 19 | 20 | EXPECTED_RESULTS = { 21 | "AutoEncoderTopK": { 22 | "l2_loss": 4.362327718734742, 23 | "l1_loss": 50.94957427978515, 24 | "l0": 40.0, 25 | "frac_variance_explained": 0.9578053653240204, 26 | "cossim": 0.9478691875934601, 27 | "l2_ratio": 0.9478908002376556, 28 | "relative_reconstruction_bias": 0.999762898683548, 29 | "loss_original": 3.3361297130584715, 30 | "loss_reconstructed": 3.8404462814331053, 31 | "loss_zero": 13.251659297943116, 32 | "frac_recovered": 0.948982036113739, 33 | "frac_alive": 0.99951171875, 34 | }, 35 | "AutoEncoder": { 36 | "l2_loss": 6.822444677352905, 37 | "l1_loss": 19.382131576538086, 38 | "l0": 37.45087890625, 39 | "frac_variance_explained": 0.8993501663208008, 40 | "cossim": 0.8791120409965515, 41 | "l2_ratio": 0.74552041888237, 42 | "relative_reconstruction_bias": 0.9595054805278778, 43 | "loss_original": 3.3361297130584715, 44 | "loss_reconstructed": 5.208198881149292, 45 | "loss_zero": 13.251659297943116, 46 | "frac_recovered": 0.8106247961521149, 47 | "frac_alive": 0.99658203125, 48 | }, 49 | } 50 | 51 | DEVICE = "cuda:0" 52 | SAVE_DIR = "./test_data" 53 | MODEL_NAME = "EleutherAI/pythia-70m-deduped" 54 | RANDOM_SEED = 42 55 | LAYER = 3 56 | DATASET_NAME = "monology/pile-uncopyrighted" 57 | 58 | EVAL_TOLERANCE = 0.01 59 | 60 | 61 | def test_sae_training(): 62 | """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. 63 | This isn't a nice suite of unit tests, but it's better than nothing. 64 | I have observed that results can slightly vary with library versions. For full determinism, 65 | use pytorch 2.5.1 and nnsight 0.3.7. 66 | 67 | NOTE: `dictionary_learning` is meant to be used as a submodule. Thus, to run this test, you need to use `dictionary_learning` as a submodule 68 | and run the test from the root of the repository using `pytest -s`. Refer to https://github.com/adamkarvonen/dictionary_learning_demo for an example""" 69 | random.seed(RANDOM_SEED) 70 | t.manual_seed(RANDOM_SEED) 71 | 72 | model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) 73 | 74 | context_length = 128 75 | llm_batch_size = 512 # Fits on a 24GB GPU 76 | sae_batch_size = 8192 77 | num_contexts_per_sae_batch = sae_batch_size // context_length 78 | 79 | num_inputs_in_buffer = num_contexts_per_sae_batch * 20 80 | 81 | num_tokens = 10_000_000 82 | 83 | # sae training parameters 84 | k = 40 85 | sparsity_penalty = 2.0 86 | expansion_factor = 8 87 | 88 | steps = int(num_tokens / sae_batch_size) # Total number of batches to train 89 | save_steps = None 90 | warmup_steps = 1000 # Warmup period at start of training and after each resample 91 | resample_steps = None 92 | 93 | # standard sae training parameters 94 | learning_rate = 3e-4 95 | 96 | # topk sae training parameters 97 | decay_start = None 98 | auxk_alpha = 1 / 32 99 | 100 | submodule = model.gpt_neox.layers[LAYER] 101 | submodule_name = f"resid_post_layer_{LAYER}" 102 | io = "out" 103 | activation_dim = model.config.hidden_size 104 | 105 | generator = hf_dataset_to_generator(DATASET_NAME) 106 | 107 | activation_buffer = ActivationBuffer( 108 | generator, 109 | model, 110 | submodule, 111 | n_ctxs=num_inputs_in_buffer, 112 | ctx_len=context_length, 113 | refresh_batch_size=llm_batch_size, 114 | out_batch_size=sae_batch_size, 115 | io=io, 116 | d_submodule=activation_dim, 117 | device=DEVICE, 118 | ) 119 | 120 | # create the list of configs 121 | trainer_configs = [] 122 | trainer_configs.extend( 123 | [ 124 | { 125 | "trainer": TopKTrainer, 126 | "dict_class": AutoEncoderTopK, 127 | "lr": None, 128 | "activation_dim": activation_dim, 129 | "dict_size": expansion_factor * activation_dim, 130 | "k": k, 131 | "auxk_alpha": auxk_alpha, # see Appendix A.2 132 | "warmup_steps": 0, 133 | "decay_start": decay_start, # when does the lr decay start 134 | "steps": steps, # when when does training end 135 | "seed": RANDOM_SEED, 136 | "wandb_name": f"TopKTrainer-{MODEL_NAME}-{submodule_name}", 137 | "device": DEVICE, 138 | "layer": LAYER, 139 | "lm_name": MODEL_NAME, 140 | "submodule_name": submodule_name, 141 | }, 142 | ] 143 | ) 144 | trainer_configs.extend( 145 | [ 146 | { 147 | "trainer": StandardTrainer, 148 | "dict_class": AutoEncoder, 149 | "activation_dim": activation_dim, 150 | "dict_size": expansion_factor * activation_dim, 151 | "lr": learning_rate, 152 | "l1_penalty": sparsity_penalty, 153 | "warmup_steps": warmup_steps, 154 | "sparsity_warmup_steps": None, 155 | "decay_start": decay_start, 156 | "steps": steps, 157 | "resample_steps": resample_steps, 158 | "seed": RANDOM_SEED, 159 | "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", 160 | "layer": LAYER, 161 | "lm_name": MODEL_NAME, 162 | "device": DEVICE, 163 | "submodule_name": submodule_name, 164 | }, 165 | ] 166 | ) 167 | 168 | print(f"len trainer configs: {len(trainer_configs)}") 169 | output_dir = f"{SAVE_DIR}/{submodule_name}" 170 | 171 | trainSAE( 172 | data=activation_buffer, 173 | trainer_configs=trainer_configs, 174 | steps=steps, 175 | save_steps=save_steps, 176 | save_dir=output_dir, 177 | ) 178 | 179 | folders = get_nested_folders(output_dir) 180 | 181 | assert len(folders) == 2 182 | 183 | for folder in folders: 184 | dictionary, config = load_dictionary(folder, DEVICE) 185 | 186 | assert dictionary is not None 187 | assert config is not None 188 | 189 | 190 | def test_evaluation(): 191 | random.seed(RANDOM_SEED) 192 | t.manual_seed(RANDOM_SEED) 193 | 194 | model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) 195 | ae_paths = get_nested_folders(SAVE_DIR) 196 | 197 | context_length = 128 198 | llm_batch_size = 100 199 | sae_batch_size = 4096 200 | n_batches = 10 201 | buffer_size = 256 202 | io = "out" 203 | 204 | generator = hf_dataset_to_generator(DATASET_NAME) 205 | submodule = model.gpt_neox.layers[LAYER] 206 | 207 | input_strings = [] 208 | for i, example in enumerate(generator): 209 | input_strings.append(example) 210 | if i > buffer_size * n_batches: 211 | break 212 | 213 | for ae_path in ae_paths: 214 | dictionary, config = load_dictionary(ae_path, DEVICE) 215 | dictionary = dictionary.to(dtype=model.dtype) 216 | 217 | activation_dim = config["trainer"]["activation_dim"] 218 | context_length = config["buffer"]["ctx_len"] 219 | 220 | activation_buffer_data = iter(input_strings) 221 | 222 | activation_buffer = ActivationBuffer( 223 | activation_buffer_data, 224 | model, 225 | submodule, 226 | n_ctxs=buffer_size, 227 | ctx_len=context_length, 228 | refresh_batch_size=llm_batch_size, 229 | out_batch_size=sae_batch_size, 230 | io=io, 231 | d_submodule=activation_dim, 232 | device=DEVICE, 233 | ) 234 | 235 | eval_results = evaluate( 236 | dictionary, 237 | activation_buffer, 238 | context_length, 239 | llm_batch_size, 240 | io=io, 241 | device=DEVICE, 242 | n_batches=n_batches, 243 | ) 244 | 245 | print(eval_results) 246 | 247 | dict_class = config["trainer"]["dict_class"] 248 | expected_results = EXPECTED_RESULTS[dict_class] 249 | 250 | max_diff = 0 251 | max_diff_percent = 0 252 | for key, value in expected_results.items(): 253 | diff = abs(eval_results[key] - value) 254 | max_diff = max(max_diff, diff) 255 | max_diff_percent = max(max_diff_percent, diff / value) 256 | 257 | print(f"Max diff: {max_diff}, max diff %: {max_diff_percent}") 258 | assert max_diff < EVAL_TOLERANCE 259 | -------------------------------------------------------------------------------- /dictionary_learning/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .standard import StandardTrainer 2 | from .gdm import GatedSAETrainer 3 | from .p_anneal import PAnnealTrainer 4 | from .gated_anneal import GatedAnnealTrainer 5 | from .top_k import TopKTrainer 6 | from .jumprelu import JumpReluTrainer 7 | from .batch_top_k import BatchTopKTrainer, BatchTopKSAE 8 | from .matroyshka_batch_top_k import MatroyshkaBatchTopKTrainer, MatroyshkaBatchTopKSAE -------------------------------------------------------------------------------- /dictionary_learning/trainers/batch_top_k.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import einops 5 | from collections import namedtuple 6 | from typing import Optional 7 | 8 | from ..dictionary import Dictionary 9 | from ..trainers.trainer import ( 10 | SAETrainer, 11 | get_lr_schedule, 12 | set_decoder_norm_to_unit_norm, 13 | remove_gradient_parallel_to_decoder_directions, 14 | ) 15 | 16 | 17 | class BatchTopKSAE(Dictionary, nn.Module): 18 | def __init__(self, activation_dim: int, dict_size: int, k: int): 19 | super().__init__() 20 | self.activation_dim = activation_dim 21 | self.dict_size = dict_size 22 | 23 | assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" 24 | self.register_buffer("k", t.tensor(k, dtype=t.int)) 25 | self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) 26 | 27 | self.decoder = nn.Linear(dict_size, activation_dim, bias=False) 28 | self.decoder.weight.data = set_decoder_norm_to_unit_norm( 29 | self.decoder.weight, activation_dim, dict_size 30 | ) 31 | 32 | self.encoder = nn.Linear(activation_dim, dict_size) 33 | self.encoder.weight.data = self.decoder.weight.T.clone() 34 | self.encoder.bias.data.zero_() 35 | self.b_dec = nn.Parameter(t.zeros(activation_dim)) 36 | 37 | def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True): 38 | post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) 39 | 40 | if use_threshold: 41 | encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) 42 | else: 43 | # Flatten and perform batch top-k 44 | flattened_acts = post_relu_feat_acts_BF.flatten() 45 | post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) 46 | 47 | encoded_acts_BF = ( 48 | t.zeros_like(post_relu_feat_acts_BF.flatten()) 49 | .scatter_(-1, post_topk.indices, post_topk.values) 50 | .reshape(post_relu_feat_acts_BF.shape) 51 | ) 52 | 53 | if return_active: 54 | return encoded_acts_BF, encoded_acts_BF.sum(0) > 0, post_relu_feat_acts_BF 55 | else: 56 | return encoded_acts_BF 57 | 58 | def decode(self, x: t.Tensor) -> t.Tensor: 59 | return self.decoder(x) + self.b_dec 60 | 61 | def forward(self, x: t.Tensor, output_features: bool = False): 62 | encoded_acts_BF = self.encode(x) 63 | x_hat_BD = self.decode(encoded_acts_BF) 64 | 65 | if not output_features: 66 | return x_hat_BD 67 | else: 68 | return x_hat_BD, encoded_acts_BF 69 | 70 | def scale_biases(self, scale: float): 71 | self.encoder.bias.data *= scale 72 | self.b_dec.data *= scale 73 | if self.threshold >= 0: 74 | self.threshold *= scale 75 | 76 | @classmethod 77 | def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": 78 | state_dict = t.load(path) 79 | dict_size, activation_dim = state_dict["encoder.weight"].shape 80 | if k is None: 81 | k = state_dict["k"].item() 82 | elif "k" in state_dict and k != state_dict["k"].item(): 83 | raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") 84 | 85 | autoencoder = cls(activation_dim, dict_size, k) 86 | autoencoder.load_state_dict(state_dict) 87 | if device is not None: 88 | autoencoder.to(device) 89 | return autoencoder 90 | 91 | 92 | class BatchTopKTrainer(SAETrainer): 93 | def __init__( 94 | self, 95 | steps: int, # total number of steps to train for 96 | activation_dim: int, 97 | dict_size: int, 98 | k: int, 99 | layer: int, 100 | lm_name: str, 101 | dict_class: type = BatchTopKSAE, 102 | lr: Optional[float] = None, 103 | auxk_alpha: float = 1 / 32, 104 | warmup_steps: int = 1000, 105 | decay_start: Optional[int] = None, # when does the lr decay start 106 | threshold_beta: float = 0.999, 107 | threshold_start_step: int = 1000, 108 | seed: Optional[int] = None, 109 | device: Optional[str] = None, 110 | wandb_name: str = "BatchTopKSAE", 111 | submodule_name: Optional[str] = None, 112 | ): 113 | super().__init__(seed) 114 | assert layer is not None and lm_name is not None 115 | self.layer = layer 116 | self.lm_name = lm_name 117 | self.submodule_name = submodule_name 118 | self.wandb_name = wandb_name 119 | self.steps = steps 120 | self.decay_start = decay_start 121 | self.warmup_steps = warmup_steps 122 | self.k = k 123 | self.threshold_beta = threshold_beta 124 | self.threshold_start_step = threshold_start_step 125 | 126 | if seed is not None: 127 | t.manual_seed(seed) 128 | t.cuda.manual_seed_all(seed) 129 | 130 | self.ae = dict_class(activation_dim, dict_size, k) 131 | 132 | if device is None: 133 | self.device = "cuda" if t.cuda.is_available() else "cpu" 134 | else: 135 | self.device = device 136 | self.ae.to(self.device) 137 | 138 | if lr is not None: 139 | self.lr = lr 140 | else: 141 | # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper 142 | scale = dict_size / (2**14) 143 | self.lr = 2e-4 / scale**0.5 144 | 145 | self.auxk_alpha = auxk_alpha 146 | self.dead_feature_threshold = 10_000_000 147 | self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper 148 | self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) 149 | self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"] 150 | self.effective_l0 = -1 151 | self.dead_features = -1 152 | self.pre_norm_auxk_loss = -1 153 | 154 | self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) 155 | 156 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start=decay_start) 157 | 158 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) 159 | 160 | def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor): 161 | dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold 162 | self.dead_features = int(dead_features.sum()) 163 | 164 | if dead_features.sum() > 0: 165 | k_aux = min(self.top_k_aux, dead_features.sum()) 166 | 167 | auxk_latents = t.where(dead_features[None], post_relu_acts_BF, -t.inf) 168 | 169 | # Top-k dead latents 170 | auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) 171 | 172 | auxk_buffer_BF = t.zeros_like(post_relu_acts_BF) 173 | auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) 174 | 175 | # Note: decoder(), not decode(), as we don't want to apply the bias 176 | x_reconstruct_aux = self.ae.decoder(auxk_acts_BF) 177 | l2_loss_aux = ( 178 | (residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean() 179 | ) 180 | 181 | self.pre_norm_auxk_loss = l2_loss_aux 182 | 183 | # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614 184 | residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape) 185 | loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() 186 | normalized_auxk_loss = l2_loss_aux / loss_denom 187 | 188 | return normalized_auxk_loss.nan_to_num(0.0) 189 | else: 190 | self.pre_norm_auxk_loss = -1 191 | return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device) 192 | 193 | def update_threshold(self, f: t.Tensor): 194 | device_type = "cuda" if f.is_cuda else "cpu" 195 | with t.autocast(device_type=device_type, enabled=False), t.no_grad(): 196 | active = f[f > 0] 197 | 198 | if active.size(0) == 0: 199 | min_activation = 0.0 200 | else: 201 | min_activation = active.min().detach().to(dtype=t.float32) 202 | 203 | if self.ae.threshold < 0: 204 | self.ae.threshold = min_activation 205 | else: 206 | self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( 207 | (1 - self.threshold_beta) * min_activation 208 | ) 209 | 210 | def loss(self, x, step=None, logging=False): 211 | f, active_indices_F, post_relu_acts_BF = self.ae.encode( 212 | x, return_active=True, use_threshold=False 213 | ) 214 | # l0 = (f != 0).float().sum(dim=-1).mean().item() 215 | 216 | if step > self.threshold_start_step: 217 | self.update_threshold(f) 218 | 219 | x_hat = self.ae.decode(f) 220 | 221 | e = x - x_hat 222 | 223 | self.effective_l0 = self.k 224 | 225 | num_tokens_in_step = x.size(0) 226 | did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool) 227 | did_fire[active_indices_F] = True 228 | self.num_tokens_since_fired += num_tokens_in_step 229 | self.num_tokens_since_fired[did_fire] = 0 230 | 231 | l2_loss = e.pow(2).sum(dim=-1).mean() 232 | auxk_loss = self.get_auxiliary_loss(e.detach(), post_relu_acts_BF) 233 | loss = l2_loss + self.auxk_alpha * auxk_loss 234 | 235 | if not logging: 236 | return loss 237 | else: 238 | return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])( 239 | x, 240 | x_hat, 241 | f, 242 | {"l2_loss": l2_loss.item(), "auxk_loss": auxk_loss.item(), "loss": loss.item()}, 243 | ) 244 | 245 | def update(self, step, x): 246 | if step == 0: 247 | median = self.geometric_median(x) 248 | median = median.to(self.ae.b_dec.dtype) 249 | self.ae.b_dec.data = median 250 | 251 | x = x.to(self.device) 252 | loss = self.loss(x, step=step) 253 | loss.backward() 254 | 255 | self.ae.decoder.weight.grad = remove_gradient_parallel_to_decoder_directions( 256 | self.ae.decoder.weight, 257 | self.ae.decoder.weight.grad, 258 | self.ae.activation_dim, 259 | self.ae.dict_size, 260 | ) 261 | t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) 262 | 263 | self.optimizer.step() 264 | self.optimizer.zero_grad() 265 | self.scheduler.step() 266 | 267 | # Make sure the decoder is still unit-norm 268 | self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm( 269 | self.ae.decoder.weight, self.ae.activation_dim, self.ae.dict_size 270 | ) 271 | 272 | return loss.item() 273 | 274 | @property 275 | def config(self): 276 | return { 277 | "trainer_class": "BatchTopKTrainer", 278 | "dict_class": "BatchTopKSAE", 279 | "lr": self.lr, 280 | "steps": self.steps, 281 | "auxk_alpha": self.auxk_alpha, 282 | "warmup_steps": self.warmup_steps, 283 | "decay_start": self.decay_start, 284 | "threshold_beta": self.threshold_beta, 285 | "threshold_start_step": self.threshold_start_step, 286 | "top_k_aux": self.top_k_aux, 287 | "seed": self.seed, 288 | "activation_dim": self.ae.activation_dim, 289 | "dict_size": self.ae.dict_size, 290 | "k": self.ae.k.item(), 291 | "device": self.device, 292 | "layer": self.layer, 293 | "lm_name": self.lm_name, 294 | "wandb_name": self.wandb_name, 295 | "submodule_name": self.submodule_name, 296 | } 297 | 298 | @staticmethod 299 | def geometric_median(points: t.Tensor, max_iter: int = 100, tol: float = 1e-5): 300 | guess = points.mean(dim=0) 301 | prev = t.zeros_like(guess) 302 | weights = t.ones(len(points), device=points.device) 303 | 304 | for _ in range(max_iter): 305 | prev = guess 306 | weights = 1 / t.norm(points - guess, dim=1) 307 | weights /= weights.sum() 308 | guess = (weights.unsqueeze(1) * points).sum(dim=0) 309 | if t.norm(guess - prev) < tol: 310 | break 311 | 312 | return guess 313 | -------------------------------------------------------------------------------- /dictionary_learning/trainers/gated_anneal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the training scheme for a gated SAE described in https://arxiv.org/abs/2404.16014 3 | """ 4 | 5 | import torch as t 6 | from typing import Optional 7 | 8 | from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam 9 | from ..config import DEBUG 10 | from ..dictionary import GatedAutoEncoder 11 | from collections import namedtuple 12 | 13 | class GatedAnnealTrainer(SAETrainer): 14 | """ 15 | Gated SAE training scheme with p-annealing. 16 | """ 17 | def __init__(self, 18 | steps: int, # total number of steps to train for 19 | activation_dim: int, 20 | dict_size: int, 21 | layer: int, 22 | lm_name: str, 23 | dict_class: type = GatedAutoEncoder, 24 | lr: float = 3e-4, 25 | warmup_steps: int = 1000, # lr warmup period at start of training and after each resample 26 | sparsity_warmup_steps: Optional[int] = 2000, # sparsity warmup period at start of training 27 | decay_start: Optional[int] = None, # decay learning rate after this many steps 28 | sparsity_function: str = 'Lp^p', # Lp or Lp^p 29 | initial_sparsity_penalty: float = 1e-1, # equal to l1 penalty in standard trainer 30 | anneal_start: int = 15000, # step at which to start annealing p 31 | anneal_end: Optional[int] = None, # step at which to stop annealing, defaults to steps-1 32 | p_start: float = 1, # starting value of p (constant throughout warmup) 33 | p_end: float = 0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded 34 | n_sparsity_updates: int | str = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times 35 | sparsity_queue_length: int = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty 36 | resample_steps: Optional[int] = None, # number of steps after which to resample dead neurons 37 | device: Optional[str] = None, 38 | seed: Optional[int] = 42, 39 | wandb_name: str = 'GatedAnnealTrainer', 40 | ): 41 | super().__init__(seed) 42 | 43 | assert layer is not None and lm_name is not None 44 | self.layer = layer 45 | self.lm_name = lm_name 46 | 47 | if seed is not None: 48 | t.manual_seed(seed) 49 | t.cuda.manual_seed_all(seed) 50 | 51 | # initialize dictionary 52 | # initialize dictionary 53 | self.activation_dim = activation_dim 54 | self.dict_size = dict_size 55 | self.ae = dict_class(activation_dim, dict_size) 56 | 57 | if device is None: 58 | self.device = 'cuda' if t.cuda.is_available() else 'cpu' 59 | else: 60 | self.device = device 61 | self.ae.to(self.device) 62 | 63 | self.lr = lr 64 | self.sparsity_function = sparsity_function 65 | self.anneal_start = anneal_start 66 | self.anneal_end = anneal_end if anneal_end is not None else steps 67 | self.p_start = p_start 68 | self.p_end = p_end 69 | self.p = p_start # p is set in self.loss() 70 | self.next_p = None # set in self.loss() 71 | self.lp_loss = None # set in self.loss() 72 | self.scaled_lp_loss = None # set in self.loss() 73 | if n_sparsity_updates == "continuous": 74 | self.n_sparsity_updates = self.anneal_end - anneal_start +1 75 | else: 76 | self.n_sparsity_updates = n_sparsity_updates 77 | self.sparsity_update_steps = t.linspace(anneal_start, self.anneal_end, self.n_sparsity_updates, dtype=int) 78 | self.p_values = t.linspace(p_start, p_end, self.n_sparsity_updates) 79 | self.p_step_count = 0 80 | self.sparsity_coeff = initial_sparsity_penalty # alpha 81 | self.sparsity_queue_length = sparsity_queue_length 82 | self.sparsity_queue = [] 83 | 84 | self.warmup_steps = warmup_steps 85 | self.sparsity_warmup_steps = sparsity_warmup_steps 86 | self.decay_start = decay_start 87 | self.steps = steps 88 | self.logging_parameters = ['p', 'next_p', 'lp_loss', 'scaled_lp_loss', 'sparsity_coeff'] 89 | self.seed = seed 90 | self.wandb_name = wandb_name 91 | 92 | self.resample_steps = resample_steps 93 | if self.resample_steps is not None: 94 | # how many steps since each neuron was last activated? 95 | self.steps_since_active = t.zeros(self.dict_size, dtype=int).to(self.device) 96 | else: 97 | self.steps_since_active = None 98 | 99 | self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr, betas=(0.0, 0.999)) 100 | 101 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps, sparsity_warmup_steps) 102 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) 103 | self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) 104 | 105 | def resample_neurons(self, deads, activations): 106 | with t.no_grad(): 107 | if deads.sum() == 0: return 108 | print(f"resampling {deads.sum().item()} neurons") 109 | 110 | # compute loss for each activation 111 | losses = (activations - self.ae(activations)).norm(dim=-1) 112 | 113 | # sample input to create encoder/decoder weights from 114 | n_resample = min([deads.sum(), losses.shape[0]]) 115 | indices = t.multinomial(losses, num_samples=n_resample, replacement=False) 116 | sampled_vecs = activations[indices] 117 | 118 | # reset encoder/decoder weights for dead neurons 119 | alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() 120 | self.ae.encoder.weight[deads][:n_resample] = sampled_vecs * alive_norm * 0.2 121 | self.ae.decoder.weight[:,deads][:,:n_resample] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T 122 | self.ae.encoder.bias[deads][:n_resample] = 0. 123 | 124 | 125 | # reset Adam parameters for dead neurons 126 | state_dict = self.optimizer.state_dict()['state'] 127 | ## encoder weight 128 | state_dict[1]['exp_avg'][deads] = 0. 129 | state_dict[1]['exp_avg_sq'][deads] = 0. 130 | ## encoder bias 131 | state_dict[2]['exp_avg'][deads] = 0. 132 | state_dict[2]['exp_avg_sq'][deads] = 0. 133 | ## decoder weight 134 | state_dict[3]['exp_avg'][:,deads] = 0. 135 | state_dict[3]['exp_avg_sq'][:,deads] = 0. 136 | 137 | def lp_norm(self, f, p): 138 | norm_sq = f.pow(p).sum(dim=-1) 139 | if self.sparsity_function == 'Lp^p': 140 | return norm_sq.mean() 141 | elif self.sparsity_function == 'Lp': 142 | return norm_sq.pow(1/p).mean() 143 | else: 144 | raise ValueError("Sparsity function must be 'Lp' or 'Lp^p'") 145 | 146 | def loss(self, x:t.Tensor, step:int, logging=False, **kwargs): 147 | sparsity_scale = self.sparsity_warmup_fn(step) 148 | f, f_gate = self.ae.encode(x, return_gate=True) 149 | x_hat = self.ae.decode(f) 150 | x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach() 151 | 152 | L_recon = (x - x_hat).pow(2).sum(dim=-1).mean() 153 | L_aux = (x - x_hat_gate).pow(2).sum(dim=-1).mean() 154 | 155 | fs = f_gate # feature activation that we use for sparsity term 156 | lp_loss = self.lp_norm(fs, self.p) 157 | scaled_lp_loss = lp_loss * self.sparsity_coeff * sparsity_scale 158 | self.lp_loss = lp_loss 159 | self.scaled_lp_loss = scaled_lp_loss 160 | 161 | if self.next_p is not None: 162 | lp_loss_next = self.lp_norm(fs, self.next_p) 163 | self.sparsity_queue.append([self.lp_loss.item(), lp_loss_next.item()]) 164 | self.sparsity_queue = self.sparsity_queue[-self.sparsity_queue_length:] 165 | 166 | if step in self.sparsity_update_steps: 167 | # check to make sure we don't update on repeat step: 168 | if step >= self.sparsity_update_steps[self.p_step_count]: 169 | # Adapt sparsity penalty alpha 170 | if self.next_p is not None: 171 | local_sparsity_new = t.tensor([i[0] for i in self.sparsity_queue]).mean() 172 | local_sparsity_old = t.tensor([i[1] for i in self.sparsity_queue]).mean() 173 | self.sparsity_coeff = self.sparsity_coeff * (local_sparsity_new / local_sparsity_old).item() 174 | # Update p 175 | self.p = self.p_values[self.p_step_count].item() 176 | if self.p_step_count < self.n_sparsity_updates-1: 177 | self.next_p = self.p_values[self.p_step_count+1].item() 178 | else: 179 | self.next_p = self.p_end 180 | self.p_step_count += 1 181 | 182 | # Update dead feature count 183 | if self.steps_since_active is not None: 184 | # update steps_since_active 185 | deads = (f == 0).all(dim=0) 186 | self.steps_since_active[deads] += 1 187 | self.steps_since_active[~deads] = 0 188 | 189 | loss = L_recon + scaled_lp_loss + L_aux 190 | 191 | if not logging: 192 | return loss 193 | else: 194 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( 195 | x, x_hat, f, 196 | { 197 | 'mse_loss' : L_recon.item(), 198 | 'aux_loss' : L_aux.item(), 199 | 'loss' : loss.item(), 200 | 'p' : self.p, 201 | 'next_p' : self.next_p, 202 | 'lp_loss' : lp_loss.item(), 203 | 'sparsity_loss' : scaled_lp_loss.item(), 204 | 'sparsity_coeff' : self.sparsity_coeff, 205 | } 206 | ) 207 | 208 | def update(self, step, activations): 209 | activations = activations.to(self.device) 210 | 211 | self.optimizer.zero_grad() 212 | loss = self.loss(activations, step, logging=False) 213 | loss.backward() 214 | self.optimizer.step() 215 | self.scheduler.step() 216 | 217 | if self.resample_steps is not None and step % self.resample_steps == self.resample_steps - 1: 218 | self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) 219 | 220 | # @property 221 | # def config(self): 222 | # return { 223 | # 'trainer_class' : 'GatedSAETrainer', 224 | # 'activation_dim' : self.ae.activation_dim, 225 | # 'dict_size' : self.ae.dict_size, 226 | # 'lr' : self.lr, 227 | # 'l1_penalty' : self.l1_penalty, 228 | # 'warmup_steps' : self.warmup_steps, 229 | # 'device' : self.device, 230 | # 'wandb_name': self.wandb_name, 231 | # } 232 | 233 | @property 234 | def config(self): 235 | return { 236 | 'trainer_class' : "GatedAnnealTrainer", 237 | 'dict_class' : "GatedAutoEncoder", 238 | 'activation_dim' : self.activation_dim, 239 | 'dict_size' : self.dict_size, 240 | 'lr' : self.lr, 241 | 'sparsity_function' : self.sparsity_function, 242 | 'sparsity_penalty' : self.sparsity_coeff, 243 | 'p_start' : self.p_start, 244 | 'p_end' : self.p_end, 245 | 'anneal_start' : self.anneal_start, 246 | 'sparsity_queue_length' : self.sparsity_queue_length, 247 | 'n_sparsity_updates' : self.n_sparsity_updates, 248 | 'warmup_steps' : self.warmup_steps, 249 | 'resample_steps' : self.resample_steps, 250 | 'sparsity_warmup_steps' : self.sparsity_warmup_steps, 251 | 'decay_start' : self.decay_start, 252 | 'steps' : self.steps, 253 | 'seed' : self.seed, 254 | 'layer' : self.layer, 255 | 'lm_name' : self.lm_name, 256 | 'wandb_name' : self.wandb_name, 257 | } 258 | -------------------------------------------------------------------------------- /dictionary_learning/trainers/gdm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the training scheme for a gated SAE described in https://arxiv.org/abs/2404.16014 3 | """ 4 | 5 | import torch as t 6 | from typing import Optional 7 | 8 | from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam 9 | from ..config import DEBUG 10 | from ..dictionary import GatedAutoEncoder 11 | from collections import namedtuple 12 | 13 | class GatedSAETrainer(SAETrainer): 14 | """ 15 | Gated SAE training scheme. 16 | """ 17 | def __init__(self, 18 | steps: int, # total number of steps to train for 19 | activation_dim: int, 20 | dict_size: int, 21 | layer: int, 22 | lm_name: str, 23 | dict_class = GatedAutoEncoder, 24 | lr: float = 5e-5, 25 | l1_penalty: float = 1e-1, 26 | warmup_steps: int = 1000, # lr warmup period at start of training and after each resample 27 | sparsity_warmup_steps: Optional[int] = 2000, 28 | decay_start:Optional[int]=None, # decay learning rate after this many steps 29 | seed: Optional[int] = None, 30 | device: Optional[str] = None, 31 | wandb_name: Optional[str] = 'GatedSAETrainer', 32 | submodule_name: Optional[str] = None, 33 | ): 34 | super().__init__(seed) 35 | 36 | assert layer is not None and lm_name is not None 37 | self.layer = layer 38 | self.lm_name = lm_name 39 | self.submodule_name = submodule_name 40 | 41 | if seed is not None: 42 | t.manual_seed(seed) 43 | t.cuda.manual_seed_all(seed) 44 | 45 | # initialize dictionary 46 | self.ae = dict_class(activation_dim, dict_size) 47 | 48 | self.lr = lr 49 | self.l1_penalty=l1_penalty 50 | self.warmup_steps = warmup_steps 51 | self.sparsity_warmup_steps = sparsity_warmup_steps 52 | self.decay_start = decay_start 53 | self.wandb_name = wandb_name 54 | 55 | if device is None: 56 | self.device = 'cuda' if t.cuda.is_available() else 'cpu' 57 | else: 58 | self.device = device 59 | self.ae.to(self.device) 60 | 61 | self.optimizer = ConstrainedAdam( 62 | self.ae.parameters(), 63 | self.ae.decoder.parameters(), 64 | lr=lr, 65 | betas=(0.0, 0.999), 66 | ) 67 | 68 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps=None, sparsity_warmup_steps=sparsity_warmup_steps) 69 | 70 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_fn) 71 | self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) 72 | 73 | 74 | def loss(self, x:t.Tensor, step:int, logging:bool=False, **kwargs): 75 | 76 | sparsity_scale = self.sparsity_warmup_fn(step) 77 | 78 | f, f_gate = self.ae.encode(x, return_gate=True) 79 | x_hat = self.ae.decode(f) 80 | x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach() 81 | 82 | L_recon = (x - x_hat).pow(2).sum(dim=-1).mean() 83 | L_sparse = t.linalg.norm(f_gate, ord=1, dim=-1).mean() 84 | L_aux = (x - x_hat_gate).pow(2).sum(dim=-1).mean() 85 | 86 | loss = L_recon + (self.l1_penalty * L_sparse * sparsity_scale) + L_aux 87 | 88 | if not logging: 89 | return loss 90 | else: 91 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( 92 | x, x_hat, f, 93 | { 94 | 'mse_loss' : L_recon.item(), 95 | 'sparsity_loss' : L_sparse.item(), 96 | 'aux_loss' : L_aux.item(), 97 | 'loss' : loss.item() 98 | } 99 | ) 100 | 101 | def update(self, step, x): 102 | x = x.to(self.device) 103 | self.optimizer.zero_grad() 104 | loss = self.loss(x, step) 105 | loss.backward() 106 | self.optimizer.step() 107 | self.scheduler.step() 108 | 109 | @property 110 | def config(self): 111 | return { 112 | 'dict_class': 'GatedAutoEncoder', 113 | 'trainer_class' : 'GatedSAETrainer', 114 | 'activation_dim' : self.ae.activation_dim, 115 | 'dict_size' : self.ae.dict_size, 116 | 'lr' : self.lr, 117 | 'l1_penalty' : self.l1_penalty, 118 | 'warmup_steps' : self.warmup_steps, 119 | 'sparsity_warmup_steps' : self.sparsity_warmup_steps, 120 | 'decay_start' : self.decay_start, 121 | 'seed' : self.seed, 122 | 'device' : self.device, 123 | 'layer' : self.layer, 124 | 'lm_name' : self.lm_name, 125 | 'wandb_name': self.wandb_name, 126 | 'submodule_name': self.submodule_name, 127 | } 128 | -------------------------------------------------------------------------------- /dictionary_learning/trainers/jumprelu.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | import torch.autograd as autograd 5 | from torch import nn 6 | from typing import Optional 7 | 8 | from ..dictionary import Dictionary, JumpReluAutoEncoder 9 | from ..trainers.trainer import ( 10 | SAETrainer, 11 | get_lr_schedule, 12 | get_sparsity_warmup_fn, 13 | set_decoder_norm_to_unit_norm, 14 | remove_gradient_parallel_to_decoder_directions, 15 | ) 16 | 17 | 18 | class RectangleFunction(autograd.Function): 19 | @staticmethod 20 | def forward(ctx, x): 21 | ctx.save_for_backward(x) 22 | return ((x > -0.5) & (x < 0.5)).float() 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | (x,) = ctx.saved_tensors 27 | grad_input = grad_output.clone() 28 | grad_input[(x <= -0.5) | (x >= 0.5)] = 0 29 | return grad_input 30 | 31 | 32 | class JumpReLUFunction(autograd.Function): 33 | @staticmethod 34 | def forward(ctx, x, threshold, bandwidth): 35 | ctx.save_for_backward(x, threshold, torch.tensor(bandwidth)) 36 | return x * (x > threshold).float() 37 | 38 | @staticmethod 39 | def backward(ctx, grad_output): 40 | x, threshold, bandwidth_tensor = ctx.saved_tensors 41 | bandwidth = bandwidth_tensor.item() 42 | x_grad = (x > threshold).float() * grad_output 43 | threshold_grad = ( 44 | -(threshold / bandwidth) 45 | * RectangleFunction.apply((x - threshold) / bandwidth) 46 | * grad_output 47 | ) 48 | return x_grad, threshold_grad, None # None for bandwidth 49 | 50 | 51 | class StepFunction(autograd.Function): 52 | @staticmethod 53 | def forward(ctx, x, threshold, bandwidth): 54 | ctx.save_for_backward(x, threshold, torch.tensor(bandwidth)) 55 | return (x > threshold).float() 56 | 57 | @staticmethod 58 | def backward(ctx, grad_output): 59 | x, threshold, bandwidth_tensor = ctx.saved_tensors 60 | bandwidth = bandwidth_tensor.item() 61 | x_grad = torch.zeros_like(x) 62 | threshold_grad = ( 63 | -(1.0 / bandwidth) * RectangleFunction.apply((x - threshold) / bandwidth) * grad_output 64 | ) 65 | return x_grad, threshold_grad, None # None for bandwidth 66 | 67 | 68 | class JumpReluTrainer(nn.Module, SAETrainer): 69 | """ 70 | Trains a JumpReLU autoencoder. 71 | 72 | Note does not use learning rate or sparsity scheduling as in the paper. 73 | """ 74 | 75 | def __init__( 76 | self, 77 | steps: int, # total number of steps to train for 78 | activation_dim: int, 79 | dict_size: int, 80 | layer: int, 81 | lm_name: str, 82 | dict_class=JumpReluAutoEncoder, 83 | seed: Optional[int] = None, 84 | # TODO: What's the default lr use in the paper? 85 | lr: float = 7e-5, 86 | bandwidth: float = 0.001, 87 | sparsity_penalty: float = 1.0, 88 | warmup_steps: int = 1000, # lr warmup period at start of training and after each resample 89 | sparsity_warmup_steps: Optional[int] = 2000, # sparsity warmup period at start of training 90 | decay_start: Optional[int] = None, # decay learning rate after this many steps 91 | target_l0: float = 20.0, 92 | device: str = "cpu", 93 | wandb_name: str = "JumpRelu", 94 | submodule_name: Optional[str] = None, 95 | ): 96 | super().__init__() 97 | 98 | # TODO: Should just be args, and this should be commonised 99 | assert layer is not None, "Layer must be specified" 100 | assert lm_name is not None, "Language model name must be specified" 101 | self.lm_name = lm_name 102 | self.layer = layer 103 | self.submodule_name = submodule_name 104 | self.device = device 105 | self.steps = steps 106 | self.lr = lr 107 | self.seed = seed 108 | 109 | self.bandwidth = bandwidth 110 | self.sparsity_coefficient = sparsity_penalty 111 | self.warmup_steps = warmup_steps 112 | self.sparsity_warmup_steps = sparsity_warmup_steps 113 | self.decay_start = decay_start 114 | self.target_l0 = target_l0 115 | 116 | # TODO: Better auto-naming (e.g. in BatchTopK package) 117 | self.wandb_name = wandb_name 118 | 119 | # TODO: Why not just pass in the initialised autoencoder instead? 120 | self.ae = dict_class( 121 | activation_dim=activation_dim, 122 | dict_size=dict_size, 123 | device=device, 124 | ).to(self.device) 125 | 126 | # Parameters from the paper 127 | self.optimizer = torch.optim.Adam(self.ae.parameters(), lr=lr, betas=(0.0, 0.999), eps=1e-8) 128 | 129 | lr_fn = get_lr_schedule( 130 | steps, 131 | warmup_steps, 132 | decay_start, 133 | resample_steps=None, 134 | sparsity_warmup_steps=sparsity_warmup_steps, 135 | ) 136 | 137 | self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) 138 | 139 | self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) 140 | 141 | # Purely for logging purposes 142 | self.dead_feature_threshold = 10_000_000 143 | self.num_tokens_since_fired = torch.zeros(dict_size, dtype=torch.long, device=device) 144 | self.dead_features = -1 145 | self.logging_parameters = ["dead_features"] 146 | 147 | def loss(self, x: torch.Tensor, step: int, logging=False, **_): 148 | # Note: We are using threshold, not log_threshold as in this notebook: 149 | # https://colab.research.google.com/drive/1PlFzI_PWGTN9yCQLuBcSuPJUjgHL7GiD#scrollTo=yP828a6uIlSO 150 | # I had poor results when using log_threshold and it would complicate the scale_biases() function 151 | 152 | sparsity_scale = self.sparsity_warmup_fn(step) 153 | x = x.to(self.ae.W_enc.dtype) 154 | 155 | pre_jump = x @ self.ae.W_enc + self.ae.b_enc 156 | f = JumpReLUFunction.apply(pre_jump, self.ae.threshold, self.bandwidth) 157 | 158 | active_indices = f.sum(0) > 0 159 | did_fire = torch.zeros_like(self.num_tokens_since_fired, dtype=torch.bool) 160 | did_fire[active_indices] = True 161 | self.num_tokens_since_fired += x.size(0) 162 | self.num_tokens_since_fired[active_indices] = 0 163 | self.dead_features = ( 164 | (self.num_tokens_since_fired > self.dead_feature_threshold).sum().item() 165 | ) 166 | 167 | recon = self.ae.decode(f) 168 | 169 | recon_loss = (x - recon).pow(2).sum(dim=-1).mean() 170 | l0 = StepFunction.apply(f, self.ae.threshold, self.bandwidth).sum(dim=-1).mean() 171 | 172 | sparsity_loss = ( 173 | self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2) * sparsity_scale 174 | ) 175 | loss = recon_loss + sparsity_loss 176 | 177 | if not logging: 178 | return loss 179 | else: 180 | return namedtuple("LossLog", ["x", "recon", "f", "losses"])( 181 | x, 182 | recon, 183 | f, 184 | { 185 | "l2_loss": recon_loss.item(), 186 | "loss": loss.item(), 187 | }, 188 | ) 189 | 190 | def update(self, step, x): 191 | x = x.to(self.device) 192 | loss = self.loss(x, step=step) 193 | loss.backward() 194 | 195 | # We must transpose because we are using nn.Parameter, not nn.Linear 196 | self.ae.W_dec.grad = remove_gradient_parallel_to_decoder_directions( 197 | self.ae.W_dec.T, self.ae.W_dec.grad.T, self.ae.activation_dim, self.ae.dict_size 198 | ).T 199 | torch.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) 200 | 201 | self.optimizer.step() 202 | self.scheduler.step() 203 | self.optimizer.zero_grad() 204 | 205 | # We must transpose because we are using nn.Parameter, not nn.Linear 206 | self.ae.W_dec.data = set_decoder_norm_to_unit_norm( 207 | self.ae.W_dec.T, self.ae.activation_dim, self.ae.dict_size 208 | ).T 209 | 210 | return loss.item() 211 | 212 | @property 213 | def config(self): 214 | return { 215 | "trainer_class": "JumpReluTrainer", 216 | "dict_class": "JumpReluAutoEncoder", 217 | "lr": self.lr, 218 | "steps": self.steps, 219 | "seed": self.seed, 220 | "activation_dim": self.ae.activation_dim, 221 | "dict_size": self.ae.dict_size, 222 | "device": self.device, 223 | "layer": self.layer, 224 | "lm_name": self.lm_name, 225 | "wandb_name": self.wandb_name, 226 | "submodule_name": self.submodule_name, 227 | "bandwidth": self.bandwidth, 228 | "sparsity_penalty": self.sparsity_coefficient, 229 | "sparsity_warmup_steps": self.sparsity_warmup_steps, 230 | "target_l0": self.target_l0, 231 | } 232 | -------------------------------------------------------------------------------- /dictionary_learning/trainers/p_anneal.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | from typing import Optional 3 | """ 4 | Implements the standard SAE training scheme. 5 | """ 6 | 7 | from ..dictionary import AutoEncoder 8 | from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam 9 | from ..config import DEBUG 10 | 11 | class PAnnealTrainer(SAETrainer): 12 | """ 13 | SAE training scheme with the option to anneal the sparsity parameter p. 14 | You can further choose to use Lp or Lp^p sparsity. 15 | """ 16 | def __init__(self, 17 | steps: int, # total number of steps to train for 18 | activation_dim: int, 19 | dict_size: int, 20 | layer: int, 21 | lm_name: str, 22 | dict_class: type = AutoEncoder, 23 | lr: float = 1e-3, 24 | warmup_steps: int = 1000, # lr warmup period at start of training and after each resample 25 | decay_start: Optional[int] = None, # step at which to start decaying lr 26 | sparsity_warmup_steps: Optional[int] = 2000, # number of steps to warm up sparsity penalty 27 | sparsity_function: str = 'Lp', # Lp or Lp^p 28 | initial_sparsity_penalty: float = 1e-1, # equal to l1 penalty in standard trainer 29 | anneal_start: int = 15000, # step at which to start annealing p 30 | anneal_end: Optional[int] = None, # step at which to stop annealing, defaults to steps-1 31 | p_start: float = 1, # starting value of p (constant throughout warmup) 32 | p_end: float = 0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded 33 | n_sparsity_updates: int | str = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times 34 | sparsity_queue_length: int = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty 35 | resample_steps: Optional[int] = None, # number of steps after which to resample dead neurons 36 | device: Optional[str] = None, 37 | seed: int = 42, 38 | wandb_name: str = 'PAnnealTrainer', 39 | submodule_name: Optional[str] = None, 40 | ): 41 | super().__init__(seed) 42 | 43 | assert layer is not None and lm_name is not None 44 | self.layer = layer 45 | self.lm_name = lm_name 46 | self.submodule_name = submodule_name 47 | 48 | if seed is not None: 49 | t.manual_seed(seed) 50 | t.cuda.manual_seed_all(seed) 51 | 52 | if device is None: 53 | self.device = t.device('cuda' if t.cuda.is_available() else 'cpu') 54 | else: 55 | self.device = device 56 | 57 | # initialize dictionary 58 | self.activation_dim = activation_dim 59 | self.dict_size = dict_size 60 | self.ae = dict_class(activation_dim, dict_size) 61 | self.ae.to(self.device) 62 | 63 | self.lr = lr 64 | self.sparsity_function = sparsity_function 65 | self.anneal_start = anneal_start 66 | self.anneal_end = anneal_end if anneal_end is not None else steps 67 | self.p_start = p_start 68 | self.p_end = p_end 69 | self.p = p_start 70 | self.next_p = None 71 | if n_sparsity_updates == "continuous": 72 | self.n_sparsity_updates = self.anneal_end - anneal_start +1 73 | else: 74 | self.n_sparsity_updates = n_sparsity_updates 75 | self.sparsity_update_steps = t.linspace(anneal_start, self.anneal_end, self.n_sparsity_updates, dtype=int) 76 | self.p_values = t.linspace(p_start, p_end, self.n_sparsity_updates) 77 | self.p_step_count = 0 78 | self.sparsity_coeff = initial_sparsity_penalty # alpha 79 | self.sparsity_queue_length = sparsity_queue_length 80 | self.sparsity_queue = [] 81 | 82 | self.warmup_steps = warmup_steps 83 | self.sparsity_warmup_steps = sparsity_warmup_steps 84 | self.decay_start = decay_start 85 | self.steps = steps 86 | self.logging_parameters = ['p', 'next_p', 'lp_loss', 'scaled_lp_loss', 'sparsity_coeff'] 87 | self.seed = seed 88 | self.wandb_name = wandb_name 89 | 90 | self.resample_steps = resample_steps 91 | if self.resample_steps is not None: 92 | # how many steps since each neuron was last activated? 93 | self.steps_since_active = t.zeros(self.dict_size, dtype=int).to(self.device) 94 | else: 95 | self.steps_since_active = None 96 | 97 | self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) 98 | 99 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps, sparsity_warmup_steps) 100 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) 101 | 102 | self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) 103 | 104 | if (self.sparsity_update_steps.unique(return_counts=True)[1] >1).any(): 105 | print("Warning! Duplicates om self.sparsity_update_steps detected!") 106 | 107 | def resample_neurons(self, deads, activations): 108 | with t.no_grad(): 109 | if deads.sum() == 0: return 110 | print(f"resampling {deads.sum().item()} neurons") 111 | 112 | # compute loss for each activation 113 | losses = (activations - self.ae(activations)).norm(dim=-1) 114 | 115 | # sample input to create encoder/decoder weights from 116 | n_resample = min([deads.sum(), losses.shape[0]]) 117 | indices = t.multinomial(losses, num_samples=n_resample, replacement=False) 118 | sampled_vecs = activations[indices] 119 | 120 | # reset encoder/decoder weights for dead neurons 121 | alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() 122 | self.ae.encoder.weight[deads][:n_resample] = sampled_vecs * alive_norm * 0.2 123 | self.ae.decoder.weight[:,deads][:,:n_resample] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T 124 | self.ae.encoder.bias[deads][:n_resample] = 0. 125 | 126 | 127 | # reset Adam parameters for dead neurons 128 | state_dict = self.optimizer.state_dict()['state'] 129 | ## encoder weight 130 | state_dict[1]['exp_avg'][deads] = 0. 131 | state_dict[1]['exp_avg_sq'][deads] = 0. 132 | ## encoder bias 133 | state_dict[2]['exp_avg'][deads] = 0. 134 | state_dict[2]['exp_avg_sq'][deads] = 0. 135 | ## decoder weight 136 | state_dict[3]['exp_avg'][:,deads] = 0. 137 | state_dict[3]['exp_avg_sq'][:,deads] = 0. 138 | 139 | def lp_norm(self, f, p): 140 | norm_sq = f.pow(p).sum(dim=-1) 141 | if self.sparsity_function == 'Lp^p': 142 | return norm_sq.mean() 143 | elif self.sparsity_function == 'Lp': 144 | return norm_sq.pow(1/p).mean() 145 | else: 146 | raise ValueError("Sparsity function must be 'Lp' or 'Lp^p'") 147 | 148 | def loss(self, x: t.Tensor, step:int, logging=False): 149 | sparsity_scale = self.sparsity_warmup_fn(step) 150 | 151 | # Compute loss terms 152 | x_hat, f = self.ae(x, output_features=True) 153 | recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() 154 | lp_loss = self.lp_norm(f, self.p) 155 | scaled_lp_loss = lp_loss * self.sparsity_coeff * sparsity_scale 156 | self.lp_loss = lp_loss 157 | self.scaled_lp_loss = scaled_lp_loss 158 | 159 | if self.next_p is not None: 160 | lp_loss_next = self.lp_norm(f, self.next_p) 161 | self.sparsity_queue.append([self.lp_loss.item(), lp_loss_next.item()]) 162 | self.sparsity_queue = self.sparsity_queue[-self.sparsity_queue_length:] 163 | 164 | if step in self.sparsity_update_steps: 165 | # check to make sure we don't update on repeat step: 166 | if step >= self.sparsity_update_steps[self.p_step_count]: 167 | # Adapt sparsity penalty alpha 168 | if self.next_p is not None: 169 | local_sparsity_new = t.tensor([i[0] for i in self.sparsity_queue]).mean() 170 | local_sparsity_old = t.tensor([i[1] for i in self.sparsity_queue]).mean() 171 | self.sparsity_coeff = self.sparsity_coeff * (local_sparsity_new / local_sparsity_old).item() 172 | # Update p 173 | self.p = self.p_values[self.p_step_count].item() 174 | if self.p_step_count < self.n_sparsity_updates-1: 175 | self.next_p = self.p_values[self.p_step_count+1].item() 176 | else: 177 | self.next_p = self.p_end 178 | self.p_step_count += 1 179 | 180 | # Update dead feature count 181 | if self.steps_since_active is not None: 182 | # update steps_since_active 183 | deads = (f == 0).all(dim=0) 184 | self.steps_since_active[deads] += 1 185 | self.steps_since_active[~deads] = 0 186 | 187 | if logging is False: 188 | return recon_loss + scaled_lp_loss 189 | else: 190 | loss_log = { 191 | 'p' : self.p, 192 | 'next_p' : self.next_p, 193 | 'lp_loss' : lp_loss.item(), 194 | 'scaled_lp_loss' : scaled_lp_loss.item(), 195 | 'sparsity_coeff' : self.sparsity_coeff, 196 | } 197 | return x, x_hat, f, loss_log 198 | 199 | 200 | def update(self, step, activations): 201 | activations = activations.to(self.device) 202 | 203 | self.optimizer.zero_grad() 204 | loss = self.loss(activations, step, logging=False) 205 | loss.backward() 206 | self.optimizer.step() 207 | self.scheduler.step() 208 | 209 | if self.resample_steps is not None and step % self.resample_steps == self.resample_steps - 1: 210 | self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) 211 | 212 | @property 213 | def config(self): 214 | return { 215 | 'trainer_class' : "PAnnealTrainer", 216 | 'dict_class' : "AutoEncoder", 217 | 'activation_dim' : self.activation_dim, 218 | 'dict_size' : self.dict_size, 219 | 'lr' : self.lr, 220 | 'sparsity_function' : self.sparsity_function, 221 | 'sparsity_penalty' : self.sparsity_coeff, 222 | 'p_start' : self.p_start, 223 | 'p_end' : self.p_end, 224 | 'anneal_start' : self.anneal_start, 225 | 'sparsity_queue_length' : self.sparsity_queue_length, 226 | 'n_sparsity_updates' : self.n_sparsity_updates, 227 | 'warmup_steps' : self.warmup_steps, 228 | 'sparsity_warmup_steps': self.sparsity_warmup_steps, 229 | 'decay_start': self.decay_start, 230 | 'resample_steps' : self.resample_steps, 231 | 'steps' : self.steps, 232 | 'seed' : self.seed, 233 | 'layer' : self.layer, 234 | 'lm_name' : self.lm_name, 235 | 'wandb_name' : self.wandb_name, 236 | 'submodule_name' : self.submodule_name, 237 | } 238 | -------------------------------------------------------------------------------- /dictionary_learning/trainers/standard.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the standard SAE training scheme. 3 | """ 4 | import torch as t 5 | from typing import Optional 6 | 7 | from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam 8 | from ..config import DEBUG 9 | from ..dictionary import AutoEncoder 10 | from collections import namedtuple 11 | 12 | class StandardTrainer(SAETrainer): 13 | """ 14 | Standard SAE training scheme following Towards Monosemanticity. Decoder column norms are constrained to 1. 15 | """ 16 | def __init__(self, 17 | steps: int, # total number of steps to train for 18 | activation_dim: int, 19 | dict_size: int, 20 | layer: int, 21 | lm_name: str, 22 | dict_class=AutoEncoder, 23 | lr:float=1e-3, 24 | l1_penalty:float=1e-1, 25 | warmup_steps:int=1000, # lr warmup period at start of training and after each resample 26 | sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training 27 | decay_start:Optional[int]=None, # decay learning rate after this many steps 28 | resample_steps:Optional[int]=None, # how often to resample neurons 29 | seed:Optional[int]=None, 30 | device=None, 31 | wandb_name:Optional[str]='StandardTrainer', 32 | submodule_name:Optional[str]=None, 33 | ): 34 | super().__init__(seed) 35 | 36 | assert layer is not None and lm_name is not None 37 | self.layer = layer 38 | self.lm_name = lm_name 39 | self.submodule_name = submodule_name 40 | 41 | if seed is not None: 42 | t.manual_seed(seed) 43 | t.cuda.manual_seed_all(seed) 44 | 45 | # initialize dictionary 46 | self.ae = dict_class(activation_dim, dict_size) 47 | 48 | self.lr = lr 49 | self.l1_penalty=l1_penalty 50 | self.warmup_steps = warmup_steps 51 | self.sparsity_warmup_steps = sparsity_warmup_steps 52 | self.steps = steps 53 | self.decay_start = decay_start 54 | self.wandb_name = wandb_name 55 | 56 | if device is None: 57 | self.device = 'cuda' if t.cuda.is_available() else 'cpu' 58 | else: 59 | self.device = device 60 | self.ae.to(self.device) 61 | 62 | self.resample_steps = resample_steps 63 | if self.resample_steps is not None: 64 | # how many steps since each neuron was last activated? 65 | self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device) 66 | else: 67 | self.steps_since_active = None 68 | 69 | self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) 70 | 71 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps, sparsity_warmup_steps) 72 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) 73 | 74 | self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) 75 | 76 | def resample_neurons(self, deads, activations): 77 | with t.no_grad(): 78 | if deads.sum() == 0: return 79 | print(f"resampling {deads.sum().item()} neurons") 80 | 81 | # compute loss for each activation 82 | losses = (activations - self.ae(activations)).norm(dim=-1) 83 | 84 | # sample input to create encoder/decoder weights from 85 | n_resample = min([deads.sum(), losses.shape[0]]) 86 | indices = t.multinomial(losses, num_samples=n_resample, replacement=False) 87 | sampled_vecs = activations[indices] 88 | 89 | # get norm of the living neurons 90 | alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() 91 | 92 | # resample first n_resample dead neurons 93 | deads[deads.nonzero()[n_resample:]] = False 94 | self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2 95 | self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T 96 | self.ae.encoder.bias[deads] = 0. 97 | 98 | 99 | # reset Adam parameters for dead neurons 100 | state_dict = self.optimizer.state_dict()['state'] 101 | ## encoder weight 102 | state_dict[1]['exp_avg'][deads] = 0. 103 | state_dict[1]['exp_avg_sq'][deads] = 0. 104 | ## encoder bias 105 | state_dict[2]['exp_avg'][deads] = 0. 106 | state_dict[2]['exp_avg_sq'][deads] = 0. 107 | ## decoder weight 108 | state_dict[3]['exp_avg'][:,deads] = 0. 109 | state_dict[3]['exp_avg_sq'][:,deads] = 0. 110 | 111 | def loss(self, x, step: int, logging=False, **kwargs): 112 | 113 | sparsity_scale = self.sparsity_warmup_fn(step) 114 | 115 | x_hat, f = self.ae(x, output_features=True) 116 | l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() 117 | recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() 118 | l1_loss = f.norm(p=1, dim=-1).mean() 119 | 120 | if self.steps_since_active is not None: 121 | # update steps_since_active 122 | deads = (f == 0).all(dim=0) 123 | self.steps_since_active[deads] += 1 124 | self.steps_since_active[~deads] = 0 125 | 126 | loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss 127 | 128 | if not logging: 129 | return loss 130 | else: 131 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( 132 | x, x_hat, f, 133 | { 134 | 'l2_loss' : l2_loss.item(), 135 | 'mse_loss' : recon_loss.item(), 136 | 'sparsity_loss' : l1_loss.item(), 137 | 'loss' : loss.item() 138 | } 139 | ) 140 | 141 | 142 | def update(self, step, activations): 143 | activations = activations.to(self.device) 144 | 145 | self.optimizer.zero_grad() 146 | loss = self.loss(activations, step=step) 147 | loss.backward() 148 | self.optimizer.step() 149 | self.scheduler.step() 150 | 151 | if self.resample_steps is not None and step % self.resample_steps == 0: 152 | self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) 153 | 154 | @property 155 | def config(self): 156 | return { 157 | 'dict_class': 'AutoEncoder', 158 | 'trainer_class' : 'StandardTrainer', 159 | 'activation_dim': self.ae.activation_dim, 160 | 'dict_size': self.ae.dict_size, 161 | 'lr' : self.lr, 162 | 'l1_penalty' : self.l1_penalty, 163 | 'warmup_steps' : self.warmup_steps, 164 | 'resample_steps' : self.resample_steps, 165 | 'sparsity_warmup_steps' : self.sparsity_warmup_steps, 166 | 'steps' : self.steps, 167 | 'decay_start' : self.decay_start, 168 | 'seed' : self.seed, 169 | 'device' : self.device, 170 | 'layer' : self.layer, 171 | 'lm_name' : self.lm_name, 172 | 'wandb_name': self.wandb_name, 173 | 'submodule_name': self.submodule_name, 174 | } 175 | 176 | 177 | class StandardTrainerAprilUpdate(SAETrainer): 178 | """ 179 | Standard SAE training scheme following the Anthropic April update. Decoder column norms are NOT constrained to 1. 180 | This trainer does not support resampling or ghost gradients. This trainer will have fewer dead neurons than the standard trainer. 181 | """ 182 | def __init__(self, 183 | steps: int, # total number of steps to train for 184 | activation_dim: int, 185 | dict_size: int, 186 | layer: int, 187 | lm_name: str, 188 | dict_class=AutoEncoder, 189 | lr:float=1e-3, 190 | l1_penalty:float=1e-1, 191 | warmup_steps:int=1000, # lr warmup period at start of training 192 | sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training 193 | decay_start:Optional[int]=None, # decay learning rate after this many steps 194 | seed:Optional[int]=None, 195 | device=None, 196 | wandb_name:Optional[str]='StandardTrainerAprilUpdate', 197 | submodule_name:Optional[str]=None, 198 | ): 199 | super().__init__(seed) 200 | 201 | assert layer is not None and lm_name is not None 202 | self.layer = layer 203 | self.lm_name = lm_name 204 | self.submodule_name = submodule_name 205 | 206 | if seed is not None: 207 | t.manual_seed(seed) 208 | t.cuda.manual_seed_all(seed) 209 | 210 | # initialize dictionary 211 | self.ae = dict_class(activation_dim, dict_size) 212 | 213 | self.lr = lr 214 | self.l1_penalty=l1_penalty 215 | self.warmup_steps = warmup_steps 216 | self.sparsity_warmup_steps = sparsity_warmup_steps 217 | self.steps = steps 218 | self.decay_start = decay_start 219 | self.wandb_name = wandb_name 220 | 221 | if device is None: 222 | self.device = 'cuda' if t.cuda.is_available() else 'cpu' 223 | else: 224 | self.device = device 225 | self.ae.to(self.device) 226 | 227 | self.optimizer = t.optim.Adam(self.ae.parameters(), lr=lr) 228 | 229 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps=None, sparsity_warmup_steps=sparsity_warmup_steps) 230 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) 231 | 232 | self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) 233 | 234 | def loss(self, x, step: int, logging=False, **kwargs): 235 | 236 | sparsity_scale = self.sparsity_warmup_fn(step) 237 | 238 | x_hat, f = self.ae(x, output_features=True) 239 | l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() 240 | recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() 241 | l1_loss = (f * self.ae.decoder.weight.norm(p=2, dim=0)).sum(dim=-1).mean() 242 | 243 | loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss 244 | 245 | if not logging: 246 | return loss 247 | else: 248 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( 249 | x, x_hat, f, 250 | { 251 | 'l2_loss' : l2_loss.item(), 252 | 'mse_loss' : recon_loss.item(), 253 | 'sparsity_loss' : l1_loss.item(), 254 | 'loss' : loss.item() 255 | } 256 | ) 257 | 258 | 259 | def update(self, step, activations): 260 | activations = activations.to(self.device) 261 | 262 | self.optimizer.zero_grad() 263 | loss = self.loss(activations, step=step) 264 | loss.backward() 265 | t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) 266 | self.optimizer.step() 267 | self.scheduler.step() 268 | 269 | @property 270 | def config(self): 271 | return { 272 | 'dict_class': 'AutoEncoder', 273 | 'trainer_class' : 'StandardTrainerAprilUpdate', 274 | 'activation_dim': self.ae.activation_dim, 275 | 'dict_size': self.ae.dict_size, 276 | 'lr' : self.lr, 277 | 'l1_penalty' : self.l1_penalty, 278 | 'warmup_steps' : self.warmup_steps, 279 | 'sparsity_warmup_steps' : self.sparsity_warmup_steps, 280 | 'steps' : self.steps, 281 | 'decay_start' : self.decay_start, 282 | 'seed' : self.seed, 283 | 'device' : self.device, 284 | 'layer' : self.layer, 285 | 'lm_name' : self.lm_name, 286 | 'wandb_name': self.wandb_name, 287 | 'submodule_name': self.submodule_name, 288 | } 289 | 290 | -------------------------------------------------------------------------------- /dictionary_learning/trainers/top_k.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the SAE training scheme from https://arxiv.org/abs/2406.04093. 3 | Significant portions of this code have been copied from https://github.com/EleutherAI/sae/blob/main/sae 4 | """ 5 | 6 | import einops 7 | import torch as t 8 | import torch.nn as nn 9 | from collections import namedtuple 10 | from typing import Optional 11 | 12 | from ..config import DEBUG 13 | from ..dictionary import Dictionary 14 | from ..trainers.trainer import ( 15 | SAETrainer, 16 | get_lr_schedule, 17 | set_decoder_norm_to_unit_norm, 18 | remove_gradient_parallel_to_decoder_directions, 19 | ) 20 | 21 | 22 | @t.no_grad() 23 | def geometric_median(points: t.Tensor, max_iter: int = 100, tol: float = 1e-5): 24 | """Compute the geometric median `points`. Used for initializing decoder bias.""" 25 | # Initialize our guess as the mean of the points 26 | guess = points.mean(dim=0) 27 | prev = t.zeros_like(guess) 28 | 29 | # Weights for iteratively reweighted least squares 30 | weights = t.ones(len(points), device=points.device) 31 | 32 | for _ in range(max_iter): 33 | prev = guess 34 | 35 | # Compute the weights 36 | weights = 1 / t.norm(points - guess, dim=1) 37 | 38 | # Normalize the weights 39 | weights /= weights.sum() 40 | 41 | # Compute the new geometric median 42 | guess = (weights.unsqueeze(1) * points).sum(dim=0) 43 | 44 | # Early stopping condition 45 | if t.norm(guess - prev) < tol: 46 | break 47 | 48 | return guess 49 | 50 | 51 | class AutoEncoderTopK(Dictionary, nn.Module): 52 | """ 53 | The top-k autoencoder architecture and initialization used in https://arxiv.org/abs/2406.04093 54 | NOTE: (From Adam Karvonen) There is an unmaintained implementation using Triton kernels in the topk-triton-implementation branch. 55 | We abandoned it as we didn't notice a significant speedup and it added complications, which are noted 56 | in the AutoEncoderTopK class docstring in that branch. 57 | 58 | With some additional effort, you can train a Top-K SAE with the Triton kernels and modify the state dict for compatibility with this class. 59 | Notably, the Triton kernels currently have the decoder to be stored in nn.Parameter, not nn.Linear, and the decoder weights must also 60 | be stored in the same shape as the encoder. 61 | """ 62 | 63 | def __init__(self, activation_dim: int, dict_size: int, k: int): 64 | super().__init__() 65 | self.activation_dim = activation_dim 66 | self.dict_size = dict_size 67 | 68 | assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" 69 | self.register_buffer("k", t.tensor(k, dtype=t.int)) 70 | self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) 71 | 72 | self.decoder = nn.Linear(dict_size, activation_dim, bias=False) 73 | self.decoder.weight.data = set_decoder_norm_to_unit_norm( 74 | self.decoder.weight, activation_dim, dict_size 75 | ) 76 | 77 | self.encoder = nn.Linear(activation_dim, dict_size) 78 | self.encoder.weight.data = self.decoder.weight.T.clone() 79 | self.encoder.bias.data.zero_() 80 | 81 | self.b_dec = nn.Parameter(t.zeros(activation_dim)) 82 | 83 | def encode(self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = False): 84 | post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) 85 | 86 | if use_threshold: 87 | encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) 88 | if return_topk: 89 | post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) 90 | return encoded_acts_BF, post_topk.values, post_topk.indices, post_relu_feat_acts_BF 91 | else: 92 | return encoded_acts_BF 93 | 94 | post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) 95 | 96 | # We can't split immediately due to nnsight 97 | tops_acts_BK = post_topk.values 98 | top_indices_BK = post_topk.indices 99 | 100 | buffer_BF = t.zeros_like(post_relu_feat_acts_BF) 101 | encoded_acts_BF = buffer_BF.scatter_(dim=-1, index=top_indices_BK, src=tops_acts_BK) 102 | 103 | if return_topk: 104 | return encoded_acts_BF, tops_acts_BK, top_indices_BK, post_relu_feat_acts_BF 105 | else: 106 | return encoded_acts_BF 107 | 108 | def decode(self, x: t.Tensor) -> t.Tensor: 109 | return self.decoder(x) + self.b_dec 110 | 111 | def forward(self, x: t.Tensor, output_features: bool = False): 112 | encoded_acts_BF = self.encode(x) 113 | x_hat_BD = self.decode(encoded_acts_BF) 114 | if not output_features: 115 | return x_hat_BD 116 | else: 117 | return x_hat_BD, encoded_acts_BF 118 | 119 | def scale_biases(self, scale: float): 120 | self.encoder.bias.data *= scale 121 | self.b_dec.data *= scale 122 | if self.threshold >= 0: 123 | self.threshold *= scale 124 | 125 | def from_pretrained(path, k: Optional[int] = None, device=None): 126 | """ 127 | Load a pretrained autoencoder from a file. 128 | """ 129 | state_dict = t.load(path) 130 | dict_size, activation_dim = state_dict["encoder.weight"].shape 131 | 132 | if k is None: 133 | k = state_dict["k"].item() 134 | elif "k" in state_dict and k != state_dict["k"].item(): 135 | raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") 136 | 137 | autoencoder = AutoEncoderTopK(activation_dim, dict_size, k) 138 | autoencoder.load_state_dict(state_dict) 139 | if device is not None: 140 | autoencoder.to(device) 141 | return autoencoder 142 | 143 | 144 | class TopKTrainer(SAETrainer): 145 | """ 146 | Top-K SAE training scheme. 147 | """ 148 | 149 | def __init__( 150 | self, 151 | steps: int, # total number of steps to train for 152 | activation_dim: int, 153 | dict_size: int, 154 | k: int, 155 | layer: int, 156 | lm_name: str, 157 | dict_class: type = AutoEncoderTopK, 158 | lr: Optional[float] = None, 159 | auxk_alpha: float = 1 / 32, # see Appendix A.2 160 | warmup_steps: int = 1000, 161 | decay_start: Optional[int] = None, # when does the lr decay start 162 | threshold_beta: float = 0.999, 163 | threshold_start_step: int = 1000, 164 | seed: Optional[int] = None, 165 | device: Optional[str] = None, 166 | wandb_name: str = "AutoEncoderTopK", 167 | submodule_name: Optional[str] = None, 168 | ): 169 | super().__init__(seed) 170 | 171 | assert layer is not None and lm_name is not None 172 | self.layer = layer 173 | self.lm_name = lm_name 174 | self.submodule_name = submodule_name 175 | 176 | self.wandb_name = wandb_name 177 | self.steps = steps 178 | self.decay_start = decay_start 179 | self.warmup_steps = warmup_steps 180 | self.k = k 181 | self.threshold_beta = threshold_beta 182 | self.threshold_start_step = threshold_start_step 183 | 184 | if seed is not None: 185 | t.manual_seed(seed) 186 | t.cuda.manual_seed_all(seed) 187 | 188 | # Initialise autoencoder 189 | self.ae = dict_class(activation_dim, dict_size, k) 190 | if device is None: 191 | self.device = "cuda" if t.cuda.is_available() else "cpu" 192 | else: 193 | self.device = device 194 | self.ae.to(self.device) 195 | 196 | if lr is not None: 197 | self.lr = lr 198 | else: 199 | # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper 200 | scale = dict_size / (2**14) 201 | self.lr = 2e-4 / scale**0.5 202 | 203 | self.auxk_alpha = auxk_alpha 204 | self.dead_feature_threshold = 10_000_000 205 | self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper 206 | self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) 207 | self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"] 208 | self.effective_l0 = -1 209 | self.dead_features = -1 210 | self.pre_norm_auxk_loss = -1 211 | 212 | # Optimizer and scheduler 213 | self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) 214 | 215 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start=decay_start) 216 | 217 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) 218 | 219 | def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor): 220 | dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold 221 | self.dead_features = int(dead_features.sum()) 222 | 223 | if self.dead_features > 0: 224 | k_aux = min(self.top_k_aux, self.dead_features) 225 | 226 | auxk_latents = t.where(dead_features[None], post_relu_acts_BF, -t.inf) 227 | 228 | # Top-k dead latents 229 | auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) 230 | 231 | auxk_buffer_BF = t.zeros_like(post_relu_acts_BF) 232 | auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) 233 | 234 | # Note: decoder(), not decode(), as we don't want to apply the bias 235 | x_reconstruct_aux = self.ae.decoder(auxk_acts_BF) 236 | l2_loss_aux = ( 237 | (residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean() 238 | ) 239 | 240 | self.pre_norm_auxk_loss = l2_loss_aux 241 | 242 | # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614 243 | residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape) 244 | loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() 245 | normalized_auxk_loss = l2_loss_aux / loss_denom 246 | 247 | return normalized_auxk_loss.nan_to_num(0.0) 248 | else: 249 | self.pre_norm_auxk_loss = -1 250 | return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device) 251 | 252 | def update_threshold(self, top_acts_BK: t.Tensor): 253 | device_type = "cuda" if top_acts_BK.is_cuda else "cpu" 254 | with t.autocast(device_type=device_type, enabled=False), t.no_grad(): 255 | active = top_acts_BK.clone().detach() 256 | active[active <= 0] = float("inf") 257 | min_activations = active.min(dim=1).values.to(dtype=t.float32) 258 | min_activation = min_activations.mean() 259 | 260 | B, K = active.shape 261 | assert len(active.shape) == 2 262 | assert min_activations.shape == (B,) 263 | 264 | if self.ae.threshold < 0: 265 | self.ae.threshold = min_activation 266 | else: 267 | self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( 268 | (1 - self.threshold_beta) * min_activation 269 | ) 270 | 271 | def loss(self, x, step=None, logging=False): 272 | # Run the SAE 273 | f, top_acts_BK, top_indices_BK, post_relu_acts_BF = self.ae.encode( 274 | x, return_topk=True, use_threshold=False 275 | ) 276 | 277 | if step > self.threshold_start_step: 278 | self.update_threshold(top_acts_BK) 279 | 280 | x_hat = self.ae.decode(f) 281 | 282 | # Measure goodness of reconstruction 283 | e = x - x_hat 284 | 285 | # Update the effective L0 (again, should just be K) 286 | self.effective_l0 = top_acts_BK.size(1) 287 | 288 | # Update "number of tokens since fired" for each features 289 | num_tokens_in_step = x.size(0) 290 | did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool) 291 | did_fire[top_indices_BK.flatten()] = True 292 | self.num_tokens_since_fired += num_tokens_in_step 293 | self.num_tokens_since_fired[did_fire] = 0 294 | 295 | l2_loss = e.pow(2).sum(dim=-1).mean() 296 | auxk_loss = ( 297 | self.get_auxiliary_loss(e.detach(), post_relu_acts_BF) if self.auxk_alpha > 0 else 0 298 | ) 299 | 300 | loss = l2_loss + self.auxk_alpha * auxk_loss 301 | 302 | if not logging: 303 | return loss 304 | else: 305 | return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])( 306 | x, 307 | x_hat, 308 | f, 309 | {"l2_loss": l2_loss.item(), "auxk_loss": auxk_loss.item(), "loss": loss.item()}, 310 | ) 311 | 312 | def update(self, step, x): 313 | # Initialise the decoder bias 314 | if step == 0: 315 | median = geometric_median(x) 316 | median = median.to(self.ae.b_dec.dtype) 317 | self.ae.b_dec.data = median 318 | 319 | # compute the loss 320 | x = x.to(self.device) 321 | loss = self.loss(x, step=step) 322 | loss.backward() 323 | 324 | # clip grad norm and remove grads parallel to decoder directions 325 | self.ae.decoder.weight.grad = remove_gradient_parallel_to_decoder_directions( 326 | self.ae.decoder.weight, 327 | self.ae.decoder.weight.grad, 328 | self.ae.activation_dim, 329 | self.ae.dict_size, 330 | ) 331 | t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) 332 | 333 | # do a training step 334 | self.optimizer.step() 335 | self.optimizer.zero_grad() 336 | self.scheduler.step() 337 | 338 | # Make sure the decoder is still unit-norm 339 | self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm( 340 | self.ae.decoder.weight, self.ae.activation_dim, self.ae.dict_size 341 | ) 342 | 343 | return loss.item() 344 | 345 | @property 346 | def config(self): 347 | return { 348 | "trainer_class": "TopKTrainer", 349 | "dict_class": "AutoEncoderTopK", 350 | "lr": self.lr, 351 | "steps": self.steps, 352 | "auxk_alpha": self.auxk_alpha, 353 | "warmup_steps": self.warmup_steps, 354 | "decay_start": self.decay_start, 355 | "threshold_beta": self.threshold_beta, 356 | "threshold_start_step": self.threshold_start_step, 357 | "seed": self.seed, 358 | "activation_dim": self.ae.activation_dim, 359 | "dict_size": self.ae.dict_size, 360 | "k": self.ae.k.item(), 361 | "device": self.device, 362 | "layer": self.layer, 363 | "lm_name": self.lm_name, 364 | "wandb_name": self.wandb_name, 365 | "submodule_name": self.submodule_name, 366 | } 367 | -------------------------------------------------------------------------------- /dictionary_learning/trainers/trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | import torch 3 | import einops 4 | 5 | 6 | class SAETrainer: 7 | """ 8 | Generic class for implementing SAE training algorithms 9 | """ 10 | 11 | def __init__(self, seed=None): 12 | self.seed = seed 13 | self.logging_parameters = [] 14 | 15 | def update( 16 | self, 17 | step, # index of step in training 18 | activations, # of shape [batch_size, d_submodule] 19 | ): 20 | pass # implemented by subclasses 21 | 22 | def get_logging_parameters(self): 23 | stats = {} 24 | for param in self.logging_parameters: 25 | if hasattr(self, param): 26 | stats[param] = getattr(self, param) 27 | else: 28 | print(f"Warning: {param} not found in {self}") 29 | return stats 30 | 31 | @property 32 | def config(self): 33 | return { 34 | "wandb_name": "trainer", 35 | } 36 | 37 | 38 | class ConstrainedAdam(torch.optim.Adam): 39 | """ 40 | A variant of Adam where some of the parameters are constrained to have unit norm. 41 | Note: This should be used with a decoder that is nn.Linear, not nn.Parameter. 42 | If nn.Parameter, the dim argument to norm should be 1. 43 | """ 44 | 45 | def __init__( 46 | self, params, constrained_params, lr: float, betas: tuple[float, float] = (0.9, 0.999) 47 | ): 48 | super().__init__(params, lr=lr, betas=betas) 49 | self.constrained_params = list(constrained_params) 50 | 51 | def step(self, closure=None): 52 | with torch.no_grad(): 53 | for p in self.constrained_params: 54 | normed_p = p / p.norm(dim=0, keepdim=True) 55 | # project away the parallel component of the gradient 56 | p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p 57 | super().step(closure=closure) 58 | with torch.no_grad(): 59 | for p in self.constrained_params: 60 | # renormalize the constrained parameters 61 | p /= p.norm(dim=0, keepdim=True) 62 | 63 | 64 | # The next two functions could be replaced with the ConstrainedAdam Optimizer 65 | @torch.no_grad() 66 | def set_decoder_norm_to_unit_norm( 67 | W_dec_DF: torch.nn.Parameter, activation_dim: int, d_sae: int 68 | ) -> torch.Tensor: 69 | """There's a major footgun here: we use this with both nn.Linear and nn.Parameter decoders. 70 | nn.Linear stores the decoder weights in a transposed format (d_model, d_sae). So, we pass the dimensions in 71 | to catch this error.""" 72 | 73 | D, F = W_dec_DF.shape 74 | 75 | assert D == activation_dim 76 | assert F == d_sae 77 | 78 | eps = torch.finfo(W_dec_DF.dtype).eps 79 | norm = torch.norm(W_dec_DF.data, dim=0, keepdim=True) 80 | W_dec_DF.data /= norm + eps 81 | return W_dec_DF.data 82 | 83 | 84 | @torch.no_grad() 85 | def remove_gradient_parallel_to_decoder_directions( 86 | W_dec_DF: torch.Tensor, 87 | W_dec_DF_grad: torch.Tensor, 88 | activation_dim: int, 89 | d_sae: int, 90 | ) -> torch.Tensor: 91 | """There's a major footgun here: we use this with both nn.Linear and nn.Parameter decoders. 92 | nn.Linear stores the decoder weights in a transposed format (d_model, d_sae). So, we pass the dimensions in 93 | to catch this error.""" 94 | 95 | D, F = W_dec_DF.shape 96 | assert D == activation_dim 97 | assert F == d_sae 98 | 99 | normed_W_dec_DF = W_dec_DF / (torch.norm(W_dec_DF, dim=0, keepdim=True) + 1e-6) 100 | 101 | parallel_component = einops.einsum( 102 | W_dec_DF_grad, 103 | normed_W_dec_DF, 104 | "d_in d_sae, d_in d_sae -> d_sae", 105 | ) 106 | W_dec_DF_grad -= einops.einsum( 107 | parallel_component, 108 | normed_W_dec_DF, 109 | "d_sae, d_in d_sae -> d_in d_sae", 110 | ) 111 | return W_dec_DF_grad 112 | 113 | 114 | def get_lr_schedule( 115 | total_steps: int, 116 | warmup_steps: int, 117 | decay_start: Optional[int] = None, 118 | resample_steps: Optional[int] = None, 119 | sparsity_warmup_steps: Optional[int] = None, 120 | ) -> Callable[[int], float]: 121 | """ 122 | Creates a learning rate schedule function with linear warmup followed by an optional decay phase. 123 | 124 | Note: resample_steps creates a repeating warmup pattern instead of the standard phases, but 125 | is rarely used in practice. 126 | 127 | Args: 128 | total_steps: Total number of training steps 129 | warmup_steps: Steps for linear warmup from 0 to 1 130 | decay_start: Optional step to begin linear decay to 0 131 | resample_steps: Optional period for repeating warmup pattern 132 | sparsity_warmup_steps: Used for validation with decay_start 133 | 134 | Returns: 135 | Function that computes LR scale factor for a given step 136 | """ 137 | if decay_start is not None: 138 | assert resample_steps is None, ( 139 | "decay_start and resample_steps are currently mutually exclusive." 140 | ) 141 | assert 0 <= decay_start < total_steps, "decay_start must be >= 0 and < steps." 142 | assert decay_start > warmup_steps, "decay_start must be > warmup_steps." 143 | if sparsity_warmup_steps is not None: 144 | assert decay_start > sparsity_warmup_steps, ( 145 | "decay_start must be > sparsity_warmup_steps." 146 | ) 147 | 148 | assert 0 <= warmup_steps < total_steps, "warmup_steps must be >= 0 and < steps." 149 | 150 | if resample_steps is None: 151 | 152 | def lr_schedule(step: int) -> float: 153 | if step < warmup_steps: 154 | # Warm-up phase 155 | return step / warmup_steps 156 | 157 | if decay_start is not None and step >= decay_start: 158 | # Decay phase 159 | return (total_steps - step) / (total_steps - decay_start) 160 | 161 | # Constant phase 162 | return 1.0 163 | else: 164 | assert 0 < resample_steps < total_steps, "resample_steps must be > 0 and < steps." 165 | 166 | def lr_schedule(step: int) -> float: 167 | return min((step % resample_steps) / warmup_steps, 1.0) 168 | 169 | return lr_schedule 170 | 171 | 172 | def get_sparsity_warmup_fn( 173 | total_steps: int, sparsity_warmup_steps: Optional[int] = None 174 | ) -> Callable[[int], float]: 175 | """ 176 | Return a function that computes a scale factor for sparsity penalty at a given step. 177 | 178 | If `sparsity_warmup_steps` is None or 0, returns 1.0 for all steps. 179 | Otherwise, scales from 0.0 up to 1.0 across `sparsity_warmup_steps`. 180 | """ 181 | 182 | if sparsity_warmup_steps is not None: 183 | assert 0 <= sparsity_warmup_steps < total_steps, ( 184 | "sparsity_warmup_steps must be >= 0 and < steps." 185 | ) 186 | 187 | def scale_fn(step: int) -> float: 188 | if not sparsity_warmup_steps: 189 | # If it's None or zero, we just return 1.0 190 | return 1.0 191 | else: 192 | # Gradually increase from 0.0 -> 1.0 as step goes from 0 -> sparsity_warmup_steps 193 | return min(step / sparsity_warmup_steps, 1.0) 194 | 195 | return scale_fn 196 | -------------------------------------------------------------------------------- /dictionary_learning/training.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training dictionaries 3 | """ 4 | 5 | import json 6 | import torch.multiprocessing as mp 7 | import os 8 | from queue import Empty 9 | from typing import Optional 10 | from contextlib import nullcontext 11 | from itertools import cycle 12 | 13 | import torch as t 14 | from tqdm import tqdm 15 | 16 | import wandb 17 | 18 | from .dictionary import AutoEncoder 19 | from .evaluation import evaluate 20 | from .trainers.standard import StandardTrainer 21 | from .trainers.matroyshka_batch_top_k import MatroyshkaBatchTopKSAE 22 | 23 | 24 | def new_wandb_process(config, log_queue, entity, project): 25 | wandb.init(entity=entity, project=project, config=config, name=config["wandb_name"]) 26 | while True: 27 | try: 28 | log = log_queue.get(timeout=1) 29 | if log == "DONE": 30 | break 31 | wandb.log(log) 32 | except Empty: 33 | continue 34 | wandb.finish() 35 | 36 | 37 | def log_stats( 38 | trainers, 39 | step: int, 40 | act: t.Tensor, 41 | activations_split_by_head: bool, 42 | transcoder: bool, 43 | log_queues: list=[], 44 | verbose: bool=False, 45 | ): 46 | with t.no_grad(): 47 | # quick hack to make sure all trainers get the same x 48 | z = act.clone() 49 | for i, trainer in enumerate(trainers): 50 | log = {} 51 | act = z.clone() 52 | if activations_split_by_head: # x.shape: [batch, pos, n_heads, d_head] 53 | act = act[..., i, :] 54 | if not transcoder: 55 | act, act_hat, f, losslog = trainer.loss(act, step=step, logging=True) 56 | 57 | # L0 58 | l0 = (f != 0).float().sum(dim=-1).mean().item() 59 | # fraction of variance explained 60 | total_variance = t.var(act, dim=0).sum() 61 | residual_variance = t.var(act - act_hat, dim=0).sum() 62 | frac_variance_explained = 1 - residual_variance / total_variance 63 | log[f"frac_variance_explained"] = frac_variance_explained.item() 64 | else: # transcoder 65 | x, x_hat, f, losslog = trainer.loss(act, step=step, logging=True) 66 | 67 | # L0 68 | l0 = (f != 0).float().sum(dim=-1).mean().item() 69 | 70 | if verbose: 71 | print(f"Step {step}: L0 = {l0}, frac_variance_explained = {frac_variance_explained}") 72 | 73 | # log parameters from training 74 | log.update({f"{k}": v.cpu().item() if isinstance(v, t.Tensor) else v for k, v in losslog.items()}) 75 | log[f"l0"] = l0 76 | trainer_log = trainer.get_logging_parameters() 77 | for name, value in trainer_log.items(): 78 | if isinstance(value, t.Tensor): 79 | value = value.cpu().item() 80 | log[f"{name}"] = value 81 | 82 | if log_queues: 83 | log_queues[i].put(log) 84 | 85 | def get_norm_factor(data, steps: int) -> float: 86 | """Per Section 3.1, find a fixed scalar factor so activation vectors have unit mean squared norm. 87 | This is very helpful for hyperparameter transfer between different layers and models. 88 | Use more steps for more accurate results. 89 | https://arxiv.org/pdf/2408.05147 90 | 91 | If experiencing troubles with hyperparameter transfer between models, it may be worth instead normalizing to the square root of d_model. 92 | https://transformer-circuits.pub/2024/april-update/index.html#training-saes""" 93 | total_mean_squared_norm = 0 94 | count = 0 95 | 96 | for step, act_BD in enumerate(tqdm(data, total=steps, desc="Calculating norm factor")): 97 | if step > steps: 98 | break 99 | 100 | count += 1 101 | mean_squared_norm = t.mean(t.sum(act_BD ** 2, dim=1)) 102 | total_mean_squared_norm += mean_squared_norm 103 | 104 | average_mean_squared_norm = total_mean_squared_norm / count 105 | norm_factor = t.sqrt(average_mean_squared_norm).item() 106 | 107 | print(f"Average mean squared norm: {average_mean_squared_norm}") 108 | print(f"Norm factor: {norm_factor}") 109 | 110 | return norm_factor 111 | 112 | # Assumes only one trainer and one log queue 113 | def validation(val_data, autocast_dtype, trainer, log_queue, norm_factor): 114 | for use_threshold in [False, True]: 115 | l0s = [] 116 | l2s = [] 117 | fracs = [] 118 | 119 | for act in val_data: 120 | act = act.detach().clone() 121 | act = act.to(dtype=autocast_dtype) 122 | act /= norm_factor 123 | with t.no_grad(): 124 | f = trainer.ae.encode(act, use_threshold=use_threshold) 125 | act_hat = trainer.ae.decode(f) 126 | e = act - act_hat 127 | 128 | # Sparsity - L0 129 | l0 = (f != 0).float().sum(dim=-1).mean().item() 130 | 131 | # Reconstruction - L2 132 | l2 = e.pow(2).sum(dim=-1).mean().item() 133 | 134 | # Reconstruction - Fraction of variance explained 135 | total_variance = t.var(act, dim=0).sum() 136 | residual_variance = t.var(e, dim=0).sum() 137 | frac_variance_explained = 1 - residual_variance / total_variance 138 | frac = frac_variance_explained.item() 139 | 140 | l0s.append(l0) 141 | l2s.append(l2) 142 | fracs.append(frac) 143 | 144 | threshold_str = "true" if use_threshold else "false" 145 | 146 | log = { 147 | f"val_threshold_{threshold_str}/sparsity_l0": t.mean(t.tensor(l0s)).item(), 148 | f"val_threshold_{threshold_str}/reconstruction_l2": t.mean(t.tensor(l2s)).item(), 149 | f"val_threshold_{threshold_str}/frac_variance_explained": t.mean(t.tensor(fracs)).item(), 150 | } 151 | 152 | log_queue.put(log) 153 | 154 | 155 | def trainSAE( 156 | data, 157 | val_data, 158 | trainer_configs: list[dict], 159 | steps: int, 160 | use_wandb:bool=False, 161 | wandb_entity:str="", 162 | wandb_project:str="", 163 | save_steps:Optional[list[int]]=None, 164 | save_dir:Optional[str]=None, 165 | log_steps:Optional[int]=None, 166 | activations_split_by_head:bool=False, 167 | transcoder:bool=False, 168 | run_cfg:dict={}, 169 | normalize_activations:bool=True, 170 | verbose:bool=False, 171 | device:str="cuda", 172 | autocast_dtype: t.dtype = t.float32, 173 | ): 174 | """ 175 | Train SAEs using the given trainers 176 | 177 | If normalize_activations is True, the activations will be normalized to have unit mean squared norm. 178 | The autoencoders weights will be scaled before saving, so the activations don't need to be scaled during inference. 179 | This is very helpful for hyperparameter transfer between different layers and models. 180 | 181 | Setting autocast_dtype to t.bfloat16 provides a significant speedup with minimal change in performance. 182 | """ 183 | 184 | device_type = "cuda" if "cuda" in device else "cpu" 185 | autocast_context = nullcontext() if device_type == "cpu" else t.autocast(device_type=device_type, dtype=autocast_dtype) 186 | 187 | trainers = [] 188 | for i, config in enumerate(trainer_configs): 189 | if "wandb_name" in config: 190 | config["wandb_name"] = f"{config['wandb_name']}_trainer_{i}" 191 | trainer_class = config["trainer"] 192 | del config["trainer"] 193 | trainers.append(trainer_class(**config)) 194 | 195 | wandb_processes = [] 196 | log_queues = [] 197 | 198 | if use_wandb: 199 | # Note: If encountering wandb and CUDA related errors, try setting start method to spawn in the if __name__ == "__main__" block 200 | # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method 201 | # Everything should work fine with the default fork method but it may not be as robust 202 | for i, trainer in enumerate(trainers): 203 | log_queue = mp.Queue() 204 | log_queues.append(log_queue) 205 | wandb_config = trainer.config | run_cfg 206 | # Make sure wandb config doesn't contain any CUDA tensors 207 | wandb_config = {k: v.cpu().item() if isinstance(v, t.Tensor) else v 208 | for k, v in wandb_config.items()} 209 | wandb_process = mp.Process( 210 | target=new_wandb_process, 211 | args=(wandb_config, log_queue, wandb_entity, wandb_project), 212 | ) 213 | wandb_process.start() 214 | wandb_processes.append(wandb_process) 215 | 216 | # make save dirs, export config 217 | if save_dir is not None: 218 | save_dirs = [ 219 | os.path.join(save_dir, f"trainer_{i}") for i in range(len(trainer_configs)) 220 | ] 221 | for trainer, dir in zip(trainers, save_dirs): 222 | os.makedirs(dir, exist_ok=True) 223 | # save config 224 | config = {"trainer": trainer.config} 225 | try: 226 | config["buffer"] = data.config 227 | except: 228 | pass 229 | with open(os.path.join(dir, "config.json"), "w") as f: 230 | json.dump(config, f, indent=4) 231 | else: 232 | save_dirs = [None for _ in trainer_configs] 233 | 234 | if normalize_activations: 235 | norm_factor = get_norm_factor(data, steps=100) 236 | 237 | for trainer in trainers: 238 | trainer.config["norm_factor"] = norm_factor 239 | # Verify that all autoencoders have a scale_biases method 240 | trainer.ae.scale_biases(1.0) 241 | 242 | # def rand_cycle(iterable): 243 | # while True: 244 | # for x in iterable: 245 | # yield x 246 | 247 | for step, act in enumerate(tqdm(cycle(data), total=steps)): 248 | 249 | act = act.detach().clone() # TODO: maybe remove if activation dataset modified 250 | act = act.to(dtype=autocast_dtype) 251 | 252 | if normalize_activations: 253 | act /= norm_factor 254 | 255 | if step >= steps: 256 | break 257 | 258 | # logging 259 | if (use_wandb or verbose) and step % log_steps == 0: 260 | log_stats( 261 | trainers, step, act, activations_split_by_head, transcoder, log_queues=log_queues, verbose=verbose 262 | ) 263 | 264 | # logging validation 265 | # if (use_wandb or verbose) and step % save_steps == 0: 266 | if step % log_steps == 0: 267 | validation(val_data, autocast_dtype, trainers[0], log_queues[0], norm_factor) 268 | 269 | # saving 270 | if save_steps is not None and step in save_steps: 271 | for dir, trainer in zip(save_dirs, trainers): 272 | if dir is not None: 273 | 274 | if normalize_activations: 275 | # Temporarily scale up biases for checkpoint saving 276 | trainer.ae.scale_biases(norm_factor) 277 | 278 | if not os.path.exists(os.path.join(dir, "checkpoints")): 279 | os.mkdir(os.path.join(dir, "checkpoints")) 280 | 281 | checkpoint = {k: v.cpu() for k, v in trainer.ae.state_dict().items()} 282 | t.save( 283 | checkpoint, 284 | os.path.join(dir, "checkpoints", f"ae_{step}.pt"), 285 | ) 286 | 287 | if normalize_activations: 288 | trainer.ae.scale_biases(1 / norm_factor) 289 | 290 | # training 291 | for trainer in trainers: 292 | with autocast_context: 293 | trainer.update(step, act) 294 | 295 | # save final SAEs 296 | for save_dir, trainer in zip(save_dirs, trainers): 297 | if normalize_activations: 298 | trainer.ae.scale_biases(norm_factor) 299 | if save_dir is not None: 300 | final = {k: v.cpu() for k, v in trainer.ae.state_dict().items()} 301 | t.save(final, os.path.join(save_dir, "ae.pt")) 302 | 303 | # Signal wandb processes to finish 304 | if use_wandb: 305 | for queue in log_queues: 306 | queue.put("DONE") 307 | for process in wandb_processes: 308 | process.join() 309 | -------------------------------------------------------------------------------- /dictionary_learning/utils.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import zstandard as zstd 3 | import io 4 | import json 5 | import os 6 | from nnsight import LanguageModel 7 | 8 | from .trainers.top_k import AutoEncoderTopK 9 | from .trainers.batch_top_k import BatchTopKSAE 10 | from .trainers.matroyshka_batch_top_k import MatroyshkaBatchTopKSAE 11 | from .dictionary import ( 12 | AutoEncoder, 13 | GatedAutoEncoder, 14 | AutoEncoderNew, 15 | JumpReluAutoEncoder, 16 | ) 17 | 18 | 19 | def hf_dataset_to_generator(dataset_name, split="train", streaming=True): 20 | dataset = load_dataset(dataset_name, split=split, streaming=streaming) 21 | 22 | def gen(): 23 | for x in iter(dataset): 24 | yield x["text"] 25 | 26 | return gen() 27 | 28 | 29 | def zst_to_generator(data_path): 30 | """ 31 | Load a dataset from a .jsonl.zst file. 32 | The jsonl entries is assumed to have a 'text' field 33 | """ 34 | compressed_file = open(data_path, "rb") 35 | dctx = zstd.ZstdDecompressor() 36 | reader = dctx.stream_reader(compressed_file) 37 | text_stream = io.TextIOWrapper(reader, encoding="utf-8") 38 | 39 | def generator(): 40 | for line in text_stream: 41 | yield json.loads(line)["text"] 42 | 43 | return generator() 44 | 45 | 46 | def get_nested_folders(path: str) -> list[str]: 47 | """ 48 | Recursively get a list of folders that contain an ae.pt file, starting the search from the given path 49 | """ 50 | folder_names = [] 51 | 52 | for root, dirs, files in os.walk(path): 53 | if "ae.pt" in files: 54 | folder_names.append(root) 55 | 56 | return folder_names 57 | 58 | 59 | def load_dictionary(base_path: str, device: str) -> tuple: 60 | ae_path = f"{base_path}/ae.pt" 61 | config_path = f"{base_path}/config.json" 62 | 63 | with open(config_path, "r") as f: 64 | config = json.load(f) 65 | 66 | dict_class = config["trainer"]["dict_class"] 67 | 68 | if dict_class == "AutoEncoder": 69 | dictionary = AutoEncoder.from_pretrained(ae_path, device=device) 70 | elif dict_class == "GatedAutoEncoder": 71 | dictionary = GatedAutoEncoder.from_pretrained(ae_path, device=device) 72 | elif dict_class == "AutoEncoderNew": 73 | dictionary = AutoEncoderNew.from_pretrained(ae_path, device=device) 74 | elif dict_class == "AutoEncoderTopK": 75 | k = config["trainer"]["k"] 76 | dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device) 77 | elif dict_class == "BatchTopKSAE": 78 | k = config["trainer"]["k"] 79 | dictionary = BatchTopKSAE.from_pretrained(ae_path, k=k, device=device) 80 | elif dict_class == "MatroyshkaBatchTopKSAE": 81 | k = config["trainer"]["k"] 82 | dictionary = MatroyshkaBatchTopKSAE.from_pretrained(ae_path, k=k, device=device) 83 | elif dict_class == "JumpReluAutoEncoder": 84 | dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device) 85 | else: 86 | raise ValueError(f"Dictionary class {dict_class} not supported") 87 | 88 | return dictionary, config 89 | 90 | 91 | def get_submodule(model: LanguageModel, layer: int): 92 | """Gets the residual stream submodule""" 93 | model_name = model._model_key 94 | 95 | if "pythia" in model_name: 96 | return model.gpt_neox.layers[layer] 97 | elif "gemma" in model_name: 98 | return model.model.layers[layer] 99 | else: 100 | raise ValueError(f"Please add submodule for model {model_name}") 101 | -------------------------------------------------------------------------------- /encode_images.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path 3 | import argparse 4 | import tqdm 5 | from utils import get_dataset, get_model 6 | 7 | def get_args_parser(): 8 | parser = argparse.ArgumentParser("Encode images", add_help=False) 9 | parser.add_argument("--embeddings_path") 10 | parser.add_argument("--model_name", default="clip", type=str) 11 | parser.add_argument("--dataset_name", default="imagenet", type=str) 12 | parser.add_argument("--data_path", default="/shared-network/inat2021", type=str) 13 | parser.add_argument("--split", default="train", type=str) 14 | parser.add_argument("--batch_size", default=128, type=int) 15 | parser.add_argument("--num_workers", default=10, type=int) 16 | parser.add_argument("--device", default="cuda:0") 17 | return parser 18 | 19 | 20 | if __name__ == "__main__": 21 | args = get_args_parser().parse_args() 22 | 23 | if os.path.exists(args.embeddings_path): 24 | print(f"Embeddings already saved at {args.embeddings_path}") 25 | else: 26 | model, processor = get_model(args) 27 | ds, dl = get_dataset(args, preprocess=None, processor=processor, split=args.split) 28 | 29 | embeddings = [] 30 | pbar = tqdm.tqdm(dl) 31 | for image in pbar: 32 | with torch.no_grad(): 33 | output = model.encode(image) 34 | embeddings.append(output.detach().cpu()) 35 | 36 | embeddings = torch.cat(embeddings, dim=0) 37 | os.makedirs(os.path.dirname(args.embeddings_path), exist_ok=True) 38 | torch.save(embeddings, args.embeddings_path) 39 | print(f"Embeddings shape: {embeddings.shape}") 40 | print(f"Saved embeddings to {args.embeddings_path}") 41 | 42 | -------------------------------------------------------------------------------- /find_hai_indices.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tqdm 3 | import os 4 | import torch 5 | from torchvision import transforms 6 | from utils import get_dataset 7 | import argparse 8 | from datasets.activations import ActivationsDataset, ChunkedActivationsDataset 9 | from torch.utils.data import DataLoader, Subset 10 | from itertools import combinations 11 | from collections import Counter 12 | import math 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description="Find indices of highest activating images from activations") 16 | parser.add_argument('--activations_dir', type=str, required=True) 17 | parser.add_argument("--dataset_name", default="imagenet", type=str) 18 | parser.add_argument("--data_path", default="/shared-network/inat2021", type=str) 19 | parser.add_argument('--split', type=str, default='train') 20 | parser.add_argument('--k', type=int, default=16) 21 | parser.add_argument('--chunk_size', type=int) 22 | return parser.parse_args() 23 | 24 | if __name__ == "__main__": 25 | # Parse command line arguments 26 | args = parse_args() 27 | args.batch_size = 1 # not used 28 | args.num_workers = 0 # not used 29 | 30 | hai_indices_path = os.path.join(args.activations_dir, f"hai_indices_{args.k}") 31 | if os.path.exists(hai_indices_path): 32 | print(f"HAI indices already saved at {hai_indices_path}") 33 | else: 34 | print("Computing HAI indices", flush=True) 35 | print(f"Loading activations from {args.activations_dir}", flush=True) 36 | activations_dataset = ActivationsDataset(args.activations_dir, device=torch.device("cpu")) 37 | print(f"Dataset loaded. Total samples: {len(activations_dataset)}", flush=True) 38 | 39 | activations_dataloader = DataLoader(activations_dataset, batch_size=args.chunk_size, shuffle=False, num_workers=16) 40 | num_samples = len(activations_dataset) 41 | 42 | first_batch = next(iter(activations_dataloader)) 43 | num_neurons = first_batch.shape[1] 44 | print(f"Number of neurons detected: {num_neurons}") 45 | num_chunks = math.ceil(num_neurons / args.chunk_size) 46 | print(f"Processing {num_chunks} chunks of {args.chunk_size} neurons each...", flush=True) 47 | 48 | importants = [] 49 | worst_hais = [] 50 | pbar = tqdm.tqdm(list(range(num_chunks))) 51 | for i in pbar: 52 | neuron_start = i * args.chunk_size 53 | neuron_end = min((i + 1) * args.chunk_size, num_neurons) 54 | activations_chunks = np.zeros((num_samples, neuron_end - neuron_start)) 55 | for j, activations_chunk in enumerate(activations_dataloader): 56 | sample_start = j * args.chunk_size 57 | sample_end = min((j + 1) * args.chunk_size, num_samples) 58 | activations_chunk = activations_chunk.numpy() 59 | activations_chunks[sample_start:sample_end, :] = activations_chunk[:, neuron_start:neuron_end] 60 | for neuron in range(neuron_end - neuron_start): 61 | neuron_activations = activations_chunks[:, neuron] 62 | important = np.argsort(neuron_activations)[-args.k:] 63 | importants.append(important) 64 | worst_hai = neuron_activations[important[0]] 65 | worst_hais.append(worst_hai) 66 | 67 | hai_indices = np.array(importants) 68 | print(f"hai_indices.shape(): {hai_indices.shape}") 69 | np.save(hai_indices_path, hai_indices) 70 | print(f"Saved HAI indices to: {hai_indices_path}") 71 | 72 | worst_hai_indices_path = os.path.join(args.activations_dir, f"hai_indices_{args.k}_worst") 73 | worst_hais = np.array(worst_hais) 74 | np.save(worst_hai_indices_path, worst_hais) 75 | print(f"Saved worst HAI indices to: {worst_hai_indices_path}") -------------------------------------------------------------------------------- /imagenet_subset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | 5 | def main(imagenet_root, output_dir): 6 | os.makedirs(output_dir, exist_ok=True) 7 | 8 | for class_folder in sorted(os.listdir(imagenet_root)): 9 | class_path = os.path.join(imagenet_root, class_folder) 10 | 11 | if os.path.isdir(class_path): 12 | images = sorted(os.listdir(class_path)) 13 | if images: 14 | first_image = images[0] 15 | src_path = os.path.join(class_path, first_image) 16 | dest_path = os.path.join(output_dir, f"{class_folder}_{first_image}") 17 | 18 | shutil.copy(src_path, dest_path) 19 | print(f"Copied {first_image} from {class_folder}") 20 | 21 | print("Done selecting first images for each class!") 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser(description="Copy the first image from each ImageNet class folder.") 25 | parser.add_argument("--imagenet_root", required=True, help="Path to the ImageNet training dataset root directory") 26 | parser.add_argument("--output_dir", default="./images_imagenet", help="Output directory for selected images") 27 | 28 | args = parser.parse_args() 29 | main(args.imagenet_root, args.output_dir) 30 | -------------------------------------------------------------------------------- /images/white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/sae-for-vlm/69e211dc4d0a31cfb3b3f6f682f78a1890677d5f/images/white.png -------------------------------------------------------------------------------- /inat_depth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tqdm 3 | import os 4 | import torch 5 | from utils import get_dataset 6 | import argparse 7 | from itertools import combinations 8 | from collections import Counter 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description="Measure hierarchy in iNaturalist trained Matryoshka SAE") 13 | parser.add_argument('--activations_dir', type=str, required=True) 14 | parser.add_argument('--hai_indices_path', type=str, required=True) 15 | parser.add_argument("--data_path", default="/shared-network/inat2021", type=str) 16 | parser.add_argument('--split', type=str, default='train') 17 | parser.add_argument('--k', type=int, default=16) 18 | parser.add_argument('--group_fractions', type=float, nargs='+', required=True) 19 | return parser.parse_args() 20 | 21 | 22 | if __name__ == "__main__": 23 | # Parse command line arguments 24 | args = parse_args() 25 | args.batch_size = 1 # not used 26 | args.num_workers = 0 # not used 27 | args.dataset_name = 'inat' # fixed 28 | 29 | # Get HAI indices 30 | hai_indices = np.load(args.hai_indices_path) 31 | print(f"Loaded HAI indices found at {args.hai_indices_path}") 32 | 33 | # Get HAI worst scores 34 | worst_scores_path = f"{args.hai_indices_path[:-4]}_worst.npy" 35 | worst_scores = np.load(worst_scores_path) 36 | print(f"Loaded worst scores of HAI found at {worst_scores_path}") 37 | 38 | # Assign path to each of top k images 39 | hai_classes = [] 40 | num_neurons = hai_indices.shape[0] 41 | ds, _ = get_dataset(args, preprocess=None, processor=None, split=args.split) 42 | 43 | for neuron in range(num_neurons): 44 | hai_classes_neuron = [] 45 | for i in range(args.k): 46 | class_index = hai_indices[neuron, i] 47 | image_path = ds.imgs[class_index][0] 48 | class_name = image_path.split(os.path.sep)[-2].split("_")[1:] 49 | hai_classes_neuron.append(class_name) 50 | hai_classes.append(hai_classes_neuron) 51 | 52 | 53 | # Compute pairwise LCA 54 | def get_lca(x, y): 55 | lca = 0 56 | for a, b in zip(x, y): 57 | if a == b: 58 | lca += 1 59 | else: 60 | break 61 | return lca 62 | 63 | 64 | lcas_majority = [] 65 | lcas_mean = [] 66 | 67 | for neuron in range(num_neurons): 68 | lcas_neuron = [] 69 | hai_classes_neuron = hai_classes[neuron] 70 | for x, y in combinations(hai_classes_neuron, 2): 71 | lca = get_lca(x, y) 72 | lcas_neuron.append(lca) 73 | 74 | if lcas_neuron: 75 | lcas_mean.append(round(sum(lcas_neuron) / len(lcas_neuron))) 76 | lcas_majority.append(Counter(lcas_neuron).most_common(1)[0][0]) 77 | 78 | # Compute avg depth of LCA 79 | assert np.isclose(sum(args.group_fractions), 1.0), "group_fractions must sum to 1.0" 80 | group_sizes = [int(f * num_neurons) for f in args.group_fractions[:-1]] 81 | group_sizes.append(num_neurons - sum(group_sizes)) # Ensure it adds up to num_neurons 82 | 83 | start_idx = 0 84 | depths_majority = [] 85 | depths_mean = [] 86 | 87 | for group_idx, group_size in enumerate(group_sizes): 88 | end_idx = start_idx + group_size 89 | valid_mask = worst_scores[start_idx:end_idx] != 0.0 90 | valid_lcas_majority = np.compress(valid_mask, lcas_majority[start_idx:end_idx]) 91 | valid_lcas_mean = np.compress(valid_mask, lcas_mean[start_idx:end_idx]) 92 | depths_majority.append(np.mean(valid_lcas_majority)) 93 | depths_mean.append(np.mean(valid_lcas_mean)) 94 | start_idx = end_idx 95 | 96 | num_excluded = np.sum(~valid_mask) 97 | percentage_excluded = (num_excluded / group_size) * 100 98 | print(f"Group {group_idx}: {percentage_excluded:.2f}% neurons excluded") 99 | 100 | print("Group-wise Average WordNet Depths (mean):", depths_mean) 101 | print("Group-wise Average WordNet Depths (majority):", depths_majority) -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path 3 | import argparse 4 | from datasets.activations import ActivationsDataset 5 | import os 6 | 7 | from torch.utils.data import DataLoader, Subset 8 | import tqdm 9 | import torch.nn.functional as F 10 | 11 | def get_args_parser(): 12 | parser = argparse.ArgumentParser("Measure monosemanticity via weighted pairwise cosine similarity", add_help=False) 13 | parser.add_argument("--embeddings_path") 14 | parser.add_argument("--activations_dir") 15 | parser.add_argument("--output_subdir") 16 | parser.add_argument("--device", default="cpu") 17 | return parser 18 | 19 | def main(args): 20 | # Load embeddings 21 | embeddings = torch.load(args.embeddings_path, map_location=torch.device(args.device)) 22 | print(f"Loaded embeddings found at {args.embeddings_path}") 23 | print(f"Embeddings shape: {embeddings.shape}") 24 | 25 | # Load activations 26 | activations_dataset = ActivationsDataset(args.activations_dir, device=torch.device(args.device), take_every=1) 27 | activations_dataloader = DataLoader(activations_dataset, batch_size=len(activations_dataset), shuffle=False) 28 | activations = next(iter(activations_dataloader)) 29 | print(f"Loaded activations found at {args.activations_dir}") 30 | print(f"Activations shape: {activations.shape}") 31 | 32 | # Scale to 0-1 per neuron 33 | min_values = activations.min(dim=0, keepdim=True)[0] 34 | max_values = activations.max(dim=0, keepdim=True)[0] 35 | activations = (activations - min_values) / (max_values - min_values) 36 | 37 | # embeddings = embeddings - embeddings.mean(dim=0, keepdim=True) 38 | num_images, embed_dim = embeddings.shape 39 | num_neurons = activations.shape[1] 40 | 41 | # Initialize accumulators 42 | weighted_cosine_similarity_sum = torch.zeros(num_neurons, device=torch.device(args.device)) 43 | weight_sum = torch.zeros(num_neurons, device=torch.device(args.device)) 44 | batch_size = 100 # Set batch size 45 | 46 | for i in tqdm.tqdm(range(num_images), desc="Processing image pairs"): 47 | for j_start in range(i + 1, num_images, batch_size): # Process in batches 48 | j_end = min(j_start + batch_size, num_images) 49 | 50 | embeddings_i = embeddings[i].cuda() # (embedding_dim) 51 | embeddings_j = embeddings[j_start:j_end].cuda() # (batch_size, embedding_dim) 52 | activations_i = activations[i].cuda() # (num_neurons) 53 | activations_j = activations[j_start:j_end].cuda() # (batch_size, num_neurons) 54 | 55 | # Compute cosine similarity 56 | cosine_similarities = F.cosine_similarity( 57 | embeddings_i.unsqueeze(0).expand(j_end - j_start, -1), # Expanding to (batch_size, embedding_dim) 58 | embeddings_j, 59 | dim=1 60 | ) 61 | 62 | # Compute weights and weighted similarities 63 | # Expanding activations_i to (1, num_neurons) 64 | weights = activations_i.unsqueeze(0) * activations_j # (batch_size, num_neurons) 65 | weighted_cosine_similarities = weights * cosine_similarities.unsqueeze(1) # (batch_size, num_neurons) 66 | 67 | weighted_cosine_similarities = torch.sum(weighted_cosine_similarities, dim=0) # (num_neurons) 68 | weighted_cosine_similarity_sum += weighted_cosine_similarities.cpu() 69 | 70 | weights = torch.sum(weights, dim=0) # (num_neurons) 71 | weight_sum += weights.cpu() 72 | 73 | monosemanticity = torch.where(weight_sum != 0, weighted_cosine_similarity_sum / weight_sum, torch.nan) 74 | 75 | os.makedirs(os.path.join(args.activations_dir, args.output_subdir), exist_ok=True) 76 | torch.save(monosemanticity, os.path.join(args.activations_dir, args.output_subdir, "all_neurons_scores.pth")) 77 | 78 | is_nan = torch.isnan(monosemanticity) 79 | nan_count = is_nan.sum() 80 | monosemanticity_mean = torch.mean(monosemanticity[~is_nan]) 81 | monosemanticity_std = torch.std(monosemanticity[~is_nan]) 82 | 83 | print(f"Monosemanticity: {monosemanticity_mean.item()} +- {monosemanticity_std.item()}") 84 | print(f"Dead neurons:", nan_count.item()) 85 | print(f"Total neurons:", num_neurons) 86 | 87 | # Filter out NaNs 88 | valid_indices = ~torch.isnan(monosemanticity) 89 | valid_monosemanticity = monosemanticity[valid_indices] 90 | valid_indices = torch.nonzero(valid_indices).squeeze() 91 | 92 | # Get top 10 highest and lowest monosemantic neurons 93 | top_10_values, top_10_indices = torch.topk(valid_monosemanticity, 10) 94 | bottom_10_values, bottom_10_indices = torch.topk(valid_monosemanticity, 10, largest=False) 95 | 96 | # Map indices back to original positions 97 | top_10_indices = valid_indices[top_10_indices] 98 | bottom_10_indices = valid_indices[bottom_10_indices] 99 | 100 | # Print results 101 | print("Top 10 most monosemantic neurons:") 102 | for i, (idx, val) in enumerate(zip(top_10_indices, top_10_values)): 103 | print(f"{i + 1}. Neuron {idx.item()} - {val.item()}") 104 | 105 | print("\nBottom 10 least monosemantic neurons:") 106 | for i, (idx, val) in enumerate(zip(bottom_10_indices, bottom_10_values)): 107 | print(f"{i + 1}. Neuron {idx.item()} - {val.item()}") 108 | 109 | # Save to file 110 | output_path = os.path.join(args.activations_dir, args.output_subdir, "metric_stats_new.txt") 111 | with open(output_path, "w") as file: 112 | file.write(f"Monosemanticity: {monosemanticity_mean.item()} +- {monosemanticity_std.item()}\n") 113 | file.write(f"Dead neurons: {nan_count.item()}\n") 114 | file.write(f"Total neurons: {num_neurons}\n\n") 115 | 116 | file.write("Top 10 most monosemantic neurons:\n") 117 | for idx, val in zip(top_10_indices, top_10_values): 118 | file.write(f"Neuron {idx.item()} - {val.item()}\n") 119 | 120 | file.write("\nBottom 10 least monosemantic neurons:\n") 121 | for idx, val in zip(bottom_10_indices, bottom_10_values): 122 | file.write(f"Neuron {idx.item()} - {val.item()}\n") 123 | 124 | 125 | if __name__ == "__main__": 126 | args = get_args_parser() 127 | args = args.parse_args() 128 | main(args) 129 | -------------------------------------------------------------------------------- /models/clip.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoProcessor, CLIPVisionModelWithProjection 2 | import torch 3 | import torch.nn as nn 4 | from typing import Optional, Tuple 5 | 6 | class Clip: 7 | def __init__(self, model_name, device): 8 | self.device = device 9 | self.model = CLIPVisionModelWithProjection.from_pretrained(f"openai/{model_name}").to(device) 10 | self.processor = AutoProcessor.from_pretrained(f"openai/{model_name}") 11 | self.register = {} 12 | self.attach_methods = { 13 | 'in_mlp': self._attach_in_mlp, 14 | 'post_mlp': self._attach_post_mlp, 15 | 'post_mlp_residual': self._attach_post_mlp_residual, 16 | 'post_projection': self._attach_post_projection, 17 | } 18 | 19 | def encode(self, inputs): 20 | for hook in self.register.keys(): 21 | self.register[hook] = [] 22 | outputs = self.model(**inputs.to(self.device)) 23 | # pooled_output = outputs.pooler_output 24 | # return pooled_output 25 | image_embeds = outputs.image_embeds 26 | return image_embeds 27 | 28 | def attach(self, attachment_point, layer, sae=None): 29 | if attachment_point in self.attach_methods: 30 | self.attach_methods[attachment_point](layer, sae) 31 | self.register[f'{attachment_point}_{layer}'] = [] 32 | else: 33 | raise NotImplementedError(f"Attachment point {attachment_point} not implemented") 34 | 35 | def _attach_in_mlp(self, layer, sae): 36 | raise NotImplementedError 37 | 38 | def _attach_post_mlp(self, layer, sae): 39 | raise NotImplementedError 40 | 41 | def _attach_post_mlp_residual(self, layer, sae): 42 | self.model.vision_model.encoder.layers[layer] = CLIPEncoderLayerPostMlpResidual( 43 | self.model.vision_model.encoder.layers[layer], 44 | sae, 45 | layer, 46 | self.register, 47 | ) 48 | 49 | def _attach_post_projection(self, layer, sae): 50 | self.model.visual_projection = CLIPProjectionLayer( 51 | self.model.visual_projection, 52 | sae, 53 | layer, 54 | self.register, 55 | ) 56 | 57 | class CLIPProjectionLayer(nn.Module): 58 | def __init__(self, projector, sae, layer, register): 59 | super().__init__() 60 | self.projector = projector 61 | self.sae = sae 62 | self.layer = layer 63 | self.register = register 64 | 65 | def forward(self, inputs): 66 | outputs = self.projector(inputs) 67 | if self.sae is not None: 68 | outputs = self.sae.encode(outputs) 69 | self.register[f'post_projection_{self.layer}'].append(outputs.detach().cpu()) 70 | outputs = self.sae.decode(outputs) 71 | else: 72 | self.register[f'post_projection_{self.layer}'].append(outputs.detach().cpu()) 73 | return outputs 74 | 75 | class CLIPEncoderLayerPostMlpResidual(nn.Module): 76 | def __init__(self, base, sae, layer, register): 77 | super().__init__() 78 | self.embed_dim = base.embed_dim 79 | self.self_attn = base.self_attn 80 | self.layer_norm1 = base.layer_norm1 81 | self.mlp = base.mlp 82 | self.layer_norm2 = base.layer_norm2 83 | 84 | self.sae = sae 85 | self.layer = layer 86 | self.register = register 87 | 88 | def forward( 89 | self, 90 | hidden_states: torch.Tensor, 91 | attention_mask: torch.Tensor, 92 | causal_attention_mask: torch.Tensor, 93 | output_attentions: Optional[bool] = False, 94 | ) -> Tuple[torch.FloatTensor]: 95 | residual = hidden_states 96 | 97 | hidden_states = self.layer_norm1(hidden_states) 98 | hidden_states, attn_weights = self.self_attn( 99 | hidden_states=hidden_states, 100 | attention_mask=attention_mask, 101 | causal_attention_mask=causal_attention_mask, 102 | output_attentions=output_attentions, 103 | ) 104 | hidden_states = residual + hidden_states 105 | 106 | residual = hidden_states 107 | hidden_states = self.layer_norm2(hidden_states) 108 | hidden_states = self.mlp(hidden_states) 109 | hidden_states = residual + hidden_states 110 | 111 | if self.sae is not None: 112 | hidden_states = self.sae.encode(hidden_states) 113 | self.register[f'post_mlp_residual_{self.layer}'].append(hidden_states.detach().cpu()) 114 | hidden_states = self.sae.decode(hidden_states) 115 | else: 116 | self.register[f'post_mlp_residual_{self.layer}'].append(hidden_states.detach().cpu()) 117 | 118 | outputs = (hidden_states,) 119 | 120 | if output_attentions: 121 | outputs += (attn_weights,) 122 | 123 | return outputs 124 | -------------------------------------------------------------------------------- /models/dino.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoImageProcessor, Dinov2Model 2 | import torch 3 | 4 | class Dino: 5 | def __init__(self, model_name="dinov2-base", device=torch.device("cuda")): 6 | self.device = device 7 | self.model = Dinov2Model.from_pretrained(f"facebook/{model_name}").to(device) 8 | self.processor = AutoImageProcessor.from_pretrained(f"facebook/{model_name}") 9 | 10 | def encode(self, inputs): 11 | outputs = self.model(**inputs.to(self.device)) 12 | image_embeds = outputs.pooler_output 13 | return image_embeds 14 | -------------------------------------------------------------------------------- /models/llava.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoProcessor, LlavaForConditionalGeneration 2 | import torch 3 | import torch.nn as nn 4 | from typing import Optional, Tuple 5 | import copy 6 | 7 | class Llava: 8 | 9 | def __init__(self, device): 10 | self.device = device 11 | self.layer = 22 12 | self.model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", 13 | torch_dtype=torch.float16, 14 | device_map=self.device) 15 | self.processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") 16 | self.base_CLIPEncoderLayerPostMlpResidual = copy.deepcopy( 17 | self.model.vision_tower.vision_model.encoder.layers[self.layer] 18 | ) 19 | 20 | def prompt(self, text, image, max_tokens=5): 21 | conversation = [ 22 | { 23 | "role": "user", 24 | "content": [ 25 | {"type": "image"}, 26 | {"type": "text", "text": text}, 27 | ], 28 | }, 29 | ] 30 | prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) 31 | inputs = self.processor(images=[image], text=[prompt], 32 | padding=True, return_tensors="pt").to(self.model.device, torch.float16) 33 | generate_ids = self.model.generate(**inputs, max_new_tokens=max_tokens) 34 | output = self.processor.batch_decode(generate_ids, skip_special_tokens=True) 35 | output = [x.split('ASSISTANT: ')[-1] for x in output] 36 | return output 37 | 38 | def attach_and_fix(self, sae, neurons_to_fix={}, pre_zero=False): 39 | modified_sae = SAEWrapper(sae, neurons_to_fix, pre_zero) 40 | self.model.vision_tower.vision_model.encoder.layers[self.layer] = CLIPEncoderLayerPostMlpResidual( 41 | self.base_CLIPEncoderLayerPostMlpResidual, 42 | modified_sae, 43 | ) 44 | 45 | 46 | class SAEWrapper(nn.Module): 47 | 48 | def __init__(self, sae, neurons_to_fix, pre_zero): 49 | super().__init__() 50 | self.sae = sae 51 | self.neurons_to_fix = neurons_to_fix 52 | self.pre_zero = pre_zero 53 | 54 | def encode(self, x): 55 | x = self.sae.encode(x) 56 | if self.pre_zero: 57 | x = torch.zeros_like(x) 58 | for neuron_id, value in self.neurons_to_fix.items(): 59 | x[:, :, neuron_id] = value 60 | return x 61 | 62 | def decode(self, x): 63 | x = self.sae.decode(x) 64 | x = x.to(dtype=torch.float16) 65 | return x 66 | 67 | 68 | 69 | 70 | class CLIPEncoderLayerPostMlpResidual(nn.Module): 71 | 72 | def __init__(self, base, sae): 73 | super().__init__() 74 | self.embed_dim = base.embed_dim 75 | self.self_attn = base.self_attn 76 | self.layer_norm1 = base.layer_norm1 77 | self.mlp = base.mlp 78 | self.layer_norm2 = base.layer_norm2 79 | self.sae = sae 80 | 81 | def forward( 82 | self, 83 | hidden_states: torch.Tensor, 84 | attention_mask: torch.Tensor, 85 | causal_attention_mask: torch.Tensor, 86 | output_attentions: Optional[bool] = False, 87 | ) -> Tuple[torch.FloatTensor]: 88 | residual = hidden_states 89 | hidden_states = self.layer_norm1(hidden_states) 90 | hidden_states, attn_weights = self.self_attn( 91 | hidden_states=hidden_states, 92 | attention_mask=attention_mask, 93 | causal_attention_mask=causal_attention_mask, 94 | output_attentions=output_attentions, 95 | ) 96 | hidden_states = residual + hidden_states 97 | residual = hidden_states 98 | hidden_states = self.layer_norm2(hidden_states) 99 | hidden_states = self.mlp(hidden_states) 100 | hidden_states = residual + hidden_states 101 | encoded_hidden_states = self.sae.encode(hidden_states) 102 | decoded_hidden_states = self.sae.decode(encoded_hidden_states) 103 | hidden_states = decoded_hidden_states 104 | outputs = (hidden_states,) 105 | if output_attentions: 106 | outputs += (attn_weights,) 107 | return outputs -------------------------------------------------------------------------------- /models/siglip.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoProcessor, SiglipVisionModel 2 | import torch 3 | import torch.nn as nn 4 | from typing import Optional, Tuple 5 | 6 | 7 | # siglip-so400m-patch14-384 8 | class Siglip: 9 | def __init__(self, model_name, device): 10 | self.device = device 11 | self.model = SiglipVisionModel.from_pretrained(f"google/{model_name}").to(device) 12 | self.processor = AutoProcessor.from_pretrained(f"google/{model_name}") 13 | self.register = {} 14 | self.attach_methods = { 15 | 'post_mlp_residual': self._attach_post_mlp_residual, 16 | 'post_projection': self._attach_post_projection, 17 | } 18 | self.sae = None 19 | self.layer = None 20 | 21 | def encode(self, inputs): 22 | for hook in self.register.keys(): 23 | self.register[hook] = [] 24 | outputs = self.model(**inputs.to(self.device)) 25 | pooled_output = outputs.pooler_output 26 | 27 | if self.sae is not None: 28 | pooled_output = self.sae.encode(pooled_output) 29 | self.register[f'post_projection_{self.layer}'].append(pooled_output.detach().cpu()) 30 | pooled_output = self.sae.decode(pooled_output) 31 | elif self.layer is not None: 32 | self.register[f'post_projection_{self.layer}'].append(pooled_output.detach().cpu()) 33 | 34 | return pooled_output 35 | 36 | def attach(self, attachment_point, layer, sae=None): 37 | if attachment_point in self.attach_methods: 38 | self.attach_methods[attachment_point](layer, sae) 39 | self.register[f'{attachment_point}_{layer}'] = [] 40 | else: 41 | raise NotImplementedError(f"Attachment point {attachment_point} not implemented") 42 | 43 | def _attach_post_mlp_residual(self, layer, sae): 44 | self.model.vision_model.encoder.layers[layer] = SiglipEncoderLayerPostMlpResidual( 45 | self.model.vision_model.encoder.layers[layer], 46 | sae, 47 | layer, 48 | self.register, 49 | ) 50 | 51 | def _attach_post_projection(self, layer, sae): 52 | self.sae = sae 53 | self.layer = layer 54 | 55 | 56 | class SiglipEncoderLayerPostMlpResidual(nn.Module): 57 | def __init__(self, base, sae, layer, register): 58 | super().__init__() 59 | self.embed_dim = base.embed_dim 60 | self.self_attn = base.self_attn 61 | self.layer_norm1 = base.layer_norm1 62 | self.mlp = base.mlp 63 | self.layer_norm2 = base.layer_norm2 64 | 65 | self.sae = sae 66 | self.layer = layer 67 | self.register = register 68 | 69 | def forward( 70 | self, 71 | hidden_states: torch.Tensor, 72 | attention_mask: torch.Tensor, 73 | output_attentions: Optional[bool] = False, 74 | ) -> Tuple[torch.FloatTensor]: 75 | residual = hidden_states 76 | 77 | hidden_states = self.layer_norm1(hidden_states) 78 | hidden_states, attn_weights = self.self_attn( 79 | hidden_states=hidden_states, 80 | attention_mask=attention_mask, 81 | output_attentions=output_attentions, 82 | ) 83 | hidden_states = residual + hidden_states 84 | 85 | residual = hidden_states 86 | hidden_states = self.layer_norm2(hidden_states) 87 | hidden_states = self.mlp(hidden_states) 88 | hidden_states = residual + hidden_states 89 | 90 | if self.sae is not None: 91 | hidden_states = self.sae.encode(hidden_states) 92 | self.register[f'post_mlp_residual_{self.layer}'].append(hidden_states.detach().cpu()) 93 | hidden_states = self.sae.decode(hidden_states) 94 | else: 95 | self.register[f'post_mlp_residual_{self.layer}'].append(hidden_states.detach().cpu()) 96 | 97 | outputs = (hidden_states,) 98 | 99 | if output_attentions: 100 | outputs += (attn_weights,) 101 | 102 | return outputs -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.1.1 2 | aiohappyeyeballs==2.4.3 3 | aiohttp==3.11.8 4 | aiosignal==1.3.1 5 | annotated-types==0.7.0 6 | anyio==4.7.0 7 | argon2-cffi==23.1.0 8 | argon2-cffi-bindings==21.2.0 9 | arrow==1.3.0 10 | asttokens==3.0.0 11 | async-lru==2.0.4 12 | attrs==24.2.0 13 | babel==2.16.0 14 | beautifulsoup4==4.12.3 15 | bidict==0.23.1 16 | bleach==6.2.0 17 | certifi==2024.8.30 18 | cffi==1.17.1 19 | charset-normalizer==3.4.0 20 | circuitsvis==1.43.2 21 | click==8.1.7 22 | comm==0.2.2 23 | contourpy==1.3.1 24 | cycler==0.12.1 25 | datasets==3.1.0 26 | debugpy==1.8.9 27 | decorator==5.1.1 28 | defusedxml==0.7.1 29 | diffusers==0.31.0 30 | dill==0.3.8 31 | docker-pycreds==0.4.0 32 | einops==0.8.0 33 | executing==2.1.0 34 | fastjsonschema==2.21.1 35 | filelock==3.16.1 36 | fonttools==4.55.0 37 | fqdn==1.5.1 38 | frozenlist==1.5.0 39 | fsspec==2024.9.0 40 | gitdb==4.0.11 41 | GitPython==3.1.43 42 | h11==0.14.0 43 | httpcore==1.0.7 44 | httpx==0.28.1 45 | huggingface-hub==0.26.3 46 | idna==3.10 47 | importlib_metadata==8.5.0 48 | ipykernel==6.29.5 49 | ipython==8.30.0 50 | ipywidgets==8.1.5 51 | isoduration==20.11.0 52 | jedi==0.19.2 53 | Jinja2==3.1.4 54 | joblib==1.4.2 55 | json5==0.10.0 56 | jsonpointer==3.0.0 57 | jsonschema==4.23.0 58 | jsonschema-specifications==2024.10.1 59 | jupyter==1.1.1 60 | jupyter-console==6.6.3 61 | jupyter-events==0.10.0 62 | jupyter-lsp==2.2.5 63 | jupyter_client==8.6.3 64 | jupyter_core==5.7.2 65 | jupyter_server==2.14.2 66 | jupyter_server_terminals==0.5.3 67 | jupyterlab==4.3.3 68 | jupyterlab_pygments==0.3.0 69 | jupyterlab_server==2.27.3 70 | jupyterlab_widgets==3.0.13 71 | kiwisolver==1.4.7 72 | llvmlite==0.43.0 73 | MarkupSafe==3.0.2 74 | matplotlib==3.9.3 75 | matplotlib-inline==0.1.7 76 | mistune==3.0.2 77 | mpmath==1.3.0 78 | multidict==6.1.0 79 | multiprocess==0.70.16 80 | nbclient==0.10.1 81 | nbconvert==7.16.4 82 | nbformat==5.10.4 83 | nest-asyncio==1.6.0 84 | networkx==3.4.2 85 | nltk==3.9.1 86 | nnsight==0.3.3 87 | notebook==7.3.1 88 | notebook_shim==0.2.4 89 | numba==0.60.0 90 | numpy==1.26.4 91 | nvidia-cublas-cu12==12.1.3.1 92 | nvidia-cuda-cupti-cu12==12.1.105 93 | nvidia-cuda-nvrtc-cu12==12.1.105 94 | nvidia-cuda-runtime-cu12==12.1.105 95 | nvidia-cudnn-cu12==8.9.2.26 96 | nvidia-cufft-cu12==11.0.2.54 97 | nvidia-curand-cu12==10.3.2.106 98 | nvidia-cusolver-cu12==11.4.5.107 99 | nvidia-cusparse-cu12==12.1.0.106 100 | nvidia-nccl-cu12==2.18.1 101 | nvidia-nvjitlink-cu12==12.6.85 102 | nvidia-nvtx-cu12==12.1.105 103 | overrides==7.7.0 104 | packaging==24.2 105 | pandas==2.2.3 106 | pandocfilters==1.5.1 107 | parso==0.8.4 108 | pexpect==4.9.0 109 | pillow==11.0.0 110 | platformdirs==4.3.6 111 | plotly==5.24.1 112 | prometheus_client==0.21.1 113 | prompt_toolkit==3.0.48 114 | propcache==0.2.0 115 | protobuf==5.29.0 116 | psutil==6.1.0 117 | ptyprocess==0.7.0 118 | pure_eval==0.2.3 119 | pyarrow==18.1.0 120 | pycparser==2.22 121 | pydantic==2.10.2 122 | pydantic_core==2.27.1 123 | Pygments==2.18.0 124 | pynndescent==0.5.13 125 | pyparsing==3.2.0 126 | python-dateutil==2.9.0.post0 127 | python-engineio==4.10.1 128 | python-json-logger==2.0.7 129 | python-socketio==5.11.4 130 | pytz==2024.2 131 | PyYAML==6.0.2 132 | pyzmq==26.2.0 133 | referencing==0.35.1 134 | regex==2024.11.6 135 | requests==2.32.3 136 | rfc3339-validator==0.1.4 137 | rfc3986-validator==0.1.1 138 | rpds-py==0.22.3 139 | safetensors==0.4.5 140 | scikit-learn==1.5.2 141 | scipy==1.14.1 142 | Send2Trash==1.8.3 143 | sentencepiece==0.2.0 144 | sentry-sdk==2.19.0 145 | setproctitle==1.3.4 146 | simple-websocket==1.1.0 147 | six==1.16.0 148 | smmap==5.0.1 149 | sniffio==1.3.1 150 | soupsieve==2.6 151 | stack-data==0.6.3 152 | sympy==1.13.3 153 | tenacity==9.0.0 154 | terminado==0.18.1 155 | threadpoolctl==3.5.0 156 | tinycss2==1.4.0 157 | tokenizers==0.20.3 158 | torch==2.1.2 159 | torchvision==0.16.2 160 | tornado==6.4.2 161 | tqdm==4.67.1 162 | traitlets==5.14.3 163 | transformers==4.46.3 164 | triton==2.1.0 165 | types-python-dateutil==2.9.0.20241206 166 | typing_extensions==4.12.2 167 | tzdata==2024.2 168 | umap-learn==0.5.7 169 | uri-template==1.3.0 170 | urllib3==2.2.3 171 | wandb==0.18.7 172 | wcwidth==0.2.13 173 | webcolors==24.11.1 174 | webencodings==0.5.1 175 | websocket-client==1.8.0 176 | widgetsnbextension==4.0.13 177 | wsproto==1.2.0 178 | xxhash==3.5.0 179 | yarl==1.18.0 180 | zipp==3.21.0 181 | zstandard==0.23.0 182 | -------------------------------------------------------------------------------- /sae_train.py: -------------------------------------------------------------------------------- 1 | from dictionary_learning import ActivationBuffer, AutoEncoder, JumpReluAutoEncoder 2 | from dictionary_learning.trainers import * 3 | from dictionary_learning.training import trainSAE 4 | from torch.utils.data import DataLoader 5 | from datasets.activations import ActivationsDataset 6 | import torch 7 | from pathlib import Path 8 | import argparse 9 | 10 | def get_args_parser(): 11 | parser = argparse.ArgumentParser("Train Sparse Autoencoder", add_help=False) 12 | parser.add_argument("--sae_model", default="jumprelu", type=str) 13 | parser.add_argument("--activations_dir", required=True, type=str) 14 | parser.add_argument("--val_activations_dir", required=True, type=str) 15 | parser.add_argument("--checkpoints_dir", default="./output_dir", type=str) 16 | parser.add_argument("--device", default="cuda:0") 17 | parser.add_argument("--expansion_factor", type=int, default=1) 18 | parser.add_argument("--lr", type=float) 19 | parser.add_argument("--batch_size", type=int, default=8192) 20 | parser.add_argument("--steps", type=int, default=10_000) 21 | parser.add_argument("--save_steps", type=int, default=1_000) 22 | parser.add_argument("--log_steps", type=int, default=50) 23 | # JumpRelu 24 | parser.add_argument("--bandwidth", type=float, default=0.001) 25 | parser.add_argument("--sparsity_penalty", type=float, default=0.1) 26 | # Standard 27 | parser.add_argument("--l1_penalty", type=float, default=0.1) 28 | parser.add_argument("--warmup_steps", type=int, default=0) 29 | parser.add_argument("--resample_steps", type=int, default=None) 30 | # TopK + Batch TopK 31 | parser.add_argument("--k", type=int, default=8) 32 | parser.add_argument("--auxk_alpha", type=float, default=1/32) 33 | parser.add_argument("--decay_start", type=int, default=1_000_000) 34 | # Batch TopK 35 | parser.add_argument("--threshold_beta", type=float, default=0.999) 36 | parser.add_argument("--threshold_start_step", type=int, default=1000) 37 | # MatryoshkaBatchTopK 38 | parser.add_argument("--group_fractions", type=float, nargs="+") 39 | 40 | return parser 41 | 42 | def train_sae(args): 43 | dataset = ActivationsDataset(args.activations_dir, device=torch.device(args.device)) 44 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) 45 | 46 | val_dataset = ActivationsDataset(args.val_activations_dir, device=torch.device(args.device)) 47 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) 48 | 49 | sample = next(iter(dataloader)) 50 | print(sample.shape) 51 | 52 | activation_dim = sample.shape[1] 53 | dictionary_size = args.expansion_factor * activation_dim 54 | 55 | trainers = { 56 | 'jumprelu': JumpReluTrainer, 57 | 'standard': StandardTrainer, 58 | 'batch_top_k': BatchTopKTrainer, 59 | 'top_k': TopKTrainer, 60 | 'matroyshka_batch_top_k': MatroyshkaBatchTopKTrainer, 61 | } 62 | 63 | # autoencoders = { 64 | # 'jumprelu': JumpReluAutoEncoder, 65 | # 'standard': AutoEncoder, 66 | # 'batch_top_k': BatchTopKSAE, 67 | # 'top_k': AutoEncoderTopK, 68 | # } 69 | 70 | trainer_cfg = { 71 | "trainer": trainers[args.sae_model], 72 | "activation_dim": activation_dim, 73 | "dict_size": dictionary_size, 74 | "lr": args.lr, 75 | "device": args.device, 76 | "steps": args.steps, 77 | "layer": "", 78 | "lm_name": "", 79 | "submodule_name": "" 80 | } 81 | 82 | if args.sae_model == "jumprelu": 83 | trainer_cfg["bandwidth"] = args.bandwidth 84 | trainer_cfg["sparsity_penalty"] = args.sparsity_penalty 85 | if args.sae_model == "standard": 86 | trainer_cfg["l1_penalty"] = args.l1_penalty 87 | trainer_cfg["warmup_steps"] = args.warmup_steps 88 | trainer_cfg["resample_steps"] = args.resample_steps 89 | if args.sae_model == "top_k" or args.sae_model == "batch_top_k" or args.sae_model == "matroyshka_batch_top_k": 90 | trainer_cfg["k"] = args.k 91 | trainer_cfg["auxk_alpha"] = args.auxk_alpha 92 | trainer_cfg["decay_start"] = args.decay_start 93 | if args.sae_model == "batch_top_k" or args.sae_model == "matroyshka_batch_top_k": 94 | trainer_cfg["threshold_beta"] = args.threshold_beta 95 | trainer_cfg["threshold_start_step"] = args.threshold_start_step 96 | if args.sae_model == "matroyshka_batch_top_k": 97 | trainer_cfg["group_fractions"] = args.group_fractions 98 | 99 | dataset_name = Path(args.activations_dir).name 100 | save_dir = Path(args.checkpoints_dir) / f"{dataset_name}_{args.sae_model}_{args.k}_x{args.expansion_factor}" 101 | save_dir.mkdir(parents=True, exist_ok=True) 102 | 103 | ae = trainSAE( 104 | data=dataloader, 105 | val_data=val_dataloader, 106 | trainer_configs=[trainer_cfg], 107 | use_wandb=True, 108 | wandb_entity="mateuszpach", 109 | wandb_project="Clip SAE", 110 | steps=args.steps, 111 | save_steps=[x for x in range(0, args.steps, args.save_steps)], 112 | save_dir=save_dir, 113 | log_steps=args.log_steps, 114 | ) 115 | 116 | 117 | if __name__ == "__main__": 118 | args = get_args_parser() 119 | args = args.parse_args() 120 | train_sae(args) -------------------------------------------------------------------------------- /save_activations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from torch.utils.data import DataLoader 5 | import tqdm 6 | import argparse 7 | from pathlib import Path 8 | from torchvision.datasets import ImageFolder 9 | from utils import get_dataset, get_model 10 | from torchvision.transforms import ToTensor 11 | from dictionary_learning import AutoEncoder 12 | from dictionary_learning.trainers import BatchTopKSAE, MatroyshkaBatchTopKSAE 13 | 14 | 15 | def get_args_parser(): 16 | parser = argparse.ArgumentParser("Save activations used to train SAE", add_help=False) 17 | parser.add_argument("--batch_size", default=128, type=int) 18 | parser.add_argument("--sae_model", default=None, type=str) 19 | parser.add_argument("--model_name", default="clip", type=str) 20 | parser.add_argument("--attachment_point", default="post_mlp_residual", type=str) 21 | parser.add_argument("--layer", default=-1, type=int) 22 | parser.add_argument("--sae_path", default=None, type=str) 23 | parser.add_argument("--dataset_name", default="imagenet", type=str) 24 | parser.add_argument("--data_path", default="/shared-network/inat2021", type=str) 25 | parser.add_argument("--split", default="train", type=str) 26 | parser.add_argument("--num_workers", default=10, type=int) 27 | parser.add_argument("--output_dir", default="./output_dir", type=str) 28 | parser.add_argument("--cls_only", default=False, action="store_true") 29 | parser.add_argument("--mean_pool", default=False, action="store_true") 30 | parser.add_argument("--take_every", default=1, type=int) 31 | parser.add_argument("--random_k", default=-1, type=int) 32 | parser.add_argument("--save_every", default=50_000, type=int) 33 | parser.add_argument("--device", default="cuda:0") 34 | return parser 35 | 36 | def save_activations(activations, count, split, save_count, args): 37 | activations_tensor = torch.cat(activations, dim=0) 38 | if args.take_every > 1: 39 | # Pick every n-th activation in the batch 40 | activations_tensor = activations_tensor[::args.take_every, :, :] 41 | 42 | if args.layer == -1: 43 | # Tokens already pooled 44 | activations_tensor = activations_tensor 45 | elif args.cls_only: 46 | # Keep only CLS token 47 | activations_tensor = activations_tensor[:, 0, :] 48 | elif args.mean_pool: 49 | # Mean pool tokens into one data point 50 | activations_tensor = torch.mean(activations_tensor, dim=1) 51 | elif args.random_k != -1: 52 | # Treat each token as a separate data point but pick random k tokens from each image 53 | batch_size, seq_len, hidden_dim = activations_tensor.shape 54 | indices = torch.randint(0, seq_len, (batch_size, args.random_k)) 55 | activations_tensor = torch.stack([activations_tensor[i, indices[i], :] for i in range(batch_size)]) 56 | activations_tensor = activations_tensor.reshape(-1, hidden_dim) 57 | else: 58 | # Treat each token as a separate data point and use all the tokens 59 | activations_tensor = activations_tensor.reshape(activations_tensor.shape[0] * activations_tensor.shape[1], 60 | activations_tensor.shape[2]) 61 | 62 | filename = f"{args.dataset_name}_{split}_activations_{args.model_name}_{args.layer}_{args.attachment_point}_part{save_count + 1}.pt" 63 | save_path = os.path.join(args.output_dir, filename) 64 | torch.save(torch.tensor(activations_tensor.cpu().numpy()), save_path) 65 | print(f"Saved the activations at count {count} to {save_path}") 66 | 67 | def collect_activations(args): 68 | model, processor = get_model(args) 69 | 70 | if args.sae_model is not None: 71 | if args.sae_model == "standard": 72 | sae = AutoEncoder.from_pretrained(args.sae_path).to(args.device) 73 | if args.sae_model == "batch_top_k": 74 | sae = BatchTopKSAE.from_pretrained(args.sae_path).to(args.device) 75 | if args.sae_model == "matroyshka_batch_top_k": 76 | sae = MatroyshkaBatchTopKSAE.from_pretrained(args.sae_path).to(args.device) 77 | print(f"Attached SAE from {args.sae_path}") 78 | else: 79 | sae = None 80 | print(f"No SAE attached. Saving original activations") 81 | 82 | model.attach(args.attachment_point, args.layer, sae=sae) 83 | 84 | ds, dl = get_dataset(args, preprocess=None, processor=processor, split=args.split) 85 | activations = [] 86 | count = 0 87 | save_count = 0 88 | pbar = tqdm.tqdm(dl) 89 | for image in pbar: 90 | 91 | with torch.no_grad(): 92 | model.encode(image) 93 | activations.extend(model.register[f"{args.attachment_point}_{args.layer}"]) 94 | 95 | count += image['pixel_values'].shape[0] 96 | pbar.set_postfix({'Processed data points': count}) 97 | 98 | if count >= args.save_every * (save_count + 1): 99 | save_activations(activations, count, args.split, save_count, args) 100 | activations = [] 101 | save_count += 1 102 | 103 | if activations: 104 | save_activations(activations, count, args.split, save_count, args) 105 | 106 | 107 | if __name__ == "__main__": 108 | args = get_args_parser() 109 | args = args.parse_args() 110 | if args.output_dir: 111 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 112 | collect_activations(args) -------------------------------------------------------------------------------- /scripts/matryoshka_hierarchy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET_PATH="${INAT_PATH}" 4 | 5 | # 1. Save original activations 6 | for SPLIT in "train" "val"; do 7 | python save_activations.py \ 8 | --batch_size 32 \ 9 | --model_name "clip-vit-large-patch14-336" \ 10 | --attachment_point "post_projection" \ 11 | --layer "-1" \ 12 | --dataset_name "inat" \ 13 | --split "${SPLIT}" \ 14 | --data_path "${DATASET_PATH}" \ 15 | --num_workers 8 \ 16 | --output_dir "./activations_dir/raw/inat_${SPLIT}_activations_clip-vit-large-patch14-336_-1_post_projection" \ 17 | --cls_only \ 18 | --save_every 100 19 | done 20 | 21 | # 2. Train SAE 22 | python sae_train.py \ 23 | --sae_model "matroyshka_batch_top_k" \ 24 | --activations_dir "activations_dir/raw/inat_train_activations_clip-vit-large-patch14-336_-1_post_projection" \ 25 | --val_activations_dir "activations_dir/raw/inat_val_activations_clip-vit-large-patch14-336_-1_post_projection" \ 26 | --checkpoints_dir "checkpoints_dir/matroyshka_batch_top_k_20_x2" \ 27 | --expansion_factor 2 \ 28 | --steps 110000 \ 29 | --save_steps 20000 \ 30 | --log_steps 10000 \ 31 | --batch_size 4096 \ 32 | --k 20 \ 33 | --auxk_alpha 0.03 \ 34 | --decay_start 109999 \ 35 | --group_fractions 0.002 0.009 0.035 0.189 0.765 36 | 37 | # 3. Save SAE activations 38 | python save_activations.py \ 39 | --batch_size 32 \ 40 | --model_name "clip-vit-large-patch14-336" \ 41 | --attachment_point "post_projection" \ 42 | --layer "-1" \ 43 | --dataset_name "inat" \ 44 | --split "train" \ 45 | --data_path "${DATASET_PATH}" \ 46 | --num_workers 8 \ 47 | --output_dir "./activations_dir/matroyshka_batch_top_k_20_x2/inat_train_activations_clip-vit-large-patch14-336_-1_post_projection" \ 48 | --cls_only \ 49 | --save_every 100 \ 50 | --sae_model "matroyshka_batch_top_k" \ 51 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x2/inat_train_activations_clip-vit-large-patch14-336_-1_post_projection_matroyshka_batch_top_k_20_x2/trainer_0/checkpoints/ae_100000.pt" 52 | 53 | # 4. Compute LCA depth per level 54 | python find_hai_indices.py \ 55 | --activations_dir "./activations_dir/matroyshka_batch_top_k_20_x2/inat_train_activations_clip-vit-large-patch14-336_-1_post_projection" \ 56 | --dataset_name "inat" \ 57 | --data_path "${DATASET_PATH}" \ 58 | --split "train" \ 59 | --k 16 \ 60 | --chunk_size 1000 61 | 62 | python inat_depth.py \ 63 | --activations_dir "./activations_dir/matroyshka_batch_top_k_20_x2/inat_train_activations_clip-vit-large-patch14-336_-1_post_projection" \ 64 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x2/inat_train_activations_clip-vit-large-patch14-336_-1_post_projection/hai_indices_16.npy" \ 65 | --data_path "${DATASET_PATH}" \ 66 | --split "train" \ 67 | --k 16 \ 68 | --group_fractions 0.002 0.009 0.035 0.189 0.765 69 | -------------------------------------------------------------------------------- /scripts/mllm_steering.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET_PATH="${IMAGENET_PATH}" 4 | 5 | # 1. Save original activations 6 | for SPLIT in "train" "val"; do 7 | python save_activations.py \ 8 | --batch_size 32 \ 9 | --model_name "clip-vit-large-patch14-336" \ 10 | --attachment_point "post_mlp_residual" \ 11 | --layer 22 \ 12 | --dataset_name "imagenet" \ 13 | --split "${SPLIT}" \ 14 | --data_path "${DATASET_PATH}" \ 15 | --num_workers 8 \ 16 | --output_dir "./activations_dir/raw/random_k_2/imagenet_${SPLIT}_activations_clip-vit-large-patch14-336_22_post_mlp_residual" \ 17 | --random_k 2 \ 18 | --save_every 100 19 | done 20 | 21 | # 2. Train SAE 22 | python sae_train.py \ 23 | --sae_model "matroyshka_batch_top_k" \ 24 | --activations_dir "activations_dir/raw/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual" \ 25 | --val_activations_dir "activations_dir/raw/random_k_2/imagenet_val_activations_clip-vit-large-patch14-336_22_post_mlp_residual" \ 26 | --checkpoints_dir "checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/" \ 27 | --expansion_factor 64 \ 28 | --steps 110000 \ 29 | --save_steps 20000 \ 30 | --log_steps 10000 \ 31 | --batch_size 4096 \ 32 | --k 20 \ 33 | --auxk_alpha 0.03 \ 34 | --decay_start 109999 \ 35 | --group_fractions 0.0625 0.125 0.25 0.5625 36 | 37 | # 3. Save SAE activations 38 | python save_activations.py \ 39 | --batch_size 32 \ 40 | --model_name "clip-vit-large-patch14-336" \ 41 | --attachment_point "post_mlp_residual" \ 42 | --layer 22 \ 43 | --dataset_name "imagenet" \ 44 | --split "train" \ 45 | --data_path "${DATASET_PATH}" \ 46 | --num_workers 8 \ 47 | --output_dir "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual" \ 48 | --mean_pool \ 49 | --save_every 100 \ 50 | --sae_model "matroyshka_batch_top_k" \ 51 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual_matroyshka_batch_top_k_20_x64/trainer_0/checkpoints/ae_100000.pt" 52 | 53 | # 4. Visualize neurons 54 | python find_hai_indices.py \ 55 | --activations_dir "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual" \ 56 | --dataset_name "imagenet" \ 57 | --data_path "${DATASET_PATH}" \ 58 | --split "train" \ 59 | --k 16 \ 60 | --chunk_size 1000 61 | 62 | python visualize_neurons.py \ 63 | --output_dir "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual" \ 64 | --top_k 16 \ 65 | --dataset_name "imagenet" \ 66 | --data_path "${DATASET_PATH}" \ 67 | --split "train" \ 68 | --group_fractions 0.0625 0.125 0.25 0.5625 \ 69 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual/hai_indices_16.npy" 70 | 71 | # 5. Compute steering score 72 | python encode_images.py \ 73 | --embeddings_path "embeddings_dir/imagenet_train_embeddings_clip-vit-base-patch32.pt" \ 74 | --model_name "clip-vit-base-patch32" \ 75 | --dataset_name "imagenet" \ 76 | --split "train" \ 77 | --data_path "${DATASET_PATH}" \ 78 | --batch_size 128 79 | 80 | python imagenet_subset.py \ 81 | --imagenet_root "${DATASET_PATH}" \ 82 | --output_dir "./images_imagenet" 83 | 84 | # no steering (1000 images, 10 neurons) 85 | python steering_score.py \ 86 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual/hai_indices_16.npy" \ 87 | --embeddings_path "./embeddings_dir/imagenet_train_embeddings_clip-vit-base-patch32.pt" \ 88 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual_matroyshka_batch_top_k_20_x64/trainer_0/checkpoints/ae_100000.pt" \ 89 | --images_path "./images_imagenet/" \ 90 | --no-pre_zero \ 91 | --model_name "clip-vit-base-patch32" \ 92 | --neuron_prefix 10 \ 93 | --no-steer \ 94 | --output_path "./llava_results_dir/1000/no_steering/" 95 | 96 | # steering (1000 images, 10 neurons) 97 | python steering_score.py \ 98 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual/hai_indices_16.npy" \ 99 | --embeddings_path "./embeddings_dir/imagenet_train_embeddings_clip-vit-base-patch32.pt" \ 100 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual_matroyshka_batch_top_k_20_x64/trainer_0/checkpoints/ae_100000.pt" \ 101 | --images_path "./images_imagenet/" \ 102 | --no-pre_zero \ 103 | --model_name "clip-vit-base-patch32" \ 104 | --neuron_prefix 10 \ 105 | --steer \ 106 | --output_path "./llava_results_dir/1000/steering/" 107 | 108 | # steering (1 image, 1000 neurons) 109 | python steering_score.py \ 110 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual/hai_indices_16.npy" \ 111 | --embeddings_path "./embeddings_dir/imagenet_train_embeddings_clip-vit-base-patch32.pt" \ 112 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual_matroyshka_batch_top_k_20_x64/trainer_0/checkpoints/ae_100000.pt" \ 113 | --images_path "./images/" \ 114 | --no-pre_zero \ 115 | --model_name "clip-vit-base-patch32" \ 116 | --neuron_prefix 1000 \ 117 | --no-steer \ 118 | --output_path "./llava_results_dir/1/no_steering/" 119 | 120 | python steering_score.py \ 121 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual/hai_indices_16.npy" \ 122 | --embeddings_path "./embeddings_dir/imagenet_train_embeddings_clip-vit-base-patch32.pt" \ 123 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual_matroyshka_batch_top_k_20_x64/trainer_0/checkpoints/ae_100000.pt" \ 124 | --images_path "./images/" \ 125 | --no-pre_zero \ 126 | --model_name "clip-vit-base-patch32" \ 127 | --neuron_prefix 1000 \ 128 | --steer \ 129 | --output_path "./llava_results_dir/1/steering/" 130 | 131 | # 6. Compute baseline scores 132 | python similarity_baseline.py 133 | 134 | # 7. Finding qualitative examples (e.g. steering pencil neuron) 135 | python steering_qualitative.py 136 | -------------------------------------------------------------------------------- /scripts/monosemanticity_score.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET="imagenet" 4 | DATASET_PATH="${IMAGENET_PATH}" 5 | LAYERS=("-1" 11 17 22 23) 6 | MODEL_NAME="clip-vit-large-patch14-336" # "siglip-so400m-patch14-384" 7 | EXPANSION_FACTORS=(64 16 8 4 2 1) 8 | VISION_ENCODER="clip-vit-base-patch32" # "dinov2-base" 9 | 10 | # 1. Save original activations 11 | for LAYER in "${LAYERS[@]}"; do 12 | if [ "${LAYER}" == "-1" ]; then 13 | POINT="post_projection" 14 | else 15 | POINT="post_mlp_residual" 16 | fi 17 | for SPLIT in "train" "val"; do 18 | python save_activations.py \ 19 | --batch_size 32 \ 20 | --model_name "${MODEL_NAME}" \ 21 | --attachment_point "${POINT}" \ 22 | --layer "${LAYER}" \ 23 | --dataset_name "${DATASET}" \ 24 | --split "${SPLIT}" \ 25 | --data_path "${DATASET_PATH}" \ 26 | --num_workers 8 \ 27 | --output_dir "./activations_dir/raw/${DATASET}_${SPLIT}_activations_${MODEL_NAME}_${LAYER}_${POINT}" \ 28 | --cls_only \ 29 | --save_every 100 30 | done 31 | done 32 | 33 | # 2. Train SAE 34 | for LAYER in "${LAYERS[@]}"; do 35 | if [ "${LAYER}" == "-1" ]; then 36 | POINT="post_projection" 37 | else 38 | POINT="post_mlp_residual" 39 | fi 40 | for EXPANSION_FACTOR in "${EXPANSION_FACTORS[@]}"; do 41 | python sae_train.py \ 42 | --sae_model "matroyshka_batch_top_k" \ 43 | --activations_dir "activations_dir/raw/${DATASET}_train_activations_${MODEL_NAME}_${LAYER}_${POINT}" \ 44 | --val_activations_dir "activations_dir/raw/${DATASET}_val_activations_${MODEL_NAME}_${LAYER}_${POINT}" \ 45 | --checkpoints_dir "checkpoints_dir/matroyshka_batch_top_k_20_x${EXPANSION_FACTOR}" \ 46 | --expansion_factor "${EXPANSION_FACTOR}" \ 47 | --steps 110000 \ 48 | --save_steps 20000 \ 49 | --log_steps 10000 \ 50 | --batch_size 4096 \ 51 | --k 20 \ 52 | --auxk_alpha 0.03 \ 53 | --decay_start 109999 \ 54 | --group_fractions 0.0625 0.125 0.25 0.5625 55 | 56 | python sae_train.py \ 57 | --sae_model "batch_top_k" \ 58 | --activations_dir "activations_dir/raw/${DATASET}_train_activations_${MODEL_NAME}_${LAYER}_${POINT}" \ 59 | --val_activations_dir "activations_dir/raw/${DATASET}_val_activations_${MODEL_NAME}_${LAYER}_${POINT}" \ 60 | --checkpoints_dir "checkpoints_dir/batch_top_k_20_x${EXPANSION_FACTOR}" \ 61 | --expansion_factor "${EXPANSION_FACTOR}" \ 62 | --steps 110000 \ 63 | --save_steps 20000 \ 64 | --log_steps 10000 \ 65 | --batch_size 4096 \ 66 | --k 20 \ 67 | --auxk_alpha 0.03 \ 68 | --decay_start 109999 69 | done 70 | done 71 | 72 | # 3. Save SAE activations 73 | for LAYER in "${LAYERS[@]}"; do 74 | if [ "${LAYER}" == "-1" ]; then 75 | POINT="post_projection" 76 | else 77 | POINT="post_mlp_residual" 78 | fi 79 | for EXPANSION_FACTOR in "${EXPANSION_FACTORS[@]}"; do 80 | for SAE_MODEL in "matroyshka_batch_top_k" "batch_top_k"; do 81 | python save_activations.py \ 82 | --batch_size 32 \ 83 | --model_name "${MODEL_NAME}" \ 84 | --attachment_point "${POINT}" \ 85 | --layer "${LAYER}" \ 86 | --dataset_name "${DATASET}" \ 87 | --split "val" \ 88 | --data_path "${DATASET_PATH}" \ 89 | --num_workers 8 \ 90 | --output_dir "./activations_dir/${SAE_MODEL}_20_x${EXPANSION_FACTOR}/${DATASET}_val_activations_${MODEL_NAME}_${LAYER}_${POINT}" \ 91 | --cls_only \ 92 | --save_every 100 \ 93 | --sae_model "${SAE_MODEL}" \ 94 | --sae_path "./checkpoints_dir/${SAE_MODEL}_20_x${EXPANSION_FACTOR}/${DATASET}_train_activations_${MODEL_NAME}_${LAYER}_${POINT}_${SAE_MODEL}_20_x${EXPANSION_FACTOR}/trainer_0/checkpoints/ae_100000.pt" 95 | done 96 | done 97 | done 98 | 99 | # 4. Save vision encoder embeddings 100 | python encode_images.py \ 101 | --embeddings_path "embeddings_dir/${DATASET}_val_embeddings_${VISION_ENCODER}.pt" \ 102 | --model_name "${VISION_ENCODER}" \ 103 | --dataset_name "${DATASET}" \ 104 | --split "val" \ 105 | --data_path "${DATASET_PATH}" \ 106 | --batch_size 128 107 | 108 | # 5. Compute Monosemanticity Score 109 | for LAYER in "${LAYERS[@]}"; do 110 | if [ "${LAYER}" == "-1" ]; then 111 | POINT="post_projection" 112 | else 113 | POINT="post_mlp_residual" 114 | fi 115 | # SAE neurons 116 | for EXPANSION_FACTOR in "${EXPANSION_FACTORS[@]}"; do 117 | for SAE_MODEL in "matroyshka_batch_top_k" "batch_top_k"; do 118 | python metric.py \ 119 | --activations_dir "activations_dir/${SAE_MODEL}_20_x${EXPANSION_FACTOR}/${DATASET}_val_activations_${MODEL_NAME}_${LAYER}_${POINT}" \ 120 | --embeddings_path ${EMBEDDINGS_PATH} \ 121 | --output_subdir "ms_${VISION_ENCODER}" 122 | done 123 | done 124 | # original neurons 125 | python metric.py \ 126 | --activations_dir "activations_dir/raw/${DATASET}_val_activations_${MODEL_NAME}_${LAYER}_${POINT}" \ 127 | --embeddings_path ${EMBEDDINGS_PATH} \ 128 | --output_subdir "ms_${VISION_ENCODER}" 129 | done 130 | 131 | # 6. Visualize neurons (of selected SAE, using training set) 132 | python save_activations.py \ 133 | --batch_size 32 \ 134 | --model_name "${MODEL_NAME}" \ 135 | --attachment_point "post_projection" \ 136 | --layer "-1" \ 137 | --dataset_name "${DATASET}" \ 138 | --split "train" \ 139 | --data_path "${DATASET_PATH}" \ 140 | --num_workers 8 \ 141 | --output_dir "./activations_dir/matroyshka_batch_top_k_20_x4/${DATASET}_train_activations_${MODEL_NAME}_-1_post_projection" \ 142 | --cls_only \ 143 | --save_every 100 \ 144 | --sae_model "matroyshka_batch_top_k" \ 145 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x4/${DATASET}_train_activations_${MODEL_NAME}_-1_post_projection_matroyshka_batch_top_k_20_x4/trainer_0/checkpoints/ae_100000.pt" 146 | 147 | python find_hai_indices.py \ 148 | --activations_dir "./activations_dir/matroyshka_batch_top_k_20_x4/${DATASET}_train_activations_${MODEL_NAME}_-1_post_projection" \ 149 | --dataset_name "${DATASET}" \ 150 | --data_path "${DATASET_PATH}" \ 151 | --split "train" \ 152 | --k 16 \ 153 | --chunk_size 1000 154 | 155 | python visualize_neurons.py \ 156 | --output_dir "./activations_dir/matroyshka_batch_top_k_20_x4/${DATASET}_train_activations_${MODEL_NAME}_-1_post_projection" \ 157 | --top_k 16 \ 158 | --dataset_name "${DATASET}" \ 159 | --data_path "${DATASET_PATH}" \ 160 | --split "train" \ 161 | --group_fractions 0.0625 0.125 0.25 0.5625 \ 162 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x4/${DATASET}_train_activations_${MODEL_NAME}_-1_post_projection/hai_indices_16.npy" 163 | -------------------------------------------------------------------------------- /steering_qualitative.py: -------------------------------------------------------------------------------- 1 | from models.llava import Llava 2 | from dictionary_learning.trainers import MatroyshkaBatchTopKSAE 3 | from PIL import Image 4 | import requests 5 | 6 | llava = Llava("cuda") 7 | sae_path = "checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual_matroyshka_batch_top_k_20_x64/trainer_0/checkpoints/ae_100000.pt" 8 | sae = MatroyshkaBatchTopKSAE.from_pretrained(sae_path).cuda() 9 | text = "Write me a short love poem" 10 | url = "https://img.freepik.com/free-photo/cement-texture_1194-5269.jpg?semt=ais_hybrid" 11 | image = Image.open(requests.get(url, stream=True).raw) 12 | for neuron in [39]: # Note: ID of the pencil neuron may change after retraining SAE 13 | for alpha in [0, 30, 40, 50]: 14 | print(f"neuron {neuron}, alpha {alpha}") 15 | llava.attach_and_fix(sae=sae, neurons_to_fix={neuron: alpha}, pre_zero=False) 16 | output = llava.prompt(text, image, max_tokens=60)[0] 17 | print(output) 18 | print("======================================================") -------------------------------------------------------------------------------- /steering_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tqdm 3 | import os 4 | import torch 5 | import torch.nn.functional as F 6 | from models.llava import Llava 7 | from utils import IdentitySAE, get_text_model 8 | import argparse 9 | from dictionary_learning.trainers import MatroyshkaBatchTopKSAE 10 | from PIL import Image 11 | import random 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description="Compute CLIP-based score for steering accuracy") 15 | parser.add_argument('--hai_indices_path', type=str, required=True) 16 | parser.add_argument('--embeddings_path', type=str, required=True) 17 | parser.add_argument('--sae_path', type=str, default=None) 18 | parser.add_argument('--images_path', type=str, required=True) 19 | parser.add_argument("--model_name", type=str, required=True) 20 | parser.add_argument("--device", type=str, default="cuda") 21 | parser.add_argument("--pre_zero", action=argparse.BooleanOptionalAction) 22 | parser.add_argument('--output_path', type=str, required=True) 23 | parser.add_argument('--neuron_prefix', type=int, default=None) 24 | parser.add_argument("--steer", action=argparse.BooleanOptionalAction) 25 | 26 | return parser.parse_args() 27 | 28 | if __name__ == "__main__": 29 | # Parse command line arguments 30 | args = parse_args() 31 | args.batch_size = 1 # not used 32 | args.num_workers = 0 # not used 33 | 34 | # Get HAI indices 35 | hai_indices = torch.from_numpy(np.load(args.hai_indices_path)).to(args.device) 36 | print(f"Loaded HAI indices found at {args.hai_indices_path}") 37 | print(f"hai_indices shape: {hai_indices.shape}") 38 | 39 | # Get image embeddings 40 | embeddings = torch.load(args.embeddings_path).to(args.device) 41 | print(f"Loaded embeddings found at {args.embeddings_path}") 42 | print(f"embeddings shape: {embeddings.shape}") 43 | 44 | # Compute mean image embedding of HAI per neuron 45 | hai_embeddings = embeddings[hai_indices] 46 | print(f"hai_embeddings shape: {hai_embeddings.shape}") # (num_neurons, k, embedding_dim) 47 | hai_embeddings = hai_embeddings.mean(dim=1) 48 | print(f"hai_embeddings shape: {hai_embeddings.shape}") # (num_neurons, embedding_dim) 49 | 50 | # Load LLaVA model 51 | llava = Llava(args.device) 52 | if args.sae_path: 53 | sae = MatroyshkaBatchTopKSAE.from_pretrained(args.sae_path).to(args.device) 54 | print(f"Attached SAE from {args.sae_path}") 55 | else: 56 | sae = IdentitySAE() 57 | print(f"Attached Identity SAE (scoring original neurons)") 58 | 59 | # Filter neurons if prefix is given 60 | num_neurons = hai_embeddings.shape[0] 61 | if args.neuron_prefix: 62 | neuron_indices = list(range(args.neuron_prefix)) 63 | else: 64 | neuron_indices = list(range(num_neurons)) 65 | print(f"Evaluating on {len(neuron_indices)} neurons") 66 | 67 | # Label images while clamping each neuron 68 | image_files = [f for f in os.listdir(args.images_path) if f.endswith(('png', 'jpg', 'jpeg', '.JPEG'))] 69 | print(f"Found {len(image_files)} images in {args.images_path}") 70 | text = "What is shown in this image? Use exactly one word!" 71 | labels = {neuron: [] for neuron in neuron_indices} 72 | if args.steer: 73 | print("Steering") 74 | else: 75 | print("Not steering") 76 | for neuron in tqdm.tqdm(neuron_indices, desc="Processing neurons"): 77 | if args.steer: 78 | llava.attach_and_fix(sae=sae, neurons_to_fix={neuron: 100}, pre_zero=args.pre_zero) 79 | for image_file in image_files: 80 | image_path = os.path.join(args.images_path, image_file) 81 | image = Image.open(image_path) 82 | label = llava.prompt(text, image, max_tokens=5)[0].split(" ")[0] 83 | labels[neuron].append((image_file, label)) 84 | 85 | # Save labels 86 | labels_path = os.path.join(os.path.dirname(args.output_path), "labels.txt") 87 | os.makedirs(os.path.dirname(labels_path), exist_ok=True) 88 | with open(labels_path, "w") as f: 89 | for neuron, image_labels in labels.items(): 90 | for image_file, label in image_labels: 91 | f.write(f"{neuron},{image_file},{label}\n") 92 | print(f"Labels saved to {labels_path}") 93 | 94 | # Compute text embeddings for neuron-clamped LLaVA labels 95 | text_encoder, tokenizer = get_text_model(args) 96 | label_embeddings = torch.zeros(len(neuron_indices), len(image_files), hai_embeddings.shape[1]).to(args.device) 97 | for i, (neuron, image_labels) in tqdm.tqdm(enumerate(labels.items()), desc="Computing text embeddings"): 98 | for j, label in enumerate(image_labels): 99 | with torch.no_grad(): 100 | inputs = tokenizer([label[1]], padding=True, return_tensors="pt").to(args.device) 101 | outputs = text_encoder(**inputs) 102 | label_embeddings[i, j] = outputs.text_embeds 103 | 104 | # Compute cosine similarities 105 | cosine_similarities = [] 106 | for i in range(label_embeddings.shape[1]): 107 | cosine_similarities.append(F.cosine_similarity(hai_embeddings[neuron_indices], label_embeddings[:, i], dim=1)) 108 | cosine_similarities = torch.cat(cosine_similarities) 109 | print(cosine_similarities.shape) 110 | torch.save(cosine_similarities, os.path.join(os.path.dirname(args.output_path), "scores_per_neuron")) 111 | mean_cosine_similarity = cosine_similarities.mean().item() 112 | std_cosine_similarity = cosine_similarities.std().item() 113 | 114 | print("Mean Cosine Similarity:", mean_cosine_similarity) 115 | print("Standard Deviation Cosine Similarity:", std_cosine_similarity) 116 | 117 | # Save results 118 | with open(os.path.join(os.path.dirname(args.output_path), "metric.txt"), "w") as f: 119 | f.write(f"Mean Cosine Similarity: {mean_cosine_similarity}\n") 120 | f.write(f"Standard Deviation Cosine Similarity: {std_cosine_similarity}\n") 121 | 122 | print("Done") -------------------------------------------------------------------------------- /uniqueness.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | from itertools import combinations 5 | from tqdm import tqdm # Progress bar 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser(description="Measure uniqueness of neurons with pairwise Jaccard Index") 10 | parser.add_argument('--activations_dir', type=str, required=True) 11 | parser.add_argument('--k', type=int, default=16) 12 | return parser.parse_args() 13 | 14 | 15 | def jaccard_index(set1, set2): 16 | """Compute Jaccard index between two sets.""" 17 | intersection = len(set1.intersection(set2)) 18 | union = len(set1.union(set2)) 19 | return intersection / union if union != 0 else 0 20 | 21 | 22 | if __name__ == "__main__": 23 | args = parse_args() 24 | 25 | hai_indices_path = os.path.join(args.activations_dir, f"hai_indices_{args.k}.npy") 26 | worst_scores_path = os.path.join(args.activations_dir, f"hai_indices_{args.k}_worst.npy") 27 | 28 | hai_indices = np.load(hai_indices_path) # (num_neurons, k) 29 | worst_scores = np.load(worst_scores_path) # (num_neurons) 30 | 31 | print(f"Loaded HAI indices from {hai_indices_path}") 32 | print(f"Loaded worst scores from {worst_scores_path}") 33 | 34 | # Correct mask condition 35 | mask = worst_scores != 0 # Keep only neurons that are NOT "dead" 36 | hai_indices = hai_indices[mask] 37 | 38 | print(f"Removed {np.count_nonzero(~mask)} dead (or almost dead) neurons") 39 | print(f"Remaining neurons: {hai_indices.shape[0]}") 40 | 41 | # Compute pairwise Jaccard index 42 | num_neurons = hai_indices.shape[0] 43 | jaccard_scores = [] 44 | index_pairs = [] 45 | 46 | total_pairs = (num_neurons * (num_neurons - 1)) // 2 # Number of unique pairs 47 | 48 | for i, j in tqdm(combinations(range(num_neurons), 2), total=total_pairs, desc="Computing Jaccard Index"): 49 | set1, set2 = set(hai_indices[i]), set(hai_indices[j]) 50 | jaccard = jaccard_index(set1, set2) 51 | jaccard_scores.append(jaccard) 52 | index_pairs.append((i, j)) 53 | 54 | jaccard_scores = np.array(jaccard_scores) 55 | 56 | # Sort Jaccard scores 57 | sorted_indices = np.argsort(jaccard_scores) # Ascending order 58 | 59 | # Extract top and bottom 10 index pairs 60 | top_10 = [(index_pairs[i], jaccard_scores[i]) for i in sorted_indices[-10:]] # Top 10 highest Jaccard 61 | bottom_10 = [(index_pairs[i], jaccard_scores[i]) for i in sorted_indices[:10]] # Bottom 10 lowest Jaccard 62 | 63 | print(f"Total pairs: {total_pairs}") 64 | # Count how many have Jaccard Index > 0.1 * i 65 | for i in np.arange(0, 1, 0.1): 66 | high_similarity_count = np.count_nonzero(jaccard_scores > i) 67 | print(f"Pairs with Jaccard index > {i:.1f}: {high_similarity_count}") 68 | print(f"Ratio: {high_similarity_count / total_pairs:.4f}") 69 | 70 | print(f"Top 10 Indexes (Most Similar Pairs): {top_10}") 71 | print(f"Bottom 10 Indexes (Most Unique Pairs): {bottom_10}") 72 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Subset 2 | from torchvision.datasets import ImageNet, ImageFolder 3 | import torch.nn as nn 4 | from models.clip import Clip 5 | from models.dino import Dino 6 | from models.siglip import Siglip 7 | import os 8 | from transformers import AutoTokenizer, CLIPTextModelWithProjection 9 | 10 | def get_collate_fn(processor): 11 | def collate_fn(batch): 12 | images = [img[0] for img in batch] 13 | return processor(images=images, return_tensors="pt", padding=True) 14 | return collate_fn 15 | 16 | def get_dataset(args, preprocess, processor, split, subset=1.0): 17 | if args.dataset_name == 'cc3m': 18 | # if subset < 1.0: 19 | # raise NotImplementedError 20 | # return get_cc3m(args, preprocess, split) 21 | raise NotImplementedError 22 | elif args.dataset_name == 'inat_birds': 23 | ds = ImageFolder(root=os.path.join(args.data_path, split), transform=preprocess) 24 | elif args.dataset_name == 'inat': 25 | ds = ImageFolder(root=os.path.join(args.data_path, split), transform=preprocess) 26 | elif args.dataset_name == 'imagenet': 27 | ds = ImageNet(root=args.data_path, split=split, transform=preprocess) 28 | elif args.dataset_name == 'cub': 29 | ds = ImageFolder(root=os.path.join(args.data_path, split), transform=preprocess) 30 | 31 | keep_every = int(1.0 / subset) 32 | if keep_every > 1: 33 | ds = Subset(ds, list(range(0, len(ds), keep_every))) 34 | if processor is not None: 35 | dl = DataLoader(ds, batch_size=args.batch_size, shuffle=False, 36 | num_workers=args.num_workers, collate_fn=get_collate_fn(processor)) 37 | else: 38 | dl = DataLoader(ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 39 | 40 | return ds, dl 41 | 42 | def get_model(args): 43 | if args.model_name.startswith('clip'): 44 | clip = Clip(args.model_name, args.device) 45 | return clip, clip.processor 46 | elif args.model_name.startswith('dino'): 47 | dino = Dino(args.model_name, args.device) 48 | return dino, dino.processor 49 | elif args.model_name.startswith('siglip'): 50 | siglip = Siglip(args.model_name, args.device) 51 | return siglip, siglip.processor 52 | 53 | def get_text_model(args): 54 | if args.model_name.startswith('clip'): 55 | model = CLIPTextModelWithProjection.from_pretrained(f"openai/{args.model_name}").to(args.device) 56 | tokenizer = AutoTokenizer.from_pretrained(f"openai/{args.model_name}") 57 | return model, tokenizer 58 | 59 | class IdentitySAE(nn.Module): 60 | def encode(self, x): 61 | return x 62 | def decode(self, x): 63 | return x 64 | -------------------------------------------------------------------------------- /visualize_neurons.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from PIL import Image 4 | import os 5 | from torchvision import transforms 6 | from utils import get_dataset 7 | import argparse 8 | from math import isclose 9 | 10 | 11 | def image_grid(imgs, rows, cols): 12 | assert len(imgs) == rows * cols, "Number of images must match rows * cols." 13 | w, h = imgs[0].size 14 | grid = Image.new("RGB", size=(cols * w, rows * h)) 15 | for i, img in enumerate(imgs): 16 | grid.paste(img, box=(i % cols * w, i // cols * h)) 17 | return grid 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description="Visualize top-k activating images for neurons.") 22 | parser.add_argument('--output_dir', type=str, required=True) 23 | parser.add_argument('--top_k', type=int, default=16) 24 | parser.add_argument("--dataset_name", default="imagenet", type=str) 25 | parser.add_argument("--data_path", default="/shared-network/inat2021", type=str) 26 | parser.add_argument('--split', type=str, default='train') 27 | parser.add_argument('--visualization_size', type=int, default=224) 28 | parser.add_argument('--group_fractions', type=float, nargs='+') 29 | parser.add_argument('--hai_indices_path', type=str) 30 | return parser.parse_args() 31 | 32 | 33 | if __name__ == "__main__": 34 | # Parse command line arguments 35 | args = parse_args() 36 | args.batch_size = 1 # not used 37 | args.num_workers = 0 # not used 38 | 39 | importants = np.load(args.hai_indices_path) 40 | print(f"Loaded HAI indices found at {args.hai_indices_path}", flush=True) 41 | num_neurons = importants.shape[0] 42 | 43 | # Visualize selected images 44 | def _convert_to_rgb(image): 45 | return image.convert("RGB") 46 | 47 | visualization_preprocess = transforms.Compose([ 48 | transforms.Resize(size=224, interpolation=Image.BICUBIC), 49 | transforms.CenterCrop(size=(224, 224)), 50 | _convert_to_rgb, 51 | ]) 52 | 53 | ds, dl = get_dataset(args, preprocess=visualization_preprocess, processor=None, split=args.split, subset=1) 54 | 55 | os.makedirs(os.path.join(args.output_dir, 'hai'), exist_ok=True) 56 | 57 | assert isclose(sum(args.group_fractions), 1.0), "group_fractions must sum to 1.0" 58 | group_sizes = [int(f * num_neurons) for f in args.group_fractions[:-1]] 59 | group_sizes.append(num_neurons - sum(group_sizes)) 60 | 61 | start_idx = 0 62 | for group_idx, group_size in enumerate(group_sizes): 63 | end_idx = start_idx + group_size 64 | group_neurons = range(start_idx, end_idx) 65 | 66 | for neuron_id, absolute_id in enumerate(group_neurons[:5000]): 67 | print(f"Visualizing neuron {neuron_id} (absolute {absolute_id}) in group {group_idx}", flush=True) 68 | 69 | important = importants[absolute_id] 70 | images = [ds[i][0] for i in important] 71 | s = int(np.sqrt(args.top_k)) 72 | grid_image = image_grid(images[::-1], rows=s, cols=s) 73 | 74 | plt.imshow(grid_image) 75 | plt.axis('off') 76 | filename = f"group_{group_idx}_neuron_{neuron_id}_absolute_{absolute_id}.png" 77 | output_path = os.path.join(args.output_dir, 'tree', filename) 78 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 79 | plt.savefig(output_path, bbox_inches='tight', pad_inches=0) 80 | plt.close() # Close the plot to free memory 81 | 82 | start_idx = end_idx 83 | --------------------------------------------------------------------------------