├── README.md
├── assets
└── teaser.svg
├── datasets
├── __init__.py
└── activations.py
├── dictionary_learning
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── buffer.py
├── config.py
├── dictionary.py
├── evaluation.py
├── grad_pursuit.py
├── interp.py
├── pretrained_dictionary_downloader.sh
├── requirements.txt
├── tests
│ └── test_end_to_end.py
├── trainers
│ ├── __init__.py
│ ├── batch_top_k.py
│ ├── gated_anneal.py
│ ├── gdm.py
│ ├── jumprelu.py
│ ├── matroyshka_batch_top_k.py
│ ├── p_anneal.py
│ ├── standard.py
│ ├── top_k.py
│ └── trainer.py
├── training.py
└── utils.py
├── encode_images.py
├── find_hai_indices.py
├── imagenet_subset.py
├── images
└── white.png
├── inat_depth.py
├── metric.py
├── models
├── clip.py
├── dino.py
├── llava.py
└── siglip.py
├── requirements.txt
├── sae_train.py
├── save_activations.py
├── scripts
├── matryoshka_hierarchy.sh
├── mllm_steering.sh
└── monosemanticity_score.sh
├── similarity_baseline.py
├── steering_qualitative.py
├── steering_score.py
├── uniqueness.py
├── utils.py
└── visualize_neurons.py
/README.md:
--------------------------------------------------------------------------------
1 | ##
2 |
Sparse Autoencoders Learn Monosemantic Features in Vision-Language Models
3 |
4 |
15 |
16 | Abstract
17 |
18 |
19 | Sparse Autoencoders (SAEs) have recently been shown to enhance interpretability and steerability in Large Language Models (LLMs). In this work, we extend the application of SAEs to Vision-Language Models (VLMs), such as CLIP, and introduce a comprehensive framework for evaluating monosemanticity in vision representations. Our experimental results reveal that SAEs trained on VLMs significantly enhance the monosemanticity of individual neurons while also exhibiting hierarchical representations that align well with expert-defined structures (e.g., iNaturalist taxonomy). Most notably, we demonstrate that applying SAEs to intervene on a CLIP vision encoder, directly steer output from multimodal LLMs (e.g., LLaVA) without any modifications to the underlying model. These findings emphasize the practicality and efficacy of SAEs as an unsupervised approach for enhancing both the interpretability and control of VLMs.
20 |
21 |
22 |
23 |

