├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── NOTICE
├── README.md
├── analysis
├── __init__.py
├── analysis.ipynb
└── utils.py
├── assets
├── demo_overview.png
├── patchsae.gif
├── ref_1.gif
├── ref_2.gif
├── ref_3.gif
├── ref_imgs_munich.gif
└── sae_arch.gif
├── configs
├── classnames
│ ├── caltech101_classnames.txt
│ ├── cifar100_cat_to_name.json
│ ├── food101_cat_to_name.json
│ ├── imagenet-sketch_classnames.txt
│ ├── imagenet_classnames.txt
│ ├── oxford_flowers_classnames.json
│ └── pet_cat_to_name.json
└── models
│ └── maple
│ └── vit_b16_c2_ep5_batch4_2ctx.yaml
├── demo.ipynb
├── requirements.txt
├── scripts
├── 01_run_train.sh
├── 02_run_compute_feature_data.sh
├── 03_run_class_level.sh
└── 04_run_topk_eval.sh
├── src
├── demo
│ ├── app.py
│ ├── core.py
│ └── utils.py
├── models
│ ├── architecture
│ │ └── maple.py
│ ├── clip
│ │ ├── __init__.py
│ │ ├── bpe_simple_vocab_16e6.txt.gz
│ │ ├── clip.py
│ │ ├── model.py
│ │ └── simple_tokenizer.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── default_config.py
│ │ └── maple.py
│ ├── templates
│ │ └── openai_imagenet_templates.py
│ └── utils.py
└── sae_training
│ ├── config.py
│ ├── hooked_vit.py
│ ├── sae_trainer.py
│ ├── sparse_autoencoder.py
│ ├── utils.py
│ └── vit_activations_store.py
└── tasks
├── README.md
├── classification_with_top_k_masking.py
├── compute_class_wise_sae_activation.py
├── compute_sae_feature_data.py
├── train_sae_vit.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | data/
3 | .vscode/
4 | out/
5 | .DS_Store
6 | *.zip
7 | wandb/
8 | logs/
9 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v5.0.0
4 | hooks:
5 | - id: end-of-file-fixer
6 | - id: trailing-whitespace
7 | - id: check-yaml
8 | - id: check-executables-have-shebangs
9 | - id: check-toml
10 | - id: check-added-large-files
11 | args: [--maxkb=15000]
12 | - id: check-case-conflict
13 | - id: check-merge-conflict
14 | - id: check-symlinks
15 | - id: debug-statements
16 | - id: detect-private-key
17 | - id: mixed-line-ending
18 | args: [--fix=lf]
19 | - id: requirements-txt-fixer
20 | - repo: https://github.com/astral-sh/ruff-pre-commit
21 | rev: v0.11.6
22 | hooks:
23 | - id: ruff
24 | args: [ --fix, --select=I ]
25 | - id: ruff-format
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024-2025 Hyesu Lim, Jinho Choi, Jaegul Choo, Steffen Schneider
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Code from the following third-party repositories was used.
2 | Please include the contents of the LICENSE as well this NOTICE file in all
3 | re-distributions of this code.
4 |
5 | =================================================================================
6 |
7 | https://github.com/muzairkhattak/multimodal-prompt-learning
8 |
9 | MIT License
10 |
11 | Copyright (c) 2022 Muhammad Uzair Khattak
12 | Copyright (c) 2021 Kaiyang Zhou
13 |
14 | Permission is hereby granted, free of charge, to any person obtaining a copy
15 | of this software and associated documentation files (the "Software"), to deal
16 | in the Software without restriction, including without limitation the rights
17 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18 | copies of the Software, and to permit persons to whom the Software is
19 | furnished to do so, subject to the following conditions:
20 |
21 | The above copyright notice and this permission notice shall be included in all
22 | copies or substantial portions of the Software.
23 |
24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30 | SOFTWARE.
31 |
32 | =================================================================================
33 |
34 | https://github.com/muzairkhattak/PromptSRC
35 |
36 | MIT License
37 | Copyright (c) 2023 Muhammad Uzair Khattak
38 | Copyright (c) 2022 Muhammad Uzair Khattak
39 | Copyright (c) 2021 Kaiyang Zhou
40 |
41 | Permission is hereby granted, free of charge, to any person obtaining a copy
42 | of this software and associated documentation files (the "Software"), to deal
43 | in the Software without restriction, including without limitation the rights
44 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
45 | copies of the Software, and to permit persons to whom the Software is
46 | furnished to do so, subject to the following conditions:
47 |
48 | The above copyright notice and this permission notice shall be included in all
49 | copies or substantial portions of the Software.
50 |
51 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
52 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
53 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
54 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
55 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
56 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
57 | SOFTWARE.
58 |
59 | =================================================================================
60 |
61 | https://github.com/HugoFry/mats_sae_training_for_ViTs
62 |
63 | MIT License
64 |
65 | Copyright (c) 2023 Joseph Bloom
66 |
67 | Permission is hereby granted, free of charge, to any person obtaining a copy
68 | of this software and associated documentation files (the "Software"), to deal
69 | in the Software without restriction, including without limitation the rights
70 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
71 | copies of the Software, and to permit persons to whom the Software is
72 | furnished to do so, subject to the following conditions:
73 |
74 | The above copyright notice and this permission notice shall be included in all
75 | copies or substantial portions of the Software.
76 |
77 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
78 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
79 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
80 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
81 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
82 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
83 | SOFTWARE.
84 |
85 | =================================================================================
86 |
87 | https://github.com/jbloomAus/SAELens
88 |
89 | MIT License
90 |
91 | Copyright (c) 2023 Joseph Bloom
92 |
93 | Permission is hereby granted, free of charge, to any person obtaining a copy
94 | of this software and associated documentation files (the "Software"), to deal
95 | in the Software without restriction, including without limitation the rights
96 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
97 | copies of the Software, and to permit persons to whom the Software is
98 | furnished to do so, subject to the following conditions:
99 |
100 | The above copyright notice and this permission notice shall be included in all
101 | copies or substantial portions of the Software.
102 |
103 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
104 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
105 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
106 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
107 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
108 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
109 | SOFTWARE.
110 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PatchSAE: Sparse Autoencoders Reveal Selective Remapping of Visual Concepts During Adaptation
2 |
3 | [](https://dynamical-inference.ai/patchsae/)
4 | [](https://arxiv.org/abs/2412.05276)
5 | [](https://openreview.net/forum?id=imT03YXlG2)
6 | [](https://huggingface.co/spaces/dynamical-inference/patchsae-demo)
7 |
8 |
9 |

10 |
11 |
12 | ## 🚀 Quick Navigation
13 |
14 | - [Getting Started](#-getting-started)
15 | - [Interactive Demo](#-interactive-demo)
16 | - [Training & Analysis](#-patchsae-training-and-analysis)
17 | - [Status Updates](#-status-updates)
18 | - [License & Credits](#-license--credits)
19 |
20 | ## 🛠 Getting Started
21 |
22 | Set up your environment with these simple steps:
23 |
24 | ```bash
25 | # Create and activate environment
26 | conda create --name patchsae python=3.12
27 | conda activate patchsae
28 |
29 | # Install dependencies
30 | pip install -r requirements.txt
31 |
32 | # Always set PYTHONPATH before running any scripts
33 | cd patchsae
34 | PYTHONPATH=./ python src/demo/app.py
35 | ```
36 |
37 | ## 🎮 Interactive Demo
38 |
39 | ### Online Demo on Hugging Face 🤗 [](https://dynamical-inference.ai/patchsae/)
40 |
41 | Explore our pre-computed images and SAE latents without any installation!
42 | > 💡 The demo may experience slowdowns due to network constraints. For optimal performance, consider disabling your VPN if you encounter any delays.
43 |
44 |
45 |

46 |
47 |
48 |
49 | ### Local Demo: Try Your Own Images
50 |
51 | Want to experiment with your own images? Follow these steps:
52 |
53 | #### 1. Setup Local Demo
54 |
55 | First, download the necessary files:
56 |
57 | - [out.zip](https://drive.google.com/file/d/1NJzF8PriKz_mopBY4l8_44R0FVi2uw2g/edit)
58 | - [data.zip](https://drive.google.com/file/d/1reuDjXsiMkntf1JJPLC5a3CcWuJ6Ji3Z/edit)
59 |
60 | You can download the files using `gdown` as follows:
61 |
62 | ```bash
63 | # Activate environment first (see Getting Started)
64 |
65 | # Download necessary files (35MB + 513MB)
66 | gdown --id 1NJzF8PriKz_mopBY4l8_44R0FVi2uw2g # out.zip
67 | gdown --id 1reuDjXsiMkntf1JJPLC5a3CcWuJ6Ji3Z # data.zip
68 |
69 | # Extract files
70 | unzip data.zip
71 | unzip out.zip
72 | ```
73 |
74 | > 💡 Need `gdown`? Install it with: `conda install conda-forge::gdown`
75 |
76 | Your folder structure should look like:
77 |
78 | ```
79 | patchsae/
80 | ├── configs/
81 | ├── data/ # From data.zip
82 | ├── out/ # From out.zip
83 | ├── src/
84 | │ └── demo/
85 | │ └── app.py
86 | ├── tasks/
87 | ├── requirements.txt
88 | └── ... (other files)
89 | ```
90 |
91 | #### 2. Launch the Demo
92 |
93 | ```bash
94 | PYTHONPATH=./ python src/demo/app.py
95 | ```
96 |
97 | ⚠️ **Note**:
98 | - First run will download datasets from HuggingFace automatically (About 30GB in total)
99 | - Demo runs on CPU by default
100 | - Access the interface at http://127.0.0.1:7860 (or the URL shown in terminal)
101 |
102 | ## 📊 PatchSAE Training and Analysis
103 |
104 | - **Training Instructions**: See [tasks/README.md](./tasks/README.md)
105 | - **Analysis Notebooks**:
106 | - [demo.ipynb](./demo.ipynb)
107 | - [analysis.ipynb](./analysis/analysis.ipynb)
108 |
109 | ## 📝 Status Updates
110 |
111 | - **Jan 13, 2025**: Training & Analysis code work properly. Minor error in data loading by class when using ImageNet.
112 | - **Jan 09, 2025**: Analysis code works. Updated training with evaluation during training, fixed optimizer bug.
113 | - **Jan 07, 2025**: Added analysis code. Reproducibility tests completed (trained on ImageNet, tested on Oxford-Flowers).
114 | - **Jan 06, 2025**: Training code updated. Reproducibility testing in progress.
115 | - **Jan 02, 2025**: Training code incomplete in this version. Updates coming soon.
116 |
117 | ## 📜 License & Credits
118 |
119 | ### Reference Implementations
120 |
121 | - [SAE for ViT](https://github.com/HugoFry/mats_sae_training_for_ViTs)
122 | - [SAELens](https://github.com/jbloomAus/SAELens)
123 | - [Differentiable and Fast Geometric Median in NumPy and PyTorch](https://github.com/krishnap25/geom_median)
124 | - [Self-regulating Prompts: Foundational Model Adaptation without Forgetting [ICCV 2023]](https://github.com/muzairkhattak/PromptSRC)
125 | - Used in: `configs/` and `msrc/models/`
126 | - [MaPLe: Multi-modal Prompt Learning CVPR 2023](https://github.com/muzairkhattak/multimodal-prompt-learning)
127 | - Used in: `configs/models/maple/...yaml` and `data/clip/maple/imagenet/model.pth.tar-2`
128 |
129 | ### License Notice
130 |
131 | Our code is distributed under an MIT license, please see the [LICENSE](LICENSE) file for details.
132 | The [NOTICE](NOTICE) file lists license for all third-party code included in this repository.
133 | Please include the contents of the LICENSE and NOTICE files in all re-distributions of this code.
134 |
135 | ---
136 |
137 | ### Citation
138 |
139 | If you find our code or models useful in your work, please cite our [paper](https://arxiv.org/abs/2412.05276):
140 |
141 | ```
142 | @inproceedings{
143 | lim2025patchsae,
144 | title={Sparse autoencoders reveal selective remapping of visual concepts during adaptation},
145 | author={Hyesu Lim and Jinho Choi and Jaegul Choo and Steffen Schneider},
146 | booktitle={The Thirteenth International Conference on Learning Representations},
147 | year={2025},
148 | url={https://openreview.net/forum?id=imT03YXlG2}
149 | }
150 | ```
151 |
--------------------------------------------------------------------------------
/analysis/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/analysis/__init__.py
--------------------------------------------------------------------------------
/analysis/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib.pyplot as plt
4 | import pandas as pd
5 | import plotly.express as px
6 | import torch
7 |
8 | DATASET_INFO = {
9 | "imagenet": {
10 | "path": "evanarlian/imagenet_1k_resized_256",
11 | "split": "train",
12 | },
13 | "imagenet-sketch": {
14 | "path": "clip-benchmark/wds_imagenet_sketch",
15 | "split": "train",
16 | },
17 | "oxford_flowers": {
18 | "path": "nelorth/oxford-flowers",
19 | "split": "train",
20 | },
21 | "caltech101": {
22 | "path": "HuggingFaceM4/Caltech-101",
23 | "split": "train",
24 | "name": "with_background_category",
25 | },
26 | }
27 |
28 |
29 | def calculate_entropy(top_val, top_label, ignore_label_idx: int = None, eps=1e-9):
30 | dict_size = top_label.shape[0]
31 | entropy = torch.zeros(dict_size)
32 |
33 | for i in range(dict_size):
34 | unique_labels, counts = top_label[i].unique(return_counts=True)
35 | if ignore_label_idx is not None:
36 | counts = counts[unique_labels != ignore_label_idx]
37 | unique_labels = unique_labels[unique_labels != ignore_label_idx]
38 | if len(unique_labels) != 0:
39 | if counts.sum().item() < 10:
40 | entropy[i] = -1 # discount as too few datapoints!
41 | else:
42 | summed_probs = torch.zeros_like(unique_labels, dtype=top_val.dtype)
43 | for j, label in enumerate(unique_labels):
44 | summed_probs[j] = top_val[i][top_label[i] == label].sum().item()
45 | summed_probs = summed_probs / summed_probs.sum()
46 | entropy[i] = -torch.sum(summed_probs * torch.log(summed_probs + eps))
47 | else:
48 | entropy[i] = -1
49 | return entropy
50 |
51 |
52 | def load_stats(save_directory: str, device: torch.device):
53 | mean_acts = torch.load(
54 | os.path.join(save_directory, "sae_mean_acts.pt"),
55 | map_location=torch.device(device),
56 | )
57 | sparsity = torch.load(
58 | os.path.join(save_directory, "sae_sparsity.pt"),
59 | map_location=torch.device(device),
60 | )
61 | top_val = torch.load(
62 | os.path.join(save_directory, "max_activating_image_values.pt"),
63 | map_location=torch.device(device),
64 | )
65 | top_idx = torch.load(
66 | os.path.join(save_directory, "max_activating_image_indices.pt"),
67 | map_location=torch.device(device),
68 | )
69 | top_label = torch.load(
70 | os.path.join(save_directory, "max_activating_image_label_indices.pt"),
71 | map_location=torch.device(device),
72 | )
73 | try:
74 | top_entropy = torch.load(
75 | os.path.join(save_directory, "top_entropy.pt"),
76 | map_location=torch.device(device),
77 | )
78 | except: # noqa: E722
79 | print("Calculating top entropy")
80 | top_entropy = calculate_entropy(top_val, top_label)
81 | torch.save(top_entropy, os.path.join(save_directory, "top_entropy.pt"))
82 |
83 | print(f"Stats loaded from {save_directory}")
84 |
85 | stats = {
86 | "mean_acts": mean_acts.to(device),
87 | "sparsity": sparsity.to(device),
88 | "top_val": top_val.to(device),
89 | "top_idx": top_idx.to(device).to(torch.int64),
90 | "top_label": top_label.to(device).to(torch.int64),
91 | "top_entropy": top_entropy.to(device),
92 | }
93 | return stats
94 |
95 |
96 | def get_stats_scatter_plot(stats, mask=None, save_directory: str = None, eps=1e-9):
97 | if mask is None:
98 | mask = torch.ones_like(
99 | stats["sparsity"], dtype=torch.bool, device=stats["sparsity"].device
100 | )
101 |
102 | indices = torch.where(mask)[0]
103 | plotting_data = torch.stack(
104 | [
105 | torch.log10(stats["sparsity"][mask] + eps),
106 | torch.log10(stats["mean_acts"][mask] + eps),
107 | stats["top_entropy"][mask],
108 | indices,
109 | ],
110 | dim=0,
111 | )
112 | plotting_data = plotting_data.transpose(0, 1)
113 |
114 | x_label = "log10(sparsity)"
115 | y_label = "log10(mean_acts)"
116 | color_label = "entropy"
117 | hover_label = "index"
118 |
119 | df = pd.DataFrame(
120 | plotting_data.numpy(), columns=[x_label, y_label, color_label, hover_label]
121 | )
122 | fig = px.scatter(
123 | df,
124 | x=x_label,
125 | y=y_label,
126 | color=color_label,
127 | marginal_x="histogram",
128 | marginal_y="histogram",
129 | opacity=0.5,
130 | hover_data=[hover_label],
131 | )
132 |
133 | if save_directory is not None:
134 | fig.write_image(os.path.join(save_directory, "scatter_plot.png"))
135 | fig.show()
136 |
137 |
138 | def plot_ref_images(stats, dataset, latent_idx: int, plot_top_k: int = 10, eps=1e-9):
139 | resize_size = 224
140 | num_cols = 5
141 | num_rows = plot_top_k // num_cols
142 |
143 | images = []
144 | labels = []
145 |
146 | for i, idx in enumerate(stats["top_idx"][latent_idx][:plot_top_k]):
147 | img = dataset[idx.item()]["image"]
148 | images.append(img.resize((resize_size, resize_size)))
149 | assert dataset[idx.item()]["label"] == stats["top_label"][latent_idx][i], (
150 | "label mismatch, try matching dataset shuffle seed"
151 | )
152 | labels.append(dataset[idx.item()]["label"])
153 |
154 | _, axes = plt.subplots(num_rows, num_cols, figsize=(4.5 * num_cols, 6 * num_rows))
155 |
156 | for i, (image, label) in enumerate(zip(images, labels)):
157 | ax = axes[i // num_cols, i % num_cols]
158 | ax.imshow(image)
159 | top_val_i = torch.log10(stats["top_val"][latent_idx][i] + eps).item()
160 | ax.set_title(f"{label}(a: {top_val_i:.2f})", fontsize=35)
161 | ax.axis("off")
162 |
163 | mean_acts = torch.log10(stats["mean_acts"][latent_idx] + eps).item()
164 | sparsity = torch.log10(stats["sparsity"][latent_idx] + eps).item()
165 | plt.suptitle(
166 | f"Index {latent_idx} (f: {sparsity:.2f}, a: {mean_acts:.2f}, e: {stats['top_entropy'][latent_idx]:.2f})\n",
167 | fontsize=35,
168 | )
169 |
170 | plt.tight_layout()
171 | plt.show()
172 |
--------------------------------------------------------------------------------
/assets/demo_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/demo_overview.png
--------------------------------------------------------------------------------
/assets/patchsae.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/patchsae.gif
--------------------------------------------------------------------------------
/assets/ref_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/ref_1.gif
--------------------------------------------------------------------------------
/assets/ref_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/ref_2.gif
--------------------------------------------------------------------------------
/assets/ref_3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/ref_3.gif
--------------------------------------------------------------------------------
/assets/ref_imgs_munich.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/ref_imgs_munich.gif
--------------------------------------------------------------------------------
/assets/sae_arch.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/sae_arch.gif
--------------------------------------------------------------------------------
/configs/classnames/caltech101_classnames.txt:
--------------------------------------------------------------------------------
1 | accordion
2 | airplanes
3 | anchor
4 | ant
5 | background_google
6 | barrel
7 | bass
8 | beaver
9 | binocular
10 | bonsai
11 | brain
12 | brontosaurus
13 | buddha
14 | butterfly
15 | camera
16 | cannon
17 | car_side
18 | ceiling_fan
19 | cellphone
20 | chair
21 | chandelier
22 | cougar_body
23 | cougar_face
24 | crab
25 | crayfish
26 | crocodile
27 | crocodile_head
28 | cup
29 | dalmatian
30 | dollar_bill
31 | dolphin
32 | dragonfly
33 | electric_guitar
34 | elephant
35 | emu
36 | euphonium
37 | ewer
38 | faces
39 | faces_easy
40 | ferry
41 | flamingo
42 | flamingo_head
43 | garfield
44 | gerenuk
45 | gramophone
46 | grand_piano
47 | hawksbill
48 | headphone
49 | hedgehog
50 | helicopter
51 | ibis
52 | inline_skate
53 | joshua_tree
54 | kangaroo
55 | ketch
56 | lamp
57 | laptop
58 | leopards
59 | llama
60 | lobster
61 | lotus
62 | mandolin
63 | mayfly
64 | menorah
65 | metronome
66 | minaret
67 | motorbikes
68 | nautilus
69 | octopus
70 | okapi
71 | pagoda
72 | panda
73 | pigeon
74 | pizza
75 | platypus
76 | pyramid
77 | revolver
78 | rhino
79 | rooster
80 | saxophone
81 | schooner
82 | scissors
83 | scorpion
84 | sea_horse
85 | snoopy
86 | soccer_ball
87 | stapler
88 | starfish
89 | stegosaurus
90 | stop_sign
91 | strawberry
92 | sunflower
93 | tick
94 | trilobite
95 | umbrella
96 | watch
97 | water_lilly
98 | wheelchair
99 | wild_cat
100 | windsor_chair
101 | wrench
102 | yin_yang
103 |
--------------------------------------------------------------------------------
/configs/classnames/cifar100_cat_to_name.json:
--------------------------------------------------------------------------------
1 | {"0": "apple", "1": "aquarium_fish", "2": "baby", "3": "bear", "4": "beaver", "5": "bed", "6": "bee", "7": "beetle", "8": "bicycle", "9": "bottle", "10": "bowl", "11": "boy", "12": "bridge", "13": "bus", "14": "butterfly", "15": "camel", "16": "can", "17": "castle", "18": "caterpillar", "19": "cattle", "20": "chair", "21": "chimpanzee", "22": "clock", "23": "cloud", "24": "cockroach", "25": "couch", "26": "cra", "27": "crocodile", "28": "cup", "29": "dinosaur", "30": "dolphin", "31": "elephant", "32": "flatfish", "33": "forest", "34": "fox", "35": "girl", "36": "hamster", "37": "house", "38": "kangaroo", "39": "keyboard", "40": "lamp", "41": "lawn_mower", "42": "leopard", "43": "lion", "44": "lizard", "45": "lobster", "46": "man", "47": "maple_tree", "48": "motorcycle", "49": "mountain", "50": "mouse", "51": "mushroom", "52": "oak_tree", "53": "orange", "54": "orchid", "55": "otter", "56": "palm_tree", "57": "pear", "58": "pickup_truck", "59": "pine_tree", "60": "plain", "61": "plate", "62": "poppy", "63": "porcupine", "64": "possum", "65": "rabbit", "66": "raccoon", "67": "ray", "68": "road", "69": "rocket", "70": "rose", "71": "sea", "72": "seal", "73": "shark", "74": "shrew", "75": "skunk", "76": "skyscraper", "77": "snail", "78": "snake", "79": "spider", "80": "squirrel", "81": "streetcar", "82": "sunflower", "83": "sweet_pepper", "84": "table", "85": "tank", "86": "telephone", "87": "television", "88": "tiger", "89": "tractor", "90": "train", "91": "trout", "92": "tulip", "93": "turtle", "94": "wardrobe", "95": "whale", "96": "willow_tree", "97": "wolf", "98": "woman", "99": "worm"}
2 |
--------------------------------------------------------------------------------
/configs/classnames/food101_cat_to_name.json:
--------------------------------------------------------------------------------
1 | {"0": "apple_pie", "1": "baby_back_ribs", "2": "baklava", "3": "beef_carpaccio", "4": "beef_tartare", "5": "beet_salad", "6": "beignets", "7": "bibimbap", "8": "bread_pudding", "9": "breakfast_burrito", "10": "bruschetta", "11": "caesar_salad", "12": "cannoli", "13": "caprese_salad", "14": "carrot_cake", "15": "ceviche", "16": "cheesecake", "17": "cheese_plate", "18": "chicken_curry", "19": "chicken_quesadilla", "20": "chicken_wings", "21": "chocolate_cake", "22": "chocolate_mousse", "23": "churros", "24": "clam_chowder", "25": "club_sandwich", "26": "crab_cakes", "27": "creme_brulee", "28": "croque_madame", "29": "cup_cakes", "30": "deviled_eggs", "31": "donuts", "32": "dumplings", "33": "edamame", "34": "eggs_benedict", "35": "escargots", "36": "falafel", "37": "filet_mignon", "38": "fish_and_chips", "39": "foie_gras", "40": "french_fries", "41": "french_onion_soup", "42": "french_toast", "43": "fried_calamari", "44": "fried_rice", "45": "frozen_yogurt", "46": "garlic_bread", "47": "gnocchi", "48": "greek_salad", "49": "grilled_cheese_sandwich", "50": "grilled_salmon", "51": "guacamole", "52": "gyoza", "53": "hamburger", "54": "hot_and_sour_soup", "55": "hot_dog", "56": "huevos_rancheros", "57": "hummus", "58": "ice_cream", "59": "lasagna", "60": "lobster_bisque", "61": "lobster_roll_sandwich", "62": "macaroni_and_cheese", "63": "macarons", "64": "miso_soup", "65": "mussels", "66": "nachos", "67": "omelette", "68": "onion_rings", "69": "oysters", "70": "pad_thai", "71": "paella", "72": "pancakes", "73": "panna_cotta", "74": "peking_duck", "75": "pho", "76": "pizza", "77": "pork_chop", "78": "poutine", "79": "prime_rib", "80": "pulled_pork_sandwich", "81": "ramen", "82": "ravioli", "83": "red_velvet_cake", "84": "risotto", "85": "samosa", "86": "sashimi", "87": "scallops", "88": "seaweed_salad", "89": "shrimp_and_grits", "90": "spaghetti_bolognese", "91": "spaghetti_carbonara", "92": "spring_rolls", "93": "steak", "94": "strawberry_shortcake", "95": "sushi", "96": "tacos", "97": "takoyaki", "98": "tiramisu", "99": "tuna_tartare", "100": "waffles"}
2 |
--------------------------------------------------------------------------------
/configs/classnames/imagenet_classnames.txt:
--------------------------------------------------------------------------------
1 | n01440764 tench
2 | n01443537 goldfish
3 | n01484850 great white shark
4 | n01491361 tiger shark
5 | n01494475 hammerhead shark
6 | n01496331 electric ray
7 | n01498041 stingray
8 | n01514668 rooster
9 | n01514859 hen
10 | n01518878 ostrich
11 | n01530575 brambling
12 | n01531178 goldfinch
13 | n01532829 house finch
14 | n01534433 junco
15 | n01537544 indigo bunting
16 | n01558993 American robin
17 | n01560419 bulbul
18 | n01580077 jay
19 | n01582220 magpie
20 | n01592084 chickadee
21 | n01601694 American dipper
22 | n01608432 kite (bird of prey)
23 | n01614925 bald eagle
24 | n01616318 vulture
25 | n01622779 great grey owl
26 | n01629819 fire salamander
27 | n01630670 smooth newt
28 | n01631663 newt
29 | n01632458 spotted salamander
30 | n01632777 axolotl
31 | n01641577 American bullfrog
32 | n01644373 tree frog
33 | n01644900 tailed frog
34 | n01664065 loggerhead sea turtle
35 | n01665541 leatherback sea turtle
36 | n01667114 mud turtle
37 | n01667778 terrapin
38 | n01669191 box turtle
39 | n01675722 banded gecko
40 | n01677366 green iguana
41 | n01682714 Carolina anole
42 | n01685808 desert grassland whiptail lizard
43 | n01687978 agama
44 | n01688243 frilled-necked lizard
45 | n01689811 alligator lizard
46 | n01692333 Gila monster
47 | n01693334 European green lizard
48 | n01694178 chameleon
49 | n01695060 Komodo dragon
50 | n01697457 Nile crocodile
51 | n01698640 American alligator
52 | n01704323 triceratops
53 | n01728572 worm snake
54 | n01728920 ring-necked snake
55 | n01729322 eastern hog-nosed snake
56 | n01729977 smooth green snake
57 | n01734418 kingsnake
58 | n01735189 garter snake
59 | n01737021 water snake
60 | n01739381 vine snake
61 | n01740131 night snake
62 | n01742172 boa constrictor
63 | n01744401 African rock python
64 | n01748264 Indian cobra
65 | n01749939 green mamba
66 | n01751748 sea snake
67 | n01753488 Saharan horned viper
68 | n01755581 eastern diamondback rattlesnake
69 | n01756291 sidewinder rattlesnake
70 | n01768244 trilobite
71 | n01770081 harvestman
72 | n01770393 scorpion
73 | n01773157 yellow garden spider
74 | n01773549 barn spider
75 | n01773797 European garden spider
76 | n01774384 southern black widow
77 | n01774750 tarantula
78 | n01775062 wolf spider
79 | n01776313 tick
80 | n01784675 centipede
81 | n01795545 black grouse
82 | n01796340 ptarmigan
83 | n01797886 ruffed grouse
84 | n01798484 prairie grouse
85 | n01806143 peafowl
86 | n01806567 quail
87 | n01807496 partridge
88 | n01817953 african grey parrot
89 | n01818515 macaw
90 | n01819313 sulphur-crested cockatoo
91 | n01820546 lorikeet
92 | n01824575 coucal
93 | n01828970 bee eater
94 | n01829413 hornbill
95 | n01833805 hummingbird
96 | n01843065 jacamar
97 | n01843383 toucan
98 | n01847000 duck
99 | n01855032 red-breasted merganser
100 | n01855672 goose
101 | n01860187 black swan
102 | n01871265 tusker
103 | n01872401 echidna
104 | n01873310 platypus
105 | n01877812 wallaby
106 | n01882714 koala
107 | n01883070 wombat
108 | n01910747 jellyfish
109 | n01914609 sea anemone
110 | n01917289 brain coral
111 | n01924916 flatworm
112 | n01930112 nematode
113 | n01943899 conch
114 | n01944390 snail
115 | n01945685 slug
116 | n01950731 sea slug
117 | n01955084 chiton
118 | n01968897 chambered nautilus
119 | n01978287 Dungeness crab
120 | n01978455 rock crab
121 | n01980166 fiddler crab
122 | n01981276 red king crab
123 | n01983481 American lobster
124 | n01984695 spiny lobster
125 | n01985128 crayfish
126 | n01986214 hermit crab
127 | n01990800 isopod
128 | n02002556 white stork
129 | n02002724 black stork
130 | n02006656 spoonbill
131 | n02007558 flamingo
132 | n02009229 little blue heron
133 | n02009912 great egret
134 | n02011460 bittern bird
135 | n02012849 crane bird
136 | n02013706 limpkin
137 | n02017213 common gallinule
138 | n02018207 American coot
139 | n02018795 bustard
140 | n02025239 ruddy turnstone
141 | n02027492 dunlin
142 | n02028035 common redshank
143 | n02033041 dowitcher
144 | n02037110 oystercatcher
145 | n02051845 pelican
146 | n02056570 king penguin
147 | n02058221 albatross
148 | n02066245 grey whale
149 | n02071294 killer whale
150 | n02074367 dugong
151 | n02077923 sea lion
152 | n02085620 Chihuahua
153 | n02085782 Japanese Chin
154 | n02085936 Maltese
155 | n02086079 Pekingese
156 | n02086240 Shih Tzu
157 | n02086646 King Charles Spaniel
158 | n02086910 Papillon
159 | n02087046 toy terrier
160 | n02087394 Rhodesian Ridgeback
161 | n02088094 Afghan Hound
162 | n02088238 Basset Hound
163 | n02088364 Beagle
164 | n02088466 Bloodhound
165 | n02088632 Bluetick Coonhound
166 | n02089078 Black and Tan Coonhound
167 | n02089867 Treeing Walker Coonhound
168 | n02089973 English foxhound
169 | n02090379 Redbone Coonhound
170 | n02090622 borzoi
171 | n02090721 Irish Wolfhound
172 | n02091032 Italian Greyhound
173 | n02091134 Whippet
174 | n02091244 Ibizan Hound
175 | n02091467 Norwegian Elkhound
176 | n02091635 Otterhound
177 | n02091831 Saluki
178 | n02092002 Scottish Deerhound
179 | n02092339 Weimaraner
180 | n02093256 Staffordshire Bull Terrier
181 | n02093428 American Staffordshire Terrier
182 | n02093647 Bedlington Terrier
183 | n02093754 Border Terrier
184 | n02093859 Kerry Blue Terrier
185 | n02093991 Irish Terrier
186 | n02094114 Norfolk Terrier
187 | n02094258 Norwich Terrier
188 | n02094433 Yorkshire Terrier
189 | n02095314 Wire Fox Terrier
190 | n02095570 Lakeland Terrier
191 | n02095889 Sealyham Terrier
192 | n02096051 Airedale Terrier
193 | n02096177 Cairn Terrier
194 | n02096294 Australian Terrier
195 | n02096437 Dandie Dinmont Terrier
196 | n02096585 Boston Terrier
197 | n02097047 Miniature Schnauzer
198 | n02097130 Giant Schnauzer
199 | n02097209 Standard Schnauzer
200 | n02097298 Scottish Terrier
201 | n02097474 Tibetan Terrier
202 | n02097658 Australian Silky Terrier
203 | n02098105 Soft-coated Wheaten Terrier
204 | n02098286 West Highland White Terrier
205 | n02098413 Lhasa Apso
206 | n02099267 Flat-Coated Retriever
207 | n02099429 Curly-coated Retriever
208 | n02099601 Golden Retriever
209 | n02099712 Labrador Retriever
210 | n02099849 Chesapeake Bay Retriever
211 | n02100236 German Shorthaired Pointer
212 | n02100583 Vizsla
213 | n02100735 English Setter
214 | n02100877 Irish Setter
215 | n02101006 Gordon Setter
216 | n02101388 Brittany dog
217 | n02101556 Clumber Spaniel
218 | n02102040 English Springer Spaniel
219 | n02102177 Welsh Springer Spaniel
220 | n02102318 Cocker Spaniel
221 | n02102480 Sussex Spaniel
222 | n02102973 Irish Water Spaniel
223 | n02104029 Kuvasz
224 | n02104365 Schipperke
225 | n02105056 Groenendael dog
226 | n02105162 Malinois
227 | n02105251 Briard
228 | n02105412 Australian Kelpie
229 | n02105505 Komondor
230 | n02105641 Old English Sheepdog
231 | n02105855 Shetland Sheepdog
232 | n02106030 collie
233 | n02106166 Border Collie
234 | n02106382 Bouvier des Flandres dog
235 | n02106550 Rottweiler
236 | n02106662 German Shepherd Dog
237 | n02107142 Dobermann
238 | n02107312 Miniature Pinscher
239 | n02107574 Greater Swiss Mountain Dog
240 | n02107683 Bernese Mountain Dog
241 | n02107908 Appenzeller Sennenhund
242 | n02108000 Entlebucher Sennenhund
243 | n02108089 Boxer
244 | n02108422 Bullmastiff
245 | n02108551 Tibetan Mastiff
246 | n02108915 French Bulldog
247 | n02109047 Great Dane
248 | n02109525 St. Bernard
249 | n02109961 husky
250 | n02110063 Alaskan Malamute
251 | n02110185 Siberian Husky
252 | n02110341 Dalmatian
253 | n02110627 Affenpinscher
254 | n02110806 Basenji
255 | n02110958 pug
256 | n02111129 Leonberger
257 | n02111277 Newfoundland dog
258 | n02111500 Great Pyrenees dog
259 | n02111889 Samoyed
260 | n02112018 Pomeranian
261 | n02112137 Chow Chow
262 | n02112350 Keeshond
263 | n02112706 brussels griffon
264 | n02113023 Pembroke Welsh Corgi
265 | n02113186 Cardigan Welsh Corgi
266 | n02113624 Toy Poodle
267 | n02113712 Miniature Poodle
268 | n02113799 Standard Poodle
269 | n02113978 Mexican hairless dog (xoloitzcuintli)
270 | n02114367 grey wolf
271 | n02114548 Alaskan tundra wolf
272 | n02114712 red wolf or maned wolf
273 | n02114855 coyote
274 | n02115641 dingo
275 | n02115913 dhole
276 | n02116738 African wild dog
277 | n02117135 hyena
278 | n02119022 red fox
279 | n02119789 kit fox
280 | n02120079 Arctic fox
281 | n02120505 grey fox
282 | n02123045 tabby cat
283 | n02123159 tiger cat
284 | n02123394 Persian cat
285 | n02123597 Siamese cat
286 | n02124075 Egyptian Mau
287 | n02125311 cougar
288 | n02127052 lynx
289 | n02128385 leopard
290 | n02128757 snow leopard
291 | n02128925 jaguar
292 | n02129165 lion
293 | n02129604 tiger
294 | n02130308 cheetah
295 | n02132136 brown bear
296 | n02133161 American black bear
297 | n02134084 polar bear
298 | n02134418 sloth bear
299 | n02137549 mongoose
300 | n02138441 meerkat
301 | n02165105 tiger beetle
302 | n02165456 ladybug
303 | n02167151 ground beetle
304 | n02168699 longhorn beetle
305 | n02169497 leaf beetle
306 | n02172182 dung beetle
307 | n02174001 rhinoceros beetle
308 | n02177972 weevil
309 | n02190166 fly
310 | n02206856 bee
311 | n02219486 ant
312 | n02226429 grasshopper
313 | n02229544 cricket insect
314 | n02231487 stick insect
315 | n02233338 cockroach
316 | n02236044 praying mantis
317 | n02256656 cicada
318 | n02259212 leafhopper
319 | n02264363 lacewing
320 | n02268443 dragonfly
321 | n02268853 damselfly
322 | n02276258 red admiral butterfly
323 | n02277742 ringlet butterfly
324 | n02279972 monarch butterfly
325 | n02280649 small white butterfly
326 | n02281406 sulphur butterfly
327 | n02281787 gossamer-winged butterfly
328 | n02317335 starfish
329 | n02319095 sea urchin
330 | n02321529 sea cucumber
331 | n02325366 cottontail rabbit
332 | n02326432 hare
333 | n02328150 Angora rabbit
334 | n02342885 hamster
335 | n02346627 porcupine
336 | n02356798 fox squirrel
337 | n02361337 marmot
338 | n02363005 beaver
339 | n02364673 guinea pig
340 | n02389026 common sorrel horse
341 | n02391049 zebra
342 | n02395406 pig
343 | n02396427 wild boar
344 | n02397096 warthog
345 | n02398521 hippopotamus
346 | n02403003 ox
347 | n02408429 water buffalo
348 | n02410509 bison
349 | n02412080 ram (adult male sheep)
350 | n02415577 bighorn sheep
351 | n02417914 Alpine ibex
352 | n02422106 hartebeest
353 | n02422699 impala (antelope)
354 | n02423022 gazelle
355 | n02437312 arabian camel
356 | n02437616 llama
357 | n02441942 weasel
358 | n02442845 mink
359 | n02443114 European polecat
360 | n02443484 black-footed ferret
361 | n02444819 otter
362 | n02445715 skunk
363 | n02447366 badger
364 | n02454379 armadillo
365 | n02457408 three-toed sloth
366 | n02480495 orangutan
367 | n02480855 gorilla
368 | n02481823 chimpanzee
369 | n02483362 gibbon
370 | n02483708 siamang
371 | n02484975 guenon
372 | n02486261 patas monkey
373 | n02486410 baboon
374 | n02487347 macaque
375 | n02488291 langur
376 | n02488702 black-and-white colobus
377 | n02489166 proboscis monkey
378 | n02490219 marmoset
379 | n02492035 white-headed capuchin
380 | n02492660 howler monkey
381 | n02493509 titi monkey
382 | n02493793 Geoffroy's spider monkey
383 | n02494079 common squirrel monkey
384 | n02497673 ring-tailed lemur
385 | n02500267 indri
386 | n02504013 Asian elephant
387 | n02504458 African bush elephant
388 | n02509815 red panda
389 | n02510455 giant panda
390 | n02514041 snoek fish
391 | n02526121 eel
392 | n02536864 silver salmon
393 | n02606052 rock beauty fish
394 | n02607072 clownfish
395 | n02640242 sturgeon
396 | n02641379 gar fish
397 | n02643566 lionfish
398 | n02655020 pufferfish
399 | n02666196 abacus
400 | n02667093 abaya
401 | n02669723 academic gown
402 | n02672831 accordion
403 | n02676566 acoustic guitar
404 | n02687172 aircraft carrier
405 | n02690373 airliner
406 | n02692877 airship
407 | n02699494 altar
408 | n02701002 ambulance
409 | n02704792 amphibious vehicle
410 | n02708093 analog clock
411 | n02727426 apiary
412 | n02730930 apron
413 | n02747177 trash can
414 | n02749479 assault rifle
415 | n02769748 backpack
416 | n02776631 bakery
417 | n02777292 balance beam
418 | n02782093 balloon
419 | n02783161 ballpoint pen
420 | n02786058 Band-Aid
421 | n02787622 banjo
422 | n02788148 baluster / handrail
423 | n02790996 barbell
424 | n02791124 barber chair
425 | n02791270 barbershop
426 | n02793495 barn
427 | n02794156 barometer
428 | n02795169 barrel
429 | n02797295 wheelbarrow
430 | n02799071 baseball
431 | n02802426 basketball
432 | n02804414 bassinet
433 | n02804610 bassoon
434 | n02807133 swimming cap
435 | n02808304 bath towel
436 | n02808440 bathtub
437 | n02814533 station wagon
438 | n02814860 lighthouse
439 | n02815834 beaker
440 | n02817516 military hat (bearskin or shako)
441 | n02823428 beer bottle
442 | n02823750 beer glass
443 | n02825657 bell tower
444 | n02834397 baby bib
445 | n02835271 tandem bicycle
446 | n02837789 bikini
447 | n02840245 ring binder
448 | n02841315 binoculars
449 | n02843684 birdhouse
450 | n02859443 boathouse
451 | n02860847 bobsleigh
452 | n02865351 bolo tie
453 | n02869837 poke bonnet
454 | n02870880 bookcase
455 | n02871525 bookstore
456 | n02877765 bottle cap
457 | n02879718 hunting bow
458 | n02883205 bow tie
459 | n02892201 brass memorial plaque
460 | n02892767 bra
461 | n02894605 breakwater
462 | n02895154 breastplate
463 | n02906734 broom
464 | n02909870 bucket
465 | n02910353 buckle
466 | n02916936 bulletproof vest
467 | n02917067 high-speed train
468 | n02927161 butcher shop
469 | n02930766 taxicab
470 | n02939185 cauldron
471 | n02948072 candle
472 | n02950826 cannon
473 | n02951358 canoe
474 | n02951585 can opener
475 | n02963159 cardigan
476 | n02965783 car mirror
477 | n02966193 carousel
478 | n02966687 tool kit
479 | n02971356 cardboard box / carton
480 | n02974003 car wheel
481 | n02977058 automated teller machine
482 | n02978881 cassette
483 | n02979186 cassette player
484 | n02980441 castle
485 | n02981792 catamaran
486 | n02988304 CD player
487 | n02992211 cello
488 | n02992529 mobile phone
489 | n02999410 chain
490 | n03000134 chain-link fence
491 | n03000247 chain mail
492 | n03000684 chainsaw
493 | n03014705 storage chest
494 | n03016953 chiffonier
495 | n03017168 bell or wind chime
496 | n03018349 china cabinet
497 | n03026506 Christmas stocking
498 | n03028079 church
499 | n03032252 movie theater
500 | n03041632 cleaver
501 | n03042490 cliff dwelling
502 | n03045698 cloak
503 | n03047690 clogs
504 | n03062245 cocktail shaker
505 | n03063599 coffee mug
506 | n03063689 coffeemaker
507 | n03065424 spiral or coil
508 | n03075370 combination lock
509 | n03085013 computer keyboard
510 | n03089624 candy store
511 | n03095699 container ship
512 | n03100240 convertible
513 | n03109150 corkscrew
514 | n03110669 cornet
515 | n03124043 cowboy boot
516 | n03124170 cowboy hat
517 | n03125729 cradle
518 | n03126707 construction crane
519 | n03127747 crash helmet
520 | n03127925 crate
521 | n03131574 infant bed
522 | n03133878 Crock Pot
523 | n03134739 croquet ball
524 | n03141823 crutch
525 | n03146219 cuirass
526 | n03160309 dam
527 | n03179701 desk
528 | n03180011 desktop computer
529 | n03187595 rotary dial telephone
530 | n03188531 diaper
531 | n03196217 digital clock
532 | n03197337 digital watch
533 | n03201208 dining table
534 | n03207743 dishcloth
535 | n03207941 dishwasher
536 | n03208938 disc brake
537 | n03216828 dock
538 | n03218198 dog sled
539 | n03220513 dome
540 | n03223299 doormat
541 | n03240683 drilling rig
542 | n03249569 drum
543 | n03250847 drumstick
544 | n03255030 dumbbell
545 | n03259280 Dutch oven
546 | n03271574 electric fan
547 | n03272010 electric guitar
548 | n03272562 electric locomotive
549 | n03290653 entertainment center
550 | n03291819 envelope
551 | n03297495 espresso machine
552 | n03314780 face powder
553 | n03325584 feather boa
554 | n03337140 filing cabinet
555 | n03344393 fireboat
556 | n03345487 fire truck
557 | n03347037 fire screen
558 | n03355925 flagpole
559 | n03372029 flute
560 | n03376595 folding chair
561 | n03379051 football helmet
562 | n03384352 forklift
563 | n03388043 fountain
564 | n03388183 fountain pen
565 | n03388549 four-poster bed
566 | n03393912 freight car
567 | n03394916 French horn
568 | n03400231 frying pan
569 | n03404251 fur coat
570 | n03417042 garbage truck
571 | n03424325 gas mask or respirator
572 | n03425413 gas pump
573 | n03443371 goblet
574 | n03444034 go-kart
575 | n03445777 golf ball
576 | n03445924 golf cart
577 | n03447447 gondola
578 | n03447721 gong
579 | n03450230 gown
580 | n03452741 grand piano
581 | n03457902 greenhouse
582 | n03459775 radiator grille
583 | n03461385 grocery store
584 | n03467068 guillotine
585 | n03476684 hair clip
586 | n03476991 hair spray
587 | n03478589 half-track
588 | n03481172 hammer
589 | n03482405 hamper
590 | n03483316 hair dryer
591 | n03485407 hand-held computer
592 | n03485794 handkerchief
593 | n03492542 hard disk drive
594 | n03494278 harmonica
595 | n03495258 harp
596 | n03496892 combine harvester
597 | n03498962 hatchet
598 | n03527444 holster
599 | n03529860 home theater
600 | n03530642 honeycomb
601 | n03532672 hook
602 | n03534580 hoop skirt
603 | n03535780 gymnastic horizontal bar
604 | n03538406 horse-drawn vehicle
605 | n03544143 hourglass
606 | n03584254 iPod
607 | n03584829 clothes iron
608 | n03590841 carved pumpkin
609 | n03594734 jeans
610 | n03594945 jeep
611 | n03595614 T-shirt
612 | n03598930 jigsaw puzzle
613 | n03599486 rickshaw
614 | n03602883 joystick
615 | n03617480 kimono
616 | n03623198 knee pad
617 | n03627232 knot
618 | n03630383 lab coat
619 | n03633091 ladle
620 | n03637318 lampshade
621 | n03642806 laptop computer
622 | n03649909 lawn mower
623 | n03657121 lens cap
624 | n03658185 letter opener
625 | n03661043 library
626 | n03662601 lifeboat
627 | n03666591 lighter
628 | n03670208 limousine
629 | n03673027 ocean liner
630 | n03676483 lipstick
631 | n03680355 slip-on shoe
632 | n03690938 lotion
633 | n03691459 music speaker
634 | n03692522 loupe magnifying glass
635 | n03697007 sawmill
636 | n03706229 magnetic compass
637 | n03709823 messenger bag
638 | n03710193 mailbox
639 | n03710637 tights
640 | n03710721 one-piece bathing suit
641 | n03717622 manhole cover
642 | n03720891 maraca
643 | n03721384 marimba
644 | n03724870 mask
645 | n03729826 matchstick
646 | n03733131 maypole
647 | n03733281 maze
648 | n03733805 measuring cup
649 | n03742115 medicine cabinet
650 | n03743016 megalith
651 | n03759954 microphone
652 | n03761084 microwave oven
653 | n03763968 military uniform
654 | n03764736 milk can
655 | n03769881 minibus
656 | n03770439 miniskirt
657 | n03770679 minivan
658 | n03773504 missile
659 | n03775071 mitten
660 | n03775546 mixing bowl
661 | n03776460 mobile home
662 | n03777568 ford model t
663 | n03777754 modem
664 | n03781244 monastery
665 | n03782006 monitor
666 | n03785016 moped
667 | n03786901 mortar and pestle
668 | n03787032 graduation cap
669 | n03788195 mosque
670 | n03788365 mosquito net
671 | n03791053 vespa
672 | n03792782 mountain bike
673 | n03792972 tent
674 | n03793489 computer mouse
675 | n03794056 mousetrap
676 | n03796401 moving van
677 | n03803284 muzzle
678 | n03804744 metal nail
679 | n03814639 neck brace
680 | n03814906 necklace
681 | n03825788 baby pacifier
682 | n03832673 notebook computer
683 | n03837869 obelisk
684 | n03838899 oboe
685 | n03840681 ocarina
686 | n03841143 odometer
687 | n03843555 oil filter
688 | n03854065 pipe organ
689 | n03857828 oscilloscope
690 | n03866082 overskirt
691 | n03868242 bullock cart
692 | n03868863 oxygen mask
693 | n03871628 product packet / packaging
694 | n03873416 paddle
695 | n03874293 paddle wheel
696 | n03874599 padlock
697 | n03876231 paintbrush
698 | n03877472 pajamas
699 | n03877845 palace
700 | n03884397 pan flute
701 | n03887697 paper towel
702 | n03888257 parachute
703 | n03888605 parallel bars
704 | n03891251 park bench
705 | n03891332 parking meter
706 | n03895866 railroad car
707 | n03899768 patio
708 | n03902125 payphone
709 | n03903868 pedestal
710 | n03908618 pencil case
711 | n03908714 pencil sharpener
712 | n03916031 perfume
713 | n03920288 Petri dish
714 | n03924679 photocopier
715 | n03929660 plectrum
716 | n03929855 Pickelhaube
717 | n03930313 picket fence
718 | n03930630 pickup truck
719 | n03933933 pier
720 | n03935335 piggy bank
721 | n03937543 pill bottle
722 | n03938244 pillow
723 | n03942813 ping-pong ball
724 | n03944341 pinwheel
725 | n03947888 pirate ship
726 | n03950228 drink pitcher
727 | n03954731 block plane
728 | n03956157 planetarium
729 | n03958227 plastic bag
730 | n03961711 plate rack
731 | n03967562 farm plow
732 | n03970156 plunger
733 | n03976467 Polaroid camera
734 | n03976657 pole
735 | n03977966 police van
736 | n03980874 poncho
737 | n03982430 pool table
738 | n03983396 soda bottle
739 | n03991062 plant pot
740 | n03992509 potter's wheel
741 | n03995372 power drill
742 | n03998194 prayer rug
743 | n04004767 printer
744 | n04005630 prison
745 | n04008634 missile
746 | n04009552 projector
747 | n04019541 hockey puck
748 | n04023962 punching bag
749 | n04026417 purse
750 | n04033901 quill
751 | n04033995 quilt
752 | n04037443 race car
753 | n04039381 racket
754 | n04040759 radiator
755 | n04041544 radio
756 | n04044716 radio telescope
757 | n04049303 rain barrel
758 | n04065272 recreational vehicle
759 | n04067472 fishing casting reel
760 | n04069434 reflex camera
761 | n04070727 refrigerator
762 | n04074963 remote control
763 | n04081281 restaurant
764 | n04086273 revolver
765 | n04090263 rifle
766 | n04099969 rocking chair
767 | n04111531 rotisserie
768 | n04116512 eraser
769 | n04118538 rugby ball
770 | n04118776 ruler measuring stick
771 | n04120489 sneaker
772 | n04125021 safe
773 | n04127249 safety pin
774 | n04131690 salt shaker
775 | n04133789 sandal
776 | n04136333 sarong
777 | n04141076 saxophone
778 | n04141327 scabbard
779 | n04141975 weighing scale
780 | n04146614 school bus
781 | n04147183 schooner
782 | n04149813 scoreboard
783 | n04152593 CRT monitor
784 | n04153751 screw
785 | n04154565 screwdriver
786 | n04162706 seat belt
787 | n04179913 sewing machine
788 | n04192698 shield
789 | n04200800 shoe store
790 | n04201297 shoji screen / room divider
791 | n04204238 shopping basket
792 | n04204347 shopping cart
793 | n04208210 shovel
794 | n04209133 shower cap
795 | n04209239 shower curtain
796 | n04228054 ski
797 | n04229816 balaclava ski mask
798 | n04235860 sleeping bag
799 | n04238763 slide rule
800 | n04239074 sliding door
801 | n04243546 slot machine
802 | n04251144 snorkel
803 | n04252077 snowmobile
804 | n04252225 snowplow
805 | n04254120 soap dispenser
806 | n04254680 soccer ball
807 | n04254777 sock
808 | n04258138 solar thermal collector
809 | n04259630 sombrero
810 | n04263257 soup bowl
811 | n04264628 keyboard space bar
812 | n04265275 space heater
813 | n04266014 space shuttle
814 | n04270147 spatula
815 | n04273569 motorboat
816 | n04275548 spider web
817 | n04277352 spindle
818 | n04285008 sports car
819 | n04286575 spotlight
820 | n04296562 stage
821 | n04310018 steam locomotive
822 | n04311004 through arch bridge
823 | n04311174 steel drum
824 | n04317175 stethoscope
825 | n04325704 scarf
826 | n04326547 stone wall
827 | n04328186 stopwatch
828 | n04330267 stove
829 | n04332243 strainer
830 | n04335435 tram
831 | n04336792 stretcher
832 | n04344873 couch
833 | n04346328 stupa
834 | n04347754 submarine
835 | n04350905 suit
836 | n04355338 sundial
837 | n04355933 sunglasses
838 | n04356056 sunglasses
839 | n04357314 sunscreen
840 | n04366367 suspension bridge
841 | n04367480 mop
842 | n04370456 sweatshirt
843 | n04371430 swim trunks / shorts
844 | n04371774 swing
845 | n04372370 electrical switch
846 | n04376876 syringe
847 | n04380533 table lamp
848 | n04389033 tank
849 | n04392985 tape player
850 | n04398044 teapot
851 | n04399382 teddy bear
852 | n04404412 television
853 | n04409515 tennis ball
854 | n04417672 thatched roof
855 | n04418357 front curtain
856 | n04423845 thimble
857 | n04428191 threshing machine
858 | n04429376 throne
859 | n04435653 tile roof
860 | n04442312 toaster
861 | n04443257 tobacco shop
862 | n04447861 toilet seat
863 | n04456115 torch
864 | n04458633 totem pole
865 | n04461696 tow truck
866 | n04462240 toy store
867 | n04465501 tractor
868 | n04467665 semi-trailer truck
869 | n04476259 tray
870 | n04479046 trench coat
871 | n04482393 tricycle
872 | n04483307 trimaran
873 | n04485082 tripod
874 | n04486054 triumphal arch
875 | n04487081 trolleybus
876 | n04487394 trombone
877 | n04493381 hot tub
878 | n04501370 turnstile
879 | n04505470 typewriter keyboard
880 | n04507155 umbrella
881 | n04509417 unicycle
882 | n04515003 upright piano
883 | n04517823 vacuum cleaner
884 | n04522168 vase
885 | n04523525 vaulted or arched ceiling
886 | n04525038 velvet fabric
887 | n04525305 vending machine
888 | n04532106 vestment
889 | n04532670 viaduct
890 | n04536866 violin
891 | n04540053 volleyball
892 | n04542943 waffle iron
893 | n04548280 wall clock
894 | n04548362 wallet
895 | n04550184 wardrobe
896 | n04552348 military aircraft
897 | n04553703 sink
898 | n04554684 washing machine
899 | n04557648 water bottle
900 | n04560804 water jug
901 | n04562935 water tower
902 | n04579145 whiskey jug
903 | n04579432 whistle
904 | n04584207 hair wig
905 | n04589890 window screen
906 | n04590129 window shade
907 | n04591157 Windsor tie
908 | n04591713 wine bottle
909 | n04592741 airplane wing
910 | n04596742 wok
911 | n04597913 wooden spoon
912 | n04599235 wool
913 | n04604644 split-rail fence
914 | n04606251 shipwreck
915 | n04612504 sailboat
916 | n04613696 yurt
917 | n06359193 website
918 | n06596364 comic book
919 | n06785654 crossword
920 | n06794110 traffic or street sign
921 | n06874185 traffic light
922 | n07248320 dust jacket
923 | n07565083 menu
924 | n07579787 plate
925 | n07583066 guacamole
926 | n07584110 consomme
927 | n07590611 hot pot
928 | n07613480 trifle
929 | n07614500 ice cream
930 | n07615774 popsicle
931 | n07684084 baguette
932 | n07693725 bagel
933 | n07695742 pretzel
934 | n07697313 cheeseburger
935 | n07697537 hot dog
936 | n07711569 mashed potatoes
937 | n07714571 cabbage
938 | n07714990 broccoli
939 | n07715103 cauliflower
940 | n07716358 zucchini
941 | n07716906 spaghetti squash
942 | n07717410 acorn squash
943 | n07717556 butternut squash
944 | n07718472 cucumber
945 | n07718747 artichoke
946 | n07720875 bell pepper
947 | n07730033 cardoon
948 | n07734744 mushroom
949 | n07742313 Granny Smith apple
950 | n07745940 strawberry
951 | n07747607 orange
952 | n07749582 lemon
953 | n07753113 fig
954 | n07753275 pineapple
955 | n07753592 banana
956 | n07754684 jackfruit
957 | n07760859 cherimoya (custard apple)
958 | n07768694 pomegranate
959 | n07802026 hay
960 | n07831146 carbonara
961 | n07836838 chocolate syrup
962 | n07860988 dough
963 | n07871810 meatloaf
964 | n07873807 pizza
965 | n07875152 pot pie
966 | n07880968 burrito
967 | n07892512 red wine
968 | n07920052 espresso
969 | n07930864 tea cup
970 | n07932039 eggnog
971 | n09193705 mountain
972 | n09229709 bubble
973 | n09246464 cliff
974 | n09256479 coral reef
975 | n09288635 geyser
976 | n09332890 lakeshore
977 | n09399592 promontory
978 | n09421951 sandbar
979 | n09428293 beach
980 | n09468604 valley
981 | n09472597 volcano
982 | n09835506 baseball player
983 | n10148035 bridegroom
984 | n10565667 scuba diver
985 | n11879895 rapeseed
986 | n11939491 daisy
987 | n12057211 yellow lady's slipper
988 | n12144580 corn
989 | n12267677 acorn
990 | n12620546 rose hip
991 | n12768682 horse chestnut seed
992 | n12985857 coral fungus
993 | n12998815 agaric
994 | n13037406 gyromitra
995 | n13040303 stinkhorn mushroom
996 | n13044778 earth star fungus
997 | n13052670 hen of the woods mushroom
998 | n13054560 bolete
999 | n13133613 corn cob
1000 | n15075141 toilet paper
1001 |
--------------------------------------------------------------------------------
/configs/classnames/oxford_flowers_classnames.json:
--------------------------------------------------------------------------------
1 | {"21": "fire lily", "3": "canterbury bells", "45": "bolero deep blue", "1": "pink primrose", "34": "mexican aster", "27": "prince of wales feathers", "7": "moon orchid", "16": "globe-flower", "25": "grape hyacinth", "26": "corn poppy", "79": "toad lily", "39": "siam tulip", "24": "red ginger", "67": "spring crocus", "35": "alpine sea holly", "32": "garden phlox", "10": "globe thistle", "6": "tiger lily", "93": "ball moss", "33": "love in the mist", "9": "monkshood", "102": "blackberry lily", "14": "spear thistle", "19": "balloon flower", "100": "blanket flower", "13": "king protea", "49": "oxeye daisy", "15": "yellow iris", "61": "cautleya spicata", "31": "carnation", "64": "silverbush", "68": "bearded iris", "63": "black-eyed susan", "69": "windflower", "62": "japanese anemone", "20": "giant white arum lily", "38": "great masterwort", "4": "sweet pea", "86": "tree mallow", "101": "trumpet creeper", "42": "daffodil", "22": "pincushion flower", "2": "hard-leaved pocket orchid", "54": "sunflower", "66": "osteospermum", "70": "tree poppy", "85": "desert-rose", "99": "bromelia", "87": "magnolia", "5": "english marigold", "92": "bee balm", "28": "stemless gentian", "97": "mallow", "57": "gaura", "40": "lenten rose", "47": "marigold", "59": "orange dahlia", "48": "buttercup", "55": "pelargonium", "36": "ruby-lipped cattleya", "91": "hippeastrum", "29": "artichoke", "71": "gazania", "90": "canna lily", "18": "peruvian lily", "98": "mexican petunia", "8": "bird of paradise", "30": "sweet william", "17": "purple coneflower", "52": "wild pansy", "84": "columbine", "12": "colt's foot", "11": "snapdragon", "96": "camellia", "23": "fritillary", "50": "common dandelion", "44": "poinsettia", "53": "primula", "72": "azalea", "65": "californian poppy", "80": "anthurium", "76": "morning glory", "37": "cape flower", "56": "bishop of llandaff", "60": "pink-yellow dahlia", "82": "clematis", "58": "geranium", "75": "thorn apple", "41": "barbeton daisy", "95": "bougainvillea", "43": "sword lily", "83": "hibiscus", "78": "lotus", "88": "cyclamen", "94": "foxglove", "81": "frangipani", "74": "rose", "89": "watercress", "73": "water lily", "46": "wallflower", "77": "passion flower", "51": "petunia"}
2 |
--------------------------------------------------------------------------------
/configs/classnames/pet_cat_to_name.json:
--------------------------------------------------------------------------------
1 | {"0": "abyssinian", "1": "american_bulldog", "2": "american_pit_bull_terrier", "3": "basset_hound", "4": "beagle", "5": "bengal", "6": "birman", "7": "bombay", "8": "boxer", "9": "british_shorthair", "10": "chihuahua", "11": "egyptian_mau", "12": "english_cocker_spaniel", "13": "english_setter", "14": "german_shorthaired", "15": "great_pyrenees", "16": "havanese", "17": "japanese_chin", "18": "keeshond", "19": "leonberger", "20": "maine_coon", "21": "miniature_pinscher", "22": "newfoundland", "23": "persian", "24": "pomeranian", "25": "pug", "26": "ragdoll", "27": "russian_blue", "28": "saint_bernard", "29": "samoyed", "30": "scottish_terrier", "31": "shiba_inu", "32": "siamese", "33": "sphynx", "34": "staffordshire_bull_terrier", "35": "wheaten_terrier", "36": "yorkshire_terrier"}
2 |
--------------------------------------------------------------------------------
/configs/models/maple/vit_b16_c2_ep5_batch4_2ctx.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 4
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.0035
18 | MAX_EPOCH: 5
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | MAPLE:
33 | N_CTX: 2
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
36 | PROMPT_DEPTH: 9
37 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets==3.1.0
2 | einops==0.8.0
3 | ftfy==6.3.1
4 | geom-median==0.1.0
5 | gradio==5.8.0
6 | jaxtyping==0.2.36
7 | matplotlib==3.9.3
8 | nbformat
9 | numpy==1.26.4
10 | pandas==2.2.3
11 | Pillow==11.0.0
12 | plotly==5.24.1
13 | regex==2024.11.6
14 | Requests==2.32.3
15 | scipy==1.14.1
16 | setuptools==75.1.0
17 | torchvision==0.17.2
18 | tqdm==4.67.1
19 | transformer_lens==2.9.1
20 | transformers==4.46.3
21 | wandb==0.19.0
22 | yacs==0.1.8
23 |
--------------------------------------------------------------------------------
/scripts/01_run_train.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | PYTHONPATH=./ nohup python -u tasks/train_sae_vit.py \
4 | --batch_size 128 \
5 | --checkpoint_path patchsae_checkpoints \
6 | --n_checkpoints 10 \
7 | --use_ghost_grads \
8 | --log_to_wandb --wandb_project patchsae_test --wandb_entity hyesulim-hs \
9 | > logs/01_test_training.txt
10 |
--------------------------------------------------------------------------------
/scripts/02_run_compute_feature_data.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | PYTHONPATH=./ nohup python -u tasks/compute_sae_feature_data.py \
4 | --root_dir ./ \
5 | --dataset_name imagenet \
6 | --sae_path patchsae_checkpoints/YOUR-OWN-PATH/clip-vit-base-patch16_-2_resid_49152.pt \
7 | --vit_type base > logs/02_test_extract.txt
8 |
--------------------------------------------------------------------------------
/scripts/03_run_class_level.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | PYTHONPATH=./ nohup python -u tasks/compute_class_wise_sae_activation.py \
4 | --root_dir ./ \
5 | --dataset_name imagenet \
6 | --threshold 0.2 \
7 | --sae_path /patchsae_checkpoints/YOUR-OWN-PATH/clip-vit-base-patch16_-2_resid_49152.pt \
8 | --vit_type base > logs/03_test_class_level.txt 2> logs/03_test_class_level.err &
9 |
--------------------------------------------------------------------------------
/scripts/04_run_topk_eval.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | PYTHONPATH=./ python tasks/classification_with_top_k_masking.py \
4 | --root_dir ./ \
5 | --dataset_name imagenet \
6 | --sae_path patchsae_checkpoints/YOUR-OWN-PATH/clip-vit-base-patch16_-2_resid_49152.pt \
7 | --cls_wise_sae_activation_path ./out/feature_data/sae_openai/base/imagenet/cls_sae_cnt.npy \
8 | --vit_type base \
9 | > logs/04_test_topk_eval.txt
10 |
--------------------------------------------------------------------------------
/src/demo/app.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gradio as gr
4 | import numpy as np
5 | import plotly.graph_objects as go
6 | import torch
7 | from PIL import Image, ImageDraw
8 | from plotly.subplots import make_subplots
9 |
10 | from src.demo.utils import load_sae_tester
11 |
12 | IMAGE_SIZE = 400
13 | DATASET_LIST = ["imagenet"]
14 | GRID_NUM = 14
15 |
16 |
17 | def get_grid_loc(evt, image):
18 | # Get click coordinates
19 | x, y = evt._data["index"][0], evt._data["index"][1]
20 |
21 | cell_width = image.width // GRID_NUM
22 | cell_height = image.height // GRID_NUM
23 |
24 | grid_x = x // cell_width
25 | grid_y = y // cell_height
26 | return grid_x, grid_y, cell_width, cell_height
27 |
28 |
29 | def plot_activation(
30 | evt: gr.EventData,
31 | current_image,
32 | activation,
33 | model_name: str,
34 | colors: tuple[str, str],
35 | ):
36 | """Plot activation distribution for the full image and optionally a selected tile"""
37 | mean_activation = activation.mean(0)
38 |
39 | tile_activation = None
40 | tile_x = None
41 | tile_y = None
42 |
43 | if evt is not None and evt._data is not None:
44 | tile_x, tile_y, _, _ = get_grid_loc(evt, current_image)
45 | token_idx = tile_y * GRID_NUM + tile_x + 1
46 | tile_activation = activation[token_idx]
47 |
48 | fig = create_activation_plot(
49 | mean_activation,
50 | tile_activation,
51 | tile_x,
52 | tile_y,
53 | model_name=model_name,
54 | colors=colors,
55 | )
56 |
57 | return fig
58 |
59 |
60 | def create_activation_plot(
61 | mean_activation,
62 | tile_activation=None,
63 | tile_x=None,
64 | tile_y=None,
65 | top_k=5,
66 | colors=("blue", "cyan"),
67 | model_name="CLIP",
68 | ):
69 | """Create plotly figure with activation traces and annotations"""
70 | fig = go.Figure()
71 |
72 | # Add trace for mean activation across full image
73 | model_label = model_name.split("-")[0]
74 | add_activation_trace(
75 | fig, mean_activation, f"{model_label} Image-level", colors[0], top_k
76 | )
77 |
78 | # Add trace for tile activation if provided
79 | if tile_activation is not None:
80 | add_activation_trace(
81 | fig,
82 | tile_activation,
83 | f"{model_label} Tile ({tile_x}, {tile_y})",
84 | colors[1],
85 | top_k,
86 | )
87 |
88 | # Update layout
89 | fig.update_layout(
90 | title="Activation Distribution",
91 | xaxis_title="SAE latent index",
92 | yaxis_title="Activation Value",
93 | template="plotly_white",
94 | legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5),
95 | )
96 |
97 | return fig
98 |
99 |
100 | def add_activation_trace(fig, activation, label, color, top_k):
101 | """Add a single activation trace with annotations to the figure"""
102 | # Add line trace
103 | fig.add_trace(
104 | go.Scatter(
105 | x=np.arange(len(activation)),
106 | y=activation,
107 | mode="lines",
108 | name=label,
109 | line=dict(color=color, dash="solid"),
110 | showlegend=True,
111 | )
112 | )
113 |
114 | # Add annotations for top activations
115 | top_indices = np.argsort(activation)[::-1][:top_k]
116 | for idx in top_indices:
117 | fig.add_annotation(
118 | x=idx,
119 | y=activation[idx],
120 | text=str(idx),
121 | showarrow=True,
122 | arrowhead=2,
123 | ax=0,
124 | ay=-15,
125 | arrowcolor=color,
126 | opacity=0.7,
127 | )
128 |
129 |
130 | def plot_activation_distribution(
131 | evt: gr.EventData, current_image, clip_act, maple_act, model_name: str
132 | ):
133 | fig = make_subplots(
134 | rows=2,
135 | cols=1,
136 | shared_xaxes=True,
137 | subplot_titles=["CLIP Activation", f"{model_name} Activation"],
138 | )
139 |
140 | fig_clip = plot_activation(
141 | evt, current_image, clip_act, "CLIP", colors=("#00b4d8", "#90e0ef")
142 | )
143 | fig_maple = plot_activation(
144 | evt, current_image, maple_act, model_name, colors=("#ff5a5f", "#ffcad4")
145 | )
146 |
147 | def _attach_fig(fig, sub_fig, row, col, yref):
148 | for trace in sub_fig.data:
149 | fig.add_trace(trace, row=row, col=col)
150 |
151 | for annotation in sub_fig.layout.annotations:
152 | annotation.update(yref=yref)
153 | fig.add_annotation(annotation)
154 | return fig
155 |
156 | fig = _attach_fig(fig, fig_clip, row=1, col=1, yref="y1")
157 | fig = _attach_fig(fig, fig_maple, row=2, col=1, yref="y2")
158 |
159 | fig.update_xaxes(title_text="SAE Latent Index", row=2, col=1)
160 | fig.update_xaxes(title_text="SAE Latent Index", row=1, col=1)
161 | fig.update_yaxes(title_text="Activation Value", row=1, col=1)
162 | fig.update_yaxes(title_text="Activation Value", row=2, col=1)
163 | fig.update_layout(
164 | template="plotly_white",
165 | showlegend=True,
166 | legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
167 | margin=dict(l=20, r=20, t=40, b=20),
168 | )
169 |
170 | return fig
171 |
172 |
173 | def get_top_images(model_type, slider_value, toggle_btn):
174 | out_top_images = sae_tester[model_type].get_top_images(
175 | slider_value, top_k=5, show_seg_mask=toggle_btn
176 | )
177 |
178 | out_top_images = [plt_to_pil_direct(img) for img in out_top_images]
179 | return out_top_images
180 |
181 |
182 | def get_segmask(image, sae_act, slider_value):
183 | temp = sae_act[:, slider_value]
184 | mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
185 | mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][
186 | 0
187 | ].numpy()
188 | mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
189 |
190 | base_opacity = 30
191 | image_array = np.array(image)[..., :3]
192 | rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
193 | rgba_overlay[..., :3] = image_array[..., :3]
194 |
195 | darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
196 | rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
197 | rgba_overlay[..., 3] = 255 # Fully opaque
198 |
199 | return rgba_overlay
200 |
201 |
202 | def plt_to_pil_direct(fig):
203 | # Draw the canvas to render the figure
204 | fig.canvas.draw()
205 |
206 | # Convert the figure to a NumPy array
207 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
208 | width, height = fig.canvas.get_width_height()
209 | image = data.reshape((height, width, 3))
210 |
211 | # Create a PIL Image from the NumPy array
212 | return Image.fromarray(image)
213 |
214 |
215 | def show_segmentation_masks(
216 | selected_image, slider_value, sae_act, model_type, toggle_btn=False
217 | ):
218 | slider_value = int(slider_value.split("-")[-1])
219 | rgba_overlay = get_segmask(selected_image, sae_act, slider_value)
220 | top_images = get_top_images(model_type, slider_value, toggle_btn)
221 |
222 | act_values = []
223 | for dataset in REF_DATASET_LIST:
224 | act_value = sae_data_dict["mean_act_values"][dataset][slider_value, :5]
225 | act_value = [str(round(value.item(), 3)) for value in act_value]
226 | act_value = " | ".join(act_value)
227 | out = f"#### Activation values: {act_value}"
228 | act_values.append(out)
229 |
230 | return rgba_overlay, top_images, act_values
231 |
232 |
233 | def load_results(resized_image, radio_choice, clip_act, maple_act, toggle_btn):
234 | if clip_act is None:
235 | return None, None, None, None, None, None, None
236 |
237 | init_seg, init_tops, init_values = show_segmentation_masks(
238 | resized_image, radio_choice, clip_act, "CLIP", toggle_btn
239 | )
240 |
241 | slider_value = int(radio_choice.split("-")[-1])
242 | maple_init_seg = get_segmask(resized_image, maple_act, slider_value)
243 |
244 | out = (init_seg, maple_init_seg)
245 | out += tuple(init_tops)
246 | out += tuple(init_values)
247 | return out
248 |
249 |
250 | def load_image_and_act(image, clip_act, maple_act, model_name):
251 | resized_image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
252 | sae_tester["CLIP"].register_image(resized_image)
253 | clip_act = sae_tester["CLIP"].get_activation_distribution()
254 |
255 | sae_tester[model_name].register_image(resized_image)
256 | maple_act = sae_tester[model_name].get_activation_distribution()
257 |
258 | neuron_plot = plot_activation_distribution(
259 | None, resized_image, clip_act, maple_act, model_name
260 | )
261 |
262 | radio_names = get_init_radio_options(clip_act, maple_act)
263 | radio_choices = gr.Radio(
264 | choices=radio_names, label="Top activating SAE latent", value=radio_names[0]
265 | )
266 | feautre_idx = radio_names[0].split("-")[-1]
267 | markdown_display = (
268 | f"## Segmentation mask for the selected SAE latent - {feautre_idx}"
269 | )
270 |
271 | return (
272 | resized_image,
273 | resized_image,
274 | neuron_plot,
275 | clip_act,
276 | maple_act,
277 | radio_choices,
278 | markdown_display,
279 | )
280 |
281 |
282 | def highlight_grid(evt: gr.EventData, image, clip_act, maple_act, model_name):
283 | grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
284 |
285 | highlighted_image = image.copy()
286 | draw = ImageDraw.Draw(highlighted_image)
287 | box = [
288 | grid_x * cell_width,
289 | grid_y * cell_height,
290 | (grid_x + 1) * cell_width,
291 | (grid_y + 1) * cell_height,
292 | ]
293 | draw.rectangle(box, outline="red", width=3)
294 |
295 | neuron_plot = plot_activation_distribution(
296 | evt, image, clip_act, maple_act, model_name
297 | )
298 |
299 | radio, choices = update_radio_options(clip_act, maple_act, grid_x, grid_y)
300 | feautre_idx = choices[0].split("-")[-1]
301 | markdown_display = (
302 | f"## Segmentation mask for the selected SAE latent - {feautre_idx}"
303 | )
304 |
305 | return (highlighted_image, neuron_plot, radio, markdown_display)
306 |
307 |
308 | def get_init_radio_options(clip_act, maple_act):
309 | clip_neuron_dict = {}
310 | maple_neuron_dict = {}
311 |
312 | def _get_top_actvation(activations, neuron_dict, top_k=5):
313 | activations = activations.mean(0)
314 | top_neurons = list(np.argsort(activations)[::-1][:top_k])
315 | for top_neuron in top_neurons:
316 | neuron_dict[top_neuron] = activations[top_neuron]
317 | sorted_dict = dict(
318 | sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
319 | )
320 | return sorted_dict
321 |
322 | clip_neuron_dict = _get_top_actvation(clip_act, clip_neuron_dict)
323 | maple_neuron_dict = _get_top_actvation(maple_act, maple_neuron_dict)
324 |
325 | radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
326 |
327 | return radio_choices
328 |
329 |
330 | def update_radio_options(clip_act, maple_act, grid_x, grid_y):
331 | def _sort_and_save_top_k(activations, neuron_dict, top_k=5):
332 | top_neurons = list(np.argsort(activations)[::-1][:top_k])
333 | for top_neuron in top_neurons:
334 | neuron_dict[top_neuron] = activations[top_neuron]
335 |
336 | def _get_top_actvation(activations, neuron_dict, token_idx):
337 | image_activation = activations.mean(0)
338 | _sort_and_save_top_k(image_activation, neuron_dict)
339 |
340 | tile_activations = activations[token_idx]
341 | _sort_and_save_top_k(tile_activations, neuron_dict)
342 |
343 | sorted_dict = dict(
344 | sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
345 | )
346 | return sorted_dict
347 |
348 | token_idx = grid_y * GRID_NUM + grid_x + 1
349 | clip_neuron_dict = {}
350 | maple_neuron_dict = {}
351 | clip_neuron_dict = _get_top_actvation(clip_act, clip_neuron_dict, token_idx)
352 | maple_neuron_dict = _get_top_actvation(maple_act, maple_neuron_dict, token_idx)
353 |
354 | clip_keys = list(clip_neuron_dict.keys())
355 | maple_keys = list(maple_neuron_dict.keys())
356 |
357 | common_keys = list(set(clip_keys).intersection(set(maple_keys)))
358 | clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
359 | maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
360 |
361 | common_keys.sort(
362 | key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
363 | )
364 | clip_only_keys.sort(reverse=True)
365 | maple_only_keys.sort(reverse=True)
366 |
367 | out = []
368 | out.extend([f"common-{i}" for i in common_keys[:5]])
369 | out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
370 | out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
371 |
372 | radio_choices = gr.Radio(
373 | choices=out, label="Top activating SAE latent", value=out[0]
374 | )
375 | return radio_choices, out
376 |
377 |
378 | def get_radio_names(clip_neuron_dict, maple_neuron_dict):
379 | clip_keys = list(clip_neuron_dict.keys())
380 | maple_keys = list(maple_neuron_dict.keys())
381 |
382 | common_keys = list(set(clip_keys).intersection(set(maple_keys)))
383 | clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
384 | maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
385 |
386 | common_keys.sort(
387 | key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
388 | )
389 | clip_only_keys.sort(reverse=True)
390 | maple_only_keys.sort(reverse=True)
391 |
392 | out = []
393 | out.extend([f"common-{i}" for i in common_keys[:5]])
394 | out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
395 | out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
396 |
397 | return out
398 |
399 |
400 | if __name__ == "__main__":
401 | parser = argparse.ArgumentParser()
402 | parser.add_argument(
403 | "--include-imagenet",
404 | action="store_true",
405 | default=False,
406 | help="Include ImageNet in the Demo",
407 | )
408 | args = parser.parse_args()
409 |
410 | sae_tester = load_sae_tester("./data/sae_weight/base/out.pt", args.include_imagenet)
411 | sae_data_dict = {"mean_act_values": {}}
412 | if args.include_imagenet:
413 | REF_DATASET_LIST = ["imagenet", "imagenet-sketch", "caltech101"]
414 | else:
415 | REF_DATASET_LIST = ["imagenet-sketch", "caltech101"]
416 | for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
417 | data = torch.load(
418 | f"./out/feature_data/sae_base/base/{dataset}/max_activating_image_values.pt",
419 | map_location="cpu",
420 | )
421 | sae_data_dict["mean_act_values"][dataset] = data
422 |
423 | with gr.Blocks(
424 | theme=gr.themes.Citrus(),
425 | css="""
426 | .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
427 | .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
428 | """,
429 | ) as demo:
430 | with gr.Row():
431 | with gr.Column():
432 | # Left View: Image selection and click handling
433 | gr.Markdown("## Select input image and patch on the image")
434 |
435 | current_image = gr.State()
436 | clip_act = gr.State()
437 | maple_act = gr.State()
438 |
439 | image_display = gr.Image(type="pil", interactive=True)
440 |
441 | with gr.Column():
442 | gr.Markdown("## SAE latent activations of CLIP and MaPLE")
443 | model_options = [
444 | f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST
445 | ]
446 | model_selector = gr.Dropdown(
447 | choices=model_options,
448 | value=model_options[0],
449 | label="Select adapted model (MaPLe)",
450 | )
451 |
452 | neuron_plot = gr.Plot(label="Neuron Activation", show_label=False)
453 |
454 | with gr.Row():
455 | with gr.Column():
456 | markdown_display = gr.Markdown(
457 | "## Segmentation mask for the selected SAE latent - "
458 | )
459 | gr.Markdown("### Localize SAE latent activation using CLIP")
460 | seg_mask_display = gr.Image(type="pil", show_label=False)
461 |
462 | gr.Markdown("### Localize SAE latent activation using MaPLE")
463 | seg_mask_display_maple = gr.Image(type="pil", show_label=False)
464 |
465 | with gr.Column():
466 | radio_choices = gr.Radio(
467 | choices=[],
468 | label="Top activating SAE latent",
469 | )
470 | toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
471 |
472 | image_display_dict = {}
473 | activation_dict = {}
474 | for dataset in REF_DATASET_LIST:
475 | image_display_dict[dataset] = gr.Image(
476 | type="pil", label=dataset, show_label=False
477 | )
478 | activation_dict[dataset] = gr.Markdown("")
479 |
480 | image_display.upload(
481 | fn=load_image_and_act,
482 | inputs=[image_display, clip_act, maple_act, model_selector],
483 | outputs=[
484 | image_display,
485 | current_image,
486 | neuron_plot,
487 | clip_act,
488 | maple_act,
489 | radio_choices,
490 | markdown_display,
491 | ],
492 | )
493 |
494 | outputs = [seg_mask_display, seg_mask_display_maple]
495 | outputs += list(image_display_dict.values())
496 | outputs += list(activation_dict.values())
497 |
498 | radio_choices.change(
499 | fn=load_results,
500 | inputs=[current_image, radio_choices, clip_act, maple_act, toggle_btn],
501 | outputs=outputs,
502 | )
503 |
504 | toggle_btn.change(
505 | fn=load_results,
506 | inputs=[current_image, radio_choices, clip_act, maple_act, toggle_btn],
507 | outputs=outputs,
508 | )
509 |
510 | image_display.select(
511 | fn=highlight_grid,
512 | inputs=[current_image, clip_act, maple_act, model_selector],
513 | outputs=[
514 | image_display,
515 | neuron_plot,
516 | radio_choices,
517 | markdown_display,
518 | ],
519 | )
520 |
521 | radio_choices.change(
522 | fn=load_results,
523 | inputs=[current_image, radio_choices, clip_act, maple_act, toggle_btn],
524 | outputs=outputs,
525 | )
526 |
527 | demo.launch()
528 |
--------------------------------------------------------------------------------
/src/demo/core.py:
--------------------------------------------------------------------------------
1 | import os
2 | from copy import deepcopy
3 | from io import BytesIO
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import requests
8 | import torch
9 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
10 | from PIL import Image
11 |
12 |
13 | class UtilMixin:
14 | def _get_max_activating_images_and_labels(
15 | self, neuron_idx, dataset, max_activating_image_indices
16 | ):
17 | img_list = max_activating_image_indices[neuron_idx]
18 | images = []
19 | labels = []
20 | for i in img_list:
21 | try:
22 | images.append(dataset[i.item()]["image"])
23 | labels.append(dataset[i.item()]["label"])
24 | except Exception:
25 | images.append(dataset[i.item()]["jpg"])
26 | labels.append(dataset[i.item()]["cls"])
27 | return images, labels
28 |
29 | def _create_patches(self, patch=256):
30 | temp = self.processed_image["pixel_values"].clone()
31 | patches = temp[0].data.unfold(0, 3, 3)
32 | patches = patches.unfold(1, patch, patch)
33 | patches = patches.unfold(2, patch, patch)
34 | return patches
35 |
36 |
37 | class VisualizeMixin:
38 | def _plot_input_image(self):
39 | plt.imshow(self.input_image)
40 |
41 | def _plot_feature_mask(self, patches, feat_idx, mask=None, plot=True):
42 | if mask is None:
43 | mask = self.sae_act[0, :, feat_idx].cpu()
44 |
45 | fig, axs = plt.subplots(patches.size(1), patches.size(2), figsize=(6, 6))
46 | plt.subplots_adjust(wspace=0.01, hspace=0.01)
47 |
48 | for i in range(patches.size(1)):
49 | for j in range(patches.size(2)):
50 | patch = patches[0, i, j].permute(1, 2, 0)
51 | patch *= torch.tensor(self.vit.processor.image_processor.image_std)
52 | patch += torch.tensor(self.vit.processor.image_processor.image_mean)
53 | masked_patch = patch * mask[i * patches.size(2) + j + 1]
54 | masked_patch = (masked_patch - masked_patch.min()) / (
55 | masked_patch.max() - masked_patch.min() + 1e-8
56 | )
57 | axs[i, j].imshow(masked_patch)
58 | axs[i, j].axis("off")
59 |
60 | fig.suptitle(feat_idx)
61 | plt.close()
62 | return fig
63 |
64 | def _plot_patches(self, patches, highlight_patch_idx=None):
65 | fig, axs = plt.subplots(patches.size(1), patches.size(2), figsize=(6, 6))
66 | plt.subplots_adjust(wspace=0.01, hspace=0.01)
67 | for i in range(patches.size(1)):
68 | for j in range(patches.size(2)):
69 | patch = patches[0, i, j].permute(1, 2, 0)
70 | patch *= torch.tensor(self.vit.processor.image_processor.image_std)
71 | patch += torch.tensor(self.vit.processor.image_processor.image_mean)
72 | axs[i, j].imshow(patch)
73 | if i * patches.size(2) + j == highlight_patch_idx:
74 | for spine in axs[i, j].spines.values():
75 | spine.set_edgecolor("red")
76 | spine.set_linewidth(3)
77 | axs[i, j].set_xticks([])
78 | axs[i, j].set_yticks([])
79 | else:
80 | axs[i, j].axis("off")
81 | plt.show()
82 |
83 | def _plot_union_top_neruons(
84 | self, top_k, union_top_neurons, token_idx, token_act, save=False
85 | ):
86 | print(f"Union of top {top_k} neurons: {union_top_neurons}")
87 |
88 | plt.figure(figsize=(10, 5))
89 | plt.plot(token_act)
90 | plt.plot(
91 | union_top_neurons,
92 | token_act[union_top_neurons],
93 | "ro",
94 | label="Top neurons",
95 | markersize=5,
96 | )
97 |
98 | # Annotate feature indices
99 | for idx in union_top_neurons:
100 | plt.text(
101 | idx, token_act[idx] + 0.05, str(idx), fontsize=9, ha="center"
102 | ) # Adjust the 0.05 value as needed for spacing
103 |
104 | plt.legend()
105 | plt.title(f"token {token_idx} activation")
106 |
107 | if save:
108 | img_name = os.path.basename(self.img_url).replace(".jpg", "")
109 | save_name = f"{self.save_dir}/{img_name}/activation/{token_idx}.jpg"
110 | os.makedirs(os.path.dirname(save_name), exist_ok=True)
111 | plt.savefig(save_name)
112 |
113 | plt.show()
114 | plt.close()
115 |
116 | def _plot_images(
117 | self,
118 | dataset_name,
119 | images,
120 | neuron_idx,
121 | labels=None,
122 | suptitle=None,
123 | top_k=5,
124 | save=False,
125 | ):
126 | images = [img.resize((224, 224)) for img in images]
127 | num_cols = min(top_k, 5)
128 | num_rows = (top_k + num_cols - 1) // num_cols
129 | fig, axes = plt.subplots(
130 | num_rows, num_cols, figsize=(4.5 * num_cols, 5 * num_rows)
131 | )
132 | axes = axes.flatten() # Flatten the 2D array of axes
133 |
134 | for i in range(top_k):
135 | axes[i].imshow(images[i]) # Display the image
136 | axes[i].axis("off") # Hide axes
137 | if labels is not None:
138 | class_name = self.class_names[dataset_name][int(labels[i])]
139 | axes[i].set_title(f"{labels[i]} {class_name}", fontsize=25)
140 | # plt.suptitle(suptitle)
141 | plt.tight_layout()
142 |
143 | if save:
144 | img_name = os.path.basename(self.img_url).replace(".jpg", "")
145 | save_name = (
146 | f"{self.save_dir}/{img_name}/top_images/{dataset_name}/{neuron_idx}.jpg"
147 | )
148 | os.makedirs(os.path.dirname(save_name), exist_ok=True)
149 | plt.savefig(save_name)
150 |
151 | plt.close()
152 |
153 | return fig
154 |
155 | def _fig_to_img(self, fig):
156 | canvas = FigureCanvas(fig)
157 | canvas.draw()
158 | img = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
159 | img = img.reshape(canvas.get_width_height()[::-1] + (3,))
160 | return img
161 |
162 | def _plot_multiple_images(self, figs, neuron_idx, top_k=5, save=False):
163 | # Create a new figure to hold all subplots
164 | num_plots = len(figs)
165 | cols = 1 # Number of columns in the subplot grid
166 | rows = (num_plots + cols - 1) // cols # Calculate rows required
167 |
168 | combined_fig = plt.figure(figsize=(20, 12)) # Adjust figsize as needed
169 |
170 | for i, fig in enumerate(figs):
171 | ax = combined_fig.add_subplot(rows, cols, i + 1)
172 | img = self._fig_to_img(fig)
173 | ax.imshow(img)
174 | ax.axis("off")
175 |
176 | if save:
177 | img_name = os.path.basename(self.img_url).replace(".jpg", "")
178 | save_name = f"{self.save_dir}/{img_name}/top_images/{neuron_idx}.jpg"
179 | os.makedirs(os.path.dirname(save_name), exist_ok=True)
180 | plt.savefig(save_name)
181 |
182 | combined_fig.show()
183 | # plt.close(combined_fig)
184 |
185 |
186 | class SAETester(VisualizeMixin, UtilMixin):
187 | def __init__(
188 | self,
189 | vit,
190 | cfg,
191 | sae,
192 | mean_acts,
193 | max_act_images,
194 | datasets,
195 | class_names,
196 | noisy_threshold=0.1,
197 | device="cpu",
198 | ):
199 | self.vit = vit
200 | self.cfg = cfg
201 | self.sae = sae
202 | self.mean_acts = mean_acts
203 | self.max_act_images = max_act_images
204 | self.datasets = datasets
205 | self.class_names = class_names
206 | self.noisy_threshold = noisy_threshold
207 | self.device = device
208 |
209 | def register_image(self, img_url: str) -> None:
210 | """Load and process an image from a URL or local path."""
211 | if isinstance(img_url, str):
212 | image = self._load_image(img_url)
213 | else:
214 | image = img_url
215 | self.input_image = image
216 | self.processed_image = self.vit.processor(
217 | images=image, text="", return_tensors="pt", padding=True
218 | )
219 |
220 | def _load_image(self, img_url: str) -> Image.Image:
221 | """Helper method to load image from URL or local path."""
222 | if "http" in img_url:
223 | response = requests.get(img_url)
224 | response.raise_for_status()
225 | return Image.open(BytesIO(response.content))
226 | return Image.open(img_url)
227 |
228 | @property
229 | def processed_image(self):
230 | return self._processed_image
231 |
232 | @processed_image.setter
233 | def processed_image(self, value):
234 | self._processed_image = value
235 |
236 | @property
237 | def input_image(self):
238 | return self._input_image
239 |
240 | @input_image.setter
241 | def input_image(self, value):
242 | self._input_image = value
243 |
244 | def show_input_image(self):
245 | self._plot_input_image()
246 |
247 | def run(
248 | self, highlight_patch_idx, patch_size=16, top_k=5, num_images=5, seg_mask=True
249 | ):
250 | # idx = 0 is cls token
251 | self.show_patches(
252 | highlight_patch_idx=highlight_patch_idx - 1, patch_size=patch_size
253 | )
254 | top_neurons = self.get_top_neurons(highlight_patch_idx, top_k=top_k)
255 | self.show_ref_images_of_neuron_indices(
256 | top_neurons, top_k=num_images, seg_mask=True
257 | )
258 |
259 | def show_patches(self, highlight_patch_idx=None, patch_size=16):
260 | if not hasattr(self, "input_image"):
261 | assert not hasattr(self, "input_image"), "register image first"
262 |
263 | patches = self._create_patches(patch=patch_size)
264 | self._plot_patches(patches.cpu().data, highlight_patch_idx=highlight_patch_idx)
265 |
266 | def show_segmentation_mask(self, feat_idx, patch_size=16, mask=None, plot=True):
267 | patches = self._create_patches(patch=patch_size)
268 | fig = self._plot_feature_mask(
269 | patches.cpu().data, feat_idx, mask=None, plot=plot
270 | )
271 | return fig
272 |
273 | def get_segmentation_mask(self, image, feat_idx: int):
274 | if image.mode == "L":
275 | image = image.convert("RGB")
276 |
277 | vit_act = self._run_vit_hook(image)
278 | sae_act = self._run_sae_hook(vit_act)
279 | token_act = sae_act[0].detach().cpu().numpy()
280 | filtered_mean_act = self._filter_out_nosiy_activation(token_act)
281 |
282 | temp = filtered_mean_act[:, feat_idx]
283 | mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
284 | mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][
285 | 0
286 | ].numpy()
287 | mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
288 |
289 | base_opacity = 30
290 | image_array = np.array(image)[..., :3]
291 | rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
292 | rgba_overlay[..., :3] = image_array[..., :3]
293 |
294 | darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
295 | rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
296 | rgba_overlay[..., 3] = 255 # Fully opaque
297 |
298 | return Image.fromarray(rgba_overlay)
299 |
300 | def get_top_neurons(self, token_idx=None, top_k=5, plot=True):
301 | if token_idx is None:
302 | token_acts, top_neurons, self.sae_act = self._get_img_acts_and_top_neurons(
303 | top_k=top_k
304 | )
305 | else:
306 | token_acts, top_neurons, self.sae_act = (
307 | self._get_token_acts_and_top_neurons(token_idx=token_idx, top_k=top_k)
308 | )
309 | if plot:
310 | self._plot_union_top_neruons(top_k, top_neurons, token_idx, token_acts)
311 | return top_neurons
312 |
313 | def get_top_images(self, neuron_idx: int, top_k=5, show_seg_mask=False):
314 | out_top_images = []
315 | for dataset_name in self.datasets.keys():
316 | if self.datasets[dataset_name] is None:
317 | continue
318 | images, labels = self._get_max_activating_images_and_labels(
319 | neuron_idx,
320 | self.datasets[dataset_name],
321 | self.max_act_images[dataset_name],
322 | )
323 |
324 | if show_seg_mask:
325 | images = [
326 | self.get_segmentation_mask(img, neuron_idx)
327 | for img in images[:top_k]
328 | ]
329 | suptitle = f"{dataset_name} - {neuron_idx}"
330 | fig = self._plot_images(
331 | dataset_name,
332 | images,
333 | neuron_idx,
334 | labels,
335 | suptitle=suptitle,
336 | top_k=top_k,
337 | save=False,
338 | )
339 | out_top_images.append(fig)
340 |
341 | return out_top_images
342 |
343 | def show_ref_images_of_neuron_indices(
344 | self, neuron_indices: list[int], top_k=5, save=False, seg_mask=False
345 | ):
346 | for neuron_idx in neuron_indices:
347 | figs = self.get_top_images(neuron_idx, top_k=top_k, show_seg_mask=False)
348 | self._plot_multiple_images(figs, neuron_indices, top_k=top_k, save=False)
349 |
350 | if seg_mask:
351 | figs = self.get_top_images(neuron_idx, top_k=top_k, show_seg_mask=True)
352 | self._plot_multiple_images(
353 | figs, neuron_indices, top_k=top_k, save=False
354 | )
355 |
356 | def get_activation_distribution(self):
357 | vit_act = self._run_vit_hook()
358 | sae_act = self._run_sae_hook(vit_act)
359 | token_act = sae_act[0].detach().cpu().numpy()
360 | filtered_mean_act = self._filter_out_nosiy_activation(token_act)
361 | self.sae_act = sae_act
362 | return filtered_mean_act
363 |
364 | def _get_img_acts_and_top_neurons(self, top_k=5, threshold=0.2):
365 | vit_act = self._run_vit_hook()
366 | sae_act = self._run_sae_hook(vit_act)
367 |
368 | token_act = sae_act[0].detach().cpu().numpy()
369 | filtered_mean_act = self._filter_out_nosiy_activation(token_act)
370 | token_act = (filtered_mean_act > threshold).sum(0)
371 | filtered_mean_act = filtered_mean_act.sum(0)
372 | top_neurons = np.argsort(filtered_mean_act)[::-1][:top_k]
373 |
374 | return token_act, top_neurons, sae_act
375 |
376 | def _get_token_acts_and_top_neurons(self, token_idx, top_k=5):
377 | vit_act = self._run_vit_hook()
378 | sae_act = self._run_sae_hook(vit_act)
379 |
380 | token_act = sae_act[0, token_idx, :].detach().cpu().numpy()
381 | filtered_mean_act = self._filter_out_nosiy_activation(token_act)
382 | top_neurons = np.argsort(filtered_mean_act)[::-1][:top_k]
383 |
384 | return token_act, top_neurons, sae_act
385 |
386 | def _run_vit_hook(self, image=None):
387 | if image is None:
388 | inputs = self.processed_image.to(self.device)
389 | else:
390 | inputs = self.vit.processor(
391 | images=image, text="", return_tensors="pt", padding=True
392 | )
393 | list_of_hook_locations = [(self.cfg.block_layer, self.cfg.module_name)]
394 | vit_out, vit_cache_dict = self.vit.run_with_cache(
395 | list_of_hook_locations, **inputs
396 | )
397 | vit_act = vit_cache_dict[(self.cfg.block_layer, self.cfg.module_name)]
398 | return vit_act
399 |
400 | def _run_sae_hook(self, vit_act):
401 | sae_out, sae_cache_dict = self.sae.run_with_cache(vit_act)
402 | sae_act = sae_cache_dict["hook_hidden_post"]
403 | if sae_act.shape[0] != 1:
404 | sae_act = sae_act.permute(1, 0, 2)
405 | return sae_act[:, :197, :]
406 |
407 | def _filter_out_nosiy_activation(self, features):
408 | noisy_features_indices = (
409 | (self.mean_acts["imagenet"] > self.noisy_threshold).nonzero()[0].tolist()
410 | )
411 | features_copy = deepcopy(features)
412 | if len(features_copy.shape) == 1:
413 | features_copy[noisy_features_indices] = 0
414 | elif len(features_copy.shape) == 2:
415 | features_copy[:, noisy_features_indices] = 0
416 | return features_copy
417 |
--------------------------------------------------------------------------------
/src/demo/utils.py:
--------------------------------------------------------------------------------
1 | from src.demo.core import SAETester
2 | from tasks.utils import (
3 | get_all_classnames,
4 | get_max_acts_and_images,
5 | get_sae_and_vit,
6 | load_datasets,
7 | )
8 |
9 |
10 | def load_sae_tester(sae_path, include_imagenet=False):
11 | datasets = load_datasets(include_imagenet=include_imagenet)
12 | classnames = get_all_classnames(datasets, data_root="./configs/classnames")
13 |
14 | root = "./out/feature_data"
15 | sae_runname = "sae_base"
16 | vit_name = "base"
17 |
18 | if include_imagenet is False:
19 | datasets["imagenet"] = None
20 |
21 | max_act_imgs, mean_acts = get_max_acts_and_images(
22 | datasets, root, sae_runname, vit_name
23 | )
24 |
25 | sae_tester = {}
26 |
27 | sae, vit, cfg = get_sae_and_vit(
28 | sae_path=sae_path,
29 | vit_type="base",
30 | device="cpu",
31 | backbone="openai/clip-vit-base-patch16",
32 | model_path=None,
33 | classnames=None,
34 | )
35 | sae_clip = SAETester(vit, cfg, sae, mean_acts, max_act_imgs, datasets, classnames)
36 |
37 | sae, vit, cfg = get_sae_and_vit(
38 | sae_path=sae_path,
39 | vit_type="maple",
40 | device="cpu",
41 | model_path="./data/clip/maple/imagenet/model.pth.tar-2",
42 | config_path="./configs/models/maple/vit_b16_c2_ep5_batch4_2ctx.yaml",
43 | backbone="openai/clip-vit-base-patch16",
44 | classnames=classnames["imagenet"],
45 | )
46 | sae_maple = SAETester(vit, cfg, sae, mean_acts, max_act_imgs, datasets, classnames)
47 | sae_tester["CLIP"] = sae_clip
48 | sae_tester["MaPLE-imagenet"] = sae_maple
49 | return sae_tester
50 |
--------------------------------------------------------------------------------
/src/models/architecture/maple.py:
--------------------------------------------------------------------------------
1 | """
2 | # Portions of this file are based on code from the “multimodal-prompt-learning” repository (MIT-licensed):
3 | # https://github.com/muzairkhattak/multimodal-prompt-learning/blob/main/trainers/maple.py
4 | """
5 |
6 | import copy
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | from src.models.clip.clip import tokenize
12 | from src.models.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
13 |
14 | _tokenizer = _Tokenizer()
15 |
16 |
17 | class TextEncoder(nn.Module):
18 | def __init__(self, clip_model):
19 | super().__init__()
20 | self.transformer = clip_model.transformer
21 | self.positional_embedding = clip_model.positional_embedding
22 | self.ln_final = clip_model.ln_final
23 | self.text_projection = clip_model.text_projection
24 | self.dtype = clip_model.dtype
25 |
26 | def forward(self, prompts, tokenized_prompts, compound_prompts_deeper_text):
27 | x = prompts + self.positional_embedding.type(self.dtype)
28 | x = x.permute(1, 0, 2) # NLD -> LND
29 | # Pass as the list, as nn.sequential cannot process multiple arguments in the forward pass
30 | combined = [
31 | x,
32 | compound_prompts_deeper_text,
33 | 0,
34 | ] # third argument is the counter which denotes depth of prompt
35 | outputs = self.transformer(combined)
36 | x = outputs[0] # extract the x back from here
37 | x = x.permute(1, 0, 2) # LND -> NLD
38 | x = self.ln_final(x).type(self.dtype)
39 |
40 | # x.shape = [batch_size, n_ctx, transformer.width]
41 | # take features from the eot embedding (eot_token is the highest number in each sequence)
42 | x = (
43 | x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)]
44 | @ self.text_projection
45 | )
46 |
47 | return x
48 |
49 |
50 | class MultiModalPromptLearner(nn.Module):
51 | def __init__(self, cfg, classnames, clip_model):
52 | super().__init__()
53 | n_cls = len(classnames)
54 | n_ctx = cfg.TRAINER.MAPLE.N_CTX
55 | ctx_init = cfg.TRAINER.MAPLE.CTX_INIT
56 | dtype = clip_model.dtype
57 | ctx_dim = clip_model.ln_final.weight.shape[0]
58 | clip_imsize = clip_model.visual.input_resolution
59 | cfg_imsize = cfg.INPUT.SIZE[0]
60 | # Default is 1, which is compound shallow prompting
61 | assert cfg.TRAINER.MAPLE.PROMPT_DEPTH >= 1, (
62 | "For MaPLe, PROMPT_DEPTH should be >= 1"
63 | )
64 | self.compound_prompts_depth = (
65 | cfg.TRAINER.MAPLE.PROMPT_DEPTH
66 | ) # max=12, but will create 11 such shared prompts
67 | assert cfg_imsize == clip_imsize, (
68 | f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
69 | )
70 |
71 | if ctx_init and (n_ctx) <= 4:
72 | # use given words to initialize context vectors
73 | ctx_init = ctx_init.replace("_", " ")
74 | n_ctx = n_ctx
75 | prompt = tokenize(ctx_init)
76 | with torch.no_grad():
77 | embedding = clip_model.token_embedding(prompt).type(dtype)
78 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
79 | prompt_prefix = ctx_init
80 | else:
81 | # random initialization
82 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
83 | nn.init.normal_(ctx_vectors, std=0.02)
84 | prompt_prefix = " ".join(["X"] * n_ctx)
85 | print("MaPLe design: Multi-modal Prompt Learning")
86 | print(f'Initial context: "{prompt_prefix}"')
87 | print(f"Number of MaPLe context words (tokens): {n_ctx}")
88 | # These below, related to the shallow prompts
89 | # Linear layer so that the tokens will project to 512 and will be initialized from 768
90 | self.proj = nn.Linear(ctx_dim, 768).type(dtype)
91 | # self.proj.half()
92 | self.ctx = nn.Parameter(ctx_vectors)
93 | # These below parameters related to the shared prompts
94 | # Define the compound prompts for the deeper layers
95 |
96 | # Minimum can be 1, which defaults to shallow MaPLe
97 | # compound prompts
98 | self.compound_prompts_text = nn.ParameterList(
99 | [
100 | nn.Parameter(torch.empty(n_ctx, 512))
101 | for _ in range(self.compound_prompts_depth - 1)
102 | ]
103 | )
104 | for single_para in self.compound_prompts_text:
105 | nn.init.normal_(single_para, std=0.02)
106 | # Also make corresponding projection layers, for each prompt
107 | single_layer = nn.Linear(ctx_dim, 768)
108 | self.compound_prompt_projections = _get_clones(
109 | single_layer, self.compound_prompts_depth - 1
110 | )
111 |
112 | classnames = [name.replace("_", " ") for name in classnames]
113 | name_lens = [len(_tokenizer.encode(name)) for name in classnames]
114 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
115 |
116 | tokenized_prompts = torch.cat([tokenize(p) for p in prompts]) # (n_cls, n_tkn)
117 | with torch.no_grad():
118 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
119 |
120 | # These token vectors will be saved when in save_model(),
121 | # but they should be ignored in load_model() as we want to use
122 | # those computed using the current class names
123 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
124 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
125 |
126 | self.n_cls = n_cls
127 | self.n_ctx = n_ctx
128 | self.tokenized_prompts = tokenized_prompts # torch.Tensor
129 | self.name_lens = name_lens
130 |
131 | def construct_prompts(self, ctx, prefix, suffix, label=None):
132 | # dim0 is either batch_size (during training) or n_cls (during testing)
133 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
134 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
135 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
136 |
137 | if label is not None:
138 | prefix = prefix[label]
139 | suffix = suffix[label]
140 |
141 | prompts = torch.cat(
142 | [
143 | prefix, # (dim0, 1, dim)
144 | ctx, # (dim0, n_ctx, dim)
145 | suffix, # (dim0, *, dim)
146 | ],
147 | dim=1,
148 | )
149 |
150 | return prompts
151 |
152 | def forward(self):
153 | ctx = self.ctx
154 |
155 | if ctx.dim() == 2:
156 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
157 |
158 | prefix = self.token_prefix
159 | suffix = self.token_suffix
160 | prompts = self.construct_prompts(ctx, prefix, suffix)
161 |
162 | # Before returning, need to transform
163 | # prompts to 768 for the visual side
164 | visual_deep_prompts = []
165 | for index, layer in enumerate(self.compound_prompt_projections):
166 | visual_deep_prompts.append(layer(self.compound_prompts_text[index]))
167 | # Now the other way around
168 | # We will project the textual prompts from 512 to 768
169 | return (
170 | prompts,
171 | self.proj(self.ctx),
172 | self.compound_prompts_text,
173 | visual_deep_prompts,
174 | ) # pass here original, as for visual 768 is required
175 |
176 |
177 | class CustomCLIP(nn.Module):
178 | def __init__(self, cfg, classnames, clip_model):
179 | super().__init__()
180 | self.prompt_learner = MultiModalPromptLearner(cfg, classnames, clip_model)
181 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
182 | self.image_encoder = clip_model.visual
183 | self.text_encoder = TextEncoder(clip_model)
184 | self.logit_scale = clip_model.logit_scale
185 | self.dtype = clip_model.dtype
186 |
187 | def get_text_features(self):
188 | prompts, _, deep_compound_prompts_text, _ = self.prompt_learner()
189 | tokenized_prompts = self.tokenized_prompts
190 | text_features = self.text_encoder(
191 | prompts, tokenized_prompts, deep_compound_prompts_text
192 | )
193 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
194 | return text_features
195 |
196 | def forward(self, input_ids, attention_mask, pixel_values):
197 | # alhtough input_ids and attention_mask are not used, but they are required for the forward pass **kwargs
198 | _, shared_ctx, _, deep_compound_prompts_vision = self.prompt_learner()
199 | image_features = self.image_encoder(
200 | pixel_values.type(self.dtype), shared_ctx, deep_compound_prompts_vision
201 | )
202 |
203 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
204 | return image_features
205 |
206 |
207 | def _get_clones(module, N):
208 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
209 |
--------------------------------------------------------------------------------
/src/models/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import * # noqa: F403
2 |
--------------------------------------------------------------------------------
/src/models/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/src/models/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/src/models/clip/clip.py:
--------------------------------------------------------------------------------
1 | """
2 | # Portions of this file are based on code from the “openai/CLIP” repository (MIT-licensed):
3 | # https://github.com/openai/CLIP/blob/main/clip/clip.py
4 | """
5 |
6 | import hashlib
7 | import os
8 | import urllib
9 | import warnings
10 | from typing import List, Union
11 |
12 | import torch
13 | from PIL import Image
14 | from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
15 | from tqdm import tqdm
16 |
17 | from .model import build_model
18 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
19 |
20 | try:
21 | from torchvision.transforms import InterpolationMode
22 |
23 | BICUBIC = InterpolationMode.BICUBIC
24 | except ImportError:
25 | BICUBIC = Image.BICUBIC
26 |
27 |
28 | if torch.__version__.split(".") < ["1", "7", "1"]:
29 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
30 |
31 |
32 | __all__ = ["available_models", "load", "tokenize"]
33 | _tokenizer = _Tokenizer()
34 |
35 | _MODELS = {
36 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
37 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
38 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
39 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
40 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
41 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
42 | }
43 |
44 |
45 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
46 | os.makedirs(root, exist_ok=True)
47 | filename = os.path.basename(url)
48 |
49 | expected_sha256 = url.split("/")[-2]
50 | download_target = os.path.join(root, filename)
51 |
52 | if os.path.exists(download_target) and not os.path.isfile(download_target):
53 | raise RuntimeError(f"{download_target} exists and is not a regular file")
54 |
55 | if os.path.isfile(download_target):
56 | if (
57 | hashlib.sha256(open(download_target, "rb").read()).hexdigest()
58 | == expected_sha256
59 | ):
60 | return download_target
61 | else:
62 | warnings.warn(
63 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
64 | )
65 |
66 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
67 | with tqdm(
68 | total=int(source.info().get("Content-Length")),
69 | ncols=80,
70 | unit="iB",
71 | unit_scale=True,
72 | ) as loop:
73 | while True:
74 | buffer = source.read(8192)
75 | if not buffer:
76 | break
77 |
78 | output.write(buffer)
79 | loop.update(len(buffer))
80 |
81 | if (
82 | hashlib.sha256(open(download_target, "rb").read()).hexdigest()
83 | != expected_sha256
84 | ):
85 | raise RuntimeError(
86 | "Model has been downloaded but the SHA256 checksum does not not match"
87 | )
88 |
89 | return download_target
90 |
91 |
92 | def _transform(n_px):
93 | return Compose(
94 | [
95 | Resize(n_px, interpolation=BICUBIC),
96 | CenterCrop(n_px),
97 | lambda image: image.convert("RGB"),
98 | ToTensor(),
99 | Normalize(
100 | (0.48145466, 0.4578275, 0.40821073),
101 | (0.26862954, 0.26130258, 0.27577711),
102 | ),
103 | ]
104 | )
105 |
106 |
107 | def available_models() -> List[str]:
108 | """Returns the names of available CLIP models"""
109 | return list(_MODELS.keys())
110 |
111 |
112 | def load(
113 | name: str,
114 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
115 | jit=False,
116 | ):
117 | """Load a CLIP model
118 |
119 | Parameters
120 | ----------
121 | name : str
122 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
123 |
124 | device : Union[str, torch.device]
125 | The device to put the loaded model
126 |
127 | jit : bool
128 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
129 |
130 | Returns
131 | -------
132 | model : torch.nn.Module
133 | The CLIP model
134 |
135 | preprocess : Callable[[PIL.Image], torch.Tensor]
136 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
137 | """
138 | if name in _MODELS:
139 | model_path = _download(_MODELS[name])
140 | elif os.path.isfile(name):
141 | model_path = name
142 | else:
143 | raise RuntimeError(
144 | f"Model {name} not found; available models = {available_models()}"
145 | )
146 |
147 | try:
148 | # loading JIT archive
149 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
150 | state_dict = None
151 | except RuntimeError:
152 | # loading saved state dict
153 | if jit:
154 | warnings.warn(
155 | f"File {model_path} is not a JIT archive. Loading as a state dict instead"
156 | )
157 | jit = False
158 | state_dict = torch.load(model_path, map_location="cpu")
159 |
160 | if not jit:
161 | model = build_model(state_dict or model.state_dict()).to(device)
162 | if str(device) == "cpu":
163 | model.float()
164 | return model, _transform(model.visual.input_resolution)
165 |
166 | # patch the device names
167 | device_holder = torch.jit.trace(
168 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
169 | )
170 | device_node = [
171 | n
172 | for n in device_holder.graph.findAllNodes("prim::Constant")
173 | if "Device" in repr(n)
174 | ][-1]
175 |
176 | def patch_device(module):
177 | try:
178 | graphs = [module.graph] if hasattr(module, "graph") else []
179 | except RuntimeError:
180 | graphs = []
181 |
182 | if hasattr(module, "forward1"):
183 | graphs.append(module.forward1.graph)
184 |
185 | for graph in graphs:
186 | for node in graph.findAllNodes("prim::Constant"):
187 | if "value" in node.attributeNames() and str(node["value"]).startswith(
188 | "cuda"
189 | ):
190 | node.copyAttributes(device_node)
191 |
192 | model.apply(patch_device)
193 | patch_device(model.encode_image)
194 | patch_device(model.encode_text)
195 |
196 | # patch dtype to float32 on CPU
197 | if str(device) == "cpu":
198 | float_holder = torch.jit.trace(
199 | lambda: torch.ones([]).float(), example_inputs=[]
200 | )
201 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
202 | float_node = float_input.node()
203 |
204 | def patch_float(module):
205 | try:
206 | graphs = [module.graph] if hasattr(module, "graph") else []
207 | except RuntimeError:
208 | graphs = []
209 |
210 | if hasattr(module, "forward1"):
211 | graphs.append(module.forward1.graph)
212 |
213 | for graph in graphs:
214 | for node in graph.findAllNodes("aten::to"):
215 | inputs = list(node.inputs())
216 | for i in [
217 | 1,
218 | 2,
219 | ]: # dtype can be the second or third argument to aten::to()
220 | if inputs[i].node()["value"] == 5:
221 | inputs[i].node().copyAttributes(float_node)
222 |
223 | model.apply(patch_float)
224 | patch_float(model.encode_image)
225 | patch_float(model.encode_text)
226 |
227 | model.float()
228 |
229 | return model, _transform(model.input_resolution.item())
230 |
231 |
232 | def tokenize(
233 | texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False
234 | ) -> torch.LongTensor:
235 | """
236 | Returns the tokenized representation of given input string(s)
237 |
238 | Parameters
239 | ----------
240 | texts : Union[str, List[str]]
241 | An input string or a list of input strings to tokenize
242 |
243 | context_length : int
244 | The context length to use; all CLIP models use 77 as the context length
245 |
246 | truncate: bool
247 | Whether to truncate the text in case its encoding is longer than the context length
248 |
249 | Returns
250 | -------
251 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
252 | """
253 | if isinstance(texts, str):
254 | texts = [texts]
255 |
256 | sot_token = _tokenizer.encoder["<|startoftext|>"]
257 | eot_token = _tokenizer.encoder["<|endoftext|>"]
258 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
259 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
260 |
261 | for i, tokens in enumerate(all_tokens):
262 | if len(tokens) > context_length:
263 | if truncate:
264 | tokens = tokens[:context_length]
265 | tokens[-1] = eot_token
266 | else:
267 | raise RuntimeError(
268 | f"Input {texts[i]} is too long for context length {context_length}"
269 | )
270 | result[i, : len(tokens)] = torch.tensor(tokens)
271 |
272 | return result
273 |
--------------------------------------------------------------------------------
/src/models/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | """
2 | # Portions of this file are based on code from the “openai/CLIP” repository (MIT-licensed):
3 | # https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
4 | """
5 |
6 | import gzip
7 | import html
8 | import os
9 | from functools import lru_cache
10 |
11 | import ftfy
12 | import regex as re
13 |
14 |
15 | @lru_cache()
16 | def default_bpe():
17 | return os.path.join(
18 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
19 | )
20 |
21 |
22 | @lru_cache()
23 | def bytes_to_unicode():
24 | """
25 | Returns list of utf-8 byte and a corresponding list of unicode strings.
26 | The reversible bpe codes work on unicode strings.
27 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
28 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
29 | This is a signficant percentage of your normal, say, 32K bpe vocab.
30 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
31 | And avoids mapping to whitespace/control characters the bpe code barfs on.
32 | """
33 | bs = (
34 | list(range(ord("!"), ord("~") + 1))
35 | + list(range(ord("¡"), ord("¬") + 1))
36 | + list(range(ord("®"), ord("ÿ") + 1))
37 | )
38 | cs = bs[:]
39 | n = 0
40 | for b in range(2**8):
41 | if b not in bs:
42 | bs.append(b)
43 | cs.append(2**8 + n)
44 | n += 1
45 | cs = [chr(n) for n in cs]
46 | return dict(zip(bs, cs))
47 |
48 |
49 | def get_pairs(word):
50 | """Return set of symbol pairs in a word.
51 | Word is represented as tuple of symbols (symbols being variable-length strings).
52 | """
53 | pairs = set()
54 | prev_char = word[0]
55 | for char in word[1:]:
56 | pairs.add((prev_char, char))
57 | prev_char = char
58 | return pairs
59 |
60 |
61 | def basic_clean(text):
62 | text = ftfy.fix_text(text)
63 | text = html.unescape(html.unescape(text))
64 | return text.strip()
65 |
66 |
67 | def whitespace_clean(text):
68 | text = re.sub(r"\s+", " ", text)
69 | text = text.strip()
70 | return text
71 |
72 |
73 | class SimpleTokenizer(object):
74 | def __init__(self, bpe_path: str = default_bpe()):
75 | self.byte_encoder = bytes_to_unicode()
76 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
77 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
78 | merges = merges[1 : 49152 - 256 - 2 + 1]
79 | merges = [tuple(merge.split()) for merge in merges]
80 | vocab = list(bytes_to_unicode().values())
81 | vocab = vocab + [v + "" for v in vocab]
82 | for merge in merges:
83 | vocab.append("".join(merge))
84 | vocab.extend(["<|startoftext|>", "<|endoftext|>"])
85 | self.encoder = dict(zip(vocab, range(len(vocab))))
86 | self.decoder = {v: k for k, v in self.encoder.items()}
87 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
88 | self.cache = {
89 | "<|startoftext|>": "<|startoftext|>",
90 | "<|endoftext|>": "<|endoftext|>",
91 | }
92 | self.pat = re.compile(
93 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
94 | re.IGNORECASE,
95 | )
96 |
97 | def bpe(self, token):
98 | if token in self.cache:
99 | return self.cache[token]
100 | word = tuple(token[:-1]) + (token[-1] + "",)
101 | pairs = get_pairs(word)
102 |
103 | if not pairs:
104 | return token + ""
105 |
106 | while True:
107 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
108 | if bigram not in self.bpe_ranks:
109 | break
110 | first, second = bigram
111 | new_word = []
112 | i = 0
113 | while i < len(word):
114 | try:
115 | j = word.index(first, i)
116 | new_word.extend(word[i:j])
117 | i = j
118 | except: # noqa: E722
119 | new_word.extend(word[i:])
120 | break
121 |
122 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
123 | new_word.append(first + second)
124 | i += 2
125 | else:
126 | new_word.append(word[i])
127 | i += 1
128 | new_word = tuple(new_word)
129 | word = new_word
130 | if len(word) == 1:
131 | break
132 | else:
133 | pairs = get_pairs(word)
134 | word = " ".join(word)
135 | self.cache[token] = word
136 | return word
137 |
138 | def encode(self, text):
139 | bpe_tokens = []
140 | text = whitespace_clean(basic_clean(text)).lower()
141 | for token in re.findall(self.pat, text):
142 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
143 | bpe_tokens.extend(
144 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
145 | )
146 | return bpe_tokens
147 |
148 | def decode(self, tokens):
149 | text = "".join([self.decoder[token] for token in tokens])
150 | text = (
151 | bytearray([self.byte_decoder[c] for c in text])
152 | .decode("utf-8", errors="replace")
153 | .replace("", " ")
154 | )
155 | return text
156 |
--------------------------------------------------------------------------------
/src/models/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .maple import get_maple_config # noqa: F401
2 |
--------------------------------------------------------------------------------
/src/models/config/default_config.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode
2 |
3 |
4 | def get_default_config():
5 | """Initialize default configuration.
6 |
7 | Returns:
8 | CfgNode: Default config object containing all settings
9 | """
10 | cfg = CfgNode()
11 |
12 | # Basic settings
13 | cfg.VERSION = 1
14 | cfg.OUTPUT_DIR = "./output" # Directory for output files
15 | cfg.RESUME = "" # Path to previous output directory
16 | cfg.SEED = -1 # Negative for random, positive for fixed seed
17 | cfg.USE_CUDA = True
18 | cfg.VERBOSE = True # Print detailed info
19 |
20 | # Input settings
21 | cfg.INPUT = CfgNode()
22 | cfg.INPUT.SIZE = (224, 224)
23 | cfg.INPUT.INTERPOLATION = "bilinear"
24 | cfg.INPUT.TRANSFORMS = ()
25 | cfg.INPUT.NO_TRANSFORM = False
26 | cfg.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] # ImageNet mean
27 | cfg.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] # ImageNet std
28 | cfg.INPUT.CROP_PADDING = 4
29 | cfg.INPUT.RRCROP_SCALE = (0.08, 1.0)
30 | cfg.INPUT.CUTOUT_N = 1
31 | cfg.INPUT.CUTOUT_LEN = 16
32 | cfg.INPUT.GN_MEAN = 0.0
33 | cfg.INPUT.GN_STD = 0.15
34 | cfg.INPUT.RANDAUGMENT_N = 2
35 | cfg.INPUT.RANDAUGMENT_M = 10
36 | cfg.INPUT.COLORJITTER_B = 0.4
37 | cfg.INPUT.COLORJITTER_C = 0.4
38 | cfg.INPUT.COLORJITTER_S = 0.4
39 | cfg.INPUT.COLORJITTER_H = 0.1
40 | cfg.INPUT.RGS_P = 0.2
41 | cfg.INPUT.GB_P = 0.5
42 | cfg.INPUT.GB_K = 21
43 |
44 | # Dataset settings
45 | cfg.DATASET = CfgNode()
46 | cfg.DATASET.ROOT = ""
47 | cfg.DATASET.NAME = ""
48 | cfg.DATASET.SOURCE_DOMAINS = ()
49 | cfg.DATASET.TARGET_DOMAINS = ()
50 | cfg.DATASET.NUM_LABELED = -1
51 | cfg.DATASET.NUM_SHOTS = -1
52 | cfg.DATASET.VAL_PERCENT = 0.1
53 | cfg.DATASET.STL10_FOLD = -1
54 | cfg.DATASET.CIFAR_C_TYPE = ""
55 | cfg.DATASET.CIFAR_C_LEVEL = 1
56 | cfg.DATASET.ALL_AS_UNLABELED = False
57 |
58 | # Dataloader settings
59 | cfg.DATALOADER = CfgNode()
60 | cfg.DATALOADER.NUM_WORKERS = 4
61 | cfg.DATALOADER.K_TRANSFORMS = 1
62 | cfg.DATALOADER.RETURN_IMG0 = False
63 |
64 | cfg.DATALOADER.TRAIN_X = CfgNode()
65 | cfg.DATALOADER.TRAIN_X.SAMPLER = "RandomSampler"
66 | cfg.DATALOADER.TRAIN_X.BATCH_SIZE = 32
67 | cfg.DATALOADER.TRAIN_X.N_DOMAIN = 0
68 | cfg.DATALOADER.TRAIN_X.N_INS = 16
69 |
70 | cfg.DATALOADER.TRAIN_U = CfgNode()
71 | cfg.DATALOADER.TRAIN_U.SAME_AS_X = True
72 | cfg.DATALOADER.TRAIN_U.SAMPLER = "RandomSampler"
73 | cfg.DATALOADER.TRAIN_U.BATCH_SIZE = 32
74 | cfg.DATALOADER.TRAIN_U.N_DOMAIN = 0
75 | cfg.DATALOADER.TRAIN_U.N_INS = 16
76 |
77 | cfg.DATALOADER.TEST = CfgNode()
78 | cfg.DATALOADER.TEST.SAMPLER = "SequentialSampler"
79 | cfg.DATALOADER.TEST.BATCH_SIZE = 32
80 |
81 | # Model settings
82 | cfg.MODEL = CfgNode()
83 | cfg.MODEL.INIT_WEIGHTS = ""
84 | cfg.MODEL.BACKBONE = CfgNode()
85 | cfg.MODEL.BACKBONE.NAME = ""
86 | cfg.MODEL.BACKBONE.PRETRAINED = True
87 |
88 | cfg.MODEL.HEAD = CfgNode()
89 | cfg.MODEL.HEAD.NAME = ""
90 | cfg.MODEL.HEAD.HIDDEN_LAYERS = ()
91 | cfg.MODEL.HEAD.ACTIVATION = "relu"
92 | cfg.MODEL.HEAD.BN = True
93 | cfg.MODEL.HEAD.DROPOUT = 0.0
94 |
95 | # Optimization settings
96 | cfg.OPTIM = CfgNode()
97 | cfg.OPTIM.NAME = "adam"
98 | cfg.OPTIM.LR = 0.0003
99 | cfg.OPTIM.WEIGHT_DECAY = 5e-4
100 | cfg.OPTIM.MOMENTUM = 0.9
101 | cfg.OPTIM.SGD_DAMPNING = 0
102 | cfg.OPTIM.SGD_NESTEROV = False
103 | cfg.OPTIM.RMSPROP_ALPHA = 0.99
104 | cfg.OPTIM.ADAM_BETA1 = 0.9
105 | cfg.OPTIM.ADAM_BETA2 = 0.999
106 | cfg.OPTIM.STAGED_LR = False
107 | cfg.OPTIM.NEW_LAYERS = ()
108 | cfg.OPTIM.BASE_LR_MULT = 0.1
109 | cfg.OPTIM.LR_SCHEDULER = "single_step"
110 | cfg.OPTIM.STEPSIZE = (-1,)
111 | cfg.OPTIM.GAMMA = 0.1
112 | cfg.OPTIM.MAX_EPOCH = 10
113 | cfg.OPTIM.WARMUP_EPOCH = -1
114 | cfg.OPTIM.WARMUP_TYPE = "linear"
115 | cfg.OPTIM.WARMUP_CONS_LR = 1e-5
116 | cfg.OPTIM.WARMUP_MIN_LR = 1e-5
117 | cfg.OPTIM.WARMUP_RECOUNT = True
118 |
119 | # Training settings
120 | cfg.TRAIN = CfgNode()
121 | cfg.TRAIN.CHECKPOINT_FREQ = 0
122 | cfg.TRAIN.PRINT_FREQ = 10
123 | cfg.TRAIN.COUNT_ITER = "train_x"
124 |
125 | # Testing settings
126 | cfg.TEST = CfgNode()
127 | cfg.TEST.EVALUATOR = "Classification"
128 | cfg.TEST.PER_CLASS_RESULT = False
129 | cfg.TEST.COMPUTE_CMAT = False
130 | cfg.TEST.NO_TEST = False
131 | cfg.TEST.SPLIT = "test"
132 | cfg.TEST.FINAL_MODEL = "last_step"
133 |
134 | # Trainer settings
135 | cfg.TRAINER = CfgNode()
136 | cfg.TRAINER.NAME = ""
137 |
138 | # Domain adaptation settings
139 | cfg.TRAINER.MCD = CfgNode()
140 | cfg.TRAINER.MCD.N_STEP_F = 4
141 |
142 | cfg.TRAINER.MME = CfgNode()
143 | cfg.TRAINER.MME.LMDA = 0.1
144 |
145 | cfg.TRAINER.CDAC = CfgNode()
146 |
147 | return cfg
148 |
--------------------------------------------------------------------------------
/src/models/config/maple.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode
2 |
3 | from src.models.config.default_config import get_default_config
4 |
5 |
6 | def get_maple_config(custom_clip_cfg=None):
7 | """Get configuration for MaPLe model.
8 |
9 | Args:
10 | custom_clip_cfg: Optional custom CLIP config
11 |
12 | Returns:
13 | Config object with MaPLe settings
14 | """
15 | cfg = get_default_config()
16 | cfg.TRAINER = CfgNode()
17 | cfg.TRAINER.MAPLE = CfgNode()
18 | cfg.TRAINER.MAPLE.N_CTX = 2 # number of context vectors at the vision branch
19 | cfg.TRAINER.MAPLE.CTX_INIT = (
20 | "a photo of a" # initialization words (only for language prompts)
21 | )
22 | cfg.TRAINER.MAPLE.PREC = "fp16" # fp16, fp32, amp
23 | # If both variables below are set to 0, 0, will the config will degenerate to COOP model
24 | cfg.TRAINER.MAPLE.PROMPT_DEPTH = (
25 | 9 # Max 12, minimum 0, for 0 it will act as shallow IVLP prompting (J=1)
26 | )
27 | cfg.DATASET = CfgNode()
28 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
29 |
30 | cfg.TRAINER.NAME = "MaPLe"
31 | cfg.MODEL = CfgNode()
32 | cfg.MODEL.BACKBONE = CfgNode()
33 | cfg.MODEL.BACKBONE.NAME = "ViT-B/16"
34 |
35 | cfg.merge_from_file(custom_clip_cfg)
36 |
37 | return cfg
38 |
--------------------------------------------------------------------------------
/src/models/templates/openai_imagenet_templates.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is based on "openai/CLIP" repository (MIT-licensed):
3 | https://github.com/openai/CLIP/blob/main/data/prompts.md
4 | """
5 |
6 | openai_imagenet_template = [
7 | lambda c: f"a bad photo of a {c}.",
8 | lambda c: f"a photo of many {c}.",
9 | lambda c: f"a sculpture of a {c}.",
10 | lambda c: f"a photo of the hard to see {c}.",
11 | lambda c: f"a low resolution photo of the {c}.",
12 | lambda c: f"a rendering of a {c}.",
13 | lambda c: f"graffiti of a {c}.",
14 | lambda c: f"a bad photo of the {c}.",
15 | lambda c: f"a cropped photo of the {c}.",
16 | lambda c: f"a tattoo of a {c}.",
17 | lambda c: f"the embroidered {c}.",
18 | lambda c: f"a photo of a hard to see {c}.",
19 | lambda c: f"a bright photo of a {c}.",
20 | lambda c: f"a photo of a clean {c}.",
21 | lambda c: f"a photo of a dirty {c}.",
22 | lambda c: f"a dark photo of the {c}.",
23 | lambda c: f"a drawing of a {c}.",
24 | lambda c: f"a photo of my {c}.",
25 | lambda c: f"the plastic {c}.",
26 | lambda c: f"a photo of the cool {c}.",
27 | lambda c: f"a close-up photo of a {c}.",
28 | lambda c: f"a black and white photo of the {c}.",
29 | lambda c: f"a painting of the {c}.",
30 | lambda c: f"a painting of a {c}.",
31 | lambda c: f"a pixelated photo of the {c}.",
32 | lambda c: f"a sculpture of the {c}.",
33 | lambda c: f"a bright photo of the {c}.",
34 | lambda c: f"a cropped photo of a {c}.",
35 | lambda c: f"a plastic {c}.",
36 | lambda c: f"a photo of the dirty {c}.",
37 | lambda c: f"a jpeg corrupted photo of a {c}.",
38 | lambda c: f"a blurry photo of the {c}.",
39 | lambda c: f"a photo of the {c}.",
40 | lambda c: f"a good photo of the {c}.",
41 | lambda c: f"a rendering of the {c}.",
42 | lambda c: f"a {c} in a video game.",
43 | lambda c: f"a photo of one {c}.",
44 | lambda c: f"a doodle of a {c}.",
45 | lambda c: f"a close-up photo of the {c}.",
46 | lambda c: f"a photo of a {c}.",
47 | lambda c: f"the origami {c}.",
48 | lambda c: f"the {c} in a video game.",
49 | lambda c: f"a sketch of a {c}.",
50 | lambda c: f"a doodle of the {c}.",
51 | lambda c: f"a origami {c}.",
52 | lambda c: f"a low resolution photo of a {c}.",
53 | lambda c: f"the toy {c}.",
54 | lambda c: f"a rendition of the {c}.",
55 | lambda c: f"a photo of the clean {c}.",
56 | lambda c: f"a photo of a large {c}.",
57 | lambda c: f"a rendition of a {c}.",
58 | lambda c: f"a photo of a nice {c}.",
59 | lambda c: f"a photo of a weird {c}.",
60 | lambda c: f"a blurry photo of a {c}.",
61 | lambda c: f"a cartoon {c}.",
62 | lambda c: f"art of a {c}.",
63 | lambda c: f"a sketch of the {c}.",
64 | lambda c: f"a embroidered {c}.",
65 | lambda c: f"a pixelated photo of a {c}.",
66 | lambda c: f"itap of the {c}.",
67 | lambda c: f"a jpeg corrupted photo of the {c}.",
68 | lambda c: f"a good photo of a {c}.",
69 | lambda c: f"a plushie {c}.",
70 | lambda c: f"a photo of the nice {c}.",
71 | lambda c: f"a photo of the small {c}.",
72 | lambda c: f"a photo of the weird {c}.",
73 | lambda c: f"the cartoon {c}.",
74 | lambda c: f"art of the {c}.",
75 | lambda c: f"a drawing of the {c}.",
76 | lambda c: f"a photo of the large {c}.",
77 | lambda c: f"a black and white photo of a {c}.",
78 | lambda c: f"the plushie {c}.",
79 | lambda c: f"a dark photo of a {c}.",
80 | lambda c: f"itap of a {c}.",
81 | lambda c: f"graffiti of the {c}.",
82 | lambda c: f"a toy {c}.",
83 | lambda c: f"itap of my {c}.",
84 | lambda c: f"a photo of a cool {c}.",
85 | lambda c: f"a photo of a small {c}.",
86 | lambda c: f"a tattoo of the {c}.",
87 | ]
88 |
--------------------------------------------------------------------------------
/src/models/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from functools import partial
4 |
5 | import torch
6 | from transformers import CLIPModel, CLIPProcessor
7 |
8 | from src.models.clip import clip
9 | from src.models.config.maple import get_maple_config
10 |
11 |
12 | def load_clip_model(cfg, model_type: str):
13 | """Load and configure a CLIP model."""
14 | backbone_name = cfg.MODEL.BACKBONE.NAME
15 | url = clip._MODELS[backbone_name]
16 | model_path = clip._download(url)
17 | try:
18 | # loading JIT archive
19 | model = torch.jit.load(model_path, map_location="cpu").eval()
20 | state_dict = None
21 | except RuntimeError:
22 | state_dict = torch.load(model_path, map_location="cpu")
23 |
24 | design_details = {
25 | "trainer": model_type,
26 | "vision_depth": 0,
27 | "language_depth": 0,
28 | "vision_ctx": 0,
29 | "language_ctx": 0,
30 | }
31 |
32 | if model_type == "maple":
33 | design_details["trainer"] = "MaPLe"
34 | design_details["maple_length"] = cfg.TRAINER.MAPLE.N_CTX
35 |
36 | model = clip.build_model(state_dict or model.state_dict(), design_details)
37 | return model
38 |
39 |
40 | def load_checkpoint(fpath: str) -> dict:
41 | """Load a model checkpoint file.
42 |
43 | Handles loading both Python 3 and Python 2 saved checkpoints by catching UnicodeDecodeError.
44 |
45 | Args:
46 | fpath: Path to the checkpoint file
47 |
48 | Returns:
49 | The loaded checkpoint dictionary
50 |
51 | Raises:
52 | ValueError: If fpath is None
53 | FileNotFoundError: If checkpoint file does not exist
54 | Exception: If checkpoint cannot be loaded
55 |
56 | Examples:
57 | >>> checkpoint = load_checkpoint('models/checkpoint.pth')
58 | """
59 | if fpath is None:
60 | raise ValueError("Checkpoint path cannot be None")
61 |
62 | if not os.path.exists(fpath):
63 | raise FileNotFoundError(f"No checkpoint file found at {fpath}")
64 |
65 | device = None if torch.cuda.is_available() else "cpu"
66 |
67 | try:
68 | return torch.load(fpath, map_location=device)
69 |
70 | except UnicodeDecodeError:
71 | # Handle Python 2 checkpoints
72 | pickle.load = partial(pickle.load, encoding="latin1")
73 | pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
74 | return torch.load(fpath, pickle_module=pickle, map_location=device)
75 |
76 | except Exception as e:
77 | raise Exception(f"Failed to load checkpoint from {fpath}: {str(e)}")
78 |
79 |
80 | def _remove_prompt_learner_tokens(state_dict: dict) -> dict:
81 | """Remove prompt learner token vectors from state dict."""
82 | token_keys = ["prompt_learner.token_prefix", "prompt_learner.token_suffix"]
83 | for key in token_keys:
84 | if key in state_dict:
85 | del state_dict[key]
86 | return state_dict
87 |
88 |
89 | def load_state_dict_without_prompt_learner(ckpt_path: str) -> dict:
90 | """Load checkpoint and remove prompt learner token vectors."""
91 | checkpoint = load_checkpoint(ckpt_path)
92 | state_dict = checkpoint["state_dict"]
93 | return _remove_prompt_learner_tokens(state_dict)
94 |
95 |
96 | def get_base_clip(backbone: str) -> tuple[CLIPModel, CLIPProcessor]:
97 | """Load base CLIP model and processor."""
98 | model = CLIPModel.from_pretrained(backbone)
99 | processor = CLIPProcessor.from_pretrained(backbone)
100 | return model, processor
101 |
102 |
103 | def get_adapted_clip(
104 | cfg,
105 | model_type: str,
106 | model_path: str,
107 | config_path: str,
108 | backbone: str,
109 | classnames: list[str],
110 | ) -> tuple[CLIPModel, CLIPProcessor]:
111 | """Load and configure adapted CLIP model with custom prompt learning.
112 |
113 | Args:
114 | cfg: Model configuration
115 | model_type: Type of prompt learning ('maple')
116 | model_path: Path to model checkpoint
117 | classnames: Optional list of class names
118 |
119 | Returns:
120 | Tuple of (model, processor)
121 | """
122 | if model_type == "maple":
123 | cfg = get_maple_config(custom_clip_cfg=config_path)
124 |
125 | clip_model = load_clip_model(cfg, model_type)
126 | model_statedict = load_state_dict_without_prompt_learner(model_path)
127 |
128 | model_types = {
129 | "maple": "src.models.architecture.maple",
130 | }
131 |
132 | if model_type not in model_types:
133 | raise ValueError(f"Unsupported model type: {model_type}")
134 |
135 | module = __import__(model_types[model_type], fromlist=["CustomCLIP"])
136 | model = module.CustomCLIP(cfg, classnames, clip_model)
137 |
138 | model.load_state_dict(model_statedict, strict=False)
139 |
140 | processor = CLIPProcessor.from_pretrained(backbone)
141 | return model, processor
142 |
--------------------------------------------------------------------------------
/src/sae_training/config.py:
--------------------------------------------------------------------------------
1 | """
2 | # Portions of this file are based on code from the “jbloomAus/SAELens” and "HugoFry/mats_sae_training_for_ViTs" repositories (MIT-licensed):
3 | https://github.com/jbloomAus/SAELens/blob/main/sae_lens/config.py
4 | https://github.com/HugoFry/mats_sae_training_for_ViTs/blob/main/sae_training/config.py
5 | """
6 |
7 | from dataclasses import dataclass
8 | from typing import Optional
9 |
10 | import torch
11 | import wandb
12 |
13 |
14 | class Config:
15 | def __init__(self, config_dict):
16 | if not isinstance(config_dict, dict):
17 | config_dict = config_dict.__dict__
18 |
19 | for key, value in config_dict.items():
20 | if isinstance(value, dict):
21 | # Recursively convert nested dictionaries
22 | value = Config(value)
23 | setattr(self, key, value)
24 |
25 |
26 | @dataclass
27 | class ViTSAERunnerConfig:
28 | """
29 | Configuration for training a sparse autoencoder on a vision transformer.
30 | """
31 |
32 | # Data Generating Function (Model + Training Distibuion)
33 | custom_clip_ckpt_path: str = None
34 | class_token: bool = True
35 | image_width: int = 224
36 | image_height: int = 224
37 | model_name: str = "openai/clip-vit-base-patch32"
38 | module_name: str = "resid"
39 | block_layer: int = 10
40 | dataset_path: str = "evanarlian/imagenet_1k_resized_256"
41 | image_key: str = "image"
42 | label_key: str = "label"
43 | use_cached_activations: bool = False
44 | cached_activations_path: Optional[str] = (
45 | None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}"
46 | )
47 |
48 | # SAE Parameters
49 | d_in: int = 768
50 |
51 | # Activation Store Parameters
52 | total_training_tokens: int = 2_000_000
53 | n_batches_in_store: int = 32
54 | store_size: Optional[int] = None
55 | max_batch_size_for_vit_forward_pass: int = 1024
56 | create_dataloader: bool = True
57 |
58 | # Misc
59 | device: str = "cpu"
60 | seed: int = 42
61 | dtype: torch.dtype = torch.float32
62 |
63 | # SAE Parameters
64 | b_dec_init_method: str = "geometric_median"
65 | expansion_factor: int = 4
66 | from_pretrained_path: Optional[str] = None
67 | gated_sae: bool = False
68 |
69 | # Training Parameters
70 | l1_coefficient: float = 1e-3
71 | lr: float = 3e-4
72 | lr_scheduler_name: str = "constant" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
73 | lr_warm_up_steps: int = 500
74 | batch_size: int = 4096
75 | mse_cls_coefficient: float = 1.0
76 |
77 | # Resampling protocol args
78 | use_ghost_grads: bool = True
79 | feature_sampling_window: int = (
80 | 2000 # May need to change this since by default I will use ghost grads
81 | )
82 | feature_sampling_method: str = "anthropic" # None or Anthropic
83 | resample_batches: int = 32
84 | feature_reinit_scale: float = 0.2
85 | dead_feature_window: int = 1000 # unless this window is larger feature sampling,
86 | dead_feature_estimation_method: str = "no_fire"
87 | dead_feature_threshold: float = 1e-8
88 |
89 | # WANDB
90 | log_to_wandb: bool = True
91 | wandb_project: str = "mats-hugo"
92 | wandb_entity: str = None
93 | wandb_log_frequency: int = 10
94 |
95 | # Misc
96 | n_checkpoints: int = 0
97 | checkpoint_path: str = "checkpoints"
98 |
99 | image_key = "image"
100 | label_key = "label"
101 |
102 | def __post_init__(self):
103 | self.store_size = self.n_batches_in_store * self.batch_size
104 |
105 | # Autofill cached_activations_path unless the user overrode it
106 | if self.cached_activations_path is None:
107 | self.cached_activations_path = f"activations/{self.dataset_path.replace('/', '_')}/{self.model_name.replace('/', '_')}/{self.block_layer}_{self.module_name}"
108 |
109 | self.d_sae = self.d_in * self.expansion_factor
110 |
111 | self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"
112 |
113 | if self.feature_sampling_method not in [None, "l2", "anthropic"]:
114 | raise ValueError(
115 | f"feature_sampling_method must be None, l2, or anthropic. Got {self.feature_sampling_method}"
116 | )
117 |
118 | if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
119 | raise ValueError(
120 | f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
121 | )
122 | if self.b_dec_init_method == "zeros":
123 | print(
124 | "Warning: We are initializing b_dec to zeros. This is probably not what you want."
125 | )
126 |
127 | self.device = torch.device(self.device)
128 |
129 | unique_id = wandb.util.generate_id()
130 | self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"
131 |
132 | print(
133 | f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"
134 | )
135 | # Print out some useful info:
136 |
137 | total_training_steps = self.total_training_tokens // self.batch_size
138 | print(f"Total training steps: {total_training_steps}")
139 |
140 | total_wandb_updates = total_training_steps // self.wandb_log_frequency
141 | print(f"Total wandb updates: {total_wandb_updates}")
142 |
143 | # how many times will we sample dead neurons?
144 | # assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window"
145 | n_dead_feature_samples = total_training_steps // self.dead_feature_window
146 | n_feature_window_samples = total_training_steps // self.feature_sampling_window
147 | print(
148 | f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.batch_size) / 10**6}"
149 | )
150 | print(
151 | f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.batch_size) / 10**6}"
152 | )
153 | if self.feature_sampling_method is not None:
154 | print(f"We will reset neurons {n_dead_feature_samples} times.")
155 |
156 | if self.use_ghost_grads:
157 | print("Using Ghost Grads.")
158 |
159 | print(
160 | f"We will reset the sparsity calculation {n_feature_window_samples} times."
161 | )
162 | print(
163 | f"Number of tokens when resampling: {self.resample_batches * self.batch_size}"
164 | )
165 | # print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size)
166 | print(
167 | f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.batch_size:.2e}"
168 | )
169 |
--------------------------------------------------------------------------------
/src/sae_training/hooked_vit.py:
--------------------------------------------------------------------------------
1 | """
2 | # Portions of this file are based on code from the "HugoFry/mats_sae_training_for_ViTs" repository (MIT-licensed):
3 | https://github.com/HugoFry/mats_sae_training_for_ViTs/blob/main/sae_training/hooked_vit.py
4 | """
5 |
6 | from contextlib import contextmanager
7 | from functools import partial
8 | from typing import Callable, List, Tuple
9 |
10 | import torch
11 | from jaxtyping import Float
12 | from torch import Tensor
13 | from torch.nn import functional as F
14 | from transformers import CLIPModel
15 |
16 |
17 | # The Hook class does not currently only supports hooking on the following locations:
18 | # 1 - residual stream post transformer block.
19 | # 2 - mlp activations.
20 | # More hooks can be added at a later date, but only post-module.
21 | class Hook:
22 | def __init__(
23 | self,
24 | block_layer: int,
25 | module_name: str,
26 | hook_fn: Callable,
27 | is_custom: bool = None,
28 | return_module_output=True,
29 | ):
30 | self.path_dict = {
31 | "resid": "",
32 | }
33 | assert module_name in self.path_dict.keys(), (
34 | f"Module name '{module_name}' not recognised."
35 | )
36 | self.return_module_output = return_module_output
37 | self.function = self.get_full_hook_fn(hook_fn)
38 | self.attr_path = self.get_attr_path(block_layer, module_name, is_custom)
39 |
40 | def get_full_hook_fn(self, hook_fn: Callable):
41 | def full_hook_fn(module, module_input, module_output):
42 | hook_fn_output = hook_fn(module_output[0])
43 | if self.return_module_output:
44 | return module_output
45 | else:
46 | return hook_fn_output # Inexplicably, the module output is not a tensor of activaitons but a tuple (tensor,)...??
47 |
48 | return full_hook_fn
49 |
50 | def get_attr_path(
51 | self, block_layer: int, module_name: str, is_custom: bool = None
52 | ) -> str:
53 | if is_custom:
54 | attr_path = f"image_encoder.transformer.resblocks[{block_layer}]"
55 | else:
56 | attr_path = f"vision_model.encoder.layers[{block_layer}]"
57 | attr_path += self.path_dict[module_name]
58 | return attr_path
59 |
60 | def get_module(self, model):
61 | return self.get_nested_attr(model, self.attr_path)
62 |
63 | def get_nested_attr(self, model, attr_path):
64 | """
65 | Gets a nested attribute from an object using a dot-separated path.
66 | """
67 | module = model
68 | attributes = attr_path.split(".")
69 | for attr in attributes:
70 | if "[" in attr:
71 | # Split at '[' and remove the trailing ']' from the index
72 | attr_name, index = attr[:-1].split("[")
73 | module = getattr(module, attr_name)[int(index)]
74 | else:
75 | module = getattr(module, attr)
76 | return module
77 |
78 |
79 | class HookedVisionTransformer:
80 | def __init__(self, model, processor, device="cuda"):
81 | self.model = model.to(device)
82 | self.processor = processor
83 |
84 | def run_with_cache(
85 | self,
86 | list_of_hook_locations: List[Tuple[int, str]],
87 | *args,
88 | return_type="output",
89 | **kwargs,
90 | ):
91 | cache_dict, list_of_hooks = self.get_caching_hooks(list_of_hook_locations)
92 | with self.hooks(list_of_hooks) as hooked_model:
93 | with torch.no_grad():
94 | output = hooked_model(*args, **kwargs)
95 |
96 | if return_type == "output":
97 | return output, cache_dict
98 | if return_type == "loss":
99 | return (
100 | self.contrastive_loss(output.logits_per_image, output.logits_per_text),
101 | cache_dict,
102 | )
103 | else:
104 | raise Exception(
105 | f"Unrecognised keyword argument return_type='{return_type}'. Must be either 'output' or 'loss'."
106 | )
107 |
108 | def get_caching_hooks(self, list_of_hook_locations: List[Tuple[int, str]]):
109 | """
110 | Note that the cache dictionary is index by the tuple (block_layer, module_name).
111 | """
112 | cache_dict = {}
113 | list_of_hooks = []
114 |
115 | def save_activations(name, activations):
116 | cache_dict[name] = activations.detach()
117 |
118 | for block_layer, module_name in list_of_hook_locations:
119 | hook_fn = partial(save_activations, (block_layer, module_name))
120 | if isinstance(self.model, CLIPModel):
121 | is_custom = False
122 | else:
123 | is_custom = True
124 | hook = Hook(block_layer, module_name, hook_fn, is_custom=is_custom)
125 | list_of_hooks.append(hook)
126 | return cache_dict, list_of_hooks
127 |
128 | @torch.no_grad()
129 | def run_with_hooks(
130 | self, list_of_hooks: List[Hook], *args, return_type="output", **kwargs
131 | ):
132 | with self.hooks(list_of_hooks) as hooked_model:
133 | with torch.no_grad():
134 | output = hooked_model(*args, **kwargs)
135 | if return_type == "output":
136 | return output
137 | if return_type == "loss":
138 | return self.contrastive_loss(
139 | output.logits_per_image, output.logits_per_text
140 | )
141 | else:
142 | raise Exception(
143 | f"Unrecognised keyword argument return_type='{return_type}'. Must be either 'output' or 'loss'."
144 | )
145 |
146 | def train_with_hooks(
147 | self, list_of_hooks: List[Hook], *args, return_type="output", **kwargs
148 | ):
149 | with self.hooks(list_of_hooks) as hooked_model:
150 | output = hooked_model(*args, **kwargs)
151 | if return_type == "output":
152 | return output
153 | if return_type == "loss":
154 | return self.contrastive_loss(
155 | output.logits_per_image, output.logits_per_text
156 | )
157 | else:
158 | raise Exception(
159 | f"Unrecognised keyword argument return_type='{return_type}'. Must be either 'output' or 'loss'."
160 | )
161 |
162 | def contrastive_loss(
163 | self,
164 | logits_per_image: Float[Tensor, "n_images n_prompts"], # noqa: F722
165 | logits_per_text: Float[Tensor, "n_prompts n_images"], # noqa: F722
166 | ): # Assumes square matrices
167 | assert logits_per_image.size()[0] == logits_per_image.size()[1], (
168 | "The number of prompts does not match the number of images."
169 | )
170 | batch_size = logits_per_image.size()[0]
171 | labels = torch.arange(batch_size).long().to(logits_per_image.device)
172 | image_loss = F.cross_entropy(logits_per_image, labels)
173 | text_loss = F.cross_entropy(logits_per_text, labels)
174 | total_loss = (image_loss + text_loss) / 2
175 | return total_loss
176 |
177 | @contextmanager
178 | def hooks(self, hooks: List[Hook]):
179 | """
180 |
181 | This is a context manager for running a model with hooks. The funciton adds
182 | forward hooks to the model, and then returns the hooked model to be run with
183 | a foward pass. The funciton then cleans up by removing any hooks.
184 |
185 | Args:
186 |
187 | model VisionTransformer: The ViT that you want to run with the forward hook
188 |
189 | hooks List[Tuple[str, Callable]]: A list of forward hooks to add to the model.
190 | Each hook is a tuple of the module name, and the hook funciton.
191 |
192 | """
193 | hook_handles = []
194 | try:
195 | for hook in hooks:
196 | # Create a full hook funciton, with all the argumnets needed to run nn.module.register_forward_hook().
197 | # The hook functions are added to the output of the module.
198 | module = hook.get_module(self.model)
199 | handle = module.register_forward_hook(hook.function)
200 | hook_handles.append(handle)
201 | yield self.model
202 | finally:
203 | for handle in hook_handles:
204 | handle.remove()
205 |
206 | def to(self, device):
207 | self.model = self.model.to(device)
208 |
209 | def __call__(self, *args, return_type="output", **kwargs):
210 | return self.forward(*args, return_type=return_type, **kwargs)
211 |
212 | def forward(self, *args, return_type="output", **kwargs):
213 | if return_type == "output":
214 | return self.model(*args, **kwargs)
215 | elif return_type == "loss":
216 | output = self.model(*args, **kwargs)
217 | return self.contrastive_loss(
218 | output.logits_per_image, output.logits_per_text
219 | )
220 | else:
221 | raise Exception(
222 | f"Unrecognised keyword argument return_type='{return_type}'. Must be either 'output' or 'loss'."
223 | )
224 |
225 | def eval(self):
226 | self.model.eval()
227 |
228 | def train(self):
229 | self.model.train()
230 |
--------------------------------------------------------------------------------
/src/sae_training/sae_trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | # Portions of this file are based on code from the “jbloomAus/SAELens” and "HugoFry/mats_sae_training_for_ViTs" repositories (MIT-licensed):
3 | https://github.com/jbloomAus/SAELens/blob/main/sae_lens/training/sae_trainer.py
4 | https://github.com/HugoFry/mats_sae_training_for_ViTs/blob/main/sae_training/config.py
5 | """
6 |
7 | from typing import Any
8 |
9 | import torch
10 | import wandb
11 | from tqdm import tqdm
12 |
13 | from src.sae_training.config import Config
14 | from src.sae_training.hooked_vit import Hook, HookedVisionTransformer
15 | from src.sae_training.sparse_autoencoder import SparseAutoencoder
16 | from src.sae_training.vit_activations_store import ViTActivationsStore
17 |
18 |
19 | class SAETrainer:
20 | def __init__(
21 | self,
22 | sae: SparseAutoencoder,
23 | model: HookedVisionTransformer,
24 | activation_store: ViTActivationsStore,
25 | cfg: Config,
26 | optimizer: torch.optim.Optimizer,
27 | scheduler: torch.optim.lr_scheduler._LRScheduler,
28 | device: torch.device,
29 | ):
30 | self.sae = sae
31 | self.model = model
32 | self.activation_store = activation_store
33 | self.cfg = cfg
34 | self.optimizer = optimizer
35 | self.scheduler = scheduler
36 | self.device = device
37 |
38 | self.act_freq_scores = torch.zeros(sae.cfg.d_sae, device=device)
39 | self.n_forward_passes_since_fired = torch.zeros(sae.cfg.d_sae, device=device)
40 | self.n_frac_active_tokens = 0
41 | self.n_training_tokens = 0
42 | self.ghost_grad_neuron_mask = None
43 | self.n_training_steps = 0
44 |
45 | self.checkpoint_thresholds = list(
46 | range(
47 | 0,
48 | cfg.total_training_tokens,
49 | cfg.total_training_tokens // self.cfg.n_checkpoints,
50 | )
51 | )[1:]
52 |
53 | def _build_sparsity_log_dict(self) -> dict[str, Any]:
54 | feature_freq = self.act_freq_scores / self.n_frac_active_tokens
55 | log_feature_freq = torch.log10(feature_freq + 1e-10).detach().cpu()
56 |
57 | return {
58 | "plots/feature_density_line_chart": wandb.Histogram(
59 | log_feature_freq.numpy()
60 | ),
61 | "metrics/mean_log10_feature_sparsity": log_feature_freq.mean().item(),
62 | }
63 |
64 | @torch.no_grad()
65 | def _reset_running_sparsity_stats(self) -> None:
66 | self.act_freq_scores = torch.zeros(self.cfg.d_sae, device=self.device)
67 | self.n_frac_active_tokens = 0
68 |
69 | def _train_step(
70 | self,
71 | sae_in: torch.Tensor,
72 | ):
73 | self.optimizer.zero_grad()
74 |
75 | self.sae.train()
76 | self.sae.set_decoder_norm_to_unit_norm()
77 |
78 | # log and then reset the feature sparsity every feature_sampling_window steps
79 | if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0:
80 | if self.cfg.log_to_wandb:
81 | sparsity_log_dict = self._build_sparsity_log_dict()
82 | wandb.log(sparsity_log_dict, step=self.n_training_steps)
83 | self._reset_running_sparsity_stats()
84 |
85 | ghost_grad_neuron_mask = (
86 | self.n_forward_passes_since_fired > self.cfg.dead_feature_window
87 | ).bool()
88 | sae_out, feature_acts, loss_dict = self.sae(sae_in, ghost_grad_neuron_mask)
89 |
90 | with torch.no_grad():
91 | if self.cfg.class_token:
92 | did_fire = (feature_acts > 0).float().sum(-2) > 0
93 | self.act_freq_scores += (feature_acts.abs() > 0).float().sum(0)
94 |
95 | else:
96 | # default for PatchSAE
97 | did_fire = (((feature_acts > 0).float().sum(-2) > 0).sum(-2)) > 0
98 | self.act_freq_scores += (feature_acts.abs() > 0).float().sum(0).sum(0)
99 |
100 | self.n_forward_passes_since_fired += 1
101 | self.n_forward_passes_since_fired[did_fire] = 0
102 | self.n_frac_active_tokens += sae_out.size(0)
103 |
104 | self.ghost_grad_neuron_mask = ghost_grad_neuron_mask
105 |
106 | loss_dict["loss"].backward()
107 | self.sae.remove_gradient_parallel_to_decoder_directions()
108 |
109 | self.optimizer.step()
110 | self.scheduler.step()
111 |
112 | return sae_out, feature_acts, loss_dict
113 |
114 | def _calculate_sparsity_metrics(self) -> dict:
115 | """Calculate sparsity-related metrics."""
116 | feature_freq = self.act_freq_scores / self.n_frac_active_tokens
117 |
118 | return {
119 | "sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
120 | "sparsity/n_passes_since_fired_over_threshold": self.ghost_grad_neuron_mask.sum().item(),
121 | "sparsity/below_1e-5": (feature_freq < 1e-5).float().mean().item(),
122 | "sparsity/below_1e-6": (feature_freq < 1e-6).float().mean().item(),
123 | "sparsity/dead_features": (feature_freq < self.cfg.dead_feature_threshold)
124 | .float()
125 | .mean()
126 | .item(),
127 | }
128 |
129 | @torch.no_grad()
130 | def _log_train_step(
131 | self,
132 | feature_acts: torch.Tensor,
133 | loss_dict: dict[str, torch.Tensor],
134 | sae_out: torch.Tensor,
135 | sae_in: torch.Tensor,
136 | ):
137 | """Log training metrics to wandb."""
138 | metrics = self._calculate_metrics(feature_acts, sae_out, sae_in)
139 | sparsity_metrics = self._calculate_sparsity_metrics()
140 |
141 | log_dict = {
142 | "losses/overall_loss": loss_dict["loss"].item(),
143 | "losses/mse_loss": loss_dict["mse_loss"].item(),
144 | "losses/l1_loss": loss_dict["l1_loss"].item(),
145 | "losses/ghost_grad_loss": loss_dict["mse_loss_ghost_resid"].item(),
146 | **metrics,
147 | **sparsity_metrics,
148 | "details/n_training_tokens": self.n_training_tokens,
149 | "details/current_learning_rate": self.optimizer.param_groups[0]["lr"],
150 | }
151 |
152 | wandb.log(log_dict, step=self.n_training_steps)
153 |
154 | @torch.no_grad()
155 | def _calculate_metrics(
156 | self, feature_acts: torch.Tensor, sae_out: torch.Tensor, sae_in: torch.Tensor
157 | ) -> dict:
158 | """Calculate model performance metrics."""
159 | if self.cfg.class_token:
160 | l0 = (feature_acts > 0).float().sum(-1).mean()
161 | else:
162 | l0 = (feature_acts > 0).float().sum(-1).mean(-1).mean()
163 | per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).mean().squeeze()
164 | total_variance = sae_in.pow(2).sum(-1).mean()
165 | explained_variance = 1 - per_token_l2_loss / total_variance
166 |
167 | return {
168 | "metrics/explained_variance": explained_variance.mean().item(),
169 | "metrics/explained_variance_std": explained_variance.std().item(),
170 | "metrics/l0": l0.item(),
171 | }
172 |
173 | @torch.no_grad()
174 | def _update_pbar(self, loss_dict, pbar, batch_size):
175 | pbar.set_description(
176 | f"{self.n_training_steps}| MSE Loss {loss_dict['mse_loss'].item():.3f} | L1 {loss_dict['l1_loss'].item():.3f}"
177 | )
178 | pbar.update(batch_size)
179 |
180 | @torch.no_grad()
181 | def _checkpoint_if_needed(self):
182 | if (
183 | self.checkpoint_thresholds
184 | and self.n_training_tokens > self.checkpoint_thresholds[0]
185 | ):
186 | self.save_checkpoint()
187 | self.run_evals() # TODO: Implement this
188 | self.checkpoint_thresholds.pop(0)
189 |
190 | def save_checkpoint(self, is_final=False):
191 | if is_final:
192 | path = f"{self.cfg.checkpoint_path}/final_{self.sae.get_name()}.pt"
193 | else:
194 | path = f"{self.cfg.checkpoint_path}/{self.n_training_tokens}_{self.sae.get_name()}.pt"
195 | self.sae.save_model(path)
196 |
197 | def fit(self) -> SparseAutoencoder:
198 | pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE")
199 |
200 | try:
201 | # Train loop
202 | while self.n_training_tokens < self.cfg.total_training_tokens:
203 | # Do a training step.
204 | sae_acts = self.activation_store.get_batch_activations()
205 | self.n_training_tokens += sae_acts.size(0)
206 |
207 | sae_out, feature_acts, loss_dict = self._train_step(sae_in=sae_acts)
208 |
209 | if (
210 | self.cfg.log_to_wandb
211 | and (self.n_training_steps + 1) % self.cfg.wandb_log_frequency == 0
212 | ):
213 | self._log_train_step(
214 | feature_acts=feature_acts,
215 | loss_dict=loss_dict,
216 | sae_out=sae_out,
217 | sae_in=sae_acts,
218 | )
219 |
220 | self._checkpoint_if_needed()
221 | self.n_training_steps += 1
222 | self._update_pbar(loss_dict, pbar, sae_out.size(0))
223 | finally:
224 | print("Saving final checkpoint")
225 | self.save_checkpoint(is_final=True)
226 | self.run_evals()
227 |
228 | pbar.close()
229 | return self.sae
230 |
231 | @torch.no_grad()
232 | def run_evals(self):
233 | self.sae.eval()
234 |
235 | def _create_hook(hook_fn):
236 | return Hook(
237 | self.sae.cfg.block_layer,
238 | self.sae.cfg.module_name,
239 | hook_fn,
240 | return_module_output=False,
241 | )
242 |
243 | def _zero_ablation_hook(activations):
244 | activations[:, 0, :] = torch.zeros_like(activations[:, 0, :]).to(
245 | activations.device
246 | )
247 | return (activations,)
248 |
249 | def _sae_reconstruction_hook(activations):
250 | activations[:, 0, :] = self.sae(activations[:, 0, :])[0]
251 | return (activations,)
252 |
253 | # Get model inputs and compute baseline loss
254 | # model_inputs = self.activation_store.get_batch_of_images_and_labels()
255 | model_inputs = self.activation_store.get_batch_model_inputs(process_labels=True)
256 | original_loss = self.model(return_type="loss", **model_inputs).item()
257 |
258 | # Compute loss with SAE reconstruction
259 | sae_hooks = [_create_hook(_sae_reconstruction_hook)]
260 | reconstruction_loss = self.model.run_with_hooks(
261 | sae_hooks, return_type="loss", **model_inputs
262 | ).item()
263 |
264 | # Compute loss with zeroed activations
265 | zero_hooks = [_create_hook(_zero_ablation_hook)]
266 | zero_ablation_loss = self.model.run_with_hooks(
267 | zero_hooks, return_type="loss", **model_inputs
268 | ).item()
269 |
270 | # Calculate reconstruction score
271 | reconstruction_score = (reconstruction_loss - original_loss) / (
272 | zero_ablation_loss - original_loss
273 | )
274 |
275 | # Log metrics if configured
276 | if self.cfg.log_to_wandb:
277 | wandb.log(
278 | {
279 | "metrics/contrastive_loss_score": reconstruction_score,
280 | "metrics/original_contrastive_loss": original_loss,
281 | "metrics/contrastive_loss_with_sae": reconstruction_loss,
282 | "metrics/contrastive_loss_with_ablation": zero_ablation_loss,
283 | },
284 | step=self.n_training_steps,
285 | )
286 |
287 | del model_inputs
288 | torch.cuda.empty_cache()
289 |
290 | self.sae.train()
291 |
--------------------------------------------------------------------------------
/src/sae_training/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Dict, Optional
3 |
4 | import torch
5 | import torch.optim as optim
6 | import torch.optim.lr_scheduler as lr_scheduler
7 |
8 | from src.sae_training.hooked_vit import HookedVisionTransformer
9 |
10 | SAE_DIM = 49152
11 |
12 |
13 | def process_model_inputs(
14 | batch: Dict, vit: HookedVisionTransformer, device: str, process_labels: bool = False
15 | ) -> torch.Tensor:
16 | """Process input images through the ViT processor."""
17 | if process_labels:
18 | labels = [f"A photo of a {label}" for label in batch["label"]]
19 | return vit.processor(
20 | images=batch["image"], text=labels, return_tensors="pt", padding=True
21 | ).to(device)
22 |
23 | return vit.processor(
24 | images=batch["image"], text="", return_tensors="pt", padding=True
25 | ).to(device)
26 |
27 |
28 | def get_model_activations(
29 | model: HookedVisionTransformer, inputs: dict, block_layer, module_name, class_token
30 | ) -> torch.Tensor:
31 | """Extract activations from a specific layer of the vision transformer model."""
32 | hook_location = (block_layer, module_name)
33 |
34 | # Run model forward pass and extract activations from cache
35 | _, cache = model.run_with_cache([hook_location], **inputs)
36 | activations = cache[hook_location]
37 |
38 | batch_size = inputs["pixel_values"].shape[0]
39 | if activations.shape[0] != batch_size:
40 | activations = activations.transpose(0, 1)
41 |
42 | # Extract class token if specified
43 | if class_token:
44 | activations = activations[0, :, :]
45 |
46 | return activations
47 |
48 |
49 | def get_scheduler(scheduler_name: Optional[str], optimizer: optim.Optimizer, **kwargs):
50 | def get_warmup_lambda(warm_up_steps, training_steps):
51 | def lr_lambda(steps):
52 | if steps < warm_up_steps:
53 | return (steps + 1) / warm_up_steps
54 | else:
55 | return (training_steps - steps) / (training_steps - warm_up_steps)
56 |
57 | return lr_lambda
58 |
59 | # heavily derived from hugging face although copilot helped.
60 | def get_warmup_cosine_lambda(warm_up_steps, training_steps, lr_end):
61 | def lr_lambda(steps):
62 | if steps < warm_up_steps:
63 | return (steps + 1) / warm_up_steps
64 | else:
65 | progress = (steps - warm_up_steps) / (training_steps - warm_up_steps)
66 | return lr_end + 0.5 * (1 - lr_end) * (1 + math.cos(math.pi * progress))
67 |
68 | return lr_lambda
69 |
70 | if scheduler_name is None or scheduler_name.lower() == "constant":
71 | return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda steps: 1.0)
72 | elif scheduler_name.lower() == "constantwithwarmup":
73 | warm_up_steps = kwargs.get("warm_up_steps", 500)
74 | return lr_scheduler.LambdaLR(
75 | optimizer,
76 | lr_lambda=lambda steps: min(1.0, (steps + 1) / warm_up_steps),
77 | )
78 | elif scheduler_name.lower() == "linearwarmupdecay":
79 | warm_up_steps = kwargs.get("warm_up_steps", 0)
80 | training_steps = kwargs.get("training_steps")
81 | lr_lambda = get_warmup_lambda(warm_up_steps, training_steps)
82 | return lr_scheduler.LambdaLR(optimizer, lr_lambda)
83 | elif scheduler_name.lower() == "cosineannealing":
84 | training_steps = kwargs.get("training_steps")
85 | eta_min = kwargs.get("lr_end", 0)
86 | return lr_scheduler.CosineAnnealingLR(
87 | optimizer, T_max=training_steps, eta_min=eta_min
88 | )
89 | elif scheduler_name.lower() == "cosineannealingwarmup":
90 | warm_up_steps = kwargs.get("warm_up_steps", 0)
91 | training_steps = kwargs.get("training_steps")
92 | eta_min = kwargs.get("lr_end", 0)
93 | lr_lambda = get_warmup_cosine_lambda(warm_up_steps, training_steps, eta_min)
94 | return lr_scheduler.LambdaLR(optimizer, lr_lambda)
95 | elif scheduler_name.lower() == "cosineannealingwarmrestarts":
96 | training_steps = kwargs.get("training_steps")
97 | eta_min = kwargs.get("lr_end", 0)
98 | num_cycles = kwargs.get("num_cycles", 1)
99 | T_0 = training_steps // num_cycles
100 | return lr_scheduler.CosineAnnealingWarmRestarts(
101 | optimizer, T_0=T_0, eta_min=eta_min
102 | )
103 | else:
104 | raise ValueError(f"Unsupported scheduler: {scheduler_name}")
105 |
--------------------------------------------------------------------------------
/src/sae_training/vit_activations_store.py:
--------------------------------------------------------------------------------
1 | """
2 | # Portions of this file are based on code from the “jbloomAus/SAELens” and "HugoFry/mats_sae_training_for_ViTs" repositories (MIT-licensed):
3 | https://github.com/jbloomAus/SAELens/blob/main/sae_lens/training/activations_store.py
4 | https://github.com/HugoFry/mats_sae_training_for_ViTs/blob/main/sae_training/vit_activations_store.py
5 | """
6 |
7 | from torch.utils.data import DataLoader, TensorDataset
8 |
9 | from src.sae_training.hooked_vit import HookedVisionTransformer
10 | from src.sae_training.utils import get_model_activations, process_model_inputs
11 |
12 |
13 | class ViTActivationsStore:
14 | """
15 | Class for streaming tokens and generating and storing activations
16 | while training SAEs.
17 | """
18 |
19 | def __init__(
20 | self,
21 | dataset,
22 | batch_size: int,
23 | device: str,
24 | seed: int,
25 | model: HookedVisionTransformer,
26 | block_layer: int,
27 | module_name: str,
28 | class_token: bool,
29 | ):
30 | self.device = device
31 | self.model = model
32 | self.batch_size = batch_size
33 | self.block_layer = block_layer
34 | self.module_name = module_name
35 | self.class_token = class_token
36 |
37 | self.dataset = dataset.shuffle(seed=seed)
38 | self.dataset_iter = iter(self.dataset)
39 |
40 | def get_batch_model_inputs(self, process_labels=False):
41 | """Get model activations for a batch of data"""
42 | batch_dict = {"image": [], "label": []}
43 |
44 | def _add_data(current_item, batch_dict):
45 | for key, value in current_item.items():
46 | batch_dict[key].append(value)
47 |
48 | for _ in range(self.batch_size):
49 | try:
50 | current_item = next(self.dataset_iter)
51 | except StopIteration:
52 | self.dataset_iter = iter(self.dataset)
53 | current_item = next(self.dataset_iter)
54 | _add_data(current_item, batch_dict)
55 |
56 | inputs = process_model_inputs(
57 | batch_dict, self.model, self.device, process_labels=process_labels
58 | )
59 | return inputs
60 |
61 | def get_batch_activations(self):
62 | # """Get model activations for a batch of data"""
63 | inputs = self.get_batch_model_inputs()
64 | return get_model_activations(
65 | self.model, inputs, self.block_layer, self.module_name, self.class_token
66 | )
67 |
68 | def _create_new_dataloader(self) -> DataLoader:
69 | """Create a new dataloader with fresh activations"""
70 | activations = self._get_batch_activations()
71 | dataset = TensorDataset(activations)
72 | return iter(DataLoader(dataset, batch_size=self.batch_size, shuffle=True))
73 |
74 | def get_next_batch(self):
75 | """Get next batch, creating new dataloader if current one is exhausted"""
76 | try:
77 | return self._get_batch_activations()
78 | except StopIteration:
79 | self.dataloader = self._create_new_dataloader()
80 | return next(self.dataloader)
81 |
--------------------------------------------------------------------------------
/tasks/README.md:
--------------------------------------------------------------------------------
1 | # PatchSAE: Training & Usage Guide
2 |
3 | ## 📋 Table of Contents
4 | - [Train SAE](#-train-sae)
5 | - [Extract SAE Latent Data](#-extract-sae-latent-data)
6 | - [Compute Class-Level SAE Latents](#-compute-class-level-sae-latents)
7 | - [Steer Classification](#-steer-classification)
8 |
9 | ## 🔧 Train SAE
10 |
11 | Train a sparse autoencoder on CLIP features.
12 |
13 | ### Prerequisites
14 | - CLIP checkpoint
15 | - Training dataset (images only)
16 |
17 | ### Training Command
18 | ```bash
19 | PYTHONPATH=./ python tasks/train_sae_vit.py
20 | ```
21 | > 📝 Configuration files will be added soon
22 |
23 | ### Outputs
24 | - SAE checkpoint (`.pt` file)
25 |
26 | ### Monitoring
27 | - View our [training logs on W&B](https://api.wandb.ai/links/hyesulim-hs/7dx90sq0)
28 |
29 | ---
30 |
31 | ## 📊 Extract SAE Latent Data
32 |
33 | Extract and save SAE latent activations for downstream analysis.
34 |
35 | ### Prerequisites
36 | - CLIP checkpoint
37 | - SAE checkpoint (from training step)
38 | - Dataset (can differ from training dataset)
39 |
40 | ### Run with Original CLIP
41 | ```bash
42 | PYTHONPATH=./ python tasks/compute_sae_feature_data.py \
43 | --root_dir ./ \
44 | --dataset_name imagenet \
45 | --sae_path /PATH/TO/SAE_CKPT.pt \
46 | --vit_type base
47 | ```
48 |
49 | ### Run with Adapted CLIP (e.g., MaPLe)
50 | 1. Download MaPLe from the [official repo](https://github.com/muzairkhattak/multimodal-prompt-learning/tree/main?tab=readme-ov-file#model-zoo) or [Google Drive](https://drive.google.com/drive/folders/1EvuvgR8566bL0T7ucvAL3LFVwuUPMRas)
51 |
52 | 2. Run extraction:
53 | ```bash
54 | PYTHONPATH=./ python tasks/compute_sae_feature_data.py \
55 | --root_dir ./ \
56 | --dataset_name imagenet \
57 | --sae_path /PATH/TO/SAE_CKPT.pt \
58 | --vit_type maple \
59 | --model_path /PATH/TO/MAPLE_CKPT \ # e.g., .../model.pth.tar-5
60 | --config_path /PATH/TO/MAPLE_CFG \ # e.g., .../configs/models/maple/vit_b16_c2_ep5_batch4_2ctx.yaml
61 | ```
62 |
63 | ### Output Files
64 | All files will be saved to: `{root_dir}/out/feature_data/{vit_type}/{dataset_name}/`
65 |
66 | - `max_activating_image_indices.pt`
67 | - `max_activating_image_label_indices.pt`
68 | - `max_activating_image_values.pt`
69 | - `sae_mean_acts.pt`
70 | - `sae_sparsity.pt`
71 |
72 | ### Analysis
73 | Explore the extracted features with our [patchsae/analysis/analysis.ipynb](https://github.com/hyesulim/patchsae/blob/9f28fdc6ffb7beccb5c2b8ee629b6752b904aa23/analysis/analysis.ipynb)
74 |
75 | ---
76 |
77 | ## 🧩 Compute Class-Level SAE Latents
78 |
79 | Compute class-level SAE activation patterns.
80 |
81 | ### Prerequisites
82 | - CLIP checkpoint
83 | - SAE checkpoint
84 | - SAE feature data (from previous step)
85 | - Dataset (must be the SAME dataset used in the extraction step)
86 |
87 | ### Run with Original CLIP
88 | ```bash
89 | PYTHONPATH=./ python tasks/compute_class_wise_sae_activation.py \
90 | --root_dir ./ \
91 | --dataset_name imagenet \
92 | --threshold 0.2 \
93 | --sae_path /PATH/TO/SAE_CKPT.pt \
94 | --vit_type base
95 | ```
96 |
97 | ### Run with Adapted CLIP (e.g., MaPLe)
98 | ```bash
99 | PYTHONPATH=./ python tasks/compute_class_wise_sae_activation.py \
100 | --root_dir ./ \
101 | --dataset_name imagenet \
102 | --threshold 0.2 \
103 | --sae_path /PATH/TO/SAE_CKPT.pt \
104 | --vit_type maple \
105 | --model_path /PATH/TO/MAPLE_CKPT \ # e.g., .../model.pth.tar-5
106 | --config_path /PATH/TO/MAPLE_CFG \ # e.g., .../configs/models/maple/vit_b16_c2_ep5_batch4_2ctx.yaml
107 | ```
108 |
109 | ### Output File
110 | - `cls_sae_cnt.npy` - Matrix of shape `(num_sae_latents, num_classes)`
111 |
112 | ---
113 |
114 | ## 🎯 Steer Classification
115 |
116 | Evaluate classification using feature steering with SAE latents.
117 |
118 | ### Prerequisites
119 | - CLIP checkpoint
120 | - SAE checkpoint
121 | - Class-level activation data (`cls_sae_cnt.npy` from previous step)
122 | - Dataset (must be the SAME dataset used for class-level activations, though can be a different split)
123 |
124 | ### Run with Original CLIP
125 | ```bash
126 | PYTHONPATH=./ python tasks/classification_with_top_k_masking.py \
127 | --root_dir ./ \
128 | --dataset_name imagenet \
129 | --sae_path /PATH/TO/SAE_CKPT.pt \
130 | --cls_wise_sae_activation_path /PATH/TO/cls_sae_cnt.npy
131 | ```
132 |
133 | ### Run with Adapted CLIP (e.g., MaPLe)
134 | ```bash
135 | PYTHONPATH=./ python tasks/classification_with_top_k_masking.py \
136 | --root_dir ./ \
137 | --dataset_name imagenet \
138 | --sae_path /PATH/TO/SAE_CKPT.pt \
139 | --cls_wise_sae_activation_path /PATH/TO/cls_sae_cnt.npy \
140 | --vit_type maple \
141 | --model_path /PATH/TO/MAPLE_CKPT \ # e.g., .../model.pth.tar-5
142 | --config_path /PATH/TO/MAPLE_CFG \ # e.g., .../configs/models/maple/vit_b16_c2_ep5_batch4_2ctx.yaml
143 | ```
144 |
145 | ### Output File
146 | Output will be saved to `eval_outputs/`:
147 | - `metrics.csv` - Contains class-wise True Positive Rate (TPR = TP/(TP+FP+TN+FN)) for each masking configuration
148 | - Results for both "on" and "off" conditions
149 | - For k values in [1, 2, 5, 10, 50, 100, 500, 1000, 2000, SAE_DIM]
150 |
--------------------------------------------------------------------------------
/tasks/classification_with_top_k_masking.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from torch.nn.utils.rnn import pad_sequence
8 | from tqdm import tqdm
9 |
10 | from src.models.templates.openai_imagenet_templates import openai_imagenet_template
11 | from src.sae_training.config import Config
12 | from src.sae_training.hooked_vit import Hook, HookedVisionTransformer
13 | from src.sae_training.sparse_autoencoder import SparseAutoencoder
14 | from tasks.utils import (
15 | SAE_DIM,
16 | get_sae_and_vit,
17 | load_and_organize_dataset,
18 | process_batch,
19 | setup_save_directory,
20 | )
21 |
22 | TOPK_LIST = [1, 2, 5, 10, 50, 100, 500, 1000, 2000, SAE_DIM]
23 | SAE_BIAS = -0.105131256516992
24 |
25 |
26 | def calculate_text_features(model, device, classnames):
27 | """Calculate mean text features across templates for each class."""
28 | mean_text_features = 0
29 |
30 | for template_fn in openai_imagenet_template:
31 | # Generate prompts and convert to token IDs
32 | prompts = [template_fn(c) for c in classnames]
33 | prompt_ids = [
34 | model.processor(
35 | text=p, return_tensors="pt", padding=False, truncation=True
36 | ).input_ids[0]
37 | for p in prompts
38 | ]
39 |
40 | # Process batch
41 | padded_prompts = pad_sequence(prompt_ids, batch_first=True).to(device)
42 |
43 | # Get text features
44 | with torch.no_grad():
45 | text_features = model.model.get_text_features(padded_prompts)
46 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
47 | mean_text_features += text_features
48 |
49 | return mean_text_features / len(openai_imagenet_template)
50 |
51 |
52 | def create_sae_hooks(vit_type, cfg, cls_features, sae, device, hook_type="on"):
53 | """Create SAE hooks based on model type and hook type."""
54 | # Setup clamping parameters
55 | clamp_feat_dim = torch.ones(SAE_DIM).bool()
56 | clamp_value = torch.zeros(SAE_DIM) if hook_type == "on" else torch.ones(SAE_DIM)
57 | clamp_value = clamp_value.to(device)
58 | clamp_value[cls_features] = 1.0 if hook_type == "on" else 0.0
59 |
60 | def process_activations(activations, is_maple=False):
61 | """Helper function to process activations with SAE"""
62 | act = activations.transpose(0, 1) if is_maple else activations
63 | processed = (
64 | sae.forward_clamp(
65 | act[:, :, :], clamp_feat_dim=clamp_feat_dim, clamp_value=clamp_value
66 | )[0]
67 | - SAE_BIAS
68 | )
69 | return processed
70 |
71 | def hook_fn_default(activations):
72 | activations[:, :, :] = process_activations(activations)
73 | return (activations,)
74 |
75 | def hook_fn_maple(activations):
76 | activations = process_activations(activations, is_maple=True)
77 | return activations.transpose(0, 1)
78 |
79 | # Create appropriate hook based on model type
80 | hook_fn = hook_fn_maple if vit_type == "maple" else hook_fn_default
81 | is_custom = vit_type == "maple"
82 |
83 | return [
84 | Hook(
85 | cfg.block_layer,
86 | cfg.module_name,
87 | hook_fn,
88 | return_module_output=False,
89 | is_custom=is_custom,
90 | )
91 | ]
92 |
93 |
94 | def get_predictions(vit, inputs, text_features, vit_type, hooks=None):
95 | """Get model predictions with optional hooks."""
96 | with torch.no_grad():
97 | if hooks:
98 | vit_out = vit.run_with_hooks(hooks, return_type="output", **inputs)
99 | else:
100 | vit_out = vit(return_type="output", **inputs)
101 |
102 | image_features = vit_out.image_embeds if vit_type == "base" else vit_out
103 | logit_scale = vit.model.logit_scale.exp()
104 | logits = logit_scale * image_features @ text_features.t()
105 | preds = logits.argmax(dim=-1)
106 |
107 | return preds.cpu().numpy().tolist()
108 |
109 |
110 | def classify_with_top_k_masking(
111 | class_data: list,
112 | cls_idx: int,
113 | sae: SparseAutoencoder,
114 | vit: HookedVisionTransformer,
115 | cls_sae_cnt: torch.Tensor,
116 | text_features: torch.Tensor,
117 | batch_size: int,
118 | device: str,
119 | vit_type: str,
120 | cfg: Config,
121 | ):
122 | """Classify images with top-k feature masking."""
123 | num_batches = (len(class_data) + batch_size - 1) // batch_size
124 |
125 | preds_dict = defaultdict(list)
126 |
127 | for batch_idx in range(num_batches):
128 | batch_start = batch_idx * batch_size
129 | batch_end = min((batch_idx + 1) * batch_size, len(class_data))
130 | batch_data = class_data[batch_start:batch_end]
131 |
132 | batch_inputs = process_batch(vit, batch_data, device)
133 |
134 | # Get predictions without SAE
135 | preds_dict["no_sae"].extend(
136 | get_predictions(vit, batch_inputs, text_features, vit_type)
137 | )
138 | torch.cuda.empty_cache()
139 |
140 | # Get top features for current class
141 | loaded_cls_sae_idx = cls_sae_cnt[cls_idx].argsort()[::-1]
142 |
143 | for topk in TOPK_LIST:
144 | cls_features = loaded_cls_sae_idx[:topk].tolist()
145 |
146 | # Get predictions with features ON
147 | hooks_on = create_sae_hooks(vit_type, cfg, cls_features, sae, device, "on")
148 | preds_dict[f"on_{topk}"].extend(
149 | get_predictions(vit, batch_inputs, text_features, vit_type, hooks_on)
150 | )
151 | torch.cuda.empty_cache()
152 |
153 | # Get predictions with features OFF
154 | hooks_off = create_sae_hooks(
155 | vit_type, cfg, cls_features, sae, device, "off"
156 | )
157 | preds_dict[f"off_{topk}"].extend(
158 | get_predictions(vit, batch_inputs, text_features, vit_type, hooks_off)
159 | )
160 | torch.cuda.empty_cache()
161 |
162 | return preds_dict
163 |
164 |
165 | def main(
166 | sae_path: str,
167 | vit_type: str,
168 | device: str,
169 | dataset_name: str,
170 | root_dir: str,
171 | save_name: str,
172 | backbone: str = "openai/clip-vit-base-patch16",
173 | batch_size: int = 8,
174 | model_path: str = None,
175 | config_path: str = None,
176 | cls_wise_sae_activation_path: str = None,
177 | ):
178 | class_feature_type = cls_wise_sae_activation_path.split("/")[-3]
179 | save_directory = setup_save_directory(
180 | root_dir, save_name, sae_path, f"{class_feature_type}_{vit_type}", dataset_name
181 | )
182 |
183 | classnames, data_by_class = load_and_organize_dataset(dataset_name)
184 |
185 | sae, vit, cfg = get_sae_and_vit(
186 | sae_path,
187 | vit_type,
188 | device,
189 | backbone,
190 | model_path=model_path,
191 | config_path=config_path,
192 | classnames=classnames,
193 | )
194 |
195 | cls_sae_cnt = np.load(cls_wise_sae_activation_path)
196 |
197 | if vit_type == "base":
198 | text_features = calculate_text_features(vit, device, classnames)
199 | else:
200 | text_features = vit.model.get_text_features()
201 |
202 | metrics_dict = {}
203 | for class_idx, classname in enumerate(tqdm(classnames)):
204 | preds_dict = classify_with_top_k_masking(
205 | data_by_class[classname],
206 | class_idx,
207 | sae,
208 | vit,
209 | cls_sae_cnt,
210 | text_features,
211 | batch_size,
212 | device,
213 | vit_type,
214 | cfg,
215 | )
216 |
217 | metrics_dict[class_idx] = {}
218 | for k, v in preds_dict.items():
219 | metrics_dict[class_idx][k] = v.count(class_idx) / len(v) * 100
220 |
221 | metrics_df = pd.DataFrame(metrics_dict)
222 | metrics_df.to_csv(f"{save_directory}/metrics.csv", index=False)
223 | print(f"metrics.csv saved at {save_directory}")
224 |
225 |
226 | if __name__ == "__main__":
227 | parser = argparse.ArgumentParser(
228 | description="Perform classification with top-k masking"
229 | )
230 | parser.add_argument("--root_dir", type=str, default=".", help="Root directory")
231 | parser.add_argument("--dataset_name", type=str, default="imagenet")
232 | parser.add_argument(
233 | "--sae_path", type=str, required=True, help="SAE ckpt path (ends with xxx.pt)"
234 | )
235 | parser.add_argument("--batch_size", type=int, default=128, help="batch size")
236 | parser.add_argument(
237 | "--cls_wise_sae_activation_path",
238 | type=str,
239 | help="path for cls_sae_cnt.npy",
240 | )
241 | parser.add_argument(
242 | "--vit_type", type=str, default="base", help="choose between [base, maple]"
243 | )
244 | parser.add_argument(
245 | "--model_path",
246 | type=str,
247 | help="CLIP model path in the case of not using the default",
248 | )
249 | parser.add_argument(
250 | "--config_path",
251 | type=str,
252 | help="CLIP config path in the case of using maple",
253 | )
254 | parser.add_argument("--device", type=str, default="cuda")
255 |
256 | args = parser.parse_args()
257 |
258 | main(
259 | sae_path=args.sae_path,
260 | vit_type=args.vit_type,
261 | device=args.device,
262 | dataset_name=args.dataset_name,
263 | root_dir=args.root_dir,
264 | save_name="out/feature_data",
265 | batch_size=args.batch_size,
266 | model_path=args.model_path,
267 | config_path=args.config_path,
268 | cls_wise_sae_activation_path=args.cls_wise_sae_activation_path,
269 | )
270 |
--------------------------------------------------------------------------------
/tasks/compute_class_wise_sae_activation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from typing import Dict
4 |
5 | import numpy as np
6 | import torch
7 | from tqdm import tqdm
8 |
9 | from src.sae_training.config import Config
10 | from src.sae_training.hooked_vit import HookedVisionTransformer
11 | from src.sae_training.sparse_autoencoder import SparseAutoencoder
12 | from src.sae_training.utils import get_model_activations
13 | from tasks.utils import (
14 | SAE_DIM,
15 | get_sae_and_vit,
16 | load_and_organize_dataset,
17 | process_batch,
18 | setup_save_directory,
19 | )
20 |
21 |
22 | def get_sae_activations(
23 | model_activations: torch.Tensor, sae: SparseAutoencoder, threshold: float
24 | ) -> torch.Tensor:
25 | """Get binary SAE activations above threshold.
26 |
27 | Args:
28 | model_activations: Input activations from vision transformer
29 | sae: The sparse autoencoder model
30 | threshold: Activation threshold
31 |
32 | Returns:
33 | Binary tensor indicating which features were active
34 | """
35 | _, cache = sae.run_with_cache(model_activations)
36 | activations = cache["hook_hidden_post"] > threshold
37 | return activations.sum(dim=0).sum(dim=0)
38 |
39 |
40 | def process_class_batch(
41 | batch_data: list,
42 | sae: SparseAutoencoder,
43 | vit: HookedVisionTransformer,
44 | cfg: Config,
45 | threshold: float,
46 | device: str,
47 | ) -> np.ndarray:
48 | """Process a single batch of class data and get SAE feature activations."""
49 | batch_inputs = process_batch(vit, batch_data, device)
50 | transformer_activations = get_model_activations(
51 | vit, batch_inputs, cfg.block_layer, cfg.module_name, cfg.class_token
52 | )
53 | active_features = get_sae_activations(transformer_activations, sae, threshold)
54 | return active_features.cpu().numpy()
55 |
56 |
57 | def compute_class_feature_counts(
58 | class_data: list,
59 | sae: SparseAutoencoder,
60 | vit: HookedVisionTransformer,
61 | cfg: Config,
62 | batch_size: int,
63 | threshold: float,
64 | device: str,
65 | ) -> np.ndarray:
66 | """Compute SAE feature activation counts for a single class."""
67 | feature_counts = np.zeros(SAE_DIM)
68 | num_batches = (len(class_data) + batch_size - 1) // batch_size
69 |
70 | for batch_idx in range(num_batches):
71 | batch_start = batch_idx * batch_size
72 | batch_end = min((batch_idx + 1) * batch_size, len(class_data))
73 | batch_data = class_data[batch_start:batch_end]
74 |
75 | batch_counts = process_class_batch(batch_data, sae, vit, cfg, threshold, device)
76 | feature_counts += batch_counts
77 | torch.cuda.empty_cache()
78 |
79 | return feature_counts
80 |
81 |
82 | def compute_all_class_activations(
83 | classnames: list,
84 | data_by_class: Dict,
85 | sae: SparseAutoencoder,
86 | vit: HookedVisionTransformer,
87 | cfg: Config,
88 | batch_size: int,
89 | threshold: float,
90 | device: str,
91 | ) -> np.ndarray:
92 | """Compute SAE activation counts across all classes."""
93 | class_activation_counts = np.zeros((len(classnames), SAE_DIM))
94 |
95 | for class_idx, classname in enumerate(tqdm(classnames)):
96 | class_data = data_by_class[classname]
97 | class_counts = compute_class_feature_counts(
98 | class_data, sae, vit, cfg, batch_size, threshold, device
99 | )
100 | class_activation_counts[class_idx] = class_counts
101 |
102 | return class_activation_counts
103 |
104 |
105 | def main(
106 | sae_path: str,
107 | vit_type: str,
108 | device: str,
109 | dataset_name: str,
110 | root_dir: str,
111 | save_name: str,
112 | backbone: str = "openai/clip-vit-base-patch16",
113 | batch_size: int = 8,
114 | model_path: str = None,
115 | config_path: str = None,
116 | threshold: float = 0.2,
117 | ):
118 | """Main function to compute and save class-wise SAE activation counts."""
119 |
120 | save_directory = setup_save_directory(
121 | root_dir, save_name, sae_path, vit_type, dataset_name
122 | )
123 |
124 | classnames, data_by_class = load_and_organize_dataset(dataset_name)
125 |
126 | sae, vit, cfg = get_sae_and_vit(
127 | sae_path,
128 | vit_type,
129 | device,
130 | backbone,
131 | model_path=model_path,
132 | config_path=config_path,
133 | classnames=classnames,
134 | )
135 |
136 | class_activation_counts = compute_all_class_activations(
137 | classnames, data_by_class, sae, vit, cfg, batch_size, threshold, device
138 | )
139 |
140 | # Save results
141 | save_path = os.path.join(save_directory, "cls_sae_cnt.npy")
142 | np.save(save_path, class_activation_counts)
143 | print(f"Class activation counts saved at {save_directory}")
144 |
145 |
146 | if __name__ == "__main__":
147 | parser = argparse.ArgumentParser(description="Save class-wise SAE activation count")
148 | parser.add_argument("--root_dir", type=str, default=".", help="Root directory")
149 | parser.add_argument("--dataset_name", type=str, default="imagenet")
150 | parser.add_argument(
151 | "--sae_path", type=str, required=True, help="SAE ckpt path (ends with xxx.pt)"
152 | )
153 | parser.add_argument(
154 | "--vit_type", type=str, default="base", help="choose between [base, maple]"
155 | )
156 | parser.add_argument(
157 | "--threshold", type=float, default=0.2, help="threshold for SAE activation"
158 | )
159 | parser.add_argument("--batch_size", type=int, default=128)
160 | parser.add_argument(
161 | "--model_path",
162 | type=str,
163 | help="CLIP model path in the case of not using the default",
164 | )
165 | parser.add_argument(
166 | "--config_path",
167 | type=str,
168 | help="CLIP config path in the case of using maple",
169 | )
170 | parser.add_argument("--device", type=str, default="cuda")
171 |
172 | args = parser.parse_args()
173 |
174 | main(
175 | sae_path=args.sae_path,
176 | vit_type=args.vit_type,
177 | device=args.device,
178 | dataset_name=args.dataset_name,
179 | root_dir=args.root_dir,
180 | save_name="out/feature_data",
181 | batch_size=args.batch_size,
182 | model_path=args.model_path,
183 | config_path=args.config_path,
184 | threshold=args.threshold,
185 | )
186 |
--------------------------------------------------------------------------------
/tasks/compute_sae_feature_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Dict, Optional, Tuple
3 |
4 | import torch
5 | from datasets import Dataset, load_dataset
6 | from tqdm import tqdm
7 |
8 | from src.sae_training.config import Config
9 | from src.sae_training.hooked_vit import HookedVisionTransformer
10 | from src.sae_training.sparse_autoencoder import SparseAutoencoder
11 | from src.sae_training.utils import get_model_activations, process_model_inputs
12 | from tasks.utils import (
13 | DATASET_INFO,
14 | get_classnames,
15 | get_sae_activations,
16 | get_sae_and_vit,
17 | setup_save_directory,
18 | )
19 |
20 |
21 | def initialize_storage_tensors(
22 | d_sae: int, num_max: int, device: str
23 | ) -> Dict[str, torch.Tensor]:
24 | """Initialize tensors for storing results."""
25 | return {
26 | "max_activating_image_values": torch.zeros([d_sae, num_max]).to(device),
27 | "max_activating_image_indices": torch.zeros([d_sae, num_max]).to(device),
28 | "sae_sparsity": torch.zeros([d_sae]).to(device),
29 | "sae_mean_acts": torch.zeros([d_sae]).to(device),
30 | }
31 |
32 |
33 | def get_new_top_k(
34 | first_values: torch.Tensor,
35 | first_indices: torch.Tensor,
36 | second_values: torch.Tensor,
37 | second_indices: torch.Tensor,
38 | k: int,
39 | ) -> Tuple[torch.Tensor, torch.Tensor]:
40 | """Get top k values and indices from two sets of values/indices."""
41 | total_values = torch.cat([first_values, second_values], dim=1)
42 | total_indices = torch.cat([first_indices, second_indices], dim=1)
43 | new_values, indices_of_indices = torch.topk(total_values, k=k, dim=1)
44 | new_indices = torch.gather(total_indices, 1, indices_of_indices)
45 | return new_values, new_indices
46 |
47 |
48 | def compute_sae_statistics(
49 | sae_activations: torch.Tensor,
50 | ) -> Tuple[torch.Tensor, torch.Tensor]:
51 | """Compute mean activations and sparsity statistics for SAE features."""
52 | mean_acts = sae_activations.sum(dim=1)
53 | sparsity = (sae_activations > 0).sum(dim=1)
54 | return mean_acts, sparsity
55 |
56 |
57 | def get_top_activations(
58 | sae_activations: torch.Tensor, num_top_images: int, images_processed: int
59 | ) -> Tuple[torch.Tensor, torch.Tensor]:
60 | """Get top activating images and their indices."""
61 | top_k = min(num_top_images, sae_activations.size(1))
62 | values, indices = torch.topk(sae_activations, k=top_k, dim=1)
63 | indices += images_processed
64 | return values, indices
65 |
66 |
67 | def process_batch(
68 | batch: Dict,
69 | vit: HookedVisionTransformer,
70 | sae: SparseAutoencoder,
71 | cfg: Config,
72 | device: str,
73 | num_top_images: int,
74 | images_processed: int,
75 | storage: Dict[str, torch.Tensor],
76 | ) -> Tuple[Dict[str, torch.Tensor], int]:
77 | """Process a single batch of images and update feature statistics."""
78 | # Get model activations
79 | inputs = process_model_inputs(batch, vit, device)
80 | model_acts = get_model_activations(
81 | vit, inputs, cfg.block_layer, cfg.module_name, cfg.class_token
82 | )
83 | sae_acts = get_sae_activations(model_acts, sae).transpose(0, 1)
84 |
85 | # Update statistics
86 | mean_acts, sparsity = compute_sae_statistics(sae_acts)
87 | storage["sae_mean_acts"] += mean_acts
88 | storage["sae_sparsity"] += sparsity
89 |
90 | # Get top activating images
91 | values, indices = get_top_activations(sae_acts, num_top_images, images_processed)
92 |
93 | top_values, top_indices = get_new_top_k(
94 | storage["max_activating_image_values"],
95 | storage["max_activating_image_indices"],
96 | values,
97 | indices,
98 | num_top_images,
99 | )
100 |
101 | # Update processed image count
102 | images_processed += model_acts.size(0)
103 |
104 | return {
105 | "max_activating_image_values": top_values,
106 | "max_activating_image_indices": top_indices,
107 | "sae_sparsity": storage["sae_sparsity"],
108 | "sae_mean_acts": storage["sae_mean_acts"],
109 | }, images_processed
110 |
111 |
112 | def save_results(
113 | save_directory: str,
114 | storage: Dict[str, torch.Tensor],
115 | dataset: Dataset,
116 | label_field: Optional[str] = None,
117 | ) -> None:
118 | """Save results to disk."""
119 | if label_field and label_field in dataset.features:
120 | max_activating_image_label_indices = torch.tensor(
121 | [
122 | dataset[int(index)][label_field]
123 | for index in tqdm(
124 | storage["max_activating_image_indices"].flatten(),
125 | desc="getting image labels",
126 | )
127 | ]
128 | ).view(storage["max_activating_image_indices"].shape)
129 |
130 | torch.save(
131 | max_activating_image_label_indices,
132 | f"{save_directory}/max_activating_image_label_indices.pt",
133 | )
134 |
135 | torch.save(
136 | storage["max_activating_image_indices"],
137 | f"{save_directory}/max_activating_image_indices.pt",
138 | )
139 | torch.save(
140 | storage["max_activating_image_values"],
141 | f"{save_directory}/max_activating_image_values.pt",
142 | )
143 | torch.save(storage["sae_sparsity"], f"{save_directory}/sae_sparsity.pt")
144 | torch.save(storage["sae_mean_acts"], f"{save_directory}/sae_mean_acts.pt")
145 |
146 | print(f"Results saved to {save_directory}")
147 |
148 |
149 | @torch.inference_mode()
150 | def main(
151 | sae_path: str,
152 | vit_type: str,
153 | device: str,
154 | dataset_name: str,
155 | root_dir: str,
156 | save_name: str,
157 | backbone: str = "openai/clip-vit-base-patch16",
158 | number_of_max_activating_images: int = 10,
159 | seed: int = 1,
160 | batch_size: int = 8,
161 | model_path: str = None,
162 | config_path: str = None,
163 | ):
164 | """Main function to extract and save feature data."""
165 | torch.set_float32_matmul_precision("high")
166 | torch.cuda.empty_cache()
167 |
168 | dataset = load_dataset(**DATASET_INFO[dataset_name])
169 | dataset = dataset.shuffle(seed=seed)
170 | classnames = get_classnames(dataset_name, dataset)
171 |
172 | sae, vit, cfg = get_sae_and_vit(
173 | sae_path, vit_type, device, backbone, model_path, config_path, classnames
174 | )
175 |
176 | storage = initialize_storage_tensors(
177 | sae.cfg.d_sae, number_of_max_activating_images, device
178 | )
179 |
180 | # Process batches
181 | total_iterations = (len(dataset) + batch_size - 1) // batch_size
182 | num_processed = 0
183 |
184 | for iteration in tqdm(range(total_iterations)):
185 | batch_start = iteration * batch_size
186 | batch_end = (iteration + 1) * batch_size
187 | current_batch = dataset[batch_start:batch_end]
188 |
189 | storage, num_processed = process_batch(
190 | current_batch,
191 | vit,
192 | sae,
193 | cfg,
194 | device,
195 | number_of_max_activating_images,
196 | num_processed,
197 | storage,
198 | )
199 |
200 | # Finalize statistics
201 | storage["sae_mean_acts"] /= storage["sae_sparsity"]
202 | storage["sae_sparsity"] /= num_processed
203 |
204 | # Save results
205 | save_directory = setup_save_directory(
206 | root_dir, save_name, sae_path, vit_type, dataset_name
207 | )
208 | save_results(
209 | save_directory,
210 | storage,
211 | dataset,
212 | label_field="label" if "label" in dataset.features else None,
213 | )
214 |
215 |
216 | if __name__ == "__main__":
217 | parser = argparse.ArgumentParser(
218 | description="Process ViT SAE images and save feature data"
219 | )
220 | parser.add_argument("--root_dir", type=str, default=".", help="Root directory")
221 | parser.add_argument("--dataset_name", type=str, default="imagenet")
222 | parser.add_argument(
223 | "--sae_path", type=str, required=True, help="SAE ckpt path (ends with xxx.pt)"
224 | )
225 | parser.add_argument(
226 | "--batch_size",
227 | type=int,
228 | default=32,
229 | help="Batch size to compute model activations and sae features",
230 | )
231 | parser.add_argument(
232 | "--model_path",
233 | type=str,
234 | help="CLIP model path in the case of not using the default",
235 | )
236 | parser.add_argument(
237 | "--config_path",
238 | type=str,
239 | help="CLIP config path in the case of using maple",
240 | )
241 | parser.add_argument("--device", type=str, default="cuda")
242 | parser.add_argument(
243 | "--vit_type", type=str, default="base", help="choose between [base, maple]"
244 | )
245 | args = parser.parse_args()
246 |
247 | main(
248 | sae_path=args.sae_path,
249 | vit_type=args.vit_type,
250 | device=args.device,
251 | dataset_name=args.dataset_name,
252 | root_dir=args.root_dir,
253 | save_name="out/feature_data",
254 | batch_size=args.batch_size,
255 | model_path=args.model_path,
256 | config_path=args.config_path,
257 | )
258 |
--------------------------------------------------------------------------------
/tasks/train_sae_vit.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import wandb
5 | from datasets import load_dataset
6 |
7 | from src.sae_training.config import ViTSAERunnerConfig
8 | from src.sae_training.sae_trainer import SAETrainer
9 | from src.sae_training.sparse_autoencoder import SparseAutoencoder
10 | from src.sae_training.utils import get_scheduler
11 | from src.sae_training.vit_activations_store import ViTActivationsStore
12 | from tasks.utils import (
13 | DATASET_INFO,
14 | get_classnames,
15 | load_hooked_vit,
16 | )
17 |
18 | if __name__ == "__main__":
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--class_token", action="store_true", default=None)
21 | parser.add_argument("--image_width", type=int, default=224)
22 | parser.add_argument("--image_height", type=int, default=224)
23 | parser.add_argument(
24 | "--model_name", type=str, default="openai/clip-vit-base-patch16"
25 | )
26 | parser.add_argument("--module_name", type=str, default="resid")
27 | parser.add_argument("--block_layer", type=int, default=-2)
28 | parser.add_argument("--clip_dim", type=int, default=768)
29 |
30 | parser.add_argument("--dataset", type=str, default="imagenet")
31 | parser.add_argument("--use_cached_activations", action="store_true", default=None)
32 | parser.add_argument("--cached_activations_path", type=str)
33 | parser.add_argument("--expansion_factor", type=int, default=64)
34 | parser.add_argument("--b_dec_init_method", type=str, default="geometric_median")
35 | parser.add_argument("--gated_sae", action="store_true", default=None)
36 | # Training Parameters
37 | parser.add_argument("--lr", type=float, default=0.0004)
38 | parser.add_argument("--l1_coefficient", type=float, default=0.00008)
39 | parser.add_argument("--lr_scheduler_name", type=str, default="constantwithwarmup")
40 | parser.add_argument("--batch_size", type=int, default=16)
41 | parser.add_argument("--lr_warm_up_steps", type=int, default=500)
42 | parser.add_argument("--total_training_tokens", type=int, default=2_621_440)
43 | parser.add_argument("--n_batches_in_store", type=int, default=15)
44 | parser.add_argument("--mse_cls_coefficient", type=float, default=1.0)
45 | # Dead Neurons and Sparsity
46 | parser.add_argument("--use_ghost_grads", action="store_true", default=None)
47 | parser.add_argument("--feature_sampling_method")
48 | parser.add_argument("--feature_sampling_window", type=int, default=64)
49 | parser.add_argument("--dead_feature_window", type=int, default=64)
50 | parser.add_argument("--dead_feature_threshold", type=float, default=1e-6)
51 | # WANDB
52 | parser.add_argument("--log_to_wandb", action="store_true", default=None)
53 | parser.add_argument("--wandb_project", type=str, default="patch_sae")
54 | parser.add_argument("--wandb_entity", type=str, default="test")
55 | parser.add_argument("--wandb_log_frequency", type=int, default=20)
56 | # Misc
57 | parser.add_argument("--seed", type=int, default=42)
58 | parser.add_argument("--n_checkpoints", type=int, default=1)
59 | parser.add_argument("--checkpoint_path", type=str, default="out/checkpoints")
60 | parser.add_argument("--device", type=str, default="cuda")
61 | # resume
62 | parser.add_argument("--root_dir", type=str, default="")
63 | parser.add_argument("--resume", action="store_true", default=None)
64 | parser.add_argument("--run_name", type=str, default="train")
65 | parser.add_argument("--start_training_steps", type=int, default=0)
66 | parser.add_argument("--pt_name", type=str)
67 |
68 | parser.add_argument(
69 | "--vit_type", type=str, default="base", help="choose between [base, maple]"
70 | )
71 | parser.add_argument(
72 | "--model_path",
73 | type=str,
74 | help="CLIP model path in the case of not using the default",
75 | )
76 | parser.add_argument(
77 | "--config_path",
78 | type=str,
79 | help="CLIP config path in the case of using maple",
80 | )
81 | args = parser.parse_args()
82 |
83 | cfg = ViTSAERunnerConfig(
84 | class_token=args.class_token,
85 | image_width=args.image_width,
86 | image_height=args.image_height,
87 | model_name=f"openai/{args.model_name}",
88 | module_name=args.module_name,
89 | block_layer=args.block_layer,
90 | dataset_path=DATASET_INFO[args.dataset]["path"],
91 | image_key="image",
92 | label_key="label",
93 | use_cached_activations=args.use_cached_activations,
94 | cached_activations_path=args.cached_activations_path,
95 | d_in=args.clip_dim,
96 | expansion_factor=args.expansion_factor,
97 | b_dec_init_method=args.b_dec_init_method,
98 | gated_sae=args.gated_sae,
99 | lr=args.lr,
100 | l1_coefficient=args.l1_coefficient,
101 | lr_scheduler_name=args.lr_scheduler_name,
102 | batch_size=args.batch_size,
103 | lr_warm_up_steps=args.lr_warm_up_steps,
104 | total_training_tokens=args.total_training_tokens,
105 | n_batches_in_store=args.n_batches_in_store,
106 | mse_cls_coefficient=args.mse_cls_coefficient,
107 | use_ghost_grads=args.use_ghost_grads,
108 | feature_sampling_method=args.feature_sampling_method,
109 | feature_sampling_window=args.feature_sampling_window,
110 | dead_feature_window=args.dead_feature_window,
111 | dead_feature_threshold=args.dead_feature_threshold,
112 | log_to_wandb=args.log_to_wandb,
113 | wandb_project=args.wandb_project,
114 | wandb_entity=args.wandb_entity,
115 | wandb_log_frequency=args.wandb_log_frequency,
116 | device=args.device,
117 | seed=args.seed,
118 | n_checkpoints=args.n_checkpoints,
119 | checkpoint_path=args.checkpoint_path,
120 | dtype=torch.float32,
121 | )
122 |
123 | print("Loading dataset")
124 | classnames = get_classnames(args.dataset)
125 | dataset = load_dataset(**DATASET_INFO[args.dataset])
126 |
127 | print("Loading SAE and ViT models")
128 | sae = SparseAutoencoder(cfg, args.device)
129 |
130 | vit = load_hooked_vit(
131 | cfg,
132 | args.vit_type,
133 | args.model_name,
134 | args.device,
135 | args.model_path,
136 | args.config_path,
137 | classnames,
138 | )
139 |
140 | print("Initializing ViTActivationsStore")
141 | activation_store = ViTActivationsStore(
142 | dataset,
143 | args.batch_size,
144 | args.device,
145 | args.seed,
146 | vit,
147 | args.block_layer,
148 | cfg.module_name,
149 | args.class_token,
150 | )
151 |
152 | optimizer = torch.optim.Adam(sae.parameters(), lr=sae.cfg.lr)
153 | scheduler = get_scheduler(args.lr_scheduler_name, optimizer=optimizer)
154 |
155 | print("Initializing SAE b_dec using activation_store")
156 | sae.initialize_b_dec(activation_store)
157 | sae.train()
158 |
159 | if cfg.log_to_wandb:
160 | wandb.init(project=cfg.wandb_project, config=cfg, name=cfg.run_name)
161 |
162 | sae_trainer = SAETrainer(
163 | sae, vit, activation_store, cfg, optimizer, scheduler, args.device
164 | )
165 | sae_trainer.fit()
166 |
--------------------------------------------------------------------------------
/tasks/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import defaultdict
4 | from typing import Dict, Tuple
5 |
6 | import torch
7 | from datasets import Dataset, load_dataset
8 | from tqdm import tqdm
9 |
10 | from src.models.utils import get_adapted_clip, get_base_clip
11 | from src.sae_training.config import Config
12 | from src.sae_training.hooked_vit import HookedVisionTransformer
13 | from src.sae_training.sparse_autoencoder import SparseAutoencoder
14 |
15 | # Dataset configurations
16 | DATASET_INFO = {
17 | "imagenet": {
18 | "path": "evanarlian/imagenet_1k_resized_256",
19 | "split": "train",
20 | },
21 | "imagenet-sketch": {
22 | "path": "clip-benchmark/wds_imagenet_sketch",
23 | "split": "train",
24 | },
25 | "oxford_flowers": {
26 | "path": "nelorth/oxford-flowers",
27 | "split": "train",
28 | },
29 | "caltech101": {
30 | "path": "HuggingFaceM4/Caltech-101",
31 | "split": "train",
32 | "name": "with_background_category",
33 | },
34 | }
35 |
36 | SAE_DIM = 49152
37 |
38 |
39 | def load_sae(sae_path: str, device: str) -> tuple[SparseAutoencoder, Config]:
40 | """Load a sparse autoencoder model from a checkpoint file."""
41 | checkpoint = torch.load(sae_path, map_location="cpu")
42 |
43 | if "cfg" in checkpoint:
44 | cfg = Config(checkpoint["cfg"])
45 | else:
46 | cfg = Config(checkpoint["config"])
47 | sae = SparseAutoencoder(cfg, device)
48 | sae.load_state_dict(checkpoint["state_dict"])
49 | sae.eval().to(device)
50 |
51 | return sae, cfg
52 |
53 |
54 | def load_hooked_vit(
55 | cfg: Config,
56 | vit_type: str,
57 | backbone: str,
58 | device: str,
59 | model_path: str = None,
60 | config_path: str = None,
61 | classnames: list[str] = None,
62 | ) -> HookedVisionTransformer:
63 | """Load a vision transformer model with hooks."""
64 | if vit_type == "base":
65 | model, processor = get_base_clip(backbone)
66 | else:
67 | model, processor = get_adapted_clip(
68 | cfg, vit_type, model_path, config_path, backbone, classnames
69 | )
70 |
71 | return HookedVisionTransformer(model, processor, device=device)
72 |
73 |
74 | def get_sae_and_vit(
75 | sae_path: str,
76 | vit_type: str,
77 | device: str,
78 | backbone: str,
79 | model_path: str = None,
80 | config_path: str = None,
81 | classnames: list[str] = None,
82 | ) -> tuple[SparseAutoencoder, HookedVisionTransformer, Config]:
83 | """Load both SAE and ViT models."""
84 | sae, cfg = load_sae(sae_path, device)
85 | vit = load_hooked_vit(
86 | cfg, vit_type, backbone, device, model_path, config_path, classnames
87 | )
88 | return sae, vit, cfg
89 |
90 |
91 | def load_and_organize_dataset(dataset_name: str) -> Tuple[list, Dict]:
92 | # TODO: ERR for imagenet (gets killed after 75%)
93 | """
94 | Load dataset and organize data by class.
95 | Return classnames and data by class.
96 | Requried for classification_with_top_k_masking.py and compute_class_wise_sae_activation.py
97 | """
98 | dataset = load_dataset(**DATASET_INFO[dataset_name])
99 | classnames = get_classnames(dataset_name, dataset)
100 |
101 | data_by_class = defaultdict(list)
102 | for data_item in tqdm(dataset):
103 | classname = classnames[data_item["label"]]
104 | data_by_class[classname].append(data_item)
105 |
106 | return classnames, data_by_class
107 |
108 |
109 | def get_classnames(
110 | dataset_name: str, dataset: Dataset = None, data_root: str = "./configs/classnames"
111 | ) -> list[str]:
112 | """Get class names for a dataset."""
113 |
114 | filename = f"{data_root}/{dataset_name}_classnames"
115 | txt_filename = filename + ".txt"
116 | json_filename = filename + ".json"
117 |
118 | if not os.path.exists(txt_filename) and not os.path.exists(json_filename):
119 | raise ValueError(f"Dataset {dataset_name} not supported")
120 |
121 | filename = json_filename if os.path.exists(json_filename) else txt_filename
122 |
123 | with open(filename, "r") as file:
124 | if dataset_name == "caltech101":
125 | class_names = [line.strip() for line in file.readlines()]
126 | elif dataset_name == "imagenet" or dataset_name == "imagenet-sketch":
127 | class_names = [
128 | " ".join(line.strip().split(" ")[1:]) for line in file.readlines()
129 | ]
130 | elif dataset_name == "oxford_flowers":
131 | assert dataset is not None, "Dataset must be provided for Oxford Flowers"
132 | new_class_dict = {}
133 | class_names = json.load(file)
134 | classnames_from_hf = dataset.features["label"].names
135 | for i, class_name in enumerate(classnames_from_hf):
136 | new_class_dict[i] = class_names[class_name]
137 | class_names = list(new_class_dict.values())
138 |
139 | else:
140 | raise ValueError(f"Dataset {dataset_name} not supported")
141 |
142 | return class_names
143 |
144 |
145 | def setup_save_directory(
146 | root_dir: str, save_name: str, sae_path: str, vit_type: str, dataset_name: str
147 | ) -> str:
148 | """Set and create the save directory path."""
149 | sae_run_name = sae_path.split("/")[-2]
150 | save_directory = (
151 | f"{root_dir}/{save_name}/sae_{sae_run_name}/{vit_type}/{dataset_name}"
152 | )
153 | os.makedirs(save_directory, exist_ok=True)
154 | return save_directory
155 |
156 |
157 | def get_sae_activations(
158 | model_activations: torch.Tensor, sae: SparseAutoencoder
159 | ) -> torch.Tensor:
160 | """Extract and process activations from the sparse autoencoder."""
161 | hook_name = "hook_hidden_post"
162 |
163 | # Run SAE forward pass and get activations from cache
164 | _, cache = sae.run_with_cache(model_activations)
165 | sae_activations = cache[hook_name]
166 |
167 | # Average across sequence length dimension if needed
168 | if len(sae_activations.size()) > 2:
169 | sae_activations = sae_activations.mean(dim=1)
170 |
171 | return sae_activations
172 |
173 |
174 | def process_batch(vit, batch_data, device):
175 | """Process a single batch of images."""
176 | images = [data["image"] for data in batch_data]
177 |
178 | inputs = vit.processor(
179 | images=images, text="", return_tensors="pt", padding=True
180 | ).to(device)
181 | return inputs
182 |
183 |
184 | def get_max_acts_and_images(
185 | datasets: dict, feat_data_root: str, sae_runname: str, vit_name: str
186 | ) -> tuple[dict, dict]:
187 | """Load and return maximum activations and mean activations for each dataset."""
188 | max_act_imgs = {}
189 | mean_acts = {}
190 |
191 | for dataset_name in datasets:
192 | # Load max activating image indices
193 | max_act_path = os.path.join(
194 | feat_data_root,
195 | f"{sae_runname}/{vit_name}/{dataset_name}",
196 | "max_activating_image_indices.pt",
197 | )
198 | max_act_imgs[dataset_name] = torch.load(max_act_path, map_location="cpu").to(
199 | torch.int32
200 | )
201 |
202 | # Load mean activations
203 | mean_acts_path = os.path.join(
204 | feat_data_root,
205 | f"{sae_runname}/{vit_name}/{dataset_name}",
206 | "sae_mean_acts.pt",
207 | )
208 | mean_acts[dataset_name] = torch.load(mean_acts_path, map_location="cpu").numpy()
209 |
210 | return max_act_imgs, mean_acts
211 |
212 |
213 | def load_datasets(include_imagenet: bool = False, seed: int = 1):
214 | """Load multiple datasets from HuggingFace."""
215 | if include_imagenet:
216 | return {
217 | "imagenet": load_dataset(
218 | "evanarlian/imagenet_1k_resized_256", split="train"
219 | ).shuffle(seed=seed),
220 | "imagenet-sketch": load_dataset(
221 | "clip-benchmark/wds_imagenet_sketch", split="test"
222 | ).shuffle(seed=seed),
223 | "caltech101": load_dataset(
224 | "HuggingFaceM4/Caltech-101",
225 | "with_background_category",
226 | split="train",
227 | ).shuffle(seed=seed),
228 | }
229 | else:
230 | return {
231 | "imagenet-sketch": load_dataset(
232 | "clip-benchmark/wds_imagenet_sketch", split="test"
233 | ).shuffle(seed=seed),
234 | "caltech101": load_dataset(
235 | "HuggingFaceM4/Caltech-101",
236 | "with_background_category",
237 | split="train",
238 | ).shuffle(seed=seed),
239 | }
240 |
241 |
242 | def get_all_classnames(datasets, data_root):
243 | """Get class names for all datasets."""
244 | class_names = {}
245 | for dataset_name, dataset in datasets.items():
246 | class_names[dataset_name] = get_classnames(dataset_name, dataset, data_root)
247 |
248 | # imagenet classnames are required to classnames for maple
249 | if "imagenet" not in class_names:
250 | filename = f"{data_root}/imagenet_classnames"
251 | txt_filename = filename + ".txt"
252 | json_filename = filename + ".json"
253 |
254 | if not os.path.exists(txt_filename) and not os.path.exists(json_filename):
255 | raise ValueError(f"Dataset {dataset_name} not supported")
256 |
257 | filename = json_filename if os.path.exists(json_filename) else txt_filename
258 |
259 | with open(filename, "r") as file:
260 | class_names["imagenet"] = [
261 | " ".join(line.strip().split(" ")[1:]) for line in file.readlines()
262 | ]
263 |
264 | return class_names
265 |
--------------------------------------------------------------------------------