24 |
25 |
26 | ---
27 | ### Setup
28 | Install required PIP packages.
29 | ```bash
30 | pip install -r requirements.txt
31 | ```
32 | Download following datasets:
33 | * ImageNet (https://pytorch.org/vision/main/generated/torchvision.datasets.ImageNet.html)
34 | * INaturalist 2021 (https://github.com/visipedia/inat_comp/tree/master/2021)
35 |
36 | Export paths to dataset directories. The directories should contain `train/` and `val/` subdirectories.
37 | ```bash
38 | export IMAGENET_PATH=""
39 | export INAT_PATH=""
40 | ```
41 | Code was run using Python version 3.11.10.
42 | ### Running Experiments
43 | The commands required to reproduce the results are organized into scripts located in the `scripts/` directory:
44 | * `monosemanticity_score.sh` computes the Monosemanticity Score (MS) for specified SAEs, layers, models, and image encoders.
45 | * `matryoshka_hierarchy.sh` analyzes the hierarchical structure that emerges in Matryoshka SAEs.
46 | * `mllm_steering.sh` enables experimentation with steering LLaVA using an SAE built on top of the vision encoder.
47 |
48 | We use the implementation of sparse autoencoders available at https://github.com/saprmarks/dictionary_learning.
49 | ### Citation
50 | ```bibtex
51 | @article{pach2025sparse,
52 | title={Sparse Autoencoders Learn Monosemantic Features in Vision-Language Models},
53 | author={Mateusz Pach and Shyamgopal Karthik and Quentin Bouniot and Serge Belongie and Zeynep Akata},
54 | journal={arXiv preprint arXiv:2504.02821},
55 | year={2025}
56 | }
57 | ```
58 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # from .cc3m import get_cc3m
--------------------------------------------------------------------------------
/datasets/activations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import os
4 | import bisect
5 |
6 | class ChunkedActivationsDataset(Dataset):
7 | def __init__(self, directory, transform=None, device="cpu"):
8 | """
9 | Args:
10 | directory (str): Path to the directory containing .pth files.
11 | transform (callable, optional): Optional transform to be applied on a sample.
12 | device (str): Device to store the activations ('cpu' or 'cuda').
13 | """
14 | self.directory = directory
15 | self.files = sorted(
16 | (f for f in os.listdir(directory) if f.endswith('.pth') or f.endswith('.pt')),
17 | key=lambda x: int(x.split('_part')[-1].split('.pt')[0])
18 | )
19 | self.transform = transform
20 | self.device = device
21 | self.file_offsets = []
22 | self.cumulative_lengths = []
23 |
24 | # Cache to store the last accessed file
25 | self._cached_file = None
26 | self._cached_tensor = None
27 |
28 | # Compute the offsets and cumulative lengths for efficient indexing
29 | cumulative = 0
30 | for file in self.files:
31 | tensor = torch.load(os.path.join(directory, file)) # Load file temporarily
32 | length = tensor.size(0) # Number of samples in the file (K)
33 | self.file_offsets.append((file, length))
34 | cumulative += length
35 | self.cumulative_lengths.append(cumulative)
36 | # break
37 |
38 | def __len__(self):
39 | return self.cumulative_lengths[-1] if self.cumulative_lengths else 0
40 |
41 | def _load_file(self, file):
42 | """
43 | Loads a file and caches it for future access.
44 | Moves the tensor to the specified device if not already cached.
45 | """
46 | if self._cached_file != file:
47 | file_path = os.path.join(self.directory, file)
48 | tensor = torch.load(file_path)
49 | self._cached_tensor = tensor.to(self.device) # Move to the specified device
50 | self._cached_file = file
51 |
52 | def __getitem__(self, idx):
53 | """
54 | Args:
55 | idx (int): Index of the sample to retrieve.
56 | Returns:
57 | torch.Tensor: Neuron activation vector on the specified device.
58 | """
59 | # Find the correct file corresponding to the index
60 | for file_idx, (file, length) in enumerate(self.file_offsets):
61 | if idx < self.cumulative_lengths[file_idx]:
62 | # Calculate relative index within the file
63 | relative_idx = idx - (self.cumulative_lengths[file_idx - 1] if file_idx > 0 else 0)
64 |
65 | # Load the file into cache if it's not already cached
66 | self._load_file(file)
67 |
68 | # Retrieve the sample
69 | sample = self._cached_tensor[relative_idx]
70 | if self.transform:
71 | sample = self.transform(sample)
72 | return sample
73 |
74 | raise IndexError(f"Index {idx} out of range for dataset with length {len(self)}")
75 |
76 |
77 | class ActivationsDataset(Dataset):
78 | # def __init__(self, directory, transform=None, device="cpu", take_every=1):
79 | # """
80 | # Args:
81 | # directory (str): Path to the directory containing .pth files.
82 | # transform (callable, optional): Optional transform to be applied on a sample.
83 | # device (str): Device to store the activations ('cpu' or 'cuda').
84 | # """
85 | # self.directory = directory
86 | # self.files = sorted(
87 | # (f for f in os.listdir(directory) if f.endswith('.pth') or f.endswith('.pt')),
88 | # key=lambda x: int(x.split('_part')[-1].split('.pt')[0])
89 | # )
90 | # self.transform = transform
91 | # self.device = device
92 | #
93 | # # Load all tensors into memory
94 | # self.cached_tensors = []
95 | # self.cumulative_lengths = []
96 | # cumulative = 0
97 | #
98 | # for file in self.files:
99 | # tensor = torch.load(os.path.join(directory, file)).to(self.device)
100 | # self.cached_tensors.append(tensor)
101 | # cumulative += tensor.size(0) # Number of samples in the file (K)
102 | # self.cumulative_lengths.append(cumulative)
103 | def __init__(self, directory, transform=None, device="cpu", take_every=1):
104 | """
105 | Args:
106 | directory (str): Path to the directory containing .pth files.
107 | transform (callable, optional): Optional transform to be applied on a sample.
108 | device (str): Device to store the activations ('cpu' or 'cuda').
109 | take_every (int): Load every N-th row from the concatenated tensors to reduce memory usage.
110 | """
111 | self.directory = directory
112 | self.files = sorted(
113 | (f for f in os.listdir(directory) if (f.endswith('.pth') or f.endswith('.pt')) and not f.startswith('all')),
114 | key=lambda x: int(x.split('_part')[-1].split('.pt')[0])
115 | )
116 | self.transform = transform
117 | self.device = device
118 | self.take_every = take_every
119 |
120 | # Load all tensors into memory with skipping rows as per take_every
121 | self.cached_tensors = []
122 | self.cumulative_lengths = []
123 | cumulative = 0
124 |
125 | # Track the global row index
126 | global_row_index = 0
127 |
128 | for file in self.files:
129 | tensor = torch.load(os.path.join(directory, file)).to(self.device)
130 | num_rows = tensor.size(0)
131 |
132 | # Calculate global indices for the rows in this tensor
133 | global_indices = list(range(global_row_index, global_row_index + num_rows))
134 | selected_indices = [
135 | idx - global_row_index for idx in global_indices if idx % self.take_every == 0
136 | ]
137 |
138 | # Extract rows using the selected indices
139 | if selected_indices:
140 | tensor = tensor[selected_indices]
141 | self.cached_tensors.append(tensor)
142 | cumulative += len(selected_indices)
143 | self.cumulative_lengths.append(cumulative)
144 |
145 | # Update the global_row_index
146 | global_row_index += num_rows
147 |
148 | def __len__(self):
149 | return self.cumulative_lengths[-1] if self.cumulative_lengths else 0
150 |
151 | def __getitem__(self, idx):
152 | """
153 | Args:
154 | idx (int): Index of the sample to retrieve.
155 | Returns:
156 | torch.Tensor: Neuron activation vector on the specified device.
157 | """
158 | # # Find the correct tensor corresponding to the index
159 | # for tensor_idx, tensor in enumerate(self.cached_tensors):
160 | # # return torch.randn_like(tensor[0]).to(self.device)
161 | # if idx < self.cumulative_lengths[tensor_idx]:
162 | # # Calculate relative index within the tensor
163 | # relative_idx = idx - (self.cumulative_lengths[tensor_idx - 1] if tensor_idx > 0 else 0)
164 | #
165 | # # Retrieve the sample
166 | # sample = tensor[relative_idx]
167 | # if self.transform:
168 | # sample = self.transform(sample)
169 | # return sample
170 |
171 | tensor_idx = bisect.bisect_right(self.cumulative_lengths, idx)
172 | relative_idx = idx - (self.cumulative_lengths[tensor_idx - 1] if tensor_idx > 0 else 0)
173 | sample = self.cached_tensors[tensor_idx][relative_idx]
174 | if self.transform:
175 | sample = self.transform(sample)
176 | return sample
--------------------------------------------------------------------------------
/dictionary_learning/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 | dictionaries
162 | wandb
163 | experiment*
164 | run_experiment.sh
165 | nohup.out
166 | *.zip
--------------------------------------------------------------------------------
/dictionary_learning/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 saprmarks
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/dictionary_learning/__init__.py:
--------------------------------------------------------------------------------
1 | from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder
2 | from .buffer import ActivationBuffer
--------------------------------------------------------------------------------
/dictionary_learning/config.py:
--------------------------------------------------------------------------------
1 | # debugging flag for use in other scripts
2 | DEBUG = False
--------------------------------------------------------------------------------
/dictionary_learning/evaluation.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for evaluating dictionaries on a model and dataset.
3 | """
4 |
5 | import torch as t
6 | from collections import defaultdict
7 |
8 | from .buffer import ActivationBuffer, NNsightActivationBuffer
9 | from nnsight import LanguageModel
10 | from .config import DEBUG
11 |
12 |
13 | def loss_recovered(
14 | text, # a batch of text
15 | model: LanguageModel, # an nnsight LanguageModel
16 | submodule, # submodules of model
17 | dictionary, # dictionaries for submodules
18 | max_len=None, # max context length for loss recovered
19 | normalize_batch=False, # normalize batch before passing through dictionary
20 | io="out", # can be 'in', 'out', or 'in_and_out'
21 | tracer_args = {'use_cache': False, 'output_attentions': False}, # minimize cache during model trace.
22 | ):
23 | """
24 | How much of the model's loss is recovered by replacing the component output
25 | with the reconstruction by the autoencoder?
26 | """
27 |
28 | if max_len is None:
29 | invoker_args = {}
30 | else:
31 | invoker_args = {"truncation": True, "max_length": max_len }
32 |
33 | with model.trace("_"):
34 | temp_output = submodule.output.save()
35 |
36 | output_is_tuple = False
37 | # Note: isinstance() won't work here as torch.Size is a subclass of tuple,
38 | # so isinstance(temp_output.shape, tuple) would return True even for torch.Size.
39 | if type(temp_output.shape) == tuple:
40 | output_is_tuple = True
41 |
42 | # unmodified logits
43 | with model.trace(text, invoker_args=invoker_args):
44 | logits_original = model.output.save()
45 | logits_original = logits_original.value
46 |
47 | # logits when replacing component activations with reconstruction by autoencoder
48 | with model.trace(text, **tracer_args, invoker_args=invoker_args):
49 | if io == 'in':
50 | x = submodule.input
51 | if normalize_batch:
52 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
53 | x = x * scale
54 | elif io == 'out':
55 | x = submodule.output
56 | if output_is_tuple: x = x[0]
57 | if normalize_batch:
58 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
59 | x = x * scale
60 | elif io == 'in_and_out':
61 | x = submodule.input
62 | if normalize_batch:
63 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
64 | x = x * scale
65 | else:
66 | raise ValueError(f"Invalid value for io: {io}")
67 | x = x.save()
68 |
69 | # If we incorrectly handle output_is_tuple, such as with some mlp submodules, we will get an error here.
70 | assert len(x.shape) == 3, f"Expected x to have shape (B, L, D), got {x.shape}, output_is_tuple: {output_is_tuple}"
71 |
72 | x_hat = dictionary(x).to(model.dtype)
73 |
74 | # intervene with `x_hat`
75 | with model.trace(text, **tracer_args, invoker_args=invoker_args):
76 | if io == 'in':
77 | x = submodule.input
78 | if normalize_batch:
79 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
80 | x_hat = x_hat / scale
81 | submodule.input[:] = x_hat
82 | elif io == 'out':
83 | x = submodule.output
84 | if output_is_tuple: x = x[0]
85 | if normalize_batch:
86 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
87 | x_hat = x_hat / scale
88 | if output_is_tuple:
89 | submodule.output[0][:] = x_hat
90 | else:
91 | submodule.output[:] = x_hat
92 | elif io == 'in_and_out':
93 | x = submodule.input
94 | if normalize_batch:
95 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
96 | x_hat = x_hat / scale
97 | if output_is_tuple:
98 | submodule.output[0][:] = x_hat
99 | else:
100 | submodule.output[:] = x_hat
101 | else:
102 | raise ValueError(f"Invalid value for io: {io}")
103 |
104 | logits_reconstructed = model.output.save()
105 | logits_reconstructed = logits_reconstructed.value
106 |
107 | # logits when replacing component activations with zeros
108 | with model.trace(text, **tracer_args, invoker_args=invoker_args):
109 | if io == 'in':
110 | x = submodule.input
111 | submodule.input[:] = t.zeros_like(x)
112 | elif io in ['out', 'in_and_out']:
113 | x = submodule.output
114 | if output_is_tuple:
115 | submodule.output[0][:] = t.zeros_like(x[0])
116 | else:
117 | submodule.output[:] = t.zeros_like(x)
118 | else:
119 | raise ValueError(f"Invalid value for io: {io}")
120 |
121 | input = model.inputs.save()
122 | logits_zero = model.output.save()
123 |
124 | logits_zero = logits_zero.value
125 |
126 | # get everything into the right format
127 | try:
128 | logits_original = logits_original.logits
129 | logits_reconstructed = logits_reconstructed.logits
130 | logits_zero = logits_zero.logits
131 | except:
132 | pass
133 |
134 | if isinstance(text, t.Tensor):
135 | tokens = text
136 | else:
137 | try:
138 | tokens = input[1]['input_ids']
139 | except:
140 | tokens = input[1]['input']
141 |
142 | # compute losses
143 | losses = []
144 | if hasattr(model, 'tokenizer') and model.tokenizer is not None:
145 | loss_kwargs = {'ignore_index': model.tokenizer.pad_token_id}
146 | else:
147 | loss_kwargs = {}
148 | for logits in [logits_original, logits_reconstructed, logits_zero]:
149 | loss = t.nn.CrossEntropyLoss(**loss_kwargs)(
150 | logits[:, :-1, :].reshape(-1, logits.shape[-1]), tokens[:, 1:].reshape(-1)
151 | )
152 | losses.append(loss)
153 |
154 | return tuple(losses)
155 |
156 | @t.no_grad()
157 | def evaluate(
158 | dictionary, # a dictionary
159 | activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered
160 | max_len=128, # max context length for loss recovered
161 | batch_size=128, # batch size for loss recovered
162 | io="out", # can be 'in', 'out', or 'in_and_out'
163 | normalize_batch=False, # normalize batch before passing through dictionary
164 | tracer_args={'use_cache': False, 'output_attentions': False}, # minimize cache during model trace.
165 | device="cpu",
166 | n_batches: int = 1,
167 | ):
168 | assert n_batches > 0
169 | out = defaultdict(float)
170 | active_features = t.zeros(dictionary.dict_size, dtype=t.float32, device=device)
171 |
172 | for _ in range(n_batches):
173 | try:
174 | x = next(activations).to(device)
175 | if normalize_batch:
176 | x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim ** 0.5)
177 | except StopIteration:
178 | raise StopIteration(
179 | "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data."
180 | )
181 | x_hat, f = dictionary(x, output_features=True)
182 | l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean()
183 | l1_loss = f.norm(p=1, dim=-1).mean()
184 | l0 = (f != 0).float().sum(dim=-1).mean()
185 |
186 | features_BF = t.flatten(f, start_dim=0, end_dim=-2).to(dtype=t.float32) # If f is shape (B, L, D), flatten to (B*L, D)
187 | assert features_BF.shape[-1] == dictionary.dict_size
188 | assert len(features_BF.shape) == 2
189 |
190 | active_features += features_BF.sum(dim=0)
191 |
192 | # cosine similarity between x and x_hat
193 | x_normed = x / t.linalg.norm(x, dim=-1, keepdim=True)
194 | x_hat_normed = x_hat / t.linalg.norm(x_hat, dim=-1, keepdim=True)
195 | cossim = (x_normed * x_hat_normed).sum(dim=-1).mean()
196 |
197 | # l2 ratio
198 | l2_ratio = (t.linalg.norm(x_hat, dim=-1) / t.linalg.norm(x, dim=-1)).mean()
199 |
200 | #compute variance explained
201 | total_variance = t.var(x, dim=0).sum()
202 | residual_variance = t.var(x - x_hat, dim=0).sum()
203 | frac_variance_explained = (1 - residual_variance / total_variance)
204 |
205 | # Equation 10 from https://arxiv.org/abs/2404.16014
206 | x_hat_norm_squared = t.linalg.norm(x_hat, dim=-1, ord=2)**2
207 | x_dot_x_hat = (x * x_hat).sum(dim=-1)
208 | relative_reconstruction_bias = x_hat_norm_squared.mean() / x_dot_x_hat.mean()
209 |
210 | out["l2_loss"] += l2_loss.item()
211 | out["l1_loss"] += l1_loss.item()
212 | out["l0"] += l0.item()
213 | out["frac_variance_explained"] += frac_variance_explained.item()
214 | out["cossim"] += cossim.item()
215 | out["l2_ratio"] += l2_ratio.item()
216 | out['relative_reconstruction_bias'] += relative_reconstruction_bias.item()
217 |
218 | if not isinstance(activations, (ActivationBuffer, NNsightActivationBuffer)):
219 | continue
220 |
221 | # compute loss recovered
222 | loss_original, loss_reconstructed, loss_zero = loss_recovered(
223 | activations.text_batch(batch_size=batch_size),
224 | activations.model,
225 | activations.submodule,
226 | dictionary,
227 | max_len=max_len,
228 | normalize_batch=normalize_batch,
229 | io=io,
230 | tracer_args=tracer_args
231 | )
232 | frac_recovered = (loss_reconstructed - loss_zero) / (loss_original - loss_zero)
233 |
234 | out["loss_original"] += loss_original.item()
235 | out["loss_reconstructed"] += loss_reconstructed.item()
236 | out["loss_zero"] += loss_zero.item()
237 | out["frac_recovered"] += frac_recovered.item()
238 |
239 | out = {key: value / n_batches for key, value in out.items()}
240 | frac_alive = (active_features != 0).float().sum() / dictionary.dict_size
241 | out["frac_alive"] = frac_alive.item()
242 |
243 | return out
--------------------------------------------------------------------------------
/dictionary_learning/grad_pursuit.py:
--------------------------------------------------------------------------------
1 | """
2 | Implements batched gradient pursuit algorithm here:
3 | https://www.lesswrong.com/posts/C5KAZQib3bzzpeyrg/full-post-progress-update-1-from-the-gdm-mech-interp-team#Inference_Time_Optimisation:~:text=two%20seem%20promising.-,Details%20of%20Sparse%20Approximation%20Algorithms%20(for%20accelerators),-This%20section%20gets
4 | """
5 |
6 | import torch as t
7 |
8 |
9 | def _grad_pursuit_update_step(signal, weights, dictionary, batch_arange, selected_features):
10 | """
11 | signal: b x d, weights: b x n, dictionary: d x n, batch_arange: b, selected_features: b x n
12 | """
13 | residual = signal - t.einsum('bn,dn -> bd', weights, dictionary)
14 | # choose the element with largest inner product with residual, as in matched pursuit.
15 | inner_products = t.einsum('dn,bd -> bn', dictionary, residual)
16 | idxs = t.argmax(inner_products, dim=1)
17 | # add the new feature to the active set.
18 | selected_features[batch_arange, idxs] = 1
19 |
20 | # the gradient for the weights is the inner product, restricted to the chosen features
21 | grad = selected_features * inner_products
22 | # the next two steps compute the optimal step size
23 | c = t.einsum('bn,dn -> bd', grad, dictionary)
24 | step_size = t.einsum('bd,bd -> b', c, residual) / t.einsum('bd,bd -> b ', c, c)
25 | weights = weights + t.einsum('b,bn -> bn', step_size, grad)
26 | weights = t.clip(weights, min=0) # clip the weights to be positive
27 | return weights, selected_features
28 |
29 | def grad_pursuit(signal, dictionary, target_l0 : int = 20, device : str = 'cpu'):
30 | """
31 | Inputs: signal: b x d, dictionary: d x n, target_l0: int, device: str
32 | Outputs: weights: b x n
33 | """
34 | assert len(signal.shape) == 2 # makes sure this a batch of signals
35 | with t.no_grad():
36 | batch_arange = t.arange(signal.shape[0]).to(device)
37 | weights = t.zeros((signal.shape[0], dictionary.shape[1])).to(device)
38 | selected_features = t.zeros((signal.shape[0], dictionary.shape[1])).to(device)
39 | for _ in range(target_l0):
40 | weights, selected_features = _grad_pursuit_update_step(
41 | signal, weights, dictionary, batch_arange, selected_features)
42 | return weights
--------------------------------------------------------------------------------
/dictionary_learning/interp.py:
--------------------------------------------------------------------------------
1 | import random
2 | from circuitsvis.activations import text_neuron_activations
3 | from einops import rearrange
4 | import torch as t
5 | from collections import namedtuple
6 | import umap
7 | import pandas as pd
8 | import plotly.express as px
9 |
10 |
11 | def feature_effect(
12 | model,
13 | submodule,
14 | dictionary,
15 | feature,
16 | inputs,
17 | max_length=128,
18 | add_residual=True, # whether to compensate for dictionary reconstruction error by adding residual
19 | k=10,
20 | largest=True,
21 | ):
22 | """
23 | Effect of ablating the feature on top k predictions for next token.
24 | """
25 | tracer_kwargs = {
26 | "scan": False,
27 | "validate": False,
28 | "invoker_args": dict(max_length=max_length),
29 | }
30 | # clean run
31 | with t.no_grad(), model.trace(inputs, **tracer_kwargs):
32 | if dictionary is None:
33 | pass
34 | elif not add_residual: # run hidden state through autoencoder
35 | if type(submodule.output.shape) == tuple:
36 | submodule.output[0][:] = dictionary(submodule.output[0])
37 | else:
38 | submodule.output = dictionary(submodule.output)
39 | clean_output = model.output.save()
40 | try:
41 | clean_logits = clean_output.value.logits[:, -1, :]
42 | except:
43 | clean_logits = clean_output.value[:, -1, :]
44 | clean_logprobs = t.nn.functional.log_softmax(clean_logits, dim=-1)
45 |
46 | # ablated run
47 | with t.no_grad(), model.trace(inputs, **tracer_kwargs):
48 | if dictionary is None:
49 | if type(submodule.output.shape) == tuple:
50 | submodule.output[0][:, -1, feature] = 0
51 | else:
52 | submodule.output[:, -1, feature] = 0
53 | else:
54 | x = submodule.output
55 | if type(x.shape) == tuple:
56 | x = x[0]
57 | x_hat, f = dictionary(x, output_features=True)
58 | residual = x - x_hat
59 |
60 | f[:, -1, feature] = 0
61 | if add_residual:
62 | x_hat = dictionary.decode(f) + residual
63 | else:
64 | x_hat = dictionary.decode(f)
65 |
66 | if type(submodule.output.shape) == tuple:
67 | submodule.output[0][:] = x_hat
68 | else:
69 | submodule.output = x_hat
70 | ablated_output = model.output.save()
71 | try:
72 | ablated_logits = ablated_output.value.logits[:, -1, :]
73 | except:
74 | ablated_logits = ablated_output.value[:, -1, :]
75 | ablated_logprobs = t.nn.functional.log_softmax(ablated_logits, dim=-1)
76 |
77 | diff = clean_logprobs - ablated_logprobs
78 | top_probs, top_tokens = t.topk(diff.mean(dim=0), k=k, largest=largest)
79 | return top_tokens, top_probs
80 |
81 |
82 | def examine_dimension(
83 | model, submodule, buffer, dictionary=None, max_length=128, n_inputs=512, dim_idx=None, k=30
84 | ):
85 |
86 | tracer_kwargs = {
87 | "scan": False,
88 | "validate": False,
89 | "invoker_args": dict(max_length=max_length),
90 | }
91 |
92 | def _list_decode(x):
93 | if isinstance(x, int):
94 | return model.tokenizer.decode(x)
95 | else:
96 | return [_list_decode(y) for y in x]
97 |
98 | if dim_idx is None:
99 | dim_idx = random.randint(0, activations.shape[-1] - 1)
100 |
101 | inputs = buffer.tokenized_batch(batch_size=n_inputs)
102 |
103 | with t.no_grad(), model.trace(inputs, **tracer_kwargs):
104 | tokens = model.inputs[1][
105 | "input_ids"
106 | ].save() # if you're getting errors, check here; might only work for pythia models
107 | activations = submodule.output
108 | if type(activations.shape) == tuple:
109 | activations = activations[0]
110 | if dictionary is not None:
111 | activations = dictionary.encode(activations)
112 | activations = activations[:, :, dim_idx].save()
113 | activations = activations.value
114 |
115 | # get top k tokens by mean activation
116 | tokens = tokens.value
117 | token_mean_acts = {}
118 | for ctx in tokens:
119 | for tok in ctx:
120 | if tok.item() in token_mean_acts:
121 | continue
122 | idxs = (tokens == tok).nonzero(as_tuple=True)
123 | token_mean_acts[tok.item()] = activations[idxs].mean().item()
124 | top_tokens = sorted(token_mean_acts.items(), key=lambda x: x[1], reverse=True)[:k]
125 | top_tokens = [(model.tokenizer.decode(tok), act) for tok, act in top_tokens]
126 |
127 | flattened_acts = rearrange(activations, "b n -> (b n)")
128 | topk_indices = t.argsort(flattened_acts, dim=0, descending=True)[:k]
129 | batch_indices = topk_indices // activations.shape[1]
130 | token_indices = topk_indices % activations.shape[1]
131 | tokens = [
132 | tokens[batch_idx, : token_idx + 1].tolist()
133 | for batch_idx, token_idx in zip(batch_indices, token_indices)
134 | ]
135 | activations = [
136 | activations[batch_idx, : token_id + 1, None, None]
137 | for batch_idx, token_id in zip(batch_indices, token_indices)
138 | ]
139 | decoded_tokens = _list_decode(tokens)
140 | top_contexts = text_neuron_activations(decoded_tokens, activations)
141 |
142 | top_affected = feature_effect(
143 | model, submodule, dictionary, dim_idx, tokens, max_length=max_length, k=k
144 | )
145 | top_affected = [(model.tokenizer.decode(tok), prob.item()) for tok, prob in zip(*top_affected)]
146 |
147 | return namedtuple("featureProfile", ["top_contexts", "top_tokens", "top_affected"])(
148 | top_contexts, top_tokens, top_affected
149 | )
150 |
151 |
152 | def feature_umap(
153 | dictionary,
154 | weight="decoder", # 'encoder' or 'decoder'
155 | # UMAP parameters
156 | n_neighbors=15,
157 | metric="cosine",
158 | min_dist=0.05,
159 | n_components=2, # dimension of the UMAP embedding
160 | feat_idxs=None, # if not none, indicate the feature with a red dot
161 | ):
162 | """
163 | Fit a UMAP embedding of the dictionary features and return a plotly plot of the result."""
164 | if weight == "encoder":
165 | df = pd.DataFrame(dictionary.encoder.weight.cpu().detach().numpy())
166 | else:
167 | df = pd.DataFrame(dictionary.decoder.weight.T.cpu().detach().numpy())
168 | reducer = umap.UMAP(
169 | n_neighbors=n_neighbors,
170 | metric=metric,
171 | min_dist=min_dist,
172 | n_components=n_components,
173 | )
174 | embedding = reducer.fit_transform(df)
175 | if feat_idxs is None:
176 | colors = None
177 | if isinstance(feat_idxs, int):
178 | feat_idxs = [feat_idxs]
179 | else:
180 | colors = ["blue" if i not in feat_idxs else "red" for i in range(embedding.shape[0])]
181 | if n_components == 2:
182 | return px.scatter(x=embedding[:, 0], y=embedding[:, 1], hover_name=df.index, color=colors)
183 | if n_components == 3:
184 | return px.scatter_3d(
185 | x=embedding[:, 0],
186 | y=embedding[:, 1],
187 | z=embedding[:, 2],
188 | hover_name=df.index,
189 | color=colors,
190 | )
191 | raise ValueError("n_components must be 2 or 3")
192 |
--------------------------------------------------------------------------------
/dictionary_learning/pretrained_dictionary_downloader.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget https://huggingface.co/saprmarks/pythia-70m-deduped-saes/resolve/main/dictionaries_pythia-70m-deduped_10.zip
4 | unzip dictionaries_pythia-70m-deduped_10.zip
--------------------------------------------------------------------------------
/dictionary_learning/requirements.txt:
--------------------------------------------------------------------------------
1 | circuitsvis>=1.43.2
2 | datasets>=2.18.0
3 | einops>=0.7.0
4 | matplotlib>=3.8.3
5 | nnsight>=0.3.0
6 | pandas>=2.2.1
7 | plotly>=5.18.0
8 | torch>=2.1.2
9 | tqdm>=4.66.1
10 | umap-learn>=0.5.6
11 | zstandard>=0.22.0
12 | wandb>=0.12.0
13 | pytest>=6.2.4
--------------------------------------------------------------------------------
/dictionary_learning/tests/test_end_to_end.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | from nnsight import LanguageModel
3 | import os
4 | import json
5 | import random
6 |
7 | from dictionary_learning.training import trainSAE
8 | from dictionary_learning.trainers.standard import StandardTrainer
9 | from dictionary_learning.trainers.top_k import TopKTrainer, AutoEncoderTopK
10 | from dictionary_learning.utils import hf_dataset_to_generator, get_nested_folders, load_dictionary
11 | from dictionary_learning.buffer import ActivationBuffer
12 | from dictionary_learning.dictionary import (
13 | AutoEncoder,
14 | GatedAutoEncoder,
15 | AutoEncoderNew,
16 | JumpReluAutoEncoder,
17 | )
18 | from dictionary_learning.evaluation import evaluate
19 |
20 | EXPECTED_RESULTS = {
21 | "AutoEncoderTopK": {
22 | "l2_loss": 4.362327718734742,
23 | "l1_loss": 50.94957427978515,
24 | "l0": 40.0,
25 | "frac_variance_explained": 0.9578053653240204,
26 | "cossim": 0.9478691875934601,
27 | "l2_ratio": 0.9478908002376556,
28 | "relative_reconstruction_bias": 0.999762898683548,
29 | "loss_original": 3.3361297130584715,
30 | "loss_reconstructed": 3.8404462814331053,
31 | "loss_zero": 13.251659297943116,
32 | "frac_recovered": 0.948982036113739,
33 | "frac_alive": 0.99951171875,
34 | },
35 | "AutoEncoder": {
36 | "l2_loss": 6.822444677352905,
37 | "l1_loss": 19.382131576538086,
38 | "l0": 37.45087890625,
39 | "frac_variance_explained": 0.8993501663208008,
40 | "cossim": 0.8791120409965515,
41 | "l2_ratio": 0.74552041888237,
42 | "relative_reconstruction_bias": 0.9595054805278778,
43 | "loss_original": 3.3361297130584715,
44 | "loss_reconstructed": 5.208198881149292,
45 | "loss_zero": 13.251659297943116,
46 | "frac_recovered": 0.8106247961521149,
47 | "frac_alive": 0.99658203125,
48 | },
49 | }
50 |
51 | DEVICE = "cuda:0"
52 | SAVE_DIR = "./test_data"
53 | MODEL_NAME = "EleutherAI/pythia-70m-deduped"
54 | RANDOM_SEED = 42
55 | LAYER = 3
56 | DATASET_NAME = "monology/pile-uncopyrighted"
57 |
58 | EVAL_TOLERANCE = 0.01
59 |
60 |
61 | def test_sae_training():
62 | """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090.
63 | This isn't a nice suite of unit tests, but it's better than nothing.
64 | I have observed that results can slightly vary with library versions. For full determinism,
65 | use pytorch 2.5.1 and nnsight 0.3.7.
66 |
67 | NOTE: `dictionary_learning` is meant to be used as a submodule. Thus, to run this test, you need to use `dictionary_learning` as a submodule
68 | and run the test from the root of the repository using `pytest -s`. Refer to https://github.com/adamkarvonen/dictionary_learning_demo for an example"""
69 | random.seed(RANDOM_SEED)
70 | t.manual_seed(RANDOM_SEED)
71 |
72 | model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE)
73 |
74 | context_length = 128
75 | llm_batch_size = 512 # Fits on a 24GB GPU
76 | sae_batch_size = 8192
77 | num_contexts_per_sae_batch = sae_batch_size // context_length
78 |
79 | num_inputs_in_buffer = num_contexts_per_sae_batch * 20
80 |
81 | num_tokens = 10_000_000
82 |
83 | # sae training parameters
84 | k = 40
85 | sparsity_penalty = 2.0
86 | expansion_factor = 8
87 |
88 | steps = int(num_tokens / sae_batch_size) # Total number of batches to train
89 | save_steps = None
90 | warmup_steps = 1000 # Warmup period at start of training and after each resample
91 | resample_steps = None
92 |
93 | # standard sae training parameters
94 | learning_rate = 3e-4
95 |
96 | # topk sae training parameters
97 | decay_start = None
98 | auxk_alpha = 1 / 32
99 |
100 | submodule = model.gpt_neox.layers[LAYER]
101 | submodule_name = f"resid_post_layer_{LAYER}"
102 | io = "out"
103 | activation_dim = model.config.hidden_size
104 |
105 | generator = hf_dataset_to_generator(DATASET_NAME)
106 |
107 | activation_buffer = ActivationBuffer(
108 | generator,
109 | model,
110 | submodule,
111 | n_ctxs=num_inputs_in_buffer,
112 | ctx_len=context_length,
113 | refresh_batch_size=llm_batch_size,
114 | out_batch_size=sae_batch_size,
115 | io=io,
116 | d_submodule=activation_dim,
117 | device=DEVICE,
118 | )
119 |
120 | # create the list of configs
121 | trainer_configs = []
122 | trainer_configs.extend(
123 | [
124 | {
125 | "trainer": TopKTrainer,
126 | "dict_class": AutoEncoderTopK,
127 | "lr": None,
128 | "activation_dim": activation_dim,
129 | "dict_size": expansion_factor * activation_dim,
130 | "k": k,
131 | "auxk_alpha": auxk_alpha, # see Appendix A.2
132 | "warmup_steps": 0,
133 | "decay_start": decay_start, # when does the lr decay start
134 | "steps": steps, # when when does training end
135 | "seed": RANDOM_SEED,
136 | "wandb_name": f"TopKTrainer-{MODEL_NAME}-{submodule_name}",
137 | "device": DEVICE,
138 | "layer": LAYER,
139 | "lm_name": MODEL_NAME,
140 | "submodule_name": submodule_name,
141 | },
142 | ]
143 | )
144 | trainer_configs.extend(
145 | [
146 | {
147 | "trainer": StandardTrainer,
148 | "dict_class": AutoEncoder,
149 | "activation_dim": activation_dim,
150 | "dict_size": expansion_factor * activation_dim,
151 | "lr": learning_rate,
152 | "l1_penalty": sparsity_penalty,
153 | "warmup_steps": warmup_steps,
154 | "sparsity_warmup_steps": None,
155 | "decay_start": decay_start,
156 | "steps": steps,
157 | "resample_steps": resample_steps,
158 | "seed": RANDOM_SEED,
159 | "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}",
160 | "layer": LAYER,
161 | "lm_name": MODEL_NAME,
162 | "device": DEVICE,
163 | "submodule_name": submodule_name,
164 | },
165 | ]
166 | )
167 |
168 | print(f"len trainer configs: {len(trainer_configs)}")
169 | output_dir = f"{SAVE_DIR}/{submodule_name}"
170 |
171 | trainSAE(
172 | data=activation_buffer,
173 | trainer_configs=trainer_configs,
174 | steps=steps,
175 | save_steps=save_steps,
176 | save_dir=output_dir,
177 | )
178 |
179 | folders = get_nested_folders(output_dir)
180 |
181 | assert len(folders) == 2
182 |
183 | for folder in folders:
184 | dictionary, config = load_dictionary(folder, DEVICE)
185 |
186 | assert dictionary is not None
187 | assert config is not None
188 |
189 |
190 | def test_evaluation():
191 | random.seed(RANDOM_SEED)
192 | t.manual_seed(RANDOM_SEED)
193 |
194 | model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE)
195 | ae_paths = get_nested_folders(SAVE_DIR)
196 |
197 | context_length = 128
198 | llm_batch_size = 100
199 | sae_batch_size = 4096
200 | n_batches = 10
201 | buffer_size = 256
202 | io = "out"
203 |
204 | generator = hf_dataset_to_generator(DATASET_NAME)
205 | submodule = model.gpt_neox.layers[LAYER]
206 |
207 | input_strings = []
208 | for i, example in enumerate(generator):
209 | input_strings.append(example)
210 | if i > buffer_size * n_batches:
211 | break
212 |
213 | for ae_path in ae_paths:
214 | dictionary, config = load_dictionary(ae_path, DEVICE)
215 | dictionary = dictionary.to(dtype=model.dtype)
216 |
217 | activation_dim = config["trainer"]["activation_dim"]
218 | context_length = config["buffer"]["ctx_len"]
219 |
220 | activation_buffer_data = iter(input_strings)
221 |
222 | activation_buffer = ActivationBuffer(
223 | activation_buffer_data,
224 | model,
225 | submodule,
226 | n_ctxs=buffer_size,
227 | ctx_len=context_length,
228 | refresh_batch_size=llm_batch_size,
229 | out_batch_size=sae_batch_size,
230 | io=io,
231 | d_submodule=activation_dim,
232 | device=DEVICE,
233 | )
234 |
235 | eval_results = evaluate(
236 | dictionary,
237 | activation_buffer,
238 | context_length,
239 | llm_batch_size,
240 | io=io,
241 | device=DEVICE,
242 | n_batches=n_batches,
243 | )
244 |
245 | print(eval_results)
246 |
247 | dict_class = config["trainer"]["dict_class"]
248 | expected_results = EXPECTED_RESULTS[dict_class]
249 |
250 | max_diff = 0
251 | max_diff_percent = 0
252 | for key, value in expected_results.items():
253 | diff = abs(eval_results[key] - value)
254 | max_diff = max(max_diff, diff)
255 | max_diff_percent = max(max_diff_percent, diff / value)
256 |
257 | print(f"Max diff: {max_diff}, max diff %: {max_diff_percent}")
258 | assert max_diff < EVAL_TOLERANCE
259 |
--------------------------------------------------------------------------------
/dictionary_learning/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .standard import StandardTrainer
2 | from .gdm import GatedSAETrainer
3 | from .p_anneal import PAnnealTrainer
4 | from .gated_anneal import GatedAnnealTrainer
5 | from .top_k import TopKTrainer
6 | from .jumprelu import JumpReluTrainer
7 | from .batch_top_k import BatchTopKTrainer, BatchTopKSAE
8 | from .matroyshka_batch_top_k import MatroyshkaBatchTopKTrainer, MatroyshkaBatchTopKSAE
--------------------------------------------------------------------------------
/dictionary_learning/trainers/batch_top_k.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import einops
5 | from collections import namedtuple
6 | from typing import Optional
7 |
8 | from ..dictionary import Dictionary
9 | from ..trainers.trainer import (
10 | SAETrainer,
11 | get_lr_schedule,
12 | set_decoder_norm_to_unit_norm,
13 | remove_gradient_parallel_to_decoder_directions,
14 | )
15 |
16 |
17 | class BatchTopKSAE(Dictionary, nn.Module):
18 | def __init__(self, activation_dim: int, dict_size: int, k: int):
19 | super().__init__()
20 | self.activation_dim = activation_dim
21 | self.dict_size = dict_size
22 |
23 | assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer"
24 | self.register_buffer("k", t.tensor(k, dtype=t.int))
25 | self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32))
26 |
27 | self.decoder = nn.Linear(dict_size, activation_dim, bias=False)
28 | self.decoder.weight.data = set_decoder_norm_to_unit_norm(
29 | self.decoder.weight, activation_dim, dict_size
30 | )
31 |
32 | self.encoder = nn.Linear(activation_dim, dict_size)
33 | self.encoder.weight.data = self.decoder.weight.T.clone()
34 | self.encoder.bias.data.zero_()
35 | self.b_dec = nn.Parameter(t.zeros(activation_dim))
36 |
37 | def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True):
38 | post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec))
39 |
40 | if use_threshold:
41 | encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold)
42 | else:
43 | # Flatten and perform batch top-k
44 | flattened_acts = post_relu_feat_acts_BF.flatten()
45 | post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1)
46 |
47 | encoded_acts_BF = (
48 | t.zeros_like(post_relu_feat_acts_BF.flatten())
49 | .scatter_(-1, post_topk.indices, post_topk.values)
50 | .reshape(post_relu_feat_acts_BF.shape)
51 | )
52 |
53 | if return_active:
54 | return encoded_acts_BF, encoded_acts_BF.sum(0) > 0, post_relu_feat_acts_BF
55 | else:
56 | return encoded_acts_BF
57 |
58 | def decode(self, x: t.Tensor) -> t.Tensor:
59 | return self.decoder(x) + self.b_dec
60 |
61 | def forward(self, x: t.Tensor, output_features: bool = False):
62 | encoded_acts_BF = self.encode(x)
63 | x_hat_BD = self.decode(encoded_acts_BF)
64 |
65 | if not output_features:
66 | return x_hat_BD
67 | else:
68 | return x_hat_BD, encoded_acts_BF
69 |
70 | def scale_biases(self, scale: float):
71 | self.encoder.bias.data *= scale
72 | self.b_dec.data *= scale
73 | if self.threshold >= 0:
74 | self.threshold *= scale
75 |
76 | @classmethod
77 | def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE":
78 | state_dict = t.load(path)
79 | dict_size, activation_dim = state_dict["encoder.weight"].shape
80 | if k is None:
81 | k = state_dict["k"].item()
82 | elif "k" in state_dict and k != state_dict["k"].item():
83 | raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']")
84 |
85 | autoencoder = cls(activation_dim, dict_size, k)
86 | autoencoder.load_state_dict(state_dict)
87 | if device is not None:
88 | autoencoder.to(device)
89 | return autoencoder
90 |
91 |
92 | class BatchTopKTrainer(SAETrainer):
93 | def __init__(
94 | self,
95 | steps: int, # total number of steps to train for
96 | activation_dim: int,
97 | dict_size: int,
98 | k: int,
99 | layer: int,
100 | lm_name: str,
101 | dict_class: type = BatchTopKSAE,
102 | lr: Optional[float] = None,
103 | auxk_alpha: float = 1 / 32,
104 | warmup_steps: int = 1000,
105 | decay_start: Optional[int] = None, # when does the lr decay start
106 | threshold_beta: float = 0.999,
107 | threshold_start_step: int = 1000,
108 | seed: Optional[int] = None,
109 | device: Optional[str] = None,
110 | wandb_name: str = "BatchTopKSAE",
111 | submodule_name: Optional[str] = None,
112 | ):
113 | super().__init__(seed)
114 | assert layer is not None and lm_name is not None
115 | self.layer = layer
116 | self.lm_name = lm_name
117 | self.submodule_name = submodule_name
118 | self.wandb_name = wandb_name
119 | self.steps = steps
120 | self.decay_start = decay_start
121 | self.warmup_steps = warmup_steps
122 | self.k = k
123 | self.threshold_beta = threshold_beta
124 | self.threshold_start_step = threshold_start_step
125 |
126 | if seed is not None:
127 | t.manual_seed(seed)
128 | t.cuda.manual_seed_all(seed)
129 |
130 | self.ae = dict_class(activation_dim, dict_size, k)
131 |
132 | if device is None:
133 | self.device = "cuda" if t.cuda.is_available() else "cpu"
134 | else:
135 | self.device = device
136 | self.ae.to(self.device)
137 |
138 | if lr is not None:
139 | self.lr = lr
140 | else:
141 | # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper
142 | scale = dict_size / (2**14)
143 | self.lr = 2e-4 / scale**0.5
144 |
145 | self.auxk_alpha = auxk_alpha
146 | self.dead_feature_threshold = 10_000_000
147 | self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper
148 | self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device)
149 | self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"]
150 | self.effective_l0 = -1
151 | self.dead_features = -1
152 | self.pre_norm_auxk_loss = -1
153 |
154 | self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999))
155 |
156 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start=decay_start)
157 |
158 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn)
159 |
160 | def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor):
161 | dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold
162 | self.dead_features = int(dead_features.sum())
163 |
164 | if dead_features.sum() > 0:
165 | k_aux = min(self.top_k_aux, dead_features.sum())
166 |
167 | auxk_latents = t.where(dead_features[None], post_relu_acts_BF, -t.inf)
168 |
169 | # Top-k dead latents
170 | auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False)
171 |
172 | auxk_buffer_BF = t.zeros_like(post_relu_acts_BF)
173 | auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts)
174 |
175 | # Note: decoder(), not decode(), as we don't want to apply the bias
176 | x_reconstruct_aux = self.ae.decoder(auxk_acts_BF)
177 | l2_loss_aux = (
178 | (residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean()
179 | )
180 |
181 | self.pre_norm_auxk_loss = l2_loss_aux
182 |
183 | # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614
184 | residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape)
185 | loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean()
186 | normalized_auxk_loss = l2_loss_aux / loss_denom
187 |
188 | return normalized_auxk_loss.nan_to_num(0.0)
189 | else:
190 | self.pre_norm_auxk_loss = -1
191 | return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device)
192 |
193 | def update_threshold(self, f: t.Tensor):
194 | device_type = "cuda" if f.is_cuda else "cpu"
195 | with t.autocast(device_type=device_type, enabled=False), t.no_grad():
196 | active = f[f > 0]
197 |
198 | if active.size(0) == 0:
199 | min_activation = 0.0
200 | else:
201 | min_activation = active.min().detach().to(dtype=t.float32)
202 |
203 | if self.ae.threshold < 0:
204 | self.ae.threshold = min_activation
205 | else:
206 | self.ae.threshold = (self.threshold_beta * self.ae.threshold) + (
207 | (1 - self.threshold_beta) * min_activation
208 | )
209 |
210 | def loss(self, x, step=None, logging=False):
211 | f, active_indices_F, post_relu_acts_BF = self.ae.encode(
212 | x, return_active=True, use_threshold=False
213 | )
214 | # l0 = (f != 0).float().sum(dim=-1).mean().item()
215 |
216 | if step > self.threshold_start_step:
217 | self.update_threshold(f)
218 |
219 | x_hat = self.ae.decode(f)
220 |
221 | e = x - x_hat
222 |
223 | self.effective_l0 = self.k
224 |
225 | num_tokens_in_step = x.size(0)
226 | did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool)
227 | did_fire[active_indices_F] = True
228 | self.num_tokens_since_fired += num_tokens_in_step
229 | self.num_tokens_since_fired[did_fire] = 0
230 |
231 | l2_loss = e.pow(2).sum(dim=-1).mean()
232 | auxk_loss = self.get_auxiliary_loss(e.detach(), post_relu_acts_BF)
233 | loss = l2_loss + self.auxk_alpha * auxk_loss
234 |
235 | if not logging:
236 | return loss
237 | else:
238 | return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])(
239 | x,
240 | x_hat,
241 | f,
242 | {"l2_loss": l2_loss.item(), "auxk_loss": auxk_loss.item(), "loss": loss.item()},
243 | )
244 |
245 | def update(self, step, x):
246 | if step == 0:
247 | median = self.geometric_median(x)
248 | median = median.to(self.ae.b_dec.dtype)
249 | self.ae.b_dec.data = median
250 |
251 | x = x.to(self.device)
252 | loss = self.loss(x, step=step)
253 | loss.backward()
254 |
255 | self.ae.decoder.weight.grad = remove_gradient_parallel_to_decoder_directions(
256 | self.ae.decoder.weight,
257 | self.ae.decoder.weight.grad,
258 | self.ae.activation_dim,
259 | self.ae.dict_size,
260 | )
261 | t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0)
262 |
263 | self.optimizer.step()
264 | self.optimizer.zero_grad()
265 | self.scheduler.step()
266 |
267 | # Make sure the decoder is still unit-norm
268 | self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm(
269 | self.ae.decoder.weight, self.ae.activation_dim, self.ae.dict_size
270 | )
271 |
272 | return loss.item()
273 |
274 | @property
275 | def config(self):
276 | return {
277 | "trainer_class": "BatchTopKTrainer",
278 | "dict_class": "BatchTopKSAE",
279 | "lr": self.lr,
280 | "steps": self.steps,
281 | "auxk_alpha": self.auxk_alpha,
282 | "warmup_steps": self.warmup_steps,
283 | "decay_start": self.decay_start,
284 | "threshold_beta": self.threshold_beta,
285 | "threshold_start_step": self.threshold_start_step,
286 | "top_k_aux": self.top_k_aux,
287 | "seed": self.seed,
288 | "activation_dim": self.ae.activation_dim,
289 | "dict_size": self.ae.dict_size,
290 | "k": self.ae.k.item(),
291 | "device": self.device,
292 | "layer": self.layer,
293 | "lm_name": self.lm_name,
294 | "wandb_name": self.wandb_name,
295 | "submodule_name": self.submodule_name,
296 | }
297 |
298 | @staticmethod
299 | def geometric_median(points: t.Tensor, max_iter: int = 100, tol: float = 1e-5):
300 | guess = points.mean(dim=0)
301 | prev = t.zeros_like(guess)
302 | weights = t.ones(len(points), device=points.device)
303 |
304 | for _ in range(max_iter):
305 | prev = guess
306 | weights = 1 / t.norm(points - guess, dim=1)
307 | weights /= weights.sum()
308 | guess = (weights.unsqueeze(1) * points).sum(dim=0)
309 | if t.norm(guess - prev) < tol:
310 | break
311 |
312 | return guess
313 |
--------------------------------------------------------------------------------
/dictionary_learning/trainers/gated_anneal.py:
--------------------------------------------------------------------------------
1 | """
2 | Implements the training scheme for a gated SAE described in https://arxiv.org/abs/2404.16014
3 | """
4 |
5 | import torch as t
6 | from typing import Optional
7 |
8 | from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam
9 | from ..config import DEBUG
10 | from ..dictionary import GatedAutoEncoder
11 | from collections import namedtuple
12 |
13 | class GatedAnnealTrainer(SAETrainer):
14 | """
15 | Gated SAE training scheme with p-annealing.
16 | """
17 | def __init__(self,
18 | steps: int, # total number of steps to train for
19 | activation_dim: int,
20 | dict_size: int,
21 | layer: int,
22 | lm_name: str,
23 | dict_class: type = GatedAutoEncoder,
24 | lr: float = 3e-4,
25 | warmup_steps: int = 1000, # lr warmup period at start of training and after each resample
26 | sparsity_warmup_steps: Optional[int] = 2000, # sparsity warmup period at start of training
27 | decay_start: Optional[int] = None, # decay learning rate after this many steps
28 | sparsity_function: str = 'Lp^p', # Lp or Lp^p
29 | initial_sparsity_penalty: float = 1e-1, # equal to l1 penalty in standard trainer
30 | anneal_start: int = 15000, # step at which to start annealing p
31 | anneal_end: Optional[int] = None, # step at which to stop annealing, defaults to steps-1
32 | p_start: float = 1, # starting value of p (constant throughout warmup)
33 | p_end: float = 0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded
34 | n_sparsity_updates: int | str = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times
35 | sparsity_queue_length: int = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty
36 | resample_steps: Optional[int] = None, # number of steps after which to resample dead neurons
37 | device: Optional[str] = None,
38 | seed: Optional[int] = 42,
39 | wandb_name: str = 'GatedAnnealTrainer',
40 | ):
41 | super().__init__(seed)
42 |
43 | assert layer is not None and lm_name is not None
44 | self.layer = layer
45 | self.lm_name = lm_name
46 |
47 | if seed is not None:
48 | t.manual_seed(seed)
49 | t.cuda.manual_seed_all(seed)
50 |
51 | # initialize dictionary
52 | # initialize dictionary
53 | self.activation_dim = activation_dim
54 | self.dict_size = dict_size
55 | self.ae = dict_class(activation_dim, dict_size)
56 |
57 | if device is None:
58 | self.device = 'cuda' if t.cuda.is_available() else 'cpu'
59 | else:
60 | self.device = device
61 | self.ae.to(self.device)
62 |
63 | self.lr = lr
64 | self.sparsity_function = sparsity_function
65 | self.anneal_start = anneal_start
66 | self.anneal_end = anneal_end if anneal_end is not None else steps
67 | self.p_start = p_start
68 | self.p_end = p_end
69 | self.p = p_start # p is set in self.loss()
70 | self.next_p = None # set in self.loss()
71 | self.lp_loss = None # set in self.loss()
72 | self.scaled_lp_loss = None # set in self.loss()
73 | if n_sparsity_updates == "continuous":
74 | self.n_sparsity_updates = self.anneal_end - anneal_start +1
75 | else:
76 | self.n_sparsity_updates = n_sparsity_updates
77 | self.sparsity_update_steps = t.linspace(anneal_start, self.anneal_end, self.n_sparsity_updates, dtype=int)
78 | self.p_values = t.linspace(p_start, p_end, self.n_sparsity_updates)
79 | self.p_step_count = 0
80 | self.sparsity_coeff = initial_sparsity_penalty # alpha
81 | self.sparsity_queue_length = sparsity_queue_length
82 | self.sparsity_queue = []
83 |
84 | self.warmup_steps = warmup_steps
85 | self.sparsity_warmup_steps = sparsity_warmup_steps
86 | self.decay_start = decay_start
87 | self.steps = steps
88 | self.logging_parameters = ['p', 'next_p', 'lp_loss', 'scaled_lp_loss', 'sparsity_coeff']
89 | self.seed = seed
90 | self.wandb_name = wandb_name
91 |
92 | self.resample_steps = resample_steps
93 | if self.resample_steps is not None:
94 | # how many steps since each neuron was last activated?
95 | self.steps_since_active = t.zeros(self.dict_size, dtype=int).to(self.device)
96 | else:
97 | self.steps_since_active = None
98 |
99 | self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr, betas=(0.0, 0.999))
100 |
101 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps, sparsity_warmup_steps)
102 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn)
103 | self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps)
104 |
105 | def resample_neurons(self, deads, activations):
106 | with t.no_grad():
107 | if deads.sum() == 0: return
108 | print(f"resampling {deads.sum().item()} neurons")
109 |
110 | # compute loss for each activation
111 | losses = (activations - self.ae(activations)).norm(dim=-1)
112 |
113 | # sample input to create encoder/decoder weights from
114 | n_resample = min([deads.sum(), losses.shape[0]])
115 | indices = t.multinomial(losses, num_samples=n_resample, replacement=False)
116 | sampled_vecs = activations[indices]
117 |
118 | # reset encoder/decoder weights for dead neurons
119 | alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean()
120 | self.ae.encoder.weight[deads][:n_resample] = sampled_vecs * alive_norm * 0.2
121 | self.ae.decoder.weight[:,deads][:,:n_resample] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T
122 | self.ae.encoder.bias[deads][:n_resample] = 0.
123 |
124 |
125 | # reset Adam parameters for dead neurons
126 | state_dict = self.optimizer.state_dict()['state']
127 | ## encoder weight
128 | state_dict[1]['exp_avg'][deads] = 0.
129 | state_dict[1]['exp_avg_sq'][deads] = 0.
130 | ## encoder bias
131 | state_dict[2]['exp_avg'][deads] = 0.
132 | state_dict[2]['exp_avg_sq'][deads] = 0.
133 | ## decoder weight
134 | state_dict[3]['exp_avg'][:,deads] = 0.
135 | state_dict[3]['exp_avg_sq'][:,deads] = 0.
136 |
137 | def lp_norm(self, f, p):
138 | norm_sq = f.pow(p).sum(dim=-1)
139 | if self.sparsity_function == 'Lp^p':
140 | return norm_sq.mean()
141 | elif self.sparsity_function == 'Lp':
142 | return norm_sq.pow(1/p).mean()
143 | else:
144 | raise ValueError("Sparsity function must be 'Lp' or 'Lp^p'")
145 |
146 | def loss(self, x:t.Tensor, step:int, logging=False, **kwargs):
147 | sparsity_scale = self.sparsity_warmup_fn(step)
148 | f, f_gate = self.ae.encode(x, return_gate=True)
149 | x_hat = self.ae.decode(f)
150 | x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach()
151 |
152 | L_recon = (x - x_hat).pow(2).sum(dim=-1).mean()
153 | L_aux = (x - x_hat_gate).pow(2).sum(dim=-1).mean()
154 |
155 | fs = f_gate # feature activation that we use for sparsity term
156 | lp_loss = self.lp_norm(fs, self.p)
157 | scaled_lp_loss = lp_loss * self.sparsity_coeff * sparsity_scale
158 | self.lp_loss = lp_loss
159 | self.scaled_lp_loss = scaled_lp_loss
160 |
161 | if self.next_p is not None:
162 | lp_loss_next = self.lp_norm(fs, self.next_p)
163 | self.sparsity_queue.append([self.lp_loss.item(), lp_loss_next.item()])
164 | self.sparsity_queue = self.sparsity_queue[-self.sparsity_queue_length:]
165 |
166 | if step in self.sparsity_update_steps:
167 | # check to make sure we don't update on repeat step:
168 | if step >= self.sparsity_update_steps[self.p_step_count]:
169 | # Adapt sparsity penalty alpha
170 | if self.next_p is not None:
171 | local_sparsity_new = t.tensor([i[0] for i in self.sparsity_queue]).mean()
172 | local_sparsity_old = t.tensor([i[1] for i in self.sparsity_queue]).mean()
173 | self.sparsity_coeff = self.sparsity_coeff * (local_sparsity_new / local_sparsity_old).item()
174 | # Update p
175 | self.p = self.p_values[self.p_step_count].item()
176 | if self.p_step_count < self.n_sparsity_updates-1:
177 | self.next_p = self.p_values[self.p_step_count+1].item()
178 | else:
179 | self.next_p = self.p_end
180 | self.p_step_count += 1
181 |
182 | # Update dead feature count
183 | if self.steps_since_active is not None:
184 | # update steps_since_active
185 | deads = (f == 0).all(dim=0)
186 | self.steps_since_active[deads] += 1
187 | self.steps_since_active[~deads] = 0
188 |
189 | loss = L_recon + scaled_lp_loss + L_aux
190 |
191 | if not logging:
192 | return loss
193 | else:
194 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])(
195 | x, x_hat, f,
196 | {
197 | 'mse_loss' : L_recon.item(),
198 | 'aux_loss' : L_aux.item(),
199 | 'loss' : loss.item(),
200 | 'p' : self.p,
201 | 'next_p' : self.next_p,
202 | 'lp_loss' : lp_loss.item(),
203 | 'sparsity_loss' : scaled_lp_loss.item(),
204 | 'sparsity_coeff' : self.sparsity_coeff,
205 | }
206 | )
207 |
208 | def update(self, step, activations):
209 | activations = activations.to(self.device)
210 |
211 | self.optimizer.zero_grad()
212 | loss = self.loss(activations, step, logging=False)
213 | loss.backward()
214 | self.optimizer.step()
215 | self.scheduler.step()
216 |
217 | if self.resample_steps is not None and step % self.resample_steps == self.resample_steps - 1:
218 | self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations)
219 |
220 | # @property
221 | # def config(self):
222 | # return {
223 | # 'trainer_class' : 'GatedSAETrainer',
224 | # 'activation_dim' : self.ae.activation_dim,
225 | # 'dict_size' : self.ae.dict_size,
226 | # 'lr' : self.lr,
227 | # 'l1_penalty' : self.l1_penalty,
228 | # 'warmup_steps' : self.warmup_steps,
229 | # 'device' : self.device,
230 | # 'wandb_name': self.wandb_name,
231 | # }
232 |
233 | @property
234 | def config(self):
235 | return {
236 | 'trainer_class' : "GatedAnnealTrainer",
237 | 'dict_class' : "GatedAutoEncoder",
238 | 'activation_dim' : self.activation_dim,
239 | 'dict_size' : self.dict_size,
240 | 'lr' : self.lr,
241 | 'sparsity_function' : self.sparsity_function,
242 | 'sparsity_penalty' : self.sparsity_coeff,
243 | 'p_start' : self.p_start,
244 | 'p_end' : self.p_end,
245 | 'anneal_start' : self.anneal_start,
246 | 'sparsity_queue_length' : self.sparsity_queue_length,
247 | 'n_sparsity_updates' : self.n_sparsity_updates,
248 | 'warmup_steps' : self.warmup_steps,
249 | 'resample_steps' : self.resample_steps,
250 | 'sparsity_warmup_steps' : self.sparsity_warmup_steps,
251 | 'decay_start' : self.decay_start,
252 | 'steps' : self.steps,
253 | 'seed' : self.seed,
254 | 'layer' : self.layer,
255 | 'lm_name' : self.lm_name,
256 | 'wandb_name' : self.wandb_name,
257 | }
258 |
--------------------------------------------------------------------------------
/dictionary_learning/trainers/gdm.py:
--------------------------------------------------------------------------------
1 | """
2 | Implements the training scheme for a gated SAE described in https://arxiv.org/abs/2404.16014
3 | """
4 |
5 | import torch as t
6 | from typing import Optional
7 |
8 | from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam
9 | from ..config import DEBUG
10 | from ..dictionary import GatedAutoEncoder
11 | from collections import namedtuple
12 |
13 | class GatedSAETrainer(SAETrainer):
14 | """
15 | Gated SAE training scheme.
16 | """
17 | def __init__(self,
18 | steps: int, # total number of steps to train for
19 | activation_dim: int,
20 | dict_size: int,
21 | layer: int,
22 | lm_name: str,
23 | dict_class = GatedAutoEncoder,
24 | lr: float = 5e-5,
25 | l1_penalty: float = 1e-1,
26 | warmup_steps: int = 1000, # lr warmup period at start of training and after each resample
27 | sparsity_warmup_steps: Optional[int] = 2000,
28 | decay_start:Optional[int]=None, # decay learning rate after this many steps
29 | seed: Optional[int] = None,
30 | device: Optional[str] = None,
31 | wandb_name: Optional[str] = 'GatedSAETrainer',
32 | submodule_name: Optional[str] = None,
33 | ):
34 | super().__init__(seed)
35 |
36 | assert layer is not None and lm_name is not None
37 | self.layer = layer
38 | self.lm_name = lm_name
39 | self.submodule_name = submodule_name
40 |
41 | if seed is not None:
42 | t.manual_seed(seed)
43 | t.cuda.manual_seed_all(seed)
44 |
45 | # initialize dictionary
46 | self.ae = dict_class(activation_dim, dict_size)
47 |
48 | self.lr = lr
49 | self.l1_penalty=l1_penalty
50 | self.warmup_steps = warmup_steps
51 | self.sparsity_warmup_steps = sparsity_warmup_steps
52 | self.decay_start = decay_start
53 | self.wandb_name = wandb_name
54 |
55 | if device is None:
56 | self.device = 'cuda' if t.cuda.is_available() else 'cpu'
57 | else:
58 | self.device = device
59 | self.ae.to(self.device)
60 |
61 | self.optimizer = ConstrainedAdam(
62 | self.ae.parameters(),
63 | self.ae.decoder.parameters(),
64 | lr=lr,
65 | betas=(0.0, 0.999),
66 | )
67 |
68 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps=None, sparsity_warmup_steps=sparsity_warmup_steps)
69 |
70 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_fn)
71 | self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps)
72 |
73 |
74 | def loss(self, x:t.Tensor, step:int, logging:bool=False, **kwargs):
75 |
76 | sparsity_scale = self.sparsity_warmup_fn(step)
77 |
78 | f, f_gate = self.ae.encode(x, return_gate=True)
79 | x_hat = self.ae.decode(f)
80 | x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach()
81 |
82 | L_recon = (x - x_hat).pow(2).sum(dim=-1).mean()
83 | L_sparse = t.linalg.norm(f_gate, ord=1, dim=-1).mean()
84 | L_aux = (x - x_hat_gate).pow(2).sum(dim=-1).mean()
85 |
86 | loss = L_recon + (self.l1_penalty * L_sparse * sparsity_scale) + L_aux
87 |
88 | if not logging:
89 | return loss
90 | else:
91 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])(
92 | x, x_hat, f,
93 | {
94 | 'mse_loss' : L_recon.item(),
95 | 'sparsity_loss' : L_sparse.item(),
96 | 'aux_loss' : L_aux.item(),
97 | 'loss' : loss.item()
98 | }
99 | )
100 |
101 | def update(self, step, x):
102 | x = x.to(self.device)
103 | self.optimizer.zero_grad()
104 | loss = self.loss(x, step)
105 | loss.backward()
106 | self.optimizer.step()
107 | self.scheduler.step()
108 |
109 | @property
110 | def config(self):
111 | return {
112 | 'dict_class': 'GatedAutoEncoder',
113 | 'trainer_class' : 'GatedSAETrainer',
114 | 'activation_dim' : self.ae.activation_dim,
115 | 'dict_size' : self.ae.dict_size,
116 | 'lr' : self.lr,
117 | 'l1_penalty' : self.l1_penalty,
118 | 'warmup_steps' : self.warmup_steps,
119 | 'sparsity_warmup_steps' : self.sparsity_warmup_steps,
120 | 'decay_start' : self.decay_start,
121 | 'seed' : self.seed,
122 | 'device' : self.device,
123 | 'layer' : self.layer,
124 | 'lm_name' : self.lm_name,
125 | 'wandb_name': self.wandb_name,
126 | 'submodule_name': self.submodule_name,
127 | }
128 |
--------------------------------------------------------------------------------
/dictionary_learning/trainers/jumprelu.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import torch
4 | import torch.autograd as autograd
5 | from torch import nn
6 | from typing import Optional
7 |
8 | from ..dictionary import Dictionary, JumpReluAutoEncoder
9 | from ..trainers.trainer import (
10 | SAETrainer,
11 | get_lr_schedule,
12 | get_sparsity_warmup_fn,
13 | set_decoder_norm_to_unit_norm,
14 | remove_gradient_parallel_to_decoder_directions,
15 | )
16 |
17 |
18 | class RectangleFunction(autograd.Function):
19 | @staticmethod
20 | def forward(ctx, x):
21 | ctx.save_for_backward(x)
22 | return ((x > -0.5) & (x < 0.5)).float()
23 |
24 | @staticmethod
25 | def backward(ctx, grad_output):
26 | (x,) = ctx.saved_tensors
27 | grad_input = grad_output.clone()
28 | grad_input[(x <= -0.5) | (x >= 0.5)] = 0
29 | return grad_input
30 |
31 |
32 | class JumpReLUFunction(autograd.Function):
33 | @staticmethod
34 | def forward(ctx, x, threshold, bandwidth):
35 | ctx.save_for_backward(x, threshold, torch.tensor(bandwidth))
36 | return x * (x > threshold).float()
37 |
38 | @staticmethod
39 | def backward(ctx, grad_output):
40 | x, threshold, bandwidth_tensor = ctx.saved_tensors
41 | bandwidth = bandwidth_tensor.item()
42 | x_grad = (x > threshold).float() * grad_output
43 | threshold_grad = (
44 | -(threshold / bandwidth)
45 | * RectangleFunction.apply((x - threshold) / bandwidth)
46 | * grad_output
47 | )
48 | return x_grad, threshold_grad, None # None for bandwidth
49 |
50 |
51 | class StepFunction(autograd.Function):
52 | @staticmethod
53 | def forward(ctx, x, threshold, bandwidth):
54 | ctx.save_for_backward(x, threshold, torch.tensor(bandwidth))
55 | return (x > threshold).float()
56 |
57 | @staticmethod
58 | def backward(ctx, grad_output):
59 | x, threshold, bandwidth_tensor = ctx.saved_tensors
60 | bandwidth = bandwidth_tensor.item()
61 | x_grad = torch.zeros_like(x)
62 | threshold_grad = (
63 | -(1.0 / bandwidth) * RectangleFunction.apply((x - threshold) / bandwidth) * grad_output
64 | )
65 | return x_grad, threshold_grad, None # None for bandwidth
66 |
67 |
68 | class JumpReluTrainer(nn.Module, SAETrainer):
69 | """
70 | Trains a JumpReLU autoencoder.
71 |
72 | Note does not use learning rate or sparsity scheduling as in the paper.
73 | """
74 |
75 | def __init__(
76 | self,
77 | steps: int, # total number of steps to train for
78 | activation_dim: int,
79 | dict_size: int,
80 | layer: int,
81 | lm_name: str,
82 | dict_class=JumpReluAutoEncoder,
83 | seed: Optional[int] = None,
84 | # TODO: What's the default lr use in the paper?
85 | lr: float = 7e-5,
86 | bandwidth: float = 0.001,
87 | sparsity_penalty: float = 1.0,
88 | warmup_steps: int = 1000, # lr warmup period at start of training and after each resample
89 | sparsity_warmup_steps: Optional[int] = 2000, # sparsity warmup period at start of training
90 | decay_start: Optional[int] = None, # decay learning rate after this many steps
91 | target_l0: float = 20.0,
92 | device: str = "cpu",
93 | wandb_name: str = "JumpRelu",
94 | submodule_name: Optional[str] = None,
95 | ):
96 | super().__init__()
97 |
98 | # TODO: Should just be args, and this should be commonised
99 | assert layer is not None, "Layer must be specified"
100 | assert lm_name is not None, "Language model name must be specified"
101 | self.lm_name = lm_name
102 | self.layer = layer
103 | self.submodule_name = submodule_name
104 | self.device = device
105 | self.steps = steps
106 | self.lr = lr
107 | self.seed = seed
108 |
109 | self.bandwidth = bandwidth
110 | self.sparsity_coefficient = sparsity_penalty
111 | self.warmup_steps = warmup_steps
112 | self.sparsity_warmup_steps = sparsity_warmup_steps
113 | self.decay_start = decay_start
114 | self.target_l0 = target_l0
115 |
116 | # TODO: Better auto-naming (e.g. in BatchTopK package)
117 | self.wandb_name = wandb_name
118 |
119 | # TODO: Why not just pass in the initialised autoencoder instead?
120 | self.ae = dict_class(
121 | activation_dim=activation_dim,
122 | dict_size=dict_size,
123 | device=device,
124 | ).to(self.device)
125 |
126 | # Parameters from the paper
127 | self.optimizer = torch.optim.Adam(self.ae.parameters(), lr=lr, betas=(0.0, 0.999), eps=1e-8)
128 |
129 | lr_fn = get_lr_schedule(
130 | steps,
131 | warmup_steps,
132 | decay_start,
133 | resample_steps=None,
134 | sparsity_warmup_steps=sparsity_warmup_steps,
135 | )
136 |
137 | self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn)
138 |
139 | self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps)
140 |
141 | # Purely for logging purposes
142 | self.dead_feature_threshold = 10_000_000
143 | self.num_tokens_since_fired = torch.zeros(dict_size, dtype=torch.long, device=device)
144 | self.dead_features = -1
145 | self.logging_parameters = ["dead_features"]
146 |
147 | def loss(self, x: torch.Tensor, step: int, logging=False, **_):
148 | # Note: We are using threshold, not log_threshold as in this notebook:
149 | # https://colab.research.google.com/drive/1PlFzI_PWGTN9yCQLuBcSuPJUjgHL7GiD#scrollTo=yP828a6uIlSO
150 | # I had poor results when using log_threshold and it would complicate the scale_biases() function
151 |
152 | sparsity_scale = self.sparsity_warmup_fn(step)
153 | x = x.to(self.ae.W_enc.dtype)
154 |
155 | pre_jump = x @ self.ae.W_enc + self.ae.b_enc
156 | f = JumpReLUFunction.apply(pre_jump, self.ae.threshold, self.bandwidth)
157 |
158 | active_indices = f.sum(0) > 0
159 | did_fire = torch.zeros_like(self.num_tokens_since_fired, dtype=torch.bool)
160 | did_fire[active_indices] = True
161 | self.num_tokens_since_fired += x.size(0)
162 | self.num_tokens_since_fired[active_indices] = 0
163 | self.dead_features = (
164 | (self.num_tokens_since_fired > self.dead_feature_threshold).sum().item()
165 | )
166 |
167 | recon = self.ae.decode(f)
168 |
169 | recon_loss = (x - recon).pow(2).sum(dim=-1).mean()
170 | l0 = StepFunction.apply(f, self.ae.threshold, self.bandwidth).sum(dim=-1).mean()
171 |
172 | sparsity_loss = (
173 | self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2) * sparsity_scale
174 | )
175 | loss = recon_loss + sparsity_loss
176 |
177 | if not logging:
178 | return loss
179 | else:
180 | return namedtuple("LossLog", ["x", "recon", "f", "losses"])(
181 | x,
182 | recon,
183 | f,
184 | {
185 | "l2_loss": recon_loss.item(),
186 | "loss": loss.item(),
187 | },
188 | )
189 |
190 | def update(self, step, x):
191 | x = x.to(self.device)
192 | loss = self.loss(x, step=step)
193 | loss.backward()
194 |
195 | # We must transpose because we are using nn.Parameter, not nn.Linear
196 | self.ae.W_dec.grad = remove_gradient_parallel_to_decoder_directions(
197 | self.ae.W_dec.T, self.ae.W_dec.grad.T, self.ae.activation_dim, self.ae.dict_size
198 | ).T
199 | torch.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0)
200 |
201 | self.optimizer.step()
202 | self.scheduler.step()
203 | self.optimizer.zero_grad()
204 |
205 | # We must transpose because we are using nn.Parameter, not nn.Linear
206 | self.ae.W_dec.data = set_decoder_norm_to_unit_norm(
207 | self.ae.W_dec.T, self.ae.activation_dim, self.ae.dict_size
208 | ).T
209 |
210 | return loss.item()
211 |
212 | @property
213 | def config(self):
214 | return {
215 | "trainer_class": "JumpReluTrainer",
216 | "dict_class": "JumpReluAutoEncoder",
217 | "lr": self.lr,
218 | "steps": self.steps,
219 | "seed": self.seed,
220 | "activation_dim": self.ae.activation_dim,
221 | "dict_size": self.ae.dict_size,
222 | "device": self.device,
223 | "layer": self.layer,
224 | "lm_name": self.lm_name,
225 | "wandb_name": self.wandb_name,
226 | "submodule_name": self.submodule_name,
227 | "bandwidth": self.bandwidth,
228 | "sparsity_penalty": self.sparsity_coefficient,
229 | "sparsity_warmup_steps": self.sparsity_warmup_steps,
230 | "target_l0": self.target_l0,
231 | }
232 |
--------------------------------------------------------------------------------
/dictionary_learning/trainers/p_anneal.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | from typing import Optional
3 | """
4 | Implements the standard SAE training scheme.
5 | """
6 |
7 | from ..dictionary import AutoEncoder
8 | from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam
9 | from ..config import DEBUG
10 |
11 | class PAnnealTrainer(SAETrainer):
12 | """
13 | SAE training scheme with the option to anneal the sparsity parameter p.
14 | You can further choose to use Lp or Lp^p sparsity.
15 | """
16 | def __init__(self,
17 | steps: int, # total number of steps to train for
18 | activation_dim: int,
19 | dict_size: int,
20 | layer: int,
21 | lm_name: str,
22 | dict_class: type = AutoEncoder,
23 | lr: float = 1e-3,
24 | warmup_steps: int = 1000, # lr warmup period at start of training and after each resample
25 | decay_start: Optional[int] = None, # step at which to start decaying lr
26 | sparsity_warmup_steps: Optional[int] = 2000, # number of steps to warm up sparsity penalty
27 | sparsity_function: str = 'Lp', # Lp or Lp^p
28 | initial_sparsity_penalty: float = 1e-1, # equal to l1 penalty in standard trainer
29 | anneal_start: int = 15000, # step at which to start annealing p
30 | anneal_end: Optional[int] = None, # step at which to stop annealing, defaults to steps-1
31 | p_start: float = 1, # starting value of p (constant throughout warmup)
32 | p_end: float = 0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded
33 | n_sparsity_updates: int | str = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times
34 | sparsity_queue_length: int = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty
35 | resample_steps: Optional[int] = None, # number of steps after which to resample dead neurons
36 | device: Optional[str] = None,
37 | seed: int = 42,
38 | wandb_name: str = 'PAnnealTrainer',
39 | submodule_name: Optional[str] = None,
40 | ):
41 | super().__init__(seed)
42 |
43 | assert layer is not None and lm_name is not None
44 | self.layer = layer
45 | self.lm_name = lm_name
46 | self.submodule_name = submodule_name
47 |
48 | if seed is not None:
49 | t.manual_seed(seed)
50 | t.cuda.manual_seed_all(seed)
51 |
52 | if device is None:
53 | self.device = t.device('cuda' if t.cuda.is_available() else 'cpu')
54 | else:
55 | self.device = device
56 |
57 | # initialize dictionary
58 | self.activation_dim = activation_dim
59 | self.dict_size = dict_size
60 | self.ae = dict_class(activation_dim, dict_size)
61 | self.ae.to(self.device)
62 |
63 | self.lr = lr
64 | self.sparsity_function = sparsity_function
65 | self.anneal_start = anneal_start
66 | self.anneal_end = anneal_end if anneal_end is not None else steps
67 | self.p_start = p_start
68 | self.p_end = p_end
69 | self.p = p_start
70 | self.next_p = None
71 | if n_sparsity_updates == "continuous":
72 | self.n_sparsity_updates = self.anneal_end - anneal_start +1
73 | else:
74 | self.n_sparsity_updates = n_sparsity_updates
75 | self.sparsity_update_steps = t.linspace(anneal_start, self.anneal_end, self.n_sparsity_updates, dtype=int)
76 | self.p_values = t.linspace(p_start, p_end, self.n_sparsity_updates)
77 | self.p_step_count = 0
78 | self.sparsity_coeff = initial_sparsity_penalty # alpha
79 | self.sparsity_queue_length = sparsity_queue_length
80 | self.sparsity_queue = []
81 |
82 | self.warmup_steps = warmup_steps
83 | self.sparsity_warmup_steps = sparsity_warmup_steps
84 | self.decay_start = decay_start
85 | self.steps = steps
86 | self.logging_parameters = ['p', 'next_p', 'lp_loss', 'scaled_lp_loss', 'sparsity_coeff']
87 | self.seed = seed
88 | self.wandb_name = wandb_name
89 |
90 | self.resample_steps = resample_steps
91 | if self.resample_steps is not None:
92 | # how many steps since each neuron was last activated?
93 | self.steps_since_active = t.zeros(self.dict_size, dtype=int).to(self.device)
94 | else:
95 | self.steps_since_active = None
96 |
97 | self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr)
98 |
99 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps, sparsity_warmup_steps)
100 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn)
101 |
102 | self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps)
103 |
104 | if (self.sparsity_update_steps.unique(return_counts=True)[1] >1).any():
105 | print("Warning! Duplicates om self.sparsity_update_steps detected!")
106 |
107 | def resample_neurons(self, deads, activations):
108 | with t.no_grad():
109 | if deads.sum() == 0: return
110 | print(f"resampling {deads.sum().item()} neurons")
111 |
112 | # compute loss for each activation
113 | losses = (activations - self.ae(activations)).norm(dim=-1)
114 |
115 | # sample input to create encoder/decoder weights from
116 | n_resample = min([deads.sum(), losses.shape[0]])
117 | indices = t.multinomial(losses, num_samples=n_resample, replacement=False)
118 | sampled_vecs = activations[indices]
119 |
120 | # reset encoder/decoder weights for dead neurons
121 | alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean()
122 | self.ae.encoder.weight[deads][:n_resample] = sampled_vecs * alive_norm * 0.2
123 | self.ae.decoder.weight[:,deads][:,:n_resample] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T
124 | self.ae.encoder.bias[deads][:n_resample] = 0.
125 |
126 |
127 | # reset Adam parameters for dead neurons
128 | state_dict = self.optimizer.state_dict()['state']
129 | ## encoder weight
130 | state_dict[1]['exp_avg'][deads] = 0.
131 | state_dict[1]['exp_avg_sq'][deads] = 0.
132 | ## encoder bias
133 | state_dict[2]['exp_avg'][deads] = 0.
134 | state_dict[2]['exp_avg_sq'][deads] = 0.
135 | ## decoder weight
136 | state_dict[3]['exp_avg'][:,deads] = 0.
137 | state_dict[3]['exp_avg_sq'][:,deads] = 0.
138 |
139 | def lp_norm(self, f, p):
140 | norm_sq = f.pow(p).sum(dim=-1)
141 | if self.sparsity_function == 'Lp^p':
142 | return norm_sq.mean()
143 | elif self.sparsity_function == 'Lp':
144 | return norm_sq.pow(1/p).mean()
145 | else:
146 | raise ValueError("Sparsity function must be 'Lp' or 'Lp^p'")
147 |
148 | def loss(self, x: t.Tensor, step:int, logging=False):
149 | sparsity_scale = self.sparsity_warmup_fn(step)
150 |
151 | # Compute loss terms
152 | x_hat, f = self.ae(x, output_features=True)
153 | recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean()
154 | lp_loss = self.lp_norm(f, self.p)
155 | scaled_lp_loss = lp_loss * self.sparsity_coeff * sparsity_scale
156 | self.lp_loss = lp_loss
157 | self.scaled_lp_loss = scaled_lp_loss
158 |
159 | if self.next_p is not None:
160 | lp_loss_next = self.lp_norm(f, self.next_p)
161 | self.sparsity_queue.append([self.lp_loss.item(), lp_loss_next.item()])
162 | self.sparsity_queue = self.sparsity_queue[-self.sparsity_queue_length:]
163 |
164 | if step in self.sparsity_update_steps:
165 | # check to make sure we don't update on repeat step:
166 | if step >= self.sparsity_update_steps[self.p_step_count]:
167 | # Adapt sparsity penalty alpha
168 | if self.next_p is not None:
169 | local_sparsity_new = t.tensor([i[0] for i in self.sparsity_queue]).mean()
170 | local_sparsity_old = t.tensor([i[1] for i in self.sparsity_queue]).mean()
171 | self.sparsity_coeff = self.sparsity_coeff * (local_sparsity_new / local_sparsity_old).item()
172 | # Update p
173 | self.p = self.p_values[self.p_step_count].item()
174 | if self.p_step_count < self.n_sparsity_updates-1:
175 | self.next_p = self.p_values[self.p_step_count+1].item()
176 | else:
177 | self.next_p = self.p_end
178 | self.p_step_count += 1
179 |
180 | # Update dead feature count
181 | if self.steps_since_active is not None:
182 | # update steps_since_active
183 | deads = (f == 0).all(dim=0)
184 | self.steps_since_active[deads] += 1
185 | self.steps_since_active[~deads] = 0
186 |
187 | if logging is False:
188 | return recon_loss + scaled_lp_loss
189 | else:
190 | loss_log = {
191 | 'p' : self.p,
192 | 'next_p' : self.next_p,
193 | 'lp_loss' : lp_loss.item(),
194 | 'scaled_lp_loss' : scaled_lp_loss.item(),
195 | 'sparsity_coeff' : self.sparsity_coeff,
196 | }
197 | return x, x_hat, f, loss_log
198 |
199 |
200 | def update(self, step, activations):
201 | activations = activations.to(self.device)
202 |
203 | self.optimizer.zero_grad()
204 | loss = self.loss(activations, step, logging=False)
205 | loss.backward()
206 | self.optimizer.step()
207 | self.scheduler.step()
208 |
209 | if self.resample_steps is not None and step % self.resample_steps == self.resample_steps - 1:
210 | self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations)
211 |
212 | @property
213 | def config(self):
214 | return {
215 | 'trainer_class' : "PAnnealTrainer",
216 | 'dict_class' : "AutoEncoder",
217 | 'activation_dim' : self.activation_dim,
218 | 'dict_size' : self.dict_size,
219 | 'lr' : self.lr,
220 | 'sparsity_function' : self.sparsity_function,
221 | 'sparsity_penalty' : self.sparsity_coeff,
222 | 'p_start' : self.p_start,
223 | 'p_end' : self.p_end,
224 | 'anneal_start' : self.anneal_start,
225 | 'sparsity_queue_length' : self.sparsity_queue_length,
226 | 'n_sparsity_updates' : self.n_sparsity_updates,
227 | 'warmup_steps' : self.warmup_steps,
228 | 'sparsity_warmup_steps': self.sparsity_warmup_steps,
229 | 'decay_start': self.decay_start,
230 | 'resample_steps' : self.resample_steps,
231 | 'steps' : self.steps,
232 | 'seed' : self.seed,
233 | 'layer' : self.layer,
234 | 'lm_name' : self.lm_name,
235 | 'wandb_name' : self.wandb_name,
236 | 'submodule_name' : self.submodule_name,
237 | }
238 |
--------------------------------------------------------------------------------
/dictionary_learning/trainers/standard.py:
--------------------------------------------------------------------------------
1 | """
2 | Implements the standard SAE training scheme.
3 | """
4 | import torch as t
5 | from typing import Optional
6 |
7 | from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam
8 | from ..config import DEBUG
9 | from ..dictionary import AutoEncoder
10 | from collections import namedtuple
11 |
12 | class StandardTrainer(SAETrainer):
13 | """
14 | Standard SAE training scheme following Towards Monosemanticity. Decoder column norms are constrained to 1.
15 | """
16 | def __init__(self,
17 | steps: int, # total number of steps to train for
18 | activation_dim: int,
19 | dict_size: int,
20 | layer: int,
21 | lm_name: str,
22 | dict_class=AutoEncoder,
23 | lr:float=1e-3,
24 | l1_penalty:float=1e-1,
25 | warmup_steps:int=1000, # lr warmup period at start of training and after each resample
26 | sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training
27 | decay_start:Optional[int]=None, # decay learning rate after this many steps
28 | resample_steps:Optional[int]=None, # how often to resample neurons
29 | seed:Optional[int]=None,
30 | device=None,
31 | wandb_name:Optional[str]='StandardTrainer',
32 | submodule_name:Optional[str]=None,
33 | ):
34 | super().__init__(seed)
35 |
36 | assert layer is not None and lm_name is not None
37 | self.layer = layer
38 | self.lm_name = lm_name
39 | self.submodule_name = submodule_name
40 |
41 | if seed is not None:
42 | t.manual_seed(seed)
43 | t.cuda.manual_seed_all(seed)
44 |
45 | # initialize dictionary
46 | self.ae = dict_class(activation_dim, dict_size)
47 |
48 | self.lr = lr
49 | self.l1_penalty=l1_penalty
50 | self.warmup_steps = warmup_steps
51 | self.sparsity_warmup_steps = sparsity_warmup_steps
52 | self.steps = steps
53 | self.decay_start = decay_start
54 | self.wandb_name = wandb_name
55 |
56 | if device is None:
57 | self.device = 'cuda' if t.cuda.is_available() else 'cpu'
58 | else:
59 | self.device = device
60 | self.ae.to(self.device)
61 |
62 | self.resample_steps = resample_steps
63 | if self.resample_steps is not None:
64 | # how many steps since each neuron was last activated?
65 | self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device)
66 | else:
67 | self.steps_since_active = None
68 |
69 | self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr)
70 |
71 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps, sparsity_warmup_steps)
72 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn)
73 |
74 | self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps)
75 |
76 | def resample_neurons(self, deads, activations):
77 | with t.no_grad():
78 | if deads.sum() == 0: return
79 | print(f"resampling {deads.sum().item()} neurons")
80 |
81 | # compute loss for each activation
82 | losses = (activations - self.ae(activations)).norm(dim=-1)
83 |
84 | # sample input to create encoder/decoder weights from
85 | n_resample = min([deads.sum(), losses.shape[0]])
86 | indices = t.multinomial(losses, num_samples=n_resample, replacement=False)
87 | sampled_vecs = activations[indices]
88 |
89 | # get norm of the living neurons
90 | alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean()
91 |
92 | # resample first n_resample dead neurons
93 | deads[deads.nonzero()[n_resample:]] = False
94 | self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2
95 | self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T
96 | self.ae.encoder.bias[deads] = 0.
97 |
98 |
99 | # reset Adam parameters for dead neurons
100 | state_dict = self.optimizer.state_dict()['state']
101 | ## encoder weight
102 | state_dict[1]['exp_avg'][deads] = 0.
103 | state_dict[1]['exp_avg_sq'][deads] = 0.
104 | ## encoder bias
105 | state_dict[2]['exp_avg'][deads] = 0.
106 | state_dict[2]['exp_avg_sq'][deads] = 0.
107 | ## decoder weight
108 | state_dict[3]['exp_avg'][:,deads] = 0.
109 | state_dict[3]['exp_avg_sq'][:,deads] = 0.
110 |
111 | def loss(self, x, step: int, logging=False, **kwargs):
112 |
113 | sparsity_scale = self.sparsity_warmup_fn(step)
114 |
115 | x_hat, f = self.ae(x, output_features=True)
116 | l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean()
117 | recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean()
118 | l1_loss = f.norm(p=1, dim=-1).mean()
119 |
120 | if self.steps_since_active is not None:
121 | # update steps_since_active
122 | deads = (f == 0).all(dim=0)
123 | self.steps_since_active[deads] += 1
124 | self.steps_since_active[~deads] = 0
125 |
126 | loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss
127 |
128 | if not logging:
129 | return loss
130 | else:
131 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])(
132 | x, x_hat, f,
133 | {
134 | 'l2_loss' : l2_loss.item(),
135 | 'mse_loss' : recon_loss.item(),
136 | 'sparsity_loss' : l1_loss.item(),
137 | 'loss' : loss.item()
138 | }
139 | )
140 |
141 |
142 | def update(self, step, activations):
143 | activations = activations.to(self.device)
144 |
145 | self.optimizer.zero_grad()
146 | loss = self.loss(activations, step=step)
147 | loss.backward()
148 | self.optimizer.step()
149 | self.scheduler.step()
150 |
151 | if self.resample_steps is not None and step % self.resample_steps == 0:
152 | self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations)
153 |
154 | @property
155 | def config(self):
156 | return {
157 | 'dict_class': 'AutoEncoder',
158 | 'trainer_class' : 'StandardTrainer',
159 | 'activation_dim': self.ae.activation_dim,
160 | 'dict_size': self.ae.dict_size,
161 | 'lr' : self.lr,
162 | 'l1_penalty' : self.l1_penalty,
163 | 'warmup_steps' : self.warmup_steps,
164 | 'resample_steps' : self.resample_steps,
165 | 'sparsity_warmup_steps' : self.sparsity_warmup_steps,
166 | 'steps' : self.steps,
167 | 'decay_start' : self.decay_start,
168 | 'seed' : self.seed,
169 | 'device' : self.device,
170 | 'layer' : self.layer,
171 | 'lm_name' : self.lm_name,
172 | 'wandb_name': self.wandb_name,
173 | 'submodule_name': self.submodule_name,
174 | }
175 |
176 |
177 | class StandardTrainerAprilUpdate(SAETrainer):
178 | """
179 | Standard SAE training scheme following the Anthropic April update. Decoder column norms are NOT constrained to 1.
180 | This trainer does not support resampling or ghost gradients. This trainer will have fewer dead neurons than the standard trainer.
181 | """
182 | def __init__(self,
183 | steps: int, # total number of steps to train for
184 | activation_dim: int,
185 | dict_size: int,
186 | layer: int,
187 | lm_name: str,
188 | dict_class=AutoEncoder,
189 | lr:float=1e-3,
190 | l1_penalty:float=1e-1,
191 | warmup_steps:int=1000, # lr warmup period at start of training
192 | sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training
193 | decay_start:Optional[int]=None, # decay learning rate after this many steps
194 | seed:Optional[int]=None,
195 | device=None,
196 | wandb_name:Optional[str]='StandardTrainerAprilUpdate',
197 | submodule_name:Optional[str]=None,
198 | ):
199 | super().__init__(seed)
200 |
201 | assert layer is not None and lm_name is not None
202 | self.layer = layer
203 | self.lm_name = lm_name
204 | self.submodule_name = submodule_name
205 |
206 | if seed is not None:
207 | t.manual_seed(seed)
208 | t.cuda.manual_seed_all(seed)
209 |
210 | # initialize dictionary
211 | self.ae = dict_class(activation_dim, dict_size)
212 |
213 | self.lr = lr
214 | self.l1_penalty=l1_penalty
215 | self.warmup_steps = warmup_steps
216 | self.sparsity_warmup_steps = sparsity_warmup_steps
217 | self.steps = steps
218 | self.decay_start = decay_start
219 | self.wandb_name = wandb_name
220 |
221 | if device is None:
222 | self.device = 'cuda' if t.cuda.is_available() else 'cpu'
223 | else:
224 | self.device = device
225 | self.ae.to(self.device)
226 |
227 | self.optimizer = t.optim.Adam(self.ae.parameters(), lr=lr)
228 |
229 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps=None, sparsity_warmup_steps=sparsity_warmup_steps)
230 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn)
231 |
232 | self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps)
233 |
234 | def loss(self, x, step: int, logging=False, **kwargs):
235 |
236 | sparsity_scale = self.sparsity_warmup_fn(step)
237 |
238 | x_hat, f = self.ae(x, output_features=True)
239 | l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean()
240 | recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean()
241 | l1_loss = (f * self.ae.decoder.weight.norm(p=2, dim=0)).sum(dim=-1).mean()
242 |
243 | loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss
244 |
245 | if not logging:
246 | return loss
247 | else:
248 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])(
249 | x, x_hat, f,
250 | {
251 | 'l2_loss' : l2_loss.item(),
252 | 'mse_loss' : recon_loss.item(),
253 | 'sparsity_loss' : l1_loss.item(),
254 | 'loss' : loss.item()
255 | }
256 | )
257 |
258 |
259 | def update(self, step, activations):
260 | activations = activations.to(self.device)
261 |
262 | self.optimizer.zero_grad()
263 | loss = self.loss(activations, step=step)
264 | loss.backward()
265 | t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0)
266 | self.optimizer.step()
267 | self.scheduler.step()
268 |
269 | @property
270 | def config(self):
271 | return {
272 | 'dict_class': 'AutoEncoder',
273 | 'trainer_class' : 'StandardTrainerAprilUpdate',
274 | 'activation_dim': self.ae.activation_dim,
275 | 'dict_size': self.ae.dict_size,
276 | 'lr' : self.lr,
277 | 'l1_penalty' : self.l1_penalty,
278 | 'warmup_steps' : self.warmup_steps,
279 | 'sparsity_warmup_steps' : self.sparsity_warmup_steps,
280 | 'steps' : self.steps,
281 | 'decay_start' : self.decay_start,
282 | 'seed' : self.seed,
283 | 'device' : self.device,
284 | 'layer' : self.layer,
285 | 'lm_name' : self.lm_name,
286 | 'wandb_name': self.wandb_name,
287 | 'submodule_name': self.submodule_name,
288 | }
289 |
290 |
--------------------------------------------------------------------------------
/dictionary_learning/trainers/top_k.py:
--------------------------------------------------------------------------------
1 | """
2 | Implements the SAE training scheme from https://arxiv.org/abs/2406.04093.
3 | Significant portions of this code have been copied from https://github.com/EleutherAI/sae/blob/main/sae
4 | """
5 |
6 | import einops
7 | import torch as t
8 | import torch.nn as nn
9 | from collections import namedtuple
10 | from typing import Optional
11 |
12 | from ..config import DEBUG
13 | from ..dictionary import Dictionary
14 | from ..trainers.trainer import (
15 | SAETrainer,
16 | get_lr_schedule,
17 | set_decoder_norm_to_unit_norm,
18 | remove_gradient_parallel_to_decoder_directions,
19 | )
20 |
21 |
22 | @t.no_grad()
23 | def geometric_median(points: t.Tensor, max_iter: int = 100, tol: float = 1e-5):
24 | """Compute the geometric median `points`. Used for initializing decoder bias."""
25 | # Initialize our guess as the mean of the points
26 | guess = points.mean(dim=0)
27 | prev = t.zeros_like(guess)
28 |
29 | # Weights for iteratively reweighted least squares
30 | weights = t.ones(len(points), device=points.device)
31 |
32 | for _ in range(max_iter):
33 | prev = guess
34 |
35 | # Compute the weights
36 | weights = 1 / t.norm(points - guess, dim=1)
37 |
38 | # Normalize the weights
39 | weights /= weights.sum()
40 |
41 | # Compute the new geometric median
42 | guess = (weights.unsqueeze(1) * points).sum(dim=0)
43 |
44 | # Early stopping condition
45 | if t.norm(guess - prev) < tol:
46 | break
47 |
48 | return guess
49 |
50 |
51 | class AutoEncoderTopK(Dictionary, nn.Module):
52 | """
53 | The top-k autoencoder architecture and initialization used in https://arxiv.org/abs/2406.04093
54 | NOTE: (From Adam Karvonen) There is an unmaintained implementation using Triton kernels in the topk-triton-implementation branch.
55 | We abandoned it as we didn't notice a significant speedup and it added complications, which are noted
56 | in the AutoEncoderTopK class docstring in that branch.
57 |
58 | With some additional effort, you can train a Top-K SAE with the Triton kernels and modify the state dict for compatibility with this class.
59 | Notably, the Triton kernels currently have the decoder to be stored in nn.Parameter, not nn.Linear, and the decoder weights must also
60 | be stored in the same shape as the encoder.
61 | """
62 |
63 | def __init__(self, activation_dim: int, dict_size: int, k: int):
64 | super().__init__()
65 | self.activation_dim = activation_dim
66 | self.dict_size = dict_size
67 |
68 | assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer"
69 | self.register_buffer("k", t.tensor(k, dtype=t.int))
70 | self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32))
71 |
72 | self.decoder = nn.Linear(dict_size, activation_dim, bias=False)
73 | self.decoder.weight.data = set_decoder_norm_to_unit_norm(
74 | self.decoder.weight, activation_dim, dict_size
75 | )
76 |
77 | self.encoder = nn.Linear(activation_dim, dict_size)
78 | self.encoder.weight.data = self.decoder.weight.T.clone()
79 | self.encoder.bias.data.zero_()
80 |
81 | self.b_dec = nn.Parameter(t.zeros(activation_dim))
82 |
83 | def encode(self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = False):
84 | post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec))
85 |
86 | if use_threshold:
87 | encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold)
88 | if return_topk:
89 | post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1)
90 | return encoded_acts_BF, post_topk.values, post_topk.indices, post_relu_feat_acts_BF
91 | else:
92 | return encoded_acts_BF
93 |
94 | post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1)
95 |
96 | # We can't split immediately due to nnsight
97 | tops_acts_BK = post_topk.values
98 | top_indices_BK = post_topk.indices
99 |
100 | buffer_BF = t.zeros_like(post_relu_feat_acts_BF)
101 | encoded_acts_BF = buffer_BF.scatter_(dim=-1, index=top_indices_BK, src=tops_acts_BK)
102 |
103 | if return_topk:
104 | return encoded_acts_BF, tops_acts_BK, top_indices_BK, post_relu_feat_acts_BF
105 | else:
106 | return encoded_acts_BF
107 |
108 | def decode(self, x: t.Tensor) -> t.Tensor:
109 | return self.decoder(x) + self.b_dec
110 |
111 | def forward(self, x: t.Tensor, output_features: bool = False):
112 | encoded_acts_BF = self.encode(x)
113 | x_hat_BD = self.decode(encoded_acts_BF)
114 | if not output_features:
115 | return x_hat_BD
116 | else:
117 | return x_hat_BD, encoded_acts_BF
118 |
119 | def scale_biases(self, scale: float):
120 | self.encoder.bias.data *= scale
121 | self.b_dec.data *= scale
122 | if self.threshold >= 0:
123 | self.threshold *= scale
124 |
125 | def from_pretrained(path, k: Optional[int] = None, device=None):
126 | """
127 | Load a pretrained autoencoder from a file.
128 | """
129 | state_dict = t.load(path)
130 | dict_size, activation_dim = state_dict["encoder.weight"].shape
131 |
132 | if k is None:
133 | k = state_dict["k"].item()
134 | elif "k" in state_dict and k != state_dict["k"].item():
135 | raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']")
136 |
137 | autoencoder = AutoEncoderTopK(activation_dim, dict_size, k)
138 | autoencoder.load_state_dict(state_dict)
139 | if device is not None:
140 | autoencoder.to(device)
141 | return autoencoder
142 |
143 |
144 | class TopKTrainer(SAETrainer):
145 | """
146 | Top-K SAE training scheme.
147 | """
148 |
149 | def __init__(
150 | self,
151 | steps: int, # total number of steps to train for
152 | activation_dim: int,
153 | dict_size: int,
154 | k: int,
155 | layer: int,
156 | lm_name: str,
157 | dict_class: type = AutoEncoderTopK,
158 | lr: Optional[float] = None,
159 | auxk_alpha: float = 1 / 32, # see Appendix A.2
160 | warmup_steps: int = 1000,
161 | decay_start: Optional[int] = None, # when does the lr decay start
162 | threshold_beta: float = 0.999,
163 | threshold_start_step: int = 1000,
164 | seed: Optional[int] = None,
165 | device: Optional[str] = None,
166 | wandb_name: str = "AutoEncoderTopK",
167 | submodule_name: Optional[str] = None,
168 | ):
169 | super().__init__(seed)
170 |
171 | assert layer is not None and lm_name is not None
172 | self.layer = layer
173 | self.lm_name = lm_name
174 | self.submodule_name = submodule_name
175 |
176 | self.wandb_name = wandb_name
177 | self.steps = steps
178 | self.decay_start = decay_start
179 | self.warmup_steps = warmup_steps
180 | self.k = k
181 | self.threshold_beta = threshold_beta
182 | self.threshold_start_step = threshold_start_step
183 |
184 | if seed is not None:
185 | t.manual_seed(seed)
186 | t.cuda.manual_seed_all(seed)
187 |
188 | # Initialise autoencoder
189 | self.ae = dict_class(activation_dim, dict_size, k)
190 | if device is None:
191 | self.device = "cuda" if t.cuda.is_available() else "cpu"
192 | else:
193 | self.device = device
194 | self.ae.to(self.device)
195 |
196 | if lr is not None:
197 | self.lr = lr
198 | else:
199 | # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper
200 | scale = dict_size / (2**14)
201 | self.lr = 2e-4 / scale**0.5
202 |
203 | self.auxk_alpha = auxk_alpha
204 | self.dead_feature_threshold = 10_000_000
205 | self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper
206 | self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device)
207 | self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"]
208 | self.effective_l0 = -1
209 | self.dead_features = -1
210 | self.pre_norm_auxk_loss = -1
211 |
212 | # Optimizer and scheduler
213 | self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999))
214 |
215 | lr_fn = get_lr_schedule(steps, warmup_steps, decay_start=decay_start)
216 |
217 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn)
218 |
219 | def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor):
220 | dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold
221 | self.dead_features = int(dead_features.sum())
222 |
223 | if self.dead_features > 0:
224 | k_aux = min(self.top_k_aux, self.dead_features)
225 |
226 | auxk_latents = t.where(dead_features[None], post_relu_acts_BF, -t.inf)
227 |
228 | # Top-k dead latents
229 | auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False)
230 |
231 | auxk_buffer_BF = t.zeros_like(post_relu_acts_BF)
232 | auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts)
233 |
234 | # Note: decoder(), not decode(), as we don't want to apply the bias
235 | x_reconstruct_aux = self.ae.decoder(auxk_acts_BF)
236 | l2_loss_aux = (
237 | (residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean()
238 | )
239 |
240 | self.pre_norm_auxk_loss = l2_loss_aux
241 |
242 | # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614
243 | residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape)
244 | loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean()
245 | normalized_auxk_loss = l2_loss_aux / loss_denom
246 |
247 | return normalized_auxk_loss.nan_to_num(0.0)
248 | else:
249 | self.pre_norm_auxk_loss = -1
250 | return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device)
251 |
252 | def update_threshold(self, top_acts_BK: t.Tensor):
253 | device_type = "cuda" if top_acts_BK.is_cuda else "cpu"
254 | with t.autocast(device_type=device_type, enabled=False), t.no_grad():
255 | active = top_acts_BK.clone().detach()
256 | active[active <= 0] = float("inf")
257 | min_activations = active.min(dim=1).values.to(dtype=t.float32)
258 | min_activation = min_activations.mean()
259 |
260 | B, K = active.shape
261 | assert len(active.shape) == 2
262 | assert min_activations.shape == (B,)
263 |
264 | if self.ae.threshold < 0:
265 | self.ae.threshold = min_activation
266 | else:
267 | self.ae.threshold = (self.threshold_beta * self.ae.threshold) + (
268 | (1 - self.threshold_beta) * min_activation
269 | )
270 |
271 | def loss(self, x, step=None, logging=False):
272 | # Run the SAE
273 | f, top_acts_BK, top_indices_BK, post_relu_acts_BF = self.ae.encode(
274 | x, return_topk=True, use_threshold=False
275 | )
276 |
277 | if step > self.threshold_start_step:
278 | self.update_threshold(top_acts_BK)
279 |
280 | x_hat = self.ae.decode(f)
281 |
282 | # Measure goodness of reconstruction
283 | e = x - x_hat
284 |
285 | # Update the effective L0 (again, should just be K)
286 | self.effective_l0 = top_acts_BK.size(1)
287 |
288 | # Update "number of tokens since fired" for each features
289 | num_tokens_in_step = x.size(0)
290 | did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool)
291 | did_fire[top_indices_BK.flatten()] = True
292 | self.num_tokens_since_fired += num_tokens_in_step
293 | self.num_tokens_since_fired[did_fire] = 0
294 |
295 | l2_loss = e.pow(2).sum(dim=-1).mean()
296 | auxk_loss = (
297 | self.get_auxiliary_loss(e.detach(), post_relu_acts_BF) if self.auxk_alpha > 0 else 0
298 | )
299 |
300 | loss = l2_loss + self.auxk_alpha * auxk_loss
301 |
302 | if not logging:
303 | return loss
304 | else:
305 | return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])(
306 | x,
307 | x_hat,
308 | f,
309 | {"l2_loss": l2_loss.item(), "auxk_loss": auxk_loss.item(), "loss": loss.item()},
310 | )
311 |
312 | def update(self, step, x):
313 | # Initialise the decoder bias
314 | if step == 0:
315 | median = geometric_median(x)
316 | median = median.to(self.ae.b_dec.dtype)
317 | self.ae.b_dec.data = median
318 |
319 | # compute the loss
320 | x = x.to(self.device)
321 | loss = self.loss(x, step=step)
322 | loss.backward()
323 |
324 | # clip grad norm and remove grads parallel to decoder directions
325 | self.ae.decoder.weight.grad = remove_gradient_parallel_to_decoder_directions(
326 | self.ae.decoder.weight,
327 | self.ae.decoder.weight.grad,
328 | self.ae.activation_dim,
329 | self.ae.dict_size,
330 | )
331 | t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0)
332 |
333 | # do a training step
334 | self.optimizer.step()
335 | self.optimizer.zero_grad()
336 | self.scheduler.step()
337 |
338 | # Make sure the decoder is still unit-norm
339 | self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm(
340 | self.ae.decoder.weight, self.ae.activation_dim, self.ae.dict_size
341 | )
342 |
343 | return loss.item()
344 |
345 | @property
346 | def config(self):
347 | return {
348 | "trainer_class": "TopKTrainer",
349 | "dict_class": "AutoEncoderTopK",
350 | "lr": self.lr,
351 | "steps": self.steps,
352 | "auxk_alpha": self.auxk_alpha,
353 | "warmup_steps": self.warmup_steps,
354 | "decay_start": self.decay_start,
355 | "threshold_beta": self.threshold_beta,
356 | "threshold_start_step": self.threshold_start_step,
357 | "seed": self.seed,
358 | "activation_dim": self.ae.activation_dim,
359 | "dict_size": self.ae.dict_size,
360 | "k": self.ae.k.item(),
361 | "device": self.device,
362 | "layer": self.layer,
363 | "lm_name": self.lm_name,
364 | "wandb_name": self.wandb_name,
365 | "submodule_name": self.submodule_name,
366 | }
367 |
--------------------------------------------------------------------------------
/dictionary_learning/trainers/trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Callable
2 | import torch
3 | import einops
4 |
5 |
6 | class SAETrainer:
7 | """
8 | Generic class for implementing SAE training algorithms
9 | """
10 |
11 | def __init__(self, seed=None):
12 | self.seed = seed
13 | self.logging_parameters = []
14 |
15 | def update(
16 | self,
17 | step, # index of step in training
18 | activations, # of shape [batch_size, d_submodule]
19 | ):
20 | pass # implemented by subclasses
21 |
22 | def get_logging_parameters(self):
23 | stats = {}
24 | for param in self.logging_parameters:
25 | if hasattr(self, param):
26 | stats[param] = getattr(self, param)
27 | else:
28 | print(f"Warning: {param} not found in {self}")
29 | return stats
30 |
31 | @property
32 | def config(self):
33 | return {
34 | "wandb_name": "trainer",
35 | }
36 |
37 |
38 | class ConstrainedAdam(torch.optim.Adam):
39 | """
40 | A variant of Adam where some of the parameters are constrained to have unit norm.
41 | Note: This should be used with a decoder that is nn.Linear, not nn.Parameter.
42 | If nn.Parameter, the dim argument to norm should be 1.
43 | """
44 |
45 | def __init__(
46 | self, params, constrained_params, lr: float, betas: tuple[float, float] = (0.9, 0.999)
47 | ):
48 | super().__init__(params, lr=lr, betas=betas)
49 | self.constrained_params = list(constrained_params)
50 |
51 | def step(self, closure=None):
52 | with torch.no_grad():
53 | for p in self.constrained_params:
54 | normed_p = p / p.norm(dim=0, keepdim=True)
55 | # project away the parallel component of the gradient
56 | p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p
57 | super().step(closure=closure)
58 | with torch.no_grad():
59 | for p in self.constrained_params:
60 | # renormalize the constrained parameters
61 | p /= p.norm(dim=0, keepdim=True)
62 |
63 |
64 | # The next two functions could be replaced with the ConstrainedAdam Optimizer
65 | @torch.no_grad()
66 | def set_decoder_norm_to_unit_norm(
67 | W_dec_DF: torch.nn.Parameter, activation_dim: int, d_sae: int
68 | ) -> torch.Tensor:
69 | """There's a major footgun here: we use this with both nn.Linear and nn.Parameter decoders.
70 | nn.Linear stores the decoder weights in a transposed format (d_model, d_sae). So, we pass the dimensions in
71 | to catch this error."""
72 |
73 | D, F = W_dec_DF.shape
74 |
75 | assert D == activation_dim
76 | assert F == d_sae
77 |
78 | eps = torch.finfo(W_dec_DF.dtype).eps
79 | norm = torch.norm(W_dec_DF.data, dim=0, keepdim=True)
80 | W_dec_DF.data /= norm + eps
81 | return W_dec_DF.data
82 |
83 |
84 | @torch.no_grad()
85 | def remove_gradient_parallel_to_decoder_directions(
86 | W_dec_DF: torch.Tensor,
87 | W_dec_DF_grad: torch.Tensor,
88 | activation_dim: int,
89 | d_sae: int,
90 | ) -> torch.Tensor:
91 | """There's a major footgun here: we use this with both nn.Linear and nn.Parameter decoders.
92 | nn.Linear stores the decoder weights in a transposed format (d_model, d_sae). So, we pass the dimensions in
93 | to catch this error."""
94 |
95 | D, F = W_dec_DF.shape
96 | assert D == activation_dim
97 | assert F == d_sae
98 |
99 | normed_W_dec_DF = W_dec_DF / (torch.norm(W_dec_DF, dim=0, keepdim=True) + 1e-6)
100 |
101 | parallel_component = einops.einsum(
102 | W_dec_DF_grad,
103 | normed_W_dec_DF,
104 | "d_in d_sae, d_in d_sae -> d_sae",
105 | )
106 | W_dec_DF_grad -= einops.einsum(
107 | parallel_component,
108 | normed_W_dec_DF,
109 | "d_sae, d_in d_sae -> d_in d_sae",
110 | )
111 | return W_dec_DF_grad
112 |
113 |
114 | def get_lr_schedule(
115 | total_steps: int,
116 | warmup_steps: int,
117 | decay_start: Optional[int] = None,
118 | resample_steps: Optional[int] = None,
119 | sparsity_warmup_steps: Optional[int] = None,
120 | ) -> Callable[[int], float]:
121 | """
122 | Creates a learning rate schedule function with linear warmup followed by an optional decay phase.
123 |
124 | Note: resample_steps creates a repeating warmup pattern instead of the standard phases, but
125 | is rarely used in practice.
126 |
127 | Args:
128 | total_steps: Total number of training steps
129 | warmup_steps: Steps for linear warmup from 0 to 1
130 | decay_start: Optional step to begin linear decay to 0
131 | resample_steps: Optional period for repeating warmup pattern
132 | sparsity_warmup_steps: Used for validation with decay_start
133 |
134 | Returns:
135 | Function that computes LR scale factor for a given step
136 | """
137 | if decay_start is not None:
138 | assert resample_steps is None, (
139 | "decay_start and resample_steps are currently mutually exclusive."
140 | )
141 | assert 0 <= decay_start < total_steps, "decay_start must be >= 0 and < steps."
142 | assert decay_start > warmup_steps, "decay_start must be > warmup_steps."
143 | if sparsity_warmup_steps is not None:
144 | assert decay_start > sparsity_warmup_steps, (
145 | "decay_start must be > sparsity_warmup_steps."
146 | )
147 |
148 | assert 0 <= warmup_steps < total_steps, "warmup_steps must be >= 0 and < steps."
149 |
150 | if resample_steps is None:
151 |
152 | def lr_schedule(step: int) -> float:
153 | if step < warmup_steps:
154 | # Warm-up phase
155 | return step / warmup_steps
156 |
157 | if decay_start is not None and step >= decay_start:
158 | # Decay phase
159 | return (total_steps - step) / (total_steps - decay_start)
160 |
161 | # Constant phase
162 | return 1.0
163 | else:
164 | assert 0 < resample_steps < total_steps, "resample_steps must be > 0 and < steps."
165 |
166 | def lr_schedule(step: int) -> float:
167 | return min((step % resample_steps) / warmup_steps, 1.0)
168 |
169 | return lr_schedule
170 |
171 |
172 | def get_sparsity_warmup_fn(
173 | total_steps: int, sparsity_warmup_steps: Optional[int] = None
174 | ) -> Callable[[int], float]:
175 | """
176 | Return a function that computes a scale factor for sparsity penalty at a given step.
177 |
178 | If `sparsity_warmup_steps` is None or 0, returns 1.0 for all steps.
179 | Otherwise, scales from 0.0 up to 1.0 across `sparsity_warmup_steps`.
180 | """
181 |
182 | if sparsity_warmup_steps is not None:
183 | assert 0 <= sparsity_warmup_steps < total_steps, (
184 | "sparsity_warmup_steps must be >= 0 and < steps."
185 | )
186 |
187 | def scale_fn(step: int) -> float:
188 | if not sparsity_warmup_steps:
189 | # If it's None or zero, we just return 1.0
190 | return 1.0
191 | else:
192 | # Gradually increase from 0.0 -> 1.0 as step goes from 0 -> sparsity_warmup_steps
193 | return min(step / sparsity_warmup_steps, 1.0)
194 |
195 | return scale_fn
196 |
--------------------------------------------------------------------------------
/dictionary_learning/training.py:
--------------------------------------------------------------------------------
1 | """
2 | Training dictionaries
3 | """
4 |
5 | import json
6 | import torch.multiprocessing as mp
7 | import os
8 | from queue import Empty
9 | from typing import Optional
10 | from contextlib import nullcontext
11 | from itertools import cycle
12 |
13 | import torch as t
14 | from tqdm import tqdm
15 |
16 | import wandb
17 |
18 | from .dictionary import AutoEncoder
19 | from .evaluation import evaluate
20 | from .trainers.standard import StandardTrainer
21 | from .trainers.matroyshka_batch_top_k import MatroyshkaBatchTopKSAE
22 |
23 |
24 | def new_wandb_process(config, log_queue, entity, project):
25 | wandb.init(entity=entity, project=project, config=config, name=config["wandb_name"])
26 | while True:
27 | try:
28 | log = log_queue.get(timeout=1)
29 | if log == "DONE":
30 | break
31 | wandb.log(log)
32 | except Empty:
33 | continue
34 | wandb.finish()
35 |
36 |
37 | def log_stats(
38 | trainers,
39 | step: int,
40 | act: t.Tensor,
41 | activations_split_by_head: bool,
42 | transcoder: bool,
43 | log_queues: list=[],
44 | verbose: bool=False,
45 | ):
46 | with t.no_grad():
47 | # quick hack to make sure all trainers get the same x
48 | z = act.clone()
49 | for i, trainer in enumerate(trainers):
50 | log = {}
51 | act = z.clone()
52 | if activations_split_by_head: # x.shape: [batch, pos, n_heads, d_head]
53 | act = act[..., i, :]
54 | if not transcoder:
55 | act, act_hat, f, losslog = trainer.loss(act, step=step, logging=True)
56 |
57 | # L0
58 | l0 = (f != 0).float().sum(dim=-1).mean().item()
59 | # fraction of variance explained
60 | total_variance = t.var(act, dim=0).sum()
61 | residual_variance = t.var(act - act_hat, dim=0).sum()
62 | frac_variance_explained = 1 - residual_variance / total_variance
63 | log[f"frac_variance_explained"] = frac_variance_explained.item()
64 | else: # transcoder
65 | x, x_hat, f, losslog = trainer.loss(act, step=step, logging=True)
66 |
67 | # L0
68 | l0 = (f != 0).float().sum(dim=-1).mean().item()
69 |
70 | if verbose:
71 | print(f"Step {step}: L0 = {l0}, frac_variance_explained = {frac_variance_explained}")
72 |
73 | # log parameters from training
74 | log.update({f"{k}": v.cpu().item() if isinstance(v, t.Tensor) else v for k, v in losslog.items()})
75 | log[f"l0"] = l0
76 | trainer_log = trainer.get_logging_parameters()
77 | for name, value in trainer_log.items():
78 | if isinstance(value, t.Tensor):
79 | value = value.cpu().item()
80 | log[f"{name}"] = value
81 |
82 | if log_queues:
83 | log_queues[i].put(log)
84 |
85 | def get_norm_factor(data, steps: int) -> float:
86 | """Per Section 3.1, find a fixed scalar factor so activation vectors have unit mean squared norm.
87 | This is very helpful for hyperparameter transfer between different layers and models.
88 | Use more steps for more accurate results.
89 | https://arxiv.org/pdf/2408.05147
90 |
91 | If experiencing troubles with hyperparameter transfer between models, it may be worth instead normalizing to the square root of d_model.
92 | https://transformer-circuits.pub/2024/april-update/index.html#training-saes"""
93 | total_mean_squared_norm = 0
94 | count = 0
95 |
96 | for step, act_BD in enumerate(tqdm(data, total=steps, desc="Calculating norm factor")):
97 | if step > steps:
98 | break
99 |
100 | count += 1
101 | mean_squared_norm = t.mean(t.sum(act_BD ** 2, dim=1))
102 | total_mean_squared_norm += mean_squared_norm
103 |
104 | average_mean_squared_norm = total_mean_squared_norm / count
105 | norm_factor = t.sqrt(average_mean_squared_norm).item()
106 |
107 | print(f"Average mean squared norm: {average_mean_squared_norm}")
108 | print(f"Norm factor: {norm_factor}")
109 |
110 | return norm_factor
111 |
112 | # Assumes only one trainer and one log queue
113 | def validation(val_data, autocast_dtype, trainer, log_queue, norm_factor):
114 | for use_threshold in [False, True]:
115 | l0s = []
116 | l2s = []
117 | fracs = []
118 |
119 | for act in val_data:
120 | act = act.detach().clone()
121 | act = act.to(dtype=autocast_dtype)
122 | act /= norm_factor
123 | with t.no_grad():
124 | f = trainer.ae.encode(act, use_threshold=use_threshold)
125 | act_hat = trainer.ae.decode(f)
126 | e = act - act_hat
127 |
128 | # Sparsity - L0
129 | l0 = (f != 0).float().sum(dim=-1).mean().item()
130 |
131 | # Reconstruction - L2
132 | l2 = e.pow(2).sum(dim=-1).mean().item()
133 |
134 | # Reconstruction - Fraction of variance explained
135 | total_variance = t.var(act, dim=0).sum()
136 | residual_variance = t.var(e, dim=0).sum()
137 | frac_variance_explained = 1 - residual_variance / total_variance
138 | frac = frac_variance_explained.item()
139 |
140 | l0s.append(l0)
141 | l2s.append(l2)
142 | fracs.append(frac)
143 |
144 | threshold_str = "true" if use_threshold else "false"
145 |
146 | log = {
147 | f"val_threshold_{threshold_str}/sparsity_l0": t.mean(t.tensor(l0s)).item(),
148 | f"val_threshold_{threshold_str}/reconstruction_l2": t.mean(t.tensor(l2s)).item(),
149 | f"val_threshold_{threshold_str}/frac_variance_explained": t.mean(t.tensor(fracs)).item(),
150 | }
151 |
152 | log_queue.put(log)
153 |
154 |
155 | def trainSAE(
156 | data,
157 | val_data,
158 | trainer_configs: list[dict],
159 | steps: int,
160 | use_wandb:bool=False,
161 | wandb_entity:str="",
162 | wandb_project:str="",
163 | save_steps:Optional[list[int]]=None,
164 | save_dir:Optional[str]=None,
165 | log_steps:Optional[int]=None,
166 | activations_split_by_head:bool=False,
167 | transcoder:bool=False,
168 | run_cfg:dict={},
169 | normalize_activations:bool=True,
170 | verbose:bool=False,
171 | device:str="cuda",
172 | autocast_dtype: t.dtype = t.float32,
173 | ):
174 | """
175 | Train SAEs using the given trainers
176 |
177 | If normalize_activations is True, the activations will be normalized to have unit mean squared norm.
178 | The autoencoders weights will be scaled before saving, so the activations don't need to be scaled during inference.
179 | This is very helpful for hyperparameter transfer between different layers and models.
180 |
181 | Setting autocast_dtype to t.bfloat16 provides a significant speedup with minimal change in performance.
182 | """
183 |
184 | device_type = "cuda" if "cuda" in device else "cpu"
185 | autocast_context = nullcontext() if device_type == "cpu" else t.autocast(device_type=device_type, dtype=autocast_dtype)
186 |
187 | trainers = []
188 | for i, config in enumerate(trainer_configs):
189 | if "wandb_name" in config:
190 | config["wandb_name"] = f"{config['wandb_name']}_trainer_{i}"
191 | trainer_class = config["trainer"]
192 | del config["trainer"]
193 | trainers.append(trainer_class(**config))
194 |
195 | wandb_processes = []
196 | log_queues = []
197 |
198 | if use_wandb:
199 | # Note: If encountering wandb and CUDA related errors, try setting start method to spawn in the if __name__ == "__main__" block
200 | # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method
201 | # Everything should work fine with the default fork method but it may not be as robust
202 | for i, trainer in enumerate(trainers):
203 | log_queue = mp.Queue()
204 | log_queues.append(log_queue)
205 | wandb_config = trainer.config | run_cfg
206 | # Make sure wandb config doesn't contain any CUDA tensors
207 | wandb_config = {k: v.cpu().item() if isinstance(v, t.Tensor) else v
208 | for k, v in wandb_config.items()}
209 | wandb_process = mp.Process(
210 | target=new_wandb_process,
211 | args=(wandb_config, log_queue, wandb_entity, wandb_project),
212 | )
213 | wandb_process.start()
214 | wandb_processes.append(wandb_process)
215 |
216 | # make save dirs, export config
217 | if save_dir is not None:
218 | save_dirs = [
219 | os.path.join(save_dir, f"trainer_{i}") for i in range(len(trainer_configs))
220 | ]
221 | for trainer, dir in zip(trainers, save_dirs):
222 | os.makedirs(dir, exist_ok=True)
223 | # save config
224 | config = {"trainer": trainer.config}
225 | try:
226 | config["buffer"] = data.config
227 | except:
228 | pass
229 | with open(os.path.join(dir, "config.json"), "w") as f:
230 | json.dump(config, f, indent=4)
231 | else:
232 | save_dirs = [None for _ in trainer_configs]
233 |
234 | if normalize_activations:
235 | norm_factor = get_norm_factor(data, steps=100)
236 |
237 | for trainer in trainers:
238 | trainer.config["norm_factor"] = norm_factor
239 | # Verify that all autoencoders have a scale_biases method
240 | trainer.ae.scale_biases(1.0)
241 |
242 | # def rand_cycle(iterable):
243 | # while True:
244 | # for x in iterable:
245 | # yield x
246 |
247 | for step, act in enumerate(tqdm(cycle(data), total=steps)):
248 |
249 | act = act.detach().clone() # TODO: maybe remove if activation dataset modified
250 | act = act.to(dtype=autocast_dtype)
251 |
252 | if normalize_activations:
253 | act /= norm_factor
254 |
255 | if step >= steps:
256 | break
257 |
258 | # logging
259 | if (use_wandb or verbose) and step % log_steps == 0:
260 | log_stats(
261 | trainers, step, act, activations_split_by_head, transcoder, log_queues=log_queues, verbose=verbose
262 | )
263 |
264 | # logging validation
265 | # if (use_wandb or verbose) and step % save_steps == 0:
266 | if step % log_steps == 0:
267 | validation(val_data, autocast_dtype, trainers[0], log_queues[0], norm_factor)
268 |
269 | # saving
270 | if save_steps is not None and step in save_steps:
271 | for dir, trainer in zip(save_dirs, trainers):
272 | if dir is not None:
273 |
274 | if normalize_activations:
275 | # Temporarily scale up biases for checkpoint saving
276 | trainer.ae.scale_biases(norm_factor)
277 |
278 | if not os.path.exists(os.path.join(dir, "checkpoints")):
279 | os.mkdir(os.path.join(dir, "checkpoints"))
280 |
281 | checkpoint = {k: v.cpu() for k, v in trainer.ae.state_dict().items()}
282 | t.save(
283 | checkpoint,
284 | os.path.join(dir, "checkpoints", f"ae_{step}.pt"),
285 | )
286 |
287 | if normalize_activations:
288 | trainer.ae.scale_biases(1 / norm_factor)
289 |
290 | # training
291 | for trainer in trainers:
292 | with autocast_context:
293 | trainer.update(step, act)
294 |
295 | # save final SAEs
296 | for save_dir, trainer in zip(save_dirs, trainers):
297 | if normalize_activations:
298 | trainer.ae.scale_biases(norm_factor)
299 | if save_dir is not None:
300 | final = {k: v.cpu() for k, v in trainer.ae.state_dict().items()}
301 | t.save(final, os.path.join(save_dir, "ae.pt"))
302 |
303 | # Signal wandb processes to finish
304 | if use_wandb:
305 | for queue in log_queues:
306 | queue.put("DONE")
307 | for process in wandb_processes:
308 | process.join()
309 |
--------------------------------------------------------------------------------
/dictionary_learning/utils.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import zstandard as zstd
3 | import io
4 | import json
5 | import os
6 | from nnsight import LanguageModel
7 |
8 | from .trainers.top_k import AutoEncoderTopK
9 | from .trainers.batch_top_k import BatchTopKSAE
10 | from .trainers.matroyshka_batch_top_k import MatroyshkaBatchTopKSAE
11 | from .dictionary import (
12 | AutoEncoder,
13 | GatedAutoEncoder,
14 | AutoEncoderNew,
15 | JumpReluAutoEncoder,
16 | )
17 |
18 |
19 | def hf_dataset_to_generator(dataset_name, split="train", streaming=True):
20 | dataset = load_dataset(dataset_name, split=split, streaming=streaming)
21 |
22 | def gen():
23 | for x in iter(dataset):
24 | yield x["text"]
25 |
26 | return gen()
27 |
28 |
29 | def zst_to_generator(data_path):
30 | """
31 | Load a dataset from a .jsonl.zst file.
32 | The jsonl entries is assumed to have a 'text' field
33 | """
34 | compressed_file = open(data_path, "rb")
35 | dctx = zstd.ZstdDecompressor()
36 | reader = dctx.stream_reader(compressed_file)
37 | text_stream = io.TextIOWrapper(reader, encoding="utf-8")
38 |
39 | def generator():
40 | for line in text_stream:
41 | yield json.loads(line)["text"]
42 |
43 | return generator()
44 |
45 |
46 | def get_nested_folders(path: str) -> list[str]:
47 | """
48 | Recursively get a list of folders that contain an ae.pt file, starting the search from the given path
49 | """
50 | folder_names = []
51 |
52 | for root, dirs, files in os.walk(path):
53 | if "ae.pt" in files:
54 | folder_names.append(root)
55 |
56 | return folder_names
57 |
58 |
59 | def load_dictionary(base_path: str, device: str) -> tuple:
60 | ae_path = f"{base_path}/ae.pt"
61 | config_path = f"{base_path}/config.json"
62 |
63 | with open(config_path, "r") as f:
64 | config = json.load(f)
65 |
66 | dict_class = config["trainer"]["dict_class"]
67 |
68 | if dict_class == "AutoEncoder":
69 | dictionary = AutoEncoder.from_pretrained(ae_path, device=device)
70 | elif dict_class == "GatedAutoEncoder":
71 | dictionary = GatedAutoEncoder.from_pretrained(ae_path, device=device)
72 | elif dict_class == "AutoEncoderNew":
73 | dictionary = AutoEncoderNew.from_pretrained(ae_path, device=device)
74 | elif dict_class == "AutoEncoderTopK":
75 | k = config["trainer"]["k"]
76 | dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device)
77 | elif dict_class == "BatchTopKSAE":
78 | k = config["trainer"]["k"]
79 | dictionary = BatchTopKSAE.from_pretrained(ae_path, k=k, device=device)
80 | elif dict_class == "MatroyshkaBatchTopKSAE":
81 | k = config["trainer"]["k"]
82 | dictionary = MatroyshkaBatchTopKSAE.from_pretrained(ae_path, k=k, device=device)
83 | elif dict_class == "JumpReluAutoEncoder":
84 | dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device)
85 | else:
86 | raise ValueError(f"Dictionary class {dict_class} not supported")
87 |
88 | return dictionary, config
89 |
90 |
91 | def get_submodule(model: LanguageModel, layer: int):
92 | """Gets the residual stream submodule"""
93 | model_name = model._model_key
94 |
95 | if "pythia" in model_name:
96 | return model.gpt_neox.layers[layer]
97 | elif "gemma" in model_name:
98 | return model.model.layers[layer]
99 | else:
100 | raise ValueError(f"Please add submodule for model {model_name}")
101 |
--------------------------------------------------------------------------------
/encode_images.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os.path
3 | import argparse
4 | import tqdm
5 | from utils import get_dataset, get_model
6 |
7 | def get_args_parser():
8 | parser = argparse.ArgumentParser("Encode images", add_help=False)
9 | parser.add_argument("--embeddings_path")
10 | parser.add_argument("--model_name", default="clip", type=str)
11 | parser.add_argument("--dataset_name", default="imagenet", type=str)
12 | parser.add_argument("--data_path", default="/shared-network/inat2021", type=str)
13 | parser.add_argument("--split", default="train", type=str)
14 | parser.add_argument("--batch_size", default=128, type=int)
15 | parser.add_argument("--num_workers", default=10, type=int)
16 | parser.add_argument("--device", default="cuda:0")
17 | return parser
18 |
19 |
20 | if __name__ == "__main__":
21 | args = get_args_parser().parse_args()
22 |
23 | if os.path.exists(args.embeddings_path):
24 | print(f"Embeddings already saved at {args.embeddings_path}")
25 | else:
26 | model, processor = get_model(args)
27 | ds, dl = get_dataset(args, preprocess=None, processor=processor, split=args.split)
28 |
29 | embeddings = []
30 | pbar = tqdm.tqdm(dl)
31 | for image in pbar:
32 | with torch.no_grad():
33 | output = model.encode(image)
34 | embeddings.append(output.detach().cpu())
35 |
36 | embeddings = torch.cat(embeddings, dim=0)
37 | os.makedirs(os.path.dirname(args.embeddings_path), exist_ok=True)
38 | torch.save(embeddings, args.embeddings_path)
39 | print(f"Embeddings shape: {embeddings.shape}")
40 | print(f"Saved embeddings to {args.embeddings_path}")
41 |
42 |
--------------------------------------------------------------------------------
/find_hai_indices.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tqdm
3 | import os
4 | import torch
5 | from torchvision import transforms
6 | from utils import get_dataset
7 | import argparse
8 | from datasets.activations import ActivationsDataset, ChunkedActivationsDataset
9 | from torch.utils.data import DataLoader, Subset
10 | from itertools import combinations
11 | from collections import Counter
12 | import math
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser(description="Find indices of highest activating images from activations")
16 | parser.add_argument('--activations_dir', type=str, required=True)
17 | parser.add_argument("--dataset_name", default="imagenet", type=str)
18 | parser.add_argument("--data_path", default="/shared-network/inat2021", type=str)
19 | parser.add_argument('--split', type=str, default='train')
20 | parser.add_argument('--k', type=int, default=16)
21 | parser.add_argument('--chunk_size', type=int)
22 | return parser.parse_args()
23 |
24 | if __name__ == "__main__":
25 | # Parse command line arguments
26 | args = parse_args()
27 | args.batch_size = 1 # not used
28 | args.num_workers = 0 # not used
29 |
30 | hai_indices_path = os.path.join(args.activations_dir, f"hai_indices_{args.k}")
31 | if os.path.exists(hai_indices_path):
32 | print(f"HAI indices already saved at {hai_indices_path}")
33 | else:
34 | print("Computing HAI indices", flush=True)
35 | print(f"Loading activations from {args.activations_dir}", flush=True)
36 | activations_dataset = ActivationsDataset(args.activations_dir, device=torch.device("cpu"))
37 | print(f"Dataset loaded. Total samples: {len(activations_dataset)}", flush=True)
38 |
39 | activations_dataloader = DataLoader(activations_dataset, batch_size=args.chunk_size, shuffle=False, num_workers=16)
40 | num_samples = len(activations_dataset)
41 |
42 | first_batch = next(iter(activations_dataloader))
43 | num_neurons = first_batch.shape[1]
44 | print(f"Number of neurons detected: {num_neurons}")
45 | num_chunks = math.ceil(num_neurons / args.chunk_size)
46 | print(f"Processing {num_chunks} chunks of {args.chunk_size} neurons each...", flush=True)
47 |
48 | importants = []
49 | worst_hais = []
50 | pbar = tqdm.tqdm(list(range(num_chunks)))
51 | for i in pbar:
52 | neuron_start = i * args.chunk_size
53 | neuron_end = min((i + 1) * args.chunk_size, num_neurons)
54 | activations_chunks = np.zeros((num_samples, neuron_end - neuron_start))
55 | for j, activations_chunk in enumerate(activations_dataloader):
56 | sample_start = j * args.chunk_size
57 | sample_end = min((j + 1) * args.chunk_size, num_samples)
58 | activations_chunk = activations_chunk.numpy()
59 | activations_chunks[sample_start:sample_end, :] = activations_chunk[:, neuron_start:neuron_end]
60 | for neuron in range(neuron_end - neuron_start):
61 | neuron_activations = activations_chunks[:, neuron]
62 | important = np.argsort(neuron_activations)[-args.k:]
63 | importants.append(important)
64 | worst_hai = neuron_activations[important[0]]
65 | worst_hais.append(worst_hai)
66 |
67 | hai_indices = np.array(importants)
68 | print(f"hai_indices.shape(): {hai_indices.shape}")
69 | np.save(hai_indices_path, hai_indices)
70 | print(f"Saved HAI indices to: {hai_indices_path}")
71 |
72 | worst_hai_indices_path = os.path.join(args.activations_dir, f"hai_indices_{args.k}_worst")
73 | worst_hais = np.array(worst_hais)
74 | np.save(worst_hai_indices_path, worst_hais)
75 | print(f"Saved worst HAI indices to: {worst_hai_indices_path}")
--------------------------------------------------------------------------------
/imagenet_subset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 |
5 | def main(imagenet_root, output_dir):
6 | os.makedirs(output_dir, exist_ok=True)
7 |
8 | for class_folder in sorted(os.listdir(imagenet_root)):
9 | class_path = os.path.join(imagenet_root, class_folder)
10 |
11 | if os.path.isdir(class_path):
12 | images = sorted(os.listdir(class_path))
13 | if images:
14 | first_image = images[0]
15 | src_path = os.path.join(class_path, first_image)
16 | dest_path = os.path.join(output_dir, f"{class_folder}_{first_image}")
17 |
18 | shutil.copy(src_path, dest_path)
19 | print(f"Copied {first_image} from {class_folder}")
20 |
21 | print("Done selecting first images for each class!")
22 |
23 | if __name__ == "__main__":
24 | parser = argparse.ArgumentParser(description="Copy the first image from each ImageNet class folder.")
25 | parser.add_argument("--imagenet_root", required=True, help="Path to the ImageNet training dataset root directory")
26 | parser.add_argument("--output_dir", default="./images_imagenet", help="Output directory for selected images")
27 |
28 | args = parser.parse_args()
29 | main(args.imagenet_root, args.output_dir)
30 |
--------------------------------------------------------------------------------
/images/white.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/sae-for-vlm/69e211dc4d0a31cfb3b3f6f682f78a1890677d5f/images/white.png
--------------------------------------------------------------------------------
/inat_depth.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tqdm
3 | import os
4 | import torch
5 | from utils import get_dataset
6 | import argparse
7 | from itertools import combinations
8 | from collections import Counter
9 |
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser(description="Measure hierarchy in iNaturalist trained Matryoshka SAE")
13 | parser.add_argument('--activations_dir', type=str, required=True)
14 | parser.add_argument('--hai_indices_path', type=str, required=True)
15 | parser.add_argument("--data_path", default="/shared-network/inat2021", type=str)
16 | parser.add_argument('--split', type=str, default='train')
17 | parser.add_argument('--k', type=int, default=16)
18 | parser.add_argument('--group_fractions', type=float, nargs='+', required=True)
19 | return parser.parse_args()
20 |
21 |
22 | if __name__ == "__main__":
23 | # Parse command line arguments
24 | args = parse_args()
25 | args.batch_size = 1 # not used
26 | args.num_workers = 0 # not used
27 | args.dataset_name = 'inat' # fixed
28 |
29 | # Get HAI indices
30 | hai_indices = np.load(args.hai_indices_path)
31 | print(f"Loaded HAI indices found at {args.hai_indices_path}")
32 |
33 | # Get HAI worst scores
34 | worst_scores_path = f"{args.hai_indices_path[:-4]}_worst.npy"
35 | worst_scores = np.load(worst_scores_path)
36 | print(f"Loaded worst scores of HAI found at {worst_scores_path}")
37 |
38 | # Assign path to each of top k images
39 | hai_classes = []
40 | num_neurons = hai_indices.shape[0]
41 | ds, _ = get_dataset(args, preprocess=None, processor=None, split=args.split)
42 |
43 | for neuron in range(num_neurons):
44 | hai_classes_neuron = []
45 | for i in range(args.k):
46 | class_index = hai_indices[neuron, i]
47 | image_path = ds.imgs[class_index][0]
48 | class_name = image_path.split(os.path.sep)[-2].split("_")[1:]
49 | hai_classes_neuron.append(class_name)
50 | hai_classes.append(hai_classes_neuron)
51 |
52 |
53 | # Compute pairwise LCA
54 | def get_lca(x, y):
55 | lca = 0
56 | for a, b in zip(x, y):
57 | if a == b:
58 | lca += 1
59 | else:
60 | break
61 | return lca
62 |
63 |
64 | lcas_majority = []
65 | lcas_mean = []
66 |
67 | for neuron in range(num_neurons):
68 | lcas_neuron = []
69 | hai_classes_neuron = hai_classes[neuron]
70 | for x, y in combinations(hai_classes_neuron, 2):
71 | lca = get_lca(x, y)
72 | lcas_neuron.append(lca)
73 |
74 | if lcas_neuron:
75 | lcas_mean.append(round(sum(lcas_neuron) / len(lcas_neuron)))
76 | lcas_majority.append(Counter(lcas_neuron).most_common(1)[0][0])
77 |
78 | # Compute avg depth of LCA
79 | assert np.isclose(sum(args.group_fractions), 1.0), "group_fractions must sum to 1.0"
80 | group_sizes = [int(f * num_neurons) for f in args.group_fractions[:-1]]
81 | group_sizes.append(num_neurons - sum(group_sizes)) # Ensure it adds up to num_neurons
82 |
83 | start_idx = 0
84 | depths_majority = []
85 | depths_mean = []
86 |
87 | for group_idx, group_size in enumerate(group_sizes):
88 | end_idx = start_idx + group_size
89 | valid_mask = worst_scores[start_idx:end_idx] != 0.0
90 | valid_lcas_majority = np.compress(valid_mask, lcas_majority[start_idx:end_idx])
91 | valid_lcas_mean = np.compress(valid_mask, lcas_mean[start_idx:end_idx])
92 | depths_majority.append(np.mean(valid_lcas_majority))
93 | depths_mean.append(np.mean(valid_lcas_mean))
94 | start_idx = end_idx
95 |
96 | num_excluded = np.sum(~valid_mask)
97 | percentage_excluded = (num_excluded / group_size) * 100
98 | print(f"Group {group_idx}: {percentage_excluded:.2f}% neurons excluded")
99 |
100 | print("Group-wise Average WordNet Depths (mean):", depths_mean)
101 | print("Group-wise Average WordNet Depths (majority):", depths_majority)
--------------------------------------------------------------------------------
/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os.path
3 | import argparse
4 | from datasets.activations import ActivationsDataset
5 | import os
6 |
7 | from torch.utils.data import DataLoader, Subset
8 | import tqdm
9 | import torch.nn.functional as F
10 |
11 | def get_args_parser():
12 | parser = argparse.ArgumentParser("Measure monosemanticity via weighted pairwise cosine similarity", add_help=False)
13 | parser.add_argument("--embeddings_path")
14 | parser.add_argument("--activations_dir")
15 | parser.add_argument("--output_subdir")
16 | parser.add_argument("--device", default="cpu")
17 | return parser
18 |
19 | def main(args):
20 | # Load embeddings
21 | embeddings = torch.load(args.embeddings_path, map_location=torch.device(args.device))
22 | print(f"Loaded embeddings found at {args.embeddings_path}")
23 | print(f"Embeddings shape: {embeddings.shape}")
24 |
25 | # Load activations
26 | activations_dataset = ActivationsDataset(args.activations_dir, device=torch.device(args.device), take_every=1)
27 | activations_dataloader = DataLoader(activations_dataset, batch_size=len(activations_dataset), shuffle=False)
28 | activations = next(iter(activations_dataloader))
29 | print(f"Loaded activations found at {args.activations_dir}")
30 | print(f"Activations shape: {activations.shape}")
31 |
32 | # Scale to 0-1 per neuron
33 | min_values = activations.min(dim=0, keepdim=True)[0]
34 | max_values = activations.max(dim=0, keepdim=True)[0]
35 | activations = (activations - min_values) / (max_values - min_values)
36 |
37 | # embeddings = embeddings - embeddings.mean(dim=0, keepdim=True)
38 | num_images, embed_dim = embeddings.shape
39 | num_neurons = activations.shape[1]
40 |
41 | # Initialize accumulators
42 | weighted_cosine_similarity_sum = torch.zeros(num_neurons, device=torch.device(args.device))
43 | weight_sum = torch.zeros(num_neurons, device=torch.device(args.device))
44 | batch_size = 100 # Set batch size
45 |
46 | for i in tqdm.tqdm(range(num_images), desc="Processing image pairs"):
47 | for j_start in range(i + 1, num_images, batch_size): # Process in batches
48 | j_end = min(j_start + batch_size, num_images)
49 |
50 | embeddings_i = embeddings[i].cuda() # (embedding_dim)
51 | embeddings_j = embeddings[j_start:j_end].cuda() # (batch_size, embedding_dim)
52 | activations_i = activations[i].cuda() # (num_neurons)
53 | activations_j = activations[j_start:j_end].cuda() # (batch_size, num_neurons)
54 |
55 | # Compute cosine similarity
56 | cosine_similarities = F.cosine_similarity(
57 | embeddings_i.unsqueeze(0).expand(j_end - j_start, -1), # Expanding to (batch_size, embedding_dim)
58 | embeddings_j,
59 | dim=1
60 | )
61 |
62 | # Compute weights and weighted similarities
63 | # Expanding activations_i to (1, num_neurons)
64 | weights = activations_i.unsqueeze(0) * activations_j # (batch_size, num_neurons)
65 | weighted_cosine_similarities = weights * cosine_similarities.unsqueeze(1) # (batch_size, num_neurons)
66 |
67 | weighted_cosine_similarities = torch.sum(weighted_cosine_similarities, dim=0) # (num_neurons)
68 | weighted_cosine_similarity_sum += weighted_cosine_similarities.cpu()
69 |
70 | weights = torch.sum(weights, dim=0) # (num_neurons)
71 | weight_sum += weights.cpu()
72 |
73 | monosemanticity = torch.where(weight_sum != 0, weighted_cosine_similarity_sum / weight_sum, torch.nan)
74 |
75 | os.makedirs(os.path.join(args.activations_dir, args.output_subdir), exist_ok=True)
76 | torch.save(monosemanticity, os.path.join(args.activations_dir, args.output_subdir, "all_neurons_scores.pth"))
77 |
78 | is_nan = torch.isnan(monosemanticity)
79 | nan_count = is_nan.sum()
80 | monosemanticity_mean = torch.mean(monosemanticity[~is_nan])
81 | monosemanticity_std = torch.std(monosemanticity[~is_nan])
82 |
83 | print(f"Monosemanticity: {monosemanticity_mean.item()} +- {monosemanticity_std.item()}")
84 | print(f"Dead neurons:", nan_count.item())
85 | print(f"Total neurons:", num_neurons)
86 |
87 | # Filter out NaNs
88 | valid_indices = ~torch.isnan(monosemanticity)
89 | valid_monosemanticity = monosemanticity[valid_indices]
90 | valid_indices = torch.nonzero(valid_indices).squeeze()
91 |
92 | # Get top 10 highest and lowest monosemantic neurons
93 | top_10_values, top_10_indices = torch.topk(valid_monosemanticity, 10)
94 | bottom_10_values, bottom_10_indices = torch.topk(valid_monosemanticity, 10, largest=False)
95 |
96 | # Map indices back to original positions
97 | top_10_indices = valid_indices[top_10_indices]
98 | bottom_10_indices = valid_indices[bottom_10_indices]
99 |
100 | # Print results
101 | print("Top 10 most monosemantic neurons:")
102 | for i, (idx, val) in enumerate(zip(top_10_indices, top_10_values)):
103 | print(f"{i + 1}. Neuron {idx.item()} - {val.item()}")
104 |
105 | print("\nBottom 10 least monosemantic neurons:")
106 | for i, (idx, val) in enumerate(zip(bottom_10_indices, bottom_10_values)):
107 | print(f"{i + 1}. Neuron {idx.item()} - {val.item()}")
108 |
109 | # Save to file
110 | output_path = os.path.join(args.activations_dir, args.output_subdir, "metric_stats_new.txt")
111 | with open(output_path, "w") as file:
112 | file.write(f"Monosemanticity: {monosemanticity_mean.item()} +- {monosemanticity_std.item()}\n")
113 | file.write(f"Dead neurons: {nan_count.item()}\n")
114 | file.write(f"Total neurons: {num_neurons}\n\n")
115 |
116 | file.write("Top 10 most monosemantic neurons:\n")
117 | for idx, val in zip(top_10_indices, top_10_values):
118 | file.write(f"Neuron {idx.item()} - {val.item()}\n")
119 |
120 | file.write("\nBottom 10 least monosemantic neurons:\n")
121 | for idx, val in zip(bottom_10_indices, bottom_10_values):
122 | file.write(f"Neuron {idx.item()} - {val.item()}\n")
123 |
124 |
125 | if __name__ == "__main__":
126 | args = get_args_parser()
127 | args = args.parse_args()
128 | main(args)
129 |
--------------------------------------------------------------------------------
/models/clip.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoProcessor, CLIPVisionModelWithProjection
2 | import torch
3 | import torch.nn as nn
4 | from typing import Optional, Tuple
5 |
6 | class Clip:
7 | def __init__(self, model_name, device):
8 | self.device = device
9 | self.model = CLIPVisionModelWithProjection.from_pretrained(f"openai/{model_name}").to(device)
10 | self.processor = AutoProcessor.from_pretrained(f"openai/{model_name}")
11 | self.register = {}
12 | self.attach_methods = {
13 | 'in_mlp': self._attach_in_mlp,
14 | 'post_mlp': self._attach_post_mlp,
15 | 'post_mlp_residual': self._attach_post_mlp_residual,
16 | 'post_projection': self._attach_post_projection,
17 | }
18 |
19 | def encode(self, inputs):
20 | for hook in self.register.keys():
21 | self.register[hook] = []
22 | outputs = self.model(**inputs.to(self.device))
23 | # pooled_output = outputs.pooler_output
24 | # return pooled_output
25 | image_embeds = outputs.image_embeds
26 | return image_embeds
27 |
28 | def attach(self, attachment_point, layer, sae=None):
29 | if attachment_point in self.attach_methods:
30 | self.attach_methods[attachment_point](layer, sae)
31 | self.register[f'{attachment_point}_{layer}'] = []
32 | else:
33 | raise NotImplementedError(f"Attachment point {attachment_point} not implemented")
34 |
35 | def _attach_in_mlp(self, layer, sae):
36 | raise NotImplementedError
37 |
38 | def _attach_post_mlp(self, layer, sae):
39 | raise NotImplementedError
40 |
41 | def _attach_post_mlp_residual(self, layer, sae):
42 | self.model.vision_model.encoder.layers[layer] = CLIPEncoderLayerPostMlpResidual(
43 | self.model.vision_model.encoder.layers[layer],
44 | sae,
45 | layer,
46 | self.register,
47 | )
48 |
49 | def _attach_post_projection(self, layer, sae):
50 | self.model.visual_projection = CLIPProjectionLayer(
51 | self.model.visual_projection,
52 | sae,
53 | layer,
54 | self.register,
55 | )
56 |
57 | class CLIPProjectionLayer(nn.Module):
58 | def __init__(self, projector, sae, layer, register):
59 | super().__init__()
60 | self.projector = projector
61 | self.sae = sae
62 | self.layer = layer
63 | self.register = register
64 |
65 | def forward(self, inputs):
66 | outputs = self.projector(inputs)
67 | if self.sae is not None:
68 | outputs = self.sae.encode(outputs)
69 | self.register[f'post_projection_{self.layer}'].append(outputs.detach().cpu())
70 | outputs = self.sae.decode(outputs)
71 | else:
72 | self.register[f'post_projection_{self.layer}'].append(outputs.detach().cpu())
73 | return outputs
74 |
75 | class CLIPEncoderLayerPostMlpResidual(nn.Module):
76 | def __init__(self, base, sae, layer, register):
77 | super().__init__()
78 | self.embed_dim = base.embed_dim
79 | self.self_attn = base.self_attn
80 | self.layer_norm1 = base.layer_norm1
81 | self.mlp = base.mlp
82 | self.layer_norm2 = base.layer_norm2
83 |
84 | self.sae = sae
85 | self.layer = layer
86 | self.register = register
87 |
88 | def forward(
89 | self,
90 | hidden_states: torch.Tensor,
91 | attention_mask: torch.Tensor,
92 | causal_attention_mask: torch.Tensor,
93 | output_attentions: Optional[bool] = False,
94 | ) -> Tuple[torch.FloatTensor]:
95 | residual = hidden_states
96 |
97 | hidden_states = self.layer_norm1(hidden_states)
98 | hidden_states, attn_weights = self.self_attn(
99 | hidden_states=hidden_states,
100 | attention_mask=attention_mask,
101 | causal_attention_mask=causal_attention_mask,
102 | output_attentions=output_attentions,
103 | )
104 | hidden_states = residual + hidden_states
105 |
106 | residual = hidden_states
107 | hidden_states = self.layer_norm2(hidden_states)
108 | hidden_states = self.mlp(hidden_states)
109 | hidden_states = residual + hidden_states
110 |
111 | if self.sae is not None:
112 | hidden_states = self.sae.encode(hidden_states)
113 | self.register[f'post_mlp_residual_{self.layer}'].append(hidden_states.detach().cpu())
114 | hidden_states = self.sae.decode(hidden_states)
115 | else:
116 | self.register[f'post_mlp_residual_{self.layer}'].append(hidden_states.detach().cpu())
117 |
118 | outputs = (hidden_states,)
119 |
120 | if output_attentions:
121 | outputs += (attn_weights,)
122 |
123 | return outputs
124 |
--------------------------------------------------------------------------------
/models/dino.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoImageProcessor, Dinov2Model
2 | import torch
3 |
4 | class Dino:
5 | def __init__(self, model_name="dinov2-base", device=torch.device("cuda")):
6 | self.device = device
7 | self.model = Dinov2Model.from_pretrained(f"facebook/{model_name}").to(device)
8 | self.processor = AutoImageProcessor.from_pretrained(f"facebook/{model_name}")
9 |
10 | def encode(self, inputs):
11 | outputs = self.model(**inputs.to(self.device))
12 | image_embeds = outputs.pooler_output
13 | return image_embeds
14 |
--------------------------------------------------------------------------------
/models/llava.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoProcessor, LlavaForConditionalGeneration
2 | import torch
3 | import torch.nn as nn
4 | from typing import Optional, Tuple
5 | import copy
6 |
7 | class Llava:
8 |
9 | def __init__(self, device):
10 | self.device = device
11 | self.layer = 22
12 | self.model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf",
13 | torch_dtype=torch.float16,
14 | device_map=self.device)
15 | self.processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
16 | self.base_CLIPEncoderLayerPostMlpResidual = copy.deepcopy(
17 | self.model.vision_tower.vision_model.encoder.layers[self.layer]
18 | )
19 |
20 | def prompt(self, text, image, max_tokens=5):
21 | conversation = [
22 | {
23 | "role": "user",
24 | "content": [
25 | {"type": "image"},
26 | {"type": "text", "text": text},
27 | ],
28 | },
29 | ]
30 | prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
31 | inputs = self.processor(images=[image], text=[prompt],
32 | padding=True, return_tensors="pt").to(self.model.device, torch.float16)
33 | generate_ids = self.model.generate(**inputs, max_new_tokens=max_tokens)
34 | output = self.processor.batch_decode(generate_ids, skip_special_tokens=True)
35 | output = [x.split('ASSISTANT: ')[-1] for x in output]
36 | return output
37 |
38 | def attach_and_fix(self, sae, neurons_to_fix={}, pre_zero=False):
39 | modified_sae = SAEWrapper(sae, neurons_to_fix, pre_zero)
40 | self.model.vision_tower.vision_model.encoder.layers[self.layer] = CLIPEncoderLayerPostMlpResidual(
41 | self.base_CLIPEncoderLayerPostMlpResidual,
42 | modified_sae,
43 | )
44 |
45 |
46 | class SAEWrapper(nn.Module):
47 |
48 | def __init__(self, sae, neurons_to_fix, pre_zero):
49 | super().__init__()
50 | self.sae = sae
51 | self.neurons_to_fix = neurons_to_fix
52 | self.pre_zero = pre_zero
53 |
54 | def encode(self, x):
55 | x = self.sae.encode(x)
56 | if self.pre_zero:
57 | x = torch.zeros_like(x)
58 | for neuron_id, value in self.neurons_to_fix.items():
59 | x[:, :, neuron_id] = value
60 | return x
61 |
62 | def decode(self, x):
63 | x = self.sae.decode(x)
64 | x = x.to(dtype=torch.float16)
65 | return x
66 |
67 |
68 |
69 |
70 | class CLIPEncoderLayerPostMlpResidual(nn.Module):
71 |
72 | def __init__(self, base, sae):
73 | super().__init__()
74 | self.embed_dim = base.embed_dim
75 | self.self_attn = base.self_attn
76 | self.layer_norm1 = base.layer_norm1
77 | self.mlp = base.mlp
78 | self.layer_norm2 = base.layer_norm2
79 | self.sae = sae
80 |
81 | def forward(
82 | self,
83 | hidden_states: torch.Tensor,
84 | attention_mask: torch.Tensor,
85 | causal_attention_mask: torch.Tensor,
86 | output_attentions: Optional[bool] = False,
87 | ) -> Tuple[torch.FloatTensor]:
88 | residual = hidden_states
89 | hidden_states = self.layer_norm1(hidden_states)
90 | hidden_states, attn_weights = self.self_attn(
91 | hidden_states=hidden_states,
92 | attention_mask=attention_mask,
93 | causal_attention_mask=causal_attention_mask,
94 | output_attentions=output_attentions,
95 | )
96 | hidden_states = residual + hidden_states
97 | residual = hidden_states
98 | hidden_states = self.layer_norm2(hidden_states)
99 | hidden_states = self.mlp(hidden_states)
100 | hidden_states = residual + hidden_states
101 | encoded_hidden_states = self.sae.encode(hidden_states)
102 | decoded_hidden_states = self.sae.decode(encoded_hidden_states)
103 | hidden_states = decoded_hidden_states
104 | outputs = (hidden_states,)
105 | if output_attentions:
106 | outputs += (attn_weights,)
107 | return outputs
--------------------------------------------------------------------------------
/models/siglip.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoProcessor, SiglipVisionModel
2 | import torch
3 | import torch.nn as nn
4 | from typing import Optional, Tuple
5 |
6 |
7 | # siglip-so400m-patch14-384
8 | class Siglip:
9 | def __init__(self, model_name, device):
10 | self.device = device
11 | self.model = SiglipVisionModel.from_pretrained(f"google/{model_name}").to(device)
12 | self.processor = AutoProcessor.from_pretrained(f"google/{model_name}")
13 | self.register = {}
14 | self.attach_methods = {
15 | 'post_mlp_residual': self._attach_post_mlp_residual,
16 | 'post_projection': self._attach_post_projection,
17 | }
18 | self.sae = None
19 | self.layer = None
20 |
21 | def encode(self, inputs):
22 | for hook in self.register.keys():
23 | self.register[hook] = []
24 | outputs = self.model(**inputs.to(self.device))
25 | pooled_output = outputs.pooler_output
26 |
27 | if self.sae is not None:
28 | pooled_output = self.sae.encode(pooled_output)
29 | self.register[f'post_projection_{self.layer}'].append(pooled_output.detach().cpu())
30 | pooled_output = self.sae.decode(pooled_output)
31 | elif self.layer is not None:
32 | self.register[f'post_projection_{self.layer}'].append(pooled_output.detach().cpu())
33 |
34 | return pooled_output
35 |
36 | def attach(self, attachment_point, layer, sae=None):
37 | if attachment_point in self.attach_methods:
38 | self.attach_methods[attachment_point](layer, sae)
39 | self.register[f'{attachment_point}_{layer}'] = []
40 | else:
41 | raise NotImplementedError(f"Attachment point {attachment_point} not implemented")
42 |
43 | def _attach_post_mlp_residual(self, layer, sae):
44 | self.model.vision_model.encoder.layers[layer] = SiglipEncoderLayerPostMlpResidual(
45 | self.model.vision_model.encoder.layers[layer],
46 | sae,
47 | layer,
48 | self.register,
49 | )
50 |
51 | def _attach_post_projection(self, layer, sae):
52 | self.sae = sae
53 | self.layer = layer
54 |
55 |
56 | class SiglipEncoderLayerPostMlpResidual(nn.Module):
57 | def __init__(self, base, sae, layer, register):
58 | super().__init__()
59 | self.embed_dim = base.embed_dim
60 | self.self_attn = base.self_attn
61 | self.layer_norm1 = base.layer_norm1
62 | self.mlp = base.mlp
63 | self.layer_norm2 = base.layer_norm2
64 |
65 | self.sae = sae
66 | self.layer = layer
67 | self.register = register
68 |
69 | def forward(
70 | self,
71 | hidden_states: torch.Tensor,
72 | attention_mask: torch.Tensor,
73 | output_attentions: Optional[bool] = False,
74 | ) -> Tuple[torch.FloatTensor]:
75 | residual = hidden_states
76 |
77 | hidden_states = self.layer_norm1(hidden_states)
78 | hidden_states, attn_weights = self.self_attn(
79 | hidden_states=hidden_states,
80 | attention_mask=attention_mask,
81 | output_attentions=output_attentions,
82 | )
83 | hidden_states = residual + hidden_states
84 |
85 | residual = hidden_states
86 | hidden_states = self.layer_norm2(hidden_states)
87 | hidden_states = self.mlp(hidden_states)
88 | hidden_states = residual + hidden_states
89 |
90 | if self.sae is not None:
91 | hidden_states = self.sae.encode(hidden_states)
92 | self.register[f'post_mlp_residual_{self.layer}'].append(hidden_states.detach().cpu())
93 | hidden_states = self.sae.decode(hidden_states)
94 | else:
95 | self.register[f'post_mlp_residual_{self.layer}'].append(hidden_states.detach().cpu())
96 |
97 | outputs = (hidden_states,)
98 |
99 | if output_attentions:
100 | outputs += (attn_weights,)
101 |
102 | return outputs
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.1.1
2 | aiohappyeyeballs==2.4.3
3 | aiohttp==3.11.8
4 | aiosignal==1.3.1
5 | annotated-types==0.7.0
6 | anyio==4.7.0
7 | argon2-cffi==23.1.0
8 | argon2-cffi-bindings==21.2.0
9 | arrow==1.3.0
10 | asttokens==3.0.0
11 | async-lru==2.0.4
12 | attrs==24.2.0
13 | babel==2.16.0
14 | beautifulsoup4==4.12.3
15 | bidict==0.23.1
16 | bleach==6.2.0
17 | certifi==2024.8.30
18 | cffi==1.17.1
19 | charset-normalizer==3.4.0
20 | circuitsvis==1.43.2
21 | click==8.1.7
22 | comm==0.2.2
23 | contourpy==1.3.1
24 | cycler==0.12.1
25 | datasets==3.1.0
26 | debugpy==1.8.9
27 | decorator==5.1.1
28 | defusedxml==0.7.1
29 | diffusers==0.31.0
30 | dill==0.3.8
31 | docker-pycreds==0.4.0
32 | einops==0.8.0
33 | executing==2.1.0
34 | fastjsonschema==2.21.1
35 | filelock==3.16.1
36 | fonttools==4.55.0
37 | fqdn==1.5.1
38 | frozenlist==1.5.0
39 | fsspec==2024.9.0
40 | gitdb==4.0.11
41 | GitPython==3.1.43
42 | h11==0.14.0
43 | httpcore==1.0.7
44 | httpx==0.28.1
45 | huggingface-hub==0.26.3
46 | idna==3.10
47 | importlib_metadata==8.5.0
48 | ipykernel==6.29.5
49 | ipython==8.30.0
50 | ipywidgets==8.1.5
51 | isoduration==20.11.0
52 | jedi==0.19.2
53 | Jinja2==3.1.4
54 | joblib==1.4.2
55 | json5==0.10.0
56 | jsonpointer==3.0.0
57 | jsonschema==4.23.0
58 | jsonschema-specifications==2024.10.1
59 | jupyter==1.1.1
60 | jupyter-console==6.6.3
61 | jupyter-events==0.10.0
62 | jupyter-lsp==2.2.5
63 | jupyter_client==8.6.3
64 | jupyter_core==5.7.2
65 | jupyter_server==2.14.2
66 | jupyter_server_terminals==0.5.3
67 | jupyterlab==4.3.3
68 | jupyterlab_pygments==0.3.0
69 | jupyterlab_server==2.27.3
70 | jupyterlab_widgets==3.0.13
71 | kiwisolver==1.4.7
72 | llvmlite==0.43.0
73 | MarkupSafe==3.0.2
74 | matplotlib==3.9.3
75 | matplotlib-inline==0.1.7
76 | mistune==3.0.2
77 | mpmath==1.3.0
78 | multidict==6.1.0
79 | multiprocess==0.70.16
80 | nbclient==0.10.1
81 | nbconvert==7.16.4
82 | nbformat==5.10.4
83 | nest-asyncio==1.6.0
84 | networkx==3.4.2
85 | nltk==3.9.1
86 | nnsight==0.3.3
87 | notebook==7.3.1
88 | notebook_shim==0.2.4
89 | numba==0.60.0
90 | numpy==1.26.4
91 | nvidia-cublas-cu12==12.1.3.1
92 | nvidia-cuda-cupti-cu12==12.1.105
93 | nvidia-cuda-nvrtc-cu12==12.1.105
94 | nvidia-cuda-runtime-cu12==12.1.105
95 | nvidia-cudnn-cu12==8.9.2.26
96 | nvidia-cufft-cu12==11.0.2.54
97 | nvidia-curand-cu12==10.3.2.106
98 | nvidia-cusolver-cu12==11.4.5.107
99 | nvidia-cusparse-cu12==12.1.0.106
100 | nvidia-nccl-cu12==2.18.1
101 | nvidia-nvjitlink-cu12==12.6.85
102 | nvidia-nvtx-cu12==12.1.105
103 | overrides==7.7.0
104 | packaging==24.2
105 | pandas==2.2.3
106 | pandocfilters==1.5.1
107 | parso==0.8.4
108 | pexpect==4.9.0
109 | pillow==11.0.0
110 | platformdirs==4.3.6
111 | plotly==5.24.1
112 | prometheus_client==0.21.1
113 | prompt_toolkit==3.0.48
114 | propcache==0.2.0
115 | protobuf==5.29.0
116 | psutil==6.1.0
117 | ptyprocess==0.7.0
118 | pure_eval==0.2.3
119 | pyarrow==18.1.0
120 | pycparser==2.22
121 | pydantic==2.10.2
122 | pydantic_core==2.27.1
123 | Pygments==2.18.0
124 | pynndescent==0.5.13
125 | pyparsing==3.2.0
126 | python-dateutil==2.9.0.post0
127 | python-engineio==4.10.1
128 | python-json-logger==2.0.7
129 | python-socketio==5.11.4
130 | pytz==2024.2
131 | PyYAML==6.0.2
132 | pyzmq==26.2.0
133 | referencing==0.35.1
134 | regex==2024.11.6
135 | requests==2.32.3
136 | rfc3339-validator==0.1.4
137 | rfc3986-validator==0.1.1
138 | rpds-py==0.22.3
139 | safetensors==0.4.5
140 | scikit-learn==1.5.2
141 | scipy==1.14.1
142 | Send2Trash==1.8.3
143 | sentencepiece==0.2.0
144 | sentry-sdk==2.19.0
145 | setproctitle==1.3.4
146 | simple-websocket==1.1.0
147 | six==1.16.0
148 | smmap==5.0.1
149 | sniffio==1.3.1
150 | soupsieve==2.6
151 | stack-data==0.6.3
152 | sympy==1.13.3
153 | tenacity==9.0.0
154 | terminado==0.18.1
155 | threadpoolctl==3.5.0
156 | tinycss2==1.4.0
157 | tokenizers==0.20.3
158 | torch==2.1.2
159 | torchvision==0.16.2
160 | tornado==6.4.2
161 | tqdm==4.67.1
162 | traitlets==5.14.3
163 | transformers==4.46.3
164 | triton==2.1.0
165 | types-python-dateutil==2.9.0.20241206
166 | typing_extensions==4.12.2
167 | tzdata==2024.2
168 | umap-learn==0.5.7
169 | uri-template==1.3.0
170 | urllib3==2.2.3
171 | wandb==0.18.7
172 | wcwidth==0.2.13
173 | webcolors==24.11.1
174 | webencodings==0.5.1
175 | websocket-client==1.8.0
176 | widgetsnbextension==4.0.13
177 | wsproto==1.2.0
178 | xxhash==3.5.0
179 | yarl==1.18.0
180 | zipp==3.21.0
181 | zstandard==0.23.0
182 |
--------------------------------------------------------------------------------
/sae_train.py:
--------------------------------------------------------------------------------
1 | from dictionary_learning import ActivationBuffer, AutoEncoder, JumpReluAutoEncoder
2 | from dictionary_learning.trainers import *
3 | from dictionary_learning.training import trainSAE
4 | from torch.utils.data import DataLoader
5 | from datasets.activations import ActivationsDataset
6 | import torch
7 | from pathlib import Path
8 | import argparse
9 |
10 | def get_args_parser():
11 | parser = argparse.ArgumentParser("Train Sparse Autoencoder", add_help=False)
12 | parser.add_argument("--sae_model", default="jumprelu", type=str)
13 | parser.add_argument("--activations_dir", required=True, type=str)
14 | parser.add_argument("--val_activations_dir", required=True, type=str)
15 | parser.add_argument("--checkpoints_dir", default="./output_dir", type=str)
16 | parser.add_argument("--device", default="cuda:0")
17 | parser.add_argument("--expansion_factor", type=int, default=1)
18 | parser.add_argument("--lr", type=float)
19 | parser.add_argument("--batch_size", type=int, default=8192)
20 | parser.add_argument("--steps", type=int, default=10_000)
21 | parser.add_argument("--save_steps", type=int, default=1_000)
22 | parser.add_argument("--log_steps", type=int, default=50)
23 | # JumpRelu
24 | parser.add_argument("--bandwidth", type=float, default=0.001)
25 | parser.add_argument("--sparsity_penalty", type=float, default=0.1)
26 | # Standard
27 | parser.add_argument("--l1_penalty", type=float, default=0.1)
28 | parser.add_argument("--warmup_steps", type=int, default=0)
29 | parser.add_argument("--resample_steps", type=int, default=None)
30 | # TopK + Batch TopK
31 | parser.add_argument("--k", type=int, default=8)
32 | parser.add_argument("--auxk_alpha", type=float, default=1/32)
33 | parser.add_argument("--decay_start", type=int, default=1_000_000)
34 | # Batch TopK
35 | parser.add_argument("--threshold_beta", type=float, default=0.999)
36 | parser.add_argument("--threshold_start_step", type=int, default=1000)
37 | # MatryoshkaBatchTopK
38 | parser.add_argument("--group_fractions", type=float, nargs="+")
39 |
40 | return parser
41 |
42 | def train_sae(args):
43 | dataset = ActivationsDataset(args.activations_dir, device=torch.device(args.device))
44 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
45 |
46 | val_dataset = ActivationsDataset(args.val_activations_dir, device=torch.device(args.device))
47 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
48 |
49 | sample = next(iter(dataloader))
50 | print(sample.shape)
51 |
52 | activation_dim = sample.shape[1]
53 | dictionary_size = args.expansion_factor * activation_dim
54 |
55 | trainers = {
56 | 'jumprelu': JumpReluTrainer,
57 | 'standard': StandardTrainer,
58 | 'batch_top_k': BatchTopKTrainer,
59 | 'top_k': TopKTrainer,
60 | 'matroyshka_batch_top_k': MatroyshkaBatchTopKTrainer,
61 | }
62 |
63 | # autoencoders = {
64 | # 'jumprelu': JumpReluAutoEncoder,
65 | # 'standard': AutoEncoder,
66 | # 'batch_top_k': BatchTopKSAE,
67 | # 'top_k': AutoEncoderTopK,
68 | # }
69 |
70 | trainer_cfg = {
71 | "trainer": trainers[args.sae_model],
72 | "activation_dim": activation_dim,
73 | "dict_size": dictionary_size,
74 | "lr": args.lr,
75 | "device": args.device,
76 | "steps": args.steps,
77 | "layer": "",
78 | "lm_name": "",
79 | "submodule_name": ""
80 | }
81 |
82 | if args.sae_model == "jumprelu":
83 | trainer_cfg["bandwidth"] = args.bandwidth
84 | trainer_cfg["sparsity_penalty"] = args.sparsity_penalty
85 | if args.sae_model == "standard":
86 | trainer_cfg["l1_penalty"] = args.l1_penalty
87 | trainer_cfg["warmup_steps"] = args.warmup_steps
88 | trainer_cfg["resample_steps"] = args.resample_steps
89 | if args.sae_model == "top_k" or args.sae_model == "batch_top_k" or args.sae_model == "matroyshka_batch_top_k":
90 | trainer_cfg["k"] = args.k
91 | trainer_cfg["auxk_alpha"] = args.auxk_alpha
92 | trainer_cfg["decay_start"] = args.decay_start
93 | if args.sae_model == "batch_top_k" or args.sae_model == "matroyshka_batch_top_k":
94 | trainer_cfg["threshold_beta"] = args.threshold_beta
95 | trainer_cfg["threshold_start_step"] = args.threshold_start_step
96 | if args.sae_model == "matroyshka_batch_top_k":
97 | trainer_cfg["group_fractions"] = args.group_fractions
98 |
99 | dataset_name = Path(args.activations_dir).name
100 | save_dir = Path(args.checkpoints_dir) / f"{dataset_name}_{args.sae_model}_{args.k}_x{args.expansion_factor}"
101 | save_dir.mkdir(parents=True, exist_ok=True)
102 |
103 | ae = trainSAE(
104 | data=dataloader,
105 | val_data=val_dataloader,
106 | trainer_configs=[trainer_cfg],
107 | use_wandb=True,
108 | wandb_entity="mateuszpach",
109 | wandb_project="Clip SAE",
110 | steps=args.steps,
111 | save_steps=[x for x in range(0, args.steps, args.save_steps)],
112 | save_dir=save_dir,
113 | log_steps=args.log_steps,
114 | )
115 |
116 |
117 | if __name__ == "__main__":
118 | args = get_args_parser()
119 | args = args.parse_args()
120 | train_sae(args)
--------------------------------------------------------------------------------
/save_activations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os
4 | from torch.utils.data import DataLoader
5 | import tqdm
6 | import argparse
7 | from pathlib import Path
8 | from torchvision.datasets import ImageFolder
9 | from utils import get_dataset, get_model
10 | from torchvision.transforms import ToTensor
11 | from dictionary_learning import AutoEncoder
12 | from dictionary_learning.trainers import BatchTopKSAE, MatroyshkaBatchTopKSAE
13 |
14 |
15 | def get_args_parser():
16 | parser = argparse.ArgumentParser("Save activations used to train SAE", add_help=False)
17 | parser.add_argument("--batch_size", default=128, type=int)
18 | parser.add_argument("--sae_model", default=None, type=str)
19 | parser.add_argument("--model_name", default="clip", type=str)
20 | parser.add_argument("--attachment_point", default="post_mlp_residual", type=str)
21 | parser.add_argument("--layer", default=-1, type=int)
22 | parser.add_argument("--sae_path", default=None, type=str)
23 | parser.add_argument("--dataset_name", default="imagenet", type=str)
24 | parser.add_argument("--data_path", default="/shared-network/inat2021", type=str)
25 | parser.add_argument("--split", default="train", type=str)
26 | parser.add_argument("--num_workers", default=10, type=int)
27 | parser.add_argument("--output_dir", default="./output_dir", type=str)
28 | parser.add_argument("--cls_only", default=False, action="store_true")
29 | parser.add_argument("--mean_pool", default=False, action="store_true")
30 | parser.add_argument("--take_every", default=1, type=int)
31 | parser.add_argument("--random_k", default=-1, type=int)
32 | parser.add_argument("--save_every", default=50_000, type=int)
33 | parser.add_argument("--device", default="cuda:0")
34 | return parser
35 |
36 | def save_activations(activations, count, split, save_count, args):
37 | activations_tensor = torch.cat(activations, dim=0)
38 | if args.take_every > 1:
39 | # Pick every n-th activation in the batch
40 | activations_tensor = activations_tensor[::args.take_every, :, :]
41 |
42 | if args.layer == -1:
43 | # Tokens already pooled
44 | activations_tensor = activations_tensor
45 | elif args.cls_only:
46 | # Keep only CLS token
47 | activations_tensor = activations_tensor[:, 0, :]
48 | elif args.mean_pool:
49 | # Mean pool tokens into one data point
50 | activations_tensor = torch.mean(activations_tensor, dim=1)
51 | elif args.random_k != -1:
52 | # Treat each token as a separate data point but pick random k tokens from each image
53 | batch_size, seq_len, hidden_dim = activations_tensor.shape
54 | indices = torch.randint(0, seq_len, (batch_size, args.random_k))
55 | activations_tensor = torch.stack([activations_tensor[i, indices[i], :] for i in range(batch_size)])
56 | activations_tensor = activations_tensor.reshape(-1, hidden_dim)
57 | else:
58 | # Treat each token as a separate data point and use all the tokens
59 | activations_tensor = activations_tensor.reshape(activations_tensor.shape[0] * activations_tensor.shape[1],
60 | activations_tensor.shape[2])
61 |
62 | filename = f"{args.dataset_name}_{split}_activations_{args.model_name}_{args.layer}_{args.attachment_point}_part{save_count + 1}.pt"
63 | save_path = os.path.join(args.output_dir, filename)
64 | torch.save(torch.tensor(activations_tensor.cpu().numpy()), save_path)
65 | print(f"Saved the activations at count {count} to {save_path}")
66 |
67 | def collect_activations(args):
68 | model, processor = get_model(args)
69 |
70 | if args.sae_model is not None:
71 | if args.sae_model == "standard":
72 | sae = AutoEncoder.from_pretrained(args.sae_path).to(args.device)
73 | if args.sae_model == "batch_top_k":
74 | sae = BatchTopKSAE.from_pretrained(args.sae_path).to(args.device)
75 | if args.sae_model == "matroyshka_batch_top_k":
76 | sae = MatroyshkaBatchTopKSAE.from_pretrained(args.sae_path).to(args.device)
77 | print(f"Attached SAE from {args.sae_path}")
78 | else:
79 | sae = None
80 | print(f"No SAE attached. Saving original activations")
81 |
82 | model.attach(args.attachment_point, args.layer, sae=sae)
83 |
84 | ds, dl = get_dataset(args, preprocess=None, processor=processor, split=args.split)
85 | activations = []
86 | count = 0
87 | save_count = 0
88 | pbar = tqdm.tqdm(dl)
89 | for image in pbar:
90 |
91 | with torch.no_grad():
92 | model.encode(image)
93 | activations.extend(model.register[f"{args.attachment_point}_{args.layer}"])
94 |
95 | count += image['pixel_values'].shape[0]
96 | pbar.set_postfix({'Processed data points': count})
97 |
98 | if count >= args.save_every * (save_count + 1):
99 | save_activations(activations, count, args.split, save_count, args)
100 | activations = []
101 | save_count += 1
102 |
103 | if activations:
104 | save_activations(activations, count, args.split, save_count, args)
105 |
106 |
107 | if __name__ == "__main__":
108 | args = get_args_parser()
109 | args = args.parse_args()
110 | if args.output_dir:
111 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
112 | collect_activations(args)
--------------------------------------------------------------------------------
/scripts/matryoshka_hierarchy.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATASET_PATH="${INAT_PATH}"
4 |
5 | # 1. Save original activations
6 | for SPLIT in "train" "val"; do
7 | python save_activations.py \
8 | --batch_size 32 \
9 | --model_name "clip-vit-large-patch14-336" \
10 | --attachment_point "post_projection" \
11 | --layer "-1" \
12 | --dataset_name "inat" \
13 | --split "${SPLIT}" \
14 | --data_path "${DATASET_PATH}" \
15 | --num_workers 8 \
16 | --output_dir "./activations_dir/raw/inat_${SPLIT}_activations_clip-vit-large-patch14-336_-1_post_projection" \
17 | --cls_only \
18 | --save_every 100
19 | done
20 |
21 | # 2. Train SAE
22 | python sae_train.py \
23 | --sae_model "matroyshka_batch_top_k" \
24 | --activations_dir "activations_dir/raw/inat_train_activations_clip-vit-large-patch14-336_-1_post_projection" \
25 | --val_activations_dir "activations_dir/raw/inat_val_activations_clip-vit-large-patch14-336_-1_post_projection" \
26 | --checkpoints_dir "checkpoints_dir/matroyshka_batch_top_k_20_x2" \
27 | --expansion_factor 2 \
28 | --steps 110000 \
29 | --save_steps 20000 \
30 | --log_steps 10000 \
31 | --batch_size 4096 \
32 | --k 20 \
33 | --auxk_alpha 0.03 \
34 | --decay_start 109999 \
35 | --group_fractions 0.002 0.009 0.035 0.189 0.765
36 |
37 | # 3. Save SAE activations
38 | python save_activations.py \
39 | --batch_size 32 \
40 | --model_name "clip-vit-large-patch14-336" \
41 | --attachment_point "post_projection" \
42 | --layer "-1" \
43 | --dataset_name "inat" \
44 | --split "train" \
45 | --data_path "${DATASET_PATH}" \
46 | --num_workers 8 \
47 | --output_dir "./activations_dir/matroyshka_batch_top_k_20_x2/inat_train_activations_clip-vit-large-patch14-336_-1_post_projection" \
48 | --cls_only \
49 | --save_every 100 \
50 | --sae_model "matroyshka_batch_top_k" \
51 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x2/inat_train_activations_clip-vit-large-patch14-336_-1_post_projection_matroyshka_batch_top_k_20_x2/trainer_0/checkpoints/ae_100000.pt"
52 |
53 | # 4. Compute LCA depth per level
54 | python find_hai_indices.py \
55 | --activations_dir "./activations_dir/matroyshka_batch_top_k_20_x2/inat_train_activations_clip-vit-large-patch14-336_-1_post_projection" \
56 | --dataset_name "inat" \
57 | --data_path "${DATASET_PATH}" \
58 | --split "train" \
59 | --k 16 \
60 | --chunk_size 1000
61 |
62 | python inat_depth.py \
63 | --activations_dir "./activations_dir/matroyshka_batch_top_k_20_x2/inat_train_activations_clip-vit-large-patch14-336_-1_post_projection" \
64 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x2/inat_train_activations_clip-vit-large-patch14-336_-1_post_projection/hai_indices_16.npy" \
65 | --data_path "${DATASET_PATH}" \
66 | --split "train" \
67 | --k 16 \
68 | --group_fractions 0.002 0.009 0.035 0.189 0.765
69 |
--------------------------------------------------------------------------------
/scripts/mllm_steering.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATASET_PATH="${IMAGENET_PATH}"
4 |
5 | # 1. Save original activations
6 | for SPLIT in "train" "val"; do
7 | python save_activations.py \
8 | --batch_size 32 \
9 | --model_name "clip-vit-large-patch14-336" \
10 | --attachment_point "post_mlp_residual" \
11 | --layer 22 \
12 | --dataset_name "imagenet" \
13 | --split "${SPLIT}" \
14 | --data_path "${DATASET_PATH}" \
15 | --num_workers 8 \
16 | --output_dir "./activations_dir/raw/random_k_2/imagenet_${SPLIT}_activations_clip-vit-large-patch14-336_22_post_mlp_residual" \
17 | --random_k 2 \
18 | --save_every 100
19 | done
20 |
21 | # 2. Train SAE
22 | python sae_train.py \
23 | --sae_model "matroyshka_batch_top_k" \
24 | --activations_dir "activations_dir/raw/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual" \
25 | --val_activations_dir "activations_dir/raw/random_k_2/imagenet_val_activations_clip-vit-large-patch14-336_22_post_mlp_residual" \
26 | --checkpoints_dir "checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/" \
27 | --expansion_factor 64 \
28 | --steps 110000 \
29 | --save_steps 20000 \
30 | --log_steps 10000 \
31 | --batch_size 4096 \
32 | --k 20 \
33 | --auxk_alpha 0.03 \
34 | --decay_start 109999 \
35 | --group_fractions 0.0625 0.125 0.25 0.5625
36 |
37 | # 3. Save SAE activations
38 | python save_activations.py \
39 | --batch_size 32 \
40 | --model_name "clip-vit-large-patch14-336" \
41 | --attachment_point "post_mlp_residual" \
42 | --layer 22 \
43 | --dataset_name "imagenet" \
44 | --split "train" \
45 | --data_path "${DATASET_PATH}" \
46 | --num_workers 8 \
47 | --output_dir "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual" \
48 | --mean_pool \
49 | --save_every 100 \
50 | --sae_model "matroyshka_batch_top_k" \
51 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual_matroyshka_batch_top_k_20_x64/trainer_0/checkpoints/ae_100000.pt"
52 |
53 | # 4. Visualize neurons
54 | python find_hai_indices.py \
55 | --activations_dir "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual" \
56 | --dataset_name "imagenet" \
57 | --data_path "${DATASET_PATH}" \
58 | --split "train" \
59 | --k 16 \
60 | --chunk_size 1000
61 |
62 | python visualize_neurons.py \
63 | --output_dir "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual" \
64 | --top_k 16 \
65 | --dataset_name "imagenet" \
66 | --data_path "${DATASET_PATH}" \
67 | --split "train" \
68 | --group_fractions 0.0625 0.125 0.25 0.5625 \
69 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual/hai_indices_16.npy"
70 |
71 | # 5. Compute steering score
72 | python encode_images.py \
73 | --embeddings_path "embeddings_dir/imagenet_train_embeddings_clip-vit-base-patch32.pt" \
74 | --model_name "clip-vit-base-patch32" \
75 | --dataset_name "imagenet" \
76 | --split "train" \
77 | --data_path "${DATASET_PATH}" \
78 | --batch_size 128
79 |
80 | python imagenet_subset.py \
81 | --imagenet_root "${DATASET_PATH}" \
82 | --output_dir "./images_imagenet"
83 |
84 | # no steering (1000 images, 10 neurons)
85 | python steering_score.py \
86 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual/hai_indices_16.npy" \
87 | --embeddings_path "./embeddings_dir/imagenet_train_embeddings_clip-vit-base-patch32.pt" \
88 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual_matroyshka_batch_top_k_20_x64/trainer_0/checkpoints/ae_100000.pt" \
89 | --images_path "./images_imagenet/" \
90 | --no-pre_zero \
91 | --model_name "clip-vit-base-patch32" \
92 | --neuron_prefix 10 \
93 | --no-steer \
94 | --output_path "./llava_results_dir/1000/no_steering/"
95 |
96 | # steering (1000 images, 10 neurons)
97 | python steering_score.py \
98 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual/hai_indices_16.npy" \
99 | --embeddings_path "./embeddings_dir/imagenet_train_embeddings_clip-vit-base-patch32.pt" \
100 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual_matroyshka_batch_top_k_20_x64/trainer_0/checkpoints/ae_100000.pt" \
101 | --images_path "./images_imagenet/" \
102 | --no-pre_zero \
103 | --model_name "clip-vit-base-patch32" \
104 | --neuron_prefix 10 \
105 | --steer \
106 | --output_path "./llava_results_dir/1000/steering/"
107 |
108 | # steering (1 image, 1000 neurons)
109 | python steering_score.py \
110 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual/hai_indices_16.npy" \
111 | --embeddings_path "./embeddings_dir/imagenet_train_embeddings_clip-vit-base-patch32.pt" \
112 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual_matroyshka_batch_top_k_20_x64/trainer_0/checkpoints/ae_100000.pt" \
113 | --images_path "./images/" \
114 | --no-pre_zero \
115 | --model_name "clip-vit-base-patch32" \
116 | --neuron_prefix 1000 \
117 | --no-steer \
118 | --output_path "./llava_results_dir/1/no_steering/"
119 |
120 | python steering_score.py \
121 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x64/mean_pool/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual/hai_indices_16.npy" \
122 | --embeddings_path "./embeddings_dir/imagenet_train_embeddings_clip-vit-base-patch32.pt" \
123 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual_matroyshka_batch_top_k_20_x64/trainer_0/checkpoints/ae_100000.pt" \
124 | --images_path "./images/" \
125 | --no-pre_zero \
126 | --model_name "clip-vit-base-patch32" \
127 | --neuron_prefix 1000 \
128 | --steer \
129 | --output_path "./llava_results_dir/1/steering/"
130 |
131 | # 6. Compute baseline scores
132 | python similarity_baseline.py
133 |
134 | # 7. Finding qualitative examples (e.g. steering pencil neuron)
135 | python steering_qualitative.py
136 |
--------------------------------------------------------------------------------
/scripts/monosemanticity_score.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATASET="imagenet"
4 | DATASET_PATH="${IMAGENET_PATH}"
5 | LAYERS=("-1" 11 17 22 23)
6 | MODEL_NAME="clip-vit-large-patch14-336" # "siglip-so400m-patch14-384"
7 | EXPANSION_FACTORS=(64 16 8 4 2 1)
8 | VISION_ENCODER="clip-vit-base-patch32" # "dinov2-base"
9 |
10 | # 1. Save original activations
11 | for LAYER in "${LAYERS[@]}"; do
12 | if [ "${LAYER}" == "-1" ]; then
13 | POINT="post_projection"
14 | else
15 | POINT="post_mlp_residual"
16 | fi
17 | for SPLIT in "train" "val"; do
18 | python save_activations.py \
19 | --batch_size 32 \
20 | --model_name "${MODEL_NAME}" \
21 | --attachment_point "${POINT}" \
22 | --layer "${LAYER}" \
23 | --dataset_name "${DATASET}" \
24 | --split "${SPLIT}" \
25 | --data_path "${DATASET_PATH}" \
26 | --num_workers 8 \
27 | --output_dir "./activations_dir/raw/${DATASET}_${SPLIT}_activations_${MODEL_NAME}_${LAYER}_${POINT}" \
28 | --cls_only \
29 | --save_every 100
30 | done
31 | done
32 |
33 | # 2. Train SAE
34 | for LAYER in "${LAYERS[@]}"; do
35 | if [ "${LAYER}" == "-1" ]; then
36 | POINT="post_projection"
37 | else
38 | POINT="post_mlp_residual"
39 | fi
40 | for EXPANSION_FACTOR in "${EXPANSION_FACTORS[@]}"; do
41 | python sae_train.py \
42 | --sae_model "matroyshka_batch_top_k" \
43 | --activations_dir "activations_dir/raw/${DATASET}_train_activations_${MODEL_NAME}_${LAYER}_${POINT}" \
44 | --val_activations_dir "activations_dir/raw/${DATASET}_val_activations_${MODEL_NAME}_${LAYER}_${POINT}" \
45 | --checkpoints_dir "checkpoints_dir/matroyshka_batch_top_k_20_x${EXPANSION_FACTOR}" \
46 | --expansion_factor "${EXPANSION_FACTOR}" \
47 | --steps 110000 \
48 | --save_steps 20000 \
49 | --log_steps 10000 \
50 | --batch_size 4096 \
51 | --k 20 \
52 | --auxk_alpha 0.03 \
53 | --decay_start 109999 \
54 | --group_fractions 0.0625 0.125 0.25 0.5625
55 |
56 | python sae_train.py \
57 | --sae_model "batch_top_k" \
58 | --activations_dir "activations_dir/raw/${DATASET}_train_activations_${MODEL_NAME}_${LAYER}_${POINT}" \
59 | --val_activations_dir "activations_dir/raw/${DATASET}_val_activations_${MODEL_NAME}_${LAYER}_${POINT}" \
60 | --checkpoints_dir "checkpoints_dir/batch_top_k_20_x${EXPANSION_FACTOR}" \
61 | --expansion_factor "${EXPANSION_FACTOR}" \
62 | --steps 110000 \
63 | --save_steps 20000 \
64 | --log_steps 10000 \
65 | --batch_size 4096 \
66 | --k 20 \
67 | --auxk_alpha 0.03 \
68 | --decay_start 109999
69 | done
70 | done
71 |
72 | # 3. Save SAE activations
73 | for LAYER in "${LAYERS[@]}"; do
74 | if [ "${LAYER}" == "-1" ]; then
75 | POINT="post_projection"
76 | else
77 | POINT="post_mlp_residual"
78 | fi
79 | for EXPANSION_FACTOR in "${EXPANSION_FACTORS[@]}"; do
80 | for SAE_MODEL in "matroyshka_batch_top_k" "batch_top_k"; do
81 | python save_activations.py \
82 | --batch_size 32 \
83 | --model_name "${MODEL_NAME}" \
84 | --attachment_point "${POINT}" \
85 | --layer "${LAYER}" \
86 | --dataset_name "${DATASET}" \
87 | --split "val" \
88 | --data_path "${DATASET_PATH}" \
89 | --num_workers 8 \
90 | --output_dir "./activations_dir/${SAE_MODEL}_20_x${EXPANSION_FACTOR}/${DATASET}_val_activations_${MODEL_NAME}_${LAYER}_${POINT}" \
91 | --cls_only \
92 | --save_every 100 \
93 | --sae_model "${SAE_MODEL}" \
94 | --sae_path "./checkpoints_dir/${SAE_MODEL}_20_x${EXPANSION_FACTOR}/${DATASET}_train_activations_${MODEL_NAME}_${LAYER}_${POINT}_${SAE_MODEL}_20_x${EXPANSION_FACTOR}/trainer_0/checkpoints/ae_100000.pt"
95 | done
96 | done
97 | done
98 |
99 | # 4. Save vision encoder embeddings
100 | python encode_images.py \
101 | --embeddings_path "embeddings_dir/${DATASET}_val_embeddings_${VISION_ENCODER}.pt" \
102 | --model_name "${VISION_ENCODER}" \
103 | --dataset_name "${DATASET}" \
104 | --split "val" \
105 | --data_path "${DATASET_PATH}" \
106 | --batch_size 128
107 |
108 | # 5. Compute Monosemanticity Score
109 | for LAYER in "${LAYERS[@]}"; do
110 | if [ "${LAYER}" == "-1" ]; then
111 | POINT="post_projection"
112 | else
113 | POINT="post_mlp_residual"
114 | fi
115 | # SAE neurons
116 | for EXPANSION_FACTOR in "${EXPANSION_FACTORS[@]}"; do
117 | for SAE_MODEL in "matroyshka_batch_top_k" "batch_top_k"; do
118 | python metric.py \
119 | --activations_dir "activations_dir/${SAE_MODEL}_20_x${EXPANSION_FACTOR}/${DATASET}_val_activations_${MODEL_NAME}_${LAYER}_${POINT}" \
120 | --embeddings_path ${EMBEDDINGS_PATH} \
121 | --output_subdir "ms_${VISION_ENCODER}"
122 | done
123 | done
124 | # original neurons
125 | python metric.py \
126 | --activations_dir "activations_dir/raw/${DATASET}_val_activations_${MODEL_NAME}_${LAYER}_${POINT}" \
127 | --embeddings_path ${EMBEDDINGS_PATH} \
128 | --output_subdir "ms_${VISION_ENCODER}"
129 | done
130 |
131 | # 6. Visualize neurons (of selected SAE, using training set)
132 | python save_activations.py \
133 | --batch_size 32 \
134 | --model_name "${MODEL_NAME}" \
135 | --attachment_point "post_projection" \
136 | --layer "-1" \
137 | --dataset_name "${DATASET}" \
138 | --split "train" \
139 | --data_path "${DATASET_PATH}" \
140 | --num_workers 8 \
141 | --output_dir "./activations_dir/matroyshka_batch_top_k_20_x4/${DATASET}_train_activations_${MODEL_NAME}_-1_post_projection" \
142 | --cls_only \
143 | --save_every 100 \
144 | --sae_model "matroyshka_batch_top_k" \
145 | --sae_path "./checkpoints_dir/matroyshka_batch_top_k_20_x4/${DATASET}_train_activations_${MODEL_NAME}_-1_post_projection_matroyshka_batch_top_k_20_x4/trainer_0/checkpoints/ae_100000.pt"
146 |
147 | python find_hai_indices.py \
148 | --activations_dir "./activations_dir/matroyshka_batch_top_k_20_x4/${DATASET}_train_activations_${MODEL_NAME}_-1_post_projection" \
149 | --dataset_name "${DATASET}" \
150 | --data_path "${DATASET_PATH}" \
151 | --split "train" \
152 | --k 16 \
153 | --chunk_size 1000
154 |
155 | python visualize_neurons.py \
156 | --output_dir "./activations_dir/matroyshka_batch_top_k_20_x4/${DATASET}_train_activations_${MODEL_NAME}_-1_post_projection" \
157 | --top_k 16 \
158 | --dataset_name "${DATASET}" \
159 | --data_path "${DATASET_PATH}" \
160 | --split "train" \
161 | --group_fractions 0.0625 0.125 0.25 0.5625 \
162 | --hai_indices_path "./activations_dir/matroyshka_batch_top_k_20_x4/${DATASET}_train_activations_${MODEL_NAME}_-1_post_projection/hai_indices_16.npy"
163 |
--------------------------------------------------------------------------------
/steering_qualitative.py:
--------------------------------------------------------------------------------
1 | from models.llava import Llava
2 | from dictionary_learning.trainers import MatroyshkaBatchTopKSAE
3 | from PIL import Image
4 | import requests
5 |
6 | llava = Llava("cuda")
7 | sae_path = "checkpoints_dir/matroyshka_batch_top_k_20_x64/random_k_2/imagenet_train_activations_clip-vit-large-patch14-336_22_post_mlp_residual_matroyshka_batch_top_k_20_x64/trainer_0/checkpoints/ae_100000.pt"
8 | sae = MatroyshkaBatchTopKSAE.from_pretrained(sae_path).cuda()
9 | text = "Write me a short love poem"
10 | url = "https://img.freepik.com/free-photo/cement-texture_1194-5269.jpg?semt=ais_hybrid"
11 | image = Image.open(requests.get(url, stream=True).raw)
12 | for neuron in [39]: # Note: ID of the pencil neuron may change after retraining SAE
13 | for alpha in [0, 30, 40, 50]:
14 | print(f"neuron {neuron}, alpha {alpha}")
15 | llava.attach_and_fix(sae=sae, neurons_to_fix={neuron: alpha}, pre_zero=False)
16 | output = llava.prompt(text, image, max_tokens=60)[0]
17 | print(output)
18 | print("======================================================")
--------------------------------------------------------------------------------
/steering_score.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tqdm
3 | import os
4 | import torch
5 | import torch.nn.functional as F
6 | from models.llava import Llava
7 | from utils import IdentitySAE, get_text_model
8 | import argparse
9 | from dictionary_learning.trainers import MatroyshkaBatchTopKSAE
10 | from PIL import Image
11 | import random
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser(description="Compute CLIP-based score for steering accuracy")
15 | parser.add_argument('--hai_indices_path', type=str, required=True)
16 | parser.add_argument('--embeddings_path', type=str, required=True)
17 | parser.add_argument('--sae_path', type=str, default=None)
18 | parser.add_argument('--images_path', type=str, required=True)
19 | parser.add_argument("--model_name", type=str, required=True)
20 | parser.add_argument("--device", type=str, default="cuda")
21 | parser.add_argument("--pre_zero", action=argparse.BooleanOptionalAction)
22 | parser.add_argument('--output_path', type=str, required=True)
23 | parser.add_argument('--neuron_prefix', type=int, default=None)
24 | parser.add_argument("--steer", action=argparse.BooleanOptionalAction)
25 |
26 | return parser.parse_args()
27 |
28 | if __name__ == "__main__":
29 | # Parse command line arguments
30 | args = parse_args()
31 | args.batch_size = 1 # not used
32 | args.num_workers = 0 # not used
33 |
34 | # Get HAI indices
35 | hai_indices = torch.from_numpy(np.load(args.hai_indices_path)).to(args.device)
36 | print(f"Loaded HAI indices found at {args.hai_indices_path}")
37 | print(f"hai_indices shape: {hai_indices.shape}")
38 |
39 | # Get image embeddings
40 | embeddings = torch.load(args.embeddings_path).to(args.device)
41 | print(f"Loaded embeddings found at {args.embeddings_path}")
42 | print(f"embeddings shape: {embeddings.shape}")
43 |
44 | # Compute mean image embedding of HAI per neuron
45 | hai_embeddings = embeddings[hai_indices]
46 | print(f"hai_embeddings shape: {hai_embeddings.shape}") # (num_neurons, k, embedding_dim)
47 | hai_embeddings = hai_embeddings.mean(dim=1)
48 | print(f"hai_embeddings shape: {hai_embeddings.shape}") # (num_neurons, embedding_dim)
49 |
50 | # Load LLaVA model
51 | llava = Llava(args.device)
52 | if args.sae_path:
53 | sae = MatroyshkaBatchTopKSAE.from_pretrained(args.sae_path).to(args.device)
54 | print(f"Attached SAE from {args.sae_path}")
55 | else:
56 | sae = IdentitySAE()
57 | print(f"Attached Identity SAE (scoring original neurons)")
58 |
59 | # Filter neurons if prefix is given
60 | num_neurons = hai_embeddings.shape[0]
61 | if args.neuron_prefix:
62 | neuron_indices = list(range(args.neuron_prefix))
63 | else:
64 | neuron_indices = list(range(num_neurons))
65 | print(f"Evaluating on {len(neuron_indices)} neurons")
66 |
67 | # Label images while clamping each neuron
68 | image_files = [f for f in os.listdir(args.images_path) if f.endswith(('png', 'jpg', 'jpeg', '.JPEG'))]
69 | print(f"Found {len(image_files)} images in {args.images_path}")
70 | text = "What is shown in this image? Use exactly one word!"
71 | labels = {neuron: [] for neuron in neuron_indices}
72 | if args.steer:
73 | print("Steering")
74 | else:
75 | print("Not steering")
76 | for neuron in tqdm.tqdm(neuron_indices, desc="Processing neurons"):
77 | if args.steer:
78 | llava.attach_and_fix(sae=sae, neurons_to_fix={neuron: 100}, pre_zero=args.pre_zero)
79 | for image_file in image_files:
80 | image_path = os.path.join(args.images_path, image_file)
81 | image = Image.open(image_path)
82 | label = llava.prompt(text, image, max_tokens=5)[0].split(" ")[0]
83 | labels[neuron].append((image_file, label))
84 |
85 | # Save labels
86 | labels_path = os.path.join(os.path.dirname(args.output_path), "labels.txt")
87 | os.makedirs(os.path.dirname(labels_path), exist_ok=True)
88 | with open(labels_path, "w") as f:
89 | for neuron, image_labels in labels.items():
90 | for image_file, label in image_labels:
91 | f.write(f"{neuron},{image_file},{label}\n")
92 | print(f"Labels saved to {labels_path}")
93 |
94 | # Compute text embeddings for neuron-clamped LLaVA labels
95 | text_encoder, tokenizer = get_text_model(args)
96 | label_embeddings = torch.zeros(len(neuron_indices), len(image_files), hai_embeddings.shape[1]).to(args.device)
97 | for i, (neuron, image_labels) in tqdm.tqdm(enumerate(labels.items()), desc="Computing text embeddings"):
98 | for j, label in enumerate(image_labels):
99 | with torch.no_grad():
100 | inputs = tokenizer([label[1]], padding=True, return_tensors="pt").to(args.device)
101 | outputs = text_encoder(**inputs)
102 | label_embeddings[i, j] = outputs.text_embeds
103 |
104 | # Compute cosine similarities
105 | cosine_similarities = []
106 | for i in range(label_embeddings.shape[1]):
107 | cosine_similarities.append(F.cosine_similarity(hai_embeddings[neuron_indices], label_embeddings[:, i], dim=1))
108 | cosine_similarities = torch.cat(cosine_similarities)
109 | print(cosine_similarities.shape)
110 | torch.save(cosine_similarities, os.path.join(os.path.dirname(args.output_path), "scores_per_neuron"))
111 | mean_cosine_similarity = cosine_similarities.mean().item()
112 | std_cosine_similarity = cosine_similarities.std().item()
113 |
114 | print("Mean Cosine Similarity:", mean_cosine_similarity)
115 | print("Standard Deviation Cosine Similarity:", std_cosine_similarity)
116 |
117 | # Save results
118 | with open(os.path.join(os.path.dirname(args.output_path), "metric.txt"), "w") as f:
119 | f.write(f"Mean Cosine Similarity: {mean_cosine_similarity}\n")
120 | f.write(f"Standard Deviation Cosine Similarity: {std_cosine_similarity}\n")
121 |
122 | print("Done")
--------------------------------------------------------------------------------
/uniqueness.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import os
4 | from itertools import combinations
5 | from tqdm import tqdm # Progress bar
6 |
7 |
8 | def parse_args():
9 | parser = argparse.ArgumentParser(description="Measure uniqueness of neurons with pairwise Jaccard Index")
10 | parser.add_argument('--activations_dir', type=str, required=True)
11 | parser.add_argument('--k', type=int, default=16)
12 | return parser.parse_args()
13 |
14 |
15 | def jaccard_index(set1, set2):
16 | """Compute Jaccard index between two sets."""
17 | intersection = len(set1.intersection(set2))
18 | union = len(set1.union(set2))
19 | return intersection / union if union != 0 else 0
20 |
21 |
22 | if __name__ == "__main__":
23 | args = parse_args()
24 |
25 | hai_indices_path = os.path.join(args.activations_dir, f"hai_indices_{args.k}.npy")
26 | worst_scores_path = os.path.join(args.activations_dir, f"hai_indices_{args.k}_worst.npy")
27 |
28 | hai_indices = np.load(hai_indices_path) # (num_neurons, k)
29 | worst_scores = np.load(worst_scores_path) # (num_neurons)
30 |
31 | print(f"Loaded HAI indices from {hai_indices_path}")
32 | print(f"Loaded worst scores from {worst_scores_path}")
33 |
34 | # Correct mask condition
35 | mask = worst_scores != 0 # Keep only neurons that are NOT "dead"
36 | hai_indices = hai_indices[mask]
37 |
38 | print(f"Removed {np.count_nonzero(~mask)} dead (or almost dead) neurons")
39 | print(f"Remaining neurons: {hai_indices.shape[0]}")
40 |
41 | # Compute pairwise Jaccard index
42 | num_neurons = hai_indices.shape[0]
43 | jaccard_scores = []
44 | index_pairs = []
45 |
46 | total_pairs = (num_neurons * (num_neurons - 1)) // 2 # Number of unique pairs
47 |
48 | for i, j in tqdm(combinations(range(num_neurons), 2), total=total_pairs, desc="Computing Jaccard Index"):
49 | set1, set2 = set(hai_indices[i]), set(hai_indices[j])
50 | jaccard = jaccard_index(set1, set2)
51 | jaccard_scores.append(jaccard)
52 | index_pairs.append((i, j))
53 |
54 | jaccard_scores = np.array(jaccard_scores)
55 |
56 | # Sort Jaccard scores
57 | sorted_indices = np.argsort(jaccard_scores) # Ascending order
58 |
59 | # Extract top and bottom 10 index pairs
60 | top_10 = [(index_pairs[i], jaccard_scores[i]) for i in sorted_indices[-10:]] # Top 10 highest Jaccard
61 | bottom_10 = [(index_pairs[i], jaccard_scores[i]) for i in sorted_indices[:10]] # Bottom 10 lowest Jaccard
62 |
63 | print(f"Total pairs: {total_pairs}")
64 | # Count how many have Jaccard Index > 0.1 * i
65 | for i in np.arange(0, 1, 0.1):
66 | high_similarity_count = np.count_nonzero(jaccard_scores > i)
67 | print(f"Pairs with Jaccard index > {i:.1f}: {high_similarity_count}")
68 | print(f"Ratio: {high_similarity_count / total_pairs:.4f}")
69 |
70 | print(f"Top 10 Indexes (Most Similar Pairs): {top_10}")
71 | print(f"Bottom 10 Indexes (Most Unique Pairs): {bottom_10}")
72 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader, Subset
2 | from torchvision.datasets import ImageNet, ImageFolder
3 | import torch.nn as nn
4 | from models.clip import Clip
5 | from models.dino import Dino
6 | from models.siglip import Siglip
7 | import os
8 | from transformers import AutoTokenizer, CLIPTextModelWithProjection
9 |
10 | def get_collate_fn(processor):
11 | def collate_fn(batch):
12 | images = [img[0] for img in batch]
13 | return processor(images=images, return_tensors="pt", padding=True)
14 | return collate_fn
15 |
16 | def get_dataset(args, preprocess, processor, split, subset=1.0):
17 | if args.dataset_name == 'cc3m':
18 | # if subset < 1.0:
19 | # raise NotImplementedError
20 | # return get_cc3m(args, preprocess, split)
21 | raise NotImplementedError
22 | elif args.dataset_name == 'inat_birds':
23 | ds = ImageFolder(root=os.path.join(args.data_path, split), transform=preprocess)
24 | elif args.dataset_name == 'inat':
25 | ds = ImageFolder(root=os.path.join(args.data_path, split), transform=preprocess)
26 | elif args.dataset_name == 'imagenet':
27 | ds = ImageNet(root=args.data_path, split=split, transform=preprocess)
28 | elif args.dataset_name == 'cub':
29 | ds = ImageFolder(root=os.path.join(args.data_path, split), transform=preprocess)
30 |
31 | keep_every = int(1.0 / subset)
32 | if keep_every > 1:
33 | ds = Subset(ds, list(range(0, len(ds), keep_every)))
34 | if processor is not None:
35 | dl = DataLoader(ds, batch_size=args.batch_size, shuffle=False,
36 | num_workers=args.num_workers, collate_fn=get_collate_fn(processor))
37 | else:
38 | dl = DataLoader(ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
39 |
40 | return ds, dl
41 |
42 | def get_model(args):
43 | if args.model_name.startswith('clip'):
44 | clip = Clip(args.model_name, args.device)
45 | return clip, clip.processor
46 | elif args.model_name.startswith('dino'):
47 | dino = Dino(args.model_name, args.device)
48 | return dino, dino.processor
49 | elif args.model_name.startswith('siglip'):
50 | siglip = Siglip(args.model_name, args.device)
51 | return siglip, siglip.processor
52 |
53 | def get_text_model(args):
54 | if args.model_name.startswith('clip'):
55 | model = CLIPTextModelWithProjection.from_pretrained(f"openai/{args.model_name}").to(args.device)
56 | tokenizer = AutoTokenizer.from_pretrained(f"openai/{args.model_name}")
57 | return model, tokenizer
58 |
59 | class IdentitySAE(nn.Module):
60 | def encode(self, x):
61 | return x
62 | def decode(self, x):
63 | return x
64 |
--------------------------------------------------------------------------------
/visualize_neurons.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 | from PIL import Image
4 | import os
5 | from torchvision import transforms
6 | from utils import get_dataset
7 | import argparse
8 | from math import isclose
9 |
10 |
11 | def image_grid(imgs, rows, cols):
12 | assert len(imgs) == rows * cols, "Number of images must match rows * cols."
13 | w, h = imgs[0].size
14 | grid = Image.new("RGB", size=(cols * w, rows * h))
15 | for i, img in enumerate(imgs):
16 | grid.paste(img, box=(i % cols * w, i // cols * h))
17 | return grid
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser(description="Visualize top-k activating images for neurons.")
22 | parser.add_argument('--output_dir', type=str, required=True)
23 | parser.add_argument('--top_k', type=int, default=16)
24 | parser.add_argument("--dataset_name", default="imagenet", type=str)
25 | parser.add_argument("--data_path", default="/shared-network/inat2021", type=str)
26 | parser.add_argument('--split', type=str, default='train')
27 | parser.add_argument('--visualization_size', type=int, default=224)
28 | parser.add_argument('--group_fractions', type=float, nargs='+')
29 | parser.add_argument('--hai_indices_path', type=str)
30 | return parser.parse_args()
31 |
32 |
33 | if __name__ == "__main__":
34 | # Parse command line arguments
35 | args = parse_args()
36 | args.batch_size = 1 # not used
37 | args.num_workers = 0 # not used
38 |
39 | importants = np.load(args.hai_indices_path)
40 | print(f"Loaded HAI indices found at {args.hai_indices_path}", flush=True)
41 | num_neurons = importants.shape[0]
42 |
43 | # Visualize selected images
44 | def _convert_to_rgb(image):
45 | return image.convert("RGB")
46 |
47 | visualization_preprocess = transforms.Compose([
48 | transforms.Resize(size=224, interpolation=Image.BICUBIC),
49 | transforms.CenterCrop(size=(224, 224)),
50 | _convert_to_rgb,
51 | ])
52 |
53 | ds, dl = get_dataset(args, preprocess=visualization_preprocess, processor=None, split=args.split, subset=1)
54 |
55 | os.makedirs(os.path.join(args.output_dir, 'hai'), exist_ok=True)
56 |
57 | assert isclose(sum(args.group_fractions), 1.0), "group_fractions must sum to 1.0"
58 | group_sizes = [int(f * num_neurons) for f in args.group_fractions[:-1]]
59 | group_sizes.append(num_neurons - sum(group_sizes))
60 |
61 | start_idx = 0
62 | for group_idx, group_size in enumerate(group_sizes):
63 | end_idx = start_idx + group_size
64 | group_neurons = range(start_idx, end_idx)
65 |
66 | for neuron_id, absolute_id in enumerate(group_neurons[:5000]):
67 | print(f"Visualizing neuron {neuron_id} (absolute {absolute_id}) in group {group_idx}", flush=True)
68 |
69 | important = importants[absolute_id]
70 | images = [ds[i][0] for i in important]
71 | s = int(np.sqrt(args.top_k))
72 | grid_image = image_grid(images[::-1], rows=s, cols=s)
73 |
74 | plt.imshow(grid_image)
75 | plt.axis('off')
76 | filename = f"group_{group_idx}_neuron_{neuron_id}_absolute_{absolute_id}.png"
77 | output_path = os.path.join(args.output_dir, 'tree', filename)
78 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
79 | plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
80 | plt.close() # Close the plot to free memory
81 |
82 | start_idx = end_idx
83 |
--------------------------------------------------------------------------------